Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 48 additions & 73 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hashlib
from typing import TYPE_CHECKING, Any, Literal

from pydantic import SkipValidation, model_validator
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

Expand Down Expand Up @@ -62,7 +62,7 @@
enforce_eager: bool | None = None
"""Override the default enforce_eager from model_config"""
# General speculative decoding control
num_speculative_tokens: SkipValidation[int] = None # type: ignore
num_speculative_tokens: int = Field(default=None, validate_default=True, ge=1)
"""The number of speculative tokens, if provided. It will default to the
Comment on lines 62 to 66

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allow None for num_speculative_tokens when inferring from n_predict

num_speculative_tokens is now declared as int with Field(default=None, ge=1) but without | None. Pydantic treats this field as required integer and will raise a validation error before _configure runs when the value is omitted or left as None, even though later logic (e.g. inferring from n_predict in _configure and _verify_args) still relies on the field being optional. This prevents building a SpeculativeConfig that relies on the draft model to supply the token count, a configuration that previously worked. Consider making the type int | None so that validation only triggers when a value is provided.

Useful? React with 👍 / 👎.

number in the draft model config if present, otherwise, it is required."""
model: str | None = None
Expand All @@ -76,7 +76,7 @@
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
draft_tensor_parallel_size: int | None = None
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True
Expand All @@ -89,7 +89,7 @@
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
max_model_len: int | None = None
max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
revision: str | None = None
Expand All @@ -102,7 +102,7 @@
will use the default version."""

# Advanced control
disable_by_batch_size: int | None = None
disable_by_batch_size: int | None = Field(default=None, ge=2)
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
disable_padded_drafter_batch: bool = False
Expand All @@ -112,10 +112,10 @@
only affects the EAGLE method of speculation."""

# Ngram proposer configuration
prompt_lookup_max: int | None = None
prompt_lookup_max: int | None = Field(default=None, ge=0)
"""Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram."""
prompt_lookup_min: int | None = None
prompt_lookup_min: int | None = Field(default=None, ge=0)
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""

Expand Down Expand Up @@ -215,7 +215,8 @@

return hf_config

def __post_init__(self):
@model_validator(mode="after")
def _configure(self) -> Self:
# Note: "method" is a new parameter that helps to extend the
# configuration of non-model-based proposers, and the "model" parameter
# will be used to set the draft model, eagle head, or additional weight
Expand All @@ -232,9 +233,8 @@

if self.model is None and self.num_speculative_tokens is not None:
if self.method == "mtp":
assert self.target_model_config is not None, (
"target_model_config must be present for mtp"
)
if self.target_model_config is None:
raise ValueError("target_model_config must be present for mtp")
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
# remove this when the issue is fixed.
Expand Down Expand Up @@ -268,18 +268,24 @@
self.prompt_lookup_min = 5
self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None:
assert self.prompt_lookup_max is not None
if self.prompt_lookup_max is None:
raise ValueError(
"prompt_lookup_max must be provided when using the ngram method."

Check failure on line 273 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:273:89: E501 Line too long (89 > 88)
)
Comment on lines +271 to +274
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This ValueError is raised when self.prompt_lookup_max is None but it should be provided when using the ngram method. However, the check self.prompt_lookup_min is None on line 270 already implies that self.prompt_lookup_min is None at this point. Therefore, the error message should also mention that prompt_lookup_min is missing as well.

Consider rephrasing the error message to indicate that both prompt_lookup_max and prompt_lookup_min must be provided when using the ngram method.

                    raise ValueError(
                        "Both prompt_lookup_max and prompt_lookup_min must be provided when using the ngram method."
                    )

self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None:
assert self.prompt_lookup_min is not None
if self.prompt_lookup_min is None:
raise ValueError(
"prompt_lookup_min must be provided when using the ngram method."

Check failure on line 279 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:279:89: E501 Line too long (89 > 88)
)
Comment on lines +277 to +280
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, this ValueError is raised when self.prompt_lookup_min is None but it should be provided when using the ngram method. However, the check self.prompt_lookup_max is None on line 276 already implies that self.prompt_lookup_max is None at this point. Therefore, the error message should also mention that prompt_lookup_max is missing as well.

Consider rephrasing the error message to indicate that both prompt_lookup_max and prompt_lookup_min must be provided when using the ngram method.

                    raise ValueError(
                        "Both prompt_lookup_max and prompt_lookup_min must be provided when using the ngram method."
                    )

self.prompt_lookup_max = self.prompt_lookup_min

# Validate values
if self.prompt_lookup_min < 1:
if self.prompt_lookup_min is None or self.prompt_lookup_min < 1:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0"
)
if self.prompt_lookup_max < 1:
if self.prompt_lookup_max is None or self.prompt_lookup_max < 1:
raise ValueError(
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0"
)
Expand Down Expand Up @@ -323,13 +329,8 @@
hf_overrides=SpeculativeConfig.hf_config_override,
)

# Automatically detect the method
if self.method in ("eagle", "eagle3"):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
Expand All @@ -340,97 +341,84 @@
self.method = "mlp_speculator"
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
self.method = "mtp"
if self.num_speculative_tokens > 1:
if self.num_speculative_tokens and self.num_speculative_tokens > 1:
logger.warning(
"Enabling num_speculative_tokens > 1 will run"
"multiple times of forward on same MTP layer"
",which may result in lower acceptance rate"
"Enabling num_speculative_tokens > 1 will run multiple "
"times of forward on same MTP layer, which may result in "
"lower acceptance rate"
)
elif self.draft_model_config.hf_config.model_type in (
"longcat_flash_mtp"
):
elif self.draft_model_config.hf_config.model_type == "longcat_flash_mtp":

Check failure on line 350 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:350:89: E501 Line too long (89 > 88)
self.method = "longcat_flash_mtp"
if self.num_speculative_tokens > 1:
if self.num_speculative_tokens and self.num_speculative_tokens > 1:
logger.warning(
"LongCat MTP models only have "
"one layer. Might need some code changes "
"to support multiple layers."
"LongCat MTP models only have one layer. Might need some "
"code changes to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or mtp."
"Speculative decoding with draft model is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or mtp."
)

# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
raise ValueError(
"Chunked prefill and EAGLE are not compatible "
"when using V0."
"Chunked prefill and EAGLE are not compatible when using V0."

Check failure on line 368 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:368:89: E501 Line too long (89 > 88)
)

from vllm.transformers_utils.configs import SpeculatorsConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig

if isinstance(
if not isinstance(
self.draft_model_config.hf_config,
(EAGLEConfig, SpeculatorsConfig),
):
pass
else:
eagle_config = EAGLEConfig(
self.draft_model_config.hf_config,
method=self.method,
model_type="eagle",
)
self.draft_model_config.hf_config = eagle_config
Comment on lines +374 to 383
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code block is checking if self.draft_model_config.hf_config is an instance of EAGLEConfig or SpeculatorsConfig. If it's not, it proceeds to create an EAGLEConfig. However, if it is an instance of either of those classes, the code block does nothing. This seems like an incomplete implementation, as there might be a need to handle these cases differently or perform some other action. If no action is required, the if not isinstance(...) check is unnecessary and can be removed.

Consider removing the if not isinstance(...) check and directly create the eagle_config without the conditional check, or add a comment explaining why no action is needed when the condition is false.

                eagle_config = EAGLEConfig(
                    self.draft_model_config.hf_config,
                    method=self.method,
                    model_type="eagle",
                )
                self.draft_model_config.hf_config = eagle_config


if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
if (
self.num_speculative_tokens is not None
and hasattr(self.draft_model_config.hf_config, "num_lookahead_tokens")

Check failure on line 387 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:387:89: E501 Line too long (90 > 88)
):
self.draft_model_config.hf_config.num_lookahead_tokens = (
self.num_speculative_tokens
)

n_predict = getattr(
self.draft_model_config.hf_config, "n_predict", None
)
n_predict = getattr(self.draft_model_config.hf_config, "n_predict", None)

Check failure on line 393 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:393:89: E501 Line too long (89 > 88)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif (
self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0
):
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}"
)

if self.speculative_token_tree is None:
if self.speculative_token_tree is None and self.num_speculative_tokens is not None:

Check failure on line 406 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:406:89: E501 Line too long (99 > 88)
# Generate chain of tokens.
self.speculative_token_tree = str(
[(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
)
else:
elif self.speculative_token_tree is not None:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t))
)

self.draft_tensor_parallel_size = (
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config,
)
self.draft_tensor_parallel_size = SpeculativeConfig._verify_and_get_draft_tp(

Check failure on line 418 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:418:89: E501 Line too long (93 > 88)
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config,
)

self.draft_model_config.max_model_len = (
Expand All @@ -441,12 +429,12 @@
)
)

self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config, self.draft_tensor_parallel_size
)
self.draft_parallel_config = SpeculativeConfig.create_draft_parallel_config(

Check failure on line 432 in vllm/config/speculative.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/speculative.py:432:89: E501 Line too long (92 > 88)
self.target_parallel_config, self.draft_tensor_parallel_size
)

return self

@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: int | None,
Expand Down Expand Up @@ -550,24 +538,11 @@
"n_predict parameter."
)

if self.num_speculative_tokens <= 0:
raise ValueError(
"Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens})."
)

if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config
)

if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
raise ValueError(
"Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}"
)

eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
if (
self.method == "eagle3"
Expand Down
Loading