Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions src/inputs/plugins/riva_asr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import re
import time
from typing import Dict, List, Optional
from uuid import uuid4
Expand Down Expand Up @@ -56,6 +57,101 @@ class RivaASRSensorConfig(SensorConfig):
)


# Keywords that suggest speech is directed at the robot (greeter context)
# All lowercase — matching is case-insensitive
# Avoid generic words (is, are, do, will, etc.) that appear in any conversation
_DIRECTED_KEYWORDS = {
# Addressing the robot directly
"you",
"your",
"yours",
"yourself",
"bits",
"robot",
"dog",
"puppy",
"buddy",
"openmind",
"om1",
# Greetings and social
"hello",
"hi",
"hey",
"howdy",
"greetings",
"goodbye",
"bye",
"thanks",
"thank",
"please",
# Questions — only question words, not auxiliaries
"what",
"how",
"why",
"who",
"where",
"when",
"which",
# Requests directed at the robot
"tell",
"show",
"explain",
"describe",
"help",
# Conference / demo context
"gtc",
"nvidia",
"unitree",
"conference",
"demo",
"booth",
"exhibit",
# Product questions
"name",
"company",
"product",
"price",
"cost",
"buy",
"available",
"software",
"ai",
"autonomous",
"platform",
}


# Common ASR misrecognitions → correct text (case-insensitive)
_ASR_CORRECTIONS = [
(
re.compile(
r"\b(?:om one|ol one|on one|om 1|ol 1|oh and one|o one|oh one)\b",
re.IGNORECASE,
),
"OM1",
),
(re.compile(r"\b(?:open mind|pokemon)\b", re.IGNORECASE), "OpenMind"),
(re.compile(r"\bunit tree\b", re.IGNORECASE), "Unitree"),
]


def _normalize_asr_text(text: str) -> str:
"""Fix common ASR misrecognitions."""
for pattern, replacement in _ASR_CORRECTIONS:
text = pattern.sub(replacement, text)
return text


def _seems_directed_at_robot(text: str) -> bool:
"""Check if the transcript seems directed at the robot rather than overheard chatter."""
words = set(text.lower().split())
if words & _DIRECTED_KEYWORDS:
return True
if text.rstrip().endswith("?"):
return True
return False


class RivaASRInput(FuserInput[RivaASRSensorConfig, Optional[str]]):
"""
Automatic Speech Recognition (ASR) input handler.
Expand Down Expand Up @@ -146,8 +242,11 @@ def _handle_asr_message(self, raw_message: str):
try:
json_message: Dict = json.loads(raw_message)
if "asr_reply" in json_message:
asr_reply = json_message["asr_reply"]
if len(asr_reply.split()) > 1:
asr_reply = _normalize_asr_text(json_message["asr_reply"])
if len(asr_reply.split()) > 2:
if not _seems_directed_at_robot(asr_reply):
logging.info("ASR filtered as overheard chatter: %s", asr_reply)
return
self.message_buffer.put_nowait(asr_reply)
logging.info("Detected ASR message: %s", asr_reply)
except json.JSONDecodeError:
Expand Down
8 changes: 6 additions & 2 deletions src/inputs/plugins/riva_asr_rtsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from inputs.base import Message, SensorConfig
from inputs.base.loop import FuserInput
from inputs.plugins.riva_asr import _normalize_asr_text, _seems_directed_at_robot
from providers.asr_rtsp_provider import ASRRTSPProvider
from providers.io_provider import IOProvider
from providers.sleep_ticker_provider import SleepTickerProvider
Expand Down Expand Up @@ -134,8 +135,11 @@ def _handle_asr_message(self, raw_message: str):
try:
json_message: Dict = json.loads(raw_message)
if "asr_reply" in json_message:
asr_reply = json_message["asr_reply"]
if len(asr_reply.split()) > 1:
asr_reply = _normalize_asr_text(json_message["asr_reply"])
if len(asr_reply.split()) > 2:
if not _seems_directed_at_robot(asr_reply):
logging.info("ASR filtered as overheard chatter: %s", asr_reply)
return
self.message_buffer.put_nowait(asr_reply)
logging.info("Detected ASR message: %s", asr_reply)
except json.JSONDecodeError:
Expand Down
100 changes: 96 additions & 4 deletions tests/inputs/plugins/test_riva_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,102 @@
import pytest

from inputs.base import Message
from inputs.plugins.riva_asr import RivaASRInput, RivaASRSensorConfig
from inputs.plugins.riva_asr import (
RivaASRInput,
RivaASRSensorConfig,
_normalize_asr_text,
_seems_directed_at_robot,
)


class TestNormalizeAsrText:
"""Tests for _normalize_asr_text."""

def test_corrects_om_one_variants(self):
assert "OM1" in _normalize_asr_text("tell me about om one")
assert "OM1" in _normalize_asr_text("what is ol one")
assert "OM1" in _normalize_asr_text("I like om 1")

def test_corrects_open_mind(self):
assert "OpenMind" in _normalize_asr_text("this is open mind")

def test_corrects_unit_tree(self):
assert "Unitree" in _normalize_asr_text("the unit tree robot")

def test_no_correction_needed(self):
assert _normalize_asr_text("hello world") == "hello world"


class TestSeemsDirectedAtRobot:
"""Tests for _seems_directed_at_robot."""

def test_directed_keyword_match(self):
assert _seems_directed_at_robot("hello how are you") is True

def test_question_mark_detected(self):
assert _seems_directed_at_robot("something random stuff?") is True

def test_overheard_chatter_filtered(self):
assert _seems_directed_at_robot("yeah totally agree man") is False

def test_case_insensitive(self):
assert _seems_directed_at_robot("HELLO there friend") is True


def test_handle_asr_message_filters_overheard_chatter():
"""Test that _handle_asr_message filters messages not directed at robot."""
with (
patch("inputs.plugins.riva_asr.IOProvider"),
patch("inputs.plugins.riva_asr.ASRProvider") as mock_asr,
patch("inputs.plugins.riva_asr.SleepTickerProvider"),
):
mock_asr_instance = MagicMock()
mock_asr.return_value = mock_asr_instance

config = RivaASRSensorConfig()
sensor = RivaASRInput(config=config)

# Message with >2 words but not directed at robot
raw_message = '{"asr_reply": "yeah totally agree man"}'
sensor._handle_asr_message(raw_message)
assert sensor.message_buffer.qsize() == 0


def test_handle_asr_message_accepts_directed_speech():
"""Test that _handle_asr_message accepts messages directed at robot."""
with (
patch("inputs.plugins.riva_asr.IOProvider"),
patch("inputs.plugins.riva_asr.ASRProvider") as mock_asr,
patch("inputs.plugins.riva_asr.SleepTickerProvider"),
):
mock_asr_instance = MagicMock()
mock_asr.return_value = mock_asr_instance

config = RivaASRSensorConfig()
sensor = RivaASRInput(config=config)

raw_message = '{"asr_reply": "hello how are you"}'
sensor._handle_asr_message(raw_message)
assert sensor.message_buffer.qsize() == 1


def test_handle_asr_message_normalizes_text():
"""Test that _handle_asr_message applies ASR text normalization."""
with (
patch("inputs.plugins.riva_asr.IOProvider"),
patch("inputs.plugins.riva_asr.ASRProvider") as mock_asr,
patch("inputs.plugins.riva_asr.SleepTickerProvider"),
):
mock_asr_instance = MagicMock()
mock_asr.return_value = mock_asr_instance

config = RivaASRSensorConfig()
sensor = RivaASRInput(config=config)

raw_message = '{"asr_reply": "tell me about om one please"}'
sensor._handle_asr_message(raw_message)
assert sensor.message_buffer.qsize() == 1
assert "OM1" in sensor.message_buffer.get_nowait()


def test_initialization():
Expand All @@ -16,7 +111,6 @@ def test_initialization():
patch("inputs.plugins.riva_asr.TeleopsConversationProvider"),
patch("inputs.plugins.riva_asr.open_zenoh_session") as mock_zenoh,
):

mock_asr_instance = MagicMock()
mock_asr.return_value = mock_asr_instance
mock_session = MagicMock()
Expand Down Expand Up @@ -100,7 +194,6 @@ async def test_poll():
patch("inputs.plugins.riva_asr.TeleopsConversationProvider"),
patch("inputs.plugins.riva_asr.open_zenoh_session"),
):

config = RivaASRSensorConfig()
sensor = RivaASRInput(config=config)

Expand All @@ -119,7 +212,6 @@ async def test_poll_with_message():
patch("inputs.plugins.riva_asr.TeleopsConversationProvider"),
patch("inputs.plugins.riva_asr.open_zenoh_session"),
):

config = RivaASRSensorConfig()
sensor = RivaASRInput(config=config)

Expand Down
81 changes: 81 additions & 0 deletions tests/inputs/plugins/test_riva_asr_rtsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,87 @@ def test_handle_asr_message_ignores_json_without_asr_reply(
assert final_size == initial_size


def test_handle_asr_message_filters_overheard_chatter(
mock_io_provider,
mock_asr_provider,
mock_sleep_ticker_provider,
mock_teleops_conversation_provider,
mock_zenoh,
):
"""Test that _handle_asr_message filters messages not directed at robot."""
_, mock_asr_instance = mock_asr_provider
_, mock_sleep_ticker_instance = mock_sleep_ticker_provider
_, mock_teleops_conv_instance = mock_teleops_conversation_provider

config = RivaASRRTSPSensorConfig()
with (
patch("inputs.plugins.riva_asr_rtsp.IOProvider", return_value=mock_io_provider),
patch(
"inputs.plugins.riva_asr_rtsp.ASRRTSPProvider",
return_value=mock_asr_instance,
),
patch(
"inputs.plugins.riva_asr_rtsp.SleepTickerProvider",
return_value=mock_sleep_ticker_instance,
),
patch(
"inputs.plugins.riva_asr_rtsp.TeleopsConversationProvider",
return_value=mock_teleops_conv_instance,
),
patch(
"inputs.plugins.riva_asr_rtsp.open_zenoh_session",
mock_zenoh["open_session"],
),
):
instance = RivaASRRTSPInput(config=config)

# Message with >2 words but not directed at robot should be filtered
raw_message = '{"asr_reply": "yeah totally agree man"}'
initial_size = instance.message_buffer.qsize()
instance._handle_asr_message(raw_message)
assert instance.message_buffer.qsize() == initial_size


def test_handle_asr_message_normalizes_text(
mock_io_provider,
mock_asr_provider,
mock_sleep_ticker_provider,
mock_teleops_conversation_provider,
mock_zenoh,
):
"""Test that _handle_asr_message applies ASR text normalization."""
_, mock_asr_instance = mock_asr_provider
_, mock_sleep_ticker_instance = mock_sleep_ticker_provider
_, mock_teleops_conv_instance = mock_teleops_conversation_provider

config = RivaASRRTSPSensorConfig()
with (
patch("inputs.plugins.riva_asr_rtsp.IOProvider", return_value=mock_io_provider),
patch(
"inputs.plugins.riva_asr_rtsp.ASRRTSPProvider",
return_value=mock_asr_instance,
),
patch(
"inputs.plugins.riva_asr_rtsp.SleepTickerProvider",
return_value=mock_sleep_ticker_instance,
),
patch(
"inputs.plugins.riva_asr_rtsp.TeleopsConversationProvider",
return_value=mock_teleops_conv_instance,
),
patch(
"inputs.plugins.riva_asr_rtsp.open_zenoh_session",
mock_zenoh["open_session"],
),
):
instance = RivaASRRTSPInput(config=config)

raw_message = '{"asr_reply": "tell me about om one please"}'
instance._handle_asr_message(raw_message)
assert instance.message_buffer.qsize() == 1
assert "OM1" in instance.message_buffer.get_nowait()


def test_handle_asr_message_ignores_json_with_asr_reply_shorter_than_two_words(
mock_io_provider,
mock_asr_provider,
Expand Down
Loading