Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ db.sqlite3
db.sqlite3-journal
build/
dist/
test_reproduction.py
119 changes: 63 additions & 56 deletions django_eventstream/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,16 @@ async def stream(event_request, listener):
# FIXME: reconcile without re-reading from db

lm.lock.acquire()
conflict = False
if len(listener.channel_items) > 0:
# items were queued while reading from the db. toss them and
# read from db again
listener.aevent.clear()
listener.channel_items = {}
conflict = True
lm.lock.release()
try:
conflict = False
if len(listener.channel_items) > 0:
# items were queued while reading from the db. toss them and
# read from db again
listener.aevent.clear()
listener.channel_items = {}
conflict = True
finally:
lm.lock.release()

if conflict:
continue
Expand All @@ -223,59 +225,64 @@ async def stream(event_request, listener):

while True:
f = asyncio.ensure_future(listener.aevent.wait())
while True:
done, _ = await asyncio.wait([f], timeout=20)
if f in done:
break
body = "event: keep-alive\ndata:\n\n"
yield body
try:
while True:
done, _ = await asyncio.wait([f], timeout=20)
if f in done:
break
body = "event: keep-alive\ndata:\n\n"
yield body

lm.lock.acquire()
try:
channel_items = listener.channel_items
overflow = listener.overflow
error_data = listener.error

listener.aevent.clear()
listener.channel_items = {}
listener.overflow = False
finally:
lm.lock.release()

body = ""
for channel, items in channel_items.items():
for item in items:
if channel in last_ids:
if item.id is not None:
last_ids[channel] = item.id
else:
del last_ids[channel]
if last_ids:
event_id = make_id(last_ids)
else:
event_id = None
body += sse_encode_event(
item.type, item.data, event_id=event_id
)

lm.lock.acquire()
more = True

channel_items = listener.channel_items
overflow = listener.overflow
error_data = listener.error
if error_data:
condition = error_data["condition"]
text = error_data["text"]
extra = error_data.get("extra")
body += sse_encode_error(condition, text, extra=extra)
more = False

listener.aevent.clear()
listener.channel_items = {}
listener.overflow = False
if body or not more:
yield body

lm.lock.release()
if not more:
break

body = ""
for channel, items in channel_items.items():
for item in items:
if channel in last_ids:
if item.id is not None:
last_ids[channel] = item.id
else:
del last_ids[channel]
if last_ids:
event_id = make_id(last_ids)
else:
event_id = None
body += sse_encode_event(
item.type, item.data, event_id=event_id
)

more = True

if error_data:
condition = error_data["condition"]
text = error_data["text"]
extra = error_data.get("extra")
body += sse_encode_error(condition, text, extra=extra)
more = False

if body or not more:
yield body

if not more:
break

if overflow:
# check db
break
if overflow:
# check db
break
finally:
# Always cancel the future to prevent it from lingering
if not f.done():
f.cancel()

event_request.channel_last_ids = last_ids
finally:
Expand Down
222 changes: 222 additions & 0 deletions tests/stress_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#!/usr/bin/env python
"""
Stress test for django-eventstream to verify fixes under load.

This test simulates high-load scenarios to ensure:
1. No asyncio task leaks under rapid connect/disconnect
2. No lock deadlocks under concurrent access
3. Proper cleanup on stream cancellation

Run this manually to verify the fixes work under stress.
"""

import os
import sys
import asyncio
import time
import statistics

# Setup Django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.settings')
import django
django.setup()

from unittest.mock import patch
from asgiref.sync import sync_to_async
from django_eventstream.views import Listener, stream
from django_eventstream.eventrequest import EventRequest
from django_eventstream.storage import DjangoModelStorage


async def stress_test_rapid_cancellations(num_iterations=50):
"""
Stress test: rapid connection/cancellation cycles.

This simulates what happens in production when clients
rapidly connect and disconnect.
"""
print(f"\nStress Test: {num_iterations} rapid connect/disconnect cycles")
print("-" * 70)

storage = DjangoModelStorage()
get_current_id = sync_to_async(storage.get_current_id)

with patch("django_eventstream.eventstream.get_storage", return_value=storage):
try:
current_id = await get_current_id("stress_test_channel")
except:
# Channel doesn't exist yet, that's OK
current_id = "0"

initial_tasks = len([t for t in asyncio.all_tasks()
if not t.done() and t != asyncio.current_task()])

start_time = time.time()

for i in range(num_iterations):
listener = Listener()
request = EventRequest()
request.is_next = False
request.is_recover = False
request.channels = ["stress_test_channel"]
request.channel_last_ids = {"stress_test_channel": str(current_id)}

# Start stream
task = asyncio.create_task(collect_stream(stream(request, listener)))

# Brief wait before cancel
await asyncio.sleep(0.01)

# Cancel
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

# Print progress
if (i + 1) % 10 == 0:
print(f" Completed {i + 1}/{num_iterations} cycles...")

elapsed = time.time() - start_time

# Give time for cleanup
await asyncio.sleep(0.5)

final_tasks = len([t for t in asyncio.all_tasks()
if not t.done() and t != asyncio.current_task()])

print(f"\nResults:")
print(f" Time elapsed: {elapsed:.2f}s")
print(f" Cycles/second: {num_iterations/elapsed:.1f}")
print(f" Initial tasks: {initial_tasks}")
print(f" Final tasks: {final_tasks}")
print(f" Task leak: {final_tasks - initial_tasks}")

if final_tasks - initial_tasks <= 0:
print(" ✓ PASS: No task leaks detected")
return True
else:
print(f" ✗ FAIL: {final_tasks - initial_tasks} tasks leaked")
return False


async def stress_test_concurrent_streams(num_concurrent=20):
"""
Stress test: multiple concurrent streams.

This simulates multiple clients connected simultaneously,
then all disconnecting at once.
"""
print(f"\nStress Test: {num_concurrent} concurrent streams")
print("-" * 70)

storage = DjangoModelStorage()
get_current_id = sync_to_async(storage.get_current_id)

with patch("django_eventstream.eventstream.get_storage", return_value=storage):
try:
current_id = await get_current_id("stress_test_channel")
except:
current_id = "0"

tasks = []

print(f" Starting {num_concurrent} concurrent streams...")

# Start all streams
for i in range(num_concurrent):
listener = Listener()
request = EventRequest()
request.is_next = False
request.is_recover = False
request.channels = ["stress_test_channel"]
request.channel_last_ids = {"stress_test_channel": str(current_id)}

task = asyncio.create_task(collect_stream(stream(request, listener)))
tasks.append(task)

# Let them all run
await asyncio.sleep(1.0)

print(f" Cancelling all {num_concurrent} streams...")

# Cancel all at once
for task in tasks:
task.cancel()

# Wait for all cancellations
cancelled = 0
for task in tasks:
try:
await task
except asyncio.CancelledError:
cancelled += 1

# Give time for cleanup
await asyncio.sleep(0.5)

remaining_tasks = len([t for t in asyncio.all_tasks()
if not t.done() and t != asyncio.current_task()])

print(f"\nResults:")
print(f" Streams started: {len(tasks)}")
print(f" Streams cancelled: {cancelled}")
print(f" Remaining tasks: {remaining_tasks}")

if cancelled == len(tasks):
print(" ✓ PASS: All streams cancelled successfully")
return True
else:
print(f" ✗ FAIL: Only {cancelled}/{len(tasks)} cancelled")
return False


async def collect_stream(stream_iter):
"""Helper to collect stream output."""
response = ""
async for chunk in stream_iter:
response += chunk
return response


async def main():
"""Run all stress tests."""
print("=" * 70)
print("Django EventStream Stress Tests")
print("=" * 70)
print("\nThese tests verify the fixes work correctly under load.")

results = []

try:
# Test 1: Rapid cancellations
results.append(await stress_test_rapid_cancellations(50))

# Test 2: Concurrent streams
results.append(await stress_test_concurrent_streams(20))

# Summary
print("\n" + "=" * 70)
print("STRESS TEST SUMMARY")
print("=" * 70)
print(f"Tests run: {len(results)}")
print(f"Tests passed: {sum(results)}")
print(f"Tests failed: {len(results) - sum(results)}")

if all(results):
print("\n✓ All stress tests passed!")
return 0
else:
print("\n✗ Some stress tests failed")
return 1

except Exception as e:
print(f"\n✗ Error during stress testing: {e}")
import traceback
traceback.print_exc()
return 1


if __name__ == "__main__":
sys.exit(asyncio.run(main()))
Loading