Skip to content

Commit 3608dd8

Browse files
committed
feat: enhance JsonResponseMiddleware to support expected data types for transformation
- Added `expected_data_type` attribute to `JsonResponseMiddleware` to specify the type of JSON responses to transform. - Implemented logic to skip transformation for unexpected data types, with appropriate logging. - Introduced example middleware classes for handling string, list, and any JSON response types, demonstrating the new functionality. - Updated tests to validate behavior for various expected data types and transformation scenarios.
1 parent 9a18c49 commit 3608dd8

File tree

2 files changed

+206
-4
lines changed

2 files changed

+206
-4
lines changed

src/stac_auth_proxy/utils/middleware.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class JsonResponseMiddleware(ABC):
1818

1919
app: ASGIApp
2020

21+
# Expected data type for JSON responses. Only responses matching this type will be transformed.
22+
# If None, all JSON responses will be transformed regardless of type.
23+
expected_data_type: Optional[type] = dict
24+
2125
@abstractmethod
2226
def should_transform_response(
2327
self, request: Request, scope: Scope
@@ -97,8 +101,21 @@ async def transform_response(message: Message) -> None:
97101
)
98102
await response(scope, receive, send)
99103
return
100-
transformed = self.transform_json(data, request=request)
101-
body = json.dumps(transformed).encode()
104+
105+
if self.expected_data_type is None or isinstance(
106+
data, self.expected_data_type
107+
):
108+
transformed = self.transform_json(data, request=request)
109+
body = json.dumps(transformed).encode()
110+
else:
111+
logger.warning(
112+
"Received JSON response with unexpected data type %r from upstream server (%r %r), "
113+
"skipping transformation (expected: %r)",
114+
type(data).__name__,
115+
request.method,
116+
request.url,
117+
self.expected_data_type.__name__,
118+
)
102119

103120
# Update content-length header
104121
headers["content-length"] = str(len(body))

tests/test_middleware.py

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Tests for middleware utilities."""
22

33
from typing import Any
4+
from unittest.mock import patch
45

6+
import pytest
57
from fastapi import FastAPI, Response
68
from starlette.datastructures import Headers
79
from starlette.requests import Request
@@ -17,18 +19,73 @@ class ExampleJsonResponseMiddleware(JsonResponseMiddleware):
1719
def __init__(self, app: ASGIApp):
1820
"""Initialize the middleware."""
1921
self.app = app
22+
# Use default expected_data_type (dict)
2023

2124
def should_transform_response(self, request: Request, scope: Scope) -> bool:
2225
"""Transform JSON responses based on content type."""
2326
return Headers(scope=scope).get("content-type", "") == "application/json"
2427

2528
def transform_json(self, data: Any, request: Request) -> Any:
2629
"""Add a test field to the response."""
27-
if isinstance(data, dict):
28-
data["transformed"] = True
30+
data["transformed"] = True
2931
return data
3032

3133

34+
class ExampleStringJsonResponseMiddleware(JsonResponseMiddleware):
35+
"""Example implementation that expects string JSON responses."""
36+
37+
def __init__(self, app: ASGIApp):
38+
"""Initialize the middleware."""
39+
self.app = app
40+
self.expected_data_type = str
41+
42+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
43+
"""Transform JSON responses based on content type."""
44+
return Headers(scope=scope).get("content-type", "") == "application/json"
45+
46+
def transform_json(self, data: Any, request: Request) -> Any:
47+
"""Transform string responses by adding a prefix."""
48+
if isinstance(data, str):
49+
return f"transformed: {data}"
50+
return data
51+
52+
53+
class ExampleListJsonResponseMiddleware(JsonResponseMiddleware):
54+
"""Example implementation that expects list JSON responses."""
55+
56+
def __init__(self, app: ASGIApp):
57+
"""Initialize the middleware."""
58+
self.app = app
59+
self.expected_data_type = list
60+
61+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
62+
"""Transform JSON responses based on content type."""
63+
return Headers(scope=scope).get("content-type", "") == "application/json"
64+
65+
def transform_json(self, data: Any, request: Request) -> Any:
66+
"""Transform list responses by adding a new item."""
67+
if isinstance(data, list):
68+
return data + ["transformed"]
69+
return data
70+
71+
72+
class ExampleAnyJsonResponseMiddleware(JsonResponseMiddleware):
73+
"""Example implementation that transforms any JSON response type."""
74+
75+
def __init__(self, app: ASGIApp):
76+
"""Initialize the middleware."""
77+
self.app = app
78+
self.expected_data_type = None # Transform any JSON type
79+
80+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
81+
"""Transform JSON responses based on content type."""
82+
return Headers(scope=scope).get("content-type", "") == "application/json"
83+
84+
def transform_json(self, data: Any, request: Request) -> Any:
85+
"""Transform any JSON response by wrapping it."""
86+
return {"transformed": True, "data": data}
87+
88+
3289
def test_json_response_middleware():
3390
"""Test that JSON responses are properly transformed."""
3491
app = FastAPI()
@@ -119,3 +176,131 @@ async def test_endpoint():
119176
assert response.headers["content-type"] == "application/json"
120177
data = response.json()
121178
assert data == {"error": "Received invalid JSON from upstream server"}
179+
180+
181+
@pytest.mark.parametrize(
182+
"content,expected_data",
183+
[
184+
('"hello world"', "hello world"),
185+
('[1, 2, 3, "test"]', [1, 2, 3, "test"]),
186+
("42", 42),
187+
("true", True),
188+
("null", None),
189+
],
190+
)
191+
def test_json_response_middleware_non_dict_json(content, expected_data):
192+
"""Test that non-dict JSON responses are not transformed by default middleware."""
193+
app = FastAPI()
194+
app.add_middleware(ExampleJsonResponseMiddleware)
195+
196+
@app.get("/test")
197+
async def test_endpoint():
198+
return Response(content=content, media_type="application/json")
199+
200+
client = TestClient(app)
201+
response = client.get("/test")
202+
assert response.status_code == 200
203+
assert response.headers["content-type"] == "application/json"
204+
data = response.json()
205+
assert data == expected_data # Should remain unchanged
206+
207+
208+
@pytest.mark.parametrize(
209+
"middleware_class,expected_data_type,test_data,expected_result,should_transform",
210+
[
211+
# String middleware tests
212+
(
213+
ExampleStringJsonResponseMiddleware,
214+
"this is a string",
215+
"transformed: this is a string",
216+
True,
217+
),
218+
(
219+
ExampleStringJsonResponseMiddleware,
220+
{"message": "not a string"},
221+
{"message": "not a string"},
222+
False,
223+
),
224+
# List middleware tests
225+
(
226+
ExampleListJsonResponseMiddleware,
227+
[1, 2, 3],
228+
[1, 2, 3, "transformed"],
229+
True,
230+
),
231+
(
232+
ExampleListJsonResponseMiddleware,
233+
"not a list",
234+
"not a list",
235+
False,
236+
),
237+
# Dict middleware tests (default)
238+
(
239+
ExampleJsonResponseMiddleware,
240+
{"message": "test"},
241+
{"message": "test", "transformed": True},
242+
True,
243+
),
244+
(
245+
ExampleJsonResponseMiddleware,
246+
"not a dict",
247+
"not a dict",
248+
False,
249+
),
250+
],
251+
)
252+
def test_json_response_middleware_type_specific(
253+
middleware_class, test_data, expected_result, should_transform
254+
):
255+
"""Test that middleware transforms only expected data types."""
256+
with patch.object(
257+
middleware_class, "transform_json", return_value=expected_result
258+
) as mock_method:
259+
app = FastAPI()
260+
app.add_middleware(middleware_class)
261+
262+
@app.get("/test")
263+
async def test_endpoint():
264+
return test_data
265+
266+
client = TestClient(app)
267+
response = client.get("/test")
268+
269+
data = response.json()
270+
assert response.status_code == 200
271+
assert response.headers["content-type"] == "application/json"
272+
assert mock_method.call_count == (1 if should_transform else 0)
273+
if should_transform:
274+
assert mock_method.call_args[0][0] == test_data
275+
assert data == expected_result
276+
277+
278+
@pytest.mark.parametrize(
279+
"test_data",
280+
[
281+
{"message": "test"},
282+
"hello world",
283+
[1, 2, 3],
284+
42,
285+
True,
286+
None,
287+
],
288+
)
289+
def test_json_response_middleware_expected_none_type(test_data):
290+
"""Test that middleware with expected_data_type=None transforms all JSON response types."""
291+
app = FastAPI()
292+
app.add_middleware(ExampleAnyJsonResponseMiddleware)
293+
294+
@app.get("/test")
295+
async def test_endpoint():
296+
return test_data
297+
298+
client = TestClient(app)
299+
response = client.get("/test")
300+
assert response.status_code == 200
301+
assert response.headers["content-type"] == "application/json"
302+
data = response.json()
303+
304+
# Verify the simplified transformation behavior
305+
assert data["transformed"] is True
306+
assert data["data"] == test_data

0 commit comments

Comments
 (0)