55
66Useful for testing.
77"""
8+ import asyncio
89import logging
910from dataclasses import dataclass , replace
1011from datetime import datetime , timezone
11- from typing import AsyncIterator
12+ from typing import AsyncIterator , TypeVar
1213
1314import grpc
1415import grpc .aio
4445_logger = logging .getLogger (__name__ )
4546
4647
48+ T = TypeVar ("T" )
49+
50+
51+ class _MockStream (AsyncIterator [T ]):
52+ """A mock stream that wraps an async iterator and adds initial_metadata."""
53+
54+ def __init__ (self , stream : AsyncIterator [T ]) -> None :
55+ """Initialize the mock stream.
56+
57+ Args:
58+ stream: The stream to wrap.
59+ """
60+ self ._iterator = stream .__aiter__ ()
61+
62+ async def initial_metadata (self ) -> None :
63+ """Do nothing, just to mock the grpc call."""
64+ _logger .debug ("Called initial_metadata()" )
65+
66+ def __aiter__ (self ) -> AsyncIterator [T ]:
67+ """Return the async iterator."""
68+ return self
69+
70+ async def __anext__ (self ) -> T :
71+ """Return the next item from the stream."""
72+ return await self ._iterator .__anext__ ()
73+
74+
4775class FakeService :
4876 """Dispatch mock service for testing."""
4977
@@ -113,7 +141,7 @@ async def StreamMicrogridDispatches(
113141 self ,
114142 request : StreamMicrogridDispatchesRequest ,
115143 timeout : int = 5 , # pylint: disable=unused-argument
116- ) -> AsyncIterator [StreamMicrogridDispatchesResponse ]:
144+ ) -> _MockStream [StreamMicrogridDispatchesResponse ]:
117145 """Stream microgrid dispatches changes.
118146
119147 Args:
@@ -122,20 +150,28 @@ async def StreamMicrogridDispatches(
122150
123151 Returns:
124152 An async generator for dispatch changes.
125-
126- Yields:
127- An event for each dispatch change.
128153 """
129- receiver = self ._stream_channel .new_receiver ()
130154
131- async for message in receiver :
132- _logger .debug ("Received message: %s" , message )
133- if message .microgrid_id == MicrogridId (request .microgrid_id ):
134- response = StreamMicrogridDispatchesResponse (
135- event = message .event .event .value ,
136- dispatch = message .event .dispatch .to_protobuf (),
137- )
138- yield response
155+ async def stream () -> AsyncIterator [StreamMicrogridDispatchesResponse ]:
156+ """Stream microgrid dispatches changes."""
157+ _logger .debug ("Starting stream for microgrid %s" , request .microgrid_id )
158+ receiver = self ._stream_channel .new_receiver ()
159+
160+ async for message in receiver :
161+ _logger .debug ("Received message: %s" , message )
162+ if message .microgrid_id == MicrogridId (request .microgrid_id ):
163+ response = StreamMicrogridDispatchesResponse (
164+ event = message .event .event .value ,
165+ dispatch = message .event .dispatch .to_protobuf (),
166+ )
167+ yield response
168+ else :
169+ _logger .debug (
170+ "Skipping message for microgrid %s" ,
171+ message .microgrid_id ,
172+ )
173+
174+ return _MockStream (stream ())
139175
140176 # pylint: disable=too-many-branches
141177 @staticmethod
@@ -196,12 +232,18 @@ async def CreateMicrogridDispatch(
196232 # implicitly create the list if it doesn't exist
197233 self .dispatches .setdefault (microgrid_id , []).append (new_dispatch )
198234
235+ _logger .debug ("Created new dispatch: %s" , new_dispatch )
236+
199237 await self ._stream_sender .send (
200238 self .StreamEvent (
201239 microgrid_id ,
202240 DispatchEvent (dispatch = new_dispatch , event = Event .CREATED ),
203241 )
204242 )
243+ # Give the stream a chance to process the message
244+ await asyncio .sleep (0 )
245+
246+ _logger .debug ("Sent create event for dispatch: %s" , new_dispatch )
205247
206248 return CreateMicrogridDispatchResponse (dispatch = new_dispatch .to_protobuf ())
207249
@@ -293,6 +335,8 @@ async def UpdateMicrogridDispatch(
293335 DispatchEvent (dispatch = dispatch , event = Event .UPDATED ),
294336 )
295337 )
338+ # Give the stream a chance to process the message
339+ await asyncio .sleep (0 )
296340
297341 return UpdateMicrogridDispatchResponse (dispatch = dispatch .to_protobuf ())
298342
@@ -352,6 +396,8 @@ async def DeleteMicrogridDispatch(
352396 ),
353397 )
354398 )
399+ # Give the stream a chance to process the message
400+ await asyncio .sleep (0 )
355401
356402 return Empty ()
357403
0 commit comments