@@ -321,7 +321,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
321321 return server
322322
323323
324- def loopback (server_factory = None , client_factory = None ):
324+ def loopback (server_factory = None , client_factory = None , blocking = True ):
325325 """
326326 Create a connected socket pair and force two connected SSL sockets
327327 to talk to each other via memory BIOs.
@@ -337,8 +337,8 @@ def loopback(server_factory=None, client_factory=None):
337337
338338 handshake (client , server )
339339
340- server .setblocking (True )
341- client .setblocking (True )
340+ server .setblocking (blocking )
341+ client .setblocking (blocking )
342342 return server , client
343343
344344
@@ -3297,11 +3297,134 @@ def test_memoryview_really_doesnt_overfill(self):
32973297 self ._doesnt_overfill_test (_make_memoryview )
32983298
32993299
3300+ @pytest .fixture
3301+ def nonblocking_tls_connections_pair ():
3302+ """Return a non-blocking TLS loopback connections pair."""
3303+ return loopback (blocking = False )
3304+
3305+
3306+ @pytest .fixture
3307+ def nonblocking_tls_server_connection (nonblocking_tls_connections_pair ):
3308+ """Return a non-blocking TLS server socket connected to loopback."""
3309+ return nonblocking_tls_connections_pair [0 ]
3310+
3311+
3312+ @pytest .fixture
3313+ def nonblocking_tls_client_connection (nonblocking_tls_connections_pair ):
3314+ """Return a non-blocking TLS client socket connected to loopback."""
3315+ return nonblocking_tls_connections_pair [1 ]
3316+
3317+
33003318class TestConnectionSendall :
33013319 """
33023320 Tests for `Connection.sendall`.
33033321 """
33043322
3323+ def test_want_write (
3324+ self ,
3325+ monkeypatch ,
3326+ nonblocking_tls_server_connection ,
3327+ nonblocking_tls_client_connection ,
3328+ ):
3329+ msg = b"x"
3330+ garbage_size = 1024 * 1024 * 64
3331+ large_payload = b"p" * garbage_size * 2
3332+ payload_size = len (large_payload )
3333+
3334+ sent_garbage_size = 0
3335+ try :
3336+ sent_garbage_size += nonblocking_tls_client_connection .send (
3337+ msg * garbage_size ,
3338+ )
3339+ except WantWriteError :
3340+ pass
3341+ for i in range (garbage_size ):
3342+ try :
3343+ sent_garbage_size += nonblocking_tls_client_connection .send (
3344+ msg ,
3345+ )
3346+ except WantWriteError :
3347+ break
3348+ else :
3349+ pytest .fail (
3350+ "Failed to fill socket buffer, cannot test "
3351+ "'want write' in `sendall()`"
3352+ )
3353+ garbage_payload = sent_garbage_size * msg
3354+
3355+ def consume_garbage (conn ):
3356+ assert patched_ssl_write .want_write_counter >= 1
3357+ assert not consume_garbage .garbage_consumed
3358+
3359+ while len (consume_garbage .consumed ) < sent_garbage_size :
3360+ try :
3361+ consume_garbage .consumed += conn .recv (
3362+ sent_garbage_size - len (consume_garbage .consumed ),
3363+ )
3364+ except WantReadError :
3365+ pass
3366+
3367+ assert consume_garbage .consumed == garbage_payload
3368+
3369+ consume_garbage .garbage_consumed = True
3370+
3371+ consume_garbage .garbage_consumed = False
3372+ consume_garbage .consumed = b""
3373+
3374+ def consume_payload (conn ):
3375+ try :
3376+ consume_payload .consumed += conn .recv (payload_size )
3377+ except WantReadError :
3378+ pass
3379+
3380+ consume_payload .consumed = b""
3381+
3382+ original_ssl_write = _lib .SSL_write
3383+
3384+ def patched_ssl_write (ctx , data , size ):
3385+ write_result = original_ssl_write (ctx , data , size )
3386+ try :
3387+ nonblocking_tls_client_connection ._raise_ssl_error (
3388+ ctx ,
3389+ write_result ,
3390+ )
3391+ except WantWriteError :
3392+ patched_ssl_write .want_write_counter += 1
3393+ consume_data_on_server = (
3394+ consume_payload
3395+ if consume_garbage .garbage_consumed
3396+ else consume_garbage
3397+ )
3398+
3399+ consume_data_on_server (nonblocking_tls_server_connection )
3400+ # NOTE: We don't re-raise it as the calling code will do
3401+ # NOTE: the same after the call.
3402+ return write_result
3403+
3404+ patched_ssl_write .want_write_counter = 0
3405+
3406+ # NOTE: Make the client think it needs a handshake so that it'll
3407+ # NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3408+ # NOTE: that originates from `sendall()`:
3409+ nonblocking_tls_client_connection .set_connect_state ()
3410+ try :
3411+ nonblocking_tls_client_connection .do_handshake ()
3412+ except WantWriteError :
3413+ assert True # Sanity check
3414+ except :
3415+ assert False # This should never happen (see the note above)
3416+
3417+ with monkeypatch .context () as mp_ctx :
3418+ mp_ctx .setattr (_lib , "SSL_write" , patched_ssl_write )
3419+ nonblocking_tls_client_connection .sendall (large_payload )
3420+
3421+ assert consume_garbage .garbage_consumed
3422+
3423+ # NOTE: Read the leftover data from the very last `SSL_write()`
3424+ consume_payload (nonblocking_tls_server_connection )
3425+
3426+ assert consume_payload .consumed == large_payload
3427+
33053428 def test_wrong_args (self ):
33063429 """
33073430 When called with arguments other than a string argument for its first
0 commit comments