diff --git a/src/om1_speech/processor/connection_processor.py b/src/om1_speech/processor/connection_processor.py index 7d81fb6..91e774e 100644 --- a/src/om1_speech/processor/connection_processor.py +++ b/src/om1_speech/processor/connection_processor.py @@ -101,6 +101,11 @@ def handle_new_connection(self, connection_id: str): if self.ws_server else None ), + error_callback=lambda: ( + self.ws_server.close_connection(connection_id) + if self.ws_server + else None + ), ) self.asr_processors[connection_id] = asr_processor diff --git a/src/om1_speech/riva/asr_processor.py b/src/om1_speech/riva/asr_processor.py index 99f8c76..2bf10eb 100644 --- a/src/om1_speech/riva/asr_processor.py +++ b/src/om1_speech/riva/asr_processor.py @@ -28,15 +28,21 @@ class ASRProcessor(ASRProcessorInterface): Command line arguments and configuration parameters callback : Optional[Callable], optional Callback function to receive ASR results (default: None) + error_callback : Optional[Callable], optional + Callback function to be called when an error occurs (default: None) """ def __init__( - self, model_args: argparse.Namespace, callback: Optional[Callable] = None + self, + model_args: argparse.Namespace, + callback: Optional[Callable] = None, + error_callback: Optional[Callable] = None, ): self.model: Optional[ASRService] = None # type: ignore self.model_config: Optional[StreamingRecognitionConfig] = None # type: ignore self.args = model_args self.callback = callback + self.error_callback = error_callback self.running: bool = True # ASR settings @@ -154,7 +160,8 @@ def process_audio(self, audio_source: Any): Continuously processes audio chunks from the source and generates transcriptions. Final transcriptions are logged and sent to the - callback function if provided. + callback function if provided. If an error occurs during processing, + the error_callback is invoked (if provided) to close the connection. Parameters ---------- @@ -162,7 +169,11 @@ def process_audio(self, audio_source: Any): Source object that provides audio chunks for processing """ if self.model is None or self.model_config is None: - raise RuntimeError("ASR model is not initialized.") + error_msg = "ASR model is not initialized." + logger.error(error_msg) + if self.error_callback: + self.error_callback() + raise RuntimeError(error_msg) logger.info("Waiting for first audio chunk to initialize sample rate...") while self.running: @@ -179,29 +190,41 @@ def process_audio(self, audio_source: Any): break time.sleep(0.01) - responses = self.model.streaming_response_generator( - audio_chunks=self._yield_audio_chunks(audio_source), - streaming_config=self.model_config, - ) - - for response in responses: - if not response.results: - continue - - result = response.results[0] - if not result.alternatives: - continue - - transcript = result.alternatives[0].transcript.strip() - if not transcript: - continue + try: + responses = self.model.streaming_response_generator( + audio_chunks=self._yield_audio_chunks(audio_source), + streaming_config=self.model_config, + ) - if result.is_final: - logging.info(f"Final ASR Result: {transcript}") - if self.callback: - self.callback(json.dumps({"asr_reply": transcript})) - else: - logging.info(f"Interim ASR Result: {transcript}") + for response in responses: + if not self.running: + logger.info("ASR processor stopped by user request") + break + + if not response.results: + continue + + result = response.results[0] + if not result.alternatives: + continue + + transcript = result.alternatives[0].transcript.strip() + if not transcript: + continue + + if result.is_final: + logging.info(f"Final ASR Result: {transcript}") + if self.callback: + self.callback(json.dumps({"asr_reply": transcript})) + else: + logging.info(f"Interim ASR Result: {transcript}") + except Exception as e: + logger.error(f"Error in ASR processing: {e}") + if self.error_callback: + logger.info("Invoking error callback to close connection") + self.error_callback() + self.running = False + raise def stop(self): """ diff --git a/tests/om1_speech/processor/test_speech_connection_processor.py b/tests/om1_speech/processor/test_speech_connection_processor.py index d0fd161..83b77bb 100644 --- a/tests/om1_speech/processor/test_speech_connection_processor.py +++ b/tests/om1_speech/processor/test_speech_connection_processor.py @@ -10,9 +10,10 @@ class MockASRProcessor(ASRProcessorInterface): - def __init__(self, args, callback=None): + def __init__(self, args, callback=None, error_callback=None): self.args = args self.callback = callback + self.error_callback = error_callback self.stopped = False def on_audio(self, audio: bytes) -> bytes: