11import tqdm
22import contextlib
3- from compressed_tensors .utils import replace_module , match_named_modules
3+ from compressed_tensors .utils import replace_module , delete_offload_module , register_offload_module
44from transformers import PreTrainedModel
55
66from llmcompressor .modeling .deepseek_v3 import replace as replace_deepseekv3
@@ -47,22 +47,26 @@ def update_qwen3_moe(model, stack):
4747
4848def update_gpt_oss_moe (model : PreTrainedModel , stack ):
4949 @contextlib .contextmanager
50- def replace_context (model , name , module ):
50+ def replace_context (parent , name , module ):
5151 linear = GptOssExpertsLinear (module )
52- replace_module (model , name , linear )
5352 del module
53+ delete_offload_module (parent , name )
54+ register_offload_module (parent , name , linear )
5455
5556 yield
5657
5758 restored = linear .to_original ()
58- replace_module (model , name , restored )
59+ del linear
60+ delete_offload_module (parent , name )
61+ register_offload_module (parent , name , restored )
5962
60- # TODO: need to think about duplicates
63+ # TODO: need to consider when replace module is duplicated in structure
6164 modules = list (model .named_modules ())
6265 for name , module in tqdm .tqdm (modules , desc = "Checking modules for replacements" ):
63- cls_name = module .__class__ .__name__
64- if cls_name == "GptOssExperts" :
65- stack .enter_context (replace_context (model , name , module ))
66+ children = list (module .named_children ())
67+ for child_name , child in children :
68+ if child .__class__ .__name__ == "GptOssExperts" :
69+ stack .enter_context (replace_context (module , child_name , child ))
6670
6771
6872
@@ -78,3 +82,27 @@ def moe_calibration_context(model: PreTrainedModel, stack):
7882 cls_name = model .__class__ .__name__
7983 if cls_name in moe_context :
8084 moe_context .get (cls_name )(model , stack )
85+
86+
87+
88+ # import torch
89+ # from accelerate.hooks import (
90+ # AlignDevicesHook,
91+ # def replace_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
92+ # hook = getattr(base, name)._hf_hook
93+ # delete_offload_module(base, name)
94+
95+ # weights_map = PrefixedDataset(
96+ # hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix.remove_suffix(name + ".")}"
97+ # )
98+
99+ # parent_hook = AlignDevicesHook(
100+ # execution_device=hook.execution_device,
101+ # offload=hook.offload,
102+ # io_same_device=False,
103+ # weights_map=weights_map,
104+ # offload_buffers=offload_buffers,
105+ # place_submodules=place_submodules,
106+ # skip_keys=None,
107+ # tied_params_map=hook.tied_params_map,
108+ # )
0 commit comments