Skip to content

Commit e7c47c8

Browse files
authored
Merge pull request #2254 from trungutt/trungutt/surface-finish-reason
Surface finish_reason on assistant messages and token usage events
2 parents 181e19e + 54e4fb0 commit e7c47c8

6 files changed

Lines changed: 143 additions & 14 deletions

File tree

pkg/chat/chat.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ type Message struct {
8989
// Cost is the cost of this message in dollars (only set for assistant messages)
9090
Cost float64 `json:"cost,omitempty"`
9191

92+
// FinishReason indicates why the model stopped generating for this message.
93+
// "stop" = natural end, "tool_calls" = tool invocation, "length" = token limit.
94+
// Only set for assistant messages.
95+
FinishReason FinishReason `json:"finish_reason,omitempty"`
96+
9297
// CacheControl indicates whether this message is a cached message (only used by anthropic)
9398
CacheControl bool `json:"cache_control,omitempty"`
9499
}

pkg/runtime/event.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,14 @@ type Usage struct {
283283
}
284284

285285
// MessageUsage contains per-message usage data to include in TokenUsageEvent.
286-
// It embeds chat.Usage and adds Cost and Model fields.
286+
// It embeds chat.Usage and adds Cost, Model, and FinishReason fields.
287287
type MessageUsage struct {
288288
chat.Usage
289289
chat.RateLimit
290290

291-
Cost float64
292-
Model string
291+
Cost float64
292+
Model string
293+
FinishReason chat.FinishReason `json:"finish_reason,omitempty"`
293294
}
294295

295296
// NewTokenUsageEvent creates a TokenUsageEvent with the given usage data.

pkg/runtime/loop.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ func (r *LocalRuntime) recordAssistantMessage(
455455
Usage: res.Usage,
456456
Model: messageModel,
457457
Cost: messageCost,
458+
FinishReason: res.FinishReason,
458459
}
459460

460461
addAgentMessage(sess, a, &assistantMessage, events)
@@ -465,9 +466,10 @@ func (r *LocalRuntime) recordAssistantMessage(
465466
return nil
466467
}
467468
msgUsage := &MessageUsage{
468-
Usage: *res.Usage,
469-
Cost: messageCost,
470-
Model: messageModel,
469+
Usage: *res.Usage,
470+
Cost: messageCost,
471+
Model: messageModel,
472+
FinishReason: res.FinishReason,
471473
}
472474
if res.RateLimit != nil {
473475
msgUsage.RateLimit = *res.RateLimit

pkg/runtime/runtime.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"go.opentelemetry.io/otel/trace"
1515

1616
"github.com/docker/docker-agent/pkg/agent"
17+
"github.com/docker/docker-agent/pkg/chat"
1718
"github.com/docker/docker-agent/pkg/config/types"
1819
"github.com/docker/docker-agent/pkg/hooks"
1920
"github.com/docker/docker-agent/pkg/modelsdev"
@@ -861,6 +862,32 @@ func (r *LocalRuntime) EmitStartupInfo(ctx context.Context, sess *session.Sessio
861862
}
862863
usage := SessionUsage(sess, contextLimit)
863864
usage.Cost = sess.TotalCost()
865+
866+
// Reconstruct LastMessage from the parent session's last assistant
867+
// message so that FinishReason (and other per-message fields) are
868+
// available on session restore. We intentionally iterate
869+
// sess.Messages (not GetAllMessages) so the result reflects the
870+
// parent agent's state: this event carries the parent session_id,
871+
// and sub-agents emit their own token_usage events with their own
872+
// session_id during live streaming.
873+
for i := len(sess.Messages) - 1; i >= 0; i-- {
874+
item := &sess.Messages[i]
875+
if !item.IsMessage() || item.Message.Message.Role != chat.MessageRoleAssistant {
876+
continue
877+
}
878+
msg := &item.Message.Message
879+
lm := &MessageUsage{
880+
Model: msg.Model,
881+
Cost: msg.Cost,
882+
FinishReason: msg.FinishReason,
883+
}
884+
if msg.Usage != nil {
885+
lm.Usage = *msg.Usage
886+
}
887+
usage.LastMessage = lm
888+
break
889+
}
890+
864891
send(NewTokenUsageEvent(sess.ID, r.CurrentAgentName(), usage))
865892
}
866893

pkg/runtime/runtime_test.go

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,9 @@ func TestSimple(t *testing.T) {
281281
AgentChoice("root", sess.ID, "Hello"),
282282
MessageAdded(sess.ID, msgAdded.Message, "root"),
283283
NewTokenUsageEvent(sess.ID, "root", &Usage{InputTokens: 3, OutputTokens: 2, ContextLength: 5, LastMessage: &MessageUsage{
284-
Usage: chat.Usage{InputTokens: 3, OutputTokens: 2},
285-
Model: "test/mock-model",
284+
Usage: chat.Usage{InputTokens: 3, OutputTokens: 2},
285+
Model: "test/mock-model",
286+
FinishReason: chat.FinishReasonStop,
286287
}}),
287288
StreamStopped(sess.ID, "root"),
288289
}
@@ -324,8 +325,9 @@ func TestMultipleContentChunks(t *testing.T) {
324325
AgentChoice("root", sess.ID, "you?"),
325326
MessageAdded(sess.ID, msgAdded.Message, "root"),
326327
NewTokenUsageEvent(sess.ID, "root", &Usage{InputTokens: 8, OutputTokens: 12, ContextLength: 20, LastMessage: &MessageUsage{
327-
Usage: chat.Usage{InputTokens: 8, OutputTokens: 12},
328-
Model: "test/mock-model",
328+
Usage: chat.Usage{InputTokens: 8, OutputTokens: 12},
329+
Model: "test/mock-model",
330+
FinishReason: chat.FinishReasonStop,
329331
}}),
330332
StreamStopped(sess.ID, "root"),
331333
}
@@ -363,8 +365,9 @@ func TestWithReasoning(t *testing.T) {
363365
AgentChoice("root", sess.ID, "Hello, how can I help you?"),
364366
MessageAdded(sess.ID, msgAdded.Message, "root"),
365367
NewTokenUsageEvent(sess.ID, "root", &Usage{InputTokens: 10, OutputTokens: 15, ContextLength: 25, LastMessage: &MessageUsage{
366-
Usage: chat.Usage{InputTokens: 10, OutputTokens: 15},
367-
Model: "test/mock-model",
368+
Usage: chat.Usage{InputTokens: 10, OutputTokens: 15},
369+
Model: "test/mock-model",
370+
FinishReason: chat.FinishReasonStop,
368371
}}),
369372
StreamStopped(sess.ID, "root"),
370373
}
@@ -404,8 +407,9 @@ func TestMixedContentAndReasoning(t *testing.T) {
404407
AgentChoice("root", sess.ID, " How can I help you today?"),
405408
MessageAdded(sess.ID, msgAdded.Message, "root"),
406409
NewTokenUsageEvent(sess.ID, "root", &Usage{InputTokens: 15, OutputTokens: 20, ContextLength: 35, LastMessage: &MessageUsage{
407-
Usage: chat.Usage{InputTokens: 15, OutputTokens: 20},
408-
Model: "test/mock-model",
410+
Usage: chat.Usage{InputTokens: 15, OutputTokens: 20},
411+
Model: "test/mock-model",
412+
FinishReason: chat.FinishReasonStop,
409413
}}),
410414
StreamStopped(sess.ID, "root"),
411415
}
@@ -982,6 +986,59 @@ func TestEmitStartupInfo_CostIncludesSubSessions(t *testing.T) {
982986
"cost should include sub-session costs (TotalCost, not OwnCost)")
983987
}
984988

989+
func TestEmitStartupInfo_LastMessageFinishReason(t *testing.T) {
990+
// When restoring a session whose last assistant message has a
991+
// FinishReason, the emitted TokenUsageEvent.LastMessage must carry
992+
// that FinishReason so the UI can identify the final response.
993+
prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}}
994+
root := agent.New("root", "agent",
995+
agent.WithModel(prov),
996+
agent.WithDescription("Root"),
997+
)
998+
tm := team.New(team.WithAgents(root))
999+
1000+
rt, err := NewLocalRuntime(tm, WithCurrentAgent("root"),
1001+
WithModelStore(mockModelStoreWithLimit{limit: 128_000}))
1002+
require.NoError(t, err)
1003+
1004+
sess := session.New()
1005+
sess.InputTokens = 500
1006+
sess.OutputTokens = 200
1007+
1008+
sess.Messages = append(sess.Messages, session.Item{
1009+
Message: &session.Message{
1010+
AgentName: "root",
1011+
Message: chat.Message{
1012+
Role: chat.MessageRoleAssistant,
1013+
Content: "final answer",
1014+
Cost: 0.02,
1015+
Model: "test/startup-model",
1016+
FinishReason: chat.FinishReasonStop,
1017+
Usage: &chat.Usage{InputTokens: 500, OutputTokens: 200},
1018+
},
1019+
},
1020+
})
1021+
1022+
events := make(chan Event, 20)
1023+
rt.EmitStartupInfo(t.Context(), sess, events)
1024+
close(events)
1025+
1026+
var tokenEvent *TokenUsageEvent
1027+
for event := range events {
1028+
if te, ok := event.(*TokenUsageEvent); ok {
1029+
tokenEvent = te
1030+
}
1031+
}
1032+
1033+
require.NotNil(t, tokenEvent, "should emit TokenUsageEvent")
1034+
require.NotNil(t, tokenEvent.Usage.LastMessage, "LastMessage should be populated on session restore")
1035+
assert.Equal(t, chat.FinishReasonStop, tokenEvent.Usage.LastMessage.FinishReason)
1036+
assert.Equal(t, "test/startup-model", tokenEvent.Usage.LastMessage.Model)
1037+
assert.InDelta(t, 0.02, tokenEvent.Usage.LastMessage.Cost, 0.0001)
1038+
assert.Equal(t, int64(500), tokenEvent.Usage.LastMessage.InputTokens)
1039+
assert.Equal(t, int64(200), tokenEvent.Usage.LastMessage.OutputTokens)
1040+
}
1041+
9851042
func TestEmitStartupInfo_NilSessionNoTokenEvent(t *testing.T) {
9861043
// When sess is nil, no TokenUsageEvent should be emitted.
9871044
prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}}

pkg/runtime/streaming.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type streamResult struct {
2626
ThinkingSignature string
2727
ThoughtSignature []byte
2828
Stopped bool
29+
FinishReason chat.FinishReason
2930
Usage *chat.Usage
3031
RateLimit *chat.RateLimit
3132
}
@@ -44,6 +45,7 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
4445
var toolCalls []tools.ToolCall
4546
var messageUsage *chat.Usage
4647
var messageRateLimit *chat.RateLimit
48+
var providerFinishReason chat.FinishReason
4749

4850
toolCallIndex := make(map[string]int) // toolCallID -> index in toolCalls slice
4951
emittedPartial := make(map[string]bool) // toolCallID -> whether we've emitted a partial event
@@ -109,11 +111,19 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
109111
ThinkingSignature: thinkingSignature,
110112
ThoughtSignature: thoughtSignature,
111113
Stopped: true,
114+
FinishReason: choice.FinishReason,
112115
Usage: messageUsage,
113116
RateLimit: messageRateLimit,
114117
}, nil
115118
}
116119

120+
// Track the provider's explicit finish reason (e.g. tool_calls) so we
121+
// can prefer it over inference after the loop. stop/length are already
122+
// handled by the early return above.
123+
if choice.FinishReason != "" {
124+
providerFinishReason = choice.FinishReason
125+
}
126+
117127
// Handle tool calls
118128
if len(choice.Delta.ToolCalls) > 0 {
119129
// Process each tool call delta
@@ -191,13 +201,40 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
191201
// If the stream completed without producing any content or tool calls, likely because of a token limit, stop to avoid breaking the request loop
192202
// NOTE(krissetto): this can likely be removed once compaction works properly with all providers (aka dmr)
193203
stoppedDueToNoOutput := fullContent.Len() == 0 && len(toolCalls) == 0
204+
205+
// Prefer the provider's explicit finish reason when available (e.g.
206+
// tool_calls). Only fall back to inference when no explicit reason was
207+
// received (stream ended with bare EOF):
208+
// - tool calls present → tool_calls (model was requesting tools)
209+
// - content but no tool calls → stop (natural completion)
210+
// - no output at all → null (unknown; likely token limit)
211+
finishReason := providerFinishReason
212+
if finishReason == "" {
213+
switch {
214+
case len(toolCalls) > 0:
215+
finishReason = chat.FinishReasonToolCalls
216+
case fullContent.Len() > 0:
217+
finishReason = chat.FinishReasonStop
218+
default:
219+
finishReason = chat.FinishReasonNull
220+
}
221+
}
222+
// Ensure finish reason agrees with the actual stream output.
223+
switch {
224+
case finishReason == chat.FinishReasonToolCalls && len(toolCalls) == 0:
225+
finishReason = chat.FinishReasonNull
226+
case finishReason == chat.FinishReasonStop && len(toolCalls) > 0:
227+
finishReason = chat.FinishReasonToolCalls
228+
}
229+
194230
return streamResult{
195231
Calls: toolCalls,
196232
Content: fullContent.String(),
197233
ReasoningContent: fullReasoningContent.String(),
198234
ThinkingSignature: thinkingSignature,
199235
ThoughtSignature: thoughtSignature,
200236
Stopped: stoppedDueToNoOutput,
237+
FinishReason: finishReason,
201238
Usage: messageUsage,
202239
RateLimit: messageRateLimit,
203240
}, nil

0 commit comments

Comments
 (0)