diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 56c55e4a9..74b65ba10 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -35,7 +35,6 @@ from ..types.agent import BidiAgentInput from ..types.events import ( BidiAudioInputEvent, - BidiConnectionCloseEvent, BidiImageInputEvent, BidiInputEvent, BidiOutputEvent, diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 79030e03f..2e9a13b54 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -81,10 +81,7 @@ def __init__( self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self._provider_config = self._resolve_provider_config(provider_config or {}) - - # Extract and store audio config for IO coordination - self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} + self.config = self._resolve_provider_config(provider_config or {}) # Store API key for later use self.api_key = self._client_config.get("api_key") @@ -113,10 +110,7 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: provider_voice = None if "speech_config" in config and isinstance(config["speech_config"], dict): provider_voice = ( - config["speech_config"] - .get("voice_config", {}) - .get("prebuilt_voice_config", {}) - .get("voice_name") + config["speech_config"].get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") ) # Define default audio configuration @@ -283,8 +277,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut BidiAudioStreamEvent( audio=audio_b64, format="pcm", - sample_rate=cast(SampleRate, self.config["audio"]["output_rate"]), - channels=cast(Channel, self.config["audio"]["channels"]), + sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), + channels=cast(AudioChannel, self.config["audio"]["channels"]), ) ] @@ -494,8 +488,8 @@ def _build_live_config( to configure any Gemini Live API parameter directly. """ config_dict: dict[str, Any] = {} - if self._provider_config: - config_dict.update({k: v for k, v in self._provider_config.items() if k != "audio"}) + if self.config: + config_dict.update({k: v for k, v in self.config.items() if k != "audio"}) # Override with any kwargs from start() config_dict.update(kwargs) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 2a16ee91e..24c932ab0 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -17,7 +17,7 @@ import json import logging import uuid -from typing import Any, AsyncGenerator, Literal, cast +from typing import Any, AsyncGenerator, cast import boto3 from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput @@ -117,10 +117,7 @@ def __init__( self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self._provider_config = self._resolve_provider_config(provider_config or {}) - - # Extract and store audio config for IO coordination - self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} + self.config = self._resolve_provider_config(provider_config or {}) # Store session and region for later use self._session = self._client_config["boto_session"] @@ -167,15 +164,15 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), } - user_audio = config.get("audio", {}) - merged_audio = {**default_audio, **user_audio} + user_audio_config = config.get("audio", {}) + merged_audio = {**default_audio_config, **user_audio_config} resolved = { "audio": merged_audio, **{k: v for k, v in config.items() if k != "audio"}, } - if user_audio: + if user_audio_config: logger.debug("audio_config | merged user-provided config with defaults") else: logger.debug("audio_config | using default Nova Sonic audio configuration") @@ -507,13 +504,11 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N if "audioOutput" in nova_event: # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] - # Channels from config is guaranteed to be 1 or 2 - channels = cast(Literal[1, 2], self.config["audio"]["channels"]) return BidiAudioStreamEvent( audio=audio_content, format="pcm", - sample_rate=cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), - channels=channels, + sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), + channels=cast(AudioChannel, self.config["audio"]["channels"]), ) # Handle text output (transcripts) diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py index 59fb55f5b..bfe3ad533 100644 --- a/src/strands/experimental/bidi/models/openai.py +++ b/src/strands/experimental/bidi/models/openai.py @@ -111,19 +111,17 @@ def __init__( self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self._provider_config = self._resolve_provider_config(provider_config or {}) - - # Extract and store audio config for IO coordination - self.config: dict[str, Any] = {"audio": self._provider_config["audio"]} + self.config = self._resolve_provider_config(provider_config or {}) # Store client config values for later use self.api_key = self._client_config["api_key"] self.organization = self._client_config.get("organization") self.project = self._client_config.get("project") + self.timeout_s = self._client_config["timeout_s"] if self.timeout_s > OPENAI_MAX_TIMEOUT_S: raise ValueError( - f"timeout_s=<{timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" + f"timeout_s=<{self.timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" ) # Connection state (initialized in start()) @@ -139,7 +137,7 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: if "api_key" not in resolved: resolved["api_key"] = os.getenv("OPENAI_API_KEY") - + if not resolved.get("api_key"): raise ValueError( "OpenAI API key is required. Provide via client_config={'api_key': '...'} " @@ -149,12 +147,15 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: env_org = os.getenv("OPENAI_ORGANIZATION") if env_org: resolved["organization"] = env_org - + if "project" not in resolved: env_project = os.getenv("OPENAI_PROJECT") if env_project: resolved["project"] = env_project + if "timeout_s" not in resolved: + resolved["timeout_s"] = OPENAI_MAX_TIMEOUT_S + return resolved def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: @@ -167,8 +168,8 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: # Define default audio configuration default_audio: AudioConfig = { - "input_rate": DEFAULT_SAMPLE_RATE, - "output_rate": DEFAULT_SAMPLE_RATE, + "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), + "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), "channels": 1, "format": "pcm", "voice": provider_voice or "alloy", @@ -288,7 +289,7 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] "turn_detection", } - for key, value in self._provider_config.items(): + for key, value in self.config.items(): if key == "audio": continue elif key in supported_params: @@ -297,15 +298,15 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) audio_config = self.config["audio"] - + if "voice" in audio_config: config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"] - + if "input_rate" in audio_config: config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[ "input_rate" ] - + if "output_rate" in audio_config: config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[ "output_rate" diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 48c1d9e09..dec83dbe3 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -97,9 +97,9 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_default.api_key is None assert model_default._live_session is None # Check default config includes transcription - assert model_default.provider_config["response_modalities"] == ["AUDIO"] - assert "outputAudioTranscription" in model_default.provider_config - assert "inputAudioTranscription" in model_default.provider_config + assert model_default.config["response_modalities"] == ["AUDIO"] + assert "outputAudioTranscription" in model_default.config + assert "inputAudioTranscription" in model_default.config # Test with API key model_with_key = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) @@ -110,10 +110,10 @@ def test_model_initialization(mock_genai_client, model_id, api_key): provider_config = {"temperature": 0.7, "top_p": 0.9} model_custom = BidiGeminiLiveModel(model_id=model_id, provider_config=provider_config) # Custom config should be merged with defaults - assert model_custom.provider_config["temperature"] == 0.7 - assert model_custom.provider_config["top_p"] == 0.9 + assert model_custom.config["temperature"] == 0.7 + assert model_custom.config["top_p"] == 0.9 # Defaults should still be present - assert "response_modalities" in model_custom.provider_config + assert "response_modalities" in model_custom.config # Connection Tests @@ -465,8 +465,8 @@ def test_audio_config_partial_override(mock_genai_client, model_id, api_key): """Test partial audio configuration override.""" _ = mock_genai_client - config = {"audio": {"output_rate": 48000, "voice": "Puck"}} - model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) + provider_config = {"audio": {"output_rate": 48000, "voice": "Puck"}} + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -482,7 +482,7 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key): """Test full audio configuration override.""" _ = mock_genai_client - config = { + provider_config = { "audio": { "input_rate": 48000, "output_rate": 48000, @@ -491,7 +491,7 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key): "voice": "Aoede", } } - model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -500,22 +500,6 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key): assert model.config["audio"]["voice"] == "Aoede" -def test_audio_config_voice_priority(mock_genai_client, model_id, api_key): - """Test that config audio voice takes precedence over provider_config voice.""" - _ = mock_genai_client - - provider_config = {"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Puck"}}}} - config = {"audio": {"voice": "Aoede"}} - - model = BidiGeminiLiveModel( - model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config, config=config - ) - - # Build config and verify config audio voice takes precedence - built_config = model._build_live_config() - assert built_config["speech_config"]["voice_config"]["prebuilt_voice_config"]["voice_name"] == "Aoede" - - # Helper Method Tests @@ -557,8 +541,8 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key _, _, _ = mock_genai_client # Create model with custom audio configuration - config = {"audio": {"output_rate": 48000, "channels": 2}} - model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config) + provider_config = {"audio": {"output_rate": 48000, "channels": 2}} + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) await model.start() # Test audio output event uses custom configuration diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 3f4f6c2bc..39524e434 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -568,6 +568,7 @@ async def test_default_audio_rates_in_events(model_id, region): # Error Handling Tests +@pytest.mark.asyncio async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): mock_output = AsyncMock() mock_output.receive.side_effect = ModelTimeoutException("Connection timeout") @@ -575,7 +576,7 @@ async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): await nova_model.start() - with pytest.raises(BidiModelTimeoutError): + with pytest.raises(BidiModelTimeoutError, match=r"Connection timeout"): async for _ in nova_model.receive(): pass @@ -588,7 +589,7 @@ async def test_bidi_nova_sonic_model_receive_timeout_validation(nova_model, mock await nova_model.start() - with pytest.raises(BidiModelTimeoutError): + with pytest.raises(BidiModelTimeoutError, match=r"InternalErrorCode=531"): async for _ in nova_model.receive(): pass diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index ab1705cd9..85a1cc097 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -130,8 +130,8 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" - config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, config=config) + provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -145,7 +145,7 @@ def test_audio_config_partial_override(api_key, model_name): def test_audio_config_full_override(api_key, model_name): """Test full audio configuration override.""" - config = { + provider_config = { "audio": { "input_rate": 48000, "output_rate": 48000, @@ -154,7 +154,7 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, config=config) + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -163,23 +163,9 @@ def test_audio_config_full_override(api_key, model_name): assert model.config["audio"]["voice"] == "shimmer" -def test_audio_config_voice_priority(api_key, model_name): - """Test that config audio voice takes precedence over provider_config voice.""" - provider_config = {"audio": {"output": {"voice": "alloy"}}} - config = {"audio": {"voice": "nova"}} - - model = BidiOpenAIRealtimeModel( - model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config, config=config - ) - - # Build config and verify config audio voice takes precedence - built_config = model._build_session_config(None, None) - assert built_config["audio"]["output"]["voice"] == "nova" - - def test_audio_config_extracts_voice_from_provider_config(api_key, model_name): """Test that voice is extracted from provider_config when config audio not provided.""" - provider_config = {"audio": {"output": {"voice": "fable"}}} + provider_config = {"audio": {"voice": "fable"}} model = BidiOpenAIRealtimeModel( model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config @@ -537,7 +523,7 @@ async def test_receive_timeout(mock_time, model): await model.start() - with pytest.raises(BidiModelTimeoutError): + with pytest.raises(BidiModelTimeoutError, match=r"timeout_s=<1>"): async for _ in model.receive(): pass @@ -712,7 +698,7 @@ async def test_custom_audio_sample_rate(mock_websockets_connect, api_key): # Create model with custom sample rate custom_sample_rate = 48000 - provider_config = {"audio": {"output": {"format": {"rate": custom_sample_rate}}}} + provider_config = {"audio": {"output_rate": custom_sample_rate}} model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}, provider_config=provider_config) await model.start()