Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions configs/prod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ custom_negative_phrases:
- "hey look at"
- "hey look it"

# Real recordings of the wake word to inject into positive_train alongside
# the TTS clips — use when biasing toward a specific voice or demographic.
# Each directory must contain 16 kHz mono .wav files; each file is copied
# `multiplier` times before augmentation. See docs/data-generation.md.
# custom_positive_samples:
# - path: ./data/my_recordings
# multiplier: 50

# ============================================================================
# TTS Parameters (VITS + SLERP speaker blending)
# ============================================================================
Expand Down
33 changes: 33 additions & 0 deletions docs/data-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,39 @@ Generated audio is silence-trimmed via WebRTC VAD. If the VAD strips too aggress

**Diversification:** Defaults cover many `voice_design_prompts` × `cfg_values` × `inference_timesteps_list` (see `VoxCpmTtsConfig` in `config.py`). Clip *i* cycles through that Cartesian product so resumes stay aligned with `start_index`. Output is **16 kHz** `clip_%06d.wav` (model native rate is resampled with librosa).

## Custom Positive Samples

Real human recordings of the target phrase can be injected into `positive_train` alongside the synthetic TTS clips. This is the usual way to bias the model toward a specific voice (yours, a customer's, a target demographic) without giving up the diversity of the TTS speaker pool.

### Configuration

```yaml
custom_positive_samples:
- path: ./data/my_recordings
multiplier: 50
- path: ./data/other_voices
multiplier: 10
```

Each entry is a directory of **16 kHz mono `.wav`** files. Every file is copied `multiplier` times into `positive_train/`, appended after the TTS clips using the same `clip_NNNNNN.wav` numbering. Augmentation (RIR, background noise, EQ, distortion) then runs over every copy independently, so duplicates are not wasted — each one sees different acoustic conditions per round.

- **Why a multiplier instead of a sampling weight?** The training sampler cycles through positives deterministically (see `batch_n_per_class`), so oversampling by duplication is how you increase per-voice exposure in this architecture. A multiplier of 50 over 143 recordings yields 7,150 positive copies that augment into ~14,300 unique features with `augmentation.rounds: 2`.
- **No resampling.** Sample-rate or channel mismatches raise `ValueError` so a misconfigured source surfaces early rather than silently producing bad training data. Pre-convert with `sox in.wav -r 16000 -c 1 out.wav` or an equivalent ffmpeg one-liner.
- **Resume-safe.** Each copy checks for its output path before writing, so interrupted runs pick up where they left off. The injection `start_index` is pinned to `n_samples`, so layout is deterministic even if TTS skipped some clips due to OOM.
- **Train split only.** Custom recordings do not enter `positive_test`; eval runs against held-out TTS. Keep a separate held-out dir outside `custom_positive_samples` if you want to measure recall on real voices.

### Minimal layout

```
data/
└── my_recordings/
├── take_001.wav # 16 kHz mono
├── take_002.wav
└── ...
```

Non-`.wav` files in the source directory are ignored with a warning. A missing path raises `FileNotFoundError` — typos are not silent.

## Adversarial Phrase Generation

`generate_adversarial_phrases()` creates phonetically similar but incorrect phrases to train the model to reject near-misses. The resulting phrases are fed back through the active TTS backend to produce negative clips.
Expand Down
21 changes: 21 additions & 0 deletions src/livekit/wakeword/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ class AugmentationConfig(BaseModel):
rir_paths: list[str] = Field(default_factory=lambda: ["./data/rirs"])


class CustomPositiveSource(BaseModel):
"""User-supplied directory of positive WAV files to inject into ``positive_train``.

Each ``.wav`` file in ``path`` is copied ``multiplier`` times into the
``positive_train`` split, appended after the TTS-generated clips. Copies
enter the standard augmentation pipeline — each one gets different RIR /
background / EQ per round — so duplication is not wasted.
"""

path: str = Field(description="Directory of 16 kHz mono .wav files (absolute or relative)")
multiplier: int = Field(
default=1,
ge=1,
description="Number of copies per source file (oversampling factor)",
)


class ModelConfig(BaseModel):
model_type: ModelType = ModelType.conv_attention
model_size: ModelSize = ModelSize.small
Expand Down Expand Up @@ -126,6 +143,10 @@ class WakeWordConfig(BaseModel):
piper_tts: PiperTtsConfig = Field(default_factory=PiperTtsConfig)
voxcpm_tts: VoxCpmTtsConfig = Field(default_factory=VoxCpmTtsConfig)
custom_negative_phrases: list[str] = Field(default_factory=list)
custom_positive_samples: list[CustomPositiveSource] = Field(
default_factory=list,
description="User-supplied positive audio sources injected into positive_train after TTS",
)

# TTS parameters (Piper VITS + SLERP speaker blending)
noise_scales: list[float] = Field(default_factory=lambda: [0.98])
Expand Down
119 changes: 118 additions & 1 deletion src/livekit/wakeword/data/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
from pathlib import Path

from ..config import WakeWordConfig
from ..config import CustomPositiveSource, WakeWordConfig
from .piper.text import expand_unknown_words, get_cmudict
from .tts import get_tts_backend
from .tts.piper_backend import PiperVitsBackend
Expand Down Expand Up @@ -50,6 +50,111 @@ def _count_original_clips(directory: Path) -> int:
return sum(1 for f in directory.iterdir() if _ORIGINAL_CLIP_RE.match(f.name))


def _copy_custom_positives(
split_dir: Path,
sources: list[CustomPositiveSource],
start_index: int,
expected_sample_rate: int = 16000,
) -> int:
"""Append user-supplied ``.wav`` files to ``split_dir`` as ``clip_NNNNNN.wav``.

Each file in each source is copied ``source.multiplier`` times. Output
indexing starts at *start_index* and increases monotonically across sources.
Existing output files are skipped so repeated calls are idempotent
(resume-safe after interruption).

Input files must match *expected_sample_rate* and be mono — mismatches raise
rather than silently resampling, so configuration errors surface early.

Returns:
Number of new files written on this call.

Raises:
FileNotFoundError: If a source ``path`` does not exist.
NotADirectoryError: If a source ``path`` exists but is not a directory.
ValueError: If any ``.wav`` has the wrong sample rate or is not mono.
"""
import shutil

import soundfile as sf

if not sources:
return 0

split_dir.mkdir(parents=True, exist_ok=True)

written = 0
next_index = start_index

for source in sources:
src_path = Path(source.path)
if not src_path.exists():
raise FileNotFoundError(f"Custom positive source path does not exist: {src_path}")
if not src_path.is_dir():
raise NotADirectoryError(
f"Custom positive source path must be a directory, got: {src_path}"
)

wav_files = sorted(src_path.glob("*.wav"))
non_wav = [p.name for p in src_path.iterdir() if p.is_file() and p.suffix.lower() != ".wav"]
if non_wav:
sample = ", ".join(non_wav[:3]) + (" ..." if len(non_wav) > 3 else "")
logger.warning(
"Ignoring %d non-.wav file(s) in %s: %s",
len(non_wav),
src_path,
sample,
)

if not wav_files:
logger.warning("No .wav files found in custom positive source: %s", src_path)
continue

# Validate sample rate and channels up front so we don't partially
# copy before discovering a bad file deep in the list.
for wav in wav_files:
info = sf.info(str(wav))
if info.samplerate != expected_sample_rate:
raise ValueError(
f"Custom positive {wav} has sample rate {info.samplerate}, "
f"expected {expected_sample_rate}. Pre-convert the file "
f"(e.g. `sox in.wav -r 16000 -c 1 out.wav`)."
)
if info.channels != 1:
raise ValueError(
f"Custom positive {wav} has {info.channels} channels, "
f"expected 1 (mono). Pre-convert the file "
f"(e.g. `sox in.wav -r 16000 -c 1 out.wav`)."
)

logger.info(
"Injecting %d file(s) from %s × multiplier %d = %d copies",
len(wav_files),
src_path,
source.multiplier,
len(wav_files) * source.multiplier,
)

for wav in wav_files:
for _ in range(source.multiplier):
out_path = split_dir / f"clip_{next_index:06d}.wav"
next_index += 1
if out_path.exists():
continue
shutil.copy2(wav, out_path)
written += 1

if written:
logger.info(
"Custom positive injection: wrote %d new clip(s) (directory now contains clips 0..%d)",
written,
next_index - 1,
)
else:
logger.info("Custom positive injection: all clips already present, nothing to do")
return written


def _phoneme_replacements(
phones: list[str],
max_replace: int | None = None,
Expand Down Expand Up @@ -373,6 +478,18 @@ def run_generate(config: WakeWordConfig) -> None:
batch_size=config.tts_batch_size,
)

# --- Custom positive samples (user-supplied recordings) ---
# Appended after TTS so the combined directory keeps the clip_NNNNNN.wav
# layout. start_index is pinned to config.n_samples (not the live count)
# so layout is deterministic even if TTS skipped clips due to OOM — gaps
# are harmless to the augmentation regex.
if config.custom_positive_samples:
_copy_custom_positives(
split_dir=model_dir / "positive_train",
sources=config.custom_positive_samples,
start_index=config.n_samples,
)

# --- Adversarial negative splits ---
neg_train_dir = model_dir / "negative_train"
neg_test_dir = model_dir / "negative_test"
Expand Down
Loading
Loading