diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index d93df605ae..53b98621f3 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,18 +3,24 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) +# will be applied during calibration to enable proper +# expert calibration and vLLM compatibility. +# These replace the original `Llama4TextMoe` class from +# `transformers.models.llama4.modeling_llama4`. +# +# NOTE: This restructuring is specifically required for vLLM compatibility. +# To define custom calibration logic, create a new calibration module in +# modeling/llama4.py that inherits from `MoECalibrationModule`, and register +# it using the `@register_moe_calibration` decorator with the appropriate +# module class name (e.g., "Llama4TextMoe"). DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/README.md b/examples/quantization_w4a4_fp4/README.md index ab9e3eb37c..a0d458722c 100644 --- a/examples/quantization_w4a4_fp4/README.md +++ b/examples/quantization_w4a4_fp4/README.md @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model! # Quantizing MoEs -To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to: +To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which: -1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. -2. Ensure experts are quantized correctly as not all experts are activated during calibration +1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. +2. Ensures experts are quantized correctly as not all experts are activated during calibration -Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. +Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index ec476f1f0e..de35bfa2fb 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,18 +3,18 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) +# will be applied during calibration to enable +# proper expert calibration and vLLM compatibility. +# These replace the original `Llama4TextMoe` class from +# `transformers.models.llama4.modeling_llama4`. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py index bfda232181..bbbea7dc9c 100644 --- a/examples/quantization_w4a4_fp4/qwen_30b_a3b.py +++ b/examples/quantization_w4a4_fp4/qwen_30b_a3b.py @@ -59,18 +59,23 @@ def tokenize(sample): ) # Apply quantization. -# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock` -# during calibration. +# MoE calibration is now handled automatically by the pipeline. +# We set `moe_calibrate_all_experts` to True to ensure all experts receive +# calibration data. This temporarily updates the model definition to use +# `CalibrationQwen3MoeSparseMoeBlock` (from `llmcompressor.modeling.qwen3_moe`) +# which replaces the original `Qwen3MoeSparseMoeBlock` class from +# `transformers.models.qwen3_moe.modeling_qwen3_moe`. This updates how the +# forward pass is handled in the MoE block during calibration. # Feel free to update the definition under -# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with -# this behaviour and evaluate its impact on quantization performance +# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py to play around with +# this behavior and evaluate its impact on quantization performance. oneshot( model=model, dataset=ds, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, - calibrate_moe_context=True, + moe_calibrate_all_experts=True, ) diff --git a/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py b/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py index 6d34d8e98e..f7974aebeb 100644 --- a/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/llama4_fp8_block_example.py @@ -1,7 +1,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.utils import dispatch_for_generation @@ -10,7 +9,12 @@ # Load model. model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`) +# will be applied during calibration to enable +# proper expert calibration and vLLM compatibility. +# These replace the original `Llama4TextMoe` class from +# `transformers.models.llama4.modeling_llama4`. # Configure the quantization algorithm and scheme. # In this case, we: # * quantize the weights to fp8 with block size 128 via ptq diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9977584c32..70bb1f0499 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -2,7 +2,6 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -20,7 +19,11 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `CalibrationDeepseekV3MoE` modules (from `llmcompressor.modeling.deepseek_v3`) +# will be applied during calibration to enable proper expert calibration. +# These replace the original `DeepseekV3MoE` class from +# `transformers.models.deepseek_v3.modeling_deepseek_v3`. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 37335293b9..f3b4c3c565 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -126,16 +126,6 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - calibrate_moe_context: bool = field( - default=False, - metadata={ - "help": "If during calibration, the MoE context should be enabled " - "for the given model. This usually involves updating all MoE modules " - "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoEs and their updated " - "module definitions" - }, - ) shuffle_calibration_samples: bool | None = field( default=True, metadata={ @@ -181,6 +171,18 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) + moe_calibrate_all_experts: bool = field( + default=True, + metadata={ + "help": ( + "Whether to calibrate all experts during MoE model calibration. " + "When True, all experts will see all tokens during calibration, " + "ensuring proper quantization statistics for all experts. " + "When False, only routed experts will be used. " + "Only relevant for MoE models. Default is True." + ), + }, + ) # --- pipeline arguments --- # pipeline: str | None = field( default="independent", diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index b6e1e2b633..ba3f0e1bbf 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -20,6 +20,7 @@ from llmcompressor.core.session_functions import active_session from llmcompressor.datasets import get_calibration_dataloader from llmcompressor.entrypoints.utils import post_process, pre_process +from llmcompressor.modeling.moe_context import moe_calibration_context from llmcompressor.pipelines import CalibrationPipeline __all__ = ["Oneshot", "oneshot"] @@ -209,11 +210,16 @@ def apply_recipe_modifiers( user_pipeline = self.dataset_args.pipeline modifiers = session.lifecycle.recipe.modifiers pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline) - pipeline( + # Apply MoE calibration context for the entire calibration process + with moe_calibration_context( self.model, - calibration_dataloader, - self.dataset_args, - ) + calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts, + ): + pipeline( + self.model, + calibration_dataloader, + self.dataset_args, + ) session.finalize() @@ -252,7 +258,7 @@ def oneshot( overwrite_cache: bool = False, preprocessing_num_workers: Optional[int] = None, min_tokens_per_module: Optional[float] = None, - calibrate_moe_context: bool = False, + moe_calibrate_all_experts: bool = True, quantization_aware_calibration: bool = True, # Miscellaneous arguments output_dir: Optional[str] = None, @@ -316,9 +322,10 @@ def oneshot( preprocessing. :param min_tokens_per_module: Minimum percentage of tokens per module, relevant for MoE models. - :param calibrate_moe_context: If during calibration, the MoE context should be - enabled for the given model. This usually involves updating all MoE modules - in the model for the duration of calibration. + :param moe_calibrate_all_experts: Whether to calibrate all experts during MoE + model calibration. When True, all experts will see all tokens during + calibration, ensuring proper quantization statistics. When False, only + routed experts will be used. Only relevant for MoE models. Default is True. :param quantization_aware_calibration: Whether to enable quantization-aware calibration in the sequential pipeline. When True, quantization is applied during forward pass in calibration. When False, quantization is disabled diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 8cc7f47271..c2dd8f4b69 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -4,17 +4,25 @@ DeepseekV3MoE as OriginalDeepseekV3MoE, ) +from llmcompressor.modeling.moe_context import ( + MoECalibrationModule, + register_moe_calibration, +) + -class DeepseekV3MoECalibrate(torch.nn.Module): +@register_moe_calibration("DeepseekV3MoE") +class CalibrationDeepseekV3MoE(MoECalibrationModule): """ - Patched DeepseekV3MoE which sends all tokens to all experts for calibration + Calibration version of DeepseekV3MoE that sends all tokens to all experts. """ + is_permanent = True + def __init__( self, - config: DeepseekV3Config, original: OriginalDeepseekV3MoE, - calibrate_all_experts: bool, + config: DeepseekV3Config, + calibrate_all_experts: bool = True, ): super().__init__() self.config = config @@ -65,11 +73,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +# Legacy function for backward compatibility def replace( config: DeepseekV3Config, module: OriginalDeepseekV3MoE, calibrate_all_experts: bool, ): - return DeepseekV3MoECalibrate( - config=config, original=module, calibrate_all_experts=calibrate_all_experts + """ + Legacy replacement function. + Use CalibrationDeepseekV3MoE instead. + """ + return CalibrationDeepseekV3MoE( + module, + config, + calibrate_all_experts=calibrate_all_experts, ) diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 02deb90acc..2b49a652af 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -11,27 +11,47 @@ Llama4TextMoe, ) +from llmcompressor.modeling.moe_context import ( + MoECalibrationModule, + register_moe_calibration, +) from llmcompressor.utils.dev import skip_weights_initialize -class SequentialLlama4TextMoe(torch.nn.Module): +@register_moe_calibration("Llama4TextMoe") +class SequentialLlama4TextMoe(MoECalibrationModule): + """ + Calibration version of Llama4TextMoe that unpacks experts for sequential processing. + + This module: + 1. Unpacks the packed expert weights (3D -> 2D) for calibration + 2. Optionally sends all tokens to all experts during calibration + 3. Stays in unpacked form (permanent) for vLLM compatibility + """ + + is_permanent = True + def __init__( self, - config: Llama4TextConfig, original: Llama4TextMoe, - calibrate_all_experts: bool, + config: Llama4Config, + calibrate_all_experts: bool = True, ): super().__init__() - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - self.num_experts = config.num_local_experts - - self.experts = SequentialLlama4TextExperts(config, original.experts) + # Extract text config from multimodal config if needed + text_config = ( + config.get_text_config() if hasattr(config, "get_text_config") else config + ) + self.top_k = text_config.num_experts_per_tok + self.hidden_dim = text_config.hidden_size + self.num_experts = text_config.num_local_experts + + self.experts = SequentialLlama4TextExperts(text_config, original.experts) self.router = original.router self.shared_expert = original.shared_expert self.calibrate_all_experts = calibrate_all_experts - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]: + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_scores, router_logits = self.router(hidden_states) # transformers>=4.54 @@ -74,9 +94,14 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts): self[i].down_proj.weight.data = down.t().contiguous() +# Legacy function for backward compatibility def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool): + """ + Legacy replacement function. + Use SequentialLlama4TextMoe instead. + """ return SequentialLlama4TextMoe( - config=config.get_text_config(), - original=module, + module, + config, calibrate_all_experts=calibrate_all_experts, ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py new file mode 100644 index 0000000000..35c4470b7a --- /dev/null +++ b/src/llmcompressor/modeling/moe_context.py @@ -0,0 +1,167 @@ +""" +Simplified interface for MoE model calibration. + +MoE (Mixture of Experts) models route tokens to different expert networks. +During calibration for quantization/compression, we need to ensure ALL experts +see data, not just the ones selected by the router. This module provides the +infrastructure to temporarily modify MoE modules for proper calibration. + +Key components: +- MoECalibrationModule: Abstract base class for calibration modules +- MOE_CALIBRATION_MODULES: Registry mapping module class names to calibration classes +- moe_calibration_context: Context manager that applies calibration to a model +""" + +import contextlib +from abc import ABC +from typing import Dict, Type + +import torch +from loguru import logger +from tqdm import tqdm +from transformers import PreTrainedModel + +__all__ = [ + "MoECalibrationModule", + "MOE_CALIBRATION_MODULES", + "register_moe_calibration", + "moe_calibration_context", +] + + +class MoECalibrationModule(ABC, torch.nn.Module): + """ + Abstract base class for MoE calibration modules. + + Calibration modules replace original MoE modules during the calibration + phase to ensure all experts receive data for proper quantization statistics. + + Subclasses must: + 1. Implement `__init__()` with signature: + (self, original, config, calibrate_all_experts=True) + 2. Set `is_permanent` to indicate if module should stay in calibration form + 3. Optionally implement `restore()` if is_permanent=False + """ + + is_permanent: bool = False + + def restore(self) -> torch.nn.Module: + """ + Restore the original module structure. + + Only needed if is_permanent=False. For permanent modules, this is a no-op. + + Returns: + The original module (or self if permanent) + """ + if self.is_permanent: + return self + raise NotImplementedError( + f"{self.__class__.__name__} has is_permanent=False but doesn't " + "implement restore()" + ) + + +# Registry: module class name -> calibration module class +MOE_CALIBRATION_MODULES: Dict[str, Type[MoECalibrationModule]] = {} + + +def register_moe_calibration(module_class_name: str): + """ + Decorator to register a MoE calibration module. + + Usage: + @register_moe_calibration("DeepseekV3MoE") + class CalibrationDeepseekV3MoE(MoECalibrationModule): + ... + + Args: + module_class_name: The class name of the original module to replace + """ + + def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]: + if not issubclass(cls, MoECalibrationModule): + raise TypeError(f"{cls.__name__} must inherit from MoECalibrationModule") + MOE_CALIBRATION_MODULES[module_class_name] = cls + return cls + + return decorator + + +@contextlib.contextmanager +def moe_calibration_context( + model: PreTrainedModel, + calibrate_all_experts: bool = True, +): + """ + Context manager that applies MoE calibration to a model. + + This scans all modules in the model and replaces any MoE modules with their + calibration equivalents. After the context exits, non-permanent modules are + restored to their original form. + + The model is modified in-place, so the same model object should be used + within the context. + + Args: + model: The model to apply MoE calibration to (modified in-place) + calibrate_all_experts: If True, all experts see all tokens during calibration. + If False, use normal routing (useful for some techniques) + + Example: + with moe_calibration_context(model): + # Run calibration - all experts will see data + for batch in dataloader: + model(**batch) + # Model is now restored (unless permanent) + """ + replaced = {} + + # Step 1: Collect all MoE modules that need replacement + logger.info("Entering MoE calibration context") + modules_to_replace = [] + for name, module in model.named_modules(): + class_name = module.__class__.__name__ + if class_name in MOE_CALIBRATION_MODULES: + modules_to_replace.append((name, module, class_name)) + + # Step 2: Replace modules with progress bar + if modules_to_replace: + logger.info(f"Found {len(modules_to_replace)} MoE modules to replace") + for name, module, class_name in tqdm( + modules_to_replace, desc="Replacing MoE modules for calibration" + ): + calibration_cls = MOE_CALIBRATION_MODULES[class_name] + replacement = calibration_cls( + module, + model.config, + calibrate_all_experts=calibrate_all_experts, + ) + model.set_submodule(name, replacement) + replaced[name] = (module, replacement) + + # Log what was replaced + if replaced: + logger.info(f"Replaced {len(replaced)} MoE modules for calibration") + permanent_count = sum( + 1 for _, (_, repl) in replaced.items() if repl.is_permanent + ) + if permanent_count > 0: + logger.info( + f"{permanent_count}/{len(replaced)} modules will remain in " + "calibration form (permanent)" + ) + if permanent_count < len(replaced): + logger.info( + f"{len(replaced) - permanent_count}/{len(replaced)} modules will " + "be restored after calibration" + ) + + try: + yield + finally: + # Step 2: Restore non-permanent modules + for name, (original, replacement) in replaced.items(): + if not replacement.is_permanent: + restored = replacement.restore() + model.set_submodule(name, restored) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index fc30794b83..e93f1278d6 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,15 +1,39 @@ +""" +MoE model preparation - imports and registration. + +This module imports all MoE calibration modules to ensure they are registered +in the MOE_CALIBRATION_MODULES registry. The actual calibration logic is in +moe_context.py. +""" + import tqdm -from compressed_tensors.utils import replace_module +from compressed_tensors.utils import deprecated, replace_module from transformers import PreTrainedModel -from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 -from llmcompressor.modeling.llama4 import replace as replace_llama4 -from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.modeling.qwen3_next_moe import replace as replace_Qwen3NextMoE -from llmcompressor.modeling.qwen3_vl_moe import replace as replace_Qwen3VLMoE -from llmcompressor.utils.helpers import patch_attr - -__all__ = ["replace_modules_for_calibration"] +# Import MoE calibration modules to trigger registration +from llmcompressor.modeling.deepseek_v3 import ( # noqa: F401 + CalibrationDeepseekV3MoE, +) +from llmcompressor.modeling.deepseek_v3 import ( + replace as replace_deepseekv3, +) +from llmcompressor.modeling.llama4 import ( # noqa: F401 + SequentialLlama4TextMoe, +) +from llmcompressor.modeling.llama4 import ( + replace as replace_llama4, +) +from llmcompressor.modeling.moe_context import ( # noqa: F401 + moe_calibration_context, +) +from llmcompressor.modeling.qwen3_moe import ( # noqa: F401 + CalibrationQwen3MoeSparseMoeBlock, +) +from llmcompressor.modeling.qwen3_vl_moe import ( + replace as replace_Qwen3VLMoE, +) + +__all__ = ["moe_calibration_context", "replace_modules_for_calibration"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -19,10 +43,30 @@ } +@deprecated( + message=( + "The function `replace_modules_for_calibration` has been deprecated. " + "Please use `moe_calibration_context` instead. " + ) +) def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: + """ + Deprecated function for backward compatibility. + + Use moe_calibration_context instead: + with moe_calibration_context(model, calibrate_all_experts): + # your code here + + Args: + model: The model to modify + calibrate_all_experts: Whether to calibrate all experts + + Returns: + The modified model + """ for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: @@ -34,63 +78,3 @@ def replace_modules_for_calibration( replace_module(model, name, new_module) return model - - -# ------------------- module replacements; during calibration -------------------- - - -def update_qwen3_moe(model, module, stack, calibrate_all_experts): - cls_name = module.__class__.__name__ - if ( - cls_name == "Qwen3MoeDecoderLayer" - and module.mlp.__class__.__name__ == "Qwen3MoeSparseMoeBlock" - ): - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3MoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) - - -def update_qwen3_next_moe(model, module, stack, calibrate_all_experts): - cls_name = module.__class__.__name__ - if ( - cls_name == "Qwen3NextDecoderLayer" - and module.mlp.__class__.__name__ == "Qwen3NextSparseMoeBlock" - ): - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3NextMoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) - - -moe_context = { - "Qwen3MoeForCausalLM": update_qwen3_moe, - "Qwen3NextForCausalLM": update_qwen3_next_moe, -} - - -def moe_calibration_context( - model: PreTrainedModel, - stack, - calibrate_all_experts: bool = True, -): - # Temporarily updates the MoE modules within the context - # Once the context exists, parameter updates persist - model_name = model.__class__.__name__ - if model_name in moe_context: - for module in model.modules(): - moe_context[model_name](model, module, stack, calibrate_all_experts) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 2b451bc498..49c3fa8745 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -20,13 +20,25 @@ Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, ) +from llmcompressor.modeling.moe_context import ( + MoECalibrationModule, + register_moe_calibration, +) + + +@register_moe_calibration("Qwen3MoeSparseMoeBlock") +class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule): + """ + Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts. + """ + + is_permanent = False -class Qwen3MoeSparseMoeBlock(torch.nn.Module): def __init__( self, - config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock, - calibrate_all_experts: bool, + config: Qwen3MoeConfig, + calibrate_all_experts: bool = True, ): super().__init__() self.num_experts = config.num_experts @@ -37,7 +49,7 @@ def __init__( self.gate = original.gate self.experts = original.experts - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor): batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) @@ -87,11 +99,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits +# Legacy function for backward compatibility def replace( config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock, calibrate_all_experts: bool, ): - return Qwen3MoeSparseMoeBlock( - config=config, original=module, calibrate_all_experts=calibrate_all_experts + """ + Legacy replacement function. + Use CalibrationQwen3MoeSparseMoeBlock instead. + """ + return CalibrationQwen3MoeSparseMoeBlock( + module, + config, + calibrate_all_experts=calibrate_all_experts, ) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index db4b153058..e6494fc5e0 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -7,7 +7,6 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.core import LifecycleCallbacks -from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device @@ -46,10 +45,6 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - - if dataset_args is not None and dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) - for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 565cb81fb0..244edde87e 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -7,7 +7,6 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.core import LifecycleCallbacks, active_session -from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( @@ -82,9 +81,6 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) - # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( model, layers[0], dataloader, model_device diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f00a37e430..261afd6544 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -7,7 +7,6 @@ from tqdm import tqdm from llmcompressor.core import LifecycleCallbacks, active_session -from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.registry import CalibrationPipeline @@ -85,9 +84,6 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) - # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index 7dfea10f6d..e4e15d300f 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -1,15 +1,16 @@ +import contextlib from functools import partial import pytest import torch from transformers import AutoModelForCausalLM - -from llmcompressor.modeling.deepseek_v3 import ( - DeepseekV3Config, - DeepseekV3MoECalibrate, - OriginalDeepseekV3MoE, +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config +from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3MoE as OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration + +from llmcompressor.modeling.deepseek_v3 import CalibrationDeepseekV3MoE +from llmcompressor.modeling.moe_context import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -21,39 +22,43 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): with skip_weights_download(): model = AutoModelForCausalLM.from_pretrained(model_stub) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Deepseek MoE layer - moe_layer = None - for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): - moe_layer = module - break + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, CalibrationDeepseekV3MoE): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu @@ -70,12 +75,12 @@ def test_calib_deepseekv3_module(): with calibration_forward_context(original): true_output = original(sample) - module = DeepseekV3MoECalibrate(config, original, calibrate_all_experts=True) + module = CalibrationDeepseekV3MoE(original, config, calibrate_all_experts=True) with calibration_forward_context(module): output = module(sample) assert torch.allclose(true_output, output, atol=1e-6) - module = DeepseekV3MoECalibrate(config, original, calibrate_all_experts=False) + module = CalibrationDeepseekV3MoE(original, config, calibrate_all_experts=False) with calibration_forward_context(module): output = module(sample) assert torch.allclose(true_output, output, atol=1e-6) diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 56b1b85033..78fb4ee6d9 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -1,16 +1,15 @@ +import contextlib import os from functools import partial import pytest import torch from transformers import Llama4ForConditionalGeneration +from transformers.models.llama4.configuration_llama4 import Llama4TextConfig +from transformers.models.llama4.modeling_llama4 import Llama4TextMoe -from llmcompressor.modeling.llama4 import ( - Llama4TextConfig, - Llama4TextMoe, - SequentialLlama4TextMoe, -) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.llama4 import SequentialLlama4TextMoe +from llmcompressor.modeling.moe_context import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -28,39 +27,43 @@ def test_calib_replace_llama4_moe_all_experts(model_stub): model_stub, torch_dtype="auto" ) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Llama4 MoE layer - moe_layer = None - for module in model.modules(): - if isinstance(module, SequentialLlama4TextMoe): - moe_layer = module - break + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.text_config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu @@ -79,13 +82,13 @@ def test_calib_llama4_module(): with calibration_forward_context(original): true_out, true_router_logits = original(sample) - module = SequentialLlama4TextMoe(config, original, calibrate_all_experts=True) + module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=True) with calibration_forward_context(module): out, router_logits = module(sample) assert torch.nn.functional.mse_loss(true_out, out) < 1e-10 assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 1e-10 - module = SequentialLlama4TextMoe(config, original, calibrate_all_experts=False) + module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=False) with calibration_forward_context(module): out, router_logits = module(sample) assert torch.nn.functional.mse_loss(true_out, out) < 1e-10 diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 15b9057e0f..d6e54776dc 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -4,13 +4,13 @@ import pytest import torch from transformers import AutoModelForCausalLM - -from llmcompressor.modeling.prepare import moe_calibration_context -from llmcompressor.modeling.qwen3_moe import ( - OriginalQwen3MoeSparseMoeBlock, - Qwen3MoeConfig, - Qwen3MoeSparseMoeBlock, +from transformers.models import Qwen3MoeConfig +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, ) + +from llmcompressor.modeling.moe_context import moe_calibration_context +from llmcompressor.modeling.qwen3_moe import CalibrationQwen3MoeSparseMoeBlock from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -26,13 +26,12 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - - moe_calibration_context(model, stack, calibrate_all_experts=True) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) # Find one MoE layer moe_layer = None for name, module in model.named_modules(): - if isinstance(module, Qwen3MoeSparseMoeBlock): + if isinstance(module, CalibrationQwen3MoeSparseMoeBlock): moe_layer = module break @@ -78,13 +77,17 @@ def test_calib_qwen3_moe_module(): with calibration_forward_context(original): true_output = original(sample) - module = Qwen3MoeSparseMoeBlock(config, original, calibrate_all_experts=True) + module = CalibrationQwen3MoeSparseMoeBlock( + original, config, calibrate_all_experts=True + ) with calibration_forward_context(module): output = module(sample) assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10 assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10 - module = Qwen3MoeSparseMoeBlock(config, original, calibrate_all_experts=False) + module = CalibrationQwen3MoeSparseMoeBlock( + original, config, calibrate_all_experts=False + ) with calibration_forward_context(module): output = module(sample) assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10