diff --git a/gemma/gm/ckpts/_checkpoint.py b/gemma/gm/ckpts/_checkpoint.py index 82e9c8ae..3455972c 100644 --- a/gemma/gm/ckpts/_checkpoint.py +++ b/gemma/gm/ckpts/_checkpoint.py @@ -182,6 +182,7 @@ def load_params( sharding: kd.sharding.ShardingTree | None = None, quantize: bool = False, use_ocdbt: bool = True, + block_until_ready: bool = False, ) -> Params: """Restore the params from a checkpoint. @@ -251,6 +252,9 @@ def load_params( output_with_skip = metadata.make_tree_for_params(params) restore_fn = functools.partial(ckpt.restore, path) output = _partial_restore(restore_fn, output_with_skip) + if block_until_ready: + output.block_until_ready() + ckpt.wait_until_finished() # TODO(epot): Better API. Currently this do not quantize the weights, but # just refactor the params to the QAT structure.