From d6fd598abc856a5e4cb8807be6fff41936c65125 Mon Sep 17 00:00:00 2001 From: ovsds Date: Tue, 18 Nov 2025 12:18:46 +0100 Subject: [PATCH] feat: BI-6687 add base temporal worker app --- lib/dl_temporal/dl_temporal/__init__.py | 8 +- lib/dl_temporal/dl_temporal/app/__init__.py | 32 +++ lib/dl_temporal/dl_temporal/app/aiohttp.py | 118 +++++++++ lib/dl_temporal/dl_temporal/app/base.py | 68 ++++++ lib/dl_temporal/dl_temporal/app/temporal.py | 100 ++++++++ .../dl_temporal/client/__init__.py | 8 +- lib/dl_temporal/dl_temporal/client/client.py | 24 +- .../dl_temporal/client/metadata.py | 17 ++ .../worker => dl_temporal/utils}/__init__.py | 0 .../dl_temporal/utils/aiohttp/__init__.py | 18 ++ .../utils/aiohttp/handlers/__init__.py | 18 ++ .../utils/aiohttp/handlers/health/__init__.py | 16 ++ .../aiohttp/handlers/health/liveness_probe.py | 22 ++ .../handlers/health/readiness_probe.py | 63 +++++ .../utils/aiohttp/handlers/responses.py | 49 ++++ .../dl_temporal/utils/aiohttp/printer.py | 24 ++ .../dl_temporal/utils/app/__init__.py | 26 ++ lib/dl_temporal/dl_temporal/utils/app/base.py | 230 ++++++++++++++++++ .../dl_temporal/utils/app/exceptions.py | 18 ++ .../dl_temporal/utils/app/models.py | 13 + .../dl_temporal/utils/singleton.py | 89 +++++++ .../db/{worker => }/activities.py | 0 .../dl_temporal_tests/db/app/__init__.py | 0 .../dl_temporal_tests/db/app/config.yaml | 9 + .../dl_temporal_tests/db/app/conftest.py | 87 +++++++ .../dl_temporal_tests/db/app/test_handlers.py | 40 +++ .../db/{worker => app}/test_worker.py | 37 +-- .../db/client/test_check_health.py | 4 +- .../db/client/test_metadata.py | 4 +- .../dl_temporal_tests/db/conftest.py | 26 +- .../db/{worker => }/workflows.py | 2 +- .../dl_temporal_tests/unit/__init__.py | 0 .../dl_temporal_tests/unit/utils/__init__.py | 0 .../unit/utils/test_singleton.py | 220 +++++++++++++++++ lib/dl_temporal/pyproject.toml | 8 + 35 files changed, 1343 insertions(+), 55 deletions(-) create mode 100644 lib/dl_temporal/dl_temporal/app/__init__.py create mode 100644 lib/dl_temporal/dl_temporal/app/aiohttp.py create mode 100644 lib/dl_temporal/dl_temporal/app/base.py create mode 100644 lib/dl_temporal/dl_temporal/app/temporal.py rename lib/dl_temporal/{dl_temporal_tests/db/worker => dl_temporal/utils}/__init__.py (100%) create mode 100644 lib/dl_temporal/dl_temporal/utils/aiohttp/__init__.py create mode 100644 lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/__init__.py create mode 100644 lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/__init__.py create mode 100644 lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/liveness_probe.py create mode 100644 lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/readiness_probe.py create mode 100644 lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/responses.py create mode 100644 lib/dl_temporal/dl_temporal/utils/aiohttp/printer.py create mode 100644 lib/dl_temporal/dl_temporal/utils/app/__init__.py create mode 100644 lib/dl_temporal/dl_temporal/utils/app/base.py create mode 100644 lib/dl_temporal/dl_temporal/utils/app/exceptions.py create mode 100644 lib/dl_temporal/dl_temporal/utils/app/models.py create mode 100644 lib/dl_temporal/dl_temporal/utils/singleton.py rename lib/dl_temporal/dl_temporal_tests/db/{worker => }/activities.py (100%) create mode 100644 lib/dl_temporal/dl_temporal_tests/db/app/__init__.py create mode 100644 lib/dl_temporal/dl_temporal_tests/db/app/config.yaml create mode 100644 lib/dl_temporal/dl_temporal_tests/db/app/conftest.py create mode 100644 lib/dl_temporal/dl_temporal_tests/db/app/test_handlers.py rename lib/dl_temporal/dl_temporal_tests/db/{worker => app}/test_worker.py (60%) rename lib/dl_temporal/dl_temporal_tests/db/{worker => }/workflows.py (97%) create mode 100644 lib/dl_temporal/dl_temporal_tests/unit/__init__.py create mode 100644 lib/dl_temporal/dl_temporal_tests/unit/utils/__init__.py create mode 100644 lib/dl_temporal/dl_temporal_tests/unit/utils/test_singleton.py diff --git a/lib/dl_temporal/dl_temporal/__init__.py b/lib/dl_temporal/dl_temporal/__init__.py index 8132774b1c..0fe8b4606c 100644 --- a/lib/dl_temporal/dl_temporal/__init__.py +++ b/lib/dl_temporal/dl_temporal/__init__.py @@ -13,9 +13,12 @@ from .client import ( AlreadyExists, EmptyMetadataProvider, + EmptyMetadataProviderSettings, MetadataProvider, + MetadataProviderSettings, PermissionDenied, TemporalClient, + TemporalClientDependencies, TemporalClientError, TemporalClientSettings, ) @@ -33,10 +36,13 @@ "WorkflowProtocol", "define_activity", "define_workflow", - "MetadataProvider", "EmptyMetadataProvider", + "EmptyMetadataProviderSettings", + "MetadataProvider", + "MetadataProviderSettings", "TemporalClientError", "TemporalClient", + "TemporalClientDependencies", "TemporalClientSettings", "AlreadyExists", "PermissionDenied", diff --git a/lib/dl_temporal/dl_temporal/app/__init__.py b/lib/dl_temporal/dl_temporal/app/__init__.py new file mode 100644 index 0000000000..767ad65e4e --- /dev/null +++ b/lib/dl_temporal/dl_temporal/app/__init__.py @@ -0,0 +1,32 @@ +from .aiohttp import ( + HttpServerAppFactoryMixin, + HttpServerAppMixin, + HttpServerAppSettingsMixin, + HttpServerSettings, +) +from .base import ( + BaseTemporalWorkerApp, + BaseTemporalWorkerAppFactory, + BaseTemporalWorkerAppSettings, +) +from .temporal import ( + TemporalWorkerAppFactoryMixin, + TemporalWorkerAppMixin, + TemporalWorkerAppSettingsMixin, + TemporalWorkerSettings, +) + + +__all__ = [ + "HttpServerSettings", + "HttpServerAppFactoryMixin", + "HttpServerAppMixin", + "HttpServerAppSettingsMixin", + "TemporalWorkerSettings", + "TemporalWorkerAppFactoryMixin", + "TemporalWorkerAppMixin", + "TemporalWorkerAppSettingsMixin", + "BaseTemporalWorkerApp", + "BaseTemporalWorkerAppFactory", + "BaseTemporalWorkerAppSettings", +] diff --git a/lib/dl_temporal/dl_temporal/app/aiohttp.py b/lib/dl_temporal/dl_temporal/app/aiohttp.py new file mode 100644 index 0000000000..d1b74bbd8c --- /dev/null +++ b/lib/dl_temporal/dl_temporal/app/aiohttp.py @@ -0,0 +1,118 @@ +from typing import ( + Generic, + TypeVar, +) + +import aiohttp.web +import attr +from typing_extensions import override + +import dl_settings +import dl_temporal.utils.aiohttp.handlers as aiohttp_handlers +import dl_temporal.utils.aiohttp.printer as aiohttp_printer +import dl_temporal.utils.app as app_utils +import dl_temporal.utils.singleton as singleton_utils + + +class HttpServerSettings(dl_settings.BaseSettings): + host: str + port: int + + +class HttpServerAppSettingsMixin(app_utils.BaseAppSettings): + http_server: HttpServerSettings = NotImplemented + + +@attr.define(frozen=True, kw_only=True) +class HttpServerAppMixin(app_utils.BaseApp): + ... + + +AppType = TypeVar("AppType", bound=HttpServerAppMixin) + + +@attr.define(kw_only=True, slots=False) +class HttpServerAppFactoryMixin( + app_utils.BaseAppFactory[AppType], + Generic[AppType], +): + settings: HttpServerAppSettingsMixin + + @override + @singleton_utils.singleton_class_method_result + async def _get_main_callbacks( + self, + ) -> list[app_utils.Callback]: + result = await super()._get_main_callbacks() + + result.append( + app_utils.Callback( + coroutine=aiohttp.web._run_app( + app=await self._get_aiohttp_app(), + host=self.settings.http_server.host, + port=self.settings.http_server.port, + print=aiohttp_printer.PrintLogger(), + ), + name="run_http_server", + ), + ) + + return result + + @singleton_utils.singleton_class_method_result + async def _get_aiohttp_app( + self, + ) -> aiohttp.web.Application: + app = aiohttp.web.Application() + app.add_routes( + routes=await self._get_aiohttp_app_routes(), + ) + return app + + @singleton_utils.singleton_class_method_result + async def _get_aiohttp_app_routes( + self, + ) -> list[aiohttp.web.RouteDef]: + result: list[aiohttp.web.RouteDef] = [] + + aiohttp_liveness_probe_handler = await self._get_aiohttp_liveness_probe_handler() + result.append( + aiohttp.web.route( + method="GET", + path="/api/v1/health/liveness", + handler=aiohttp_liveness_probe_handler.process, + ), + ) + + aiohttp_readiness_probe_handler = await self._get_aiohttp_readiness_probe_handler() + result.append( + aiohttp.web.route( + method="GET", + path="/api/v1/health/readiness", + handler=aiohttp_readiness_probe_handler.process, + ), + ) + + return result + + @singleton_utils.singleton_class_method_result + async def _get_aiohttp_liveness_probe_handler( + self, + ) -> aiohttp_handlers.LivenessProbeHandler: + return aiohttp_handlers.LivenessProbeHandler() + + @singleton_utils.singleton_class_method_result + async def _get_aiohttp_readiness_probe_handler( + self, + ) -> aiohttp_handlers.ReadinessProbeHandler: + subsystems = await self._get_aiohttp_subsystem_readiness_callbacks() + + return aiohttp_handlers.ReadinessProbeHandler( + subsystems=subsystems, + ) + + @singleton_utils.singleton_class_method_result + async def _get_aiohttp_subsystem_readiness_callbacks( + self, + ) -> list[aiohttp_handlers.SubsystemReadinessCallback]: + return [] diff --git a/lib/dl_temporal/dl_temporal/app/base.py b/lib/dl_temporal/dl_temporal/app/base.py new file mode 100644 index 0000000000..8ab6c86575 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/app/base.py @@ -0,0 +1,68 @@ +from typing import ( + Generic, + TypeVar, +) + +import attr +from typing_extensions import override + +import dl_temporal.app.aiohttp as aiohttp_app +import dl_temporal.app.temporal as temporal_app +import dl_temporal.utils.aiohttp as aiohttp_utils +import dl_temporal.utils.app as app_utils +import dl_temporal.utils.singleton as singleton_utils + + +class BaseTemporalWorkerAppSettings( + temporal_app.TemporalWorkerAppSettingsMixin, + aiohttp_app.HttpServerAppSettingsMixin, + app_utils.BaseAppSettings, +): + ... + + +@attr.define(frozen=True, kw_only=True) +class BaseTemporalWorkerApp( + temporal_app.TemporalWorkerAppMixin, + aiohttp_app.HttpServerAppMixin, + app_utils.BaseApp, +): + ... + + +AppType = TypeVar("AppType", bound=BaseTemporalWorkerApp) + + +@attr.define(kw_only=True, slots=False) +class BaseTemporalWorkerAppFactory( + temporal_app.TemporalWorkerAppFactoryMixin[AppType], + aiohttp_app.HttpServerAppFactoryMixin[AppType], + app_utils.BaseAppFactory[AppType], + Generic[AppType], +): + settings: BaseTemporalWorkerAppSettings + + @override + @singleton_utils.singleton_class_method_result + async def _get_aiohttp_subsystem_readiness_callbacks( + self, + ) -> list[aiohttp_utils.SubsystemReadinessCallback]: + result = await super()._get_aiohttp_subsystem_readiness_callbacks() + + temporal_client = await self._get_temporal_client() + result.append( + aiohttp_utils.SubsystemReadinessAsyncCallback( + name="temporal_client.check_health", + is_ready=temporal_client.check_health, + ), + ) + + temporal_worker = await self._get_temporal_worker() + result.append( + aiohttp_utils.SubsystemReadinessSyncCallback( + name="temporal_worker.is_running", + is_ready=lambda: temporal_worker.is_running, + ), + ) + + return result diff --git a/lib/dl_temporal/dl_temporal/app/temporal.py b/lib/dl_temporal/dl_temporal/app/temporal.py new file mode 100644 index 0000000000..7db1e0267f --- /dev/null +++ b/lib/dl_temporal/dl_temporal/app/temporal.py @@ -0,0 +1,100 @@ +import abc +from typing import ( + Generic, + TypeVar, +) + +import attr +import temporalio.worker +from typing_extensions import override + +import dl_settings +import dl_temporal.base as base +import dl_temporal.client as client +import dl_temporal.utils.app as app_utils +import dl_temporal.utils.singleton as singleton_utils +import dl_temporal.worker as worker + + +class TemporalWorkerSettings(dl_settings.BaseSettings): + task_queue: str + + +class TemporalWorkerAppSettingsMixin(app_utils.BaseAppSettings): + temporal_client: client.TemporalClientSettings = NotImplemented + temporal_worker: TemporalWorkerSettings = NotImplemented + + +@attr.define(frozen=True, kw_only=True) +class TemporalWorkerAppMixin(app_utils.BaseApp): + ... + + +AppType = TypeVar("AppType", bound=TemporalWorkerAppMixin) + + +@attr.define(kw_only=True, slots=False) +class TemporalWorkerAppFactoryMixin( + app_utils.BaseAppFactory[AppType], + Generic[AppType], +): + settings: TemporalWorkerAppSettingsMixin + + @override + @singleton_utils.singleton_class_method_result + async def _get_main_callbacks( + self, + ) -> list[app_utils.Callback]: + result = await super()._get_main_callbacks() + + temporal_worker = await self._get_temporal_worker() + result.append(app_utils.Callback(name="temporal_worker", coroutine=temporal_worker.run())) + + return result + + @singleton_utils.singleton_class_method_result + async def _get_temporal_client( + self, + ) -> client.TemporalClient: + return await client.TemporalClient.from_dependencies( + dependencies=client.TemporalClientDependencies( + namespace=self.settings.temporal_client.namespace, + host=self.settings.temporal_client.host, + port=self.settings.temporal_client.port, + tls=self.settings.temporal_client.tls, + lazy=False, + metadata_provider=await self._get_temporal_client_metadata_provider(), + ), + ) + + @singleton_utils.singleton_class_method_result + async def _get_temporal_worker( + self, + ) -> temporalio.worker.Worker: + return worker.create_worker( + task_queue=self.settings.temporal_worker.task_queue, + client=await self._get_temporal_client(), + workflows=await self._get_temporal_workflows(), + activities=await self._get_temporal_activities(), + ) + + @abc.abstractmethod + @singleton_utils.singleton_class_method_result + async def _get_temporal_workflows( + self, + ) -> list[type[base.WorkflowProtocol]]: + ... + + @abc.abstractmethod + @singleton_utils.singleton_class_method_result + async def _get_temporal_activities( + self, + ) -> list[base.ActivityProtocol]: + ... + + @abc.abstractmethod + @singleton_utils.singleton_class_method_result + async def _get_temporal_client_metadata_provider( + self, + ) -> client.MetadataProvider: + ... diff --git a/lib/dl_temporal/dl_temporal/client/__init__.py b/lib/dl_temporal/dl_temporal/client/__init__.py index 0a91029838..a66c2cef6c 100644 --- a/lib/dl_temporal/dl_temporal/client/__init__.py +++ b/lib/dl_temporal/dl_temporal/client/__init__.py @@ -1,5 +1,6 @@ from .client import ( TemporalClient, + TemporalClientDependencies, TemporalClientSettings, ) from .exc import ( @@ -9,14 +10,19 @@ ) from .metadata import ( EmptyMetadataProvider, + EmptyMetadataProviderSettings, MetadataProvider, + MetadataProviderSettings, ) __all__ = [ - "MetadataProvider", "EmptyMetadataProvider", + "EmptyMetadataProviderSettings", + "MetadataProvider", + "MetadataProviderSettings", "TemporalClient", + "TemporalClientDependencies", "TemporalClientSettings", "TemporalClientError", "PermissionDenied", diff --git a/lib/dl_temporal/dl_temporal/client/client.py b/lib/dl_temporal/dl_temporal/client/client.py index d91e581ff7..bb5ac000bf 100644 --- a/lib/dl_temporal/dl_temporal/client/client.py +++ b/lib/dl_temporal/dl_temporal/client/client.py @@ -8,6 +8,7 @@ import temporalio.service from typing_extensions import Self +import dl_settings import dl_temporal.base as base import dl_temporal.client.exc as exc import dl_temporal.client.metadata as metadata @@ -16,8 +17,16 @@ LOGGER = logging.getLogger(__name__) +class TemporalClientSettings(dl_settings.BaseSettings): + host: str + port: int + tls: bool + namespace: str + metadata_provider: dl_settings.TypedAnnotation[metadata.MetadataProviderSettings] + + @attrs.define(kw_only=True, frozen=True) -class TemporalClientSettings: +class TemporalClientDependencies: namespace: str host: str port: int = 7233 @@ -38,15 +47,15 @@ class TemporalClient: _update_metadata_task: asyncio.Task = attrs.field(init=False) @classmethod - async def from_settings(cls, settings: TemporalClientSettings) -> Self: - metadata_provider = settings.metadata_provider + async def from_dependencies(cls, dependencies: TemporalClientDependencies) -> Self: + metadata_provider = dependencies.metadata_provider rpc_metadata = await metadata_provider.get_metadata() temporal_client = await temporalio.client.Client.connect( - target_host=settings.target_host, - namespace=settings.namespace, - lazy=settings.lazy, - tls=settings.tls, + target_host=dependencies.target_host, + namespace=dependencies.namespace, + lazy=dependencies.lazy, + tls=dependencies.tls, rpc_metadata=rpc_metadata, data_converter=base.DataConverter(), ) @@ -89,6 +98,7 @@ async def check_health(self) -> bool: timeout=datetime.timedelta(seconds=1), ) except Exception: + LOGGER.exception("Temporal client health check failed") return False async def check_auth(self) -> bool: diff --git a/lib/dl_temporal/dl_temporal/client/metadata.py b/lib/dl_temporal/dl_temporal/client/metadata.py index 9824f5de85..eb9f71fbea 100644 --- a/lib/dl_temporal/dl_temporal/client/metadata.py +++ b/lib/dl_temporal/dl_temporal/client/metadata.py @@ -5,6 +5,8 @@ import attrs +import dl_settings + LOGGER = logging.getLogger(__name__) @@ -27,7 +29,22 @@ async def get_metadata(self) -> Mapping[str, str]: pass +class MetadataProviderSettings(dl_settings.TypedBaseSettings): + ... + + +class EmptyMetadataProviderSettings(MetadataProviderSettings): + ... + + +MetadataProviderSettings.register("empty", EmptyMetadataProviderSettings) + + @attrs.define(kw_only=True, frozen=True) class EmptyMetadataProvider(MetadataProvider): + @classmethod + def from_settings(cls, settings: MetadataProviderSettings) -> "EmptyMetadataProvider": + return cls() + async def get_metadata(self) -> Mapping[str, str]: return {} diff --git a/lib/dl_temporal/dl_temporal_tests/db/worker/__init__.py b/lib/dl_temporal/dl_temporal/utils/__init__.py similarity index 100% rename from lib/dl_temporal/dl_temporal_tests/db/worker/__init__.py rename to lib/dl_temporal/dl_temporal/utils/__init__.py diff --git a/lib/dl_temporal/dl_temporal/utils/aiohttp/__init__.py b/lib/dl_temporal/dl_temporal/utils/aiohttp/__init__.py new file mode 100644 index 0000000000..9b0c749e34 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/aiohttp/__init__.py @@ -0,0 +1,18 @@ +from .handlers import ( + LivenessProbeHandler, + ReadinessProbeHandler, + Response, + SubsystemReadinessAsyncCallback, + SubsystemReadinessCallback, + SubsystemReadinessSyncCallback, +) + + +__all__ = [ + "LivenessProbeHandler", + "ReadinessProbeHandler", + "Response", + "SubsystemReadinessAsyncCallback", + "SubsystemReadinessCallback", + "SubsystemReadinessSyncCallback", +] diff --git a/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/__init__.py b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/__init__.py new file mode 100644 index 0000000000..135d7028b2 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/__init__.py @@ -0,0 +1,18 @@ +from .health import ( + LivenessProbeHandler, + ReadinessProbeHandler, + SubsystemReadinessAsyncCallback, + SubsystemReadinessCallback, + SubsystemReadinessSyncCallback, +) +from .responses import Response + + +__all__ = [ + "LivenessProbeHandler", + "ReadinessProbeHandler", + "Response", + "SubsystemReadinessAsyncCallback", + "SubsystemReadinessCallback", + "SubsystemReadinessSyncCallback", +] diff --git a/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/__init__.py b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/__init__.py new file mode 100644 index 0000000000..d9ff395b78 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/__init__.py @@ -0,0 +1,16 @@ +from .liveness_probe import LivenessProbeHandler +from .readiness_probe import ( + ReadinessProbeHandler, + SubsystemReadinessAsyncCallback, + SubsystemReadinessCallback, + SubsystemReadinessSyncCallback, +) + + +__all__ = [ + "LivenessProbeHandler", + "ReadinessProbeHandler", + "SubsystemReadinessAsyncCallback", + "SubsystemReadinessCallback", + "SubsystemReadinessSyncCallback", +] diff --git a/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/liveness_probe.py b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/liveness_probe.py new file mode 100644 index 0000000000..4cc8e7f6e6 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/liveness_probe.py @@ -0,0 +1,22 @@ +import http +import logging + +import aiohttp.web as aiohttp_web + +import dl_temporal.utils.aiohttp.handlers as aiohttp_handlers + + +logger = logging.getLogger(__name__) + + +class LivenessProbeHandler: + async def process(self, request: aiohttp_web.Request) -> aiohttp_web.Response: + return aiohttp_handlers.Response.with_data( + status=http.HTTPStatus.OK, + data={"status": "healthy"}, + ) + + +__all__ = [ + "LivenessProbeHandler", +] diff --git a/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/readiness_probe.py b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/readiness_probe.py new file mode 100644 index 0000000000..acc71873c2 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/health/readiness_probe.py @@ -0,0 +1,63 @@ +import http +import logging +import typing + +import aiohttp.web as aiohttp_web +import attr + +import dl_temporal.utils.aiohttp.handlers as aiohttp_handlers + + +logger = logging.getLogger(__name__) + + +@attr.define(frozen=True, kw_only=True) +class SubsystemReadinessAsyncCallback: + name: str + is_ready: typing.Callable[[], typing.Awaitable[bool]] + + +@attr.define(frozen=True, kw_only=True) +class SubsystemReadinessSyncCallback: + name: str + is_ready: typing.Callable[[], bool] + + +SubsystemReadinessCallback = SubsystemReadinessAsyncCallback | SubsystemReadinessSyncCallback + + +@attr.define(frozen=True, kw_only=True) +class ReadinessProbeHandler: + subsystems: typing.Sequence[SubsystemReadinessCallback] + + async def _check_subsystem_readiness(self, subsystem: SubsystemReadinessCallback) -> bool: + if isinstance(subsystem, SubsystemReadinessAsyncCallback): + return await subsystem.is_ready() + elif isinstance(subsystem, SubsystemReadinessSyncCallback): + return subsystem.is_ready() + else: + raise ValueError(f"Unknown subsystem type: {type(subsystem)}") + + async def process(self, request: aiohttp_web.Request) -> aiohttp_web.Response: + subsystems_status: dict[str, bool] = { + subsystem.name: await self._check_subsystem_readiness(subsystem) for subsystem in self.subsystems + } + + if all(subsystems_status.values()): + return aiohttp_handlers.Response.with_data( + status=http.HTTPStatus.OK, + data={"status": "healthy", "subsystems_status": subsystems_status}, + ) + + logger.error("Not all subsystems are healthy!", extra=subsystems_status) + + return aiohttp_handlers.Response.with_data( + status=http.HTTPStatus.INTERNAL_SERVER_ERROR, + data={"status": "unhealthy", "subsystems_status": subsystems_status}, + ) + + +__all__ = [ + "ReadinessProbeHandler", + "SubsystemReadinessAsyncCallback", +] diff --git a/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/responses.py b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/responses.py new file mode 100644 index 0000000000..62dbfb6503 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/aiohttp/handlers/responses.py @@ -0,0 +1,49 @@ +import http +import typing + +import aiohttp.typedefs as aiohttp_typedefs +import aiohttp.web as aiohttp_web +from typing_extensions import Self + +import dl_json + + +class Response(aiohttp_web.Response): + @classmethod + def with_bytes( + cls, + body: bytes, + status: int = http.HTTPStatus.OK, + reason: str | None = None, + headers: aiohttp_typedefs.LooseHeaders | None = None, + content_type: str | None = None, + ) -> Self: + return cls( + body=body, + status=status, + reason=reason, + headers=headers, + content_type=content_type, + ) + + @classmethod + def with_data( + cls, + data: typing.Any, + status: int = http.HTTPStatus.OK, + reason: str | None = None, + headers: aiohttp_typedefs.LooseHeaders | None = None, + ) -> Self: + body = dl_json.dumps_bytes(data) + return cls.with_bytes( + body=body, + status=status, + reason=reason, + headers=headers, + content_type="application/json", + ) + + +__all__ = [ + "Response", +] diff --git a/lib/dl_temporal/dl_temporal/utils/aiohttp/printer.py b/lib/dl_temporal/dl_temporal/utils/aiohttp/printer.py new file mode 100644 index 0000000000..d9eb56a915 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/aiohttp/printer.py @@ -0,0 +1,24 @@ +import logging +import sys +from typing import ( + Any, + TextIO, +) + + +AIOHTTP_LOGGER = logging.getLogger("aiohttp") + + +class PrintLogger: + def __call__( + self, + *args: Any, + sep: str = " ", + end: str = "\n", + file: TextIO | None = None, + ) -> None: + if file == sys.stderr: + AIOHTTP_LOGGER.error(sep.join([str(arg) for arg in args])) + + if file is None or file == sys.stdout: + AIOHTTP_LOGGER.info(sep.join([str(arg) for arg in args])) diff --git a/lib/dl_temporal/dl_temporal/utils/app/__init__.py b/lib/dl_temporal/dl_temporal/utils/app/__init__.py new file mode 100644 index 0000000000..33302db0dc --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/app/__init__.py @@ -0,0 +1,26 @@ +from .base import ( + BaseApp, + BaseAppFactory, + BaseAppSettings, +) +from .exceptions import ( + ApplicationError, + RunError, + ShutdownError, + StartupError, + UnexpectedFinishError, +) +from .models import Callback + + +__all__ = [ + "BaseApp", + "BaseAppSettings", + "BaseAppFactory", + "Callback", + "ApplicationError", + "StartupError", + "ShutdownError", + "RunError", + "UnexpectedFinishError", +] diff --git a/lib/dl_temporal/dl_temporal/utils/app/base.py b/lib/dl_temporal/dl_temporal/utils/app/base.py new file mode 100644 index 0000000000..6d5efe32ff --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/app/base.py @@ -0,0 +1,230 @@ +import abc +import asyncio +import contextlib +import datetime +import enum +import logging +from typing import ( + AsyncGenerator, + ClassVar, + Generic, + Iterator, + TypeVar, +) + +import attr +from typing_extensions import Self + +import dl_settings +import dl_temporal.utils.app.exceptions as app_exceptions +import dl_temporal.utils.app.models as app_models +import dl_temporal.utils.singleton as singleton_utils + + +class BaseAppSettings(dl_settings.BaseRootSettings): + ... + + +class RuntimeStatus(enum.Enum): + INITIALIZED = "initialized" + STARTING = "starting" + RUNNING = "running" + STOPPING = "stopping" + STOPPED = "stopped" + + +@attr.define(kw_only=True) +class AppState: + runtime_status: RuntimeStatus = attr.field(default=RuntimeStatus.INITIALIZED) + + +@attr.define(frozen=True, kw_only=True) +class BaseApp: + _startup_callbacks: list[app_models.Callback] = attr.field(factory=list) + _shutdown_callbacks: list[app_models.Callback] = attr.field(factory=list) + _main_callbacks: list[app_models.Callback] = attr.field(factory=list) + + logger: logging.Logger + + _state: AppState = attr.field(factory=AppState) + + @property + def startup_callbacks(self) -> Iterator[app_models.Callback]: + yield from self._startup_callbacks + + @property + def shutdown_callbacks(self) -> Iterator[app_models.Callback]: + yield from self._shutdown_callbacks + + @property + def main_callbacks(self) -> Iterator[app_models.Callback]: + yield from self._main_callbacks + + async def on_startup(self) -> None: + self._state.runtime_status = RuntimeStatus.STARTING + + for callback in self.startup_callbacks: + try: + await callback.coroutine + except Exception as e: + message = f"Failed to startup due to failed StartupCallback({callback.name})" + if callback.exception: + self.logger.exception(message) + raise app_exceptions.StartupError(message) from e + else: + self.logger.warning(message) + else: + self.logger.info(f"Successfully started StartupCallback({callback.name})") + + async def on_shutdown(self) -> None: + self._state.runtime_status = RuntimeStatus.STOPPING + + for callback in self.shutdown_callbacks: + try: + await callback.coroutine + except Exception as e: + message = f"Failed to shutdown due to failed ShutdownCallback({callback.name})" + if callback.exception: + self.logger.exception(message) + raise app_exceptions.ShutdownError(message) from e + else: + self.logger.warning(message) + else: + self.logger.info(f"Successfully shutdown ShutdownCallback({callback.name})") + + self._state.runtime_status = RuntimeStatus.STOPPED + + async def main(self) -> None: + tasks: list[asyncio.Task[None]] = [] + for callback in self.main_callbacks: + tasks.append(asyncio.create_task(callback.coroutine, name=callback.name)) + + if len(tasks) == 0: + self.logger.warning("No main callbacks provided") + return + + self._state.runtime_status = RuntimeStatus.RUNNING + + try: + for future in asyncio.as_completed(tasks): + await future + break + except asyncio.CancelledError: + self.logger.info("The main tasks execution was cancelled") + raise + except Exception as e: + self.logger.exception("An error occurred during the main tasks execution") + raise app_exceptions.RunError from e + else: + self.logger.error("Some tasks finished unexpectedly") + raise app_exceptions.UnexpectedFinishError + finally: + finished_unexpectedly: list[asyncio.Task[None]] = [] + finished_with_exception: list[asyncio.Task[None]] = [] + unfinished_tasks: list[asyncio.Task[None]] = [] + + for task in tasks: + if task.done(): + if task.exception() is None: + finished_unexpectedly.append(task) + else: + finished_with_exception.append(task) + else: + unfinished_tasks.append(task) + + if finished_unexpectedly: + self.logger.info("Tasks that finished unexpectedly:") + for task in finished_unexpectedly: + self.logger.info("- %s", task.get_name()) + + if finished_with_exception: + self.logger.info("Tasks that finished with exception:") + for task in finished_with_exception: + self.logger.info("- %s", task.get_name()) + + if unfinished_tasks: + self.logger.info("Unfinished tasks:") + for task in unfinished_tasks: + self.logger.info("- %s - cancelling...", task.get_name()) + task.cancel() + try: + await task + except asyncio.CancelledError: + self.logger.info("- %s - cancelled", task.get_name()) + + async def run(self) -> None: + await self.on_startup() + + try: + await self.main() + finally: + await self.on_shutdown() + + @contextlib.asynccontextmanager + async def run_in_task_context( + self, + readiness_timeout: datetime.timedelta = datetime.timedelta(seconds=60), + ) -> AsyncGenerator[Self, None]: + try: + run_task = asyncio.create_task(self.run(), name="run_in_task_context") + + deadline = datetime.datetime.now() + readiness_timeout + while datetime.datetime.now() < deadline and self._state.runtime_status != RuntimeStatus.RUNNING: + if self._state.runtime_status in [RuntimeStatus.STOPPING, RuntimeStatus.STOPPED]: + raise RuntimeError("Failed to wait for the application to be running") + + self.logger.info("Waiting for the application to be running") + await asyncio.sleep(1) + + yield self + finally: + self.logger.info("Cancelling the run task") + run_task.cancel() + try: + await run_task + except asyncio.CancelledError: + self.logger.info("The run task was cancelled") + + +AppType = TypeVar("AppType", bound=BaseApp) + + +@attr.define(kw_only=True, slots=False) +class BaseAppFactory(Generic[AppType]): + settings: BaseAppSettings + app_class: ClassVar[type[AppType]] # type: ignore + + async def create_application( + self, + ) -> AppType: + return self.app_class( + startup_callbacks=await self._get_startup_callbacks(), + shutdown_callbacks=await self._get_shutdown_callbacks(), + main_callbacks=await self._get_main_callbacks(), + logger=await self._get_logger(), + ) + + @singleton_utils.singleton_class_method_result + async def _get_startup_callbacks( + self, + ) -> list[app_models.Callback]: + return [] + + @singleton_utils.singleton_class_method_result + async def _get_shutdown_callbacks( + self, + ) -> list[app_models.Callback]: + return [] + + @singleton_utils.singleton_class_method_result + async def _get_main_callbacks( + self, + ) -> list[app_models.Callback]: + return [] + + @abc.abstractmethod + @singleton_utils.singleton_class_method_result + async def _get_logger( + self, + ) -> logging.Logger: + ... diff --git a/lib/dl_temporal/dl_temporal/utils/app/exceptions.py b/lib/dl_temporal/dl_temporal/utils/app/exceptions.py new file mode 100644 index 0000000000..8627db7fc3 --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/app/exceptions.py @@ -0,0 +1,18 @@ +class ApplicationError(Exception): + ... + + +class StartupError(ApplicationError): + ... + + +class ShutdownError(ApplicationError): + ... + + +class RunError(ApplicationError): + ... + + +class UnexpectedFinishError(RunError): + ... diff --git a/lib/dl_temporal/dl_temporal/utils/app/models.py b/lib/dl_temporal/dl_temporal/utils/app/models.py new file mode 100644 index 0000000000..e369586bdc --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/app/models.py @@ -0,0 +1,13 @@ +from typing import ( + Any, + Coroutine, +) + +import attr + + +@attr.define(frozen=True, kw_only=True) +class Callback: + coroutine: Coroutine[Any, Any, None] + name: str + exception: bool = attr.field(default=True) diff --git a/lib/dl_temporal/dl_temporal/utils/singleton.py b/lib/dl_temporal/dl_temporal/utils/singleton.py new file mode 100644 index 0000000000..8ce01b6d0e --- /dev/null +++ b/lib/dl_temporal/dl_temporal/utils/singleton.py @@ -0,0 +1,89 @@ +import functools +import inspect +import logging +from typing import ( + Any, + Callable, + Coroutine, + TypeVar, + cast, +) + + +LOGGER = logging.getLogger(__name__) + + +SINGLETON_FUNCTION_RESULT_ATTRIBUTE = "_singleton_function_result" +SyncFunction = Callable[..., Any] +AsyncFunction = Callable[..., Coroutine[Any, Any, Any]] +Function = SyncFunction | AsyncFunction + +SyncFunctionType = TypeVar("SyncFunctionType", bound=SyncFunction) +AsyncFunctionType = TypeVar("AsyncFunctionType", bound=AsyncFunction) +FunctionType = TypeVar("FunctionType", bound=Function) + + +# Decorator is not thread-safe, but it's ok for our use case +def singleton_function_result(func: FunctionType) -> FunctionType: + if inspect.iscoroutinefunction(func): + return _async_singleton_function_result(func) + else: + return _sync_singleton_function_result(func) + + +def _async_singleton_function_result(func: AsyncFunctionType) -> AsyncFunctionType: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + if not hasattr(func, SINGLETON_FUNCTION_RESULT_ATTRIBUTE): + setattr(func, SINGLETON_FUNCTION_RESULT_ATTRIBUTE, await func(*args, **kwargs)) + + return getattr(func, SINGLETON_FUNCTION_RESULT_ATTRIBUTE) + + return cast(AsyncFunctionType, wrapper) + + +def _sync_singleton_function_result(func: SyncFunctionType) -> SyncFunctionType: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if not hasattr(func, SINGLETON_FUNCTION_RESULT_ATTRIBUTE): + setattr(func, SINGLETON_FUNCTION_RESULT_ATTRIBUTE, func(*args, **kwargs)) + + return getattr(func, SINGLETON_FUNCTION_RESULT_ATTRIBUTE) + + return cast(SyncFunctionType, wrapper) + + +# Decorator is not thread-safe, but it's ok for our use case +def singleton_class_method_result(func: FunctionType) -> FunctionType: + if inspect.iscoroutinefunction(func): + return _async_singleton_class_method_result(func) + else: + return _sync_singleton_class_method_result(func) + + +def _async_singleton_class_method_result(func: AsyncFunctionType) -> AsyncFunctionType: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + class_instance = args[0] + instance_key = f"{SINGLETON_FUNCTION_RESULT_ATTRIBUTE}_{func.__name__}" + + if not hasattr(class_instance, instance_key): + setattr(class_instance, instance_key, await func(*args, **kwargs)) + + return getattr(class_instance, instance_key) + + return cast(AsyncFunctionType, wrapper) + + +def _sync_singleton_class_method_result(func: SyncFunctionType) -> SyncFunctionType: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + class_instance = args[0] + instance_key = f"{SINGLETON_FUNCTION_RESULT_ATTRIBUTE}_{func.__name__}" + + if not hasattr(class_instance, instance_key): + setattr(class_instance, instance_key, func(*args, **kwargs)) + + return getattr(class_instance, instance_key) + + return cast(SyncFunctionType, wrapper) diff --git a/lib/dl_temporal/dl_temporal_tests/db/worker/activities.py b/lib/dl_temporal/dl_temporal_tests/db/activities.py similarity index 100% rename from lib/dl_temporal/dl_temporal_tests/db/worker/activities.py rename to lib/dl_temporal/dl_temporal_tests/db/activities.py diff --git a/lib/dl_temporal/dl_temporal_tests/db/app/__init__.py b/lib/dl_temporal/dl_temporal_tests/db/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/dl_temporal/dl_temporal_tests/db/app/config.yaml b/lib/dl_temporal/dl_temporal_tests/db/app/config.yaml new file mode 100644 index 0000000000..5577c4da3b --- /dev/null +++ b/lib/dl_temporal/dl_temporal_tests/db/app/config.yaml @@ -0,0 +1,9 @@ +http_server: + host: "0.0.0.0" + port: 8080 + +temporal_client: + namespace: "test" + tls: false + metadata_provider: + type: "empty" diff --git a/lib/dl_temporal/dl_temporal_tests/db/app/conftest.py b/lib/dl_temporal/dl_temporal_tests/db/app/conftest.py new file mode 100644 index 0000000000..03186a85e8 --- /dev/null +++ b/lib/dl_temporal/dl_temporal_tests/db/app/conftest.py @@ -0,0 +1,87 @@ +import logging +import os +from typing import ( + AsyncGenerator, + ClassVar, +) + +import attr +import pytest +import pytest_asyncio +from typing_extensions import override + +import dl_temporal +import dl_temporal.app +import dl_temporal_tests.db.activities as activities +import dl_temporal_tests.db.workflows as workflows +import dl_testing + + +DIR_PATH = os.path.dirname(__file__) +LOGGER = logging.getLogger(__name__) + + +class Settings(dl_temporal.app.BaseTemporalWorkerAppSettings): + ... + + +class App(dl_temporal.app.BaseTemporalWorkerApp): + ... + + +@attr.define(kw_only=True, slots=False) +class Factory(dl_temporal.app.BaseTemporalWorkerAppFactory[App]): + settings: Settings + app_class: ClassVar[type[App]] = App + + @override + async def _get_temporal_workflows( + self, + ) -> list[type[dl_temporal.WorkflowProtocol]]: + return [workflows.Workflow] + + @override + async def _get_temporal_activities( + self, + ) -> list[dl_temporal.ActivityProtocol]: + return [activities.Activity()] + + @override + async def _get_temporal_client_metadata_provider( + self, + ) -> dl_temporal.MetadataProvider: + return dl_temporal.EmptyMetadataProvider() + + @override + async def _get_logger( + self, + ) -> logging.Logger: + return LOGGER + + +@pytest.fixture(name="app_settings") +def fixture_app_settings( + monkeypatch: pytest.MonkeyPatch, + temporal_task_queue: str, + temporal_namespace: str, + temporal_hostport: dl_testing.HostPort, +) -> Settings: + monkeypatch.setenv("CONFIG_PATH", os.path.join(DIR_PATH, "config.yaml")) + + monkeypatch.setenv("TEMPORAL_WORKER__TASK_QUEUE", temporal_task_queue) + monkeypatch.setenv("TEMPORAL_CLIENT__NAMESPACE", temporal_namespace) + monkeypatch.setenv("TEMPORAL_CLIENT__HOST", temporal_hostport.host) + monkeypatch.setenv("TEMPORAL_CLIENT__PORT", str(temporal_hostport.port)) + + return Settings() + + +@pytest_asyncio.fixture(name="app", autouse=True) +async def fixture_app( + app_settings: Settings, +) -> AsyncGenerator[dl_temporal.app.BaseTemporalWorkerApp, None]: + factory = Factory(settings=app_settings) + app = await factory.create_application() + + async with app.run_in_task_context() as app: + yield app diff --git a/lib/dl_temporal/dl_temporal_tests/db/app/test_handlers.py b/lib/dl_temporal/dl_temporal_tests/db/app/test_handlers.py new file mode 100644 index 0000000000..738c45762c --- /dev/null +++ b/lib/dl_temporal/dl_temporal_tests/db/app/test_handlers.py @@ -0,0 +1,40 @@ +import aiohttp +import pytest + +import dl_temporal +import dl_temporal.app + + +@pytest.fixture(name="temporal_task_queue") +def fixture_temporal_task_queue() -> str: + return "tests/db/app/test_handlers" + + +@pytest.mark.asyncio +async def test_liveness_probe_handler( + app_settings: dl_temporal.app.BaseTemporalWorkerAppSettings, +) -> None: + async with aiohttp.ClientSession() as session: + response = await session.get( + f"http://{app_settings.http_server.host}:{app_settings.http_server.port}/api/v1/health/liveness" + ) + assert response.status == 200 + assert await response.json() == {"status": "healthy"} + + +@pytest.mark.asyncio +async def test_readiness_probe_handler( + app_settings: dl_temporal.app.BaseTemporalWorkerAppSettings, +) -> None: + async with aiohttp.ClientSession() as session: + response = await session.get( + f"http://{app_settings.http_server.host}:{app_settings.http_server.port}/api/v1/health/readiness" + ) + assert response.status == 200 + assert await response.json() == { + "status": "healthy", + "subsystems_status": { + "temporal_client.check_health": True, + "temporal_worker.is_running": True, + }, + } diff --git a/lib/dl_temporal/dl_temporal_tests/db/worker/test_worker.py b/lib/dl_temporal/dl_temporal_tests/db/app/test_worker.py similarity index 60% rename from lib/dl_temporal/dl_temporal_tests/db/worker/test_worker.py rename to lib/dl_temporal/dl_temporal_tests/db/app/test_worker.py index cb6c5723e8..1b69f576ef 100644 --- a/lib/dl_temporal/dl_temporal_tests/db/worker/test_worker.py +++ b/lib/dl_temporal/dl_temporal_tests/db/app/test_worker.py @@ -1,47 +1,22 @@ import datetime -import logging -from typing import AsyncGenerator import uuid import pytest -import pytest_asyncio -import temporalio.worker import dl_pydantic import dl_temporal -import dl_temporal.testing as dl_temporal_testing -import dl_temporal_tests.db.worker.activities as activities -import dl_temporal_tests.db.worker.workflows as workflows -import dl_testing +import dl_temporal_tests.db.workflows as workflows -@pytest.fixture(name="temporal_queue_name") -def fixture_temporal_queue_name() -> str: - return "test_queue_name" - - -@pytest_asyncio.fixture(name="temporal_worker", autouse=True) -async def fixture_temporal_worker( - temporal_client: dl_temporal.TemporalClient, - temporal_queue_name: str, - temporal_ui_hostport: dl_testing.HostPort, -) -> AsyncGenerator[temporalio.worker.Worker, None]: - worker = dl_temporal.create_worker( - task_queue=temporal_queue_name, - client=temporal_client, - workflows=[workflows.Workflow], - activities=[activities.Activity()], - ) - logging.info(f"Temporal UI URL: http://{temporal_ui_hostport.host}:{temporal_ui_hostport.port}") - - async with dl_temporal_testing.worker_run_context(worker=worker) as worker: - yield worker +@pytest.fixture(name="temporal_task_queue") +def fixture_temporal_task_queue() -> str: + return "tests/db/app/test_worker" @pytest.mark.asyncio async def test_default( temporal_client: dl_temporal.TemporalClient, - temporal_queue_name: str, + temporal_task_queue: str, ) -> None: random_id = uuid.uuid4() @@ -61,7 +36,7 @@ async def test_default( workflows.Workflow, params, id=str(random_id), - task_queue=temporal_queue_name, + task_queue=temporal_task_queue, ) result: workflows.WorkflowResult = await workflow_handler.result() diff --git a/lib/dl_temporal/dl_temporal_tests/db/client/test_check_health.py b/lib/dl_temporal/dl_temporal_tests/db/client/test_check_health.py index fea0799d06..d3c9da3cb3 100644 --- a/lib/dl_temporal/dl_temporal_tests/db/client/test_check_health.py +++ b/lib/dl_temporal/dl_temporal_tests/db/client/test_check_health.py @@ -11,8 +11,8 @@ async def test_default(temporal_client: dl_temporal.TemporalClient) -> None: @pytest.mark.asyncio async def test_unavailable() -> None: - temporal_client = await dl_temporal.TemporalClient.from_settings( - settings=dl_temporal.TemporalClientSettings( + temporal_client = await dl_temporal.TemporalClient.from_dependencies( + dependencies=dl_temporal.TemporalClientDependencies( host="unavailable_host", port=8080, namespace="dl_temporal_tests", diff --git a/lib/dl_temporal/dl_temporal_tests/db/client/test_metadata.py b/lib/dl_temporal/dl_temporal_tests/db/client/test_metadata.py index 64f8043339..595bd7184c 100644 --- a/lib/dl_temporal/dl_temporal_tests/db/client/test_metadata.py +++ b/lib/dl_temporal/dl_temporal_tests/db/client/test_metadata.py @@ -17,8 +17,8 @@ async def test_ttl( metadata_provider.get_metadata.return_value = {"test": "test_before"} metadata_provider.ttl = ttl - temporal_client = await dl_temporal.TemporalClient.from_settings( - settings=dl_temporal.TemporalClientSettings( + temporal_client = await dl_temporal.TemporalClient.from_dependencies( + dependencies=dl_temporal.TemporalClientDependencies( host="test-host", port=1234, namespace="test-namespace", diff --git a/lib/dl_temporal/dl_temporal_tests/db/conftest.py b/lib/dl_temporal/dl_temporal_tests/db/conftest.py index 1b20ee6182..b0da73b1d0 100644 --- a/lib/dl_temporal/dl_temporal_tests/db/conftest.py +++ b/lib/dl_temporal/dl_temporal_tests/db/conftest.py @@ -39,8 +39,8 @@ async def fixture_temporal_client( temporal_namespace: str, temporal_hostport: dl_testing.HostPort, ) -> AsyncGenerator[dl_temporal.TemporalClient, None]: - client = await dl_temporal.TemporalClient.from_settings( - dl_temporal.TemporalClientSettings( + client = await dl_temporal.TemporalClient.from_dependencies( + dl_temporal.TemporalClientDependencies( host=temporal_hostport.host, port=temporal_hostport.port, namespace=temporal_namespace, @@ -48,14 +48,6 @@ async def fixture_temporal_client( ) ) - try: - await client.register_namespace( - namespace=temporal_namespace, - workflow_execution_retention_period=datetime.timedelta(days=1), - ) - except dl_temporal.AlreadyExists: - pass - await dl_utils.await_for( name="temporal client", condition=client.check_health, @@ -68,3 +60,17 @@ async def fixture_temporal_client( yield client finally: await client.close() + + +@pytest_asyncio.fixture(name="register_namespace", autouse=True) +async def fixture_register_namespace( + temporal_client: dl_temporal.TemporalClient, + temporal_namespace: str, +) -> None: + try: + await temporal_client.register_namespace( + namespace=temporal_namespace, + workflow_execution_retention_period=datetime.timedelta(days=1), + ) + except dl_temporal.AlreadyExists: + pass diff --git a/lib/dl_temporal/dl_temporal_tests/db/worker/workflows.py b/lib/dl_temporal/dl_temporal_tests/db/workflows.py similarity index 97% rename from lib/dl_temporal/dl_temporal_tests/db/worker/workflows.py rename to lib/dl_temporal/dl_temporal_tests/db/workflows.py index 39278d9e03..aeb7c38f85 100644 --- a/lib/dl_temporal/dl_temporal_tests/db/worker/workflows.py +++ b/lib/dl_temporal/dl_temporal_tests/db/workflows.py @@ -3,7 +3,7 @@ with temporalio.workflow.unsafe.imports_passed_through(): import dl_pydantic - import dl_temporal_tests.db.worker.activities as activities + import dl_temporal_tests.db.activities as activities import dl_temporal diff --git a/lib/dl_temporal/dl_temporal_tests/unit/__init__.py b/lib/dl_temporal/dl_temporal_tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/dl_temporal/dl_temporal_tests/unit/utils/__init__.py b/lib/dl_temporal/dl_temporal_tests/unit/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/dl_temporal/dl_temporal_tests/unit/utils/test_singleton.py b/lib/dl_temporal/dl_temporal_tests/unit/utils/test_singleton.py new file mode 100644 index 0000000000..4b9251f62d --- /dev/null +++ b/lib/dl_temporal/dl_temporal_tests/unit/utils/test_singleton.py @@ -0,0 +1,220 @@ +import pytest + +import dl_temporal.utils.singleton as singleton_utils + + +def test_sync_function_same_result() -> None: + @singleton_utils.singleton_function_result + def test_func() -> object: + return object() + + result1 = test_func() + result2 = test_func() + + assert result1 is result2 + + +def test_sync_function_result() -> None: + value = object() + + @singleton_utils.singleton_function_result + def test_func() -> object: + return value + + assert test_func() is value + + +def test_sync_function_with_args() -> None: + @singleton_utils.singleton_function_result + def test_func(arg: int) -> int: + return arg + + assert test_func(1) == 1 + assert test_func(2) == 1 + + +def test_sync_function_with_kwargs() -> None: + @singleton_utils.singleton_function_result + def test_func(arg: int) -> int: + return arg + + assert test_func(arg=1) == 1 + assert test_func(arg=2) == 1 + + +@pytest.mark.asyncio +async def test_async_function_same_result() -> None: + @singleton_utils.singleton_function_result + async def test_func() -> object: + return object() + + result1 = await test_func() + result2 = await test_func() + + assert result1 is result2 + + +@pytest.mark.asyncio +async def test_async_function_result() -> None: + value = object() + + @singleton_utils.singleton_function_result + async def test_func() -> object: + return value + + assert await test_func() is value + + +@pytest.mark.asyncio +async def test_async_function_with_args() -> None: + @singleton_utils.singleton_function_result + async def test_func(arg: int) -> int: + return arg + + assert await test_func(1) == 1 + assert await test_func(2) == 1 + + +@pytest.mark.asyncio +async def test_async_function_with_kwargs() -> None: + @singleton_utils.singleton_function_result + async def test_func(arg: int) -> int: + return arg + + assert await test_func(arg=1) == 1 + assert await test_func(arg=2) == 1 + + +def test_sync_class_method_same_result() -> None: + class TestClass: + @singleton_utils.singleton_class_method_result + def test_func(self) -> object: + return object() + + instance = TestClass() + result1 = instance.test_func() + result2 = instance.test_func() + + assert result1 is result2 + + +def test_sync_class_method_unique_per_instance() -> None: + class TestClass: + @singleton_utils.singleton_class_method_result + def test_func(self) -> object: + return object() + + instance1 = TestClass() + instance2 = TestClass() + + result1 = instance1.test_func() + result2 = instance2.test_func() + + assert result1 is not result2 + + +def test_sync_class_method_result() -> None: + value = object() + + class TestClass: + @singleton_utils.singleton_class_method_result + def test_func(self) -> object: + return value + + instance = TestClass() + + assert instance.test_func() is value + + +def test_sync_class_method_with_args() -> None: + class TestClass: + @singleton_utils.singleton_class_method_result + def test_func(self, arg: int) -> int: + return arg + + instance = TestClass() + + assert instance.test_func(1) == 1 + assert instance.test_func(2) == 1 + + +def test_sync_class_method_with_kwargs() -> None: + class TestClass: + @singleton_utils.singleton_class_method_result + def test_func(self, arg: int) -> int: + return arg + + instance = TestClass() + + assert instance.test_func(arg=1) == 1 + assert instance.test_func(arg=2) == 1 + + +@pytest.mark.asyncio +async def test_async_class_method_same_result() -> None: + class TestClass: + @singleton_utils.singleton_class_method_result + async def test_func(self) -> object: + return object() + + instance = TestClass() + result1 = await instance.test_func() + result2 = await instance.test_func() + + assert result1 is result2 + + +@pytest.mark.asyncio +async def test_async_class_method_unique_per_instance() -> None: + class TestClass: + @singleton_utils.singleton_class_method_result + async def test_func(self) -> object: + return object() + + instance1 = TestClass() + instance2 = TestClass() + + result1 = await instance1.test_func() + result2 = await instance2.test_func() + + assert result1 is not result2 + + +@pytest.mark.asyncio +async def test_async_class_method_result() -> None: + value = object() + + class TestClass: + @singleton_utils.singleton_class_method_result + async def test_func(self) -> object: + return value + + instance = TestClass() + + assert await instance.test_func() is value + + +@pytest.mark.asyncio +async def test_async_class_method_with_args() -> None: + class TestClass: + @singleton_utils.singleton_class_method_result + async def test_func(self, arg: int) -> int: + return arg + + instance = TestClass() + + assert await instance.test_func(1) == 1 + assert await instance.test_func(2) == 1 + + +@pytest.mark.asyncio +async def test_async_class_method_with_kwargs() -> None: + class TestClass: + @singleton_utils.singleton_class_method_result + async def test_func(self, arg: int) -> int: + return arg + + instance = TestClass() + + assert await instance.test_func(arg=1) == 1 + assert await instance.test_func(arg=2) == 1 diff --git a/lib/dl_temporal/pyproject.toml b/lib/dl_temporal/pyproject.toml index a2f6ea6a8b..2985aed0bc 100644 --- a/lib/dl_temporal/pyproject.toml +++ b/lib/dl_temporal/pyproject.toml @@ -8,8 +8,11 @@ readme = "README.md" version = "0.0.1" [tool.poetry.dependencies] +aiohttp = "*" attrs = "*" +dl-json = {path = "../dl_json"} dl-pydantic = {path = "../dl_pydantic"} +dl-settings = {path = "../dl_settings"} dl-utils = {path = "../dl_utils"} python = ">=3.10, <3.13" temporalio = "*" @@ -43,3 +46,8 @@ requires = [ [datalens.pytest.db] root_dir = "dl_temporal_tests/" target_path = "db" + +[datalens.pytest.unit] +root_dir = "dl_temporal_tests/" +skip_compose = "true" +target_path = "unit"