diff --git a/snapcast/control/group.py b/snapcast/control/group.py index 98bbad1..f32e5d6 100644 --- a/snapcast/control/group.py +++ b/snapcast/control/group.py @@ -51,7 +51,10 @@ async def set_stream(self, stream_id): @property def stream_status(self): """Get stream status.""" - return self._server.stream(self.stream).status + try: + return self._server.stream(self.stream).status + except KeyError: + return "unknown" @property def muted(self): diff --git a/snapcast/control/server.py b/snapcast/control/server.py index 6a0e570..75cbc61 100644 --- a/snapcast/control/server.py +++ b/snapcast/control/server.py @@ -272,6 +272,8 @@ def group(self, group_identifier): def stream(self, stream_identifier): """Get a stream.""" + if stream_identifier not in self._streams: + raise KeyError(f'Stream "{stream_identifier}" not found') return self._streams[stream_identifier] def client(self, client_identifier): @@ -373,6 +375,20 @@ def _on_group_name_changed(self, data): def _on_group_stream_changed(self, data): """Handle group stream change.""" group = self._groups.get(data.get('id')) + stream_id = data.get('stream_id', None) + + if stream_id not in self._streams: + def update_callback(found): + self._on_update_callback_func() + if not found: + return + group.update_stream(data) + for client_id in group.clients: + self._clients.get(client_id).callback() + + self._synchronize_if_stream_missing(stream_id, update_callback) + return + group.update_stream(data) for client_id in group.clients: self._clients.get(client_id).callback() @@ -442,11 +458,24 @@ def _on_stream_update(self, data): if data.get('stream', {}).get('uri', {}).get('query', {}).get('codec') == 'null': _LOGGER.debug('stream %s is input-only, ignore', data.get('id')) else: - _LOGGER.info('stream %s not found, synchronize', data.get('id')) - - async def async_sync(): - self.synchronize((await self.status())[0]) - asyncio.ensure_future(async_sync()) + self._synchronize_if_stream_missing(data.get('id'), self._on_update_callback_func) + + def _synchronize_if_stream_missing(self, stream_id, callback=None): + """Ensure stream exists, otherwise synchronize.""" + if stream_id is None: + return + if stream_id not in self._streams: + _LOGGER.info('stream "%s" not found, synchronize', stream_id) + + async def async_sync(): + self.synchronize((await self.status())[0]) + found = stream_id in self._streams + if not found: + _LOGGER.warning('stream "%s" still not found after synchronization', stream_id) + if callback and callable(callback): + callback(found, stream_id) + + asyncio.ensure_future(async_sync()) def set_on_update_callback(self, func): """Set on update callback function.""" diff --git a/tests/test_group.py b/tests/test_group.py index bd99ec2..31b13dc 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -101,3 +101,35 @@ def test_set_callback(self): self.group.set_callback(cb) self.group.update_mute({'mute': True}) cb.assert_called_with(self.group) + + def test_bad_stream_status(self): + # Simulate a server where the requested stream id is missing + class DummyClient: + def __init__(self, identifier, friendly_name): + self.identifier = identifier + self.friendly_name = friendly_name + + class DummyServer: + def __init__(self): + self._streams = {} + # provide clients list used by Snapgroup.friendly_name + self.clients = [DummyClient('a', 'A'), DummyClient('b', 'B')] + + def stream(self, stream_identifier): + return self._streams[stream_identifier] + + def client(self, identifier): + # return a client-like object for friendly_name lookup + for c in self.clients: + if c.identifier == identifier: + return c + raise KeyError(identifier) + + # Replace the group's server with the dummy and set an unknown stream id + self.group._server = DummyServer() + + # Updating the stream should not raise; accessing stream_status should + # not raise KeyError because the stream id is not present on the server. + self.group.update_stream({'stream_id': 'no stream'}) + self.assertEqual(self.group.stream_status, 'unknown') + diff --git a/tests/test_server.py b/tests/test_server.py index 01ac701..65aff74 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -74,6 +74,21 @@ 'title': 'Happy!', } } + }, + { + 'id': 'stream2', + 'status': 'playing', + 'uri': { + 'query': { + 'name': 'stream2' + } + }, + 'properties': { + 'canControl': False, + 'metadata': { + 'title': 'Happy2!', + } + } } ] } @@ -168,7 +183,7 @@ def test_init(self): self.assertEqual(self.server.version, '0.26.0') self.assertEqual(len(self.server.clients), 1) self.assertEqual(len(self.server.groups), 1) - self.assertEqual(len(self.server.streams), 1) + self.assertEqual(len(self.server.streams), 2) self.assertEqual(self.server.group('test').identifier, 'test') self.assertEqual(self.server.stream('stream').identifier, 'stream') self.assertEqual(self.server.client('test').identifier, 'test') @@ -282,13 +297,29 @@ def test_on_group_mute(self): self.server._on_group_mute(data) self.assertEqual(self.server.group('test').muted, True) - def test_on_group_stream_changed(self): + @mock.patch.object(Snapserver, '_synchronize_if_stream_missing') + def test_on_group_stream_changed(self, mock_sync): + data = { + 'id': 'test', + 'stream_id': 'stream2' + } + self.server._on_group_stream_changed(data) + self.assertEqual(self.server.group('test').stream, 'stream2') + + mock_sync.assert_not_called() + + @mock.patch.object(Snapserver, '_synchronize_if_stream_missing') + def test_on_group_stream_changed_no_stream(self, mock_sync): data = { 'id': 'test', 'stream_id': 'other' } self.server._on_group_stream_changed(data) - self.assertEqual(self.server.group('test').stream, 'other') + self.assertEqual(self.server.group('test').stream, 'stream2') + + mock_sync.assert_called_once() + _, args, _ = mock_sync.mock_calls[0] + self.assertEqual('other', args[0]) def test_on_client_connect(self): cb = mock.MagicMock() @@ -345,7 +376,8 @@ def test_on_client_latency_changed(self): self.server._on_client_latency_changed(data) self.assertEqual(self.server.client('test').latency, 50) - def test_on_stream_update(self): + @mock.patch.object(Snapserver, '_synchronize_if_stream_missing') + def test_on_stream_update(self, mock_sync): data = { 'id': 'stream', 'stream': { @@ -360,6 +392,26 @@ def test_on_stream_update(self): } self.server._on_stream_update(data) self.assertEqual(self.server.stream('stream').status, 'idle') + mock_sync.assert_not_called() + + @mock.patch.object(Snapserver, '_synchronize_if_stream_missing') + def test_on_stream_update_new(self, mock_sync): + data = { + 'id': 'stream_new', + 'stream': { + 'id': 'stream_new', + 'status': 'idle', + 'uri': { + 'query': { + 'name': 'stream_new' + } + } + } + } + self.server._on_stream_update(data) + mock_sync.assert_called_once() + _, args, _ = mock_sync.mock_calls[0] + self.assertEqual('stream_new', args[0]) def test_on_meta_update(self): data = { diff --git a/tox.ini b/tox.ini index 739e98d..e515fd7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py310, py311, lint +envlist = py310, py311, py313, lint skip_missing_interpreters = True [tool:pytest]