-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
Use pydantic validation in speculative.py config #27137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
import hashlib | ||
from typing import TYPE_CHECKING, Any, Literal | ||
|
||
from pydantic import SkipValidation, model_validator | ||
from pydantic import Field, SkipValidation, model_validator | ||
from pydantic.dataclasses import dataclass | ||
from typing_extensions import Self | ||
|
||
|
@@ -62,7 +62,7 @@ | |
enforce_eager: bool | None = None | ||
"""Override the default enforce_eager from model_config""" | ||
# General speculative decoding control | ||
num_speculative_tokens: SkipValidation[int] = None # type: ignore | ||
num_speculative_tokens: int = Field(default=None, validate_default=True, ge=1) | ||
"""The number of speculative tokens, if provided. It will default to the | ||
number in the draft model config if present, otherwise, it is required.""" | ||
model: str | None = None | ||
|
@@ -76,7 +76,7 @@ | |
If using `ngram` method, the related configuration `prompt_lookup_max` and | ||
`prompt_lookup_min` should be considered.""" | ||
draft_tensor_parallel_size: int | None = None | ||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1) | ||
"""The degree of the tensor parallelism for the draft model. Can only be 1 | ||
or the same as the target model's tensor parallel size.""" | ||
disable_logprobs: bool = True | ||
|
@@ -89,7 +89,7 @@ | |
"""Quantization method that was used to quantize the draft model weights. | ||
If `None`, we assume the model weights are not quantized. Note that it only | ||
takes effect when using the draft model-based speculative method.""" | ||
max_model_len: int | None = None | ||
max_model_len: int | None = Field(default=None, ge=1) | ||
"""The maximum model length of the draft model. Used when testing the | ||
ability to skip speculation for some sequences.""" | ||
revision: str | None = None | ||
|
@@ -102,7 +102,7 @@ | |
will use the default version.""" | ||
|
||
# Advanced control | ||
disable_by_batch_size: int | None = None | ||
disable_by_batch_size: int | None = Field(default=None, ge=2) | ||
"""Disable speculative decoding for new incoming requests when the number | ||
of enqueued requests is larger than this value, if provided.""" | ||
disable_padded_drafter_batch: bool = False | ||
|
@@ -112,10 +112,10 @@ | |
only affects the EAGLE method of speculation.""" | ||
|
||
# Ngram proposer configuration | ||
prompt_lookup_max: int | None = None | ||
prompt_lookup_max: int | None = Field(default=None, ge=0) | ||
"""Maximum size of ngram token window when using Ngram proposer, required | ||
when method is set to ngram.""" | ||
prompt_lookup_min: int | None = None | ||
prompt_lookup_min: int | None = Field(default=None, ge=0) | ||
"""Minimum size of ngram token window when using Ngram proposer, if | ||
provided. Defaults to 1.""" | ||
|
||
|
@@ -215,7 +215,8 @@ | |
|
||
return hf_config | ||
|
||
def __post_init__(self): | ||
@model_validator(mode="after") | ||
def _configure(self) -> Self: | ||
# Note: "method" is a new parameter that helps to extend the | ||
# configuration of non-model-based proposers, and the "model" parameter | ||
# will be used to set the draft model, eagle head, or additional weight | ||
|
@@ -232,9 +233,8 @@ | |
|
||
if self.model is None and self.num_speculative_tokens is not None: | ||
if self.method == "mtp": | ||
assert self.target_model_config is not None, ( | ||
"target_model_config must be present for mtp" | ||
) | ||
if self.target_model_config is None: | ||
raise ValueError("target_model_config must be present for mtp") | ||
if self.target_model_config.hf_text_config.model_type == "deepseek_v32": | ||
# FIXME(luccafong): cudgraph with v32 MTP is not supported, | ||
# remove this when the issue is fixed. | ||
|
@@ -268,18 +268,24 @@ | |
self.prompt_lookup_min = 5 | ||
self.prompt_lookup_max = 5 | ||
elif self.prompt_lookup_min is None: | ||
assert self.prompt_lookup_max is not None | ||
if self.prompt_lookup_max is None: | ||
raise ValueError( | ||
"prompt_lookup_max must be provided when using the ngram method." | ||
) | ||
Comment on lines
+271
to
+274
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This Consider rephrasing the error message to indicate that both raise ValueError(
"Both prompt_lookup_max and prompt_lookup_min must be provided when using the ngram method."
) |
||
self.prompt_lookup_min = self.prompt_lookup_max | ||
elif self.prompt_lookup_max is None: | ||
assert self.prompt_lookup_min is not None | ||
if self.prompt_lookup_min is None: | ||
raise ValueError( | ||
"prompt_lookup_min must be provided when using the ngram method." | ||
) | ||
Comment on lines
+277
to
+280
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment, this Consider rephrasing the error message to indicate that both raise ValueError(
"Both prompt_lookup_max and prompt_lookup_min must be provided when using the ngram method."
) |
||
self.prompt_lookup_max = self.prompt_lookup_min | ||
|
||
# Validate values | ||
if self.prompt_lookup_min < 1: | ||
if self.prompt_lookup_min is None or self.prompt_lookup_min < 1: | ||
raise ValueError( | ||
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0" | ||
) | ||
if self.prompt_lookup_max < 1: | ||
if self.prompt_lookup_max is None or self.prompt_lookup_max < 1: | ||
raise ValueError( | ||
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0" | ||
) | ||
|
@@ -323,13 +329,8 @@ | |
hf_overrides=SpeculativeConfig.hf_config_override, | ||
) | ||
|
||
# Automatically detect the method | ||
if self.method in ("eagle", "eagle3"): | ||
pass | ||
# examples: | ||
# yuhuili/EAGLE-LLaMA3-Instruct-8B | ||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B | ||
# AngelSlim/Qwen3-8B_eagle3 | ||
elif "eagle-" in self.draft_model_config.model.lower(): | ||
self.method = "eagle" | ||
elif "eagle3" in self.draft_model_config.model.lower(): | ||
|
@@ -340,97 +341,84 @@ | |
self.method = "mlp_speculator" | ||
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES: | ||
self.method = "mtp" | ||
if self.num_speculative_tokens > 1: | ||
if self.num_speculative_tokens and self.num_speculative_tokens > 1: | ||
logger.warning( | ||
"Enabling num_speculative_tokens > 1 will run" | ||
"multiple times of forward on same MTP layer" | ||
",which may result in lower acceptance rate" | ||
"Enabling num_speculative_tokens > 1 will run multiple " | ||
"times of forward on same MTP layer, which may result in " | ||
"lower acceptance rate" | ||
) | ||
elif self.draft_model_config.hf_config.model_type in ( | ||
"longcat_flash_mtp" | ||
): | ||
elif self.draft_model_config.hf_config.model_type == "longcat_flash_mtp": | ||
self.method = "longcat_flash_mtp" | ||
if self.num_speculative_tokens > 1: | ||
if self.num_speculative_tokens and self.num_speculative_tokens > 1: | ||
logger.warning( | ||
"LongCat MTP models only have " | ||
"one layer. Might need some code changes " | ||
"to support multiple layers." | ||
"LongCat MTP models only have one layer. Might need some " | ||
"code changes to support multiple layers." | ||
) | ||
else: | ||
self.method = "draft_model" | ||
raise NotImplementedError( | ||
"Speculative decoding with draft model is not " | ||
"supported yet. Please consider using other " | ||
"speculative decoding methods such as ngram, medusa, " | ||
"eagle, or mtp." | ||
"Speculative decoding with draft model is not supported yet. " | ||
"Please consider using other speculative decoding methods " | ||
"such as ngram, medusa, eagle, or mtp." | ||
) | ||
|
||
# Replace hf_config for EAGLE draft_model | ||
if self.method in ("eagle", "eagle3"): | ||
if self.enable_chunked_prefill and not envs.VLLM_USE_V1: | ||
raise ValueError( | ||
"Chunked prefill and EAGLE are not compatible " | ||
"when using V0." | ||
"Chunked prefill and EAGLE are not compatible when using V0." | ||
) | ||
|
||
from vllm.transformers_utils.configs import SpeculatorsConfig | ||
from vllm.transformers_utils.configs.eagle import EAGLEConfig | ||
|
||
if isinstance( | ||
if not isinstance( | ||
self.draft_model_config.hf_config, | ||
(EAGLEConfig, SpeculatorsConfig), | ||
): | ||
pass | ||
else: | ||
eagle_config = EAGLEConfig( | ||
self.draft_model_config.hf_config, | ||
method=self.method, | ||
model_type="eagle", | ||
) | ||
self.draft_model_config.hf_config = eagle_config | ||
Comment on lines
+374
to
383
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code block is checking if Consider removing the eagle_config = EAGLEConfig(
self.draft_model_config.hf_config,
method=self.method,
model_type="eagle",
)
self.draft_model_config.hf_config = eagle_config |
||
|
||
if self.num_speculative_tokens is not None and hasattr( | ||
self.draft_model_config.hf_config, "num_lookahead_tokens" | ||
if ( | ||
self.num_speculative_tokens is not None | ||
and hasattr(self.draft_model_config.hf_config, "num_lookahead_tokens") | ||
): | ||
self.draft_model_config.hf_config.num_lookahead_tokens = ( | ||
self.num_speculative_tokens | ||
) | ||
|
||
n_predict = getattr( | ||
self.draft_model_config.hf_config, "n_predict", None | ||
) | ||
n_predict = getattr(self.draft_model_config.hf_config, "n_predict", None) | ||
if n_predict is not None: | ||
if self.num_speculative_tokens is None: | ||
# Default to max value defined in draft model config. | ||
self.num_speculative_tokens = n_predict | ||
elif ( | ||
self.num_speculative_tokens > n_predict | ||
and self.num_speculative_tokens % n_predict != 0 | ||
): | ||
# Ensure divisibility for MTP module reuse. | ||
raise ValueError( | ||
f"num_speculative_tokens:{self.num_speculative_tokens}" | ||
f" must be divisible by {n_predict=}" | ||
) | ||
|
||
if self.speculative_token_tree is None: | ||
if self.speculative_token_tree is None and self.num_speculative_tokens is not None: | ||
# Generate chain of tokens. | ||
self.speculative_token_tree = str( | ||
[(i + 1) * (0,) for i in range(self.num_speculative_tokens)] | ||
) | ||
else: | ||
elif self.speculative_token_tree is not None: | ||
# Sort the token tree breadth-first. | ||
tree_choices = ast.literal_eval(self.speculative_token_tree) | ||
self.speculative_token_tree = str( | ||
sorted(tree_choices, key=lambda t: (len(t), t)) | ||
) | ||
|
||
self.draft_tensor_parallel_size = ( | ||
SpeculativeConfig._verify_and_get_draft_tp( | ||
self.target_parallel_config, | ||
self.draft_tensor_parallel_size, | ||
self.draft_model_config.hf_config, | ||
) | ||
self.draft_tensor_parallel_size = SpeculativeConfig._verify_and_get_draft_tp( | ||
self.target_parallel_config, | ||
self.draft_tensor_parallel_size, | ||
self.draft_model_config.hf_config, | ||
) | ||
|
||
self.draft_model_config.max_model_len = ( | ||
|
@@ -441,12 +429,12 @@ | |
) | ||
) | ||
|
||
self.draft_parallel_config = ( | ||
SpeculativeConfig.create_draft_parallel_config( | ||
self.target_parallel_config, self.draft_tensor_parallel_size | ||
) | ||
self.draft_parallel_config = SpeculativeConfig.create_draft_parallel_config( | ||
self.target_parallel_config, self.draft_tensor_parallel_size | ||
) | ||
|
||
return self | ||
|
||
@staticmethod | ||
def _maybe_override_draft_max_model_len( | ||
speculative_max_model_len: int | None, | ||
|
@@ -550,24 +538,11 @@ | |
"n_predict parameter." | ||
) | ||
|
||
if self.num_speculative_tokens <= 0: | ||
raise ValueError( | ||
"Expected num_speculative_tokens to be greater " | ||
f"than zero ({self.num_speculative_tokens})." | ||
) | ||
|
||
if self.draft_model_config: | ||
self.draft_model_config.verify_with_parallel_config( | ||
self.draft_parallel_config | ||
) | ||
|
||
if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2: | ||
raise ValueError( | ||
"Expect the batch size threshold of disabling " | ||
"speculative decoding is > 1, but got " | ||
f"{self.disable_by_batch_size=}" | ||
) | ||
|
||
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"] | ||
if ( | ||
self.method == "eagle3" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_speculative_tokens
is now declared asint
withField(default=None, ge=1)
but without| None
. Pydantic treats this field as required integer and will raise a validation error before_configure
runs when the value is omitted or left asNone
, even though later logic (e.g. inferring fromn_predict
in_configure
and_verify_args
) still relies on the field being optional. This prevents building aSpeculativeConfig
that relies on the draft model to supply the token count, a configuration that previously worked. Consider making the typeint | None
so that validation only triggers when a value is provided.Useful? React with 👍 / 👎.