@@ -308,7 +308,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
308308    return  server 
309309
310310
311- def  loopback (server_factory = None , client_factory = None ):
311+ def  loopback (server_factory = None , client_factory = None ,  blocking = True ):
312312    """ 
313313    Create a connected socket pair and force two connected SSL sockets 
314314    to talk to each other via memory BIOs. 
@@ -324,8 +324,8 @@ def loopback(server_factory=None, client_factory=None):
324324
325325    handshake (client , server )
326326
327-     server .setblocking (True )
328-     client .setblocking (True )
327+     server .setblocking (blocking )
328+     client .setblocking (blocking )
329329    return  server , client 
330330
331331
@@ -3131,11 +3131,131 @@ def test_memoryview_really_doesnt_overfill(self):
31313131        self ._doesnt_overfill_test (_make_memoryview )
31323132
31333133
3134+ @pytest .fixture  
3135+ def  nonblocking_tls_connections_pair ():
3136+     """Return a non-blocking TLS loopback connections pair.""" 
3137+     return  loopback (blocking = False )
3138+ 
3139+ 
3140+ @pytest .fixture  
3141+ def  nonblocking_tls_server_connection (nonblocking_tls_connections_pair ):
3142+     """Return a non-blocking TLS server socket connected to loopback.""" 
3143+     return  nonblocking_tls_connections_pair [0 ]
3144+ 
3145+ 
3146+ @pytest .fixture  
3147+ def  nonblocking_tls_client_connection (nonblocking_tls_connections_pair ):
3148+     """Return a non-blocking TLS client socket connected to loopback.""" 
3149+     return  nonblocking_tls_connections_pair [1 ]
3150+ 
3151+ 
31343152class  TestConnectionSendall (object ):
31353153    """ 
31363154    Tests for `Connection.sendall`. 
31373155    """ 
31383156
3157+     def  test_want_write (
3158+             self ,
3159+             monkeypatch ,
3160+             nonblocking_tls_server_connection ,
3161+             nonblocking_tls_client_connection ,
3162+     ):
3163+         msg  =  b"x" 
3164+         garbage_size  =  1024  *  1024  *  64 
3165+         large_payload  =  b"p"  *  garbage_size  *  2 
3166+         payload_size  =  len (large_payload )
3167+ 
3168+         sent_garbage_size  =  0 
3169+         try :
3170+             sent_garbage_size  +=  nonblocking_tls_client_connection .send (
3171+                 msg  *  garbage_size ,
3172+             )
3173+         except  WantWriteError :
3174+             pass 
3175+         for  i  in  range (garbage_size ):
3176+             try :
3177+                 sent_garbage_size  +=  nonblocking_tls_client_connection .send (
3178+                     msg ,
3179+                 )
3180+             except  WantWriteError :
3181+                 break 
3182+         else :
3183+             pytest .fail (
3184+                 "Failed to fill socket buffer, cannot test " 
3185+                 "'want write' in `sendall()`" 
3186+             )
3187+         garbage_payload  =  sent_garbage_size  *  msg 
3188+ 
3189+ 
3190+         def  consume_garbage (conn ):
3191+             assert  patched_ssl_write .want_write_counter  >=  1 
3192+             assert  not  consume_garbage .garbage_consumed 
3193+ 
3194+             while  len (consume_garbage .consumed ) <  sent_garbage_size :
3195+                 try :
3196+                     consume_garbage .consumed  +=  conn .recv (
3197+                         sent_garbage_size  -  len (consume_garbage .consumed ),
3198+                     )
3199+                 except  WantReadError :
3200+                     pass 
3201+ 
3202+             assert  consume_garbage .consumed  ==  garbage_payload 
3203+ 
3204+             consume_garbage .garbage_consumed  =  True 
3205+ 
3206+         consume_garbage .garbage_consumed  =  False 
3207+         consume_garbage .consumed  =  b"" 
3208+ 
3209+         def  consume_payload (conn ):
3210+             try :
3211+                 consume_payload .consumed  +=  conn .recv (payload_size )
3212+             except  WantReadError :
3213+                 pass 
3214+         consume_payload .consumed  =  b"" 
3215+ 
3216+         original_ssl_write  =  _lib .SSL_write 
3217+         def  patched_ssl_write (ctx , data , size ):
3218+             write_result  =  original_ssl_write (ctx , data , size )
3219+             try :
3220+                 nonblocking_tls_client_connection ._raise_ssl_error (
3221+                     ctx , write_result ,
3222+                 )
3223+             except  WantWriteError :
3224+                 patched_ssl_write .want_write_counter  +=  1 
3225+                 consume_data_on_server  =  (
3226+                     consume_payload  if  consume_garbage .garbage_consumed 
3227+                     else  consume_garbage 
3228+                 )
3229+ 
3230+                 consume_data_on_server (nonblocking_tls_server_connection )
3231+                 # NOTE: We don't re-raise it as the calling code will do 
3232+                 # NOTE: the same after the call. 
3233+             return  write_result 
3234+ 
3235+         patched_ssl_write .want_write_counter  =  0 
3236+ 
3237+         # NOTE: Make the client think it needs a handshake so that it'll 
3238+         # NOTE: attempt to `do_handshake()` on the next `SSL_write()` 
3239+         # NOTE: that originates from `sendall()`: 
3240+         nonblocking_tls_client_connection .set_connect_state ()
3241+         try :
3242+             nonblocking_tls_client_connection .do_handshake ()
3243+         except  WantWriteError :
3244+             assert  True   # Sanity check 
3245+         except :
3246+             assert  False   # This should never happen (see the note above) 
3247+ 
3248+         with  monkeypatch .context () as  mp_ctx :
3249+             mp_ctx .setattr (_lib , "SSL_write" , patched_ssl_write )
3250+             nonblocking_tls_client_connection .sendall (large_payload )
3251+ 
3252+         assert  consume_garbage .garbage_consumed 
3253+ 
3254+         # NOTE: Read the leftover data from the very last `SSL_write()` 
3255+         consume_payload (nonblocking_tls_server_connection )
3256+ 
3257+         assert  consume_payload .consumed  ==  large_payload 
3258+ 
31393259    def  test_wrong_args (self ):
31403260        """ 
31413261        When called with arguments other than a string argument for its first 
0 commit comments