11import time
22import random
3- from typing import TYPE_CHECKING , Type , Union , TypeVar , cast
4- from typing_extensions import ParamSpec
3+ from typing import TYPE_CHECKING , Type , Union , TypeVar , Callable , Annotated , Coroutine , cast
4+ from typing_extensions import ParamSpec , TypeAlias
55
66import anyio
77
8+ from runwayml ._utils import PropertyInfo
9+
810from .._models import BaseModel
9- from ..types .task_retrieve_response import TaskRetrieveResponse
11+ from ..types .task_retrieve_response import (
12+ Failed ,
13+ Pending ,
14+ Running ,
15+ Cancelled ,
16+ Succeeded ,
17+ Throttled ,
18+ TaskRetrieveResponse ,
19+ )
1020
1121if TYPE_CHECKING :
1222 from .._client import RunwayML , AsyncRunwayML
@@ -40,10 +50,6 @@ class NewTaskCreatedResponse(AwaitableTaskResponseMixin, BaseModel):
4050 id : str
4151
4252
43- class AwaitableTaskRetrieveResponse (AwaitableTaskResponseMixin , TaskRetrieveResponse ):
44- pass
45-
46-
4753class AsyncAwaitableTaskResponseMixin :
4854 async def wait_for_task_output (self , timeout : Union [float , None ] = 60 * 10 ) -> TaskRetrieveResponse : # type: ignore[empty-body]
4955 """
@@ -67,10 +73,6 @@ class AsyncNewTaskCreatedResponse(AsyncAwaitableTaskResponseMixin, BaseModel):
6773 id : str
6874
6975
70- class AsyncAwaitableTaskRetrieveResponse (AsyncAwaitableTaskResponseMixin , TaskRetrieveResponse ):
71- pass
72-
73-
7476def create_waitable_resource (base_class : Type [T ], client : "RunwayML" ) -> Type [NewTaskCreatedResponse ]:
7577 class WithClient (base_class ): # type: ignore[valid-type,misc]
7678 id : str
@@ -125,3 +127,74 @@ class TaskTimeoutError(Exception):
125127 def __init__ (self , task_details : TaskRetrieveResponse ):
126128 self .task_details = task_details
127129 super ().__init__ (f"Task timed out" )
130+
131+
132+
133+ class AwaitablePending (AwaitableTaskResponseMixin , Pending ): ...
134+ class AwaitableThrottled (AwaitableTaskResponseMixin , Throttled ): ...
135+ class AwaitableCancelled (AwaitableTaskResponseMixin , Cancelled ): ...
136+ class AwaitableRunning (AwaitableTaskResponseMixin , Running ): ...
137+ class AwaitableFailed (AwaitableTaskResponseMixin , Failed ): ...
138+ class AwaitableSucceeded (AwaitableTaskResponseMixin , Succeeded ): ...
139+
140+ AwaitableTaskRetrieveResponse : TypeAlias = Annotated [
141+ Union [AwaitablePending , AwaitableThrottled , AwaitableCancelled , AwaitableRunning , AwaitableFailed , AwaitableSucceeded ],
142+ PropertyInfo (discriminator = "status" )
143+ ]
144+
145+ class AsyncAwaitablePending (AsyncAwaitableTaskResponseMixin , Pending ): ...
146+ class AsyncAwaitableThrottled (AsyncAwaitableTaskResponseMixin , Throttled ): ...
147+ class AsyncAwaitableCancelled (AsyncAwaitableTaskResponseMixin , Cancelled ): ...
148+ class AsyncAwaitableRunning (AsyncAwaitableTaskResponseMixin , Running ): ...
149+ class AsyncAwaitableFailed (AsyncAwaitableTaskResponseMixin , Failed ): ...
150+ class AsyncAwaitableSucceeded (AsyncAwaitableTaskResponseMixin , Succeeded ): ...
151+
152+ AsyncAwaitableTaskRetrieveResponse : TypeAlias = Annotated [
153+ Union [AsyncAwaitablePending , AsyncAwaitableThrottled , AsyncAwaitableCancelled , AsyncAwaitableRunning , AsyncAwaitableFailed , AsyncAwaitableSucceeded ],
154+ PropertyInfo (discriminator = "status" )
155+ ]
156+
157+ def _make_sync_wait_for_task_output (client : "RunwayML" ) -> Callable [["AwaitableTaskResponseMixin" , Union [float , None ]], TaskRetrieveResponse ]:
158+ """Create a wait_for_task_output method bound to the given client."""
159+ def wait_for_task_output (self : "AwaitableTaskResponseMixin" , timeout : Union [float , None ] = 60 * 10 ) -> TaskRetrieveResponse :
160+ start_time = time .time ()
161+ while True :
162+ time .sleep (POLL_TIME + random .random () * POLL_JITTER - POLL_JITTER / 2 )
163+ task_details = client .tasks .retrieve (self .id ) # type: ignore[attr-defined]
164+ if task_details .status == "SUCCEEDED" :
165+ return task_details
166+ if task_details .status == "FAILED" :
167+ raise TaskFailedError (task_details )
168+ if timeout is not None and time .time () - start_time > timeout :
169+ raise TaskTimeoutError (task_details )
170+ return wait_for_task_output
171+
172+
173+ def inject_sync_wait_method (client : "RunwayML" , response : T ) -> T :
174+ """Inject the wait_for_task_output method onto the response instance."""
175+ import types
176+ response .wait_for_task_output = types .MethodType (_make_sync_wait_for_task_output (client ), response ) # type: ignore[attr-defined]
177+ return response
178+
179+
180+ def _make_async_wait_for_task_output (client : "AsyncRunwayML" ) -> Callable [["AsyncAwaitableTaskResponseMixin" , Union [float , None ]], Coroutine [None , None , TaskRetrieveResponse ]]:
181+ """Create an async wait_for_task_output method bound to the given client."""
182+ async def wait_for_task_output (self : "AsyncAwaitableTaskResponseMixin" , timeout : Union [float , None ] = 60 * 10 ) -> TaskRetrieveResponse :
183+ start_time = anyio .current_time ()
184+ while True :
185+ await anyio .sleep (POLL_TIME + random .random () * POLL_JITTER - POLL_JITTER / 2 )
186+ task_details = await client .tasks .retrieve (self .id ) # type: ignore[attr-defined]
187+ if task_details .status == "SUCCEEDED" :
188+ return task_details
189+ if task_details .status == "FAILED" or task_details .status == "CANCELLED" :
190+ raise TaskFailedError (task_details )
191+ if timeout is not None and anyio .current_time () - start_time > timeout :
192+ raise TaskTimeoutError (task_details )
193+ return wait_for_task_output
194+
195+
196+ def inject_async_wait_method (client : "AsyncRunwayML" , response : T ) -> T :
197+ """Inject the async wait_for_task_output method onto the response instance."""
198+ import types
199+ response .wait_for_task_output = types .MethodType (_make_async_wait_for_task_output (client ), response ) # type: ignore[attr-defined]
200+ return response
0 commit comments