@@ -148,6 +148,7 @@ def __init__(
148148 encoding_errors : str = "strict" ,
149149 decode_responses : bool = False ,
150150 parser_class : Type [BaseParser ] = DefaultParser ,
151+ check_ready : bool = False ,
151152 socket_read_size : int = 65536 ,
152153 health_check_interval : float = 0 ,
153154 client_name : Optional [str ] = None ,
@@ -204,6 +205,7 @@ def __init__(
204205 self .health_check_interval = health_check_interval
205206 self .next_health_check : float = - 1
206207 self .encoder = encoder_class (encoding , encoding_errors , decode_responses )
208+ self .check_ready = check_ready
207209 self .redis_connect_func = redis_connect_func
208210 self ._reader : Optional [asyncio .StreamReader ] = None
209211 self ._writer : Optional [asyncio .StreamWriter ] = None
@@ -295,14 +297,48 @@ async def connect(self):
295297 """Connects to the Redis server if not already connected"""
296298 await self .connect_check_health (check_health = True )
297299
300+ async def _connect_check_ready (self ):
301+ await self ._connect ()
302+
303+ # Doing handshake since connect and send operations work even when Redis is not ready
304+ if self .check_ready :
305+ try :
306+ ping_cmd = self .pack_command ("PING" )
307+ if self .socket_timeout :
308+ await asyncio .wait_for (
309+ self ._send_packed_command (ping_cmd ), self .socket_timeout
310+ )
311+ else :
312+ await self ._send_packed_command (ping_cmd )
313+
314+ if self .socket_timeout is not None :
315+ async with async_timeout (self .socket_timeout ):
316+ response = str_if_bytes (await self ._reader .read (1024 ))
317+ else :
318+ response = str_if_bytes (await self ._reader .read (1024 ))
319+
320+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
321+ raise ResponseError (f"Invalid PING response: { response } " )
322+ except (
323+ socket .timeout ,
324+ asyncio .TimeoutError ,
325+ ResponseError ,
326+ ConnectionResetError ,
327+ ) as e :
328+ # `socket_keepalive_options` might contain invalid options
329+ # causing an error. Do not leave the connection open.
330+ self ._close ()
331+ raise ConnectionError (self ._error_message (e ))
332+
298333 async def connect_check_health (self , check_health : bool = True ):
299334 if self .is_connected :
300335 return
301336 try :
302337 await self .retry .call_with_retry (
303- lambda : self ._connect (), lambda error : self .disconnect ()
338+ lambda : self ._connect_check_ready (), lambda error : self .disconnect ()
304339 )
305340 except asyncio .CancelledError :
341+ self ._close ()
306342 raise # in 3.7 and earlier, this is an Exception, not BaseException
307343 except (socket .timeout , asyncio .TimeoutError ):
308344 raise TimeoutError ("Timeout connecting to server" )
@@ -526,8 +562,7 @@ async def send_packed_command(
526562 self ._send_packed_command (command ), self .socket_timeout
527563 )
528564 else :
529- self ._writer .writelines (command )
530- await self ._writer .drain ()
565+ await self ._send_packed_command (command )
531566 except asyncio .TimeoutError :
532567 await self .disconnect (nowait = True )
533568 raise TimeoutError ("Timeout writing to socket" ) from None
@@ -774,7 +809,7 @@ async def _connect(self):
774809 except (OSError , TypeError ):
775810 # `socket_keepalive_options` might contain invalid options
776811 # causing an error. Do not leave the connection open.
777- writer . close ()
812+ self . _close ()
778813 raise
779814
780815 def _host_error (self ) -> str :
@@ -933,7 +968,6 @@ async def _connect(self):
933968 reader , writer = await asyncio .open_unix_connection (path = self .path )
934969 self ._reader = reader
935970 self ._writer = writer
936- await self .on_connect ()
937971
938972 def _host_error (self ) -> str :
939973 return self .path
0 commit comments