Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0994db6
Deprecate replace_modules_for_calibration
sairampillai Sep 22, 2025
c7a9943
Refactor MoECalibrationContext
sairampillai Sep 22, 2025
e374313
Use moe_context in pipeline and by default and add tests
sairampillai Sep 24, 2025
7fefaac
Update documentation
sairampillai Sep 24, 2025
858a0f6
Deprecate replace_modules_for_calibration
sairampillai Sep 22, 2025
ef8e0b7
Refactor MoECalibrationContext
sairampillai Sep 22, 2025
b04f957
Use moe_context in pipeline and by default and add tests
sairampillai Sep 24, 2025
ba42881
Update documentation
sairampillai Sep 24, 2025
6a38a0f
Update docstrings to fix review comments
sairampillai Sep 26, 2025
61c5141
Merge branch 'moe_calibration_refactor' of https://github.com/sairamp…
sairampillai Sep 26, 2025
9a131cb
Fix style and quality checks
sairampillai Sep 26, 2025
4520421
Deprecate replace_modules_for_calibration
sairampillai Sep 22, 2025
1c15741
Refactor MoECalibrationContext
sairampillai Sep 22, 2025
d8fecb9
Use moe_context in pipeline and by default and add tests
sairampillai Sep 24, 2025
d099bf3
Update documentation
sairampillai Sep 24, 2025
1bb9f62
Merge branch 'moe_calibration_refactor' of https://github.com/sairamp…
sairampillai Sep 26, 2025
d4a6a11
Fix style and quality checks
sairampillai Sep 26, 2025
b19e3b6
Merge branch 'moe_calibration_refactor' of https://github.com/sairamp…
sairampillai Sep 26, 2025
779d79a
Simplify MoE calibration registration and implementation
sairampillai Oct 14, 2025
87e4484
Use simplified implementation and update context entrypoint
sairampillai Oct 14, 2025
ebafc53
Make module replacement verbose and explicit
sairampillai Oct 17, 2025
c451635
Update modeling and test files with latest moe_context signature
sairampillai Oct 17, 2025
f271a51
Update examples of models where moe_context was added
sairampillai Oct 17, 2025
774bb81
Merge branch 'main' into moe_calibration_refactor
sairampillai Oct 17, 2025
32546bb
Simplify calibrate fallback calls
sairampillai Oct 27, 2025
64675e0
Merge branch 'moe_calibration_refactor' of https://github.com/sairamp…
sairampillai Oct 27, 2025
02932be
Merge branch 'main' into moe_calibration_refactor
sairampillai Oct 29, 2025
92dd5fd
Update docstrings with class location
sairampillai Oct 29, 2025
40185f3
Fix linting
sairampillai Oct 29, 2025
166606b
Refactor to remove from_original() and use init
sairampillai Oct 29, 2025
73722eb
Fix prepare.py
sairampillai Oct 29, 2025
0348c56
Fix linting error
sairampillai Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions examples/multimodal_vision/llama4_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@
from transformers import Llama4ForConditionalGeneration, Llama4Processor

from llmcompressor import oneshot
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import GPTQModifier

# Select model and load it.
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
processor = Llama4Processor.from_pretrained(model_id)
# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`.
# This change allows compatibility with vllm.
# To apply your own custom module for experimentation, consider updating
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
model = replace_modules_for_calibration(model)
# MoE calibration is now handled automatically by the pipeline.
# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`)
# will be applied during calibration to enable proper
# expert calibration and vLLM compatibility.
# These replace the original `Llama4TextMoe` class from
# `transformers.models.llama4.modeling_llama4`.
#
# NOTE: This restructuring is specifically required for vLLM compatibility.
# To define custom calibration logic, create a new calibration module in
# modeling/llama4.py that inherits from `MoECalibrationModule`, and register
# it using the `@register_moe_calibration` decorator with the appropriate
# module class name (e.g., "Llama4TextMoe").

DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 512
Expand Down
8 changes: 4 additions & 4 deletions examples/quantization_w4a4_fp4/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model!

# Quantizing MoEs

To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to:
To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which:

1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization.
2. Ensure experts are quantized correctly as not all experts are activated during calibration
1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization.
2. Ensures experts are quantized correctly as not all experts are activated during calibration

Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance.
Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance.


12 changes: 6 additions & 6 deletions examples/quantization_w4a4_fp4/llama4_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from transformers import Llama4ForConditionalGeneration, Llama4Processor

from llmcompressor import oneshot
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import QuantizationModifier

# Select model and load it.
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
processor = Llama4Processor.from_pretrained(model_id)
# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`.
# This change allows compatibility with vllm.
# To apply your own custom module for experimentation, consider updating
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
model = replace_modules_for_calibration(model)
# MoE calibration is now handled automatically by the pipeline.
# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`)
# will be applied during calibration to enable
# proper expert calibration and vLLM compatibility.
# These replace the original `Llama4TextMoe` class from
# `transformers.models.llama4.modeling_llama4`.

DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 20
Expand Down
15 changes: 10 additions & 5 deletions examples/quantization_w4a4_fp4/qwen_30b_a3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,23 @@ def tokenize(sample):
)

# Apply quantization.
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
# during calibration.
# MoE calibration is now handled automatically by the pipeline.
# We set `moe_calibrate_all_experts` to True to ensure all experts receive
# calibration data. This temporarily updates the model definition to use
# `CalibrationQwen3MoeSparseMoeBlock` (from `llmcompressor.modeling.qwen3_moe`)
# which replaces the original `Qwen3MoeSparseMoeBlock` class from
# `transformers.models.qwen3_moe.modeling_qwen3_moe`. This updates how the
# forward pass is handled in the MoE block during calibration.
# Feel free to update the definition under
# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with
# this behaviour and evaluate its impact on quantization performance
# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py to play around with
# this behavior and evaluate its impact on quantization performance.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
calibrate_moe_context=True,
moe_calibrate_all_experts=True,
)


Expand Down
8 changes: 6 additions & 2 deletions examples/quantization_w8a8_fp8/llama4_fp8_block_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation

Expand All @@ -10,7 +9,12 @@
# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = replace_modules_for_calibration(model)
# MoE calibration is now handled automatically by the pipeline.
# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`)
# will be applied during calibration to enable
# proper expert calibration and vLLM compatibility.
# These replace the original `Llama4TextMoe` class from
# `transformers.models.llama4.modeling_llama4`.
# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with block size 128 via ptq
Expand Down
7 changes: 5 additions & 2 deletions examples/quantizing_moe/deepseek_r1_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import GPTQModifier

# Select model and load it.
Expand All @@ -20,7 +19,11 @@
model_id, torch_dtype="auto", config=config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = replace_modules_for_calibration(model)
# MoE calibration is now handled automatically by the pipeline.
# The `CalibrationDeepseekV3MoE` modules (from `llmcompressor.modeling.deepseek_v3`)
# will be applied during calibration to enable proper expert calibration.
# These replace the original `DeepseekV3MoE` class from
# `transformers.models.deepseek_v3.modeling_deepseek_v3`.

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
Expand Down
22 changes: 12 additions & 10 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,6 @@ class DatasetArguments(CustomDatasetArguments):
default=512,
metadata={"help": "Number of samples to use for one-shot calibration"},
)
calibrate_moe_context: bool = field(
default=False,
metadata={
"help": "If during calibration, the MoE context should be enabled "
"for the given model. This usually involves updating all MoE modules "
"in the model for the duration of calibration. See moe_context under "
"modeling/prepare.py for a list of supported MoEs and their updated "
"module definitions"
},
)
shuffle_calibration_samples: bool | None = field(
default=True,
metadata={
Expand Down Expand Up @@ -181,6 +171,18 @@ class DatasetArguments(CustomDatasetArguments):
),
},
)
moe_calibrate_all_experts: bool = field(
default=True,
metadata={
"help": (
"Whether to calibrate all experts during MoE model calibration. "
"When True, all experts will see all tokens during calibration, "
"ensuring proper quantization statistics for all experts. "
"When False, only routed experts will be used. "
"Only relevant for MoE models. Default is True."
),
},
)
# --- pipeline arguments --- #
pipeline: str | None = field(
default="independent",
Expand Down
23 changes: 15 additions & 8 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
from llmcompressor.modeling.moe_context import moe_calibration_context
from llmcompressor.pipelines import CalibrationPipeline

__all__ = ["Oneshot", "oneshot"]
Expand Down Expand Up @@ -209,11 +210,16 @@ def apply_recipe_modifiers(
user_pipeline = self.dataset_args.pipeline
modifiers = session.lifecycle.recipe.modifiers
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
pipeline(
# Apply MoE calibration context for the entire calibration process
with moe_calibration_context(
self.model,
calibration_dataloader,
self.dataset_args,
)
calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts,
):
pipeline(
self.model,
calibration_dataloader,
self.dataset_args,
)

session.finalize()

Expand Down Expand Up @@ -252,7 +258,7 @@ def oneshot(
overwrite_cache: bool = False,
preprocessing_num_workers: Optional[int] = None,
min_tokens_per_module: Optional[float] = None,
calibrate_moe_context: bool = False,
moe_calibrate_all_experts: bool = True,
quantization_aware_calibration: bool = True,
# Miscellaneous arguments
output_dir: Optional[str] = None,
Expand Down Expand Up @@ -316,9 +322,10 @@ def oneshot(
preprocessing.
:param min_tokens_per_module: Minimum percentage of tokens per
module, relevant for MoE models.
:param calibrate_moe_context: If during calibration, the MoE context should be
enabled for the given model. This usually involves updating all MoE modules
in the model for the duration of calibration.
:param moe_calibrate_all_experts: Whether to calibrate all experts during MoE
model calibration. When True, all experts will see all tokens during
calibration, ensuring proper quantization statistics. When False, only
routed experts will be used. Only relevant for MoE models. Default is True.
:param quantization_aware_calibration: Whether to enable quantization-aware
calibration in the sequential pipeline. When True, quantization is applied
during forward pass in calibration. When False, quantization is disabled
Expand Down
27 changes: 21 additions & 6 deletions src/llmcompressor/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@
DeepseekV3MoE as OriginalDeepseekV3MoE,
)

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)


class DeepseekV3MoECalibrate(torch.nn.Module):
@register_moe_calibration("DeepseekV3MoE")
class CalibrationDeepseekV3MoE(MoECalibrationModule):
"""
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
Calibration version of DeepseekV3MoE that sends all tokens to all experts.
"""

is_permanent = True

def __init__(
self,
config: DeepseekV3Config,
original: OriginalDeepseekV3MoE,
calibrate_all_experts: bool,
config: DeepseekV3Config,
calibrate_all_experts: bool = True,
):
super().__init__()
self.config = config
Expand Down Expand Up @@ -65,11 +73,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


# Legacy function for backward compatibility
def replace(
config: DeepseekV3Config,
module: OriginalDeepseekV3MoE,
calibrate_all_experts: bool,
):
return DeepseekV3MoECalibrate(
config=config, original=module, calibrate_all_experts=calibrate_all_experts
"""
Legacy replacement function.
Use CalibrationDeepseekV3MoE instead.
"""
return CalibrationDeepseekV3MoE(
module,
config,
calibrate_all_experts=calibrate_all_experts,
)
47 changes: 36 additions & 11 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,47 @@
Llama4TextMoe,
)

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)
from llmcompressor.utils.dev import skip_weights_initialize


class SequentialLlama4TextMoe(torch.nn.Module):
@register_moe_calibration("Llama4TextMoe")
class SequentialLlama4TextMoe(MoECalibrationModule):
"""
Calibration version of Llama4TextMoe that unpacks experts for sequential processing.

This module:
1. Unpacks the packed expert weights (3D -> 2D) for calibration
2. Optionally sends all tokens to all experts during calibration
3. Stays in unpacked form (permanent) for vLLM compatibility
"""

is_permanent = True

def __init__(
self,
config: Llama4TextConfig,
original: Llama4TextMoe,
calibrate_all_experts: bool,
config: Llama4Config,
calibrate_all_experts: bool = True,
):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts

self.experts = SequentialLlama4TextExperts(config, original.experts)
# Extract text config from multimodal config if needed
text_config = (
config.get_text_config() if hasattr(config, "get_text_config") else config
)
self.top_k = text_config.num_experts_per_tok
self.hidden_dim = text_config.hidden_size
self.num_experts = text_config.num_local_experts

self.experts = SequentialLlama4TextExperts(text_config, original.experts)
self.router = original.router
self.shared_expert = original.shared_expert
self.calibrate_all_experts = calibrate_all_experts

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_scores, router_logits = self.router(hidden_states) # transformers>=4.54

Expand Down Expand Up @@ -74,9 +94,14 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
self[i].down_proj.weight.data = down.t().contiguous()


# Legacy function for backward compatibility
def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool):
"""
Legacy replacement function.
Use SequentialLlama4TextMoe instead.
"""
return SequentialLlama4TextMoe(
config=config.get_text_config(),
original=module,
module,
config,
calibrate_all_experts=calibrate_all_experts,
)
Loading