Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
import json
import logging
import os
Expand Down Expand Up @@ -400,7 +399,8 @@ def change_bias(
old_state_dict = torch.load(
input_file, map_location=env.DEVICE, weights_only=True
)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
# Only copy model_params, not the entire state dict, to avoid memory bloat
model_state_dict = old_state_dict.get("model", old_state_dict)
model_params = model_state_dict["_extra_state"]["model_params"]
elif input_file.endswith(".pth"):
old_model = torch.jit.load(input_file, map_location=env.DEVICE)
Expand Down Expand Up @@ -495,14 +495,49 @@ def change_bias(
output_path = (
output if output is not None else input_file.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"]
if multi_task:
# For multi-task models, save only the selected branch as a single-head model
single_head_model = updated_model
wrapper = ModelWrapper(single_head_model)

# Create single-head model parameters
single_head_params = model_params["model_dict"][model_branch].copy()

# Save only the selected branch with single-head structure
if "model" in old_state_dict:
# For multi-task models, don't include optimizer state to reduce file size
state_to_save = {
"model": wrapper.state_dict(),
}
# Update the model's extra state to reflect single-head parameters
state_to_save["model"]["_extra_state"] = {
"model_params": single_head_params,
"train_infos": model_state_dict["_extra_state"].get(
"train_infos", {"lr": 0.001, "step": 0}
),
}
torch.save(state_to_save, output_path)
else:
state_to_save = wrapper.state_dict()
state_to_save["_extra_state"] = {
"model_params": single_head_params,
"train_infos": model_state_dict["_extra_state"].get(
"train_infos", {"lr": 0.001, "step": 0}
),
}
torch.save(state_to_save, output_path)
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]
torch.save(old_state_dict, output_path)
# For single-task models, keep existing behavior
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = model_state_dict[
"_extra_state"
]
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]
torch.save(old_state_dict, output_path)
else:
# for .pth
output_path = (
Expand Down