@@ -618,7 +618,7 @@ def __call__(
618618 self ,
619619 chat : Chat ,
620620 / ,
621- ) -> t .Awaitable [Chat | None ]: ...
621+ ) -> t .Awaitable [Chat | None ] | Chat | None : ...
622622
623623
624624@runtime_checkable
@@ -642,7 +642,7 @@ def __call__(
642642 self ,
643643 chats : list [Chat ],
644644 / ,
645- ) -> t .Awaitable [list [Chat ]]: ...
645+ ) -> t .Awaitable [list [Chat ]] | list [ Chat ] : ...
646646
647647
648648@runtime_checkable
@@ -773,7 +773,9 @@ async def traced_watch_callback(chats: list[Chat]) -> None:
773773 chat_count = len (chats ),
774774 chat_ids = [str (c .uuid ) for c in chats ],
775775 ):
776- await callback (chats )
776+ result = callback (chats )
777+ if inspect .isawaitable (result ):
778+ await result
777779
778780 return traced_watch_callback
779781
@@ -1100,11 +1102,6 @@ async def process(chat: Chat) -> Chat | None:
11001102 ```
11011103 """
11021104 for callback in callbacks :
1103- if not asyncio .iscoroutinefunction (callback ):
1104- raise TypeError (
1105- f"Callback '{ get_qualified_name (callback )} ' must be an async function" ,
1106- )
1107-
11081105 if allow_duplicates :
11091106 continue
11101107
@@ -1147,11 +1144,6 @@ async def process(chats: list[Chat]) -> list[Chat]:
11471144 ```
11481145 """
11491146 for callback in callbacks :
1150- if not asyncio .iscoroutinefunction (callback ):
1151- raise TypeError (
1152- f"Callback '{ get_qualified_name (callback )} ' must be an async function" ,
1153- )
1154-
11551147 if allow_duplicates :
11561148 continue
11571149
@@ -1565,9 +1557,8 @@ async def complete() -> None:
15651557 exit_stack .push_async_callback (complete )
15661558
15671559 result = callback (state .chat )
1568-
15691560 if inspect .isawaitable (result ):
1570- result = await result # type: ignore [assignment]
1561+ result = await result
15711562
15721563 if result is None or isinstance (result , Chat ):
15731564 state .chat = result or state .chat
0 commit comments