diff --git a/src/art/client.py b/src/art/client.py index fe7dbdde..1eded4e5 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -1,7 +1,21 @@ +import inspect import os -from typing import Any, Iterable, Literal, TypedDict, cast +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Literal, + ParamSpec, + TypedDict, + TypeVar, + cast, + overload, +) import httpx +import tenacity from openai import AsyncOpenAI, BaseModel, _exceptions from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options from openai._compat import cached_property @@ -11,12 +25,103 @@ from openai._utils import is_mapping, maybe_transform from openai._version import __version__ from openai.pagination import AsyncCursorPage -from openai.resources.files import AsyncFiles # noqa: F401 -from openai.resources.models import AsyncModels # noqa: F401 from typing_extensions import override from .trajectories import TrajectoryGroup +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T") + + +@overload +def retry_status_codes( + fn: Callable[P, AsyncPaginator[R, AsyncCursorPage[R]]], +) -> Callable[P, AsyncIterable[R]]: ... + + +@overload +def retry_status_codes(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ... + + +@overload +def retry_status_codes(fn: Callable[P, R]) -> Callable[P, R]: ... + + +def retry_status_codes( + fn: ( + Callable[P, R] + | Callable[P, Awaitable[R]] + | Callable[P, AsyncPaginator[R, AsyncCursorPage[R]]] + ), +) -> Callable[P, R | AsyncIterable[R]] | Callable[P, Awaitable[R]]: + def is_retryable_status(exc: BaseException) -> bool: + if isinstance(exc, _exceptions.APIStatusError): + response = exc.response + if response is not None: + status = response.status_code + return status in {429, *range(500, 600)} + return False + + stop = tenacity.stop_after_attempt(3) + wait = tenacity.wait_random_exponential(multiplier=0.5, max=2.0) + retry = tenacity.retry_if_exception(is_retryable_status) + reraise = True + + async def retrying_awaitable(awaitable_fn: Callable[[], Awaitable[T]]) -> T: + async for attempt in tenacity.AsyncRetrying( + stop=stop, + wait=wait, + retry=retry, + reraise=reraise, + ): + with attempt: + return await awaitable_fn() + + # Unreachable if tenacity produces at least one attempt + raise RuntimeError("retry attempt sequence unexpectedly exhausted") + + if inspect.iscoroutinefunction(fn): + async_fn = cast(Callable[P, Awaitable[R]], fn) + + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return await retrying_awaitable(lambda: async_fn(*args, **kwargs)) + + return async_wrapper + + async def retrying_async_iterable( + async_paginator: AsyncPaginator[R, AsyncCursorPage[R]], + ) -> AsyncIterable[R]: + page = await retrying_awaitable(lambda: async_paginator) + for item in page._get_page_items(): + yield item + while page.has_next_page(): + page = await retrying_awaitable(lambda: page.get_next_page()) + for item in page._get_page_items(): + yield item + + sync_fn = cast(Callable[P, R], fn) + + def sync_or_async_paginator_wrapper( + *args: P.args, **kwargs: P.kwargs + ) -> R | AsyncIterable[R]: + for attempt in tenacity.Retrying( + stop=stop, + wait=wait, + retry=retry, + reraise=reraise, + ): + with attempt: + result = sync_fn(*args, **kwargs) + if isinstance(result, AsyncPaginator): + return retrying_async_iterable(result) + return result + + # Unreachable if tenacity produces at least one attempt + raise RuntimeError("retry attempt sequence unexpectedly exhausted") + + return sync_or_async_paginator_wrapper + class Model(BaseModel): id: str @@ -113,6 +218,7 @@ def checkpoints(self) -> "Checkpoints": class Checkpoints(AsyncAPIResource): + @retry_status_codes def list( self, *, @@ -137,6 +243,7 @@ def list( model=Checkpoint, ) + @retry_status_codes async def delete( self, *, model_id: str, steps: Iterable[int] ) -> DeleteCheckpointsResponse: @@ -174,6 +281,7 @@ def events(self) -> "TrainingJobEvents": class TrainingJobEvents(AsyncAPIResource): + @retry_status_codes def list( self, *,