diff --git a/rosapi/src/rosapi/params.py b/rosapi/src/rosapi/params.py index f5e1283c3..a01f89533 100644 --- a/rosapi/src/rosapi/params.py +++ b/rosapi/src/rosapi/params.py @@ -46,26 +46,22 @@ from rosapi.proxy import get_nodes if TYPE_CHECKING: - from rcl_interfaces.srv import ( - GetParameters_Request, - GetParameters_Response, - ListParameters_Request, - ListParameters_Response, - SetParameters_Request, - SetParameters_Response, - ) + from rcl_interfaces.srv import ListParameters_Request, ListParameters_Response from rclpy.client import Client from rclpy.node import Node from rclpy.task import Future + from rosbridge_library.internal.type_support import ROSServiceT + """ Methods to interact with the param server. Values have to be passed as JSON in order to facilitate dynamically typed SRV messages """ # Constants -DEFAULT_PARAM_TIMEOUT_SEC = 5.0 +DEFAULT_PARAM_TIMEOUT_SEC = 1.0 _node = None _timeout_sec = DEFAULT_PARAM_TIMEOUT_SEC +_client_cache: dict[tuple[type, str], Client] = {} _parameter_type_mapping = [ "", @@ -93,8 +89,9 @@ def init(node: Node, timeout_sec: float = DEFAULT_PARAM_TIMEOUT_SEC) -> None: :type timeout_sec: float | int, optional :raises ValueError: If the timeout is not a positive number. """ - global _node, _timeout_sec + global _node, _timeout_sec, _client_cache _node = node + _client_cache = {} if not isinstance(timeout_sec, int | float) or timeout_sec <= 0: msg = "Parameter timeout must be a positive number" @@ -102,6 +99,45 @@ def init(node: Node, timeout_sec: float = DEFAULT_PARAM_TIMEOUT_SEC) -> None: _timeout_sec = timeout_sec +def _get_or_create_client( + service_type: type[ROSServiceT], service_name: str +) -> Client[ROSServiceT.Request, ROSServiceT.Response]: + """Get existing client from cache or create new one.""" + assert _node is not None + cache_key = (service_type, service_name) + + if cache_key in _client_cache: + client = _client_cache[cache_key] + if client.service_is_ready(): + return client + + _node.destroy_client(client) + del _client_cache[cache_key] + + client = _node.create_client( + service_type, + service_name, + callback_group=MutuallyExclusiveCallbackGroup(), + ) + + if client.service_is_ready(): + _client_cache[cache_key] = client + return client + + _node.destroy_client(client) + msg = f"Service {service_name} is not available" + raise Exception(msg) + + +def clear_client_cache() -> None: + """Clear all cached clients.""" + assert _node is not None + global _client_cache + for client in _client_cache.values(): + _node.destroy_client(client) + _client_cache = {} + + async def set_param(node_name: str, name: str, value: str, params_glob: list[str]) -> None: """Set a parameter in a given node.""" if params_glob and not any(fnmatch.fnmatch(str(name), glob) for glob in params_glob): @@ -148,27 +184,16 @@ async def _set_param( assert value is not None setattr(parameter.value, _parameter_type_mapping[parameter_type], loads(value)) - assert _node is not None - client: Client[SetParameters_Request, SetParameters_Response] = _node.create_client( - SetParameters, - f"{node_name}/set_parameters", - callback_group=MutuallyExclusiveCallbackGroup(), - ) - - if not client.service_is_ready(): - _node.destroy_client(client) - msg = f"Service {client.srv_name} is not available" - raise Exception(msg) + client = _get_or_create_client(SetParameters, f"{node_name}/set_parameters") request = SetParameters.Request() request.parameters = [parameter] future = client.call_async(request) + assert _node is not None await futures_wait_for(_node, [future], _timeout_sec) - _node.destroy_client(client) - if not future.done(): future.cancel() msg = "Timeout occurred" @@ -209,27 +234,16 @@ async def _get_param(node_name: str, name: str) -> ParameterValue: Internal helper function for get_param. """ - assert _node is not None - client: Client[GetParameters_Request, GetParameters_Response] = _node.create_client( - GetParameters, - f"{node_name}/get_parameters", - callback_group=MutuallyExclusiveCallbackGroup(), - ) - - if not client.service_is_ready(): - _node.destroy_client(client) - msg = f"Service {client.srv_name} is not available" - raise Exception(msg) + client = _get_or_create_client(GetParameters, f"{node_name}/get_parameters") request = GetParameters.Request() request.names = [name] future = client.call_async(request) + assert _node is not None await futures_wait_for(_node, [future], _timeout_sec) - _node.destroy_client(client) - if not future.done(): future.cancel() msg = "Timeout occurred"