@@ -308,7 +308,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
308308 return server
309309
310310
311- def loopback (server_factory = None , client_factory = None ):
311+ def loopback (server_factory = None , client_factory = None , blocking = True ):
312312 """
313313 Create a connected socket pair and force two connected SSL sockets
314314 to talk to each other via memory BIOs.
@@ -324,8 +324,8 @@ def loopback(server_factory=None, client_factory=None):
324324
325325 handshake (client , server )
326326
327- server .setblocking (True )
328- client .setblocking (True )
327+ server .setblocking (blocking )
328+ client .setblocking (blocking )
329329 return server , client
330330
331331
@@ -3131,11 +3131,122 @@ def test_memoryview_really_doesnt_overfill(self):
31313131 self ._doesnt_overfill_test (_make_memoryview )
31323132
31333133
3134+ @pytest .fixture
3135+ def nonblocking_tls_connections_pair ():
3136+ """Return a non-blocking TLS loopback connections pair."""
3137+ return loopback (blocking = False )
3138+
3139+
3140+ @pytest .fixture
3141+ def nonblocking_tls_server_connection (nonblocking_tls_connections_pair ):
3142+ """Return a non-blocking TLS server socket connected to loopback."""
3143+ return nonblocking_tls_connections_pair [0 ]
3144+
3145+
3146+ @pytest .fixture
3147+ def nonblocking_tls_client_connection (nonblocking_tls_connections_pair ):
3148+ """Return a non-blocking TLS client socket connected to loopback."""
3149+ return nonblocking_tls_connections_pair [1 ]
3150+
3151+
31343152class TestConnectionSendall (object ):
31353153 """
31363154 Tests for `Connection.sendall`.
31373155 """
31383156
3157+ def test_want_write (
3158+ self ,
3159+ monkeypatch ,
3160+ nonblocking_tls_server_connection ,
3161+ nonblocking_tls_client_connection ,
3162+ ):
3163+ msg = b"x"
3164+ garbage_size = 1024 * 1024 * 64
3165+ garbage_payload = msg * garbage_size
3166+ large_payload = b"p" * garbage_size * 2
3167+ payload_size = len (large_payload )
3168+
3169+ for i in range (garbage_size ):
3170+ try :
3171+ nonblocking_tls_client_connection .send (msg )
3172+ except WantWriteError :
3173+ break
3174+ else :
3175+ pytest .fail (
3176+ "Failed to fill socket buffer, cannot test "
3177+ "'want write' in `sendall()`"
3178+ )
3179+
3180+ def consume_garbage (conn ):
3181+ if patched_ssl_write .want_write_counter < 5 :
3182+ # NOTE: Ensure that sendall will make a few internal retries
3183+ return
3184+
3185+ assert not consume_garbage .garbage_consumed
3186+
3187+ consume_garbage .consumed += conn .recv (garbage_size )
3188+ if len (consume_garbage .consumed ) < garbage_size :
3189+ return
3190+
3191+ assert consume_garbage .consumed == garbage_payload
3192+
3193+ consume_garbage .garbage_consumed = True
3194+ consume_garbage .garbage_consumed = False
3195+ consume_garbage .consumed = b""
3196+
3197+ consumed_payload = b""
3198+ def consume_payload (conn ):
3199+ consumed_payload += conn .recv (payload_size )
3200+ # FIXME: invoke conn.renegotiate()?
3201+
3202+ original_ssl_write = _lib .SSL_write
3203+ def patched_ssl_write (ctx , data , size ):
3204+ consume_data_on_server = (
3205+ consume_payload if consume_garbage .garbage_consumed
3206+ else consume_garbage
3207+ )
3208+ consume_data_on_server (nonblocking_tls_server_connection )
3209+ write_result = original_ssl_write (ctx , data , size )
3210+ try :
3211+ nonblocking_tls_client_connection ._raise_ssl_error (
3212+ ctx , write_result ,
3213+ )
3214+ except WantWriteError :
3215+ patched_ssl_write .want_write_counter += 1
3216+ consume_data_on_server = (
3217+ consume_payload if consume_garbage .garbage_consumed
3218+ else consume_garbage
3219+ )
3220+ consume_data_on_server (nonblocking_tls_server_connection )
3221+ #breakpoint()
3222+ # NOTE: We don't re-raise it as the calling code will do
3223+ # NOTE: the same after the call.
3224+ return write_result
3225+
3226+ patched_ssl_write .want_write_counter = 0
3227+
3228+ # NOTE: Make the client think it needs a handshake so that it'll
3229+ # NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3230+ # NOTE: that originates from `sendall()`:
3231+ nonblocking_tls_client_connection .set_connect_state ()
3232+ try :
3233+ nonblocking_tls_client_connection .do_handshake ()
3234+ except WantWriteError :
3235+ assert True # Sanity check
3236+ except :
3237+ assert False # This should never happen (see the note above)
3238+
3239+ monkeypatch .setattr (_lib , "SSL_write" , patched_ssl_write )
3240+
3241+ nonblocking_tls_client_connection .sendall (large_payload )
3242+
3243+ assert consume_garbage .garbage_consumed
3244+
3245+ # NOTE: Read the leftover data from the very last `SSL_write()`
3246+ consume_payload (nonblocking_tls_server_connection )
3247+
3248+ assert consumed_payload == large_payload
3249+
31393250 def test_wrong_args (self ):
31403251 """
31413252 When called with arguments other than a string argument for its first
0 commit comments