diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index c5d4678022..fda8315dd1 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -193,12 +193,14 @@ def parse_oss_maintenance_start_msg(response): @staticmethod def parse_oss_maintenance_completed_msg(response): # Expected message format is: - # SMIGRATED + # SMIGRATED [ , ...] id = response[1] - node_address = safe_str(response[2]) - slots = response[3] + nodes_to_slots_mapping_data = response[2] + nodes_to_slots_mapping = {} + for node, slots in nodes_to_slots_mapping_data: + nodes_to_slots_mapping[safe_str(node)] = safe_str(slots) - return OSSNodeMigratedNotification(id, node_address, slots) + return OSSNodeMigratedNotification(id, nodes_to_slots_mapping) @staticmethod def parse_maintenance_start_msg(response, notification_type): diff --git a/redis/connection.py b/redis/connection.py index c9a3221b0b..e8dc39a0d6 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -319,6 +319,8 @@ def __init__( oss_cluster_maint_notifications_handler, parser, ) + self._processed_start_maint_notifications = set() + self._skipped_end_maint_notifications = set() @abstractmethod def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: @@ -667,6 +669,22 @@ def maintenance_state(self) -> MaintenanceState: def maintenance_state(self, state: "MaintenanceState"): self._maintenance_state = state + def add_maint_start_notification(self, id: int): + self._processed_start_maint_notifications.add(id) + + def get_processed_start_notifications(self) -> set: + return self._processed_start_maint_notifications + + def add_skipped_end_notification(self, id: int): + self._skipped_end_maint_notifications.add(id) + + def get_skipped_end_notifications(self) -> set: + return self._skipped_end_maint_notifications + + def reset_received_notifications(self): + self._processed_start_maint_notifications.clear() + self._skipped_end_maint_notifications.clear() + def getpeername(self): """ Returns the peer name of the connection. diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index 68994944ff..da5ac9c217 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -5,7 +5,7 @@ import threading import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union from redis.typing import Number @@ -463,9 +463,7 @@ class OSSNodeMigratedNotification(MaintenanceNotification): Args: id (int): Unique identifier for this notification - node_address (Optional[str]): Address of the node that has completed migration - in the format "host:port" - slots (Optional[List[int]]): List of slots that have been migrated + nodes_to_slots_mapping (Dict[str, str]): Mapping of node addresses to slots """ DEFAULT_TTL = 30 @@ -473,12 +471,10 @@ class OSSNodeMigratedNotification(MaintenanceNotification): def __init__( self, id: int, - node_address: str, - slots: Optional[List[int]] = None, + nodes_to_slots_mapping: Dict[str, str], ): super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL) - self.node_address = node_address - self.slots = slots + self.nodes_to_slots_mapping = nodes_to_slots_mapping def __repr__(self) -> str: expiry_time = self.creation_time + self.ttl @@ -486,8 +482,7 @@ def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"id={self.id}, " - f"node_address={self.node_address}, " - f"slots={self.slots}, " + f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, " f"ttl={self.ttl}, " f"creation_time={self.creation_time}, " f"expires_at={expiry_time}, " @@ -899,12 +894,14 @@ def handle_notification(self, notification: MaintenanceNotification): return if notification_type: - self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE) + self.handle_maintenance_start_notification( + MaintenanceState.MAINTENANCE, notification + ) else: self.handle_maintenance_completed_notification() def handle_maintenance_start_notification( - self, maintenance_state: MaintenanceState + self, maintenance_state: MaintenanceState, notification: MaintenanceNotification ): if ( self.connection.maintenance_state == MaintenanceState.MOVING @@ -918,6 +915,11 @@ def handle_maintenance_start_notification( ) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relaxed_timeout) + if isinstance(notification, OSSNodeMigratingNotification): + # add the notification id to the set of processed start maint notifications + # this is used to skip the unrelaxing of the timeouts if we have received more than + # one start notification before the the final end notification + self.connection.add_maint_start_notification(notification.id) def handle_maintenance_completed_notification(self): # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled @@ -931,6 +933,9 @@ def handle_maintenance_completed_notification(self): # timeouts by providing -1 as the relaxed timeout self.connection.update_current_socket_timeout(-1) self.connection.maintenance_state = MaintenanceState.NONE + # reset the sets that keep track of received start maint + # notifications and skipped end maint notifications + self.connection.reset_received_notifications() class OSSMaintNotificationsHandler: @@ -999,40 +1004,55 @@ def handle_oss_maintenance_completed_notification( # Updates the cluster slots cache with the new slots mapping # This will also update the nodes cache with the new nodes mapping - new_node_host, new_node_port = notification.node_address.split(":") + additional_startup_nodes_info = [] + for node_address, _ in notification.nodes_to_slots_mapping.items(): + new_node_host, new_node_port = node_address.split(":") + additional_startup_nodes_info.append( + (new_node_host, int(new_node_port)) + ) self.cluster_client.nodes_manager.initialize( disconnect_startup_nodes_pools=False, - additional_startup_nodes_info=[(new_node_host, int(new_node_port))], + additional_startup_nodes_info=additional_startup_nodes_info, ) - # mark for reconnect all in use connections to the node - this will force them to - # disconnect after they complete their current commands - # Some of them might be used by sub sub and we don't know which ones - so we disconnect - # all in flight connections after they are done with current command execution - for conn in ( - current_node.redis_connection.connection_pool._get_in_use_connections() - ): - conn.mark_for_reconnect() + with current_node.redis_connection.connection_pool._lock: + # mark for reconnect all in use connections to the node - this will force them to + # disconnect after they complete their current commands + # Some of them might be used by sub sub and we don't know which ones - so we disconnect + # all in flight connections after they are done with current command execution + for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): + conn.mark_for_reconnect() - if ( - current_node - not in self.cluster_client.nodes_manager.nodes_cache.values() - ): - # disconnect all free connections to the node - this node will be dropped - # from the cluster, so we don't need to revert the timeouts - for conn in current_node.redis_connection.connection_pool._get_free_connections(): - conn.disconnect() - else: - if self.config.is_relaxed_timeouts_enabled(): - # reset the timeouts for the node to which the connection is connected - # TODO: add check if other maintenance ops are in progress for the same node - CAE-1038 - # and if so, don't reset the timeouts - for conn in ( - *current_node.redis_connection.connection_pool._get_in_use_connections(), - *current_node.redis_connection.connection_pool._get_free_connections(), - ): - conn.reset_tmp_settings(reset_relaxed_timeout=True) - conn.update_current_socket_timeout(relaxed_timeout=-1) - conn.maintenance_state = MaintenanceState.NONE + if ( + current_node + not in self.cluster_client.nodes_manager.nodes_cache.values() + ): + # disconnect all free connections to the node - this node will be dropped + # from the cluster, so we don't need to revert the timeouts + for conn in current_node.redis_connection.connection_pool._get_free_connections(): + conn.disconnect() + else: + if self.config.is_relaxed_timeouts_enabled(): + # reset the timeouts for the node to which the connection is connected + # Perform check if other maintenance ops are in progress for the same node + # and if so, don't reset the timeouts and wait for the last maintenance + # to complete + for conn in ( + *current_node.redis_connection.connection_pool._get_in_use_connections(), + *current_node.redis_connection.connection_pool._get_free_connections(), + ): + if ( + len(conn.get_processed_start_notifications()) + > len(conn.get_skipped_end_notifications()) + 1 + ): + # we have received more start notifications than end notifications + # for this connection - we should not reset the timeouts + # and add the notification id to the set of skipped end notifications + conn.add_skipped_end_notification(notification.id) + else: + conn.reset_tmp_settings(reset_relaxed_timeout=True) + conn.update_current_socket_timeout(relaxed_timeout=-1) + conn.maintenance_state = MaintenanceState.NONE + conn.reset_received_notifications() # mark the notification as processed self._processed_notifications.add(notification) diff --git a/tests/maint_notifications/proxy_server_helpers.py b/tests/maint_notifications/proxy_server_helpers.py index 7358f078d8..1b219f2aaf 100644 --- a/tests/maint_notifications/proxy_server_helpers.py +++ b/tests/maint_notifications/proxy_server_helpers.py @@ -11,37 +11,51 @@ class RespTranslator: """Helper class to translate between RESP and other encodings.""" @staticmethod - def str_or_list_to_resp(txt: str) -> str: - """ - Convert specific string or list to RESP format. - """ - if re.match(r"^<.*>$", txt): - items = txt[1:-1].split(",") - return f"*{len(items)}\r\n" + "\r\n".join( - f"${len(x)}\r\n{x}" for x in items + def oss_maint_notification_to_resp(txt: str) -> str: + """Convert query to RESP format.""" + if txt.startswith("SMIGRATED"): + # Format: SMIGRATED SeqID host:port slot1,range1-range2 host1:port1 slot2,range3-range4 + # SMIGRATED 93923 abc.com:6789 123,789-1000 abc.com:4545 1000-2000 abc.com:4323 900,910,920 + # SMIGRATED - simple string + # SeqID - integer + # host and slots info are provided as array of arrays + # host:port - simple string + # slots - simple string + + parts = txt.split() + notification = parts[0] + seq_id = parts[1] + hosts_and_slots = parts[2:] + resp = ( + ">3\r\n" # Push message with 3 elements + f"+{notification}\r\n" # Element 1: Command + f":{seq_id}\r\n" # Element 2: SeqID + f"*{len(hosts_and_slots) // 2}\r\n" # Element 3: Array of host:port, slots pairs ) + for i in range(0, len(hosts_and_slots), 2): + resp += "*2\r\n" + resp += f"+{hosts_and_slots[i]}\r\n" + resp += f"+{hosts_and_slots[i + 1]}\r\n" else: - return f"${len(txt)}\r\n{txt}" - - @staticmethod - def cluster_slots_to_resp(resp: str) -> str: - """Convert query to RESP format.""" - return ( - f"*{len(resp.split())}\r\n" - + "\r\n".join(f"${len(x)}\r\n{x}" for x in resp.split()) - + "\r\n" - ) - - @staticmethod - def oss_maint_notification_to_resp(resp: str) -> str: - """Convert query to RESP format.""" - return ( - f">{len(resp.split())}\r\n" - + "\r\n".join( - f"{RespTranslator.str_or_list_to_resp(x)}" for x in resp.split() + # SMIGRATING + # Format: SMIGRATING SeqID slot,range1-range2 + # SMIGRATING 93923 123,789-1000 + # SMIGRATING - simple string + # SeqID - integer + # slots - simple string + + parts = txt.split() + notification = parts[0] + seq_id = parts[1] + slots = parts[2] + + resp = ( + ">3\r\n" # Push message with 3 elements + f"+{notification}\r\n" # Element 1: Command + f":{seq_id}\r\n" # Element 2: SeqID + f"+{slots}\r\n" # Element 3: Array of [host:port, slots] pairs ) - + "\r\n" - ) + return resp @dataclass diff --git a/tests/maint_notifications/test_cluster_maint_notifications_handling.py b/tests/maint_notifications/test_cluster_maint_notifications_handling.py index 8e2cf55efb..e49f5c6131 100644 --- a/tests/maint_notifications/test_cluster_maint_notifications_handling.py +++ b/tests/maint_notifications/test_cluster_maint_notifications_handling.py @@ -1,6 +1,4 @@ -from asyncio import Queue from dataclasses import dataclass -from threading import Thread from typing import List, Optional, cast from redis import ConnectionPool, RedisCluster @@ -29,6 +27,32 @@ ClusterNode("127.0.0.1", NODE_PORT_3), ] +CLUSTER_SLOTS_INTERCEPTOR_NAME = "test_topology" + + +class TestRespTranslatorHelper: + def test_oss_maint_notification_to_resp(self): + resp = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 123,456,5000-7000" + ) + assert resp == ">3\r\n+SMIGRATING\r\n:12\r\n+123,456,5000-7000\r\n" + + resp = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 12 127.0.0.1:15380 123,456,5000-7000" + ) + assert ( + resp + == ">3\r\n+SMIGRATED\r\n:12\r\n*1\r\n*2\r\n+127.0.0.1:15380\r\n+123,456,5000-7000\r\n" + ) + resp = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 12 127.0.0.1:15380 123,456,5000-7000 127.0.0.1:15381 7000-8000 127.0.0.1:15382 8000-9000" + ) + + assert ( + resp + == ">3\r\n+SMIGRATED\r\n:12\r\n*3\r\n*2\r\n+127.0.0.1:15380\r\n+123,456,5000-7000\r\n*2\r\n+127.0.0.1:15381\r\n+7000-8000\r\n*2\r\n+127.0.0.1:15382\r\n+8000-9000\r\n" + ) + class TestClusterMaintNotificationsBase: """Base class for cluster maintenance notifications handling tests.""" @@ -409,6 +433,7 @@ class TestClusterMaintNotificationsHandlingBase(TestClusterMaintNotificationsBas def setup_method(self): """Set up test fixtures with mocked sockets.""" self.proxy_helper = ProxyInterceptorHelper() + self.proxy_helper.cleanup_interceptors(CLUSTER_SLOTS_INTERCEPTOR_NAME) # Create maintenance notifications config self.config = MaintNotificationsConfig( @@ -419,6 +444,7 @@ def setup_method(self): def teardown_method(self): """Clean up test fixtures.""" self.cluster.close() + # interceptors that are changed during the tests are collected in the proxy helper self.proxy_helper.cleanup_interceptors() @@ -513,7 +539,7 @@ def test_receive_smigrating_notification(self): # send a notification to node 1 notification = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 12 <123,456,5000-7000>" + "SMIGRATING 12 123,456,5000-7000" ) self.proxy_helper.send_notification(NODE_PORT_1, notification) @@ -570,7 +596,7 @@ def test_receive_smigrating_with_disabled_relaxed_timeout(self): # send a notification to node 1 notification = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 12 <123,456,5000-7000>" + "SMIGRATING 12 123,456,5000-7000" ) self.proxy_helper.send_notification(NODE_PORT_1, notification) @@ -596,7 +622,7 @@ def test_receive_smigrated_notification(self): self._warm_up_connection_pools(self.cluster, created_connections_count=3) self.proxy_helper.set_cluster_slots( - "test_topology", + CLUSTER_SLOTS_INTERCEPTOR_NAME, [ SlotsRange("0.0.0.0", NODE_PORT_NEW, 0, 5460), SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), @@ -605,7 +631,36 @@ def test_receive_smigrated_notification(self): ) # send a notification to node 1 notification = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 12 127.0.0.1:15380 <123,456,5000-7000>" + "SMIGRATED 12 127.0.0.1:15380 123,456,5000-7000" + ) + self.proxy_helper.send_notification(NODE_PORT_1, notification) + + # execute a command that will receive the notification + res = self.cluster.set("anyprefix:{3}:k", "VAL") + assert res is True + + # validate the cluster topology was updated + new_node = self.cluster.nodes_manager.get_node( + host="0.0.0.0", port=NODE_PORT_NEW + ) + assert new_node is not None + + def test_receive_smigrated_notification_with_two_nodes(self): + """Test receiving an OSS maintenance completed notification.""" + # create three connections in each node's connection pool + self._warm_up_connection_pools(self.cluster, created_connections_count=3) + + self.proxy_helper.set_cluster_slots( + CLUSTER_SLOTS_INTERCEPTOR_NAME, + [ + SlotsRange("0.0.0.0", NODE_PORT_NEW, 0, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + # send a notification to node 1 + notification = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 12 127.0.0.1:15380 123,456,5000-7000 127.0.0.1:15382 110-120" ) self.proxy_helper.send_notification(NODE_PORT_1, notification) @@ -628,7 +683,7 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): node_2 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_2) smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 12 <123,2000-3000>" + "SMIGRATING 12 123,2000-3000" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) # execute command with node 1 connection @@ -649,7 +704,7 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): ) smigrating_node_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 13 <8000-9000>" + "SMIGRATING 13 8000-9000" ) self.proxy_helper.send_notification(NODE_PORT_2, smigrating_node_2) @@ -674,17 +729,17 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): ], ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15381 <123,2000-3000>" + "SMIGRATED 14 0.0.0.0:15381 123,2000-3000" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 15 0.0.0.0:15381 <8000-9000>" + "SMIGRATED 15 0.0.0.0:15381 8000-9000" ) self.proxy_helper.send_notification(NODE_PORT_2, smigrated_node_2) self.proxy_helper.set_cluster_slots( - "test_topology", + CLUSTER_SLOTS_INTERCEPTOR_NAME, [ SlotsRange("0.0.0.0", NODE_PORT_1, 0, 122), SlotsRange("0.0.0.0", NODE_PORT_3, 123, 123), @@ -725,7 +780,7 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): ) self.proxy_helper.set_cluster_slots( - "test_topology", + CLUSTER_SLOTS_INTERCEPTOR_NAME, [ SlotsRange("0.0.0.0", NODE_PORT_1, 0, 122), SlotsRange("0.0.0.0", NODE_PORT_3, 123, 123), @@ -774,7 +829,7 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): node_3 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 12 <0-5460>" + "SMIGRATING 12 0-5460" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) # execute command with node 1 connection @@ -795,7 +850,7 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): ) smigrating_node_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 13 <5461-10922>" + "SMIGRATING 13 5461-10922" ) self.proxy_helper.send_notification(NODE_PORT_2, smigrating_node_2) @@ -821,16 +876,16 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15382 <0-5460>" + "SMIGRATED 14 0.0.0.0:15382 0-5460" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 15 0.0.0.0:15382 <5461-10922>" + "SMIGRATED 15 0.0.0.0:15382 5461-10922" ) self.proxy_helper.send_notification(NODE_PORT_2, smigrated_node_2) self.proxy_helper.set_cluster_slots( - "test_topology", + CLUSTER_SLOTS_INTERCEPTOR_NAME, [ SlotsRange("0.0.0.0", 15382, 0, 5460), SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), @@ -871,7 +926,7 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): ) self.proxy_helper.set_cluster_slots( - "test_topology", + CLUSTER_SLOTS_INTERCEPTOR_NAME, [ SlotsRange("0.0.0.0", 15382, 0, 5460), SlotsRange("0.0.0.0", 15383, 5461, 10922), @@ -909,7 +964,7 @@ def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( self._warm_up_connection_pools(self.cluster, created_connections_count=1) smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 12 <1000-2000>" + "SMIGRATING 12 1000-2000" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) # execute command with node 1 connection @@ -927,7 +982,7 @@ def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( ) smigrating_node_1_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 13 <3000-4000>" + "SMIGRATING 13 3000-4000" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1_2) # execute command with node 1 connection @@ -944,26 +999,26 @@ def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( ], ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15380 <1000-2000>" + "SMIGRATED 14 0.0.0.0:15380 1000-2000" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) # execute command with node 1 connection self.cluster.set("anyprefix:{3}:k", "VAL") # this functionality is part of CAE-1038 and will be added later # validate the timeout is still relaxed - # self._validate_connections_states( - # self.cluster, - # [ - # ConnectionStateExpectation( - # node_port=NODE_PORT_1, - # changed_connections_count=1, - # state=MaintenanceState.MAINTENANCE, - # relaxed_timeout=self.config.relaxed_timeout, - # ), - # ], - # ) + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) smigrated_node_1_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 15 0.0.0.0:15381 <3000-4000>" + "SMIGRATED 15 0.0.0.0:15381 3000-4000" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1_2) # execute command with node 1 connection @@ -1002,7 +1057,7 @@ def test_smigrating_smigrated_with_sharded_pubsub( assert msg is not None and msg["type"] == "ssubscribe" smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 12 <5200-5460>" + "SMIGRATING 12 5200-5460" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) @@ -1020,7 +1075,7 @@ def test_smigrating_smigrated_with_sharded_pubsub( ) self.proxy_helper.set_cluster_slots( - "test_topology", + CLUSTER_SLOTS_INTERCEPTOR_NAME, [ SlotsRange("0.0.0.0", NODE_PORT_1, 0, 5200), SlotsRange("0.0.0.0", NODE_PORT_2, 5201, 10922), @@ -1029,7 +1084,7 @@ def test_smigrating_smigrated_with_sharded_pubsub( ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15380 <5200-5460>" + "SMIGRATED 14 0.0.0.0:15380 5200-5460" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) # execute command with node 1 connection @@ -1090,7 +1145,7 @@ def test_smigrating_smigrated_with_std_pubsub( assert msg is not None and msg["type"] == "subscribe" smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 12 <5200-5460>" + "SMIGRATING 12 5200-5460" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) @@ -1105,7 +1160,7 @@ def test_smigrating_smigrated_with_std_pubsub( assert pubsub.connection._socket_connect_timeout == 30 self.proxy_helper.set_cluster_slots( - "test_topology", + CLUSTER_SLOTS_INTERCEPTOR_NAME, [ SlotsRange("0.0.0.0", NODE_PORT_1, 0, 5200), SlotsRange("0.0.0.0", NODE_PORT_2, 5201, 10922), @@ -1114,7 +1169,7 @@ def test_smigrating_smigrated_with_std_pubsub( ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15380 <5200-5460>" + "SMIGRATED 14 0.0.0.0:15380 5200-5460" ) self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) # execute command with node 1 connection diff --git a/tests/maint_notifications/test_maint_notifications.py b/tests/maint_notifications/test_maint_notifications.py index adb9ebb5ea..47a27a48cf 100644 --- a/tests/maint_notifications/test_maint_notifications.py +++ b/tests/maint_notifications/test_maint_notifications.py @@ -493,45 +493,46 @@ class TestOSSNodeMigratedNotification: def test_init_with_defaults(self): """Test OSSNodeMigratedNotification initialization with default values.""" with patch("time.monotonic", return_value=1000): + nodes_to_slots_mapping = {"127.0.0.1:6380": "1-100"} notification = OSSNodeMigratedNotification( - id=1, node_address="127.0.0.1:6380" + id=1, nodes_to_slots_mapping=nodes_to_slots_mapping ) assert notification.id == 1 assert notification.ttl == OSSNodeMigratedNotification.DEFAULT_TTL assert notification.creation_time == 1000 - assert notification.node_address == "127.0.0.1:6380" - assert notification.slots is None + assert notification.nodes_to_slots_mapping == nodes_to_slots_mapping def test_init_with_all_parameters(self): """Test OSSNodeMigratedNotification initialization with all parameters.""" with patch("time.monotonic", return_value=1000): - slots = [1, 2, 3, 4, 5] - node_address = "127.0.0.1:6380" + nodes_to_slots_mapping = { + "127.0.0.1:6380": "1-100", + "127.0.0.1:6381": "101-200", + } notification = OSSNodeMigratedNotification( id=1, - node_address=node_address, - slots=slots, + nodes_to_slots_mapping=nodes_to_slots_mapping, ) assert notification.id == 1 assert notification.ttl == OSSNodeMigratedNotification.DEFAULT_TTL assert notification.creation_time == 1000 - assert notification.node_address == node_address - assert notification.slots == slots + assert notification.nodes_to_slots_mapping == nodes_to_slots_mapping def test_default_ttl(self): """Test that DEFAULT_TTL is used correctly.""" assert OSSNodeMigratedNotification.DEFAULT_TTL == 30 - notification = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification = OSSNodeMigratedNotification( + id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + ) assert notification.ttl == 30 def test_repr(self): """Test OSSNodeMigratedNotification string representation.""" with patch("time.monotonic", return_value=1000): - node_address = "127.0.0.1:6380" + nodes_to_slots_mapping = {"127.0.0.1:6380": "1-100"} notification = OSSNodeMigratedNotification( id=1, - node_address=node_address, - slots=[1, 2, 3], + nodes_to_slots_mapping=nodes_to_slots_mapping, ) with patch("time.monotonic", return_value=1010): # 10 seconds later @@ -546,26 +547,30 @@ def test_equality_same_id_and_type(self): """Test equality for notifications with same id and type.""" notification1 = OSSNodeMigratedNotification( id=1, - node_address="127.0.0.1:6380", - slots=[1, 2, 3], + nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"}, ) notification2 = OSSNodeMigratedNotification( id=1, - node_address="127.0.0.1:6381", - slots=[4, 5, 6], + nodes_to_slots_mapping={"127.0.0.1:6381": "101-200"}, ) # Should be equal because id and type are the same assert notification1 == notification2 def test_equality_different_id(self): """Test inequality for notifications with different id.""" - notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") - notification2 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6380") + notification1 = OSSNodeMigratedNotification( + id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + ) + notification2 = OSSNodeMigratedNotification( + id=2, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + ) assert notification1 != notification2 def test_equality_different_type(self): """Test inequality for notifications of different types.""" - notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification1 = OSSNodeMigratedNotification( + id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + ) notification2 = NodeMigratedNotification(id=1) assert notification1 != notification2 @@ -573,29 +578,39 @@ def test_hash_same_id_and_type(self): """Test hash for notifications with same id and type.""" notification1 = OSSNodeMigratedNotification( id=1, - node_address="127.0.0.1:6380", - slots=[1, 2, 3], + nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"}, ) notification2 = OSSNodeMigratedNotification( id=1, - node_address="127.0.0.1:6381", - slots=[4, 5, 6], + nodes_to_slots_mapping={"127.0.0.1:6381": "101-200"}, ) # Should have same hash because id and type are the same assert hash(notification1) == hash(notification2) def test_hash_different_id(self): """Test hash for notifications with different id.""" - notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") - notification2 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6380") + notification1 = OSSNodeMigratedNotification( + id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + ) + notification2 = OSSNodeMigratedNotification( + id=2, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + ) assert hash(notification1) != hash(notification2) def test_in_set(self): """Test that notifications can be used in sets.""" - notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") - notification2 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") - notification3 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6381") - notification4 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6381") + notification1 = OSSNodeMigratedNotification( + id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + ) + notification2 = OSSNodeMigratedNotification( + id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + ) + notification3 = OSSNodeMigratedNotification( + id=2, nodes_to_slots_mapping={"127.0.0.1:6381": "101-200"} + ) + notification4 = OSSNodeMigratedNotification( + id=2, nodes_to_slots_mapping={"127.0.0.1:6381": "101-200"} + ) notification_set = {notification1, notification2, notification3, notification4} assert ( @@ -849,7 +864,9 @@ def test_handle_notification_migrating(self): self.handler, "handle_maintenance_start_notification" ) as mock_handle: self.handler.handle_notification(notification) - mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) + mock_handle.assert_called_once_with( + MaintenanceState.MAINTENANCE, notification + ) def test_handle_notification_migrated(self): """Test handling of NodeMigratedNotification.""" @@ -869,7 +886,9 @@ def test_handle_notification_failing_over(self): self.handler, "handle_maintenance_start_notification" ) as mock_handle: self.handler.handle_notification(notification) - mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) + mock_handle.assert_called_once_with( + MaintenanceState.MAINTENANCE, notification + ) def test_handle_notification_failed_over(self): """Test handling of NodeFailedOverNotification.""" @@ -896,7 +915,7 @@ def test_handle_maintenance_start_notification_disabled(self): handler = MaintNotificationsConnectionHandler(self.mock_connection, config) result = handler.handle_maintenance_start_notification( - MaintenanceState.MAINTENANCE + MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5) ) assert result is None @@ -907,7 +926,7 @@ def test_handle_maintenance_start_notification_moving_state(self): self.mock_connection.maintenance_state = MaintenanceState.MOVING result = self.handler.handle_maintenance_start_notification( - MaintenanceState.MAINTENANCE + MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5) ) assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() @@ -916,7 +935,9 @@ def test_handle_maintenance_start_notification_success(self): """Test successful maintenance start notification handling for migrating.""" self.mock_connection.maintenance_state = MaintenanceState.NONE - self.handler.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE) + self.handler.handle_maintenance_start_notification( + MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5) + ) assert self.mock_connection.maintenance_state == MaintenanceState.MAINTENANCE self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)