From 2c2f00d4172dc338867af76132c3427ddfad977f Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Tue, 7 Oct 2025 22:46:14 +0000 Subject: [PATCH 1/2] feat: Add retries to client.py --- src/art/client.py | 78 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/src/art/client.py b/src/art/client.py index fe7dbdde..b1d4e5f5 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -1,7 +1,20 @@ +import inspect import os -from typing import Any, Iterable, Literal, TypedDict, cast +from typing import ( + Any, + Awaitable, + Callable, + Iterable, + Literal, + ParamSpec, + TypedDict, + TypeVar, + cast, + overload, +) import httpx +import tenacity from openai import AsyncOpenAI, BaseModel, _exceptions from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options from openai._compat import cached_property @@ -11,12 +24,68 @@ from openai._utils import is_mapping, maybe_transform from openai._version import __version__ from openai.pagination import AsyncCursorPage -from openai.resources.files import AsyncFiles # noqa: F401 -from openai.resources.models import AsyncModels # noqa: F401 from typing_extensions import override from .trajectories import TrajectoryGroup +P = ParamSpec("P") +R = TypeVar("R") + + +@overload +def retry_status_codes(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ... + + +@overload +def retry_status_codes(fn: Callable[P, R]) -> Callable[P, R]: ... + + +def retry_status_codes( + fn: Callable[P, R] | Callable[P, Awaitable[R]], +) -> Callable[P, R] | Callable[P, Awaitable[R]]: + def _is_retryable_status(exc: BaseException) -> bool: + if isinstance(exc, _exceptions.APIStatusError): + response = exc.response + if response is not None: + status = response.status_code + return status in {429, *range(500, 600)} + return False + + if inspect.iscoroutinefunction(fn): + async_fn = cast(Callable[P, Awaitable[R]], fn) + + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + async for attempt in tenacity.AsyncRetrying( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_random_exponential(multiplier=0.5, max=2.0), + retry=tenacity.retry_if_exception(_is_retryable_status), + reraise=True, + ): + with attempt: + return await async_fn(*args, **kwargs) + + # Unreachable if tenacity produces at least one attempt + raise RuntimeError("retry attempt sequence unexpectedly exhausted") + + return async_wrapper + + sync_fn = cast(Callable[P, R], fn) + + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + for attempt in tenacity.Retrying( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_random_exponential(multiplier=0.5, max=2.0), + retry=tenacity.retry_if_exception(_is_retryable_status), + reraise=True, + ): + with attempt: + return sync_fn(*args, **kwargs) + + # Unreachable if tenacity produces at least one attempt + raise RuntimeError("retry attempt sequence unexpectedly exhausted") + + return sync_wrapper + class Model(BaseModel): id: str @@ -113,6 +182,7 @@ def checkpoints(self) -> "Checkpoints": class Checkpoints(AsyncAPIResource): + @retry_status_codes def list( self, *, @@ -137,6 +207,7 @@ def list( model=Checkpoint, ) + @retry_status_codes async def delete( self, *, model_id: str, steps: Iterable[int] ) -> DeleteCheckpointsResponse: @@ -174,6 +245,7 @@ def events(self) -> "TrainingJobEvents": class TrainingJobEvents(AsyncAPIResource): + @retry_status_codes def list( self, *, From 3678d431b5e0c6e148bb2f9f7040e442c19dbea9 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Wed, 8 Oct 2025 00:15:15 +0000 Subject: [PATCH 2/2] chore: Support for AsyncPaginator --- src/art/client.py | 78 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/src/art/client.py b/src/art/client.py index b1d4e5f5..1eded4e5 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -2,6 +2,7 @@ import os from typing import ( Any, + AsyncIterable, Awaitable, Callable, Iterable, @@ -30,6 +31,13 @@ P = ParamSpec("P") R = TypeVar("R") +T = TypeVar("T") + + +@overload +def retry_status_codes( + fn: Callable[P, AsyncPaginator[R, AsyncCursorPage[R]]], +) -> Callable[P, AsyncIterable[R]]: ... @overload @@ -41,9 +49,13 @@ def retry_status_codes(fn: Callable[P, R]) -> Callable[P, R]: ... def retry_status_codes( - fn: Callable[P, R] | Callable[P, Awaitable[R]], -) -> Callable[P, R] | Callable[P, Awaitable[R]]: - def _is_retryable_status(exc: BaseException) -> bool: + fn: ( + Callable[P, R] + | Callable[P, Awaitable[R]] + | Callable[P, AsyncPaginator[R, AsyncCursorPage[R]]] + ), +) -> Callable[P, R | AsyncIterable[R]] | Callable[P, Awaitable[R]]: + def is_retryable_status(exc: BaseException) -> bool: if isinstance(exc, _exceptions.APIStatusError): response = exc.response if response is not None: @@ -51,40 +63,64 @@ def _is_retryable_status(exc: BaseException) -> bool: return status in {429, *range(500, 600)} return False + stop = tenacity.stop_after_attempt(3) + wait = tenacity.wait_random_exponential(multiplier=0.5, max=2.0) + retry = tenacity.retry_if_exception(is_retryable_status) + reraise = True + + async def retrying_awaitable(awaitable_fn: Callable[[], Awaitable[T]]) -> T: + async for attempt in tenacity.AsyncRetrying( + stop=stop, + wait=wait, + retry=retry, + reraise=reraise, + ): + with attempt: + return await awaitable_fn() + + # Unreachable if tenacity produces at least one attempt + raise RuntimeError("retry attempt sequence unexpectedly exhausted") + if inspect.iscoroutinefunction(fn): async_fn = cast(Callable[P, Awaitable[R]], fn) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - async for attempt in tenacity.AsyncRetrying( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_random_exponential(multiplier=0.5, max=2.0), - retry=tenacity.retry_if_exception(_is_retryable_status), - reraise=True, - ): - with attempt: - return await async_fn(*args, **kwargs) - - # Unreachable if tenacity produces at least one attempt - raise RuntimeError("retry attempt sequence unexpectedly exhausted") + return await retrying_awaitable(lambda: async_fn(*args, **kwargs)) return async_wrapper + async def retrying_async_iterable( + async_paginator: AsyncPaginator[R, AsyncCursorPage[R]], + ) -> AsyncIterable[R]: + page = await retrying_awaitable(lambda: async_paginator) + for item in page._get_page_items(): + yield item + while page.has_next_page(): + page = await retrying_awaitable(lambda: page.get_next_page()) + for item in page._get_page_items(): + yield item + sync_fn = cast(Callable[P, R], fn) - def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + def sync_or_async_paginator_wrapper( + *args: P.args, **kwargs: P.kwargs + ) -> R | AsyncIterable[R]: for attempt in tenacity.Retrying( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_random_exponential(multiplier=0.5, max=2.0), - retry=tenacity.retry_if_exception(_is_retryable_status), - reraise=True, + stop=stop, + wait=wait, + retry=retry, + reraise=reraise, ): with attempt: - return sync_fn(*args, **kwargs) + result = sync_fn(*args, **kwargs) + if isinstance(result, AsyncPaginator): + return retrying_async_iterable(result) + return result # Unreachable if tenacity produces at least one attempt raise RuntimeError("retry attempt sequence unexpectedly exhausted") - return sync_wrapper + return sync_or_async_paginator_wrapper class Model(BaseModel):