88import logging
99from dataclasses import dataclass , replace
1010from datetime import datetime , timezone
11- from typing import AsyncIterator
11+ from typing import AsyncIterator , TypeVar
1212
1313import grpc
1414import grpc .aio
4444_logger = logging .getLogger (__name__ )
4545
4646
47+ T = TypeVar ("T" )
48+
49+
50+ class _MockStream (AsyncIterator [T ]):
51+ """A mock stream that wraps an async iterator and adds initial_metadata."""
52+
53+ def __init__ (self , stream : AsyncIterator [T ]) -> None :
54+ """Initialize the mock stream.
55+
56+ Args:
57+ stream: The stream to wrap.
58+ """
59+ self ._iterator = stream .__aiter__ ()
60+
61+ async def initial_metadata (self ) -> None :
62+ """Do nothing, just to mock the grpc call."""
63+ _logger .debug ("Called initial_metadata()" )
64+
65+ def __aiter__ (self ) -> AsyncIterator [T ]:
66+ """Return the async iterator."""
67+ return self
68+
69+ async def __anext__ (self ) -> T :
70+ """Return the next item from the stream."""
71+ return await self ._iterator .__anext__ ()
72+
73+
4774class FakeService :
4875 """Dispatch mock service for testing."""
4976
@@ -109,11 +136,11 @@ async def ListMicrogridDispatches(
109136 ),
110137 )
111138
112- async def StreamMicrogridDispatches (
139+ def StreamMicrogridDispatches (
113140 self ,
114141 request : StreamMicrogridDispatchesRequest ,
115142 timeout : int = 5 , # pylint: disable=unused-argument
116- ) -> AsyncIterator [StreamMicrogridDispatchesResponse ]:
143+ ) -> _MockStream [StreamMicrogridDispatchesResponse ]:
117144 """Stream microgrid dispatches changes.
118145
119146 Args:
@@ -122,20 +149,28 @@ async def StreamMicrogridDispatches(
122149
123150 Returns:
124151 An async generator for dispatch changes.
125-
126- Yields:
127- An event for each dispatch change.
128152 """
129- receiver = self ._stream_channel .new_receiver ()
130-
131- async for message in receiver :
132- _logger .debug ("Received message: %s" , message )
133- if message .microgrid_id == MicrogridId (request .microgrid_id ):
134- response = StreamMicrogridDispatchesResponse (
135- event = message .event .event .value ,
136- dispatch = message .event .dispatch .to_protobuf (),
137- )
138- yield response
153+
154+ async def stream () -> AsyncIterator [StreamMicrogridDispatchesResponse ]:
155+ """Stream microgrid dispatches changes."""
156+ _logger .debug ("Starting stream for microgrid %s" , request .microgrid_id )
157+ receiver = self ._stream_channel .new_receiver ()
158+
159+ async for message in receiver :
160+ _logger .debug ("Received message: %s" , message )
161+ if message .microgrid_id == MicrogridId (request .microgrid_id ):
162+ response = StreamMicrogridDispatchesResponse (
163+ event = message .event .event .value ,
164+ dispatch = message .event .dispatch .to_protobuf (),
165+ )
166+ yield response
167+ else :
168+ _logger .debug (
169+ "Skipping message for microgrid %s" ,
170+ message .microgrid_id ,
171+ )
172+
173+ return _MockStream (stream ())
139174
140175 # pylint: disable=too-many-branches
141176 @staticmethod
0 commit comments