Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions temporalio/client/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,11 @@ async def start_workflow_update(
):
break

# Add response link if its a Nexus operation
nexus_ctx = temporalio.nexus._operation_context._try_start_operation_context()
if nexus_ctx is not None and resp.HasField("link"):
nexus_ctx._add_response_link(resp.link)

# Build the handle. If the user's wait stage is COMPLETED, make sure we
# poll for result.
handle: WorkflowUpdateHandle[Any] = WorkflowUpdateHandle(
Expand Down Expand Up @@ -852,6 +857,23 @@ async def _build_update_workflow_execution_request(
)
),
)
# Only set Nexus fields for StartWorkflowUpdateInput, skip for UpdateWithStartUpdateWorkflowInput
if isinstance(input, StartWorkflowUpdateInput):
if input.request_id:
req.request.request_id = input.request_id
if input.links:
req.request.links.extend(input.links)
if input.callbacks:
req.request.completion_callbacks.extend(
temporalio.api.common.v1.Callback(
nexus=temporalio.api.common.v1.Callback.Nexus(
url=callback.url,
header=callback.headers,
),
links=input.links or [],
)
for callback in input.callbacks
)
if input.args:
req.request.input.args.payloads.extend(
await data_converter.encode(input.args)
Expand Down
4 changes: 4 additions & 0 deletions temporalio/client/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ class StartWorkflowUpdateInput:
ret_type: type | None
rpc_metadata: Mapping[str, str | bytes]
rpc_timeout: timedelta | None
# The following options are for Nexus Operation-backed updates. Experimental and unstable
callbacks: Sequence[Callback] | None = None
links: Sequence[temporalio.api.common.v1.Link] | None = None
request_id: str | None = None


@dataclass
Expand Down
14 changes: 14 additions & 0 deletions temporalio/client/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
ReturnType,
SelfType,
)
from ._callback import Callback
from ._exceptions import (
WorkflowContinuedAsNewError,
WorkflowFailureError,
Expand Down Expand Up @@ -896,6 +897,8 @@ async def start_update(
rpc_timeout: timedelta | None = None,
) -> WorkflowUpdateHandle[Any]: ...

# draft-review: check why this doesnt currently support run_id and first_execution_run_id
# If it can be supported, wire it up for nexus operation-backed updates as well
async def start_update(
self,
update: str | Callable,
Expand Down Expand Up @@ -955,6 +958,12 @@ async def _start_update(
result_type: type | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
# run_id: str | None = None,
# first_execution_run_id: str | None = None,
# The following options are for Nexus Operation-backed updates. Experimental and unstable
callbacks: Sequence[Callback] | None = None,
links: Sequence[temporalio.api.common.v1.Link] | None = None,
request_id: str | None = None,
) -> WorkflowUpdateHandle[Any]:
if wait_for_stage == WorkflowUpdateStage.ADMITTED:
raise ValueError("ADMITTED wait stage not supported")
Expand All @@ -976,6 +985,11 @@ async def _start_update(
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
wait_for_stage=wait_for_stage,
# run_id=run_id,
# first_execution_run_id=first_execution_run_id,
callbacks=callbacks,
links=links,
request_id=request_id,
)
)

Expand Down
5 changes: 4 additions & 1 deletion temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@
wait_for_worker_shutdown_sync,
)
from ._operation_handlers import (
CancelUpdateWorkflowOptions,
CancelWorkflowRunOptions,
TemporalOperationHandler,
)
from ._temporal_client import TemporalNexusClient, TemporalOperationResult
from ._token import WorkflowHandle
from ._token import UpdateHandle, WorkflowHandle

__all__ = (
"workflow_run_operation",
"CancelWorkflowRunOptions",
"CancelUpdateWorkflowOptions",
"Info",
"LoggerAdapter",
"NexusCallback",
Expand All @@ -49,6 +51,7 @@
"wait_for_worker_shutdown",
"wait_for_worker_shutdown_sync",
"WorkflowHandle",
"UpdateHandle",
"TemporalNexusClient",
"TemporalOperationStartHandlerFunc",
"TemporalOperationHandler",
Expand Down
40 changes: 40 additions & 0 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,43 @@ async def _start_nexus_backing_workflow(
)

return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle)


async def _start_nexus_backed_workflow_update( # pyright: ignore[reportUnusedFunction]
*,
temporal_context: _TemporalStartOperationContext,
workflow_id: str,
update: str | Callable,
arg: Any = temporalio.common._arg_unset,
args: Sequence[Any] = [],
id: str | None = None,
result_type: type | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
# run_id: str | None = None,
# first_execution_run_id: str | None = None,
) -> temporalio.client.WorkflowUpdateHandle[Any]:
# Default update ID to the Nexus request ID for retry-safety (matches sdk-go).
update_id = id or temporal_context.nexus_context.request_id
token = OperationToken(
type=OperationTokenType.UPDATE_WORKFLOW,
namespace=temporal_context.client.namespace,
workflow_id=workflow_id,
update_id=update_id,
).encode()
workflow_handle = temporal_context.client.get_workflow_handle(workflow_id)
return await workflow_handle._start_update(
update,
arg,
args=args,
wait_for_stage=temporalio.client.WorkflowUpdateStage.ACCEPTED, # hardcoded as nexus only supports async updates
id=update_id,
result_type=result_type,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
callbacks=temporal_context._get_callbacks(token),
links=temporal_context._get_request_links(),
request_id=temporal_context.nexus_context.request_id,
# run_id=run_id,
# first_execution_run_id=first_execution_run_id,
)
44 changes: 44 additions & 0 deletions temporalio/nexus/_operation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,23 @@ class CancelWorkflowRunOptions:
"""The ID of the workflow to cancel."""


@dataclass(frozen=True)
class CancelUpdateWorkflowOptions:
"""Options for cancelling the workflow update backing a Nexus operation.

These options are built by :py:class:`TemporalOperationHandler` and passed to
:py:meth:`TemporalOperationHandler.cancel_workflow_update`.

.. warning::
This API is experimental and unstable.
"""

workflow_id: str
"""The ID of the workflow where the update is running."""
update_id: str
"""The ID of the update to cancel."""


class TemporalOperationHandler(OperationHandler[InputT, OutputT], ABC):
"""Operation handler for Nexus operations that interact with Temporal.
Implementations override the start_operation method.
Expand Down Expand Up @@ -190,6 +207,13 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
workflow_id=operation_token.workflow_id
)
await self.cancel_workflow_run(cancel_ctx, options)
case OperationTokenType.UPDATE_WORKFLOW:
assert operation_token.update_id is not None
cancel_options = CancelUpdateWorkflowOptions(
workflow_id=operation_token.workflow_id,
update_id=operation_token.update_id,
)
await self.cancel_workflow_update(cancel_ctx, cancel_options)

async def cancel_workflow_run(
self,
Expand All @@ -205,3 +229,23 @@ async def cancel_workflow_run(
options.workflow_id
)
await workflow_handle.cancel()

# draft-review: maybe just move it inline, no need for a function just to error out
# check after review in case theres some other way to override/supply custom cancels
async def cancel_workflow_update(
self,
ctx: TemporalCancelOperationContext, # pyright: ignore[reportUnusedParameter]
options: CancelUpdateWorkflowOptions, # pyright: ignore[reportUnusedParameter]
) -> None:
"""Cancels the workflow update backing the Nexus operation.

.. warning::
This API is experimental and unstable.
"""
raise HandlerError(
"""
Cancellation is not natively supported for update-workflow Nexus operations.
Override a TemporalOperationHandler and implement this method to run cancellable workflow updates.
""",
type=HandlerErrorType.NOT_IMPLEMENTED,
)
133 changes: 132 additions & 1 deletion temporalio/nexus/_temporal_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
overload,
)

from nexusrpc import HandlerError, HandlerErrorType
from nexusrpc import HandlerError, HandlerErrorType, OperationError, OperationErrorState
from nexusrpc.handler import StartOperationResultAsync, StartOperationResultSync
from typing_extensions import Self

import temporalio.common
from temporalio.nexus._operation_context import (
_start_nexus_backed_workflow_update,
_start_nexus_backing_workflow,
_TemporalStartOperationContext,
)
from temporalio.nexus._token import UpdateHandle
from temporalio.types import (
MethodAsyncNoParam,
MethodAsyncSingleParam,
Expand All @@ -35,6 +37,7 @@

if TYPE_CHECKING:
import temporalio.client
import temporalio.workflow


_ResultT = TypeVar("_ResultT")
Expand Down Expand Up @@ -279,6 +282,85 @@ async def start_workflow(
"""
...

# Overload for no-param update
@overload
async def start_workflow_update(
self,
workflow_id: str,
update: temporalio.workflow.UpdateMethodMultiParam[[Any], ReturnType],
*,
id: str | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> TemporalOperationResult[ReturnType]: ...

# Overload for single-param update
@overload
async def start_workflow_update(
self,
workflow_id: str,
update: temporalio.workflow.UpdateMethodMultiParam[
[Any, ParamType], ReturnType
],
arg: ParamType,
*,
id: str | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> TemporalOperationResult[ReturnType]: ...

# Overload for multi-param update
@overload
async def start_workflow_update(
self,
workflow_id: str,
update: temporalio.workflow.UpdateMethodMultiParam[MultiParamSpec, ReturnType],
*,
args: MultiParamSpec.args, # type: ignore
id: str | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> TemporalOperationResult[ReturnType]: ...

# Overload for string-name update
@overload
async def start_workflow_update(
self,
workflow_id: str,
update: str,
arg: Any = temporalio.common._arg_unset,
*,
args: Sequence[Any] = [],
id: str | None = None,
result_type: type[ReturnType] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> TemporalOperationResult[ReturnType]: ...

# draft-review: check why run_id and first_execution_run_id are not used
# for update workflow in python sdk
@abstractmethod
async def start_workflow_update(
self,
workflow_id: str,
update: str | Callable,
arg: Any = temporalio.common._arg_unset,
*,
args: Sequence[Any] = [],
id: str | None = None,
result_type: type | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
# run_id: str | None = None,
# first_execution_run_id: str | None = None,
) -> TemporalOperationResult[Any]:
"""Start a workflow update as the backing asynchronous Nexus operation.

.. warning::
This API is experimental and unstable.
"""
...


class _TemporalNexusClient(TemporalNexusClient): # pyright: ignore[reportUnusedClass]
"""Nexus-aware wrapper around a Temporal Client.
Expand Down Expand Up @@ -377,3 +459,52 @@ async def start_workflow(
)

return TemporalOperationResult.async_token(wf_handle.to_token())

async def start_workflow_update(
self,
workflow_id: str,
update: str | Callable,
arg: Any = temporalio.common._arg_unset,
*,
args: Sequence[Any] = [],
id: str | None = None,
result_type: type | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
# run_id: str | None = None,
# first_execution_run_id: str | None = None,
) -> TemporalOperationResult[Any]:
"""Start a workflow update as the backing asynchronous Nexus operation."""
if not self._temporal_context.nexus_context.callback_url:
raise HandlerError(
"callback URL is required for a workflow update Nexus operation",
type=HandlerErrorType.BAD_REQUEST,
)
with self._reserve_async_start():
update_handle = await _start_nexus_backed_workflow_update(
temporal_context=self._temporal_context,
workflow_id=workflow_id,
update=update,
arg=arg,
args=args,
id=id,
result_type=result_type,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
# run_id=run_id,
# first_execution_run_id=first_execution_run_id,
)
# If the update has already completed, return the result synchronously
# This is in-line with the Go implementation as well
if update_handle._known_outcome is not None:
try:
result = await update_handle.result()
except temporalio.client.WorkflowUpdateFailedError as err:
raise OperationError(
str(err), state=OperationErrorState.FAILED
) from err
return TemporalOperationResult.sync(result)
nexus_handle: UpdateHandle[Any] = (
UpdateHandle._unsafe_from_client_workflow_update_handle(update_handle)
)
return TemporalOperationResult.async_token(nexus_handle.to_token())
Loading
Loading