2626from pydantic_graph .nodes import End , NodeRunEndT
2727
2828from . import _output , _system_prompt , exceptions , messages as _messages , models , result , usage as _usage
29+ from ._run_context import CURRENT_RUN_CONTEXT
2930from .exceptions import ToolRetryError
3031from .output import OutputDataT , OutputSpec
3132from .settings import ModelSettings
@@ -438,25 +439,29 @@ async def stream(
438439 assert not self ._did_stream , 'stream() should only be called once per node'
439440
440441 model_settings , model_request_parameters , message_history , run_context = await self ._prepare_request (ctx )
441- async with ctx .deps .model .request_stream (
442- message_history , model_settings , model_request_parameters , run_context
443- ) as streamed_response :
444- self ._did_stream = True
445- ctx .state .usage .requests += 1
446- agent_stream = result .AgentStream [DepsT , T ](
447- _raw_stream_response = streamed_response ,
448- _output_schema = ctx .deps .output_schema ,
449- _model_request_parameters = model_request_parameters ,
450- _output_validators = ctx .deps .output_validators ,
451- _run_ctx = build_run_context (ctx ),
452- _usage_limits = ctx .deps .usage_limits ,
453- _tool_manager = ctx .deps .tool_manager ,
454- )
455- yield agent_stream
456- # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
457- # otherwise usage won't be properly counted:
458- async for _ in agent_stream :
459- pass
442+ token = CURRENT_RUN_CONTEXT .set (run_context )
443+ try :
444+ async with ctx .deps .model .request_stream (
445+ message_history , model_settings , model_request_parameters , run_context
446+ ) as streamed_response :
447+ self ._did_stream = True
448+ ctx .state .usage .requests += 1
449+ agent_stream = result .AgentStream [DepsT , T ](
450+ _raw_stream_response = streamed_response ,
451+ _output_schema = ctx .deps .output_schema ,
452+ _model_request_parameters = model_request_parameters ,
453+ _output_validators = ctx .deps .output_validators ,
454+ _run_ctx = build_run_context (ctx ),
455+ _usage_limits = ctx .deps .usage_limits ,
456+ _tool_manager = ctx .deps .tool_manager ,
457+ )
458+ yield agent_stream
459+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
460+ # otherwise usage won't be properly counted:
461+ async for _ in agent_stream :
462+ pass
463+ finally :
464+ CURRENT_RUN_CONTEXT .reset (token )
460465
461466 model_response = streamed_response .get ()
462467
@@ -469,8 +474,12 @@ async def _make_request(
469474 if self ._result is not None :
470475 return self ._result # pragma: no cover
471476
472- model_settings , model_request_parameters , message_history , _ = await self ._prepare_request (ctx )
473- model_response = await ctx .deps .model .request (message_history , model_settings , model_request_parameters )
477+ model_settings , model_request_parameters , message_history , run_context = await self ._prepare_request (ctx )
478+ token = CURRENT_RUN_CONTEXT .set (run_context )
479+ try :
480+ model_response = await ctx .deps .model .request (message_history , model_settings , model_request_parameters )
481+ finally :
482+ CURRENT_RUN_CONTEXT .reset (token )
474483 ctx .state .usage .requests += 1
475484
476485 return self ._finish_handling (ctx , model_response )
0 commit comments