From e6e7a8676fe7257d0530dbfe8e8b87221cb087f5 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 30 Sep 2025 14:57:12 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 813440675 --- pathwaysutils/_initialize.py | 4 +++- pathwaysutils/persistence/orbax_handler.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pathwaysutils/_initialize.py b/pathwaysutils/_initialize.py index 0e42e59..a8df2a0 100644 --- a/pathwaysutils/_initialize.py +++ b/pathwaysutils/_initialize.py @@ -92,7 +92,9 @@ def initialize() -> None: profiling.monkey_patch_jax() # TODO: b/365549911 - Remove when OCDBT-compatible if _is_persistence_enabled(): - orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1)) + orbax_handler.register_pathways_handlers( + timeout=datetime.timedelta(hours=1), + ) # Turn off JAX compilation cache because Pathways handles its own # compilation cache. diff --git a/pathwaysutils/persistence/orbax_handler.py b/pathwaysutils/persistence/orbax_handler.py index c0e72e8..89b908b 100644 --- a/pathwaysutils/persistence/orbax_handler.py +++ b/pathwaysutils/persistence/orbax_handler.py @@ -49,16 +49,18 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler): def __init__( self, - read_timeout: datetime.timedelta | None = None, + timeout: datetime.timedelta | None = None, use_ocdbt: bool = False, ): - """Constructor. + """Orbax array handler for Pathways on Cloud with Persistence API. Args: - read_timeout: Duration indicating the timeout for reading arrays + timeout: Duration indicating the timeout for reading and writing arrays use_ocdbt: allows using Tensorstore OCDBT driver. """ - self._read_timeout = read_timeout + if timeout is None: + timeout = datetime.timedelta(hours=1) + self.timeout = timeout if use_ocdbt: raise ValueError("OCDBT not supported for Pathways.") @@ -92,7 +94,7 @@ async def serialize( self._wait_for_directory_creation_signals() locations, names = extract_parent_dir_and_name(infos) - f = functools.partial(helper.write_one_array, timeout=self._read_timeout) + f = functools.partial(helper.write_one_array, timeout=self.timeout) futures_results = list(map(f, locations, names, values)) return [ @@ -181,7 +183,7 @@ async def deserialize( grouped_global_shapes, grouped_shardings, global_mesh.devices, - timeout=self._read_timeout, + timeout=self.timeout, ) # each persistence call is awaited serially. read_future.result() @@ -191,7 +193,7 @@ async def deserialize( def register_pathways_handlers( - read_timeout: datetime.timedelta | None = None, + timeout: datetime.timedelta | None = None, ): """Function that must be called before saving or restoring with Pathways.""" logger.debug( @@ -200,7 +202,7 @@ def register_pathways_handlers( type_handlers.register_type_handler( jax.Array, CloudPathwaysArrayHandler( - read_timeout=read_timeout, + timeout=timeout, ), override=True, )