From 0994db68b3ba37d5bcb76fc0ade561683d185589 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:01 +0530 Subject: [PATCH 01/26] Deprecate replace_modules_for_calibration --- src/llmcompressor/modeling/prepare.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 113fd4364..286d61d27 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,5 @@ +import warnings + import tqdm from compressed_tensors.utils import replace_module from transformers import PreTrainedModel @@ -20,6 +22,14 @@ def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: + # This function is deprecated. Use moe_calibration_context instead. + warnings.warn( + "replace_modules_for_calibration is deprecated. \ + Use moe_calibration_context instead.", + DeprecationWarning, + stacklevel=2, + ) + for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: From c7a99430d20507fdd50044118ac26b97702e8e39 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:21 +0530 Subject: [PATCH 02/26] Refactor MoECalibrationContext --- src/llmcompressor/modeling/moe_context.py | 155 ++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 src/llmcompressor/modeling/moe_context.py diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py new file mode 100644 index 000000000..835f36d13 --- /dev/null +++ b/src/llmcompressor/modeling/moe_context.py @@ -0,0 +1,155 @@ +""" +Standardized interface for MoE model calibration. +MoE calibration context is used to apply MoE calibration modifications to the model. +There are two types of MoE calibration contexts: +1. ContextualMoECalibration: uses context managers for temporary modifications +2. PermanentMoECalibration: permanently modifies the model +""" + +import contextlib +from abc import ABC, abstractmethod +from typing import Dict, TypeVar, Union + +from transformers import PreTrainedModel + +T = TypeVar("T", bound="MoECalibrationContext") + + +class MoECalibrationContext(ABC): + """ + Abstract base class for MoE calibration. + This provides a standardized interface for MoE model calibration. + """ + + @abstractmethod + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """ + Apply MoE calibration modifications to the model. + :param model: The model to modify + :param calibrate_all_experts: Whether to calibrate all + experts or only routed ones + """ + pass + + @abstractmethod + def restore(self, model: PreTrainedModel) -> None: + """ + Restore the model to its original state. + :param model: The model to restore + """ + pass + + +class ContextualMoECalibration(MoECalibrationContext): + """ + MoE calibration that uses context managers for temporary modifications. + This is suitable for models that need to be restored after calibration. + """ + + def __init__(self, model_class_name: str, update_function): + """ + Initialize the context manager-based MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param update_function: Function that applies the MoE modifications + """ + self.model_class_name = model_class_name + self.update_function = update_function + self._stack = None + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply MoE calibration modifications using context managers.""" + if self._stack is None: + self._stack = contextlib.ExitStack() + + self.update_function(model, self._stack, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore the model by exiting the context stack.""" + if self._stack is not None: + self._stack.close() + self._stack = None + + +class PermanentMoECalibration(MoECalibrationContext): + """ + MoE calibration context that permanently modifies the model. + This is suitable for models that can be loaded in their modified form + (e.g., Llama4 in vLLM). + """ + + def __init__(self, model_class_name: str, replacement_function): + """ + Initialize the permanent MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param replacement_function: Function that permanently replaces MoE modules + """ + self.model_class_name = model_class_name + self.replacement_function = replacement_function + self._original_modules = {} + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply permanent MoE calibration modifications.""" + # Store original modules for potential restoration + for name, module in model.named_modules(): + if module.__class__.__name__ == self.model_class_name: + self._original_modules[name] = module + + # Apply the replacement + self.replacement_function(model, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore original modules (if needed).""" + # For permanent MoE calibrations, restoration is typically not needed + # as the model is meant to stay in its modified form + pass + + +# Registry for MoE calibrations +_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} + + +def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: + """ + Register a MoE calibration context for a model class. + :param model_class_name: The class name of the model + :param context: The MoE calibration context to register + """ + _MOE_CONTEXTS[model_class_name] = context + + +def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: + """ + Get the registered MoE calibration context for a model class. + :param model_class_name: The class name of the model + :return: The MoE calibration context or None if not found + """ + return _MOE_CONTEXTS.get(model_class_name) + + +def list_supported_models() -> list: + """ + List all model classes that have registered MoE calibration contexts. + :return: List of supported model class names + """ + return list(_MOE_CONTEXTS.keys()) + + +# Convenience function for backward compatibility +def create_context_manager_context(model_class_name: str, update_function): + """ + Create a context manager-based MoE calibration. + :param model_class_name: The class name of the model + :param update_function: Function that applies the MoE modifications + :return: A ContextualMoECalibration instance + """ + return ContextualMoECalibration(model_class_name, update_function) + + +def create_permanent_context(model_class_name: str, replacement_function): + """ + Create a permanent MoE calibration. + :param model_class_name: The class name of the model + :param replacement_function: Function that permanently replaces MoE modules + :return: A PermanentMoECalibration instance + """ + return PermanentMoECalibration(model_class_name, replacement_function) From e374313a87908dcaed8160bff3a2ac06c1f0eb3a Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 20:23:03 +0530 Subject: [PATCH 03/26] Use moe_context in pipeline and by default and add tests --- src/llmcompressor/args/dataset_arguments.py | 35 +++- src/llmcompressor/modeling/moe_context.py | 176 ++++++++++++++++-- src/llmcompressor/modeling/prepare.py | 96 +++++++--- src/llmcompressor/pipelines/basic/pipeline.py | 4 +- .../pipelines/layer_sequential/pipeline.py | 3 +- .../pipelines/sequential/pipeline.py | 3 +- .../modeling/test_calib_deepseek_v3.py | 57 +++--- .../modeling/test_calib_llama4.py | 57 +++--- .../modeling/test_calib_qwen3.py | 3 +- 9 files changed, 318 insertions(+), 116 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index da816584d..e3e92fdf4 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,6 +7,7 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ +import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union @@ -126,16 +127,6 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - calibrate_moe_context: bool = field( - default=False, - metadata={ - "help": "If during calibration, the MoE context should be enabled " - "for the given model. This usually involves updating all MoE modules " - "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoEs and their updated " - "module definitions" - }, - ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ @@ -181,6 +172,17 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) + calibrate_moe_context: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "DEPRECATED: This parameter is deprecated and will be \ + removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect the calibration process." + ), + }, + ) # --- pipeline arguments --- # pipeline: Optional[str] = field( default="independent", @@ -229,3 +231,16 @@ class DatasetArguments(CustomDatasetArguments): def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None + + def __post_init__(self): + """Post-initialization hook to issue deprecation warnings.""" + if self.calibrate_moe_context is not None: + warnings.warn( + "The 'calibrate_moe_context' parameter is deprecated\ + and will be removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect\ + the calibration process.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 835f36d13..27e51c50e 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -8,13 +8,72 @@ import contextlib from abc import ABC, abstractmethod -from typing import Dict, TypeVar, Union +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Optional, TypeVar, Union +import tqdm +from compressed_tensors.utils import replace_module from transformers import PreTrainedModel +from llmcompressor.utils.helpers import patch_attr + T = TypeVar("T", bound="MoECalibrationContext") +class MoECalibrationType(Enum): + """Enumeration of supported MoE calibration types.""" + + PERMANENT = "permanent" + CONTEXTUAL = "contextual" + + +@dataclass +class MoEModelConfig: + """ + Configuration for MoE model calibration. + + This dataclass defines the parameters needed to configure MoE calibration + for a specific model architecture. It follows the same pattern used by + other model configuration systems in the project (e.g., SmoothQuant, AWQ). + + Attributes: + calibration_type: Type of calibration - MoECalibrationType.PERMANENT or + MoECalibrationType.CONTEXTUAL + target_class_name: The class name of the MoE module to replace + replace_function: Function that creates the replacement module + target_attribute: For contextual calibration, the attribute to replace + description: Optional description of the model configuration + """ + + calibration_type: MoECalibrationType + target_class_name: str + replace_function: Callable + target_attribute: Optional[str] = None + description: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if ( + self.calibration_type == MoECalibrationType.CONTEXTUAL + and self.target_attribute is None + ): + raise ValueError("target_attribute is required for contextual calibration") + + if ( + self.calibration_type == MoECalibrationType.PERMANENT + and self.target_attribute is not None + ): + raise ValueError( + "target_attribute should not be set for permanent calibration" + ) + + +# Registry of MoE model configurations +# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY +MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} + + class MoECalibrationContext(ABC): """ Abstract base class for MoE calibration. @@ -60,13 +119,14 @@ def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> N """Apply MoE calibration modifications using context managers.""" if self._stack is None: self._stack = contextlib.ExitStack() + self._stack.__enter__() self.update_function(model, self._stack, calibrate_all_experts) def restore(self, model: PreTrainedModel) -> None: """Restore the model by exiting the context stack.""" if self._stack is not None: - self._stack.close() + self._stack.__exit__(None, None, None) self._stack = None @@ -134,22 +194,108 @@ def list_supported_models() -> list: return list(_MOE_CONTEXTS.keys()) -# Convenience function for backward compatibility -def create_context_manager_context(model_class_name: str, update_function): +# Generic factory functions for creating MoE updaters +def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): """ - Create a context manager-based MoE calibration. - :param model_class_name: The class name of the model - :param update_function: Function that applies the MoE modifications - :return: A ContextualMoECalibration instance + Create a permanent MoE updater function for the given target class. + + Args: + target_class_name: The class name to look for in the model + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with PermanentMoECalibration """ - return ContextualMoECalibration(model_class_name, update_function) + def update_function(model: PreTrainedModel, calibrate_all_experts: bool): + """Update MoE modules for calibration.""" + for name, module in tqdm.tqdm(list(model.named_modules())): + if module.__class__.__name__ == target_class_name: + new_module = replace_function( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) + replace_module(model, name, new_module) + + return update_function -def create_permanent_context(model_class_name: str, replacement_function): + +def create_contextual_moe_updater( + target_class_name: str, target_attr: str, replace_function: Callable +): """ - Create a permanent MoE calibration. - :param model_class_name: The class name of the model - :param replacement_function: Function that permanently replaces MoE modules - :return: A PermanentMoECalibration instance + Create a contextual MoE updater function for the given target class and attribute. + + Args: + target_class_name: The class name to look for in the model + target_attr: The attribute name to replace within the target class + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with ContextualMoECalibration + """ + + def update_function( + model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool + ): + """Update MoE modules for calibration using context managers.""" + for module in model.modules(): + if module.__class__.__name__ == target_class_name: + stack.enter_context( + patch_attr( + module, + target_attr, + replace_function( + config=model.config, + module=getattr(module, target_attr), + calibrate_all_experts=calibrate_all_experts, + ), + ) + ) + + return update_function + + +def register_moe_model(model_class_name: str, config: MoEModelConfig): """ - return PermanentMoECalibration(model_class_name, replacement_function) + Register a MoE model with its configuration. + + Args: + model_class_name: The model class name + config: MoEModelConfig dataclass instance with calibration parameters + """ + if config.calibration_type == MoECalibrationType.PERMANENT: + updater = create_permanent_moe_updater( + config.target_class_name, config.replace_function + ) + context = PermanentMoECalibration(config.target_class_name, updater) + elif config.calibration_type == MoECalibrationType.CONTEXTUAL: + updater = create_contextual_moe_updater( + config.target_class_name, config.target_attribute, config.replace_function + ) + context = ContextualMoECalibration(model_class_name, updater) + else: + raise ValueError(f"Unknown MoE type: {config.calibration_type}") + + register_moe_context(model_class_name, context) + + +def register_moe_model_from_dict(model_class_name: str, config_dict: dict): + """ + Register a MoE model from a dictionary configuration (backward compatibility). + + Args: + model_class_name: The model class name + config_dict: Dictionary with calibration parameters + """ + # Convert string calibration_type to enum + if "calibration_type" in config_dict and isinstance( + config_dict["calibration_type"], str + ): + config_dict["calibration_type"] = MoECalibrationType( + config_dict["calibration_type"] + ) + + config = MoEModelConfig(**config_dict) + register_moe_model(model_class_name, config) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 286d61d27..d3c985535 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,4 @@ +import contextlib import warnings import tqdm @@ -6,10 +7,15 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.moe_context import ( + MoECalibrationType, + MoEModelConfig, + get_moe_context, + register_moe_model, +) from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.utils.helpers import patch_attr -__all__ = ["replace_modules_for_calibration"] +__all__ = ["moe_calibration_context"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -24,8 +30,8 @@ def replace_modules_for_calibration( ) -> PreTrainedModel: # This function is deprecated. Use moe_calibration_context instead. warnings.warn( - "replace_modules_for_calibration is deprecated. \ - Use moe_calibration_context instead.", + "replace_modules_for_calibration is deprecated. " + "Use moe_calibration_context instead.", DeprecationWarning, stacklevel=2, ) @@ -45,37 +51,67 @@ def replace_modules_for_calibration( # ------------------- module replacements; during calibration -------------------- - -def update_qwen3_moe(model, stack, calibrate_all_experts): - for module in model.modules(): - cls_name = module.__class__.__name__ - if cls_name == "Qwen3MoeDecoderLayer": - # Optionally update the model.config to pass in other arguments - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3MoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) +# MoE model configurations - centralized registry +# Adding a new MoE model is now as simple as adding an entry here! +# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ +MOE_EXPERTS_REPLACEMENTS = { + "Qwen3MoeForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.CONTEXTUAL, + target_class_name="Qwen3MoeDecoderLayer", + target_attribute="mlp", + replace_function=replace_Qwen3MoE, + description="Qwen3 MoE model with contextual calibration for MLP layers", + ), + "DeepseekV3ForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="DeepseekV3MoE", + replace_function=replace_deepseekv3, + description="DeepSeek V3 MoE model with permanent calibration", + ), + "Llama4ForConditionalGeneration": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="Llama4TextMoe", + replace_function=replace_llama4, + description=( + "Llama4 MoE model with permanent calibration for vLLM compatibility" + ), + ), +} -moe_context = { - "Qwen3MoeForCausalLM": update_qwen3_moe, -} +# Register all MoE models automatically +for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items(): + register_moe_model(model_class_name, config) +@contextlib.contextmanager def moe_calibration_context( model: PreTrainedModel, - stack, - calibrate_all_experts: bool = False, + calibrate_all_experts: bool = True, ): - # Temporarily updates the MoE modules within the context - # Once the context exists, parameter updates persist + """ + Context manager for MoE calibration that temporarily updates MoE modules. + + Args: + model: The model to apply MoE calibration to + calibrate_all_experts: Whether to calibrate all experts or only routed ones + + Yields: + The model with MoE calibration applied + """ cls_name = model.__class__.__name__ - if cls_name in moe_context: - moe_context.get(cls_name)(model, stack, calibrate_all_experts) + moe_context = get_moe_context(cls_name) + + if moe_context is None: + # No MoE context registered for this model, yield unchanged + yield model + return + + # Apply MoE calibration + moe_context.apply(model, calibrate_all_experts) + + try: + yield model + finally: + # Restore original state + moe_context.restore(model) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index db4b15305..edcb46b09 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -46,9 +46,7 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - - if dataset_args is not None and dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 565cb81fb..bab732f48 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -82,8 +82,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f00a37e43..9827052b7 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -85,8 +85,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index ca7fb06af..239365fdd 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial import pytest @@ -9,7 +10,7 @@ DeepseekV3MoECalibrate, OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -21,39 +22,43 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): with skip_weights_download(): model = AutoModelForCausalLM.from_pretrained(model_stub) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Deepseek MoE layer - moe_layer = None - for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): - moe_layer = module - break + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 4eb609ca9..d5363d35c 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -1,3 +1,4 @@ +import contextlib import os from functools import partial @@ -10,7 +11,7 @@ Llama4TextMoe, SequentialLlama4TextMoe, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -28,39 +29,43 @@ def test_calib_replace_llama4_moe_all_experts(model_stub): model_stub, torch_dtype="auto" ) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Llama4 MoE layer - moe_layer = None - for module in model.modules(): - if isinstance(module, SequentialLlama4TextMoe): - moe_layer = module - break + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.text_config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 822af18db..a20acf8a8 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -26,8 +26,7 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - - moe_calibration_context(model, stack, calibrate_all_experts=True) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) # Find one MoE layer moe_layer = None From 7fefaacbda6d461da51adfe517ab7e46d8cf6a38 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 22:35:23 +0530 Subject: [PATCH 04/26] Update documentation --- examples/multimodal_vision/llama4_example.py | 9 +++------ examples/quantization_w4a4_fp4/README.md | 8 ++++---- examples/quantization_w4a4_fp4/llama4_example.py | 9 +++------ examples/quantizing_moe/deepseek_r1_example.py | 5 +++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 292fa8300..6832ababb 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/README.md b/examples/quantization_w4a4_fp4/README.md index ab9e3eb37..a0d458722 100644 --- a/examples/quantization_w4a4_fp4/README.md +++ b/examples/quantization_w4a4_fp4/README.md @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model! # Quantizing MoEs -To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to: +To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which: -1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. -2. Ensure experts are quantized correctly as not all experts are activated during calibration +1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. +2. Ensures experts are quantized correctly as not all experts are activated during calibration -Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. +Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 28b57dda9..3a52986ad 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9977584c3..9e5d1ca63 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -2,7 +2,6 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -20,7 +19,9 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `DeepseekV3MoECalibrate` modules will be applied during calibration +# to enable proper expert calibration. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" From 858a0f6cf97ccd043e8b46ba59395c03ad9c5f8a Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:01 +0530 Subject: [PATCH 05/26] Deprecate replace_modules_for_calibration --- src/llmcompressor/modeling/prepare.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 113fd4364..286d61d27 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,5 @@ +import warnings + import tqdm from compressed_tensors.utils import replace_module from transformers import PreTrainedModel @@ -20,6 +22,14 @@ def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: + # This function is deprecated. Use moe_calibration_context instead. + warnings.warn( + "replace_modules_for_calibration is deprecated. \ + Use moe_calibration_context instead.", + DeprecationWarning, + stacklevel=2, + ) + for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: From ef8e0b752e32f0addf1085c2490bc9d15c43b22d Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:21 +0530 Subject: [PATCH 06/26] Refactor MoECalibrationContext --- src/llmcompressor/modeling/moe_context.py | 155 ++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 src/llmcompressor/modeling/moe_context.py diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py new file mode 100644 index 000000000..835f36d13 --- /dev/null +++ b/src/llmcompressor/modeling/moe_context.py @@ -0,0 +1,155 @@ +""" +Standardized interface for MoE model calibration. +MoE calibration context is used to apply MoE calibration modifications to the model. +There are two types of MoE calibration contexts: +1. ContextualMoECalibration: uses context managers for temporary modifications +2. PermanentMoECalibration: permanently modifies the model +""" + +import contextlib +from abc import ABC, abstractmethod +from typing import Dict, TypeVar, Union + +from transformers import PreTrainedModel + +T = TypeVar("T", bound="MoECalibrationContext") + + +class MoECalibrationContext(ABC): + """ + Abstract base class for MoE calibration. + This provides a standardized interface for MoE model calibration. + """ + + @abstractmethod + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """ + Apply MoE calibration modifications to the model. + :param model: The model to modify + :param calibrate_all_experts: Whether to calibrate all + experts or only routed ones + """ + pass + + @abstractmethod + def restore(self, model: PreTrainedModel) -> None: + """ + Restore the model to its original state. + :param model: The model to restore + """ + pass + + +class ContextualMoECalibration(MoECalibrationContext): + """ + MoE calibration that uses context managers for temporary modifications. + This is suitable for models that need to be restored after calibration. + """ + + def __init__(self, model_class_name: str, update_function): + """ + Initialize the context manager-based MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param update_function: Function that applies the MoE modifications + """ + self.model_class_name = model_class_name + self.update_function = update_function + self._stack = None + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply MoE calibration modifications using context managers.""" + if self._stack is None: + self._stack = contextlib.ExitStack() + + self.update_function(model, self._stack, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore the model by exiting the context stack.""" + if self._stack is not None: + self._stack.close() + self._stack = None + + +class PermanentMoECalibration(MoECalibrationContext): + """ + MoE calibration context that permanently modifies the model. + This is suitable for models that can be loaded in their modified form + (e.g., Llama4 in vLLM). + """ + + def __init__(self, model_class_name: str, replacement_function): + """ + Initialize the permanent MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param replacement_function: Function that permanently replaces MoE modules + """ + self.model_class_name = model_class_name + self.replacement_function = replacement_function + self._original_modules = {} + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply permanent MoE calibration modifications.""" + # Store original modules for potential restoration + for name, module in model.named_modules(): + if module.__class__.__name__ == self.model_class_name: + self._original_modules[name] = module + + # Apply the replacement + self.replacement_function(model, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore original modules (if needed).""" + # For permanent MoE calibrations, restoration is typically not needed + # as the model is meant to stay in its modified form + pass + + +# Registry for MoE calibrations +_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} + + +def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: + """ + Register a MoE calibration context for a model class. + :param model_class_name: The class name of the model + :param context: The MoE calibration context to register + """ + _MOE_CONTEXTS[model_class_name] = context + + +def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: + """ + Get the registered MoE calibration context for a model class. + :param model_class_name: The class name of the model + :return: The MoE calibration context or None if not found + """ + return _MOE_CONTEXTS.get(model_class_name) + + +def list_supported_models() -> list: + """ + List all model classes that have registered MoE calibration contexts. + :return: List of supported model class names + """ + return list(_MOE_CONTEXTS.keys()) + + +# Convenience function for backward compatibility +def create_context_manager_context(model_class_name: str, update_function): + """ + Create a context manager-based MoE calibration. + :param model_class_name: The class name of the model + :param update_function: Function that applies the MoE modifications + :return: A ContextualMoECalibration instance + """ + return ContextualMoECalibration(model_class_name, update_function) + + +def create_permanent_context(model_class_name: str, replacement_function): + """ + Create a permanent MoE calibration. + :param model_class_name: The class name of the model + :param replacement_function: Function that permanently replaces MoE modules + :return: A PermanentMoECalibration instance + """ + return PermanentMoECalibration(model_class_name, replacement_function) From b04f957137188d45455a89fcdff6b209963de068 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 20:23:03 +0530 Subject: [PATCH 07/26] Use moe_context in pipeline and by default and add tests --- src/llmcompressor/args/dataset_arguments.py | 35 +++- src/llmcompressor/modeling/moe_context.py | 176 ++++++++++++++++-- src/llmcompressor/modeling/prepare.py | 96 +++++++--- src/llmcompressor/pipelines/basic/pipeline.py | 4 +- .../pipelines/layer_sequential/pipeline.py | 3 +- .../pipelines/sequential/pipeline.py | 3 +- .../modeling/test_calib_deepseek_v3.py | 57 +++--- .../modeling/test_calib_llama4.py | 57 +++--- .../modeling/test_calib_qwen3.py | 3 +- 9 files changed, 318 insertions(+), 116 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index da816584d..e3e92fdf4 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,6 +7,7 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ +import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union @@ -126,16 +127,6 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - calibrate_moe_context: bool = field( - default=False, - metadata={ - "help": "If during calibration, the MoE context should be enabled " - "for the given model. This usually involves updating all MoE modules " - "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoEs and their updated " - "module definitions" - }, - ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ @@ -181,6 +172,17 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) + calibrate_moe_context: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "DEPRECATED: This parameter is deprecated and will be \ + removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect the calibration process." + ), + }, + ) # --- pipeline arguments --- # pipeline: Optional[str] = field( default="independent", @@ -229,3 +231,16 @@ class DatasetArguments(CustomDatasetArguments): def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None + + def __post_init__(self): + """Post-initialization hook to issue deprecation warnings.""" + if self.calibrate_moe_context is not None: + warnings.warn( + "The 'calibrate_moe_context' parameter is deprecated\ + and will be removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect\ + the calibration process.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 835f36d13..27e51c50e 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -8,13 +8,72 @@ import contextlib from abc import ABC, abstractmethod -from typing import Dict, TypeVar, Union +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Optional, TypeVar, Union +import tqdm +from compressed_tensors.utils import replace_module from transformers import PreTrainedModel +from llmcompressor.utils.helpers import patch_attr + T = TypeVar("T", bound="MoECalibrationContext") +class MoECalibrationType(Enum): + """Enumeration of supported MoE calibration types.""" + + PERMANENT = "permanent" + CONTEXTUAL = "contextual" + + +@dataclass +class MoEModelConfig: + """ + Configuration for MoE model calibration. + + This dataclass defines the parameters needed to configure MoE calibration + for a specific model architecture. It follows the same pattern used by + other model configuration systems in the project (e.g., SmoothQuant, AWQ). + + Attributes: + calibration_type: Type of calibration - MoECalibrationType.PERMANENT or + MoECalibrationType.CONTEXTUAL + target_class_name: The class name of the MoE module to replace + replace_function: Function that creates the replacement module + target_attribute: For contextual calibration, the attribute to replace + description: Optional description of the model configuration + """ + + calibration_type: MoECalibrationType + target_class_name: str + replace_function: Callable + target_attribute: Optional[str] = None + description: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if ( + self.calibration_type == MoECalibrationType.CONTEXTUAL + and self.target_attribute is None + ): + raise ValueError("target_attribute is required for contextual calibration") + + if ( + self.calibration_type == MoECalibrationType.PERMANENT + and self.target_attribute is not None + ): + raise ValueError( + "target_attribute should not be set for permanent calibration" + ) + + +# Registry of MoE model configurations +# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY +MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} + + class MoECalibrationContext(ABC): """ Abstract base class for MoE calibration. @@ -60,13 +119,14 @@ def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> N """Apply MoE calibration modifications using context managers.""" if self._stack is None: self._stack = contextlib.ExitStack() + self._stack.__enter__() self.update_function(model, self._stack, calibrate_all_experts) def restore(self, model: PreTrainedModel) -> None: """Restore the model by exiting the context stack.""" if self._stack is not None: - self._stack.close() + self._stack.__exit__(None, None, None) self._stack = None @@ -134,22 +194,108 @@ def list_supported_models() -> list: return list(_MOE_CONTEXTS.keys()) -# Convenience function for backward compatibility -def create_context_manager_context(model_class_name: str, update_function): +# Generic factory functions for creating MoE updaters +def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): """ - Create a context manager-based MoE calibration. - :param model_class_name: The class name of the model - :param update_function: Function that applies the MoE modifications - :return: A ContextualMoECalibration instance + Create a permanent MoE updater function for the given target class. + + Args: + target_class_name: The class name to look for in the model + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with PermanentMoECalibration """ - return ContextualMoECalibration(model_class_name, update_function) + def update_function(model: PreTrainedModel, calibrate_all_experts: bool): + """Update MoE modules for calibration.""" + for name, module in tqdm.tqdm(list(model.named_modules())): + if module.__class__.__name__ == target_class_name: + new_module = replace_function( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) + replace_module(model, name, new_module) + + return update_function -def create_permanent_context(model_class_name: str, replacement_function): + +def create_contextual_moe_updater( + target_class_name: str, target_attr: str, replace_function: Callable +): """ - Create a permanent MoE calibration. - :param model_class_name: The class name of the model - :param replacement_function: Function that permanently replaces MoE modules - :return: A PermanentMoECalibration instance + Create a contextual MoE updater function for the given target class and attribute. + + Args: + target_class_name: The class name to look for in the model + target_attr: The attribute name to replace within the target class + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with ContextualMoECalibration + """ + + def update_function( + model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool + ): + """Update MoE modules for calibration using context managers.""" + for module in model.modules(): + if module.__class__.__name__ == target_class_name: + stack.enter_context( + patch_attr( + module, + target_attr, + replace_function( + config=model.config, + module=getattr(module, target_attr), + calibrate_all_experts=calibrate_all_experts, + ), + ) + ) + + return update_function + + +def register_moe_model(model_class_name: str, config: MoEModelConfig): """ - return PermanentMoECalibration(model_class_name, replacement_function) + Register a MoE model with its configuration. + + Args: + model_class_name: The model class name + config: MoEModelConfig dataclass instance with calibration parameters + """ + if config.calibration_type == MoECalibrationType.PERMANENT: + updater = create_permanent_moe_updater( + config.target_class_name, config.replace_function + ) + context = PermanentMoECalibration(config.target_class_name, updater) + elif config.calibration_type == MoECalibrationType.CONTEXTUAL: + updater = create_contextual_moe_updater( + config.target_class_name, config.target_attribute, config.replace_function + ) + context = ContextualMoECalibration(model_class_name, updater) + else: + raise ValueError(f"Unknown MoE type: {config.calibration_type}") + + register_moe_context(model_class_name, context) + + +def register_moe_model_from_dict(model_class_name: str, config_dict: dict): + """ + Register a MoE model from a dictionary configuration (backward compatibility). + + Args: + model_class_name: The model class name + config_dict: Dictionary with calibration parameters + """ + # Convert string calibration_type to enum + if "calibration_type" in config_dict and isinstance( + config_dict["calibration_type"], str + ): + config_dict["calibration_type"] = MoECalibrationType( + config_dict["calibration_type"] + ) + + config = MoEModelConfig(**config_dict) + register_moe_model(model_class_name, config) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 286d61d27..d3c985535 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,4 @@ +import contextlib import warnings import tqdm @@ -6,10 +7,15 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.moe_context import ( + MoECalibrationType, + MoEModelConfig, + get_moe_context, + register_moe_model, +) from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.utils.helpers import patch_attr -__all__ = ["replace_modules_for_calibration"] +__all__ = ["moe_calibration_context"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -24,8 +30,8 @@ def replace_modules_for_calibration( ) -> PreTrainedModel: # This function is deprecated. Use moe_calibration_context instead. warnings.warn( - "replace_modules_for_calibration is deprecated. \ - Use moe_calibration_context instead.", + "replace_modules_for_calibration is deprecated. " + "Use moe_calibration_context instead.", DeprecationWarning, stacklevel=2, ) @@ -45,37 +51,67 @@ def replace_modules_for_calibration( # ------------------- module replacements; during calibration -------------------- - -def update_qwen3_moe(model, stack, calibrate_all_experts): - for module in model.modules(): - cls_name = module.__class__.__name__ - if cls_name == "Qwen3MoeDecoderLayer": - # Optionally update the model.config to pass in other arguments - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3MoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) +# MoE model configurations - centralized registry +# Adding a new MoE model is now as simple as adding an entry here! +# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ +MOE_EXPERTS_REPLACEMENTS = { + "Qwen3MoeForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.CONTEXTUAL, + target_class_name="Qwen3MoeDecoderLayer", + target_attribute="mlp", + replace_function=replace_Qwen3MoE, + description="Qwen3 MoE model with contextual calibration for MLP layers", + ), + "DeepseekV3ForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="DeepseekV3MoE", + replace_function=replace_deepseekv3, + description="DeepSeek V3 MoE model with permanent calibration", + ), + "Llama4ForConditionalGeneration": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="Llama4TextMoe", + replace_function=replace_llama4, + description=( + "Llama4 MoE model with permanent calibration for vLLM compatibility" + ), + ), +} -moe_context = { - "Qwen3MoeForCausalLM": update_qwen3_moe, -} +# Register all MoE models automatically +for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items(): + register_moe_model(model_class_name, config) +@contextlib.contextmanager def moe_calibration_context( model: PreTrainedModel, - stack, - calibrate_all_experts: bool = False, + calibrate_all_experts: bool = True, ): - # Temporarily updates the MoE modules within the context - # Once the context exists, parameter updates persist + """ + Context manager for MoE calibration that temporarily updates MoE modules. + + Args: + model: The model to apply MoE calibration to + calibrate_all_experts: Whether to calibrate all experts or only routed ones + + Yields: + The model with MoE calibration applied + """ cls_name = model.__class__.__name__ - if cls_name in moe_context: - moe_context.get(cls_name)(model, stack, calibrate_all_experts) + moe_context = get_moe_context(cls_name) + + if moe_context is None: + # No MoE context registered for this model, yield unchanged + yield model + return + + # Apply MoE calibration + moe_context.apply(model, calibrate_all_experts) + + try: + yield model + finally: + # Restore original state + moe_context.restore(model) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index db4b15305..edcb46b09 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -46,9 +46,7 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - - if dataset_args is not None and dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 565cb81fb..bab732f48 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -82,8 +82,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f00a37e43..9827052b7 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -85,8 +85,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index ca7fb06af..239365fdd 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial import pytest @@ -9,7 +10,7 @@ DeepseekV3MoECalibrate, OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -21,39 +22,43 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): with skip_weights_download(): model = AutoModelForCausalLM.from_pretrained(model_stub) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Deepseek MoE layer - moe_layer = None - for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): - moe_layer = module - break + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 4eb609ca9..d5363d35c 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -1,3 +1,4 @@ +import contextlib import os from functools import partial @@ -10,7 +11,7 @@ Llama4TextMoe, SequentialLlama4TextMoe, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -28,39 +29,43 @@ def test_calib_replace_llama4_moe_all_experts(model_stub): model_stub, torch_dtype="auto" ) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Llama4 MoE layer - moe_layer = None - for module in model.modules(): - if isinstance(module, SequentialLlama4TextMoe): - moe_layer = module - break + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.text_config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 822af18db..a20acf8a8 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -26,8 +26,7 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - - moe_calibration_context(model, stack, calibrate_all_experts=True) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) # Find one MoE layer moe_layer = None From ba42881c3367a2b04be889ba122a71fd806803ec Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 22:35:23 +0530 Subject: [PATCH 08/26] Update documentation --- examples/multimodal_vision/llama4_example.py | 9 +++------ examples/quantization_w4a4_fp4/README.md | 8 ++++---- examples/quantization_w4a4_fp4/llama4_example.py | 9 +++------ examples/quantizing_moe/deepseek_r1_example.py | 5 +++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 292fa8300..6832ababb 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/README.md b/examples/quantization_w4a4_fp4/README.md index ab9e3eb37..a0d458722 100644 --- a/examples/quantization_w4a4_fp4/README.md +++ b/examples/quantization_w4a4_fp4/README.md @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model! # Quantizing MoEs -To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to: +To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which: -1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. -2. Ensure experts are quantized correctly as not all experts are activated during calibration +1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. +2. Ensures experts are quantized correctly as not all experts are activated during calibration -Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. +Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 28b57dda9..3a52986ad 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9977584c3..9e5d1ca63 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -2,7 +2,6 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -20,7 +19,9 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `DeepseekV3MoECalibrate` modules will be applied during calibration +# to enable proper expert calibration. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" From 6a38a0fbccaeda55e901eec3e0a4576707d4b220 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Fri, 26 Sep 2025 18:11:53 +0530 Subject: [PATCH 09/26] Update docstrings to fix review comments --- examples/multimodal_vision/llama4_example.py | 5 +++++ src/llmcompressor/modeling/moe_context.py | 7 ++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 6832ababb..aa88304c3 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -12,6 +12,11 @@ # MoE calibration is now handled automatically by the pipeline. # The `SequentialLlama4TextMoe` modules will be applied during calibration # to enable proper expert calibration and vLLM compatibility. +# +# NOTE: This restructuring is specifically required for vLLM compatibility +# Users can customize the calibration behavior as needed by modifying the +# To define custom calibration logic, implement your function in modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). +# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your custom function. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 27e51c50e..ff2d9cf14 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -2,8 +2,8 @@ Standardized interface for MoE model calibration. MoE calibration context is used to apply MoE calibration modifications to the model. There are two types of MoE calibration contexts: -1. ContextualMoECalibration: uses context managers for temporary modifications -2. PermanentMoECalibration: permanently modifies the model +1. ContextualMoECalibration: uses context managers for temporary modifications and restores the model to its original state after pipeline execution +2. PermanentMoECalibration: permanently modifies the model and stays in its modified form after pipeline execution """ import contextlib @@ -41,7 +41,8 @@ class MoEModelConfig: calibration_type: Type of calibration - MoECalibrationType.PERMANENT or MoECalibrationType.CONTEXTUAL target_class_name: The class name of the MoE module to replace - replace_function: Function that creates the replacement module + replace_function: Function that creates the replacement module, + generally defined in modeling/model_name.py target_attribute: For contextual calibration, the attribute to replace description: Optional description of the model configuration """ From 9a131cbeff498a8f160d3cb6f36ece01eb0d94dd Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Fri, 26 Sep 2025 18:39:30 +0530 Subject: [PATCH 10/26] Fix style and quality checks --- examples/multimodal_vision/llama4_example.py | 6 +- src/llmcompressor/modeling/moe_context.py | 297 ------------------- 2 files changed, 4 insertions(+), 299 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index f3738429d..c7838bfbf 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -15,8 +15,10 @@ # # NOTE: This restructuring is specifically required for vLLM compatibility # Users can customize the calibration behavior as needed by modifying the -# To define custom calibration logic, implement your function in modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). -# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your custom function. +# To define custom calibration logic, implement your function in +# modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). +# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your +# custom function. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 558f9dcbf..c65a3a2b6 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -283,303 +283,6 @@ def register_moe_model(model_class_name: str, config: MoEModelConfig): register_moe_context(model_class_name, context) -def register_moe_model_from_dict(model_class_name: str, config_dict: dict): - """ - Register a MoE model from a dictionary configuration (backward compatibility). - - Args: - model_class_name: The model class name - config_dict: Dictionary with calibration parameters - """ - # Convert string calibration_type to enum - if "calibration_type" in config_dict and isinstance( - config_dict["calibration_type"], str - ): - config_dict["calibration_type"] = MoECalibrationType( - config_dict["calibration_type"] - ) - - config = MoEModelConfig(**config_dict) - register_moe_model(model_class_name, config) -""" -Standardized interface for MoE model calibration. -MoE calibration context is used to apply MoE calibration modifications to the model. -There are two types of MoE calibration contexts: -1. ContextualMoECalibration: uses context managers for temporary modifications and restores the model to its original state after pipeline execution -2. PermanentMoECalibration: permanently modifies the model and stays in its modified form after pipeline execution -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Callable, Dict, Optional, TypeVar, Union - -from transformers import PreTrainedModel - -T = TypeVar("T", bound="MoECalibrationContext") - - -class MoECalibrationType(Enum): - """Enumeration of supported MoE calibration types.""" - - PERMANENT = "permanent" - CONTEXTUAL = "contextual" - - -@dataclass -class MoEModelConfig: - """ - Configuration for MoE model calibration. - - This dataclass defines the parameters needed to configure MoE calibration - for a specific model architecture. It follows the same pattern used by - other model configuration systems in the project (e.g., SmoothQuant, AWQ). - - Attributes: - calibration_type: Type of calibration - MoECalibrationType.PERMANENT or - MoECalibrationType.CONTEXTUAL - target_class_name: The class name of the MoE module to replace - replace_function: Function that creates the replacement module, - generally defined in modeling/model_name.py - target_attribute: For contextual calibration, the attribute to replace - description: Optional description of the model configuration - """ - - calibration_type: MoECalibrationType - target_class_name: str - replace_function: Callable - target_attribute: Optional[str] = None - description: Optional[str] = None - - def __post_init__(self): - """Validate configuration after initialization.""" - if ( - self.calibration_type == MoECalibrationType.CONTEXTUAL - and self.target_attribute is None - ): - raise ValueError("target_attribute is required for contextual calibration") - - if ( - self.calibration_type == MoECalibrationType.PERMANENT - and self.target_attribute is not None - ): - raise ValueError( - "target_attribute should not be set for permanent calibration" - ) - - -# Registry of MoE model configurations -# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY -MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} - - -class MoECalibrationContext(ABC): - """ - Abstract base class for MoE calibration. - This provides a standardized interface for MoE model calibration. - """ - - @abstractmethod - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """ - Apply MoE calibration modifications to the model. - :param model: The model to modify - :param calibrate_all_experts: Whether to calibrate all - experts or only routed ones - """ - pass - - @abstractmethod - def restore(self, model: PreTrainedModel) -> None: - """ - Restore the model to its original state. - :param model: The model to restore - """ - pass - - -class ContextualMoECalibration(MoECalibrationContext): - """ - MoE calibration that uses context managers for temporary modifications. - This is suitable for models that need to be restored after calibration. - """ - - def __init__(self, model_class_name: str, update_function): - """ - Initialize the context manager-based MoE calibration. - :param model_class_name: The class name of the model this context applies to - :param update_function: Function that applies the MoE modifications - """ - self.model_class_name = model_class_name - self.update_function = update_function - self._stack = None - - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """Apply MoE calibration modifications using context managers.""" - if self._stack is None: - self._stack = contextlib.ExitStack() - self._stack.__enter__() - - self.update_function(model, self._stack, calibrate_all_experts) - - def restore(self, model: PreTrainedModel) -> None: - """Restore the model by exiting the context stack.""" - if self._stack is not None: - self._stack.__exit__(None, None, None) - self._stack = None - - -class PermanentMoECalibration(MoECalibrationContext): - """ - MoE calibration context that permanently modifies the model. - This is suitable for models that can be loaded in their modified form - (e.g., Llama4 in vLLM). - """ - - def __init__(self, model_class_name: str, replacement_function): - """ - Initialize the permanent MoE calibration. - :param model_class_name: The class name of the model this context applies to - :param replacement_function: Function that permanently replaces MoE modules - """ - self.model_class_name = model_class_name - self.replacement_function = replacement_function - self._original_modules = {} - - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """Apply permanent MoE calibration modifications.""" - # Store original modules for potential restoration - for name, module in model.named_modules(): - if module.__class__.__name__ == self.model_class_name: - self._original_modules[name] = module - - # Apply the replacement - self.replacement_function(model, calibrate_all_experts) - - def restore(self, model: PreTrainedModel) -> None: - """Restore original modules (if needed).""" - # For permanent MoE calibrations, restoration is typically not needed - # as the model is meant to stay in its modified form - pass - - -# Registry for MoE calibrations -_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} - - -def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: - """ - Register a MoE calibration context for a model class. - :param model_class_name: The class name of the model - :param context: The MoE calibration context to register - """ - _MOE_CONTEXTS[model_class_name] = context - - -def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: - """ - Get the registered MoE calibration context for a model class. - :param model_class_name: The class name of the model - :return: The MoE calibration context or None if not found - """ - return _MOE_CONTEXTS.get(model_class_name) - - -def list_supported_models() -> list: - """ - List all model classes that have registered MoE calibration contexts. - :return: List of supported model class names - """ - return list(_MOE_CONTEXTS.keys()) - - -# Generic factory functions for creating MoE updaters -def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): - """ - Create a permanent MoE updater function for the given target class. - - Args: - target_class_name: The class name to look for in the model - replace_function: Function that creates the replacement module - - Returns: - A function that can be used with PermanentMoECalibration - """ - - def update_function(model: PreTrainedModel, calibrate_all_experts: bool): - """Update MoE modules for calibration.""" - for name, module in tqdm.tqdm(list(model.named_modules())): - if module.__class__.__name__ == target_class_name: - new_module = replace_function( - config=model.config, - module=module, - calibrate_all_experts=calibrate_all_experts, - ) - replace_module(model, name, new_module) - - return update_function - - -def create_contextual_moe_updater( - target_class_name: str, target_attr: str, replace_function: Callable -): - """ - Create a contextual MoE updater function for the given target class and attribute. - - Args: - target_class_name: The class name to look for in the model - target_attr: The attribute name to replace within the target class - replace_function: Function that creates the replacement module - - Returns: - A function that can be used with ContextualMoECalibration - """ - - def update_function( - model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool - ): - """Update MoE modules for calibration using context managers.""" - for module in model.modules(): - if module.__class__.__name__ == target_class_name: - stack.enter_context( - patch_attr( - module, - target_attr, - replace_function( - config=model.config, - module=getattr(module, target_attr), - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) - - return update_function - - -def register_moe_model(model_class_name: str, config: MoEModelConfig): - """ - Register a MoE model with its configuration. - - Args: - model_class_name: The model class name - config: MoEModelConfig dataclass instance with calibration parameters - """ - if config.calibration_type == MoECalibrationType.PERMANENT: - updater = create_permanent_moe_updater( - config.target_class_name, config.replace_function - ) - context = PermanentMoECalibration(config.target_class_name, updater) - elif config.calibration_type == MoECalibrationType.CONTEXTUAL: - updater = create_contextual_moe_updater( - config.target_class_name, config.target_attribute, config.replace_function - ) - context = ContextualMoECalibration(model_class_name, updater) - else: - raise ValueError(f"Unknown MoE type: {config.calibration_type}") - - register_moe_context(model_class_name, context) - - def register_moe_model_from_dict(model_class_name: str, config_dict: dict): """ Register a MoE model from a dictionary configuration (backward compatibility). From 452042112d5ca4fe6382f746124b5e6a71abfc39 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:01 +0530 Subject: [PATCH 11/26] Deprecate replace_modules_for_calibration --- src/llmcompressor/modeling/prepare.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index e966761bd..138ad773c 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,5 @@ +import warnings + import tqdm from compressed_tensors.utils import replace_module from transformers import PreTrainedModel @@ -20,6 +22,14 @@ def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: + # This function is deprecated. Use moe_calibration_context instead. + warnings.warn( + "replace_modules_for_calibration is deprecated. \ + Use moe_calibration_context instead.", + DeprecationWarning, + stacklevel=2, + ) + for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: From 1c15741db15d4b0a144775394bd932a64c9f67cf Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:21 +0530 Subject: [PATCH 12/26] Refactor MoECalibrationContext --- src/llmcompressor/modeling/moe_context.py | 155 ++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 src/llmcompressor/modeling/moe_context.py diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py new file mode 100644 index 000000000..835f36d13 --- /dev/null +++ b/src/llmcompressor/modeling/moe_context.py @@ -0,0 +1,155 @@ +""" +Standardized interface for MoE model calibration. +MoE calibration context is used to apply MoE calibration modifications to the model. +There are two types of MoE calibration contexts: +1. ContextualMoECalibration: uses context managers for temporary modifications +2. PermanentMoECalibration: permanently modifies the model +""" + +import contextlib +from abc import ABC, abstractmethod +from typing import Dict, TypeVar, Union + +from transformers import PreTrainedModel + +T = TypeVar("T", bound="MoECalibrationContext") + + +class MoECalibrationContext(ABC): + """ + Abstract base class for MoE calibration. + This provides a standardized interface for MoE model calibration. + """ + + @abstractmethod + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """ + Apply MoE calibration modifications to the model. + :param model: The model to modify + :param calibrate_all_experts: Whether to calibrate all + experts or only routed ones + """ + pass + + @abstractmethod + def restore(self, model: PreTrainedModel) -> None: + """ + Restore the model to its original state. + :param model: The model to restore + """ + pass + + +class ContextualMoECalibration(MoECalibrationContext): + """ + MoE calibration that uses context managers for temporary modifications. + This is suitable for models that need to be restored after calibration. + """ + + def __init__(self, model_class_name: str, update_function): + """ + Initialize the context manager-based MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param update_function: Function that applies the MoE modifications + """ + self.model_class_name = model_class_name + self.update_function = update_function + self._stack = None + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply MoE calibration modifications using context managers.""" + if self._stack is None: + self._stack = contextlib.ExitStack() + + self.update_function(model, self._stack, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore the model by exiting the context stack.""" + if self._stack is not None: + self._stack.close() + self._stack = None + + +class PermanentMoECalibration(MoECalibrationContext): + """ + MoE calibration context that permanently modifies the model. + This is suitable for models that can be loaded in their modified form + (e.g., Llama4 in vLLM). + """ + + def __init__(self, model_class_name: str, replacement_function): + """ + Initialize the permanent MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param replacement_function: Function that permanently replaces MoE modules + """ + self.model_class_name = model_class_name + self.replacement_function = replacement_function + self._original_modules = {} + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply permanent MoE calibration modifications.""" + # Store original modules for potential restoration + for name, module in model.named_modules(): + if module.__class__.__name__ == self.model_class_name: + self._original_modules[name] = module + + # Apply the replacement + self.replacement_function(model, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore original modules (if needed).""" + # For permanent MoE calibrations, restoration is typically not needed + # as the model is meant to stay in its modified form + pass + + +# Registry for MoE calibrations +_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} + + +def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: + """ + Register a MoE calibration context for a model class. + :param model_class_name: The class name of the model + :param context: The MoE calibration context to register + """ + _MOE_CONTEXTS[model_class_name] = context + + +def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: + """ + Get the registered MoE calibration context for a model class. + :param model_class_name: The class name of the model + :return: The MoE calibration context or None if not found + """ + return _MOE_CONTEXTS.get(model_class_name) + + +def list_supported_models() -> list: + """ + List all model classes that have registered MoE calibration contexts. + :return: List of supported model class names + """ + return list(_MOE_CONTEXTS.keys()) + + +# Convenience function for backward compatibility +def create_context_manager_context(model_class_name: str, update_function): + """ + Create a context manager-based MoE calibration. + :param model_class_name: The class name of the model + :param update_function: Function that applies the MoE modifications + :return: A ContextualMoECalibration instance + """ + return ContextualMoECalibration(model_class_name, update_function) + + +def create_permanent_context(model_class_name: str, replacement_function): + """ + Create a permanent MoE calibration. + :param model_class_name: The class name of the model + :param replacement_function: Function that permanently replaces MoE modules + :return: A PermanentMoECalibration instance + """ + return PermanentMoECalibration(model_class_name, replacement_function) From d8fecb9527340b4b38e99a1beb86783f8cd34793 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 20:23:03 +0530 Subject: [PATCH 13/26] Use moe_context in pipeline and by default and add tests --- src/llmcompressor/args/dataset_arguments.py | 35 +++- src/llmcompressor/modeling/moe_context.py | 176 ++++++++++++++++-- src/llmcompressor/modeling/prepare.py | 94 +++++++--- src/llmcompressor/pipelines/basic/pipeline.py | 4 +- .../pipelines/layer_sequential/pipeline.py | 3 +- .../pipelines/sequential/pipeline.py | 3 +- .../modeling/test_calib_deepseek_v3.py | 57 +++--- .../modeling/test_calib_llama4.py | 57 +++--- .../modeling/test_calib_qwen3.py | 3 +- 9 files changed, 317 insertions(+), 115 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index da816584d..e3e92fdf4 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,6 +7,7 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ +import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union @@ -126,16 +127,6 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - calibrate_moe_context: bool = field( - default=False, - metadata={ - "help": "If during calibration, the MoE context should be enabled " - "for the given model. This usually involves updating all MoE modules " - "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoEs and their updated " - "module definitions" - }, - ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ @@ -181,6 +172,17 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) + calibrate_moe_context: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "DEPRECATED: This parameter is deprecated and will be \ + removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect the calibration process." + ), + }, + ) # --- pipeline arguments --- # pipeline: Optional[str] = field( default="independent", @@ -229,3 +231,16 @@ class DatasetArguments(CustomDatasetArguments): def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None + + def __post_init__(self): + """Post-initialization hook to issue deprecation warnings.""" + if self.calibrate_moe_context is not None: + warnings.warn( + "The 'calibrate_moe_context' parameter is deprecated\ + and will be removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect\ + the calibration process.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 835f36d13..27e51c50e 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -8,13 +8,72 @@ import contextlib from abc import ABC, abstractmethod -from typing import Dict, TypeVar, Union +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Optional, TypeVar, Union +import tqdm +from compressed_tensors.utils import replace_module from transformers import PreTrainedModel +from llmcompressor.utils.helpers import patch_attr + T = TypeVar("T", bound="MoECalibrationContext") +class MoECalibrationType(Enum): + """Enumeration of supported MoE calibration types.""" + + PERMANENT = "permanent" + CONTEXTUAL = "contextual" + + +@dataclass +class MoEModelConfig: + """ + Configuration for MoE model calibration. + + This dataclass defines the parameters needed to configure MoE calibration + for a specific model architecture. It follows the same pattern used by + other model configuration systems in the project (e.g., SmoothQuant, AWQ). + + Attributes: + calibration_type: Type of calibration - MoECalibrationType.PERMANENT or + MoECalibrationType.CONTEXTUAL + target_class_name: The class name of the MoE module to replace + replace_function: Function that creates the replacement module + target_attribute: For contextual calibration, the attribute to replace + description: Optional description of the model configuration + """ + + calibration_type: MoECalibrationType + target_class_name: str + replace_function: Callable + target_attribute: Optional[str] = None + description: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if ( + self.calibration_type == MoECalibrationType.CONTEXTUAL + and self.target_attribute is None + ): + raise ValueError("target_attribute is required for contextual calibration") + + if ( + self.calibration_type == MoECalibrationType.PERMANENT + and self.target_attribute is not None + ): + raise ValueError( + "target_attribute should not be set for permanent calibration" + ) + + +# Registry of MoE model configurations +# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY +MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} + + class MoECalibrationContext(ABC): """ Abstract base class for MoE calibration. @@ -60,13 +119,14 @@ def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> N """Apply MoE calibration modifications using context managers.""" if self._stack is None: self._stack = contextlib.ExitStack() + self._stack.__enter__() self.update_function(model, self._stack, calibrate_all_experts) def restore(self, model: PreTrainedModel) -> None: """Restore the model by exiting the context stack.""" if self._stack is not None: - self._stack.close() + self._stack.__exit__(None, None, None) self._stack = None @@ -134,22 +194,108 @@ def list_supported_models() -> list: return list(_MOE_CONTEXTS.keys()) -# Convenience function for backward compatibility -def create_context_manager_context(model_class_name: str, update_function): +# Generic factory functions for creating MoE updaters +def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): """ - Create a context manager-based MoE calibration. - :param model_class_name: The class name of the model - :param update_function: Function that applies the MoE modifications - :return: A ContextualMoECalibration instance + Create a permanent MoE updater function for the given target class. + + Args: + target_class_name: The class name to look for in the model + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with PermanentMoECalibration """ - return ContextualMoECalibration(model_class_name, update_function) + def update_function(model: PreTrainedModel, calibrate_all_experts: bool): + """Update MoE modules for calibration.""" + for name, module in tqdm.tqdm(list(model.named_modules())): + if module.__class__.__name__ == target_class_name: + new_module = replace_function( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) + replace_module(model, name, new_module) + + return update_function -def create_permanent_context(model_class_name: str, replacement_function): + +def create_contextual_moe_updater( + target_class_name: str, target_attr: str, replace_function: Callable +): """ - Create a permanent MoE calibration. - :param model_class_name: The class name of the model - :param replacement_function: Function that permanently replaces MoE modules - :return: A PermanentMoECalibration instance + Create a contextual MoE updater function for the given target class and attribute. + + Args: + target_class_name: The class name to look for in the model + target_attr: The attribute name to replace within the target class + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with ContextualMoECalibration + """ + + def update_function( + model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool + ): + """Update MoE modules for calibration using context managers.""" + for module in model.modules(): + if module.__class__.__name__ == target_class_name: + stack.enter_context( + patch_attr( + module, + target_attr, + replace_function( + config=model.config, + module=getattr(module, target_attr), + calibrate_all_experts=calibrate_all_experts, + ), + ) + ) + + return update_function + + +def register_moe_model(model_class_name: str, config: MoEModelConfig): """ - return PermanentMoECalibration(model_class_name, replacement_function) + Register a MoE model with its configuration. + + Args: + model_class_name: The model class name + config: MoEModelConfig dataclass instance with calibration parameters + """ + if config.calibration_type == MoECalibrationType.PERMANENT: + updater = create_permanent_moe_updater( + config.target_class_name, config.replace_function + ) + context = PermanentMoECalibration(config.target_class_name, updater) + elif config.calibration_type == MoECalibrationType.CONTEXTUAL: + updater = create_contextual_moe_updater( + config.target_class_name, config.target_attribute, config.replace_function + ) + context = ContextualMoECalibration(model_class_name, updater) + else: + raise ValueError(f"Unknown MoE type: {config.calibration_type}") + + register_moe_context(model_class_name, context) + + +def register_moe_model_from_dict(model_class_name: str, config_dict: dict): + """ + Register a MoE model from a dictionary configuration (backward compatibility). + + Args: + model_class_name: The model class name + config_dict: Dictionary with calibration parameters + """ + # Convert string calibration_type to enum + if "calibration_type" in config_dict and isinstance( + config_dict["calibration_type"], str + ): + config_dict["calibration_type"] = MoECalibrationType( + config_dict["calibration_type"] + ) + + config = MoEModelConfig(**config_dict) + register_moe_model(model_class_name, config) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 138ad773c..d3c985535 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,4 @@ +import contextlib import warnings import tqdm @@ -6,10 +7,15 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.moe_context import ( + MoECalibrationType, + MoEModelConfig, + get_moe_context, + register_moe_model, +) from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.utils.helpers import patch_attr -__all__ = ["replace_modules_for_calibration"] +__all__ = ["moe_calibration_context"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -24,8 +30,8 @@ def replace_modules_for_calibration( ) -> PreTrainedModel: # This function is deprecated. Use moe_calibration_context instead. warnings.warn( - "replace_modules_for_calibration is deprecated. \ - Use moe_calibration_context instead.", + "replace_modules_for_calibration is deprecated. " + "Use moe_calibration_context instead.", DeprecationWarning, stacklevel=2, ) @@ -45,37 +51,67 @@ def replace_modules_for_calibration( # ------------------- module replacements; during calibration -------------------- - -def update_qwen3_moe(model, stack, calibrate_all_experts): - for module in model.modules(): - cls_name = module.__class__.__name__ - if cls_name == "Qwen3MoeDecoderLayer": - # Optionally update the model.config to pass in other arguments - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3MoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) +# MoE model configurations - centralized registry +# Adding a new MoE model is now as simple as adding an entry here! +# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ +MOE_EXPERTS_REPLACEMENTS = { + "Qwen3MoeForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.CONTEXTUAL, + target_class_name="Qwen3MoeDecoderLayer", + target_attribute="mlp", + replace_function=replace_Qwen3MoE, + description="Qwen3 MoE model with contextual calibration for MLP layers", + ), + "DeepseekV3ForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="DeepseekV3MoE", + replace_function=replace_deepseekv3, + description="DeepSeek V3 MoE model with permanent calibration", + ), + "Llama4ForConditionalGeneration": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="Llama4TextMoe", + replace_function=replace_llama4, + description=( + "Llama4 MoE model with permanent calibration for vLLM compatibility" + ), + ), +} -moe_context = { - "Qwen3MoeForCausalLM": update_qwen3_moe, -} +# Register all MoE models automatically +for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items(): + register_moe_model(model_class_name, config) +@contextlib.contextmanager def moe_calibration_context( model: PreTrainedModel, - stack, calibrate_all_experts: bool = True, ): - # Temporarily updates the MoE modules within the context - # Once the context exists, parameter updates persist + """ + Context manager for MoE calibration that temporarily updates MoE modules. + + Args: + model: The model to apply MoE calibration to + calibrate_all_experts: Whether to calibrate all experts or only routed ones + + Yields: + The model with MoE calibration applied + """ cls_name = model.__class__.__name__ - if cls_name in moe_context: - moe_context.get(cls_name)(model, stack, calibrate_all_experts) + moe_context = get_moe_context(cls_name) + + if moe_context is None: + # No MoE context registered for this model, yield unchanged + yield model + return + + # Apply MoE calibration + moe_context.apply(model, calibrate_all_experts) + + try: + yield model + finally: + # Restore original state + moe_context.restore(model) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index db4b15305..edcb46b09 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -46,9 +46,7 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - - if dataset_args is not None and dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 565cb81fb..bab732f48 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -82,8 +82,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f00a37e43..9827052b7 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -85,8 +85,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index ca7fb06af..239365fdd 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial import pytest @@ -9,7 +10,7 @@ DeepseekV3MoECalibrate, OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -21,39 +22,43 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): with skip_weights_download(): model = AutoModelForCausalLM.from_pretrained(model_stub) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Deepseek MoE layer - moe_layer = None - for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): - moe_layer = module - break + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 4eb609ca9..d5363d35c 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -1,3 +1,4 @@ +import contextlib import os from functools import partial @@ -10,7 +11,7 @@ Llama4TextMoe, SequentialLlama4TextMoe, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -28,39 +29,43 @@ def test_calib_replace_llama4_moe_all_experts(model_stub): model_stub, torch_dtype="auto" ) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Llama4 MoE layer - moe_layer = None - for module in model.modules(): - if isinstance(module, SequentialLlama4TextMoe): - moe_layer = module - break + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.text_config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 822af18db..a20acf8a8 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -26,8 +26,7 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - - moe_calibration_context(model, stack, calibrate_all_experts=True) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) # Find one MoE layer moe_layer = None From d099bf3488b14bae882455296084aa1276404fe6 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 22:35:23 +0530 Subject: [PATCH 14/26] Update documentation --- examples/multimodal_vision/llama4_example.py | 9 +++------ examples/quantization_w4a4_fp4/README.md | 8 ++++---- examples/quantization_w4a4_fp4/llama4_example.py | 9 +++------ examples/quantizing_moe/deepseek_r1_example.py | 5 +++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 292fa8300..6832ababb 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/README.md b/examples/quantization_w4a4_fp4/README.md index ab9e3eb37..a0d458722 100644 --- a/examples/quantization_w4a4_fp4/README.md +++ b/examples/quantization_w4a4_fp4/README.md @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model! # Quantizing MoEs -To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to: +To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which: -1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. -2. Ensure experts are quantized correctly as not all experts are activated during calibration +1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. +2. Ensures experts are quantized correctly as not all experts are activated during calibration -Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. +Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 28b57dda9..3a52986ad 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9977584c3..9e5d1ca63 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -2,7 +2,6 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -20,7 +19,9 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `DeepseekV3MoECalibrate` modules will be applied during calibration +# to enable proper expert calibration. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" From d4a6a11421b430afa404dc89372be7dd9e70fbf2 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Fri, 26 Sep 2025 18:39:30 +0530 Subject: [PATCH 15/26] Fix style and quality checks --- examples/multimodal_vision/llama4_example.py | 8 +++++--- src/llmcompressor/modeling/moe_context.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index aa88304c3..c7838bfbf 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -12,11 +12,13 @@ # MoE calibration is now handled automatically by the pipeline. # The `SequentialLlama4TextMoe` modules will be applied during calibration # to enable proper expert calibration and vLLM compatibility. -# +# # NOTE: This restructuring is specifically required for vLLM compatibility # Users can customize the calibration behavior as needed by modifying the -# To define custom calibration logic, implement your function in modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). -# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your custom function. +# To define custom calibration logic, implement your function in +# modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). +# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your +# custom function. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 72e495049..9de74c20a 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -2,9 +2,9 @@ Standardized interface for MoE model calibration. MoE calibration context is used to apply MoE calibration modifications to the model. There are two types of MoE calibration contexts: -1. ContextualMoECalibration: uses context managers for temporary modifications +1. ContextualMoECalibration: uses context managers for temporary modifications and restores the model to its original state after pipeline execution -2. PermanentMoECalibration: permanently modifies the model and stays in its modified +2. PermanentMoECalibration: permanently modifies the model and stays in its modified form after pipeline execution """ @@ -43,7 +43,7 @@ class MoEModelConfig: calibration_type: Type of calibration - MoECalibrationType.PERMANENT or MoECalibrationType.CONTEXTUAL target_class_name: The class name of the MoE module to replace - replace_function: Function that creates the replacement module, + replace_function: Function that creates the replacement module generally defined in modeling/model_name.py target_attribute: For contextual calibration, the attribute to replace description: Optional description of the model configuration From 779d79ae87592931692099436b723919bcb9dc6d Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 14 Oct 2025 22:32:58 +0530 Subject: [PATCH 16/26] Simplify MoE calibration registration and implementation Signed-off-by: Sairam Pillai --- src/llmcompressor/modeling/moe_context.py | 407 ++++++++-------------- 1 file changed, 141 insertions(+), 266 deletions(-) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 9de74c20a..38d91dcc5 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -1,304 +1,179 @@ """ -Standardized interface for MoE model calibration. -MoE calibration context is used to apply MoE calibration modifications to the model. -There are two types of MoE calibration contexts: -1. ContextualMoECalibration: uses context managers for temporary modifications - and restores the model to its original state after pipeline execution -2. PermanentMoECalibration: permanently modifies the model and stays in its modified - form after pipeline execution +Simplified interface for MoE model calibration. + +MoE (Mixture of Experts) models route tokens to different expert networks. +During calibration for quantization/compression, we need to ensure ALL experts +see data, not just the ones selected by the router. This module provides the +infrastructure to temporarily modify MoE modules for proper calibration. + +Key components: +- MoECalibrationModule: Abstract base class for calibration modules +- MOE_CALIBRATION_MODULES: Registry mapping module class names to calibration classes +- moe_calibration_context: Context manager that applies calibration to a model """ import contextlib from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Callable, Dict, Optional, TypeVar, Union +from typing import Any, Dict, Optional, Type -import tqdm -from compressed_tensors.utils import replace_module +import torch +from loguru import logger from transformers import PreTrainedModel -from llmcompressor.utils.helpers import patch_attr - -T = TypeVar("T", bound="MoECalibrationContext") - - -class MoECalibrationType(Enum): - """Enumeration of supported MoE calibration types.""" - - PERMANENT = "permanent" - CONTEXTUAL = "contextual" - - -@dataclass -class MoEModelConfig: - """ - Configuration for MoE model calibration. - - This dataclass defines the parameters needed to configure MoE calibration - for a specific model architecture. It follows the same pattern used by - other model configuration systems in the project (e.g., SmoothQuant, AWQ). +__all__ = [ + "MoECalibrationModule", + "MOE_CALIBRATION_MODULES", + "register_moe_calibration", + "moe_calibration_context", +] - Attributes: - calibration_type: Type of calibration - MoECalibrationType.PERMANENT or - MoECalibrationType.CONTEXTUAL - target_class_name: The class name of the MoE module to replace - replace_function: Function that creates the replacement module - generally defined in modeling/model_name.py - target_attribute: For contextual calibration, the attribute to replace - description: Optional description of the model configuration - """ - - calibration_type: MoECalibrationType - target_class_name: str - replace_function: Callable - target_attribute: Optional[str] = None - description: Optional[str] = None - - def __post_init__(self): - """Validate configuration after initialization.""" - if ( - self.calibration_type == MoECalibrationType.CONTEXTUAL - and self.target_attribute is None - ): - raise ValueError("target_attribute is required for contextual calibration") - - if ( - self.calibration_type == MoECalibrationType.PERMANENT - and self.target_attribute is not None - ): - raise ValueError( - "target_attribute should not be set for permanent calibration" - ) - -# Registry of MoE model configurations -# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY -MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} - - -class MoECalibrationContext(ABC): +class MoECalibrationModule(ABC, torch.nn.Module): """ - Abstract base class for MoE calibration. - This provides a standardized interface for MoE model calibration. + Abstract base class for MoE calibration modules. + + Calibration modules replace original MoE modules during the calibration + phase to ensure all experts receive data for proper quantization statistics. + + Subclasses must: + 1. Implement `from_original()` to create calibration module from original + 2. Set `is_permanent` to indicate if module should stay in calibration form + 3. Optionally implement `restore()` if is_permanent=False """ - @abstractmethod - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """ - Apply MoE calibration modifications to the model. - :param model: The model to modify - :param calibrate_all_experts: Whether to calibrate all - experts or only routed ones - """ - pass + is_permanent: bool = False + @classmethod @abstractmethod - def restore(self, model: PreTrainedModel) -> None: - """ - Restore the model to its original state. - :param model: The model to restore + def from_original( + cls, + original: torch.nn.Module, + config: Any, + calibrate_all_experts: bool = True, + ) -> "MoECalibrationModule": """ - pass - + Create a calibration module from the original MoE module. -class ContextualMoECalibration(MoECalibrationContext): - """ - MoE calibration that uses context managers for temporary modifications. - This is suitable for models that need to be restored after calibration. - """ + Args: + original: The original MoE module to convert + config: Model configuration (contains num_experts, etc.) + calibrate_all_experts: If True, send all tokens to all experts. + If False, use normal routing. - def __init__(self, model_class_name: str, update_function): - """ - Initialize the context manager-based MoE calibration. - :param model_class_name: The class name of the model this context applies to - :param update_function: Function that applies the MoE modifications + Returns: + Instance of the calibration module """ - self.model_class_name = model_class_name - self.update_function = update_function - self._stack = None - - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """Apply MoE calibration modifications using context managers.""" - if self._stack is None: - self._stack = contextlib.ExitStack() - self._stack.__enter__() - - self.update_function(model, self._stack, calibrate_all_experts) - - def restore(self, model: PreTrainedModel) -> None: - """Restore the model by exiting the context stack.""" - if self._stack is not None: - self._stack.__exit__(None, None, None) - self._stack = None - - -class PermanentMoECalibration(MoECalibrationContext): - """ - MoE calibration context that permanently modifies the model. - This is suitable for models that can be loaded in their modified form - (e.g., Llama4 in vLLM). - """ + pass - def __init__(self, model_class_name: str, replacement_function): + def restore(self) -> torch.nn.Module: """ - Initialize the permanent MoE calibration. - :param model_class_name: The class name of the model this context applies to - :param replacement_function: Function that permanently replaces MoE modules + Restore the original module structure. + + Only needed if is_permanent=False. For permanent modules, this is a no-op. + + Returns: + The original module (or self if permanent) """ - self.model_class_name = model_class_name - self.replacement_function = replacement_function - self._original_modules = {} - - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """Apply permanent MoE calibration modifications.""" - # Store original modules for potential restoration - for name, module in model.named_modules(): - if module.__class__.__name__ == self.model_class_name: - self._original_modules[name] = module - - # Apply the replacement - self.replacement_function(model, calibrate_all_experts) - - def restore(self, model: PreTrainedModel) -> None: - """Restore original modules (if needed).""" - # For permanent MoE calibrations, restoration is typically not needed - # as the model is meant to stay in its modified form - pass - - -# Registry for MoE calibrations -_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} - - -def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: - """ - Register a MoE calibration context for a model class. - :param model_class_name: The class name of the model - :param context: The MoE calibration context to register - """ - _MOE_CONTEXTS[model_class_name] = context - - -def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: - """ - Get the registered MoE calibration context for a model class. - :param model_class_name: The class name of the model - :return: The MoE calibration context or None if not found - """ - return _MOE_CONTEXTS.get(model_class_name) + if self.is_permanent: + return self + raise NotImplementedError( + f"{self.__class__.__name__} has is_permanent=False but doesn't " + "implement restore()" + ) -def list_supported_models() -> list: - """ - List all model classes that have registered MoE calibration contexts. - :return: List of supported model class names - """ - return list(_MOE_CONTEXTS.keys()) +# Registry: module class name -> calibration module class +MOE_CALIBRATION_MODULES: Dict[str, Type[MoECalibrationModule]] = {} -# Generic factory functions for creating MoE updaters -def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): +def register_moe_calibration(module_class_name: str): """ - Create a permanent MoE updater function for the given target class. - + Decorator to register a MoE calibration module. + + Usage: + @register_moe_calibration("DeepseekV3MoE") + class CalibrationDeepseekV3MoE(MoECalibrationModule): + ... + Args: - target_class_name: The class name to look for in the model - replace_function: Function that creates the replacement module - - Returns: - A function that can be used with PermanentMoECalibration + module_class_name: The class name of the original module to replace """ - def update_function(model: PreTrainedModel, calibrate_all_experts: bool): - """Update MoE modules for calibration.""" - for name, module in tqdm.tqdm(list(model.named_modules())): - if module.__class__.__name__ == target_class_name: - new_module = replace_function( - config=model.config, - module=module, - calibrate_all_experts=calibrate_all_experts, - ) - replace_module(model, name, new_module) + def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]: + if not issubclass(cls, MoECalibrationModule): + raise TypeError( + f"{cls.__name__} must inherit from MoECalibrationModule" + ) + MOE_CALIBRATION_MODULES[module_class_name] = cls + return cls - return update_function + return decorator -def create_contextual_moe_updater( - target_class_name: str, target_attr: str, replace_function: Callable +@contextlib.contextmanager +def moe_calibration_context( + model: PreTrainedModel, + calibrate_all_experts: bool = True, ): """ - Create a contextual MoE updater function for the given target class and attribute. - + Context manager that applies MoE calibration to a model. + + This scans all modules in the model and replaces any MoE modules with their + calibration equivalents. After the context exits, non-permanent modules are + restored to their original form. + Args: - target_class_name: The class name to look for in the model - target_attr: The attribute name to replace within the target class - replace_function: Function that creates the replacement module - - Returns: - A function that can be used with ContextualMoECalibration - """ - - def update_function( - model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool - ): - """Update MoE modules for calibration using context managers.""" - for module in model.modules(): - if module.__class__.__name__ == target_class_name: - stack.enter_context( - patch_attr( - module, - target_attr, - replace_function( - config=model.config, - module=getattr(module, target_attr), - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) - - return update_function - - -def register_moe_model(model_class_name: str, config: MoEModelConfig): - """ - Register a MoE model with its configuration. - - Args: - model_class_name: The model class name - config: MoEModelConfig dataclass instance with calibration parameters - """ - if config.calibration_type == MoECalibrationType.PERMANENT: - updater = create_permanent_moe_updater( - config.target_class_name, config.replace_function - ) - context = PermanentMoECalibration(config.target_class_name, updater) - elif config.calibration_type == MoECalibrationType.CONTEXTUAL: - updater = create_contextual_moe_updater( - config.target_class_name, config.target_attribute, config.replace_function - ) - context = ContextualMoECalibration(model_class_name, updater) - else: - raise ValueError(f"Unknown MoE type: {config.calibration_type}") - - register_moe_context(model_class_name, context) - - -def register_moe_model_from_dict(model_class_name: str, config_dict: dict): - """ - Register a MoE model from a dictionary configuration (backward compatibility). - - Args: - model_class_name: The model class name - config_dict: Dictionary with calibration parameters - """ - # Convert string calibration_type to enum - if "calibration_type" in config_dict and isinstance( - config_dict["calibration_type"], str - ): - config_dict["calibration_type"] = MoECalibrationType( - config_dict["calibration_type"] + model: The model to apply MoE calibration to + calibrate_all_experts: If True, all experts see all tokens during calibration. + If False, use normal routing (useful for some techniques) + + Yields: + The model with MoE calibration applied + + Example: + with moe_calibration_context(model): + # Run calibration - all experts will see data + for batch in dataloader: + model(**batch) + # Model is now restored (unless permanent) + """ + replaced = {} + + # Step 1: Find and replace MoE modules + for name, module in model.named_modules(): + class_name = module.__class__.__name__ + if class_name in MOE_CALIBRATION_MODULES: + calibration_cls = MOE_CALIBRATION_MODULES[class_name] + replacement = calibration_cls.from_original( + module, + model.config, + calibrate_all_experts, + ) + model.set_submodule(name, replacement) + replaced[name] = (module, replacement) + + # Log what was replaced + if replaced: + logger.info(f"Replaced {len(replaced)} MoE modules for calibration") + permanent_count = sum( + 1 for _, (_, repl) in replaced.items() if repl.is_permanent ) + if permanent_count > 0: + logger.info( + f"{permanent_count}/{len(replaced)} modules will remain in " + "calibration form (permanent)" + ) + if permanent_count < len(replaced): + logger.info( + f"{len(replaced) - permanent_count}/{len(replaced)} modules will " + "be restored after calibration" + ) - config = MoEModelConfig(**config_dict) - register_moe_model(model_class_name, config) + try: + yield model + finally: + # Step 2: Restore non-permanent modules + for name, (original, replacement) in replaced.items(): + if not replacement.is_permanent: + restored = replacement.restore() + model.set_submodule(name, restored) From 87e44843347695d562bfd73e942982ed2804e5e6 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 15 Oct 2025 00:09:57 +0530 Subject: [PATCH 17/26] Use simplified implementation and update context entrypoint Signed-off-by: Sairam Pillai --- src/llmcompressor/args/dataset_arguments.py | 27 +--- src/llmcompressor/entrypoints/oneshot.py | 23 +-- src/llmcompressor/modeling/moe_context.py | 34 +++-- src/llmcompressor/modeling/prepare.py | 133 ++++++------------ src/llmcompressor/pipelines/basic/pipeline.py | 7 +- .../pipelines/layer_sequential/pipeline.py | 3 - .../pipelines/sequential/pipeline.py | 3 - 7 files changed, 85 insertions(+), 145 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index e3e92fdf4..5bff8b692 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,7 +7,6 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ -import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union @@ -172,14 +171,15 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) - calibrate_moe_context: Optional[bool] = field( - default=None, + moe_calibrate_all_experts: bool = field( + default=True, metadata={ "help": ( - "DEPRECATED: This parameter is deprecated and will be \ - removed in a future version. " - "MoE calibration context is now handled automatically by the pipeline. " - "This parameter is ignored and will not affect the calibration process." + "Whether to calibrate all experts during MoE model calibration. " + "When True, all experts will see all tokens during calibration, " + "ensuring proper quantization statistics for all experts. " + "When False, only routed experts will be used. " + "Only relevant for MoE models. Default is True." ), }, ) @@ -231,16 +231,3 @@ class DatasetArguments(CustomDatasetArguments): def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None - - def __post_init__(self): - """Post-initialization hook to issue deprecation warnings.""" - if self.calibrate_moe_context is not None: - warnings.warn( - "The 'calibrate_moe_context' parameter is deprecated\ - and will be removed in a future version. " - "MoE calibration context is now handled automatically by the pipeline. " - "This parameter is ignored and will not affect\ - the calibration process.", - DeprecationWarning, - stacklevel=2, - ) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 33537f18e..9e329a928 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -19,6 +19,7 @@ from llmcompressor.core.session_functions import active_session from llmcompressor.datasets import get_calibration_dataloader from llmcompressor.entrypoints.utils import post_process, pre_process +from llmcompressor.modeling.moe_context import moe_calibration_context from llmcompressor.pipelines import CalibrationPipeline __all__ = ["Oneshot", "oneshot"] @@ -198,11 +199,16 @@ def apply_recipe_modifiers( user_pipeline = self.dataset_args.pipeline modifiers = session.lifecycle.recipe.modifiers pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline) - pipeline( + # Apply MoE calibration context for the entire calibration process + with moe_calibration_context( self.model, - calibration_dataloader, - self.dataset_args, - ) + calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts, + ): + pipeline( + self.model, + calibration_dataloader, + self.dataset_args, + ) session.finalize() @@ -241,7 +247,7 @@ def oneshot( overwrite_cache: bool = False, preprocessing_num_workers: Optional[int] = None, min_tokens_per_module: Optional[float] = None, - calibrate_moe_context: bool = False, + moe_calibrate_all_experts: bool = True, quantization_aware_calibration: bool = True, # Miscellaneous arguments output_dir: Optional[str] = None, @@ -305,9 +311,10 @@ def oneshot( preprocessing. :param min_tokens_per_module: Minimum percentage of tokens per module, relevant for MoE models. - :param calibrate_moe_context: If during calibration, the MoE context should be - enabled for the given model. This usually involves updating all MoE modules - in the model for the duration of calibration. + :param moe_calibrate_all_experts: Whether to calibrate all experts during MoE + model calibration. When True, all experts will see all tokens during + calibration, ensuring proper quantization statistics. When False, only + routed experts will be used. Only relevant for MoE models. Default is True. :param quantization_aware_calibration: Whether to enable quantization-aware calibration in the sequential pipeline. When True, quantization is applied during forward pass in calibration. When False, quantization is disabled diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 38d91dcc5..ee2bf065a 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -14,7 +14,7 @@ import contextlib from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Type import torch from loguru import logger @@ -31,10 +31,10 @@ class MoECalibrationModule(ABC, torch.nn.Module): """ Abstract base class for MoE calibration modules. - + Calibration modules replace original MoE modules during the calibration phase to ensure all experts receive data for proper quantization statistics. - + Subclasses must: 1. Implement `from_original()` to create calibration module from original 2. Set `is_permanent` to indicate if module should stay in calibration form @@ -68,9 +68,9 @@ def from_original( def restore(self) -> torch.nn.Module: """ Restore the original module structure. - + Only needed if is_permanent=False. For permanent modules, this is a no-op. - + Returns: The original module (or self if permanent) """ @@ -89,21 +89,19 @@ def restore(self) -> torch.nn.Module: def register_moe_calibration(module_class_name: str): """ Decorator to register a MoE calibration module. - + Usage: @register_moe_calibration("DeepseekV3MoE") class CalibrationDeepseekV3MoE(MoECalibrationModule): ... - + Args: module_class_name: The class name of the original module to replace """ def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]: if not issubclass(cls, MoECalibrationModule): - raise TypeError( - f"{cls.__name__} must inherit from MoECalibrationModule" - ) + raise TypeError(f"{cls.__name__} must inherit from MoECalibrationModule") MOE_CALIBRATION_MODULES[module_class_name] = cls return cls @@ -117,19 +115,19 @@ def moe_calibration_context( ): """ Context manager that applies MoE calibration to a model. - + This scans all modules in the model and replaces any MoE modules with their calibration equivalents. After the context exits, non-permanent modules are restored to their original form. - + + The model is modified in-place, so the same model object should be used + within the context. + Args: - model: The model to apply MoE calibration to + model: The model to apply MoE calibration to (modified in-place) calibrate_all_experts: If True, all experts see all tokens during calibration. If False, use normal routing (useful for some techniques) - - Yields: - The model with MoE calibration applied - + Example: with moe_calibration_context(model): # Run calibration - all experts will see data @@ -170,7 +168,7 @@ def moe_calibration_context( ) try: - yield model + yield finally: # Step 2: Restore non-permanent modules for name, (original, replacement) in replaced.items(): diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index d3c985535..d500e8a7b 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,21 +1,36 @@ -import contextlib -import warnings +""" +MoE model preparation - imports and registration. + +This module imports all MoE calibration modules to ensure they are registered +in the MOE_CALIBRATION_MODULES registry. The actual calibration logic is in +moe_context.py. +""" import tqdm -from compressed_tensors.utils import replace_module +from compressed_tensors.utils import deprecated, replace_module from transformers import PreTrainedModel -from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 -from llmcompressor.modeling.llama4 import replace as replace_llama4 -from llmcompressor.modeling.moe_context import ( - MoECalibrationType, - MoEModelConfig, - get_moe_context, - register_moe_model, +# Import MoE calibration modules to trigger registration +from llmcompressor.modeling.deepseek_v3 import ( # noqa: F401 + CalibrationDeepseekV3MoE, +) +from llmcompressor.modeling.deepseek_v3 import ( + replace as replace_deepseekv3, +) +from llmcompressor.modeling.llama4 import ( # noqa: F401 + SequentialLlama4TextMoe, +) +from llmcompressor.modeling.llama4 import ( + replace as replace_llama4, +) +from llmcompressor.modeling.moe_context import ( # noqa: F401 + moe_calibration_context, +) +from llmcompressor.modeling.qwen3_moe import ( # noqa: F401 + CalibrationQwen3MoeSparseMoeBlock, ) -from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -__all__ = ["moe_calibration_context"] +__all__ = ["moe_calibration_context", "replace_modules_for_calibration"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -24,18 +39,30 @@ } +@deprecated( + message=( + "The function `replace_modules_for_calibration` has been deprecated. " + "Please use `moe_calibration_context` instead. " + ) +) def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: - # This function is deprecated. Use moe_calibration_context instead. - warnings.warn( - "replace_modules_for_calibration is deprecated. " - "Use moe_calibration_context instead.", - DeprecationWarning, - stacklevel=2, - ) + """ + Deprecated function for backward compatibility. + + Use moe_calibration_context instead: + with moe_calibration_context(model, calibrate_all_experts): + # your code here + + Args: + model: The model to modify + calibrate_all_experts: Whether to calibrate all experts + Returns: + The modified model + """ for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: @@ -47,71 +74,3 @@ def replace_modules_for_calibration( replace_module(model, name, new_module) return model - - -# ------------------- module replacements; during calibration -------------------- - -# MoE model configurations - centralized registry -# Adding a new MoE model is now as simple as adding an entry here! -# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ -MOE_EXPERTS_REPLACEMENTS = { - "Qwen3MoeForCausalLM": MoEModelConfig( - calibration_type=MoECalibrationType.CONTEXTUAL, - target_class_name="Qwen3MoeDecoderLayer", - target_attribute="mlp", - replace_function=replace_Qwen3MoE, - description="Qwen3 MoE model with contextual calibration for MLP layers", - ), - "DeepseekV3ForCausalLM": MoEModelConfig( - calibration_type=MoECalibrationType.PERMANENT, - target_class_name="DeepseekV3MoE", - replace_function=replace_deepseekv3, - description="DeepSeek V3 MoE model with permanent calibration", - ), - "Llama4ForConditionalGeneration": MoEModelConfig( - calibration_type=MoECalibrationType.PERMANENT, - target_class_name="Llama4TextMoe", - replace_function=replace_llama4, - description=( - "Llama4 MoE model with permanent calibration for vLLM compatibility" - ), - ), -} - - -# Register all MoE models automatically -for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items(): - register_moe_model(model_class_name, config) - - -@contextlib.contextmanager -def moe_calibration_context( - model: PreTrainedModel, - calibrate_all_experts: bool = True, -): - """ - Context manager for MoE calibration that temporarily updates MoE modules. - - Args: - model: The model to apply MoE calibration to - calibrate_all_experts: Whether to calibrate all experts or only routed ones - - Yields: - The model with MoE calibration applied - """ - cls_name = model.__class__.__name__ - moe_context = get_moe_context(cls_name) - - if moe_context is None: - # No MoE context registered for this model, yield unchanged - yield model - return - - # Apply MoE calibration - moe_context.apply(model, calibrate_all_experts) - - try: - yield model - finally: - # Restore original state - moe_context.restore(model) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index edcb46b09..605358ae9 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,4 +1,3 @@ -import contextlib from typing import TYPE_CHECKING, Union import torch @@ -7,7 +6,6 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.core import LifecycleCallbacks -from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device @@ -44,10 +42,7 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with contextlib.ExitStack() as stack: - stack.enter_context(calibration_forward_context(model)) - stack.enter_context(moe_calibration_context(model)) - + with calibration_forward_context(model): for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index bab732f48..244edde87 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -7,7 +7,6 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.core import LifecycleCallbacks, active_session -from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( @@ -82,8 +81,6 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - stack.enter_context(moe_calibration_context(model)) - # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( model, layers[0], dataloader, model_device diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 9827052b7..261afd654 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -7,7 +7,6 @@ from tqdm import tqdm from llmcompressor.core import LifecycleCallbacks, active_session -from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.registry import CalibrationPipeline @@ -85,8 +84,6 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - stack.enter_context(moe_calibration_context(model)) - # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) From ebafc53fc76deb909fcf4045de095062436aaf6a Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Sat, 18 Oct 2025 04:43:30 +0530 Subject: [PATCH 18/26] Make module replacement verbose and explicit Signed-off-by: Sairam Pillai --- src/llmcompressor/modeling/moe_context.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index ee2bf065a..3e682e91e 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -18,6 +18,7 @@ import torch from loguru import logger +from tqdm import tqdm from transformers import PreTrainedModel __all__ = [ @@ -137,10 +138,20 @@ def moe_calibration_context( """ replaced = {} - # Step 1: Find and replace MoE modules + # Step 1: Collect all MoE modules that need replacement + logger.info("Entering MoE calibration context") + modules_to_replace = [] for name, module in model.named_modules(): class_name = module.__class__.__name__ if class_name in MOE_CALIBRATION_MODULES: + modules_to_replace.append((name, module, class_name)) + + # Step 2: Replace modules with progress bar + if modules_to_replace: + logger.info(f"Found {len(modules_to_replace)} MoE modules to replace") + for name, module, class_name in tqdm( + modules_to_replace, desc="Replacing MoE modules for calibration" + ): calibration_cls = MOE_CALIBRATION_MODULES[class_name] replacement = calibration_cls.from_original( module, From c451635e937686438e3cd3260b64370a38dc3ba4 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Sat, 18 Oct 2025 04:54:44 +0530 Subject: [PATCH 19/26] Update modeling and test files with latest moe_context signature Signed-off-by: Sairam Pillai --- src/llmcompressor/modeling/deepseek_v3.py | 37 +++++++++++++-- src/llmcompressor/modeling/llama4.py | 47 +++++++++++++++++-- src/llmcompressor/modeling/qwen3_moe.py | 41 ++++++++++++++-- src/llmcompressor/pipelines/basic/pipeline.py | 5 +- .../modeling/test_calib_deepseek_v3.py | 22 +++++---- .../modeling/test_calib_llama4.py | 18 +++---- .../modeling/test_calib_qwen3.py | 22 +++++---- 7 files changed, 153 insertions(+), 39 deletions(-) diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 8cc7f4727..64f22af3b 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -4,12 +4,20 @@ DeepseekV3MoE as OriginalDeepseekV3MoE, ) +from llmcompressor.modeling.moe_context import ( + MoECalibrationModule, + register_moe_calibration, +) + -class DeepseekV3MoECalibrate(torch.nn.Module): +@register_moe_calibration("DeepseekV3MoE") +class CalibrationDeepseekV3MoE(MoECalibrationModule): """ - Patched DeepseekV3MoE which sends all tokens to all experts for calibration + Calibration version of DeepseekV3MoE that sends all tokens to all experts. """ + is_permanent = True + def __init__( self, config: DeepseekV3Config, @@ -23,6 +31,20 @@ def __init__( self.shared_experts = original.shared_experts self.calibrate_all_experts = calibrate_all_experts + @classmethod + def from_original( + cls, + original: OriginalDeepseekV3MoE, + config: DeepseekV3Config, + calibrate_all_experts: bool = True, + ) -> "CalibrationDeepseekV3MoE": + """Create calibration module from original DeepseekV3MoE.""" + return cls( + config=config, + original=original, + calibrate_all_experts=calibrate_all_experts, + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape @@ -65,11 +87,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +# Legacy function for backward compatibility def replace( config: DeepseekV3Config, module: OriginalDeepseekV3MoE, calibrate_all_experts: bool, ): - return DeepseekV3MoECalibrate( - config=config, original=module, calibrate_all_experts=calibrate_all_experts + """ + Legacy replacement function. + Use CalibrationDeepseekV3MoE.from_original() instead. + """ + return CalibrationDeepseekV3MoE.from_original( + original=module, + config=config, + calibrate_all_experts=calibrate_all_experts, ) diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 10f4f2033..7aa82bcad 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -11,10 +11,26 @@ Llama4TextMoe, ) +from llmcompressor.modeling.moe_context import ( + MoECalibrationModule, + register_moe_calibration, +) from llmcompressor.utils.dev import skip_weights_initialize -class SequentialLlama4TextMoe(torch.nn.Module): +@register_moe_calibration("Llama4TextMoe") +class SequentialLlama4TextMoe(MoECalibrationModule): + """ + Calibration version of Llama4TextMoe that unpacks experts for sequential processing. + + This module: + 1. Unpacks the packed expert weights (3D -> 2D) for calibration + 2. Optionally sends all tokens to all experts during calibration + 3. Stays in unpacked form (permanent) for vLLM compatibility + """ + + is_permanent = True + def __init__( self, config: Llama4TextConfig, @@ -31,7 +47,25 @@ def __init__( self.shared_expert = original.shared_expert self.calibrate_all_experts = calibrate_all_experts - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]: + @classmethod + def from_original( + cls, + original: Llama4TextMoe, + config: Llama4Config, + calibrate_all_experts: bool = True, + ) -> "SequentialLlama4TextMoe": + """Create calibration module from original Llama4TextMoe.""" + # Extract text config from multimodal config if needed + text_config = ( + config.get_text_config() if hasattr(config, "get_text_config") else config + ) + return cls( + config=text_config, + original=original, + calibrate_all_experts=calibrate_all_experts, + ) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_outputs = self.router(hidden_states) @@ -88,9 +122,14 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts): self[i].down_proj.weight.data = down.t().clone().contiguous() +# Legacy function for backward compatibility def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool): - return SequentialLlama4TextMoe( - config=config.get_text_config(), + """ + Legacy replacement function. + Use SequentialLlama4TextMoe.from_original() instead. + """ + return SequentialLlama4TextMoe.from_original( original=module, + config=config, calibrate_all_experts=calibrate_all_experts, ) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 2b451bc49..fe9b7a5c6 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -20,8 +20,20 @@ Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, ) +from llmcompressor.modeling.moe_context import ( + MoECalibrationModule, + register_moe_calibration, +) + + +@register_moe_calibration("Qwen3MoeSparseMoeBlock") +class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule): + """ + Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts. + """ + + is_permanent = False -class Qwen3MoeSparseMoeBlock(torch.nn.Module): def __init__( self, config: Qwen3MoeConfig, @@ -37,7 +49,21 @@ def __init__( self.gate = original.gate self.experts = original.experts - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + @classmethod + def from_original( + cls, + original: OriginalQwen3MoeSparseMoeBlock, + config: Qwen3MoeConfig, + calibrate_all_experts: bool = True, + ) -> "CalibrationQwen3MoeSparseMoeBlock": + """Create calibration module from original Qwen3MoeSparseMoeBlock.""" + return cls( + config=config, + original=original, + calibrate_all_experts=calibrate_all_experts, + ) + + def forward(self, hidden_states: torch.Tensor): batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) @@ -87,11 +113,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits +# Legacy function for backward compatibility def replace( config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock, calibrate_all_experts: bool, ): - return Qwen3MoeSparseMoeBlock( - config=config, original=module, calibrate_all_experts=calibrate_all_experts + """ + Legacy replacement function. + Use CalibrationQwen3MoeSparseMoeBlock.from_original() instead. + """ + return CalibrationQwen3MoeSparseMoeBlock.from_original( + original=module, + config=config, + calibrate_all_experts=calibrate_all_experts, ) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 605358ae9..b986e6963 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,3 +1,4 @@ +import contextlib from typing import TYPE_CHECKING, Union import torch @@ -42,7 +43,9 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with calibration_forward_context(model): + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index 239365fdd..95f74139f 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -4,13 +4,13 @@ import pytest import torch from transformers import AutoModelForCausalLM - -from llmcompressor.modeling.deepseek_v3 import ( - DeepseekV3Config, - DeepseekV3MoECalibrate, - OriginalDeepseekV3MoE, +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config +from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3MoE as OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import moe_calibration_context + +from llmcompressor.modeling.deepseek_v3 import CalibrationDeepseekV3MoE +from llmcompressor.modeling.moe_context import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -29,7 +29,7 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): # Find a Deepseek MoE layer moe_layer = None for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): + if isinstance(module, CalibrationDeepseekV3MoE): moe_layer = module break @@ -75,12 +75,16 @@ def test_calib_deepseekv3_module(): with calibration_forward_context(original): true_output = original(sample)[0] - module = DeepseekV3MoECalibrate(config, original, calibrate_all_experts=True) + module = CalibrationDeepseekV3MoE.from_original( + original, config, calibrate_all_experts=True + ) with calibration_forward_context(module): output = module(sample)[0] assert torch.allclose(true_output, output, atol=1e-6) - module = DeepseekV3MoECalibrate(config, original, calibrate_all_experts=False) + module = CalibrationDeepseekV3MoE.from_original( + original, config, calibrate_all_experts=False + ) with calibration_forward_context(module): output = module(sample)[0] assert torch.allclose(true_output, output, atol=1e-6) diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index d5363d35c..1b37fa2a9 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -5,13 +5,11 @@ import pytest import torch from transformers import Llama4ForConditionalGeneration +from transformers.models.llama4.configuration_llama4 import Llama4TextConfig +from transformers.models.llama4.modeling_llama4 import Llama4TextMoe -from llmcompressor.modeling.llama4 import ( - Llama4TextConfig, - Llama4TextMoe, - SequentialLlama4TextMoe, -) -from llmcompressor.modeling.prepare import moe_calibration_context +from llmcompressor.modeling.llama4 import SequentialLlama4TextMoe +from llmcompressor.modeling.moe_context import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -82,12 +80,16 @@ def test_calib_llama4_module(): with calibration_forward_context(original): true_output = original(sample)[0] - module = SequentialLlama4TextMoe(config, original, calibrate_all_experts=True) + module = SequentialLlama4TextMoe.from_original( + original, config, calibrate_all_experts=True + ) with calibration_forward_context(module): output = module(sample)[0] assert torch.allclose(true_output, output, atol=1e-6) - module = SequentialLlama4TextMoe(config, original, calibrate_all_experts=False) + module = SequentialLlama4TextMoe.from_original( + original, config, calibrate_all_experts=False + ) with calibration_forward_context(module): output = module(sample)[0] assert torch.allclose(true_output, output, atol=1e-6) diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index a20acf8a8..f3879ad45 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -4,13 +4,13 @@ import pytest import torch from transformers import AutoModelForCausalLM - -from llmcompressor.modeling.prepare import moe_calibration_context -from llmcompressor.modeling.qwen3_moe import ( - OriginalQwen3MoeSparseMoeBlock, - Qwen3MoeConfig, - Qwen3MoeSparseMoeBlock, +from transformers.models import Qwen3MoeConfig +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, ) + +from llmcompressor.modeling.moe_context import moe_calibration_context +from llmcompressor.modeling.qwen3_moe import CalibrationQwen3MoeSparseMoeBlock from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -31,7 +31,7 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): # Find one MoE layer moe_layer = None for name, module in model.named_modules(): - if isinstance(module, Qwen3MoeSparseMoeBlock): + if isinstance(module, CalibrationQwen3MoeSparseMoeBlock): moe_layer = module break @@ -77,12 +77,16 @@ def test_calib_qwen3_moe_module(): with calibration_forward_context(original): true_output = original(sample)[0] - module = Qwen3MoeSparseMoeBlock(config, original, calibrate_all_experts=True) + module = CalibrationQwen3MoeSparseMoeBlock.from_original( + original, config, calibrate_all_experts=True + ) with calibration_forward_context(module): output = module(sample)[0] assert torch.allclose(true_output, output, atol=1e-6) - module = Qwen3MoeSparseMoeBlock(config, original, calibrate_all_experts=False) + module = CalibrationQwen3MoeSparseMoeBlock.from_original( + original, config, calibrate_all_experts=False + ) with calibration_forward_context(module): output = module(sample)[0] assert torch.allclose(true_output, output, atol=1e-6) From f271a51f07be7941ec9d20879960257648677008 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Sat, 18 Oct 2025 04:58:21 +0530 Subject: [PATCH 20/26] Update examples of models where moe_context was added Signed-off-by: Sairam Pillai --- examples/multimodal_vision/llama4_example.py | 11 +++++------ examples/quantization_w4a4_fp4/qwen_30b_a3b.py | 13 ++++++++----- .../llama4_fp8_block_example.py | 5 +++-- examples/quantizing_moe/deepseek_r1_example.py | 2 +- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index c7838bfbf..90b3cbfab 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -13,12 +13,11 @@ # The `SequentialLlama4TextMoe` modules will be applied during calibration # to enable proper expert calibration and vLLM compatibility. # -# NOTE: This restructuring is specifically required for vLLM compatibility -# Users can customize the calibration behavior as needed by modifying the -# To define custom calibration logic, implement your function in -# modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). -# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your -# custom function. +# NOTE: This restructuring is specifically required for vLLM compatibility. +# To define custom calibration logic, create a new calibration module in +# modeling/llama4.py that inherits from `MoECalibrationModule`, and register +# it using the `@register_moe_calibration` decorator with the appropriate +# module class name (e.g., "Llama4TextMoe"). DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py index bfda23218..60d2c802d 100644 --- a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py +++ b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py @@ -59,18 +59,21 @@ def tokenize(sample): ) # Apply quantization. -# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock` -# during calibration. +# MoE calibration is now handled automatically by the pipeline. +# We set `moe_calibrate_all_experts` to True to ensure all experts receive +# calibration data. This temporarily updates the model definition to use +# `CalibrationQwen3MoeSparseMoeBlock` which updates how the forward pass is +# handled in the MoE block during calibration. # Feel free to update the definition under -# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with -# this behaviour and evaluate its impact on quantization performance +# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py to play around with +# this behavior and evaluate its impact on quantization performance. oneshot( model=model, dataset=ds, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, - calibrate_moe_context=True, + moe_calibrate_all_experts=True, ) diff --git a/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py b/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py index d28e02716..2b12e2f06 100644 --- a/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py @@ -1,7 +1,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.utils import dispatch_for_generation @@ -10,7 +9,9 @@ # Load model. model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. # Configure the quantization algorithm and scheme. # In this case, we: # * quantize the weights to fp8 with block size 128 via ptq diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9e5d1ca63..6a68b474c 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -20,7 +20,7 @@ ) tokenizer = AutoTokenizer.from_pretrained(model_id) # MoE calibration is now handled automatically by the pipeline. -# The `DeepseekV3MoECalibrate` modules will be applied during calibration +# The `CalibrationDeepseekV3MoE` modules will be applied during calibration # to enable proper expert calibration. # Select calibration dataset. From 32546bb1ac17a5cb388990a4e26a2a4e8b093c59 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Mon, 27 Oct 2025 19:27:27 +0530 Subject: [PATCH 21/26] Simplify calibrate fallback calls Signed-off-by: Sairam Pillai --- src/llmcompressor/modeling/deepseek_v3.py | 4 ++-- src/llmcompressor/modeling/llama4.py | 4 ++-- src/llmcompressor/modeling/qwen3_moe.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 64f22af3b..62ab551bb 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -95,9 +95,9 @@ def replace( ): """ Legacy replacement function. - Use CalibrationDeepseekV3MoE.from_original() instead. + Use CalibrationDeepseekV3MoE instead. """ - return CalibrationDeepseekV3MoE.from_original( + return CalibrationDeepseekV3MoE( original=module, config=config, calibrate_all_experts=calibrate_all_experts, diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 7aa82bcad..f5173faf3 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -126,9 +126,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts): def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool): """ Legacy replacement function. - Use SequentialLlama4TextMoe.from_original() instead. + Use SequentialLlama4TextMoe instead. """ - return SequentialLlama4TextMoe.from_original( + return SequentialLlama4TextMoe( original=module, config=config, calibrate_all_experts=calibrate_all_experts, diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index fe9b7a5c6..4954652bf 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -121,9 +121,9 @@ def replace( ): """ Legacy replacement function. - Use CalibrationQwen3MoeSparseMoeBlock.from_original() instead. + Use CalibrationQwen3MoeSparseMoeBlock instead. """ - return CalibrationQwen3MoeSparseMoeBlock.from_original( + return CalibrationQwen3MoeSparseMoeBlock( original=module, config=config, calibrate_all_experts=calibrate_all_experts, From 92dd5fdda3db03c4b3acfc16369af0a99a337f56 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 29 Oct 2025 18:53:14 +0530 Subject: [PATCH 22/26] Update docstrings with class location Signed-off-by: Sairam Pillai --- examples/multimodal_vision/llama4_example.py | 6 ++++-- examples/quantization_w4a4_fp4/llama4_example.py | 6 ++++-- examples/quantization_w4a4_fp4/qwen_30b_a3b.py | 6 ++++-- examples/quantization_w8a8_fp8/llama4_fp8_block_example.py | 6 ++++-- examples/quantizing_moe/deepseek_r1_example.py | 6 ++++-- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index b8d3104d5..c71c06ab7 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -10,8 +10,10 @@ model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) # MoE calibration is now handled automatically by the pipeline. -# The `SequentialLlama4TextMoe` modules will be applied during calibration -# to enable proper expert calibration and vLLM compatibility. +# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) +# will be applied during calibration to enable proper expert calibration and vLLM compatibility. +# These replace the original `Llama4TextMoe` class from +# `transformers.models.llama4.modeling_llama4`. # # NOTE: This restructuring is specifically required for vLLM compatibility. # To define custom calibration logic, create a new calibration module in diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 2c23c67ef..597dcb301 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -10,8 +10,10 @@ model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) # MoE calibration is now handled automatically by the pipeline. -# The `SequentialLlama4TextMoe` modules will be applied during calibration -# to enable proper expert calibration and vLLM compatibility. +# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) +# will be applied during calibration to enable proper expert calibration and vLLM compatibility. +# These replace the original `Llama4TextMoe` class from +# `transformers.models.llama4.modeling_llama4`. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py index 60d2c802d..bbbea7dc9 100644 --- a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py +++ b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py @@ -62,8 +62,10 @@ def tokenize(sample): # MoE calibration is now handled automatically by the pipeline. # We set `moe_calibrate_all_experts` to True to ensure all experts receive # calibration data. This temporarily updates the model definition to use -# `CalibrationQwen3MoeSparseMoeBlock` which updates how the forward pass is -# handled in the MoE block during calibration. +# `CalibrationQwen3MoeSparseMoeBlock` (from `llmcompressor.modeling.qwen3_moe`) +# which replaces the original `Qwen3MoeSparseMoeBlock` class from +# `transformers.models.qwen3_moe.modeling_qwen3_moe`. This updates how the +# forward pass is handled in the MoE block during calibration. # Feel free to update the definition under # llm-compressor/src/llmcompressor/modeling/qwen3_moe.py to play around with # this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py b/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py index 5554d374c..c2d0d340c 100644 --- a/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py @@ -10,8 +10,10 @@ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # MoE calibration is now handled automatically by the pipeline. -# The `SequentialLlama4TextMoe` modules will be applied during calibration -# to enable proper expert calibration and vLLM compatibility. +# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) +# will be applied during calibration to enable proper expert calibration and vLLM compatibility. +# These replace the original `Llama4TextMoe` class from +# `transformers.models.llama4.modeling_llama4`. # Configure the quantization algorithm and scheme. # In this case, we: # * quantize the weights to fp8 with block size 128 via ptq diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 6a68b474c..70bb1f049 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -20,8 +20,10 @@ ) tokenizer = AutoTokenizer.from_pretrained(model_id) # MoE calibration is now handled automatically by the pipeline. -# The `CalibrationDeepseekV3MoE` modules will be applied during calibration -# to enable proper expert calibration. +# The `CalibrationDeepseekV3MoE` modules (from `llmcompressor.modeling.deepseek_v3`) +# will be applied during calibration to enable proper expert calibration. +# These replace the original `DeepseekV3MoE` class from +# `transformers.models.deepseek_v3.modeling_deepseek_v3`. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" From 40185f314a50ec957fdba9bf61159e6745cbf202 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 29 Oct 2025 19:10:07 +0530 Subject: [PATCH 23/26] Fix linting Signed-off-by: Sairam Pillai --- examples/multimodal_vision/llama4_example.py | 3 ++- examples/quantization_w4a4_fp4/llama4_example.py | 3 ++- examples/quantization_w8a8_fp8/llama4_fp8_block_example.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index c71c06ab7..53b98621f 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -11,7 +11,8 @@ processor = Llama4Processor.from_pretrained(model_id) # MoE calibration is now handled automatically by the pipeline. # The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) -# will be applied during calibration to enable proper expert calibration and vLLM compatibility. +# will be applied during calibration to enable proper +# expert calibration and vLLM compatibility. # These replace the original `Llama4TextMoe` class from # `transformers.models.llama4.modeling_llama4`. # diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 597dcb301..de35bfa2f 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -11,7 +11,8 @@ processor = Llama4Processor.from_pretrained(model_id) # MoE calibration is now handled automatically by the pipeline. # The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) -# will be applied during calibration to enable proper expert calibration and vLLM compatibility. +# will be applied during calibration to enable +# proper expert calibration and vLLM compatibility. # These replace the original `Llama4TextMoe` class from # `transformers.models.llama4.modeling_llama4`. diff --git a/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py b/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py index c2d0d340c..f7974aebe 100644 --- a/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py @@ -11,7 +11,8 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # MoE calibration is now handled automatically by the pipeline. # The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) -# will be applied during calibration to enable proper expert calibration and vLLM compatibility. +# will be applied during calibration to enable +# proper expert calibration and vLLM compatibility. # These replace the original `Llama4TextMoe` class from # `transformers.models.llama4.modeling_llama4`. # Configure the quantization algorithm and scheme. From 166606bd9ce1a302657f8dc773bc4c3173afaa9a Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 29 Oct 2025 19:26:41 +0530 Subject: [PATCH 24/26] Refactor to remove from_original() and use init Signed-off-by: Sairam Pillai --- src/llmcompressor/modeling/deepseek_v3.py | 22 ++--------- src/llmcompressor/modeling/llama4.py | 38 ++++++------------- src/llmcompressor/modeling/moe_context.py | 33 +++------------- src/llmcompressor/modeling/qwen3_moe.py | 22 ++--------- .../modeling/test_calib_deepseek_v3.py | 8 +--- .../modeling/test_calib_llama4.py | 8 +--- .../modeling/test_calib_qwen3.py | 4 +- 7 files changed, 32 insertions(+), 103 deletions(-) diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 62ab551bb..c2dd8f4b6 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -20,9 +20,9 @@ class CalibrationDeepseekV3MoE(MoECalibrationModule): def __init__( self, - config: DeepseekV3Config, original: OriginalDeepseekV3MoE, - calibrate_all_experts: bool, + config: DeepseekV3Config, + calibrate_all_experts: bool = True, ): super().__init__() self.config = config @@ -31,20 +31,6 @@ def __init__( self.shared_experts = original.shared_experts self.calibrate_all_experts = calibrate_all_experts - @classmethod - def from_original( - cls, - original: OriginalDeepseekV3MoE, - config: DeepseekV3Config, - calibrate_all_experts: bool = True, - ) -> "CalibrationDeepseekV3MoE": - """Create calibration module from original DeepseekV3MoE.""" - return cls( - config=config, - original=original, - calibrate_all_experts=calibrate_all_experts, - ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape @@ -98,7 +84,7 @@ def replace( Use CalibrationDeepseekV3MoE instead. """ return CalibrationDeepseekV3MoE( - original=module, - config=config, + module, + config, calibrate_all_experts=calibrate_all_experts, ) diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 0b3b20ac0..2b49a652a 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -33,37 +33,23 @@ class SequentialLlama4TextMoe(MoECalibrationModule): def __init__( self, - config: Llama4TextConfig, - original: Llama4TextMoe, - calibrate_all_experts: bool, - ): - super().__init__() - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - self.num_experts = config.num_local_experts - - self.experts = SequentialLlama4TextExperts(config, original.experts) - self.router = original.router - self.shared_expert = original.shared_expert - self.calibrate_all_experts = calibrate_all_experts - - @classmethod - def from_original( - cls, original: Llama4TextMoe, config: Llama4Config, calibrate_all_experts: bool = True, - ) -> "SequentialLlama4TextMoe": - """Create calibration module from original Llama4TextMoe.""" + ): + super().__init__() # Extract text config from multimodal config if needed text_config = ( config.get_text_config() if hasattr(config, "get_text_config") else config ) - return cls( - config=text_config, - original=original, - calibrate_all_experts=calibrate_all_experts, - ) + self.top_k = text_config.num_experts_per_tok + self.hidden_dim = text_config.hidden_size + self.num_experts = text_config.num_local_experts + + self.experts = SequentialLlama4TextExperts(text_config, original.experts) + self.router = original.router + self.shared_expert = original.shared_expert + self.calibrate_all_experts = calibrate_all_experts def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: hidden_states = hidden_states.reshape(-1, self.hidden_dim) @@ -115,7 +101,7 @@ def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: Use SequentialLlama4TextMoe instead. """ return SequentialLlama4TextMoe( - original=module, - config=config, + module, + config, calibrate_all_experts=calibrate_all_experts, ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 3e682e91e..35c4470b7 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -13,8 +13,8 @@ """ import contextlib -from abc import ABC, abstractmethod -from typing import Any, Dict, Type +from abc import ABC +from typing import Dict, Type import torch from loguru import logger @@ -37,35 +37,14 @@ class MoECalibrationModule(ABC, torch.nn.Module): phase to ensure all experts receive data for proper quantization statistics. Subclasses must: - 1. Implement `from_original()` to create calibration module from original + 1. Implement `__init__()` with signature: + (self, original, config, calibrate_all_experts=True) 2. Set `is_permanent` to indicate if module should stay in calibration form 3. Optionally implement `restore()` if is_permanent=False """ is_permanent: bool = False - @classmethod - @abstractmethod - def from_original( - cls, - original: torch.nn.Module, - config: Any, - calibrate_all_experts: bool = True, - ) -> "MoECalibrationModule": - """ - Create a calibration module from the original MoE module. - - Args: - original: The original MoE module to convert - config: Model configuration (contains num_experts, etc.) - calibrate_all_experts: If True, send all tokens to all experts. - If False, use normal routing. - - Returns: - Instance of the calibration module - """ - pass - def restore(self) -> torch.nn.Module: """ Restore the original module structure. @@ -153,10 +132,10 @@ def moe_calibration_context( modules_to_replace, desc="Replacing MoE modules for calibration" ): calibration_cls = MOE_CALIBRATION_MODULES[class_name] - replacement = calibration_cls.from_original( + replacement = calibration_cls( module, model.config, - calibrate_all_experts, + calibrate_all_experts=calibrate_all_experts, ) model.set_submodule(name, replacement) replaced[name] = (module, replacement) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 4954652bf..49c3fa874 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -36,9 +36,9 @@ class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule): def __init__( self, - config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock, - calibrate_all_experts: bool, + config: Qwen3MoeConfig, + calibrate_all_experts: bool = True, ): super().__init__() self.num_experts = config.num_experts @@ -49,20 +49,6 @@ def __init__( self.gate = original.gate self.experts = original.experts - @classmethod - def from_original( - cls, - original: OriginalQwen3MoeSparseMoeBlock, - config: Qwen3MoeConfig, - calibrate_all_experts: bool = True, - ) -> "CalibrationQwen3MoeSparseMoeBlock": - """Create calibration module from original Qwen3MoeSparseMoeBlock.""" - return cls( - config=config, - original=original, - calibrate_all_experts=calibrate_all_experts, - ) - def forward(self, hidden_states: torch.Tensor): batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -124,7 +110,7 @@ def replace( Use CalibrationQwen3MoeSparseMoeBlock instead. """ return CalibrationQwen3MoeSparseMoeBlock( - original=module, - config=config, + module, + config, calibrate_all_experts=calibrate_all_experts, ) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index 58414b122..e4e15d300 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -75,16 +75,12 @@ def test_calib_deepseekv3_module(): with calibration_forward_context(original): true_output = original(sample) - module = CalibrationDeepseekV3MoE.from_original( - original, config, calibrate_all_experts=True - ) + module = CalibrationDeepseekV3MoE(original, config, calibrate_all_experts=True) with calibration_forward_context(module): output = module(sample) assert torch.allclose(true_output, output, atol=1e-6) - module = CalibrationDeepseekV3MoE.from_original( - original, config, calibrate_all_experts=False - ) + module = CalibrationDeepseekV3MoE(original, config, calibrate_all_experts=False) with calibration_forward_context(module): output = module(sample) assert torch.allclose(true_output, output, atol=1e-6) diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 8940c985f..78fb4ee6d 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -82,17 +82,13 @@ def test_calib_llama4_module(): with calibration_forward_context(original): true_out, true_router_logits = original(sample) - module = SequentialLlama4TextMoe.from_original( - original, config, calibrate_all_experts=True - ) + module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=True) with calibration_forward_context(module): out, router_logits = module(sample) assert torch.nn.functional.mse_loss(true_out, out) < 1e-10 assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 1e-10 - module = SequentialLlama4TextMoe.from_original( - original, config, calibrate_all_experts=False - ) + module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=False) with calibration_forward_context(module): out, router_logits = module(sample) assert torch.nn.functional.mse_loss(true_out, out) < 1e-10 diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 678c20d76..d6e54776d 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -77,7 +77,7 @@ def test_calib_qwen3_moe_module(): with calibration_forward_context(original): true_output = original(sample) - module = CalibrationQwen3MoeSparseMoeBlock.from_original( + module = CalibrationQwen3MoeSparseMoeBlock( original, config, calibrate_all_experts=True ) with calibration_forward_context(module): @@ -85,7 +85,7 @@ def test_calib_qwen3_moe_module(): assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10 assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10 - module = CalibrationQwen3MoeSparseMoeBlock.from_original( + module = CalibrationQwen3MoeSparseMoeBlock( original, config, calibrate_all_experts=False ) with calibration_forward_context(module): From 73722eb15ac0b3585304398856e2db12d2e49a48 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 29 Oct 2025 19:28:08 +0530 Subject: [PATCH 25/26] Fix prepare.py Signed-off-by: Sairam Pillai --- src/llmcompressor/modeling/prepare.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index defd7a0c7..e93f1278d 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -29,6 +29,9 @@ from llmcompressor.modeling.qwen3_moe import ( # noqa: F401 CalibrationQwen3MoeSparseMoeBlock, ) +from llmcompressor.modeling.qwen3_vl_moe import ( + replace as replace_Qwen3VLMoE, +) __all__ = ["moe_calibration_context", "replace_modules_for_calibration"] From 0348c56d6600707cb34e136f5d77e381d7ad925a Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 29 Oct 2025 19:38:47 +0530 Subject: [PATCH 26/26] Fix linting error Signed-off-by: Sairam Pillai --- src/llmcompressor/pipelines/basic/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index b986e6963..e6494fc5e 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -45,7 +45,6 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device)