From af8855dfa53f83e69931fe7aad153aed5d2167d1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 21 Oct 2025 14:23:16 +0000 Subject: [PATCH 1/4] Initial plan From 8857c8640f41839eb97a2bdfc1cbb1b99dcdfe9b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 21 Oct 2025 14:31:54 +0000 Subject: [PATCH 2/4] Fix asyncio future cancellation in SSE stream to prevent shutdown hangs Co-authored-by: enzofrnt <63660254+enzofrnt@users.noreply.github.com> --- django_eventstream/views.py | 101 +++++++++++++++++++----------------- tests/test_stream.py | 50 ++++++++++++++++++ 2 files changed, 103 insertions(+), 48 deletions(-) diff --git a/django_eventstream/views.py b/django_eventstream/views.py index d670694..7735e04 100644 --- a/django_eventstream/views.py +++ b/django_eventstream/views.py @@ -223,59 +223,64 @@ async def stream(event_request, listener): while True: f = asyncio.ensure_future(listener.aevent.wait()) - while True: - done, _ = await asyncio.wait([f], timeout=20) - if f in done: - break - body = "event: keep-alive\ndata:\n\n" - yield body + try: + while True: + done, _ = await asyncio.wait([f], timeout=20) + if f in done: + break + body = "event: keep-alive\ndata:\n\n" + yield body + + lm.lock.acquire() + + channel_items = listener.channel_items + overflow = listener.overflow + error_data = listener.error + + listener.aevent.clear() + listener.channel_items = {} + listener.overflow = False + + lm.lock.release() + + body = "" + for channel, items in channel_items.items(): + for item in items: + if channel in last_ids: + if item.id is not None: + last_ids[channel] = item.id + else: + del last_ids[channel] + if last_ids: + event_id = make_id(last_ids) + else: + event_id = None + body += sse_encode_event( + item.type, item.data, event_id=event_id + ) - lm.lock.acquire() + more = True - channel_items = listener.channel_items - overflow = listener.overflow - error_data = listener.error + if error_data: + condition = error_data["condition"] + text = error_data["text"] + extra = error_data.get("extra") + body += sse_encode_error(condition, text, extra=extra) + more = False - listener.aevent.clear() - listener.channel_items = {} - listener.overflow = False + if body or not more: + yield body - lm.lock.release() + if not more: + break - body = "" - for channel, items in channel_items.items(): - for item in items: - if channel in last_ids: - if item.id is not None: - last_ids[channel] = item.id - else: - del last_ids[channel] - if last_ids: - event_id = make_id(last_ids) - else: - event_id = None - body += sse_encode_event( - item.type, item.data, event_id=event_id - ) - - more = True - - if error_data: - condition = error_data["condition"] - text = error_data["text"] - extra = error_data.get("extra") - body += sse_encode_error(condition, text, extra=extra) - more = False - - if body or not more: - yield body - - if not more: - break - - if overflow: - # check db - break + if overflow: + # check db + break + finally: + # Always cancel the future to prevent it from lingering + if not f.done(): + f.cancel() event_request.channel_last_ids = last_ids finally: diff --git a/tests/test_stream.py b/tests/test_stream.py index e6baa86..406f538 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -58,6 +58,56 @@ def __assert_all_events_are_retrieved_only_once(self): CHANNEL_NAME, EVENTS_LIMIT, limit=EVENTS_LIMIT + 1 ) + @patch("django_eventstream.eventstream.get_storage") + async def test_stream_cancellation_during_wait(self, mock_get_storage): + """Test that stream properly cleans up when cancelled during event wait.""" + mock_get_storage.return_value = self.storage + + # Create a real listener (not mocked) to test actual wait behavior + listener = Listener() + + request = EventRequest() + request.is_next = False + request.is_recover = False + request.channels = [CHANNEL_NAME] + + # Get current ID using sync_to_async + get_current_id = sync_to_async(self.storage.get_current_id) + current_id = await get_current_id(CHANNEL_NAME) + request.channel_last_ids = {CHANNEL_NAME: str(current_id)} + + # Start streaming - this will wait for events since we're caught up + stream_task = asyncio.create_task( + self.__collect_response(stream(request, listener)) + ) + + # Give it time to enter the wait loop + await asyncio.sleep(0.5) + + # Cancel the stream + stream_task.cancel() + + try: + await stream_task + raise ValueError("stream completed unexpectedly") + except asyncio.CancelledError: + pass + + # Verify no tasks are left running + pending_tasks = [task for task in asyncio.all_tasks() + if not task.done() and task != asyncio.current_task()] + + # Allow brief time for cleanup + await asyncio.sleep(0.1) + + # Check again after cleanup time + pending_tasks_after = [task for task in asyncio.all_tasks() + if not task.done() and task != asyncio.current_task()] + + # The number of pending tasks should not increase after cancellation + self.assertLessEqual(len(pending_tasks_after), len(pending_tasks), + "Stream cancellation should not leave lingering tasks") + async def __initialise_test(self, mock_get_storage): mock_get_storage.return_value = self.storage From 1135332f7c395e48b6dca83cfc640753f6f3ac83 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 21 Oct 2025 14:40:13 +0000 Subject: [PATCH 3/4] Add try-finally protection for lock operations to prevent deadlocks Co-authored-by: enzofrnt <63660254+enzofrnt@users.noreply.github.com> --- django_eventstream/views.py | 38 +++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/django_eventstream/views.py b/django_eventstream/views.py index 7735e04..e561b02 100644 --- a/django_eventstream/views.py +++ b/django_eventstream/views.py @@ -207,14 +207,16 @@ async def stream(event_request, listener): # FIXME: reconcile without re-reading from db lm.lock.acquire() - conflict = False - if len(listener.channel_items) > 0: - # items were queued while reading from the db. toss them and - # read from db again - listener.aevent.clear() - listener.channel_items = {} - conflict = True - lm.lock.release() + try: + conflict = False + if len(listener.channel_items) > 0: + # items were queued while reading from the db. toss them and + # read from db again + listener.aevent.clear() + listener.channel_items = {} + conflict = True + finally: + lm.lock.release() if conflict: continue @@ -232,16 +234,16 @@ async def stream(event_request, listener): yield body lm.lock.acquire() - - channel_items = listener.channel_items - overflow = listener.overflow - error_data = listener.error - - listener.aevent.clear() - listener.channel_items = {} - listener.overflow = False - - lm.lock.release() + try: + channel_items = listener.channel_items + overflow = listener.overflow + error_data = listener.error + + listener.aevent.clear() + listener.channel_items = {} + listener.overflow = False + finally: + lm.lock.release() body = "" for channel, items in channel_items.items(): From a54be9dcac0d062046d95b7f76de9aac112118d3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 21 Oct 2025 19:35:11 +0000 Subject: [PATCH 4/4] Add stress test to verify fixes under high load Co-authored-by: enzofrnt <63660254+enzofrnt@users.noreply.github.com> --- .gitignore | 1 + tests/stress_test.py | 222 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 tests/stress_test.py diff --git a/.gitignore b/.gitignore index 8894874..57da3cb 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ db.sqlite3 db.sqlite3-journal build/ dist/ +test_reproduction.py diff --git a/tests/stress_test.py b/tests/stress_test.py new file mode 100644 index 0000000..e569ef6 --- /dev/null +++ b/tests/stress_test.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +""" +Stress test for django-eventstream to verify fixes under load. + +This test simulates high-load scenarios to ensure: +1. No asyncio task leaks under rapid connect/disconnect +2. No lock deadlocks under concurrent access +3. Proper cleanup on stream cancellation + +Run this manually to verify the fixes work under stress. +""" + +import os +import sys +import asyncio +import time +import statistics + +# Setup Django +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.settings') +import django +django.setup() + +from unittest.mock import patch +from asgiref.sync import sync_to_async +from django_eventstream.views import Listener, stream +from django_eventstream.eventrequest import EventRequest +from django_eventstream.storage import DjangoModelStorage + + +async def stress_test_rapid_cancellations(num_iterations=50): + """ + Stress test: rapid connection/cancellation cycles. + + This simulates what happens in production when clients + rapidly connect and disconnect. + """ + print(f"\nStress Test: {num_iterations} rapid connect/disconnect cycles") + print("-" * 70) + + storage = DjangoModelStorage() + get_current_id = sync_to_async(storage.get_current_id) + + with patch("django_eventstream.eventstream.get_storage", return_value=storage): + try: + current_id = await get_current_id("stress_test_channel") + except: + # Channel doesn't exist yet, that's OK + current_id = "0" + + initial_tasks = len([t for t in asyncio.all_tasks() + if not t.done() and t != asyncio.current_task()]) + + start_time = time.time() + + for i in range(num_iterations): + listener = Listener() + request = EventRequest() + request.is_next = False + request.is_recover = False + request.channels = ["stress_test_channel"] + request.channel_last_ids = {"stress_test_channel": str(current_id)} + + # Start stream + task = asyncio.create_task(collect_stream(stream(request, listener))) + + # Brief wait before cancel + await asyncio.sleep(0.01) + + # Cancel + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Print progress + if (i + 1) % 10 == 0: + print(f" Completed {i + 1}/{num_iterations} cycles...") + + elapsed = time.time() - start_time + + # Give time for cleanup + await asyncio.sleep(0.5) + + final_tasks = len([t for t in asyncio.all_tasks() + if not t.done() and t != asyncio.current_task()]) + + print(f"\nResults:") + print(f" Time elapsed: {elapsed:.2f}s") + print(f" Cycles/second: {num_iterations/elapsed:.1f}") + print(f" Initial tasks: {initial_tasks}") + print(f" Final tasks: {final_tasks}") + print(f" Task leak: {final_tasks - initial_tasks}") + + if final_tasks - initial_tasks <= 0: + print(" ✓ PASS: No task leaks detected") + return True + else: + print(f" ✗ FAIL: {final_tasks - initial_tasks} tasks leaked") + return False + + +async def stress_test_concurrent_streams(num_concurrent=20): + """ + Stress test: multiple concurrent streams. + + This simulates multiple clients connected simultaneously, + then all disconnecting at once. + """ + print(f"\nStress Test: {num_concurrent} concurrent streams") + print("-" * 70) + + storage = DjangoModelStorage() + get_current_id = sync_to_async(storage.get_current_id) + + with patch("django_eventstream.eventstream.get_storage", return_value=storage): + try: + current_id = await get_current_id("stress_test_channel") + except: + current_id = "0" + + tasks = [] + + print(f" Starting {num_concurrent} concurrent streams...") + + # Start all streams + for i in range(num_concurrent): + listener = Listener() + request = EventRequest() + request.is_next = False + request.is_recover = False + request.channels = ["stress_test_channel"] + request.channel_last_ids = {"stress_test_channel": str(current_id)} + + task = asyncio.create_task(collect_stream(stream(request, listener))) + tasks.append(task) + + # Let them all run + await asyncio.sleep(1.0) + + print(f" Cancelling all {num_concurrent} streams...") + + # Cancel all at once + for task in tasks: + task.cancel() + + # Wait for all cancellations + cancelled = 0 + for task in tasks: + try: + await task + except asyncio.CancelledError: + cancelled += 1 + + # Give time for cleanup + await asyncio.sleep(0.5) + + remaining_tasks = len([t for t in asyncio.all_tasks() + if not t.done() and t != asyncio.current_task()]) + + print(f"\nResults:") + print(f" Streams started: {len(tasks)}") + print(f" Streams cancelled: {cancelled}") + print(f" Remaining tasks: {remaining_tasks}") + + if cancelled == len(tasks): + print(" ✓ PASS: All streams cancelled successfully") + return True + else: + print(f" ✗ FAIL: Only {cancelled}/{len(tasks)} cancelled") + return False + + +async def collect_stream(stream_iter): + """Helper to collect stream output.""" + response = "" + async for chunk in stream_iter: + response += chunk + return response + + +async def main(): + """Run all stress tests.""" + print("=" * 70) + print("Django EventStream Stress Tests") + print("=" * 70) + print("\nThese tests verify the fixes work correctly under load.") + + results = [] + + try: + # Test 1: Rapid cancellations + results.append(await stress_test_rapid_cancellations(50)) + + # Test 2: Concurrent streams + results.append(await stress_test_concurrent_streams(20)) + + # Summary + print("\n" + "=" * 70) + print("STRESS TEST SUMMARY") + print("=" * 70) + print(f"Tests run: {len(results)}") + print(f"Tests passed: {sum(results)}") + print(f"Tests failed: {len(results) - sum(results)}") + + if all(results): + print("\n✓ All stress tests passed!") + return 0 + else: + print("\n✗ Some stress tests failed") + return 1 + + except Exception as e: + print(f"\n✗ Error during stress testing: {e}") + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main()))