@@ -236,6 +236,7 @@ def __init__(
236236 encoding : str = "utf-8" ,
237237 encoding_errors : str = "strict" ,
238238 decode_responses : bool = False ,
239+ check_server_ready : bool = False ,
239240 parser_class = DefaultParser ,
240241 socket_read_size : int = 65536 ,
241242 health_check_interval : int = 0 ,
@@ -302,6 +303,7 @@ def __init__(
302303 self .redis_connect_func = redis_connect_func
303304 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
304305 self .handshake_metadata = None
306+ self .check_server_ready = check_server_ready
305307 self ._sock = None
306308 self ._socket_read_size = socket_read_size
307309 self .set_parser (parser_class )
@@ -382,15 +384,15 @@ def connect_check_health(self, check_health: bool = True):
382384 if self ._sock :
383385 return
384386 try :
385- sock = self .retry .call_with_retry (
386- lambda : self ._connect (), lambda error : self .disconnect (error )
387+ self .retry .call_with_retry (
388+ lambda : self ._connect_check_server_ready (),
389+ lambda error : self .disconnect (error ),
387390 )
388391 except socket .timeout :
389392 raise TimeoutError ("Timeout connecting to server" )
390393 except OSError as e :
391394 raise ConnectionError (self ._error_message (e ))
392395
393- self ._sock = sock
394396 try :
395397 if self .redis_connect_func is None :
396398 # Use the default on_connect function
@@ -412,8 +414,27 @@ def connect_check_health(self, check_health: bool = True):
412414 if callback :
413415 callback (self )
414416
417+ def _connect_check_server_ready (self ):
418+ self ._connect ()
419+
420+ # Doing handshake since connect and send operations work even when Redis is not ready
421+ if self .check_server_ready :
422+ try :
423+ self .send_command ("PING" , check_health = False )
424+
425+ response = str_if_bytes (self ._sock .recv (1024 ))
426+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
427+ raise ResponseError (f"Invalid PING response: { response } " )
428+ except (ConnectionResetError , ResponseError ) as err :
429+ try :
430+ self ._sock .shutdown (socket .SHUT_RDWR ) # ensure a clean close
431+ except OSError :
432+ pass
433+ self ._sock .close ()
434+ raise ConnectionError (self ._error_message (err ))
435+
415436 @abstractmethod
416- def _connect (self ):
437+ def _connect (self ) -> None :
417438 pass
418439
419440 @abstractmethod
@@ -752,7 +773,7 @@ def repr_pieces(self):
752773 pieces .append (("client_name" , self .client_name ))
753774 return pieces
754775
755- def _connect (self ):
776+ def _connect (self ) -> None :
756777 "Create a TCP socket connection"
757778 # we want to mimic what socket.create_connection does to support
758779 # ipv4/ipv6, but we want to set options prior to calling
@@ -782,7 +803,8 @@ def _connect(self):
782803
783804 # set the socket_timeout now that we're connected
784805 sock .settimeout (self .socket_timeout )
785- return sock
806+ self ._sock = sock
807+ return
786808
787809 except OSError as _ :
788810 err = _
@@ -1095,15 +1117,15 @@ def __init__(
10951117 self .ssl_ciphers = ssl_ciphers
10961118 super ().__init__ (** kwargs )
10971119
1098- def _connect (self ):
1120+ def _connect (self ) -> None :
10991121 """
11001122 Wrap the socket with SSL support, handling potential errors.
11011123 """
1102- sock = super ()._connect ()
1124+ super ()._connect ()
11031125 try :
1104- return self ._wrap_socket_with_ssl (sock )
1126+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
11051127 except (OSError , RedisError ):
1106- sock .close ()
1128+ self . _sock .close ()
11071129 raise
11081130
11091131 def _wrap_socket_with_ssl (self , sock ):
@@ -1200,7 +1222,7 @@ def repr_pieces(self):
12001222 pieces .append (("client_name" , self .client_name ))
12011223 return pieces
12021224
1203- def _connect (self ):
1225+ def _connect (self ) -> None :
12041226 "Create a Unix domain socket connection"
12051227 sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
12061228 sock .settimeout (self .socket_connect_timeout )
@@ -1215,7 +1237,7 @@ def _connect(self):
12151237 sock .close ()
12161238 raise
12171239 sock .settimeout (self .socket_timeout )
1218- return sock
1240+ self . _sock = sock
12191241
12201242 def _host_error (self ):
12211243 return self .path
0 commit comments