Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions vllm_omni/model_executor/custom_process_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections.abc import Callable

import torch


class CustomProcessMixin:
"""
Mixin class for all stages in the Omni model.
"""

def set_custom_preprocess(self, preprocess_fn: Callable) -> None:
"""
Set a preprocess function for the stage.
Args:
preprocess_fn: The preprocess function to register.
"""
self.preprocess = preprocess_fn

def set_custom_postprocess(self, postprocess_fn: Callable) -> None:
"""
Set a postprocess function for the stage.
Args:
postprocess_fn: The postprocess function to register.
"""
self.postprocess = postprocess_fn

def preprocess(
self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **input_dict: object
) -> tuple[torch.Tensor, torch.Tensor, dict]:
"""
Process the input_ids and input_embeds for the given input_dict.
Returns the processed input_ids, input_embeds, and the input_dict.

If the stage don't applicable, return the original input_ids, input_embeds, and an empty dict.
"""
raise NotImplementedError("Preprocess is not implemented for this stage.")

def postprocess(self, model_output, **info_dict: object):
"""
Postprocess the model output.
Returns the postprocessed model output and the save dictionary.
Args:
model_output: The model output to postprocess.
"""
raise NotImplementedError("Postprocess is not implemented for this stage.")
277 changes: 121 additions & 156 deletions vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler

from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
from vllm_omni.model_executor.models.output_templates import OmniOutput
from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import (
Expand All @@ -50,10 +51,11 @@
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
)
class Qwen2_5OmniForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, Qwen2_5OmniConditionalGenerationMixin, SupportsMRoPE
nn.Module, SupportsMultiModal, SupportsPP, SupportsMRoPE, Qwen2_5OmniConditionalGenerationMixin, CustomProcessMixin
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.has_preprocess = False
self.have_multimodal_outputs = True
config: Qwen2_5OmniConfig = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
Expand Down Expand Up @@ -84,6 +86,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.token2wav = None

elif self.model_stage == "talker":
# register the process function for the talker stage
self.has_preprocess = True
self.set_custom_preprocess(self.talker_preprocess)
self.thinker = None
# Initialize talker model wrapper (handles projection + LM)
self.talker = init_vllm_registered_model(
Expand Down Expand Up @@ -244,13 +249,14 @@ def forward(
thinker_inputs_embeds = inputs_embeds

# Run thinker
thinker_output = self.thinker(
input_ids=thinker_input_ids,
positions=thinker_positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=thinker_inputs_embeds,
**kwargs,
)
with torch.inference_mode():
thinker_output = self.thinker(
input_ids=thinker_input_ids,
positions=thinker_positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=thinker_inputs_embeds,
**kwargs,
)

if isinstance(thinker_output, tuple):
embeds, text_hidden_states = thinker_output
Expand All @@ -265,162 +271,21 @@ def forward(

# 2) Talker (if codec not provided)
if self.model_stage == "talker":
# Mixed-mode support: In a single step, both Prefill*n and Decode*n are supported.
# Rules:
# - Prefill segments are wrapped with special tokens: [BOS][PAD...][EOS]
# - Decode segments consist of a single non-special token.
# - If additional_information is provided
# (can be a list split by request or a concatenated tensor plus a list of shapes),
# then for each request, reconstruct the thinker→talker input embeddings for the Prefill segments;
# - For Decode segments, if per-request auxiliary decode embeddings are provided (optional), add them;
# otherwise, keep the original embedding.

if input_ids is None and additional_information is None:
input_ids = torch.zeros(
inputs_embeds.shape[0],
dtype=torch.long,
device=inputs_embeds.device,
)
additional_information = {}
# mock data for profile
if input_ids is None:
input_ids = torch.zeros(inputs_embeds.shape[0], dtype=torch.long, device=inputs_embeds.device)
self.thinker_reply_part = torch.zeros_like(inputs_embeds)
is_profile = True
else:
is_profile = False

# Ensure we have base embeddings when only ids are provided
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.talker.get_input_embeddings(input_ids)

# ------- Request-scoped additional information (no cross-request concat) -------
request_ids: Optional[list[str]] = kwargs.get("request_ids") # ordered
request_token_spans: Optional[list[tuple[int, int]]] = kwargs.get("request_token_spans")
addi_by_req: Optional[dict] = kwargs.get("additional_information_by_req_id")
runtime_addi = kwargs.get("runtime_additional_information")

# Normalize runtime_addi into a mapping by request_id for convenience
runtime_addi_by_req: dict[str, dict] = {}
if (
isinstance(request_ids, list)
and isinstance(runtime_addi, list)
and len(runtime_addi) == len(request_ids)
):
for i, rid in enumerate(request_ids):
if isinstance(rid, str) and isinstance(runtime_addi[i], dict):
runtime_addi_by_req[rid] = runtime_addi[i]
elif isinstance(request_ids, list) and isinstance(runtime_addi, dict):
for rid in request_ids:
if isinstance(rid, str) and isinstance(runtime_addi.get(rid), dict):
runtime_addi_by_req[rid] = runtime_addi[rid]

# Containers to return per-request updates (e.g., thinker_reply_part_per_request)
update_by_req_id: dict[str, dict] = {}

# ------- Prefill: span_len > 1 -------
if (
not is_profile
and isinstance(request_ids, list)
and isinstance(request_token_spans, list)
and isinstance(addi_by_req, dict)
):
for idx_req, rid in enumerate(request_ids):
s, e = request_token_spans[idx_req]
span_len = int(e) - int(s)
if span_len <= 1:
continue
info = addi_by_req.get(rid, {}) if isinstance(rid, str) else {}
if not isinstance(info, dict):
info = {}
pe = info.get("prompt_embeds") # Tensor [P,H]
tr = info.get("thinker_result") # Tensor [K,H]
ptoks = info.get("prompt_token_ids") # list[int]
otoks = info.get("thinker_output_token_ids") # list[int]

if not isinstance(pe, torch.Tensor):
pe = torch.zeros(
0,
self.talker.config.hidden_size,
dtype=inputs_embeds.dtype,
device=self._module_device(self.model),
)
if not isinstance(tr, torch.Tensor):
tr = torch.zeros(
0,
self.talker.config.hidden_size,
dtype=inputs_embeds.dtype,
device=self._module_device(self.model),
)
if not isinstance(ptoks, (list, torch.Tensor)):
ptoks = []
if not isinstance(otoks, (list, torch.Tensor)):
otoks = []

req_input_ids, req_embeds = self._thinker_to_talker_prefill(
voice_type=voice_type,
output_prompt_embeds=tr.to(inputs_embeds.dtype).to(self._module_device(self.model)),
output_token_ids=otoks,
thinker_prompt_embeds=pe.to(inputs_embeds.dtype).to(self._module_device(self.model)),
prompt_token_ids=ptoks,
)
seg_len = min(span_len, req_embeds.shape[0])
inputs_embeds[s : s + seg_len] = req_embeds[:seg_len]
if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len:
input_ids[s : s + seg_len] = req_input_ids

# Prepare per-request reply queue for subsequent decode: drop first row
if tr.ndim == 2 and tr.shape[0] > 0:
update_by_req_id.setdefault(rid, {})["thinker_reply_part_per_request"] = (
tr[1:].detach().to("cpu").contiguous()
)

# ------- Decode: span_len == 1 -------
if not is_profile and isinstance(request_ids, list) and isinstance(request_token_spans, list):
for idx_req, rid in enumerate(request_ids):
s, e = request_token_spans[idx_req]
if (int(e) - int(s)) != 1:
continue
# choose step vector in priority order
step_vec = None
# A) runtime queue
q = None
if isinstance(rid, str):
q = runtime_addi_by_req.get(rid, {}).get("thinker_reply_part_per_request")
if isinstance(q, torch.Tensor) and q.numel() > 0:
step_vec = q[0:1]
new_q = q[1:].detach().to("cpu").contiguous()
update_by_req_id.setdefault(rid, {})["thinker_reply_part_per_request"] = new_q
else:
# B) per-request provided decode vector (optional)
info = addi_by_req.get(rid, {}) if isinstance(addi_by_req, dict) else {}
dv = info.get("decode_output_prompt_embeds") if isinstance(info, dict) else None
if isinstance(dv, torch.Tensor) and dv.numel() > 0:
step_vec = dv[0:1] if dv.ndim == 2 else dv.view(1, -1)
elif (
hasattr(self, "thinker_reply_part")
and isinstance(self.thinker_reply_part, torch.Tensor)
and self.thinker_reply_part.numel() > 0
):
# C) fallback shared pool
step_vec = self.thinker_reply_part[0:1]
self.thinker_reply_part = self.thinker_reply_part[1:]

if isinstance(step_vec, torch.Tensor) and step_vec.numel() > 0:
one_id = input_ids[s : s + 1]
_, one_embed = self._thinker_to_talker_decode_one_step(
output_prompt_embeds=step_vec.to(inputs_embeds.dtype).to(self._module_device(self.model)),
output_token_ids=one_id,
)
inputs_embeds[s] = one_embed[0]
# TODO(Peiqi): temporal hack here to support voice_type.
if not hasattr(self, "voice_type"):
self.voice_type = voice_type

with torch.inference_mode():
talker_hidden = self.talker(
input_ids=input_ids,
positions=positions[0],
inputs_embeds=inputs_embeds,
)
multimodal_outputs: dict = None
# Return updates if any
if update_by_req_id:
multimodal_outputs = {"additional_information_update_by_req_id": update_by_req_id}

if sampling_metadata is not None:
# the padding token id is set to text model's pad token id,
Expand All @@ -429,7 +294,7 @@ def forward(

return OmniOutput(
text_hidden_states=talker_hidden,
multimodal_outputs=multimodal_outputs,
multimodal_outputs=None,
)

if self.model_stage == "code2wav":
Expand Down Expand Up @@ -741,6 +606,74 @@ def _get_text_spk_token_id(self, voice_type: str):
return talker_hf_config.tts_text_start_token_id
return self.tts_text_spk_token_ids[voice_type]

def talker_preprocess(
self,
input_ids: torch.Tensor,
input_embeds: torch.Tensor,
**info_dict: object,
):
# Mixed-mode support: In a single step, both Prefill*n and Decode*n are supported.
# Rules:
# - Prefill segments are wrapped with special tokens: [BOS][PAD...][EOS]
# - Decode segments consist of a single non-special token.
# - If additional_information is provided (can be a list split by request or a
# concatenated tensor plus a list of shapes), then for each request, reconstruct
# the thinker→talker input embeddings for the Prefill segments;
# - For Decode segments, if per-request auxiliary decode embeddings are provided (optional),
# add them; otherwise, keep the original embedding.

# Ensure we have base embeddings when only ids are provided
if input_embeds is None and input_ids is not None:
input_embeds = self.talker.get_input_embeddings(input_ids)

span_len = input_ids.shape[0]
if span_len > 1:
# prefill
return self.thinker_to_talker_process(input_ids, input_embeds, **info_dict)
else:
# decode
return self.thinker_to_talker_decode_one_step(input_ids, input_embeds, **info_dict)

def thinker_to_talker_process(
self,
input_ids: torch.Tensor,
input_embeds: torch.Tensor,
**info_dict: object,
):
update_dict = {}

prompt_embeds = info_dict.get("prompt_embeds") # Tensor [P,H]
thinker_result = info_dict.get("thinker_result") # Tensor [K,H]
prompt_token_ids = info_dict.get("prompt_token_ids") # list[int]
thinker_output_token_ids = info_dict.get("thinker_output_token_ids") # list[int]

if not isinstance(prompt_embeds, torch.Tensor):
prompt_embeds = torch.zeros(
0, self.talker.config.hidden_size, dtype=input_embeds.dtype, device=self._module_device(self.model)
)
if not isinstance(thinker_result, torch.Tensor):
thinker_result = torch.zeros(
0, self.talker.config.hidden_size, dtype=input_embeds.dtype, device=self._module_device(self.model)
)
if not isinstance(prompt_token_ids, (list, torch.Tensor)):
prompt_token_ids = []
if not isinstance(thinker_output_token_ids, (list, torch.Tensor)):
thinker_output_token_ids = []

# TODO(Peiqi): add voice_type support
req_input_ids, req_embeds = self._thinker_to_talker_prefill(
voice_type=self.voice_type,
output_prompt_embeds=thinker_result.to(input_embeds.dtype).to(self._module_device(self.model)),
output_token_ids=thinker_output_token_ids,
thinker_prompt_embeds=prompt_embeds.to(input_embeds.dtype).to(self._module_device(self.model)),
prompt_token_ids=prompt_token_ids,
)

if thinker_result.ndim == 2 and thinker_result.shape[0] > 0:
update_dict["thinker_reply_part"] = thinker_result[1:].detach().to("cpu").contiguous()

return req_input_ids, req_embeds, update_dict

def _thinker_to_talker_prefill(
self,
voice_type: str,
Expand Down Expand Up @@ -786,6 +719,38 @@ def _thinker_to_talker_prefill(
)
return prompt_token_ids_processed, prompt_embeds

def thinker_to_talker_decode_one_step(self, input_ids, input_embeds, **info_dict):
update_dict = {}
# choose step vector in priority order
step_vec = None
q = info_dict.get("thinker_reply_part", None)
if isinstance(q, torch.Tensor) and q.numel() > 0:
step_vec = q[0:1]
new_q = q[1:].detach().to("cpu").contiguous()
update_dict["thinker_reply_part"] = new_q
else:
# B) per-request provided decode vector (optional)
dv = info_dict.get("decode_output_prompt_embeds") if isinstance(info_dict, dict) else None
if isinstance(dv, torch.Tensor) and dv.numel() > 0:
step_vec = dv[0:1] if dv.ndim == 2 else dv.view(1, -1)
elif (
hasattr(self, "thinker_reply_part")
and isinstance(self.thinker_reply_part, torch.Tensor)
and self.thinker_reply_part.numel() > 0
):
# C) fallback shared pool
step_vec = self.thinker_reply_part[0:1]
self.thinker_reply_part = self.thinker_reply_part[1:]

if isinstance(step_vec, torch.Tensor) and step_vec.numel() > 0:
one_id = input_ids[0:1]
_, one_embed = self._thinker_to_talker_decode_one_step(
output_prompt_embeds=step_vec.to(input_embeds.dtype).to(self._module_device(self.model)),
output_token_ids=one_id,
)
input_embeds[0] = one_embed[0]
return input_ids[0:1], input_embeds[0:1], update_dict

def _thinker_to_talker_decode_one_step(
self,
output_prompt_embeds,
Expand Down
Loading