From 6e18305dbe04714ef82a2d45d0300062647c0887 Mon Sep 17 00:00:00 2001 From: Gwendal Roulleau Date: Tue, 12 May 2026 22:14:05 +0200 Subject: [PATCH] Add support for configurable ONNX SessionOptions across feature extraction and inference modules --- src/livekit/wakeword/data/features.py | 5 ++++- src/livekit/wakeword/inference/model.py | 11 +++++++---- src/livekit/wakeword/models/feature_extractor.py | 11 +++++++---- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/livekit/wakeword/data/features.py b/src/livekit/wakeword/data/features.py index 6588f56..ffab2da 100644 --- a/src/livekit/wakeword/data/features.py +++ b/src/livekit/wakeword/data/features.py @@ -6,6 +6,7 @@ from pathlib import Path import numpy as np +from onnxruntime.capi.onnxruntime_pybind11_state import SessionOptions from ..config import WakeWordConfig from ..models.feature_extractor import MelSpectrogramFrontend, SpeechEmbedding @@ -68,13 +69,15 @@ def extract_features_from_directory( return np.stack(all_features, axis=0) # (N_clips, 16, 96) -def run_extraction(config: WakeWordConfig) -> None: +def run_extraction(config: WakeWordConfig, sess_options: SessionOptions) -> None: """Extract and save features for all splits of a wake word config.""" mel_frontend = MelSpectrogramFrontend( onnx_path=get_mel_model_path(), + sess_options=sess_options, ) speech_embedding = SpeechEmbedding( onnx_path=get_embedding_model_path(), + sess_options=sess_options, ) model_dir = config.model_output_dir diff --git a/src/livekit/wakeword/inference/model.py b/src/livekit/wakeword/inference/model.py index 4d9db7c..4f16e02 100644 --- a/src/livekit/wakeword/inference/model.py +++ b/src/livekit/wakeword/inference/model.py @@ -6,6 +6,7 @@ from pathlib import Path import numpy as np +from onnxruntime.capi.onnxruntime_pybind11_state import SessionOptions from ..models.feature_extractor import MelSpectrogramFrontend, SpeechEmbedding from ..resources import get_embedding_model_path, get_mel_model_path @@ -37,6 +38,7 @@ class WakeWordModel: def __init__( self, models: list[str | Path] | None = None, + sess_options: SessionOptions | None = None, ): """Initialize the wake word detection model. @@ -58,17 +60,17 @@ def __init__( "This should not happen - please reinstall livekit-wakeword." ) - self._mel_frontend = MelSpectrogramFrontend(onnx_path=mel_path) - self._speech_embedding = SpeechEmbedding(onnx_path=embedding_path) + self._mel_frontend = MelSpectrogramFrontend(onnx_path=mel_path, sess_options=sess_options) + self._speech_embedding = SpeechEmbedding(onnx_path=embedding_path, sess_options=sess_options) # name -> (onnx_session, input_name) self._classifiers: dict[str, tuple] = {} if models: for model_path in models: - self.load_model(model_path) + self.load_model(model_path, sess_options=sess_options) - def load_model(self, model_path: str | Path, model_name: str | None = None) -> None: + def load_model(self, model_path: str | Path, model_name: str | None = None, sess_options: SessionOptions = None ) -> None: """Load a wake word classifier model. Args: @@ -87,6 +89,7 @@ def load_model(self, model_path: str | Path, model_name: str | None = None) -> N session = ort.InferenceSession( str(model_path), providers=["CPUExecutionProvider"], + sess_options=sess_options ) input_name = session.get_inputs()[0].name self._classifiers[model_name] = (session, input_name) diff --git a/src/livekit/wakeword/models/feature_extractor.py b/src/livekit/wakeword/models/feature_extractor.py index b364106..76c8c71 100644 --- a/src/livekit/wakeword/models/feature_extractor.py +++ b/src/livekit/wakeword/models/feature_extractor.py @@ -13,6 +13,7 @@ from pathlib import Path import numpy as np +from onnxruntime.capi.onnxruntime_pybind11_state import SessionOptions logger = logging.getLogger(__name__) @@ -28,20 +29,21 @@ class MelSpectrogramFrontend: Output: (batch, time_frames, 32) """ - def __init__(self, onnx_path: str | Path): + def __init__(self, onnx_path: str | Path, sess_options: SessionOptions = None): if not Path(onnx_path).exists(): raise FileNotFoundError( f"Mel ONNX model not found: {onnx_path}\n" "This should not happen - please reinstall livekit-wakeword." ) - self._init_onnx(onnx_path) + self._init_onnx(onnx_path, sess_options) - def _init_onnx(self, onnx_path: str | Path) -> None: + def _init_onnx(self, onnx_path: str | Path, sess_options: SessionOptions = None) -> None: import onnxruntime as ort self._onnx_session = ort.InferenceSession( str(onnx_path), providers=["CPUExecutionProvider"], + sess_options=sess_options, ) self._input_name = self._onnx_session.get_inputs()[0].name logger.info(f"Loaded mel ONNX model from {onnx_path}") @@ -89,7 +91,7 @@ class SpeechEmbedding: ONNX output: (batch, 1, 1, 96) — 96-dim embedding """ - def __init__(self, onnx_path: str | Path): + def __init__(self, onnx_path: str | Path, sess_options: SessionOptions = None): import onnxruntime as ort if not Path(onnx_path).exists(): @@ -101,6 +103,7 @@ def __init__(self, onnx_path: str | Path): self._session = ort.InferenceSession( str(onnx_path), providers=["CPUExecutionProvider"], + sess_options=sess_options, ) self._input_name = self._session.get_inputs()[0].name logger.info(f"Loaded embedding ONNX model from {onnx_path}")