From 7fd6dbd0db356ca2f2fbdf6c7d590eab0dc8a7d2 Mon Sep 17 00:00:00 2001 From: Shubhra Pandit Date: Wed, 17 Sep 2025 14:01:11 +0000 Subject: [PATCH 1/8] Add file to linearize and quantize the gpt-oss models --- src/llmcompressor/modeling/gpt_oss.py | 240 ++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 src/llmcompressor/modeling/gpt_oss.py diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py new file mode 100644 index 000000000..eae925978 --- /dev/null +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -0,0 +1,240 @@ +import gc +import torch +from torch import nn +import os +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.utils.dev import skip_weights_initialize +from llmcompressor.modifiers.quantization import QuantizationModifier + +def convert_model_for_quantization_gptoss(model): + to_delete = [] + + for name, module in model.named_modules(): + if not (hasattr(module, "experts") and hasattr(module, "router")): + continue + experts = module.experts + if not (hasattr(experts, "gate_up_proj") and hasattr(experts, "down_proj")): + continue + + gup = experts.gate_up_proj # [E, H, 2I] + dwn = experts.down_proj # [E, I, H] + assert gup.dim() == 3 and dwn.dim() == 3 + E, gH, g2i = gup.shape + Ed, dI, dH = dwn.shape + assert E == Ed and gH == dH + assert g2i % 2 == 0 + intermediate = g2i // 2 + hidden = gH + + parent, child_name = _get_parent_and_child(model, name) + top_k = int(max(1, min(_get_top_k(model.config) or 1, E))) + seq = SequentialGPTOSSMoE( + hidden_size=hidden, + intermediate_size=intermediate, + top_k=top_k, + original_moe=module, + ) + parent._modules[child_name] = seq + to_delete.append(module) + print(f"[GPT-OSS] Patched {name} -> SequentialGPTOSSMoE (E={E}, inter={intermediate}, hidden={hidden})", flush=True) + + for m in to_delete: + del m + if to_delete: + gc.collect() + try: + torch.cuda.empty_cache() + except Exception: + pass + + +def _get_parent_and_child(model, dotted_name: str): + parts = dotted_name.split(".") + parent = model + for p in parts[:-1]: + parent = getattr(parent, p) + return parent, parts[-1] + + +def _get_hidden_size(config): + return getattr(config, "hidden_size", None) or getattr(config, "n_embd", None) + + +def _get_top_k(config): + # GPT-OSS MoE: experts per token + return getattr(config, "num_experts_per_tok", None) or getattr(config, "num_experts_per_token", 1) + + +class GPTOSSMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.alpha = 1.702 + self.limit = 7.0 + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True) + + def forward(self, x): + gate = self.gate_proj(x) + up = self.up_proj(x) + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + act = (up + 1) * glu + return self.down_proj(act) + + +class SequentialGPTOSSMoE(nn.Module): + """ + Replaces GPT-OSS fused-expert MoE with per-expert GPTOSSMLP modules. + Copies weights from fused tensors and reuses the original router and optional shared_expert. + """ + def __init__(self, hidden_size, intermediate_size, top_k, original_moe): + super().__init__() + self.hidden_size = hidden_size + self.intermediate = intermediate_size + self.top_k = top_k + self.router = original_moe.router + self.shared_expert = getattr(original_moe, "shared_expert", None) + + # Number of experts + E = original_moe.experts.gate_up_proj.shape[0] + self.num_experts = E + + # Build per-expert MLPs + self.experts = nn.ModuleList() + with skip_weights_initialize(): + for _ in range(E): + self.experts.append(GPTOSSMLP(hidden_size, intermediate_size)) + + gup = original_moe.experts.gate_up_proj # [E, H, 2I] + gup_b = original_moe.experts.gate_up_proj_bias # [E, 2I] + dwn = original_moe.experts.down_proj # [E, I, H] + dwn_b = original_moe.experts.down_proj_bias # [E, H] + + for i in range(E): + gup_i = gup[i] # [H, 2I] + gate_w = gup_i[:, ::2] # [H, I] + up_w = gup_i[:, 1::2] # [H, I] + down_w = dwn[i] # [I, H] + + mlp = self.experts[i] + mlp.gate_proj.weight.data.copy_(gate_w.T) # [I, H] + mlp.up_proj.weight.data.copy_(up_w.T) # [I, H] + mlp.down_proj.weight.data.copy_(down_w.T) # [H, I] + + gate_b = gup_b[i] # [2I] + mlp.gate_proj.bias.data.copy_(gate_b[::2]) # [I] + mlp.up_proj.bias.data.copy_(gate_b[1::2]) # [I] + mlp.down_proj.bias.data.copy_(dwn_b[i]) # [H] + + + + def forward(self, hidden_states): + B, T, H = hidden_states.shape + x = hidden_states.reshape(-1, H) + + # Use the original router (it returns scores and indices already softmaxed over top-k) + router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k] + + out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x) + + # Accumulate expert outputs for chosen experts only + for j in range(self.top_k): + idx = router_indices[:, j] + w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1) + for e in range(self.num_experts): + mask = (idx == e) + if not torch.any(mask): + continue + out[mask] += self.experts[e](x[mask]) * w[mask] + + out = out.view(B, T, H) + router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder + return out, router_scores + + +model_id = "/mnt/nvme4/openai/gpt-oss-120b-BF16" + +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, +) +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + +convert_model_for_quantization_gptoss(model) + +# ----------------------------- +# Calibration data & preprocessing +# ----------------------------- +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 128 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + +ds = ds.map(preprocess) + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# ----------------------------- +# Quantization recipe +# ----------------------------- +recipe = QuantizationModifier( + targets="Linear", + scheme="FP8_DYNAMIC", + ignore=[ + "re:.*lm_head", + 're:.*self_attn', + 're:.*attn', + 're:.*attention.*', + 're:.*router', + ], +) + +SAVE_DIR = f"/proving-grounds/machine/shubhra/gpt_oss_120b/{os.path.basename(model_id)}-{recipe.scheme}_ns{NUM_CALIBRATION_SAMPLES}_fixed_CTgb2df366" + +# Oneshot quantization +oneshot( + model=model, + tokenizer=tokenizer, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=SAVE_DIR, +) + +# Save compressed +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) From 9a685b44e501e6d211fd8a6b79d577c79282a449 Mon Sep 17 00:00:00 2001 From: Shubhra Pandit Date: Wed, 17 Sep 2025 11:36:58 -0400 Subject: [PATCH 2/8] Update src/llmcompressor/modeling/gpt_oss.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llmcompressor/modeling/gpt_oss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index eae925978..5505a0b6d 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -214,10 +214,10 @@ def tokenize(sample): scheme="FP8_DYNAMIC", ignore=[ "re:.*lm_head", - 're:.*self_attn', - 're:.*attn', - 're:.*attention.*', - 're:.*router', + "re:.*self_attn", + "re:.*attn", + "re:.*attention.*", + "re:.*router", ], ) From 8e11dcbe552b1e12b10cf49551ee8a9c926eb8a9 Mon Sep 17 00:00:00 2001 From: Shubhra Pandit Date: Wed, 17 Sep 2025 11:38:35 -0400 Subject: [PATCH 3/8] Update src/llmcompressor/modeling/gpt_oss.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llmcompressor/modeling/gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index 5505a0b6d..2a8b71308 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -47,8 +47,8 @@ def convert_model_for_quantization_gptoss(model): gc.collect() try: torch.cuda.empty_cache() - except Exception: - pass + except Exception as e: + print(f"[GPT-OSS] Warning: Failed to empty CUDA cache: {e}", flush=True) def _get_parent_and_child(model, dotted_name: str): From 858ea4851f3ebb583d57fd06c2a9221d09c66dfe Mon Sep 17 00:00:00 2001 From: Shubhra Pandit Date: Wed, 17 Sep 2025 11:40:12 -0400 Subject: [PATCH 4/8] Update src/llmcompressor/modeling/gpt_oss.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llmcompressor/modeling/gpt_oss.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index 2a8b71308..d51f1beff 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -148,10 +148,9 @@ def forward(self, hidden_states): for j in range(self.top_k): idx = router_indices[:, j] w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1) - for e in range(self.num_experts): + unique_experts = torch.unique(idx) + for e in unique_experts: mask = (idx == e) - if not torch.any(mask): - continue out[mask] += self.experts[e](x[mask]) * w[mask] out = out.view(B, T, H) From 4587569ee3c3b3928be921c396970d93971f4f43 Mon Sep 17 00:00:00 2001 From: Shubhra Pandit Date: Wed, 17 Sep 2025 15:42:20 +0000 Subject: [PATCH 5/8] Remove hardcoded paths --- src/llmcompressor/modeling/gpt_oss.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index d51f1beff..fae0fa3fd 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -47,8 +47,8 @@ def convert_model_for_quantization_gptoss(model): gc.collect() try: torch.cuda.empty_cache() - except Exception as e: - print(f"[GPT-OSS] Warning: Failed to empty CUDA cache: {e}", flush=True) + except Exception: + pass def _get_parent_and_child(model, dotted_name: str): @@ -148,9 +148,10 @@ def forward(self, hidden_states): for j in range(self.top_k): idx = router_indices[:, j] w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1) - unique_experts = torch.unique(idx) - for e in unique_experts: + for e in range(self.num_experts): mask = (idx == e) + if not torch.any(mask): + continue out[mask] += self.experts[e](x[mask]) * w[mask] out = out.view(B, T, H) @@ -158,7 +159,7 @@ def forward(self, hidden_states): return out, router_scores -model_id = "/mnt/nvme4/openai/gpt-oss-120b-BF16" +model_id = "unsloth/gpt-oss-120b-BF16" model = AutoModelForCausalLM.from_pretrained( model_id, @@ -213,14 +214,14 @@ def tokenize(sample): scheme="FP8_DYNAMIC", ignore=[ "re:.*lm_head", - "re:.*self_attn", - "re:.*attn", - "re:.*attention.*", - "re:.*router", + 're:.*self_attn', + 're:.*attn', + 're:.*attention.*', + 're:.*router', ], ) -SAVE_DIR = f"/proving-grounds/machine/shubhra/gpt_oss_120b/{os.path.basename(model_id)}-{recipe.scheme}_ns{NUM_CALIBRATION_SAMPLES}_fixed_CTgb2df366" +SAVE_DIR = f"{model_id.split('/')[-1]}-FP8-Dynamic" # Oneshot quantization oneshot( From 35fa394b042208e878a42291174616e62d51f107 Mon Sep 17 00:00:00 2001 From: Shubhra Pandit Date: Thu, 18 Sep 2025 13:51:40 +0000 Subject: [PATCH 6/8] Remove dataset loading and processing --- src/llmcompressor/modeling/gpt_oss.py | 56 ++++----------------------- 1 file changed, 8 insertions(+), 48 deletions(-) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index fae0fa3fd..7f7d25f58 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -2,7 +2,6 @@ import torch from torch import nn import os -from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot @@ -47,8 +46,8 @@ def convert_model_for_quantization_gptoss(model): gc.collect() try: torch.cuda.empty_cache() - except Exception: - pass + except Exception as e: + print(f"[GPT-OSS] Warning: Failed to empty CUDA cache: {e}", flush=True) def _get_parent_and_child(model, dotted_name: str): @@ -148,10 +147,9 @@ def forward(self, hidden_states): for j in range(self.top_k): idx = router_indices[:, j] w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1) - for e in range(self.num_experts): + unique_experts = torch.unique(idx) + for e in unique_experts: mask = (idx == e) - if not torch.any(mask): - continue out[mask] += self.experts[e](x[mask]) * w[mask] out = out.view(B, T, H) @@ -171,41 +169,6 @@ def forward(self, hidden_states): convert_model_for_quantization_gptoss(model) -# ----------------------------- -# Calibration data & preprocessing -# ----------------------------- -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 128 -MAX_SEQUENCE_LENGTH = 2048 - - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - -ds = ds.map(preprocess) - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - -ds = ds.map(tokenize, remove_columns=ds.column_names) - # ----------------------------- # Quantization recipe # ----------------------------- @@ -214,10 +177,10 @@ def tokenize(sample): scheme="FP8_DYNAMIC", ignore=[ "re:.*lm_head", - 're:.*self_attn', - 're:.*attn', - 're:.*attention.*', - 're:.*router', + "re:.*self_attn", + "re:.*attn", + "re:.*attention.*", + "re:.*router", ], ) @@ -227,10 +190,7 @@ def tokenize(sample): oneshot( model=model, tokenizer=tokenizer, - dataset=ds, recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, output_dir=SAVE_DIR, ) From e3fd08a8b495969c693a580cd5f70ab27f3b2698 Mon Sep 17 00:00:00 2001 From: Shubhra Pandit Date: Wed, 24 Sep 2025 15:58:07 +0000 Subject: [PATCH 7/8] Address review comments --- src/llmcompressor/modeling/gpt_oss.py | 63 +++++++++++++++++---------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index 7f7d25f58..c72667968 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -7,6 +7,7 @@ from llmcompressor import oneshot from llmcompressor.utils.dev import skip_weights_initialize from llmcompressor.modifiers.quantization import QuantizationModifier +from compressed_tensors.utils import align_module_device, update_offload_parameter def convert_model_for_quantization_gptoss(model): to_delete = [] @@ -28,6 +29,7 @@ def convert_model_for_quantization_gptoss(model): intermediate = g2i // 2 hidden = gH + dtype = gup.dtype parent, child_name = _get_parent_and_child(model, name) top_k = int(max(1, min(_get_top_k(model.config) or 1, E))) seq = SequentialGPTOSSMoE( @@ -35,6 +37,7 @@ def convert_model_for_quantization_gptoss(model): intermediate_size=intermediate, top_k=top_k, original_moe=module, + dtype=dtype, ) parent._modules[child_name] = seq to_delete.append(module) @@ -45,6 +48,7 @@ def convert_model_for_quantization_gptoss(model): if to_delete: gc.collect() try: + torch.cuda.synchronize() torch.cuda.empty_cache() except Exception as e: print(f"[GPT-OSS] Warning: Failed to empty CUDA cache: {e}", flush=True) @@ -68,15 +72,15 @@ def _get_top_k(config): class GPTOSSMLP(nn.Module): - def __init__(self, hidden_size, intermediate_size): + def __init__(self, hidden_size, intermediate_size, dtype=None): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.alpha = 1.702 self.limit = 7.0 - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True) + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True, dtype=dtype) def forward(self, x): gate = self.gate_proj(x) @@ -93,7 +97,7 @@ class SequentialGPTOSSMoE(nn.Module): Replaces GPT-OSS fused-expert MoE with per-expert GPTOSSMLP modules. Copies weights from fused tensors and reuses the original router and optional shared_expert. """ - def __init__(self, hidden_size, intermediate_size, top_k, original_moe): + def __init__(self, hidden_size, intermediate_size, top_k, original_moe, dtype=None): super().__init__() self.hidden_size = hidden_size self.intermediate = intermediate_size @@ -106,32 +110,43 @@ def __init__(self, hidden_size, intermediate_size, top_k, original_moe): self.num_experts = E # Build per-expert MLPs - self.experts = nn.ModuleList() - with skip_weights_initialize(): + self.experts = nn.ModuleList() + with skip_weights_initialize(), align_module_device(original_moe.experts): for _ in range(E): - self.experts.append(GPTOSSMLP(hidden_size, intermediate_size)) + self.experts.append(GPTOSSMLP(hidden_size, intermediate_size, dtype=dtype)) gup = original_moe.experts.gate_up_proj # [E, H, 2I] gup_b = original_moe.experts.gate_up_proj_bias # [E, 2I] dwn = original_moe.experts.down_proj # [E, I, H] dwn_b = original_moe.experts.down_proj_bias # [E, H] - for i in range(E): - gup_i = gup[i] # [H, 2I] - gate_w = gup_i[:, ::2] # [H, I] - up_w = gup_i[:, 1::2] # [H, I] - down_w = dwn[i] # [I, H] - - mlp = self.experts[i] - mlp.gate_proj.weight.data.copy_(gate_w.T) # [I, H] - mlp.up_proj.weight.data.copy_(up_w.T) # [I, H] - mlp.down_proj.weight.data.copy_(down_w.T) # [H, I] - - gate_b = gup_b[i] # [2I] - mlp.gate_proj.bias.data.copy_(gate_b[::2]) # [I] - mlp.up_proj.bias.data.copy_(gate_b[1::2]) # [I] - mlp.down_proj.bias.data.copy_(dwn_b[i]) # [H] - + with align_module_device(self.experts): + for i, mlp in enumerate(self.experts): + update_offload_parameter( + mlp.gate_proj, "weight", + original_moe.experts.gate_up_proj[i, :, ::2].T + ) + update_offload_parameter( + mlp.up_proj, "weight", + original_moe.experts.gate_up_proj[i, :, 1::2].T + ) + update_offload_parameter( + mlp.down_proj, "weight", + original_moe.experts.down_proj[i].T + ) + + update_offload_parameter( + mlp.gate_proj, "bias", + original_moe.experts.gate_up_proj_bias[i, ::2] + ) + update_offload_parameter( + mlp.up_proj, "bias", + original_moe.experts.gate_up_proj_bias[i, 1::2] + ) + update_offload_parameter( + mlp.down_proj, "bias", + original_moe.experts.down_proj_bias[i] + ) # [H] def forward(self, hidden_states): From b6200cd4c85515512de6a21565eb278d190ca59c Mon Sep 17 00:00:00 2001 From: Shubhra Pandit Date: Wed, 24 Sep 2025 16:01:56 +0000 Subject: [PATCH 8/8] Address review comments --- src/llmcompressor/modeling/gpt_oss.py | 47 ++++++++++++++------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index c72667968..5b6055694 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -19,29 +19,30 @@ def convert_model_for_quantization_gptoss(model): if not (hasattr(experts, "gate_up_proj") and hasattr(experts, "down_proj")): continue - gup = experts.gate_up_proj # [E, H, 2I] - dwn = experts.down_proj # [E, I, H] - assert gup.dim() == 3 and dwn.dim() == 3 - E, gH, g2i = gup.shape - Ed, dI, dH = dwn.shape - assert E == Ed and gH == dH - assert g2i % 2 == 0 - intermediate = g2i // 2 - hidden = gH - - dtype = gup.dtype - parent, child_name = _get_parent_and_child(model, name) - top_k = int(max(1, min(_get_top_k(model.config) or 1, E))) - seq = SequentialGPTOSSMoE( - hidden_size=hidden, - intermediate_size=intermediate, - top_k=top_k, - original_moe=module, - dtype=dtype, - ) - parent._modules[child_name] = seq - to_delete.append(module) - print(f"[GPT-OSS] Patched {name} -> SequentialGPTOSSMoE (E={E}, inter={intermediate}, hidden={hidden})", flush=True) + with align_module_device(experts): + gup = experts.gate_up_proj # [E, H, 2I] + dwn = experts.down_proj # [E, I, H] + assert gup.dim() == 3 and dwn.dim() == 3 + E, gH, g2i = gup.shape + Ed, dI, dH = dwn.shape + assert E == Ed and gH == dH + assert g2i % 2 == 0 + intermediate = g2i // 2 + hidden = gH + + dtype = gup.dtype + parent, child_name = _get_parent_and_child(model, name) + top_k = int(max(1, min(_get_top_k(model.config) or 1, E))) + seq = SequentialGPTOSSMoE( + hidden_size=hidden, + intermediate_size=intermediate, + top_k=top_k, + original_moe=module, + dtype=dtype, + ) + parent._modules[child_name] = seq + to_delete.append(module) + print(f"[GPT-OSS] Patched {name} -> SequentialGPTOSSMoE (E={E}, inter={intermediate}, hidden={hidden})", flush=True) for m in to_delete: del m