Skip to content

Commit e1dcffb

Browse files
committed
fix load_flex_checkpoint
1 parent 1d439b0 commit e1dcffb

File tree

1 file changed

+17
-24
lines changed

1 file changed

+17
-24
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -849,28 +849,30 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
849849
logger.info("Create zero cost checkpoint manager done.")
850850

851851
def _load_flex_checkpoint(self, resume_from_checkpoint):
852+
def get_metadata_file_name(path):
853+
files = os.listdir(path)
854+
metadata_files = [f for f in files if f.endswith(".metadata")]
855+
assert len(metadata_files) > 0, f"Found no metadata files in {path}"
856+
assert len(metadata_files) == 1, f"Found multiple metadata files in {path}"
857+
return metadata_files[0]
858+
852859
model_sharded_state_dict = self.model.sharded_state_dict()
853860
master_weights_path = os.path.join(resume_from_checkpoint, MASTER_WEIGHT_DIC)
854861
opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC)
855862
model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC)
856863
if not self.args.ignore_load_lr_and_optim:
857864
state_dict_metadata = {}
858865
metadata_paths = [
859-
os.path.join(model_states_path, "0.metadata"),
860-
os.path.join(opt_states_path, "0.metadata"),
861-
os.path.join(master_weights_path, "0.metadata"),
866+
os.path.join(model_states_path, get_metadata_file_name(model_states_path)),
867+
os.path.join(opt_states_path, get_metadata_file_name(opt_states_path)),
868+
os.path.join(master_weights_path, get_metadata_file_name(master_weights_path)),
862869
]
863870

864871
for metadata_file in metadata_paths:
865872
if not os.path.exists(metadata_file):
866873
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
867874
metadata = paddle.load(metadata_file)
868-
if hasattr(metadata, "state_dict_metadata"):
869-
state_dict_metadata.update(metadata.state_dict_metadata)
870-
else:
871-
raise AttributeError(
872-
f"Loaded metadata from {metadata_file} does not have 'state_dict_metadata' attribute"
873-
)
875+
state_dict_metadata.update(metadata.state_dict_metadata)
874876

875877
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
876878

@@ -915,13 +917,9 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
915917
)
916918

917919
optimizer_state_pin = {}
918-
919920
for k, v in opt_states.items():
920-
tmp = v.local_tensor
921-
optimizer_state_pin[k] = tmp.pin_memory()
922-
tmp._clear_to_zero_allocation()
923-
del tmp
924-
921+
optimizer_state_pin[k] = v.local_tensor.pin_memory()
922+
del opt_states
925923
for k, v in master_weights.items():
926924
new_v = ShardedWeight(
927925
key=v.key,
@@ -941,21 +939,16 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
941939
)
942940

943941
master_weights_pin = {}
944-
945942
for k, v in master_weights.items():
946-
tmp = v.local_tensor
947-
master_weights_pin[k] = tmp.pin_memory()
948-
tmp._clear_to_zero_allocation()
949-
del tmp
943+
master_weights_pin[k] = v.local_tensor.pin_memory()
944+
del master_weights
950945

951946
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
952-
953947
optimizer_sharded_state_dict_pin = {**master_weights_pin, **optimizer_state_pin}
954948

955949
for k, v in optimizer_sharded_state_dict.items():
956950
source_tensor = optimizer_sharded_state_dict_pin[k]
957-
source_tensor._share_buffer_to(v.local_tensor)
958-
del source_tensor
951+
v.local_tensor.set_value(source_tensor)
959952

960953
if isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2):
961954
color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
@@ -966,7 +959,7 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
966959
state_dict = self.model.state_dict()
967960
for k, v in state_dict.items():
968961
new_v = paddle.zeros_like(v)
969-
new_v._share_buffer_to(v)
962+
v.set_value(new_v)
970963

971964
self._load_scheduler(resume_from_checkpoint)
972965

0 commit comments

Comments
 (0)