Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 136 additions & 1 deletion transformerlab/plugins/unsloth_text_to_speech_server/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from snac import SNAC
import torch
import librosa
import numpy as np

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

from fastchat.serve.model_worker import logger


class AudioModelBase(ABC):
def __init__(self, model_name, device, context_length=2048):
Expand Down Expand Up @@ -343,4 +350,132 @@ def _create_voice_cloning_input(self, target_text_ids, audio_tokens, voice_promp

input_ids = torch.cat(input_sequence, dim=1)

return input_ids.to(self.device)
return input_ids.to(self.device)


class VibeVoiceAudioModel(AudioModelBase):
def __init__(self, model_name, device, context_length=2048):
super().__init__(model_name, device, context_length)
self.processor = VibeVoiceProcessor.from_pretrained(model_name)
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=None,
)
self.model.eval()
# Set noise scheduler
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
self.model.model.noise_scheduler.config,
algorithm_type="sde-dpmsolver++",
beta_schedule="squaredcos_cap_v2",
)
self.generate_kwargs = {
"max_new_tokens": None,
"cfg_scale": 1.3,
"inference_steps": 10,
"do_sample": False,
}

def tokenize(self, text, audio_path=None, sample_rate=24000, voice=None):
# VibeVoice expects a multi-speaker dialogue format
# If audio_path is provided, use it as the voice sample for Speaker 0
if audio_path:
# Load reference audio for the speaker
voice_audio = librosa.load(audio_path, sr=sample_rate)[0]
voice_samples = [[voice_audio]]
else:
# For standard TTS without voice reference, use None
# VibeVoice should handle this case internally
voice_samples = None

# Format the text as a single-speaker dialogue
formatted_text = f"Speaker 0: {text}"

inputs = self.processor(
text=[formatted_text],
voice_samples=voice_samples,
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
return inputs.to(self.device)

def generate(self, inputs, **kwargs):
self.model.set_ddpm_inference_steps(num_steps=self.generate_kwargs['inference_steps'])

gen_args = {
**inputs,
"max_new_tokens": None,
"cfg_scale": self.generate_kwargs['cfg_scale'],
"tokenizer": self.processor.tokenizer,
"generation_config": {
"do_sample": self.generate_kwargs['do_sample'],
},
"audio_streamer": None,
"stop_check_fn": None,
"verbose": False,
"refresh_negative": True,
**kwargs
}

outputs = self.model.generate(**gen_args)
return outputs

def decode(self, generated, **kwargs):
# Step 1: Extract audio from VibeVoiceGenerationOutput
if hasattr(generated, "speech_outputs"):
audio = generated.speech_outputs[0]
elif hasattr(generated, "waveform"):
audio = generated.waveform
elif hasattr(generated, "audio"):
audio = generated.audio
elif isinstance(generated, (list, tuple)):
audio = generated[0]
else:
audio = generated

# Step 2: Convert tensor to numpy array
if torch.is_tensor(audio):
# Convert bfloat16 to float32 first, then to numpy
if audio.dtype == torch.bfloat16:
audio = audio.float()
audio_np = audio.cpu().numpy().astype(np.float32)
else:
audio_np = np.array(audio, dtype=np.float32)

# Step 3: Ensure audio is 1D and properly normalized
if len(audio_np.shape) > 1:
audio_np = audio_np.squeeze()

# Step 4: Check for invalid values
if np.any(np.isnan(audio_np)) or np.any(np.isinf(audio_np)):
logger.warning("Audio contains NaN or Inf values, cleaning...")
audio_np = np.nan_to_num(audio_np, nan=0.0, posinf=1.0, neginf=-1.0)

# Step 5: Convert to 16-bit format like the demo does
audio_16bit = self._convert_to_16_bit_wav(audio_np)

# Step 6: Convert back to float32 for compatibility
final_audio = audio_16bit.astype(np.float32) / 32767.0

logger.info(f"VibeVoice audio generated: shape={final_audio.shape}, range=[{np.min(final_audio):.3f}, {np.max(final_audio):.3f}]")

return final_audio

def _convert_to_16_bit_wav(self, data):
"""Convert audio data to 16-bit format exactly like the demo."""
# Check if data is a tensor and move to cpu
if torch.is_tensor(data):
data = data.detach().cpu().numpy()

# Ensure data is numpy array
data = np.array(data)

# Normalize to range [-1, 1] if it's not already
if np.max(np.abs(data)) > 1.0:
data = data / np.max(np.abs(data))

# Scale to 16-bit integer range
data = (data * 32767).astype(np.int16)
return data
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"description": "A text-to-speech (TTS) audio generation server, supporting efficient, high-quality speech synthesis.",
"plugin-format": "python",
"type": "loader",
"version": "0.0.7",
"version": "0.0.8",
"supports": ["Text-to-Speech", "Audio"],
"model_architectures": ["CsmForConditionalGeneration", "LlamaForCausalLM"],
"supported_hardware_architectures": ["cuda", "amd"],
Expand Down
9 changes: 8 additions & 1 deletion transformerlab/plugins/unsloth_text_to_speech_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import soundfile as sf
import librosa
from audio import CsmAudioModel, OrpheusAudioModel
from audio import CsmAudioModel, OrpheusAudioModel, VibeVoiceAudioModel


from fastapi import BackgroundTasks, FastAPI, Request
Expand Down Expand Up @@ -93,6 +93,13 @@ def __init__(
)
logger.info("Initialized Orpheus Audio Model")

elif "vibevoice" in self.model_name.lower():
self.audio_model = VibeVoiceAudioModel(
model_name=self.model_name,
device=self.device,
)
logger.info("Initialized VibeVoice Audio Model")

else:
logger.info(
f"Model architecture {self.model_architecture} is not supported for audio generation."
Expand Down
3 changes: 3 additions & 0 deletions transformerlab/plugins/unsloth_text_to_speech_server/setup.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#!/usr/bin/env bash
uv pip install unsloth
uv pip install snac
uv pip install flash-attn
uv pip install "vibevoice @ git+https://github.com/rsxdalv/vibevoice.git@stable"
uv pip install transformers==4.55.2

if command -v rocminfo &> /dev/null; then
# Install Unsloth from source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"description": "A Text-to-Speech (TTS) trainer based on the unsloth audio training notebooks",
"plugin-format": "python",
"type": "trainer",
"version": "0.0.2",
"version": "0.0.3",
"model_architectures": [
"CsmForConditionalGeneration",
"LlamaForCausalLM"
Expand Down
16 changes: 15 additions & 1 deletion transformerlab/plugins/unsloth_text_to_speech_trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import TrainingArguments, Trainer
from datasets import Audio

from trainer import CsmAudioTrainer, OrpheusAudioTrainer
from trainer import CsmAudioTrainer, OrpheusAudioTrainer, VibeVoiceAudioTrainer

from transformerlab.sdk.v1.train import tlab_trainer # noqa: E402

Expand Down Expand Up @@ -92,6 +92,20 @@ def train_model():
audio_column_name=audio_column_name,
text_column_name=text_column_name
)
elif "vibevoice" in model_id:
model_trainer = VibeVoiceAudioTrainer(
model_name=model_id,
speaker_key=speaker_key,
context_length=max_seq_length,
device=device,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
sampling_rate=sampling_rate,
max_audio_length=max_audio_length,
audio_column_name=audio_column_name,
text_column_name=text_column_name
)
else:
raise ValueError(f"Model architecture {model_architecture} is not supported for audio training.")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env bash
uv pip install unsloth
uv pip install snac
vibevoice @ git+https://github.com/rsxdalv/vibevoice@stable

if command -v rocminfo &> /dev/null; then
# Install Unsloth from source
Expand All @@ -9,3 +10,4 @@ if command -v rocminfo &> /dev/null; then
# Install ROCm Bitsandbytes from source
git clone --recurse https://github.com/ROCm/bitsandbytes && cd bitsandbytes && git checkout rocm_enabled_multi_backend && uv pip install -r requirements-dev.txt && cmake -DCOMPUTE_BACKEND=hip -S . && make -j && uv pip install -e .
fi

110 changes: 110 additions & 0 deletions transformerlab/plugins/unsloth_text_to_speech_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from transformers import AutoProcessor
from snac import SNAC
from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


class AudioTrainerBase(ABC):
Expand Down Expand Up @@ -257,6 +259,114 @@
"attention_mask": attention_mask
}

except Exception as e:
print(f"Error processing example with text '{example[self.text_column_name][:50]}...': {e}")
return None


class VibeVoiceAudioTrainer(AudioTrainerBase):
def __init__(self, model_name, context_length, device, speaker_key,
lora_r, lora_alpha, lora_dropout, sampling_rate, max_audio_length,
audio_column_name="audio", text_column_name="text"):
super().__init__(model_name, context_length, device, speaker_key,
lora_r, lora_alpha, lora_dropout, sampling_rate, max_audio_length,
audio_column_name, text_column_name)

# Load model
dtype = torch.float32
self.model = VibeVoiceForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=dtype,
)

# Load processor
self.processor = VibeVoiceProcessor.from_pretrained(model_name)

# Set up LoRA for language model
self.model.model.language_model = FastLanguageModel.get_peft_model(
self.model.model.language_model,
r=lora_r,
target_modules=self.lora_target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)

num_trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f"Trainable parameters: {num_trainable}")

# Set noise scheduler
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
self.model.model.noise_scheduler.config,
algorithm_type="sde-dpmsolver++",
beta_schedule="squaredcos_cap_v2",
)

def preprocess_dataset(self, example):
"""
Preprocess a single example for VibeVoice training.
"""
try:
text = example[self.text_column_name]
audio_array = example[self.audio_column_name]["array"]

Check failure on line 315 in transformerlab/plugins/unsloth_text_to_speech_trainer/trainer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

transformerlab/plugins/unsloth_text_to_speech_trainer/trainer.py:315:13: F841 Local variable `audio_array` is assigned to but never used

Check failure on line 315 in transformerlab/plugins/unsloth_text_to_speech_trainer/trainer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

transformerlab/plugins/unsloth_text_to_speech_trainer/trainer.py:315:13: F841 Local variable `audio_array` is assigned to but never used

# Handle voice prompts if available
voice_samples = None
if "voice_prompts" in example and example["voice_prompts"]:
voice_samples = example["voice_prompts"]

# Process with VibeVoice processor
proc = self.processor(
text=[text],
voice_samples=voice_samples,
padding=False,
truncation=False,
max_length=self.context_length,
return_tensors="pt",
)

# Get basic tensors
input_ids = proc["input_ids"][0]
attention_mask = proc.get("attention_mask", torch.ones_like(input_ids))[0]

# Handle speech tensors and masks
speech_tensors = proc.get("speech_tensors")
speech_masks = proc.get("speech_masks")
speech_semantic_tensors = proc.get("speech_semantic_tensors")

# Create acoustic input mask
speech_input_mask = proc.get("speech_input_mask")
if speech_input_mask is None:
speech_input_mask = torch.zeros_like(input_ids, dtype=torch.bool)

# Create acoustic loss mask (loss only on target audio, not prompts)
acoustic_loss_mask = torch.zeros_like(speech_input_mask, dtype=torch.bool)
if speech_masks is not None and len(speech_masks) > 0:
# Last segment is target, others are prompts
target_mask = speech_masks[-1] # Last one is target
acoustic_loss_mask = torch.cat([
torch.zeros_like(speech_input_mask[:-len(target_mask)]),
target_mask
])

# Prepare labels (same as input_ids for causal LM)
labels = input_ids.clone()

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"speech_tensors": speech_tensors,
"speech_masks": speech_masks,
"speech_semantic_tensors": speech_semantic_tensors,
"acoustic_input_mask": speech_input_mask,
"acoustic_loss_mask": acoustic_loss_mask,
}

except Exception as e:
print(f"Error processing example with text '{example[self.text_column_name][:50]}...': {e}")
return None
Loading