Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/strands/experimental/bidi/agent/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ async def _restart_connection(self, timeout_error: BidiModelTimeoutError) -> Non
self._agent.system_prompt,
self._agent.tool_registry.get_all_tool_specs(),
self._agent.messages,
**timeout_error.restart_config,
)
self._task_pool.create(self._run_model())
except Exception as exception:
Expand Down
11 changes: 10 additions & 1 deletion src/strands/experimental/bidi/models/bidi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,13 @@ class BidiModelTimeoutError(Exception):
to create a seamless, uninterrupted experience for the user.
"""

pass
def __init__(self, message: str, **restart_config: Any) -> None:
"""Initialize error.

Args:
message: Timeout message from model.
**restart_config: Configure restart specific behaviors in the call to model start.
"""
super().__init__(self, message)

self.restart_config = restart_config
25 changes: 20 additions & 5 deletions src/strands/experimental/bidi/models/gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
BidiUsageEvent,
ModalityUsage,
)
from .bidi_model import BidiModel
from .bidi_model import BidiModel, BidiModelTimeoutError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,6 +92,7 @@ def __init__(
# Connection state (initialized in start())
self._live_session: Any = None
self._live_session_context_manager: Any = None
self._live_session_handle: str | None = None
self._connection_id: str | None = None

def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -175,8 +176,8 @@ async def start(
)
self._live_session = await self._live_session_context_manager.__aenter__()

# Send initial message history if provided
if messages:
# Gemini itself restores message history when resuming from session
if messages and "live_session_handle" not in kwargs:
await self._send_message_history(messages)

async def _send_message_history(self, messages: Messages) -> None:
Expand Down Expand Up @@ -227,7 +228,22 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut

Returns:
List of event dicts (empty list if no events to emit).

Raises:
BidiModelTimeoutError: If gemini responds with go away message.
"""
if message.go_away:
raise BidiModelTimeoutError(
message.go_away.model_dump_json(), live_session_handle=self._live_session_handle
)

if message.session_resumption_update:
resumption_update = message.session_resumption_update
if resumption_update.resumable and resumption_update.new_handle:
self._live_session_handle = resumption_update.new_handle
logger.debug("session_handle=<%s> | updating gemini session handle", self._live_session_handle)
return []

# Handle interruption first (from server_content)
if message.server_content and message.server_content.interrupted:
return [BidiInterruptionEvent(reason="user_speech")]
Expand Down Expand Up @@ -491,8 +507,7 @@ def _build_live_config(
if self.config:
config_dict.update({k: v for k, v in self.config.items() if k != "audio"})

# Override with any kwargs from start()
config_dict.update(kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing as we will want to be explicit with what arguments we unpack from kwargs and place into config.

config_dict["session_resumption"] = {"handle": kwargs.get("live_session_handle")}

# Add system instruction if provided
if system_prompt:
Expand Down
6 changes: 3 additions & 3 deletions src/strands/experimental/bidi/models/novasonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,13 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]:
event_data = await output.receive()

except ValidationException as error:
if "InternalErrorCode=531" in str(error):
if "InternalErrorCode=531" in error.message:
# nova also times out if user is silent for 175 seconds
raise BidiModelTimeoutError(error) from error
raise BidiModelTimeoutError(error.message) from error
raise

except ModelTimeoutException as error:
raise BidiModelTimeoutError(error) from error
raise BidiModelTimeoutError(error.message) from error

if not event_data:
continue
Expand Down
5 changes: 3 additions & 2 deletions tests/strands/experimental/bidi/agent/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def loop(agent):

@pytest.mark.asyncio
async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerator):
timeout_error = BidiModelTimeoutError("test timeout")
timeout_error = BidiModelTimeoutError("test timeout", test_restart_config=1)
text_event = BidiTextInputEvent(text="test after restart")

agent.model.receive = unittest.mock.Mock(side_effect=[timeout_error, agenerator([text_event])])
Expand All @@ -63,10 +63,11 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato

agent.model.stop.assert_called_once()
assert agent.model.start.call_count == 2
agent.model.start.assert_any_call(
agent.model.start.assert_called_with(
agent.system_prompt,
agent.tool_registry.get_all_tool_specs(),
agent.messages,
test_restart_config=1,
)


Expand Down
45 changes: 45 additions & 0 deletions tests/strands/experimental/bidi/models/test_gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
from google.genai import types as genai_types

from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError
from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel
from strands.experimental.bidi.types.events import (
BidiAudioInputEvent,
Expand Down Expand Up @@ -279,6 +280,34 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator):
assert event.connection_id == model._connection_id


@pytest.mark.asyncio
async def test_receive_timeout(mock_genai_client, model, agenerator):
mock_resumption_response = unittest.mock.Mock()
mock_resumption_response.go_away = None
mock_resumption_response.session_resumption_update = unittest.mock.Mock()
mock_resumption_response.session_resumption_update.resumable = True
mock_resumption_response.session_resumption_update.new_handle = "h1"

mock_timeout_response = unittest.mock.Mock()
mock_timeout_response.go_away = unittest.mock.Mock()
mock_timeout_response.go_away.model_dump_json.return_value = "test timeout"

_, mock_live_session, _ = mock_genai_client
mock_live_session.receive = unittest.mock.Mock(
return_value=agenerator([mock_resumption_response, mock_timeout_response])
)

await model.start()

with pytest.raises(BidiModelTimeoutError, match=r"test timeout"):
async for _ in model.receive():
pass

tru_handle = model._live_session_handle
exp_handle = "h1"
assert tru_handle == exp_handle


@pytest.mark.asyncio
async def test_event_conversion(mock_genai_client, model):
"""Test conversion of all Gemini Live event types to standard format."""
Expand All @@ -288,6 +317,8 @@ async def test_event_conversion(mock_genai_client, model):
# Test text output (converted to transcript via model_turn.parts)
mock_text = unittest.mock.Mock()
mock_text.data = None
mock_text.go_away = None
mock_text.session_resumption_update = None
mock_text.tool_call = None

# Create proper server_content structure with model_turn
Expand Down Expand Up @@ -319,6 +350,8 @@ async def test_event_conversion(mock_genai_client, model):
# Test multiple text parts (should concatenate)
mock_multi_text = unittest.mock.Mock()
mock_multi_text.data = None
mock_multi_text.go_away = None
mock_multi_text.session_resumption_update = None
mock_multi_text.tool_call = None

mock_server_content_multi = unittest.mock.Mock()
Expand Down Expand Up @@ -347,6 +380,8 @@ async def test_event_conversion(mock_genai_client, model):
mock_audio = unittest.mock.Mock()
mock_audio.text = None
mock_audio.data = b"audio_data"
mock_audio.go_away = None
mock_audio.session_resumption_update = None
mock_audio.tool_call = None
mock_audio.server_content = None

Expand All @@ -373,6 +408,8 @@ async def test_event_conversion(mock_genai_client, model):
mock_tool = unittest.mock.Mock()
mock_tool.text = None
mock_tool.data = None
mock_tool.go_away = None
mock_tool.session_resumption_update = None
mock_tool.tool_call = mock_tool_call
mock_tool.server_content = None

Expand Down Expand Up @@ -404,6 +441,8 @@ async def test_event_conversion(mock_genai_client, model):
mock_tool_multi = unittest.mock.Mock()
mock_tool_multi.text = None
mock_tool_multi.data = None
mock_tool_multi.go_away = None
mock_tool_multi.session_resumption_update = None
mock_tool_multi.tool_call = mock_tool_call_multi
mock_tool_multi.server_content = None

Expand Down Expand Up @@ -431,6 +470,8 @@ async def test_event_conversion(mock_genai_client, model):
mock_interrupt = unittest.mock.Mock()
mock_interrupt.text = None
mock_interrupt.data = None
mock_interrupt.go_away = None
mock_interrupt.session_resumption_update = None
mock_interrupt.tool_call = None
mock_interrupt.server_content = mock_server_content

Expand Down Expand Up @@ -549,6 +590,8 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key
mock_audio = unittest.mock.Mock()
mock_audio.text = None
mock_audio.data = b"audio_data"
mock_audio.go_away = None
mock_audio.session_resumption_update = None
mock_audio.tool_call = None
mock_audio.server_content = None

Expand Down Expand Up @@ -577,6 +620,8 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke
mock_audio = unittest.mock.Mock()
mock_audio.text = None
mock_audio.data = b"audio_data"
mock_audio.go_away = None
mock_audio.session_resumption_update = None
mock_audio.tool_call = None
mock_audio.server_content = None

Expand Down
Loading