Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/om1_speech/processor/connection_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def handle_new_connection(self, connection_id: str):
if self.ws_server
else None
),
error_callback=lambda: (
self.ws_server.close_connection(connection_id)
if self.ws_server
else None
),
)

self.asr_processors[connection_id] = asr_processor
Expand Down
73 changes: 48 additions & 25 deletions src/om1_speech/riva/asr_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,21 @@ class ASRProcessor(ASRProcessorInterface):
Command line arguments and configuration parameters
callback : Optional[Callable], optional
Callback function to receive ASR results (default: None)
error_callback : Optional[Callable], optional
Callback function to be called when an error occurs (default: None)
"""

def __init__(
self, model_args: argparse.Namespace, callback: Optional[Callable] = None
self,
model_args: argparse.Namespace,
callback: Optional[Callable] = None,
error_callback: Optional[Callable] = None,
):
self.model: Optional[ASRService] = None # type: ignore
self.model_config: Optional[StreamingRecognitionConfig] = None # type: ignore
self.args = model_args
self.callback = callback
self.error_callback = error_callback
self.running: bool = True

# ASR settings
Expand Down Expand Up @@ -154,15 +160,20 @@ def process_audio(self, audio_source: Any):

Continuously processes audio chunks from the source and generates
transcriptions. Final transcriptions are logged and sent to the
callback function if provided.
callback function if provided. If an error occurs during processing,
the error_callback is invoked (if provided) to close the connection.

Parameters
----------
audio_source : Any
Source object that provides audio chunks for processing
"""
if self.model is None or self.model_config is None:
raise RuntimeError("ASR model is not initialized.")
error_msg = "ASR model is not initialized."
logger.error(error_msg)
if self.error_callback:
self.error_callback()
raise RuntimeError(error_msg)

logger.info("Waiting for first audio chunk to initialize sample rate...")
while self.running:
Expand All @@ -179,29 +190,41 @@ def process_audio(self, audio_source: Any):
break
time.sleep(0.01)

responses = self.model.streaming_response_generator(
audio_chunks=self._yield_audio_chunks(audio_source),
streaming_config=self.model_config,
)

for response in responses:
if not response.results:
continue

result = response.results[0]
if not result.alternatives:
continue

transcript = result.alternatives[0].transcript.strip()
if not transcript:
continue
try:
responses = self.model.streaming_response_generator(
audio_chunks=self._yield_audio_chunks(audio_source),
streaming_config=self.model_config,
)

if result.is_final:
logging.info(f"Final ASR Result: {transcript}")
if self.callback:
self.callback(json.dumps({"asr_reply": transcript}))
else:
logging.info(f"Interim ASR Result: {transcript}")
for response in responses:
if not self.running:
logger.info("ASR processor stopped by user request")
break

if not response.results:
continue

result = response.results[0]
if not result.alternatives:
continue

transcript = result.alternatives[0].transcript.strip()
if not transcript:
continue

if result.is_final:
logging.info(f"Final ASR Result: {transcript}")
if self.callback:
self.callback(json.dumps({"asr_reply": transcript}))
else:
logging.info(f"Interim ASR Result: {transcript}")
except Exception as e:
logger.error(f"Error in ASR processing: {e}")
if self.error_callback:
logger.info("Invoking error callback to close connection")
self.error_callback()
self.running = False
raise

def stop(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@


class MockASRProcessor(ASRProcessorInterface):
def __init__(self, args, callback=None):
def __init__(self, args, callback=None, error_callback=None):
self.args = args
self.callback = callback
self.error_callback = error_callback
self.stopped = False

def on_audio(self, audio: bytes) -> bytes:
Expand Down
Loading