Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/livekit/wakeword/data/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import numpy as np
from onnxruntime.capi.onnxruntime_pybind11_state import SessionOptions

from ..config import WakeWordConfig
from ..models.feature_extractor import MelSpectrogramFrontend, SpeechEmbedding
Expand Down Expand Up @@ -68,13 +69,15 @@ def extract_features_from_directory(
return np.stack(all_features, axis=0) # (N_clips, 16, 96)


def run_extraction(config: WakeWordConfig) -> None:
def run_extraction(config: WakeWordConfig, sess_options: SessionOptions) -> None:
"""Extract and save features for all splits of a wake word config."""
mel_frontend = MelSpectrogramFrontend(
onnx_path=get_mel_model_path(),
sess_options=sess_options,
)
speech_embedding = SpeechEmbedding(
onnx_path=get_embedding_model_path(),
sess_options=sess_options,
)

model_dir = config.model_output_dir
Expand Down
11 changes: 7 additions & 4 deletions src/livekit/wakeword/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import numpy as np
from onnxruntime.capi.onnxruntime_pybind11_state import SessionOptions

from ..models.feature_extractor import MelSpectrogramFrontend, SpeechEmbedding
from ..resources import get_embedding_model_path, get_mel_model_path
Expand Down Expand Up @@ -37,6 +38,7 @@ class WakeWordModel:
def __init__(
self,
models: list[str | Path] | None = None,
sess_options: SessionOptions | None = None,
):
"""Initialize the wake word detection model.

Expand All @@ -58,17 +60,17 @@ def __init__(
"This should not happen - please reinstall livekit-wakeword."
)

self._mel_frontend = MelSpectrogramFrontend(onnx_path=mel_path)
self._speech_embedding = SpeechEmbedding(onnx_path=embedding_path)
self._mel_frontend = MelSpectrogramFrontend(onnx_path=mel_path, sess_options=sess_options)
self._speech_embedding = SpeechEmbedding(onnx_path=embedding_path, sess_options=sess_options)

# name -> (onnx_session, input_name)
self._classifiers: dict[str, tuple] = {}

if models:
for model_path in models:
self.load_model(model_path)
self.load_model(model_path, sess_options=sess_options)

def load_model(self, model_path: str | Path, model_name: str | None = None) -> None:
def load_model(self, model_path: str | Path, model_name: str | None = None, sess_options: SessionOptions = None ) -> None:
"""Load a wake word classifier model.

Args:
Expand All @@ -87,6 +89,7 @@ def load_model(self, model_path: str | Path, model_name: str | None = None) -> N
session = ort.InferenceSession(
str(model_path),
providers=["CPUExecutionProvider"],
sess_options=sess_options
)
input_name = session.get_inputs()[0].name
self._classifiers[model_name] = (session, input_name)
Expand Down
11 changes: 7 additions & 4 deletions src/livekit/wakeword/models/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path

import numpy as np
from onnxruntime.capi.onnxruntime_pybind11_state import SessionOptions

logger = logging.getLogger(__name__)

Expand All @@ -28,20 +29,21 @@ class MelSpectrogramFrontend:
Output: (batch, time_frames, 32)
"""

def __init__(self, onnx_path: str | Path):
def __init__(self, onnx_path: str | Path, sess_options: SessionOptions = None):
if not Path(onnx_path).exists():
raise FileNotFoundError(
f"Mel ONNX model not found: {onnx_path}\n"
"This should not happen - please reinstall livekit-wakeword."
)
self._init_onnx(onnx_path)
self._init_onnx(onnx_path, sess_options)

def _init_onnx(self, onnx_path: str | Path) -> None:
def _init_onnx(self, onnx_path: str | Path, sess_options: SessionOptions = None) -> None:
import onnxruntime as ort

self._onnx_session = ort.InferenceSession(
str(onnx_path),
providers=["CPUExecutionProvider"],
sess_options=sess_options,
)
self._input_name = self._onnx_session.get_inputs()[0].name
logger.info(f"Loaded mel ONNX model from {onnx_path}")
Expand Down Expand Up @@ -89,7 +91,7 @@ class SpeechEmbedding:
ONNX output: (batch, 1, 1, 96) — 96-dim embedding
"""

def __init__(self, onnx_path: str | Path):
def __init__(self, onnx_path: str | Path, sess_options: SessionOptions = None):
import onnxruntime as ort

if not Path(onnx_path).exists():
Expand All @@ -101,6 +103,7 @@ def __init__(self, onnx_path: str | Path):
self._session = ort.InferenceSession(
str(onnx_path),
providers=["CPUExecutionProvider"],
sess_options=sess_options,
)
self._input_name = self._session.get_inputs()[0].name
logger.info(f"Loaded embedding ONNX model from {onnx_path}")
Expand Down
Loading