diff --git a/google/ads/googleads/client.py b/google/ads/googleads/client.py index 8a8b70894..239731110 100644 --- a/google/ads/googleads/client.py +++ b/google/ads/googleads/client.py @@ -36,6 +36,7 @@ _logger = logging.getLogger(__name__) _SERVICE_CLIENT_TEMPLATE = "{}Client" +_ASYNC_SERVICE_CLIENT_TEMPLATE = "{}AsyncClient" _VALID_API_VERSIONS = ["v22", "v21", "v20", "v19"] _MESSAGE_TYPES = ["common", "enums", "errors", "resources", "services"] @@ -360,6 +361,7 @@ def get_service( name: str, version: str = _DEFAULT_VERSION, interceptors: Union[list, None] = None, + is_async: bool = False, ) -> Any: """Returns a service client instance for the specified service_name. @@ -372,6 +374,8 @@ def get_service( interceptors: an optional list of interceptors to include in requests. NOTE: this parameter is not intended for non-Google use and is not officially supported. + is_async: whether or not to retrieve the async version of the + service client being requested. Returns: A service client instance associated with the given service_name. @@ -391,8 +395,14 @@ def get_service( try: service_module: Any = import_module(f"{services_path}.{snaked}") + + if is_async: + service_name = _ASYNC_SERVICE_CLIENT_TEMPLATE.format(name) + else: + service_name = _SERVICE_CLIENT_TEMPLATE.format(name) + service_client_class: Any = util.get_nested_attr( - service_module, _SERVICE_CLIENT_TEMPLATE.format(name) + service_module, service_name ) except (AttributeError, ModuleNotFoundError): raise ValueError( diff --git a/tests/client_test.py b/tests/client_test.py index c95d61c54..31a5b12aa 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -490,18 +490,49 @@ def test_load_from_string_versioned(self): def test_get_service(self): # Retrieve service names for all defined service clients. for ver in valid_versions: - services_path = f"google.ads.googleads.{ver}" - service_names = [ - f'{name.rsplit("ServiceClient")[0]}Service' - for name in dir(import_module(services_path)) - if "ServiceClient" in name + services_filepath = f"google/ads/googleads/{ver}/services/services" + # Retrieve list of all the services that exist under the + # {version}/services/services directory. + service_dir_names = [ + name for name in os.listdir(services_filepath) if name.endswith("_service") ] - client = self._create_test_client() + client = self._create_test_client(version=ver) + + for dir_name in service_dir_names: + # Converts from snake case to title case, for example: + # google_ads_service --> GoogleAdsService + service_name = ''.join( + [part.capitalize() for part in dir_name.split("_")] + ) + + # Load each service module + svc = client.get_service(service_name) + self.assertEqual(svc.__class__.__name__, f"{service_name}Client") + + def test_get_async_service(self): + # Retrieve service names for all defined service clients. + for ver in valid_versions: + services_filepath = f"google/ads/googleads/{ver}/services/services" + # Retrieve list of all the services that exist under the + # {version}/services/services directory. + service_dir_names = [ + name for name in os.listdir(services_filepath) if name.endswith("_service") + ] + + client = self._create_test_client(version=ver) + + for dir_name in service_dir_names: + # Converts from snake case to title case, for example: + # google_ads_service --> GoogleAdsService + service_name = ''.join( + [part.capitalize() for part in dir_name.split("_")] + ) + + # Load each service module + svc = client.get_service(service_name, is_async=True) + self.assertEqual(svc.__class__.__name__, f"{service_name}AsyncClient") - # Iterate through retrieval of all service clients by name. - for service_name in service_names: - client.get_service(service_name, version=ver) def test_get_service_custom_endpoint(self): service_name = "GoogleAdsService"