@@ -849,28 +849,30 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
849849 logger .info ("Create zero cost checkpoint manager done." )
850850
851851 def _load_flex_checkpoint (self , resume_from_checkpoint ):
852+ def get_metadata_file_name (path ):
853+ files = os .listdir (path )
854+ metadata_files = [f for f in files if f .endswith (".metadata" )]
855+ assert len (metadata_files ) > 0 , f"Found no metadata files in { path } "
856+ assert len (metadata_files ) == 1 , f"Found multiple metadata files in { path } "
857+ return metadata_files [0 ]
858+
852859 model_sharded_state_dict = self .model .sharded_state_dict ()
853860 master_weights_path = os .path .join (resume_from_checkpoint , MASTER_WEIGHT_DIC )
854861 opt_states_path = os .path .join (resume_from_checkpoint , OPTIMIZER_STATE_DIC )
855862 model_states_path = os .path .join (resume_from_checkpoint , MODEL_STATE_DIC )
856863 if not self .args .ignore_load_lr_and_optim :
857864 state_dict_metadata = {}
858865 metadata_paths = [
859- os .path .join (model_states_path , "0.metadata" ),
860- os .path .join (opt_states_path , "0.metadata" ),
861- os .path .join (master_weights_path , "0.metadata" ),
866+ os .path .join (model_states_path , get_metadata_file_name ( model_states_path ) ),
867+ os .path .join (opt_states_path , get_metadata_file_name ( opt_states_path ) ),
868+ os .path .join (master_weights_path , get_metadata_file_name ( master_weights_path ) ),
862869 ]
863870
864871 for metadata_file in metadata_paths :
865872 if not os .path .exists (metadata_file ):
866873 raise FileNotFoundError (f"Metadata file not found: { metadata_file } " )
867874 metadata = paddle .load (metadata_file )
868- if hasattr (metadata , "state_dict_metadata" ):
869- state_dict_metadata .update (metadata .state_dict_metadata )
870- else :
871- raise AttributeError (
872- f"Loaded metadata from { metadata_file } does not have 'state_dict_metadata' attribute"
873- )
875+ state_dict_metadata .update (metadata .state_dict_metadata )
874876
875877 init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
876878
@@ -915,13 +917,9 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
915917 )
916918
917919 optimizer_state_pin = {}
918-
919920 for k , v in opt_states .items ():
920- tmp = v .local_tensor
921- optimizer_state_pin [k ] = tmp .pin_memory ()
922- tmp ._clear_to_zero_allocation ()
923- del tmp
924-
921+ optimizer_state_pin [k ] = v .local_tensor .pin_memory ()
922+ del opt_states
925923 for k , v in master_weights .items ():
926924 new_v = ShardedWeight (
927925 key = v .key ,
@@ -941,21 +939,16 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
941939 )
942940
943941 master_weights_pin = {}
944-
945942 for k , v in master_weights .items ():
946- tmp = v .local_tensor
947- master_weights_pin [k ] = tmp .pin_memory ()
948- tmp ._clear_to_zero_allocation ()
949- del tmp
943+ master_weights_pin [k ] = v .local_tensor .pin_memory ()
944+ del master_weights
950945
951946 optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
952-
953947 optimizer_sharded_state_dict_pin = {** master_weights_pin , ** optimizer_state_pin }
954948
955949 for k , v in optimizer_sharded_state_dict .items ():
956950 source_tensor = optimizer_sharded_state_dict_pin [k ]
957- source_tensor ._share_buffer_to (v .local_tensor )
958- del source_tensor
951+ v .local_tensor .set_value (source_tensor )
959952
960953 if isinstance (self .optimizer ._inner_opt , DygraphShardingOptimizerV2 ):
961954 color_to_comm_buffer_list = self .optimizer ._color_to_comm_buffer_list
@@ -966,7 +959,7 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
966959 state_dict = self .model .state_dict ()
967960 for k , v in state_dict .items ():
968961 new_v = paddle .zeros_like (v )
969- new_v . _share_buffer_to ( v )
962+ v . set_value ( new_v )
970963
971964 self ._load_scheduler (resume_from_checkpoint )
972965
0 commit comments