diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 06a7603cc0..ee4733de74 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import argparse -import copy import json import logging import os @@ -400,7 +399,8 @@ def change_bias( old_state_dict = torch.load( input_file, map_location=env.DEVICE, weights_only=True ) - model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) + # Only copy model_params, not the entire state dict, to avoid memory bloat + model_state_dict = old_state_dict.get("model", old_state_dict) model_params = model_state_dict["_extra_state"]["model_params"] elif input_file.endswith(".pth"): old_model = torch.jit.load(input_file, map_location=env.DEVICE) @@ -495,14 +495,49 @@ def change_bias( output_path = ( output if output is not None else input_file.replace(".pt", "_updated.pt") ) - wrapper = ModelWrapper(model) - if "model" in old_state_dict: - old_state_dict["model"] = wrapper.state_dict() - old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"] + if multi_task: + # For multi-task models, save only the selected branch as a single-head model + single_head_model = updated_model + wrapper = ModelWrapper(single_head_model) + + # Create single-head model parameters + single_head_params = model_params["model_dict"][model_branch].copy() + + # Save only the selected branch with single-head structure + if "model" in old_state_dict: + # For multi-task models, don't include optimizer state to reduce file size + state_to_save = { + "model": wrapper.state_dict(), + } + # Update the model's extra state to reflect single-head parameters + state_to_save["model"]["_extra_state"] = { + "model_params": single_head_params, + "train_infos": model_state_dict["_extra_state"].get( + "train_infos", {"lr": 0.001, "step": 0} + ), + } + torch.save(state_to_save, output_path) + else: + state_to_save = wrapper.state_dict() + state_to_save["_extra_state"] = { + "model_params": single_head_params, + "train_infos": model_state_dict["_extra_state"].get( + "train_infos", {"lr": 0.001, "step": 0} + ), + } + torch.save(state_to_save, output_path) else: - old_state_dict = wrapper.state_dict() - old_state_dict["_extra_state"] = model_state_dict["_extra_state"] - torch.save(old_state_dict, output_path) + # For single-task models, keep existing behavior + wrapper = ModelWrapper(model) + if "model" in old_state_dict: + old_state_dict["model"] = wrapper.state_dict() + old_state_dict["model"]["_extra_state"] = model_state_dict[ + "_extra_state" + ] + else: + old_state_dict = wrapper.state_dict() + old_state_dict["_extra_state"] = model_state_dict["_extra_state"] + torch.save(old_state_dict, output_path) else: # for .pth output_path = (