diff --git a/configs/prod.yaml b/configs/prod.yaml index 164faca..15c7213 100644 --- a/configs/prod.yaml +++ b/configs/prod.yaml @@ -44,6 +44,14 @@ custom_negative_phrases: - "hey look at" - "hey look it" +# Real recordings of the wake word to inject into positive_train alongside +# the TTS clips — use when biasing toward a specific voice or demographic. +# Each directory must contain 16 kHz mono .wav files; each file is copied +# `multiplier` times before augmentation. See docs/data-generation.md. +# custom_positive_samples: +# - path: ./data/my_recordings +# multiplier: 50 + # ============================================================================ # TTS Parameters (VITS + SLERP speaker blending) # ============================================================================ diff --git a/docs/data-generation.md b/docs/data-generation.md index 5d74cca..591f633 100644 --- a/docs/data-generation.md +++ b/docs/data-generation.md @@ -104,6 +104,39 @@ Generated audio is silence-trimmed via WebRTC VAD. If the VAD strips too aggress **Diversification:** Defaults cover many `voice_design_prompts` × `cfg_values` × `inference_timesteps_list` (see `VoxCpmTtsConfig` in `config.py`). Clip *i* cycles through that Cartesian product so resumes stay aligned with `start_index`. Output is **16 kHz** `clip_%06d.wav` (model native rate is resampled with librosa). +## Custom Positive Samples + +Real human recordings of the target phrase can be injected into `positive_train` alongside the synthetic TTS clips. This is the usual way to bias the model toward a specific voice (yours, a customer's, a target demographic) without giving up the diversity of the TTS speaker pool. + +### Configuration + +```yaml +custom_positive_samples: + - path: ./data/my_recordings + multiplier: 50 + - path: ./data/other_voices + multiplier: 10 +``` + +Each entry is a directory of **16 kHz mono `.wav`** files. Every file is copied `multiplier` times into `positive_train/`, appended after the TTS clips using the same `clip_NNNNNN.wav` numbering. Augmentation (RIR, background noise, EQ, distortion) then runs over every copy independently, so duplicates are not wasted — each one sees different acoustic conditions per round. + +- **Why a multiplier instead of a sampling weight?** The training sampler cycles through positives deterministically (see `batch_n_per_class`), so oversampling by duplication is how you increase per-voice exposure in this architecture. A multiplier of 50 over 143 recordings yields 7,150 positive copies that augment into ~14,300 unique features with `augmentation.rounds: 2`. +- **No resampling.** Sample-rate or channel mismatches raise `ValueError` so a misconfigured source surfaces early rather than silently producing bad training data. Pre-convert with `sox in.wav -r 16000 -c 1 out.wav` or an equivalent ffmpeg one-liner. +- **Resume-safe.** Each copy checks for its output path before writing, so interrupted runs pick up where they left off. The injection `start_index` is pinned to `n_samples`, so layout is deterministic even if TTS skipped some clips due to OOM. +- **Train split only.** Custom recordings do not enter `positive_test`; eval runs against held-out TTS. Keep a separate held-out dir outside `custom_positive_samples` if you want to measure recall on real voices. + +### Minimal layout + +``` +data/ +└── my_recordings/ + ├── take_001.wav # 16 kHz mono + ├── take_002.wav + └── ... +``` + +Non-`.wav` files in the source directory are ignored with a warning. A missing path raises `FileNotFoundError` — typos are not silent. + ## Adversarial Phrase Generation `generate_adversarial_phrases()` creates phonetically similar but incorrect phrases to train the model to reject near-misses. The resulting phrases are fed back through the active TTS backend to produce negative clips. diff --git a/src/livekit/wakeword/config.py b/src/livekit/wakeword/config.py index 4d694d5..695d676 100644 --- a/src/livekit/wakeword/config.py +++ b/src/livekit/wakeword/config.py @@ -59,6 +59,23 @@ class AugmentationConfig(BaseModel): rir_paths: list[str] = Field(default_factory=lambda: ["./data/rirs"]) +class CustomPositiveSource(BaseModel): + """User-supplied directory of positive WAV files to inject into ``positive_train``. + + Each ``.wav`` file in ``path`` is copied ``multiplier`` times into the + ``positive_train`` split, appended after the TTS-generated clips. Copies + enter the standard augmentation pipeline — each one gets different RIR / + background / EQ per round — so duplication is not wasted. + """ + + path: str = Field(description="Directory of 16 kHz mono .wav files (absolute or relative)") + multiplier: int = Field( + default=1, + ge=1, + description="Number of copies per source file (oversampling factor)", + ) + + class ModelConfig(BaseModel): model_type: ModelType = ModelType.conv_attention model_size: ModelSize = ModelSize.small @@ -126,6 +143,10 @@ class WakeWordConfig(BaseModel): piper_tts: PiperTtsConfig = Field(default_factory=PiperTtsConfig) voxcpm_tts: VoxCpmTtsConfig = Field(default_factory=VoxCpmTtsConfig) custom_negative_phrases: list[str] = Field(default_factory=list) + custom_positive_samples: list[CustomPositiveSource] = Field( + default_factory=list, + description="User-supplied positive audio sources injected into positive_train after TTS", + ) # TTS parameters (Piper VITS + SLERP speaker blending) noise_scales: list[float] = Field(default_factory=lambda: [0.98]) diff --git a/src/livekit/wakeword/data/generate.py b/src/livekit/wakeword/data/generate.py index e87e9f5..92549a5 100644 --- a/src/livekit/wakeword/data/generate.py +++ b/src/livekit/wakeword/data/generate.py @@ -7,7 +7,7 @@ import re from pathlib import Path -from ..config import WakeWordConfig +from ..config import CustomPositiveSource, WakeWordConfig from .piper.text import expand_unknown_words, get_cmudict from .tts import get_tts_backend from .tts.piper_backend import PiperVitsBackend @@ -50,6 +50,111 @@ def _count_original_clips(directory: Path) -> int: return sum(1 for f in directory.iterdir() if _ORIGINAL_CLIP_RE.match(f.name)) +def _copy_custom_positives( + split_dir: Path, + sources: list[CustomPositiveSource], + start_index: int, + expected_sample_rate: int = 16000, +) -> int: + """Append user-supplied ``.wav`` files to ``split_dir`` as ``clip_NNNNNN.wav``. + + Each file in each source is copied ``source.multiplier`` times. Output + indexing starts at *start_index* and increases monotonically across sources. + Existing output files are skipped so repeated calls are idempotent + (resume-safe after interruption). + + Input files must match *expected_sample_rate* and be mono — mismatches raise + rather than silently resampling, so configuration errors surface early. + + Returns: + Number of new files written on this call. + + Raises: + FileNotFoundError: If a source ``path`` does not exist. + NotADirectoryError: If a source ``path`` exists but is not a directory. + ValueError: If any ``.wav`` has the wrong sample rate or is not mono. + """ + import shutil + + import soundfile as sf + + if not sources: + return 0 + + split_dir.mkdir(parents=True, exist_ok=True) + + written = 0 + next_index = start_index + + for source in sources: + src_path = Path(source.path) + if not src_path.exists(): + raise FileNotFoundError(f"Custom positive source path does not exist: {src_path}") + if not src_path.is_dir(): + raise NotADirectoryError( + f"Custom positive source path must be a directory, got: {src_path}" + ) + + wav_files = sorted(src_path.glob("*.wav")) + non_wav = [p.name for p in src_path.iterdir() if p.is_file() and p.suffix.lower() != ".wav"] + if non_wav: + sample = ", ".join(non_wav[:3]) + (" ..." if len(non_wav) > 3 else "") + logger.warning( + "Ignoring %d non-.wav file(s) in %s: %s", + len(non_wav), + src_path, + sample, + ) + + if not wav_files: + logger.warning("No .wav files found in custom positive source: %s", src_path) + continue + + # Validate sample rate and channels up front so we don't partially + # copy before discovering a bad file deep in the list. + for wav in wav_files: + info = sf.info(str(wav)) + if info.samplerate != expected_sample_rate: + raise ValueError( + f"Custom positive {wav} has sample rate {info.samplerate}, " + f"expected {expected_sample_rate}. Pre-convert the file " + f"(e.g. `sox in.wav -r 16000 -c 1 out.wav`)." + ) + if info.channels != 1: + raise ValueError( + f"Custom positive {wav} has {info.channels} channels, " + f"expected 1 (mono). Pre-convert the file " + f"(e.g. `sox in.wav -r 16000 -c 1 out.wav`)." + ) + + logger.info( + "Injecting %d file(s) from %s × multiplier %d = %d copies", + len(wav_files), + src_path, + source.multiplier, + len(wav_files) * source.multiplier, + ) + + for wav in wav_files: + for _ in range(source.multiplier): + out_path = split_dir / f"clip_{next_index:06d}.wav" + next_index += 1 + if out_path.exists(): + continue + shutil.copy2(wav, out_path) + written += 1 + + if written: + logger.info( + "Custom positive injection: wrote %d new clip(s) (directory now contains clips 0..%d)", + written, + next_index - 1, + ) + else: + logger.info("Custom positive injection: all clips already present, nothing to do") + return written + + def _phoneme_replacements( phones: list[str], max_replace: int | None = None, @@ -373,6 +478,18 @@ def run_generate(config: WakeWordConfig) -> None: batch_size=config.tts_batch_size, ) + # --- Custom positive samples (user-supplied recordings) --- + # Appended after TTS so the combined directory keeps the clip_NNNNNN.wav + # layout. start_index is pinned to config.n_samples (not the live count) + # so layout is deterministic even if TTS skipped clips due to OOM — gaps + # are harmless to the augmentation regex. + if config.custom_positive_samples: + _copy_custom_positives( + split_dir=model_dir / "positive_train", + sources=config.custom_positive_samples, + start_index=config.n_samples, + ) + # --- Adversarial negative splits --- neg_train_dir = model_dir / "negative_train" neg_test_dir = model_dir / "negative_test" diff --git a/tests/test_custom_positives.py b/tests/test_custom_positives.py new file mode 100644 index 0000000..e564807 --- /dev/null +++ b/tests/test_custom_positives.py @@ -0,0 +1,350 @@ +"""Tests for custom positive sample injection.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import pytest +import soundfile as sf + +from livekit.wakeword.config import CustomPositiveSource, WakeWordConfig +from livekit.wakeword.data.generate import _copy_custom_positives + + +def _make_wav( + path: Path, + duration_s: float = 1.0, + sample_rate: int = 16000, + channels: int = 1, +) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + n_frames = int(duration_s * sample_rate) + shape = (n_frames,) if channels == 1 else (n_frames, channels) + data = (np.random.randn(*shape) * 0.1).astype(np.float32) + sf.write(str(path), data, sample_rate) + + +class TestCustomPositiveSourceModel: + def test_default_multiplier_is_one(self) -> None: + src = CustomPositiveSource(path="/tmp/anywhere") + assert src.multiplier == 1 + + def test_rejects_zero_multiplier(self) -> None: + with pytest.raises(ValueError): + CustomPositiveSource(path="/tmp/anywhere", multiplier=0) + + def test_rejects_negative_multiplier(self) -> None: + with pytest.raises(ValueError): + CustomPositiveSource(path="/tmp/anywhere", multiplier=-1) + + +class TestCopyCustomPositives: + def test_empty_sources_is_noop(self, tmp_path: Path) -> None: + split_dir = tmp_path / "positive_train" + written = _copy_custom_positives(split_dir, [], start_index=10) + assert written == 0 + # Empty sources should not even create the output directory. + assert not split_dir.exists() + + def test_basic_copy_appends_at_start_index(self, tmp_path: Path) -> None: + src = tmp_path / "recordings" + _make_wav(src / "sample1.wav") + _make_wav(src / "sample2.wav") + + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(src), multiplier=1) + written = _copy_custom_positives(split_dir, [source], start_index=5) + + assert written == 2 + assert (split_dir / "clip_000005.wav").exists() + assert (split_dir / "clip_000006.wav").exists() + # Numbering does not collide with pre-existing range. + assert not (split_dir / "clip_000004.wav").exists() + + def test_multiplier_duplicates_each_file(self, tmp_path: Path) -> None: + src = tmp_path / "recordings" + _make_wav(src / "only.wav") + + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(src), multiplier=3) + written = _copy_custom_positives(split_dir, [source], start_index=0) + + assert written == 3 + names = sorted(p.name for p in split_dir.glob("clip_*.wav")) + assert names == ["clip_000000.wav", "clip_000001.wav", "clip_000002.wav"] + + def test_multiple_sources_continuous_numbering(self, tmp_path: Path) -> None: + src1 = tmp_path / "first" + src2 = tmp_path / "second" + _make_wav(src1 / "a.wav") + _make_wav(src2 / "b.wav") + + split_dir = tmp_path / "positive_train" + sources = [ + CustomPositiveSource(path=str(src1), multiplier=2), + CustomPositiveSource(path=str(src2), multiplier=3), + ] + written = _copy_custom_positives(split_dir, sources, start_index=100) + + assert written == 5 + for i in range(100, 105): + assert (split_dir / f"clip_{i:06d}.wav").exists() + + def test_resume_skips_existing_outputs(self, tmp_path: Path) -> None: + src = tmp_path / "recordings" + _make_wav(src / "a.wav") + _make_wav(src / "b.wav") + + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(src), multiplier=2) + + first = _copy_custom_positives(split_dir, [source], start_index=0) + assert first == 4 + # A second call with identical inputs should be a no-op. + second = _copy_custom_positives(split_dir, [source], start_index=0) + assert second == 0 + assert len(list(split_dir.glob("clip_*.wav"))) == 4 + + def test_resume_fills_gap_when_partial(self, tmp_path: Path) -> None: + src = tmp_path / "recordings" + _make_wav(src / "a.wav") + _make_wav(src / "b.wav") + _make_wav(src / "c.wav") + + split_dir = tmp_path / "positive_train" + split_dir.mkdir() + # Simulate a partial previous run: first two targets are already on disk. + (split_dir / "clip_000010.wav").write_bytes(b"existing-a") + (split_dir / "clip_000011.wav").write_bytes(b"existing-b") + + source = CustomPositiveSource(path=str(src), multiplier=1) + written = _copy_custom_positives(split_dir, [source], start_index=10) + + # Only the third slot should be newly written. + assert written == 1 + assert (split_dir / "clip_000010.wav").read_bytes() == b"existing-a" + assert (split_dir / "clip_000011.wav").read_bytes() == b"existing-b" + assert (split_dir / "clip_000012.wav").exists() + + def test_rejects_wrong_sample_rate(self, tmp_path: Path) -> None: + src = tmp_path / "recordings" + _make_wav(src / "bad.wav", sample_rate=48000) + + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(src), multiplier=1) + with pytest.raises(ValueError, match="sample rate 48000"): + _copy_custom_positives(split_dir, [source], start_index=0) + # No partial output should be left behind. + assert not any(split_dir.glob("clip_*.wav")) + + def test_rejects_stereo(self, tmp_path: Path) -> None: + src = tmp_path / "recordings" + _make_wav(src / "stereo.wav", channels=2) + + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(src), multiplier=1) + with pytest.raises(ValueError, match="2 channels"): + _copy_custom_positives(split_dir, [source], start_index=0) + + def test_missing_source_path_raises(self, tmp_path: Path) -> None: + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(tmp_path / "nope"), multiplier=1) + with pytest.raises(FileNotFoundError, match="does not exist"): + _copy_custom_positives(split_dir, [source], start_index=0) + + def test_source_path_is_file_raises(self, tmp_path: Path) -> None: + file_as_source = tmp_path / "single.wav" + _make_wav(file_as_source) + + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(file_as_source), multiplier=1) + with pytest.raises(NotADirectoryError, match="must be a directory"): + _copy_custom_positives(split_dir, [source], start_index=0) + + def test_warns_and_skips_non_wav_files( + self, tmp_path: Path, caplog: pytest.LogCaptureFixture + ) -> None: + src = tmp_path / "recordings" + _make_wav(src / "good.wav") + (src / "notes.txt").write_text("ignored") + (src / "cover.mp3").touch() + + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(src), multiplier=1) + with caplog.at_level(logging.WARNING): + written = _copy_custom_positives(split_dir, [source], start_index=0) + + assert written == 1 + assert any( + "non-.wav" in rec.message + and "notes.txt" in rec.message + or "non-.wav" in rec.message + and "cover.mp3" in rec.message + for rec in caplog.records + ) + + def test_empty_source_dir_warns_no_files( + self, tmp_path: Path, caplog: pytest.LogCaptureFixture + ) -> None: + src = tmp_path / "recordings" + src.mkdir() + + split_dir = tmp_path / "positive_train" + source = CustomPositiveSource(path=str(src), multiplier=1) + with caplog.at_level(logging.WARNING): + written = _copy_custom_positives(split_dir, [source], start_index=0) + assert written == 0 + assert any("No .wav files" in rec.message for rec in caplog.records) + + +class _FakeTtsBackend: + """Minimal TTS stub that writes placeholder clips for integration testing.""" + + def validate_artifacts(self) -> None: + return None + + def synthesize_clips( + self, + phrases: list[str], + output_dir: Path, + n_samples: int, + start_index: int = 0, + batch_size: int = 50, + ) -> list[Path]: + output_dir.mkdir(parents=True, exist_ok=True) + written = [] + for i in range(start_index, n_samples): + p = output_dir / f"clip_{i:06d}.wav" + p.touch() # empty placeholder, distinguishable from real copies by size + written.append(p) + return written + + +class TestRunGenerateIntegration: + """End-to-end: verify run_generate wires custom_positive_samples correctly.""" + + def test_injection_happens_after_positive_train_tts( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + rec_dir = tmp_path / "recordings" + _make_wav(rec_dir / "voice_a.wav") + _make_wav(rec_dir / "voice_b.wav") + + cfg = WakeWordConfig( + model_name="hey_integration", + target_phrases=["hey integration"], + n_samples=3, + n_samples_val=1, + n_background_samples=0, + n_background_samples_val=0, + data_dir=str(tmp_path / "data"), + output_dir=str(tmp_path / "output"), + custom_positive_samples=[ + CustomPositiveSource(path=str(rec_dir), multiplier=4), + ], + ) + + from livekit.wakeword.data import generate as gen_mod + + monkeypatch.setattr(gen_mod, "get_tts_backend", lambda c: _FakeTtsBackend()) + gen_mod.run_generate(cfg) + + pos_train = cfg.model_output_dir / "positive_train" + all_clips = sorted(p.name for p in pos_train.glob("clip_*.wav")) + # 3 TTS placeholders + 2 files × multiplier 4 = 11 total clips + assert len(all_clips) == 11 + assert all_clips[0] == "clip_000000.wav" + assert all_clips[-1] == "clip_000010.wav" + + # TTS clips (indices 0..2) are empty placeholders from _FakeTtsBackend. + # Custom clips (indices 3..10) are real wav copies with non-zero size. + for i in range(3): + assert (pos_train / f"clip_{i:06d}.wav").stat().st_size == 0 + for i in range(3, 11): + assert (pos_train / f"clip_{i:06d}.wav").stat().st_size > 0 + + # positive_test is unaffected (no custom injection there in v1) + pos_test_clips = list((cfg.model_output_dir / "positive_test").glob("clip_*.wav")) + assert len(pos_test_clips) == cfg.n_samples_val + + def test_run_generate_is_idempotent_with_custom_positives( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + rec_dir = tmp_path / "recordings" + _make_wav(rec_dir / "voice.wav") + + cfg = WakeWordConfig( + model_name="hey_resume", + target_phrases=["hey resume"], + n_samples=2, + n_samples_val=1, + n_background_samples=0, + n_background_samples_val=0, + data_dir=str(tmp_path / "data"), + output_dir=str(tmp_path / "output"), + custom_positive_samples=[ + CustomPositiveSource(path=str(rec_dir), multiplier=3), + ], + ) + + from livekit.wakeword.data import generate as gen_mod + + monkeypatch.setattr(gen_mod, "get_tts_backend", lambda c: _FakeTtsBackend()) + + gen_mod.run_generate(cfg) + first_listing = sorted( + (p.name, p.stat().st_size) + for p in (cfg.model_output_dir / "positive_train").glob("clip_*.wav") + ) + + # Second invocation should be a complete no-op on positive_train + gen_mod.run_generate(cfg) + second_listing = sorted( + (p.name, p.stat().st_size) + for p in (cfg.model_output_dir / "positive_train").glob("clip_*.wav") + ) + + assert first_listing == second_listing + assert len(first_listing) == 2 + 3 # 2 TTS + 1 file × multiplier 3 + + +class TestConfigIntegration: + def test_config_accepts_custom_positive_samples(self, tmp_path: Path) -> None: + cfg = WakeWordConfig( + model_name="hey_test", + target_phrases=["hey test"], + custom_positive_samples=[ + CustomPositiveSource(path=str(tmp_path), multiplier=10), + ], + ) + assert len(cfg.custom_positive_samples) == 1 + assert cfg.custom_positive_samples[0].multiplier == 10 + + def test_config_default_is_empty_list(self) -> None: + cfg = WakeWordConfig( + model_name="hey_test", + target_phrases=["hey test"], + ) + assert cfg.custom_positive_samples == [] + + def test_config_yaml_parsing(self, tmp_path: Path) -> None: + rec_dir = tmp_path / "recordings" + rec_dir.mkdir() + yaml_text = f""" +model_name: hey_test +target_phrases: ["hey test"] +custom_positive_samples: + - path: {rec_dir} + multiplier: 25 +""" + yaml_path = tmp_path / "cfg.yaml" + yaml_path.write_text(yaml_text) + + from livekit.wakeword.config import load_config + + cfg = load_config(yaml_path) + assert len(cfg.custom_positive_samples) == 1 + assert cfg.custom_positive_samples[0].path == str(rec_dir) + assert cfg.custom_positive_samples[0].multiplier == 25