Skip to content

Commit 5c75e63

Browse files
authored
[PR aio-libs#11726/6cffcfd backport][3.13] Fix WebSocket compressed sends to be cancellation safe (aio-libs#11731)
1 parent 95daf0c commit 5c75e63

File tree

6 files changed

+329
-72
lines changed

6 files changed

+329
-72
lines changed

CHANGES/11725.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed WebSocket compressed sends to be cancellation safe. Tasks are now shielded during compression to prevent compressor state corruption. This ensures that the stateful compressor remains consistent even when send operations are cancelled -- by :user:`bdraco`.

aiohttp/_websocket/writer.py

Lines changed: 140 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import asyncio
44
import random
5+
import sys
56
from functools import partial
6-
from typing import Any, Final, Optional, Union
7+
from typing import Final, Optional, Set, Union
78

89
from ..base_protocol import BaseProtocol
910
from ..client_exceptions import ClientConnectionResetError
@@ -22,14 +23,18 @@
2223

2324
DEFAULT_LIMIT: Final[int] = 2**16
2425

26+
# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
27+
# Control frames (ping, pong, close) are never compressed
28+
WS_CONTROL_FRAME_OPCODE: Final[int] = 8
29+
2530
# For websockets, keeping latency low is extremely important as implementations
26-
# generally expect to be able to send and receive messages quickly. We use a
27-
# larger chunk size than the default to reduce the number of executor calls
28-
# since the executor is a significant source of latency and overhead when
29-
# the chunks are small. A size of 5KiB was chosen because it is also the
30-
# same value python-zlib-ng choose to use as the threshold to release the GIL.
31+
# generally expect to be able to send and receive messages quickly. We use a
32+
# larger chunk size to reduce the number of executor calls and avoid task
33+
# creation overhead, since both are significant sources of latency when chunks
34+
# are small. A size of 16KiB was chosen as a balance between avoiding task
35+
# overhead and not blocking the event loop too long with synchronous compression.
3136

32-
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
37+
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
3338

3439

3540
class WebSocketWriter:
@@ -62,7 +67,9 @@ def __init__(
6267
self._closing = False
6368
self._limit = limit
6469
self._output_size = 0
65-
self._compressobj: Any = None # actually compressobj
70+
self._compressobj: Optional[ZLibCompressor] = None
71+
self._send_lock = asyncio.Lock()
72+
self._background_tasks: Set[asyncio.Task[None]] = set()
6673

6774
async def send_frame(
6875
self, message: bytes, opcode: int, compress: Optional[int] = None
@@ -71,39 +78,57 @@ async def send_frame(
7178
if self._closing and not (opcode & WSMsgType.CLOSE):
7279
raise ClientConnectionResetError("Cannot write to closing transport")
7380

74-
# RSV are the reserved bits in the frame header. They are used to
75-
# indicate that the frame is using an extension.
76-
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
77-
rsv = 0
78-
# Only compress larger packets (disabled)
79-
# Does small packet needs to be compressed?
80-
# if self.compress and opcode < 8 and len(message) > 124:
81-
if (compress or self.compress) and opcode < 8:
82-
# RSV1 (rsv = 0x40) is set for compressed frames
83-
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
84-
rsv = 0x40
85-
86-
if compress:
87-
# Do not set self._compress if compressing is for this frame
88-
compressobj = self._make_compress_obj(compress)
89-
else: # self.compress
90-
if not self._compressobj:
91-
self._compressobj = self._make_compress_obj(self.compress)
92-
compressobj = self._compressobj
93-
94-
message = (
95-
await compressobj.compress(message)
96-
+ compressobj.flush(
97-
ZLibBackend.Z_FULL_FLUSH
98-
if self.notakeover
99-
else ZLibBackend.Z_SYNC_FLUSH
100-
)
101-
).removesuffix(WS_DEFLATE_TRAILING)
102-
# Its critical that we do not return control to the event
103-
# loop until we have finished sending all the compressed
104-
# data. Otherwise we could end up mixing compressed frames
105-
# if there are multiple coroutines compressing data.
81+
if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
82+
# Non-compressed frames don't need lock or shield
83+
self._write_websocket_frame(message, opcode, 0)
84+
elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
85+
# Small compressed payloads - compress synchronously in event loop
86+
# We need the lock even though sync compression has no await points.
87+
# This prevents small frames from interleaving with large frames that
88+
# compress in the executor, avoiding compressor state corruption.
89+
async with self._send_lock:
90+
self._send_compressed_frame_sync(message, opcode, compress)
91+
else:
92+
# Large compressed frames need shield to prevent corruption
93+
# For large compressed frames, the entire compress+send
94+
# operation must be atomic. If cancelled after compression but
95+
# before send, the compressor state would be advanced but data
96+
# not sent, corrupting subsequent frames.
97+
# Create a task to shield from cancellation
98+
# The lock is acquired inside the shielded task so the entire
99+
# operation (lock + compress + send) completes atomically.
100+
# Use eager_start on Python 3.12+ to avoid scheduling overhead
101+
loop = asyncio.get_running_loop()
102+
coro = self._send_compressed_frame_async_locked(message, opcode, compress)
103+
if sys.version_info >= (3, 12):
104+
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
105+
else:
106+
send_task = loop.create_task(coro)
107+
# Keep a strong reference to prevent garbage collection
108+
self._background_tasks.add(send_task)
109+
send_task.add_done_callback(self._background_tasks.discard)
110+
await asyncio.shield(send_task)
111+
112+
# It is safe to return control to the event loop when using compression
113+
# after this point as we have already sent or buffered all the data.
114+
# Once we have written output_size up to the limit, we call the
115+
# drain helper which waits for the transport to be ready to accept
116+
# more data. This is a flow control mechanism to prevent the buffer
117+
# from growing too large. The drain helper will return right away
118+
# if the writer is not paused.
119+
if self._output_size > self._limit:
120+
self._output_size = 0
121+
if self.protocol._paused:
122+
await self.protocol._drain_helper()
106123

124+
def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
125+
"""
126+
Write a websocket frame to the transport.
127+
128+
This method handles frame header construction, masking, and writing to transport.
129+
It does not handle compression or flow control - those are the responsibility
130+
of the caller.
131+
"""
107132
msg_length = len(message)
108133

109134
use_mask = self.use_mask
@@ -146,26 +171,85 @@ async def send_frame(
146171

147172
self._output_size += header_len + msg_length
148173

149-
# It is safe to return control to the event loop when using compression
150-
# after this point as we have already sent or buffered all the data.
174+
def _get_compressor(self, compress: Optional[int]) -> ZLibCompressor:
175+
"""Get or create a compressor object for the given compression level."""
176+
if compress:
177+
# Do not set self._compress if compressing is for this frame
178+
return ZLibCompressor(
179+
level=ZLibBackend.Z_BEST_SPEED,
180+
wbits=-compress,
181+
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
182+
)
183+
if not self._compressobj:
184+
self._compressobj = ZLibCompressor(
185+
level=ZLibBackend.Z_BEST_SPEED,
186+
wbits=-self.compress,
187+
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
188+
)
189+
return self._compressobj
151190

152-
# Once we have written output_size up to the limit, we call the
153-
# drain helper which waits for the transport to be ready to accept
154-
# more data. This is a flow control mechanism to prevent the buffer
155-
# from growing too large. The drain helper will return right away
156-
# if the writer is not paused.
157-
if self._output_size > self._limit:
158-
self._output_size = 0
159-
if self.protocol._paused:
160-
await self.protocol._drain_helper()
191+
def _send_compressed_frame_sync(
192+
self, message: bytes, opcode: int, compress: Optional[int]
193+
) -> None:
194+
"""
195+
Synchronous send for small compressed frames.
161196
162-
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
163-
return ZLibCompressor(
164-
level=ZLibBackend.Z_BEST_SPEED,
165-
wbits=-compress,
166-
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
197+
This is used for small compressed payloads that compress synchronously in the event loop.
198+
Since there are no await points, this is inherently cancellation-safe.
199+
"""
200+
# RSV are the reserved bits in the frame header. They are used to
201+
# indicate that the frame is using an extension.
202+
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
203+
compressobj = self._get_compressor(compress)
204+
# (0x40) RSV1 is set for compressed frames
205+
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
206+
self._write_websocket_frame(
207+
(
208+
compressobj.compress_sync(message)
209+
+ compressobj.flush(
210+
ZLibBackend.Z_FULL_FLUSH
211+
if self.notakeover
212+
else ZLibBackend.Z_SYNC_FLUSH
213+
)
214+
).removesuffix(WS_DEFLATE_TRAILING),
215+
opcode,
216+
0x40,
167217
)
168218

219+
async def _send_compressed_frame_async_locked(
220+
self, message: bytes, opcode: int, compress: Optional[int]
221+
) -> None:
222+
"""
223+
Async send for large compressed frames with lock.
224+
225+
Acquires the lock and compresses large payloads asynchronously in
226+
the executor. The lock is held for the entire operation to ensure
227+
the compressor state is not corrupted by concurrent sends.
228+
229+
MUST be run shielded from cancellation. If cancelled after
230+
compression but before sending, the compressor state would be
231+
advanced but data not sent, corrupting subsequent frames.
232+
"""
233+
async with self._send_lock:
234+
# RSV are the reserved bits in the frame header. They are used to
235+
# indicate that the frame is using an extension.
236+
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
237+
compressobj = self._get_compressor(compress)
238+
# (0x40) RSV1 is set for compressed frames
239+
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
240+
self._write_websocket_frame(
241+
(
242+
await compressobj.compress(message)
243+
+ compressobj.flush(
244+
ZLibBackend.Z_FULL_FLUSH
245+
if self.notakeover
246+
else ZLibBackend.Z_SYNC_FLUSH
247+
)
248+
).removesuffix(WS_DEFLATE_TRAILING),
249+
opcode,
250+
0x40,
251+
)
252+
169253
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
170254
"""Close the websocket, sending the specified code and message."""
171255
if isinstance(message, str):

aiohttp/compression_utils.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def __init__(
185185
if level is not None:
186186
kwargs["level"] = level
187187
self._compressor = self._zlib_backend.compressobj(**kwargs)
188-
self._compress_lock = asyncio.Lock()
189188

190189
def compress_sync(self, data: bytes) -> bytes:
191190
return self._compressor.compress(data)
@@ -198,22 +197,37 @@ async def compress(self, data: bytes) -> bytes:
198197
If the data size is large than the max_sync_chunk_size, the compression
199198
will be done in the executor. Otherwise, the compression will be done
200199
in the event loop.
200+
201+
**WARNING: This method is NOT cancellation-safe when used with flush().**
202+
If this operation is cancelled, the compressor state may be corrupted.
203+
The connection MUST be closed after cancellation to avoid data corruption
204+
in subsequent compress operations.
205+
206+
For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
207+
compress() + flush() + send operations in a shield and lock to ensure atomicity.
201208
"""
202-
async with self._compress_lock:
203-
# To ensure the stream is consistent in the event
204-
# there are multiple writers, we need to lock
205-
# the compressor so that only one writer can
206-
# compress at a time.
207-
if (
208-
self._max_sync_chunk_size is not None
209-
and len(data) > self._max_sync_chunk_size
210-
):
211-
return await asyncio.get_running_loop().run_in_executor(
212-
self._executor, self._compressor.compress, data
213-
)
214-
return self.compress_sync(data)
209+
# For large payloads, offload compression to executor to avoid blocking event loop
210+
should_use_executor = (
211+
self._max_sync_chunk_size is not None
212+
and len(data) > self._max_sync_chunk_size
213+
)
214+
if should_use_executor:
215+
return await asyncio.get_running_loop().run_in_executor(
216+
self._executor, self._compressor.compress, data
217+
)
218+
return self.compress_sync(data)
215219

216220
def flush(self, mode: Optional[int] = None) -> bytes:
221+
"""Flush the compressor synchronously.
222+
223+
**WARNING: This method is NOT cancellation-safe when called after compress().**
224+
The flush() operation accesses shared compressor state. If compress() was
225+
cancelled, calling flush() may result in corrupted data. The connection MUST
226+
be closed after compress() cancellation.
227+
228+
For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
229+
compress() + flush() + send operations in a shield and lock to ensure atomicity.
230+
"""
217231
return self._compressor.flush(
218232
mode if mode is not None else self._zlib_backend.Z_FINISH
219233
)

docs/spelling_wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ SocketSocketTransport
304304
ssl
305305
SSLContext
306306
startup
307+
stateful
307308
subapplication
308309
subclassed
309310
subclasses

tests/conftest.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
import socket
66
import ssl
77
import sys
8+
import time
9+
from collections.abc import AsyncIterator, Callable, Iterator
10+
from concurrent.futures import Future, ThreadPoolExecutor
811
from hashlib import md5, sha1, sha256
912
from pathlib import Path
1013
from tempfile import TemporaryDirectory
11-
from typing import Any, AsyncIterator, Generator, Iterator
14+
from typing import Any, Generator
1215
from unittest import mock
1316
from uuid import uuid4
1417

@@ -401,3 +404,27 @@ async def cleanup_payload_pending_file_closes(
401404
loop_futures = [f for f in payload._CLOSE_FUTURES if f.get_loop() is loop]
402405
if loop_futures:
403406
await asyncio.gather(*loop_futures, return_exceptions=True)
407+
408+
409+
@pytest.fixture
410+
def slow_executor() -> Iterator[ThreadPoolExecutor]:
411+
"""Executor that adds delay to simulate slow operations.
412+
413+
Useful for testing cancellation and race conditions in compression tests.
414+
"""
415+
416+
class SlowExecutor(ThreadPoolExecutor):
417+
"""Executor that adds delay to operations."""
418+
419+
def submit(
420+
self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any
421+
) -> Future[Any]:
422+
def slow_fn(*args: Any, **kwargs: Any) -> Any:
423+
time.sleep(0.05) # Add delay to simulate slow operation
424+
return fn(*args, **kwargs)
425+
426+
return super().submit(slow_fn, *args, **kwargs)
427+
428+
executor = SlowExecutor(max_workers=10)
429+
yield executor
430+
executor.shutdown(wait=True)

0 commit comments

Comments
 (0)