11"""Tests for middleware utilities."""
22
33from typing import Any
4+ from unittest .mock import patch
45
6+ import pytest
57from fastapi import FastAPI , Response
68from starlette .datastructures import Headers
79from starlette .requests import Request
@@ -17,18 +19,73 @@ class ExampleJsonResponseMiddleware(JsonResponseMiddleware):
1719 def __init__ (self , app : ASGIApp ):
1820 """Initialize the middleware."""
1921 self .app = app
22+ # Use default expected_data_type (dict)
2023
2124 def should_transform_response (self , request : Request , scope : Scope ) -> bool :
2225 """Transform JSON responses based on content type."""
2326 return Headers (scope = scope ).get ("content-type" , "" ) == "application/json"
2427
2528 def transform_json (self , data : Any , request : Request ) -> Any :
2629 """Add a test field to the response."""
27- if isinstance (data , dict ):
28- data ["transformed" ] = True
30+ data ["transformed" ] = True
2931 return data
3032
3133
34+ class ExampleStringJsonResponseMiddleware (JsonResponseMiddleware ):
35+ """Example implementation that expects string JSON responses."""
36+
37+ def __init__ (self , app : ASGIApp ):
38+ """Initialize the middleware."""
39+ self .app = app
40+ self .expected_data_type = str
41+
42+ def should_transform_response (self , request : Request , scope : Scope ) -> bool :
43+ """Transform JSON responses based on content type."""
44+ return Headers (scope = scope ).get ("content-type" , "" ) == "application/json"
45+
46+ def transform_json (self , data : Any , request : Request ) -> Any :
47+ """Transform string responses by adding a prefix."""
48+ if isinstance (data , str ):
49+ return f"transformed: { data } "
50+ return data
51+
52+
53+ class ExampleListJsonResponseMiddleware (JsonResponseMiddleware ):
54+ """Example implementation that expects list JSON responses."""
55+
56+ def __init__ (self , app : ASGIApp ):
57+ """Initialize the middleware."""
58+ self .app = app
59+ self .expected_data_type = list
60+
61+ def should_transform_response (self , request : Request , scope : Scope ) -> bool :
62+ """Transform JSON responses based on content type."""
63+ return Headers (scope = scope ).get ("content-type" , "" ) == "application/json"
64+
65+ def transform_json (self , data : Any , request : Request ) -> Any :
66+ """Transform list responses by adding a new item."""
67+ if isinstance (data , list ):
68+ return data + ["transformed" ]
69+ return data
70+
71+
72+ class ExampleAnyJsonResponseMiddleware (JsonResponseMiddleware ):
73+ """Example implementation that transforms any JSON response type."""
74+
75+ def __init__ (self , app : ASGIApp ):
76+ """Initialize the middleware."""
77+ self .app = app
78+ self .expected_data_type = None # Transform any JSON type
79+
80+ def should_transform_response (self , request : Request , scope : Scope ) -> bool :
81+ """Transform JSON responses based on content type."""
82+ return Headers (scope = scope ).get ("content-type" , "" ) == "application/json"
83+
84+ def transform_json (self , data : Any , request : Request ) -> Any :
85+ """Transform any JSON response by wrapping it."""
86+ return {"transformed" : True , "data" : data }
87+
88+
3289def test_json_response_middleware ():
3390 """Test that JSON responses are properly transformed."""
3491 app = FastAPI ()
@@ -119,3 +176,131 @@ async def test_endpoint():
119176 assert response .headers ["content-type" ] == "application/json"
120177 data = response .json ()
121178 assert data == {"error" : "Received invalid JSON from upstream server" }
179+
180+
181+ @pytest .mark .parametrize (
182+ "content,expected_data" ,
183+ [
184+ ('"hello world"' , "hello world" ),
185+ ('[1, 2, 3, "test"]' , [1 , 2 , 3 , "test" ]),
186+ ("42" , 42 ),
187+ ("true" , True ),
188+ ("null" , None ),
189+ ],
190+ )
191+ def test_json_response_middleware_non_dict_json (content , expected_data ):
192+ """Test that non-dict JSON responses are not transformed by default middleware."""
193+ app = FastAPI ()
194+ app .add_middleware (ExampleJsonResponseMiddleware )
195+
196+ @app .get ("/test" )
197+ async def test_endpoint ():
198+ return Response (content = content , media_type = "application/json" )
199+
200+ client = TestClient (app )
201+ response = client .get ("/test" )
202+ assert response .status_code == 200
203+ assert response .headers ["content-type" ] == "application/json"
204+ data = response .json ()
205+ assert data == expected_data # Should remain unchanged
206+
207+
208+ @pytest .mark .parametrize (
209+ "middleware_class,expected_data_type,test_data,expected_result,should_transform" ,
210+ [
211+ # String middleware tests
212+ (
213+ ExampleStringJsonResponseMiddleware ,
214+ "this is a string" ,
215+ "transformed: this is a string" ,
216+ True ,
217+ ),
218+ (
219+ ExampleStringJsonResponseMiddleware ,
220+ {"message" : "not a string" },
221+ {"message" : "not a string" },
222+ False ,
223+ ),
224+ # List middleware tests
225+ (
226+ ExampleListJsonResponseMiddleware ,
227+ [1 , 2 , 3 ],
228+ [1 , 2 , 3 , "transformed" ],
229+ True ,
230+ ),
231+ (
232+ ExampleListJsonResponseMiddleware ,
233+ "not a list" ,
234+ "not a list" ,
235+ False ,
236+ ),
237+ # Dict middleware tests (default)
238+ (
239+ ExampleJsonResponseMiddleware ,
240+ {"message" : "test" },
241+ {"message" : "test" , "transformed" : True },
242+ True ,
243+ ),
244+ (
245+ ExampleJsonResponseMiddleware ,
246+ "not a dict" ,
247+ "not a dict" ,
248+ False ,
249+ ),
250+ ],
251+ )
252+ def test_json_response_middleware_type_specific (
253+ middleware_class , test_data , expected_result , should_transform
254+ ):
255+ """Test that middleware transforms only expected data types."""
256+ with patch .object (
257+ middleware_class , "transform_json" , return_value = expected_result
258+ ) as mock_method :
259+ app = FastAPI ()
260+ app .add_middleware (middleware_class )
261+
262+ @app .get ("/test" )
263+ async def test_endpoint ():
264+ return test_data
265+
266+ client = TestClient (app )
267+ response = client .get ("/test" )
268+
269+ data = response .json ()
270+ assert response .status_code == 200
271+ assert response .headers ["content-type" ] == "application/json"
272+ assert mock_method .call_count == (1 if should_transform else 0 )
273+ if should_transform :
274+ assert mock_method .call_args [0 ][0 ] == test_data
275+ assert data == expected_result
276+
277+
278+ @pytest .mark .parametrize (
279+ "test_data" ,
280+ [
281+ {"message" : "test" },
282+ "hello world" ,
283+ [1 , 2 , 3 ],
284+ 42 ,
285+ True ,
286+ None ,
287+ ],
288+ )
289+ def test_json_response_middleware_expected_none_type (test_data ):
290+ """Test that middleware with expected_data_type=None transforms all JSON response types."""
291+ app = FastAPI ()
292+ app .add_middleware (ExampleAnyJsonResponseMiddleware )
293+
294+ @app .get ("/test" )
295+ async def test_endpoint ():
296+ return test_data
297+
298+ client = TestClient (app )
299+ response = client .get ("/test" )
300+ assert response .status_code == 200
301+ assert response .headers ["content-type" ] == "application/json"
302+ data = response .json ()
303+
304+ # Verify the simplified transformation behavior
305+ assert data ["transformed" ] is True
306+ assert data ["data" ] == test_data
0 commit comments