Skip to content

Commit 4a6c2c0

Browse files
Adding retries for the overall connect - socket connect + handshake. Fix for pubsub reconnect issues. (#3863)
* Adding retries for the overall connect - socket connect + handshake. Fix for pubsub reconnect issues. * Update tests/test_connection.py Co-authored-by: Copilot <[email protected]> * Update tests/test_asyncio/test_connection.py Co-authored-by: Copilot <[email protected]> * Apply suggestions from code review Co-authored-by: Copilot <[email protected]> * Fix linters --------- Co-authored-by: Copilot <[email protected]>
1 parent 742b13b commit 4a6c2c0

File tree

4 files changed

+56
-10
lines changed

4 files changed

+56
-10
lines changed

redis/asyncio/connection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,14 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:
296296

297297
async def connect(self):
298298
"""Connects to the Redis server if not already connected"""
299-
await self.connect_check_health(check_health=True)
299+
# try once the socket connect with the handshake, retry the whole
300+
# connect/handshake flow based on retry policy
301+
await self.retry.call_with_retry(
302+
lambda: self.connect_check_health(
303+
check_health=True, retry_socket_connect=False
304+
),
305+
lambda error: self.disconnect(),
306+
)
300307

301308
async def connect_check_health(
302309
self, check_health: bool = True, retry_socket_connect: bool = True

redis/connection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,14 @@ def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
843843

844844
def connect(self):
845845
"Connects to the Redis server if not already connected"
846-
self.connect_check_health(check_health=True)
846+
# try once the socket connect with the handshake, retry the whole
847+
# connect/handshake flow based on retry policy
848+
self.retry.call_with_retry(
849+
lambda: self.connect_check_health(
850+
check_health=True, retry_socket_connect=False
851+
),
852+
lambda error: self.disconnect(error),
853+
)
847854

848855
def connect_check_health(
849856
self, check_health: bool = True, retry_socket_connect: bool = True

tests/test_asyncio/test_connection.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,33 @@ async def mock_connect():
155155
await conn.disconnect()
156156

157157

158-
async def test_connect_without_retry_on_os_error():
159-
"""Test that the _connect function is not being retried in case of a OSError"""
158+
async def test_connect_without_retry_on_non_retryable_error():
159+
"""
160+
Test that the _connect function is not being retried in case of a CancelledError -
161+
error that is not in the list of retry-able errors"""
160162
with patch.object(Connection, "_connect") as _connect:
161-
_connect.side_effect = OSError("")
163+
_connect.side_effect = asyncio.CancelledError("")
162164
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2))
163-
with pytest.raises(ConnectionError):
165+
with pytest.raises(asyncio.CancelledError):
164166
await conn.connect()
165167
assert _connect.call_count == 1
166168

167169

170+
async def test_connect_with_retries():
171+
"""
172+
Test that retries occur for the entire connect+handshake flow when OSError happens during the handshake phase.
173+
"""
174+
with patch.object(asyncio.StreamWriter, "writelines") as writelines:
175+
writelines.side_effect = OSError(ECONNREFUSED)
176+
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2))
177+
with pytest.raises(ConnectionError):
178+
await conn.connect()
179+
# the handshake commands are the failing ones
180+
# validate that we don't execute too many commands on each retry
181+
# 3 retries --> 3 commands
182+
assert writelines.call_count == 3
183+
184+
168185
async def test_connect_timeout_error_without_retry():
169186
"""Test that the _connect function is not being retried if retry_on_timeout is
170187
set to False"""

tests/test_connection.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,31 @@ def mock_connect():
124124
assert conn._connect.call_count == 3
125125
self.clear(conn)
126126

127-
def test_connect_without_retry_on_os_error(self):
128-
"""Test that the _connect function is not being retried in case of a OSError"""
127+
def test_connect_without_retry_on_non_retryable_error(self):
128+
"""Test that the _connect function is not being retried in case of a non-retryable error"""
129129
with patch.object(Connection, "_connect") as _connect:
130-
_connect.side_effect = OSError("")
130+
_connect.side_effect = RedisError("")
131131
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2))
132-
with pytest.raises(ConnectionError):
132+
with pytest.raises(RedisError):
133133
conn.connect()
134134
assert _connect.call_count == 1
135135
self.clear(conn)
136136

137+
def test_connect_with_retries(self):
138+
"""
139+
Validate that retries occur for the entire connect+handshake flow when OSError
140+
happens during the handshake phase.
141+
"""
142+
with patch.object(socket.socket, "sendall") as sendall:
143+
sendall.side_effect = OSError(ECONNREFUSED)
144+
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2))
145+
with pytest.raises(ConnectionError):
146+
conn.connect()
147+
# the handshake commands are the failing ones
148+
# validate that we don't execute too many commands on each retry
149+
# 3 retries --> 3 commands
150+
assert sendall.call_count == 3
151+
137152
def test_connect_timeout_error_without_retry(self):
138153
"""Test that the _connect function is not being retried if retry_on_timeout is
139154
set to False"""

0 commit comments

Comments
 (0)