Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/strands/experimental/bidi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from ..types.agent import BidiAgentInput
from ..types.events import (
BidiAudioInputEvent,
BidiConnectionCloseEvent,
BidiImageInputEvent,
BidiInputEvent,
BidiOutputEvent,
Expand Down
18 changes: 6 additions & 12 deletions src/strands/experimental/bidi/models/gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ def __init__(
self._client_config = self._resolve_client_config(client_config or {})

# Resolve provider config with defaults
self._provider_config = self._resolve_provider_config(provider_config or {})

# Extract and store audio config for IO coordination
self.config: dict[str, Any] = {"audio": self._provider_config["audio"]}
self.config = self._resolve_provider_config(provider_config or {})

# Store API key for later use
self.api_key = self._client_config.get("api_key")
Expand Down Expand Up @@ -113,10 +110,7 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]:
provider_voice = None
if "speech_config" in config and isinstance(config["speech_config"], dict):
provider_voice = (
config["speech_config"]
.get("voice_config", {})
.get("prebuilt_voice_config", {})
.get("voice_name")
config["speech_config"].get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name")
)

# Define default audio configuration
Expand Down Expand Up @@ -283,8 +277,8 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut
BidiAudioStreamEvent(
audio=audio_b64,
format="pcm",
sample_rate=cast(SampleRate, self.config["audio"]["output_rate"]),
channels=cast(Channel, self.config["audio"]["channels"]),
sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]),
channels=cast(AudioChannel, self.config["audio"]["channels"]),
)
]

Expand Down Expand Up @@ -494,8 +488,8 @@ def _build_live_config(
to configure any Gemini Live API parameter directly.
"""
config_dict: dict[str, Any] = {}
if self._provider_config:
config_dict.update({k: v for k, v in self._provider_config.items() if k != "audio"})
if self.config:
config_dict.update({k: v for k, v in self.config.items() if k != "audio"})

# Override with any kwargs from start()
config_dict.update(kwargs)
Expand Down
19 changes: 7 additions & 12 deletions src/strands/experimental/bidi/models/novasonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import logging
import uuid
from typing import Any, AsyncGenerator, Literal, cast
from typing import Any, AsyncGenerator, cast

import boto3
from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput
Expand Down Expand Up @@ -117,10 +117,7 @@ def __init__(
self._client_config = self._resolve_client_config(client_config or {})

# Resolve provider config with defaults
self._provider_config = self._resolve_provider_config(provider_config or {})

# Extract and store audio config for IO coordination
self.config: dict[str, Any] = {"audio": self._provider_config["audio"]}
self.config = self._resolve_provider_config(provider_config or {})

# Store session and region for later use
self._session = self._client_config["boto_session"]
Expand Down Expand Up @@ -167,15 +164,15 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]:
"voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]),
}

user_audio = config.get("audio", {})
merged_audio = {**default_audio, **user_audio}
user_audio_config = config.get("audio", {})
merged_audio = {**default_audio_config, **user_audio_config}

resolved = {
"audio": merged_audio,
**{k: v for k, v in config.items() if k != "audio"},
}

if user_audio:
if user_audio_config:
logger.debug("audio_config | merged user-provided config with defaults")
else:
logger.debug("audio_config | using default Nova Sonic audio configuration")
Expand Down Expand Up @@ -507,13 +504,11 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N
if "audioOutput" in nova_event:
# Audio is already base64 string from Nova Sonic
audio_content = nova_event["audioOutput"]["content"]
# Channels from config is guaranteed to be 1 or 2
channels = cast(Literal[1, 2], self.config["audio"]["channels"])
return BidiAudioStreamEvent(
audio=audio_content,
format="pcm",
sample_rate=cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]),
channels=channels,
sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]),
channels=cast(AudioChannel, self.config["audio"]["channels"]),
)

# Handle text output (transcripts)
Expand Down
27 changes: 14 additions & 13 deletions src/strands/experimental/bidi/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,17 @@ def __init__(
self._client_config = self._resolve_client_config(client_config or {})

# Resolve provider config with defaults
self._provider_config = self._resolve_provider_config(provider_config or {})

# Extract and store audio config for IO coordination
self.config: dict[str, Any] = {"audio": self._provider_config["audio"]}
self.config = self._resolve_provider_config(provider_config or {})

# Store client config values for later use
self.api_key = self._client_config["api_key"]
self.organization = self._client_config.get("organization")
self.project = self._client_config.get("project")
self.timeout_s = self._client_config["timeout_s"]

if self.timeout_s > OPENAI_MAX_TIMEOUT_S:
raise ValueError(
f"timeout_s=<{timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit"
f"timeout_s=<{self.timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit"
)

# Connection state (initialized in start())
Expand All @@ -139,7 +137,7 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]:

if "api_key" not in resolved:
resolved["api_key"] = os.getenv("OPENAI_API_KEY")

if not resolved.get("api_key"):
raise ValueError(
"OpenAI API key is required. Provide via client_config={'api_key': '...'} "
Expand All @@ -149,12 +147,15 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]:
env_org = os.getenv("OPENAI_ORGANIZATION")
if env_org:
resolved["organization"] = env_org

if "project" not in resolved:
env_project = os.getenv("OPENAI_PROJECT")
if env_project:
resolved["project"] = env_project

if "timeout_s" not in resolved:
resolved["timeout_s"] = OPENAI_MAX_TIMEOUT_S

return resolved

def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]:
Expand All @@ -167,8 +168,8 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]:

# Define default audio configuration
default_audio: AudioConfig = {
"input_rate": DEFAULT_SAMPLE_RATE,
"output_rate": DEFAULT_SAMPLE_RATE,
"input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE),
"output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE),
"channels": 1,
"format": "pcm",
"voice": provider_voice or "alloy",
Expand Down Expand Up @@ -288,7 +289,7 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec]
"turn_detection",
}

for key, value in self._provider_config.items():
for key, value in self.config.items():
if key == "audio":
continue
elif key in supported_params:
Expand All @@ -297,15 +298,15 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec]
logger.warning("parameter=<%s> | ignoring unsupported session parameter", key)

audio_config = self.config["audio"]

if "voice" in audio_config:
config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"]

if "input_rate" in audio_config:
config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[
"input_rate"
]

if "output_rate" in audio_config:
config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[
"output_rate"
Expand Down
40 changes: 12 additions & 28 deletions tests/strands/experimental/bidi/models/test_gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def test_model_initialization(mock_genai_client, model_id, api_key):
assert model_default.api_key is None
assert model_default._live_session is None
# Check default config includes transcription
assert model_default.provider_config["response_modalities"] == ["AUDIO"]
assert "outputAudioTranscription" in model_default.provider_config
assert "inputAudioTranscription" in model_default.provider_config
assert model_default.config["response_modalities"] == ["AUDIO"]
assert "outputAudioTranscription" in model_default.config
assert "inputAudioTranscription" in model_default.config

# Test with API key
model_with_key = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key})
Expand All @@ -110,10 +110,10 @@ def test_model_initialization(mock_genai_client, model_id, api_key):
provider_config = {"temperature": 0.7, "top_p": 0.9}
model_custom = BidiGeminiLiveModel(model_id=model_id, provider_config=provider_config)
# Custom config should be merged with defaults
assert model_custom.provider_config["temperature"] == 0.7
assert model_custom.provider_config["top_p"] == 0.9
assert model_custom.config["temperature"] == 0.7
assert model_custom.config["top_p"] == 0.9
# Defaults should still be present
assert "response_modalities" in model_custom.provider_config
assert "response_modalities" in model_custom.config


# Connection Tests
Expand Down Expand Up @@ -465,8 +465,8 @@ def test_audio_config_partial_override(mock_genai_client, model_id, api_key):
"""Test partial audio configuration override."""
_ = mock_genai_client

config = {"audio": {"output_rate": 48000, "voice": "Puck"}}
model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config)
provider_config = {"audio": {"output_rate": 48000, "voice": "Puck"}}
model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config)

# Overridden values
assert model.config["audio"]["output_rate"] == 48000
Expand All @@ -482,7 +482,7 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key):
"""Test full audio configuration override."""
_ = mock_genai_client

config = {
provider_config = {
"audio": {
"input_rate": 48000,
"output_rate": 48000,
Expand All @@ -491,7 +491,7 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key):
"voice": "Aoede",
}
}
model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config)
model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config)

assert model.config["audio"]["input_rate"] == 48000
assert model.config["audio"]["output_rate"] == 48000
Expand All @@ -500,22 +500,6 @@ def test_audio_config_full_override(mock_genai_client, model_id, api_key):
assert model.config["audio"]["voice"] == "Aoede"


def test_audio_config_voice_priority(mock_genai_client, model_id, api_key):
"""Test that config audio voice takes precedence over provider_config voice."""
_ = mock_genai_client

provider_config = {"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Puck"}}}}
config = {"audio": {"voice": "Aoede"}}

model = BidiGeminiLiveModel(
model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config, config=config
)

# Build config and verify config audio voice takes precedence
built_config = model._build_live_config()
assert built_config["speech_config"]["voice_config"]["prebuilt_voice_config"]["voice_name"] == "Aoede"


# Helper Method Tests


Expand Down Expand Up @@ -557,8 +541,8 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key
_, _, _ = mock_genai_client

# Create model with custom audio configuration
config = {"audio": {"output_rate": 48000, "channels": 2}}
model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, config=config)
provider_config = {"audio": {"output_rate": 48000, "channels": 2}}
model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config)
await model.start()

# Test audio output event uses custom configuration
Expand Down
5 changes: 3 additions & 2 deletions tests/strands/experimental/bidi/models/test_novasonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,14 +568,15 @@ async def test_default_audio_rates_in_events(model_id, region):


# Error Handling Tests
@pytest.mark.asyncio
async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream):
mock_output = AsyncMock()
mock_output.receive.side_effect = ModelTimeoutException("Connection timeout")
mock_stream.await_output.return_value = (None, mock_output)

await nova_model.start()

with pytest.raises(BidiModelTimeoutError):
with pytest.raises(BidiModelTimeoutError, match=r"Connection timeout"):
async for _ in nova_model.receive():
pass

Expand All @@ -588,7 +589,7 @@ async def test_bidi_nova_sonic_model_receive_timeout_validation(nova_model, mock

await nova_model.start()

with pytest.raises(BidiModelTimeoutError):
with pytest.raises(BidiModelTimeoutError, match=r"InternalErrorCode=531"):
async for _ in nova_model.receive():
pass

Expand Down
28 changes: 7 additions & 21 deletions tests/strands/experimental/bidi/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def test_audio_config_defaults(api_key, model_name):

def test_audio_config_partial_override(api_key, model_name):
"""Test partial audio configuration override."""
config = {"audio": {"output_rate": 48000, "voice": "echo"}}
model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, config=config)
provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}}
model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config)

# Overridden values
assert model.config["audio"]["output_rate"] == 48000
Expand All @@ -145,7 +145,7 @@ def test_audio_config_partial_override(api_key, model_name):

def test_audio_config_full_override(api_key, model_name):
"""Test full audio configuration override."""
config = {
provider_config = {
"audio": {
"input_rate": 48000,
"output_rate": 48000,
Expand All @@ -154,7 +154,7 @@ def test_audio_config_full_override(api_key, model_name):
"voice": "shimmer",
}
}
model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, config=config)
model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config)

assert model.config["audio"]["input_rate"] == 48000
assert model.config["audio"]["output_rate"] == 48000
Expand All @@ -163,23 +163,9 @@ def test_audio_config_full_override(api_key, model_name):
assert model.config["audio"]["voice"] == "shimmer"


def test_audio_config_voice_priority(api_key, model_name):
"""Test that config audio voice takes precedence over provider_config voice."""
provider_config = {"audio": {"output": {"voice": "alloy"}}}
config = {"audio": {"voice": "nova"}}

model = BidiOpenAIRealtimeModel(
model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config, config=config
)

# Build config and verify config audio voice takes precedence
built_config = model._build_session_config(None, None)
assert built_config["audio"]["output"]["voice"] == "nova"


def test_audio_config_extracts_voice_from_provider_config(api_key, model_name):
"""Test that voice is extracted from provider_config when config audio not provided."""
provider_config = {"audio": {"output": {"voice": "fable"}}}
provider_config = {"audio": {"voice": "fable"}}

model = BidiOpenAIRealtimeModel(
model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config
Expand Down Expand Up @@ -537,7 +523,7 @@ async def test_receive_timeout(mock_time, model):

await model.start()

with pytest.raises(BidiModelTimeoutError):
with pytest.raises(BidiModelTimeoutError, match=r"timeout_s=<1>"):
async for _ in model.receive():
pass

Expand Down Expand Up @@ -712,7 +698,7 @@ async def test_custom_audio_sample_rate(mock_websockets_connect, api_key):

# Create model with custom sample rate
custom_sample_rate = 48000
provider_config = {"audio": {"output": {"format": {"rate": custom_sample_rate}}}}
provider_config = {"audio": {"output_rate": custom_sample_rate}}
model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}, provider_config=provider_config)

await model.start()
Expand Down
Loading