From 43bc12f9a48198d30ec5f411d24f6d7ee9c4c809 Mon Sep 17 00:00:00 2001 From: linzebing Date: Fri, 17 Oct 2025 17:10:41 -0700 Subject: [PATCH 1/2] [fix][spec decode] Fix llama4 draft model loading with different quantization config Signed-off-by: linzebing --- vllm/model_executor/models/llama4_eagle.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index dd6337244ca6..fdc8845e52e0 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -60,6 +60,10 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) + # Temporarily modify vllm_config.quant_config for draft model layers + original_quant_config = vllm_config.quant_config + vllm_config.quant_config = quant_config + self.layers = nn.ModuleList( [ Llama4DecoderLayer( @@ -70,6 +74,8 @@ def __init__( for i in range(self.config.num_hidden_layers) ] ) + # Restore original quant_config + vllm_config.quant_config = original_quant_config self.fc = torch.nn.Linear( self.config.hidden_size * 2, self.config.hidden_size, bias=False ) From 35b96a2c327339b7be6e6c9572c76090301d41a0 Mon Sep 17 00:00:00 2001 From: linzebing Date: Fri, 17 Oct 2025 17:57:34 -0700 Subject: [PATCH 2/2] make gemini happy Signed-off-by: linzebing --- vllm/model_executor/models/llama4_eagle.py | 27 +++++++++++----------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index fdc8845e52e0..90273463d64e 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -63,19 +63,20 @@ def __init__( # Temporarily modify vllm_config.quant_config for draft model layers original_quant_config = vllm_config.quant_config vllm_config.quant_config = quant_config - - self.layers = nn.ModuleList( - [ - Llama4DecoderLayer( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - config=self.config, - ) - for i in range(self.config.num_hidden_layers) - ] - ) - # Restore original quant_config - vllm_config.quant_config = original_quant_config + try: + self.layers = nn.ModuleList( + [ + Llama4DecoderLayer( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, + ) + for i in range(self.config.num_hidden_layers) + ] + ) + finally: + # Restore original quant_config + vllm_config.quant_config = original_quant_config self.fc = torch.nn.Linear( self.config.hidden_size * 2, self.config.hidden_size, bias=False )