Skip to content

Commit cc07ac7

Browse files
syn-zhuclaude
andcommitted
fix: remove dead code and make mock STS server audience-aware
- Remove unused imports from _base.py (BaseAgent, LlmAgent, AuthCredential, etc.) - Remove add_to_agent method that would overwrite audience-aware closures created by create_header_provider in types.py - Remove LlmAgent import and add_to_agent test code from test_adk_integration.py - Make mock STS server include 'aud' claim in generated tokens when audience is provided, improving E2E test coverage for per-audience token exchange Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Simon Zhu <simon.zhu@mongodb.com>
1 parent da6dd4c commit cc07ac7

3 files changed

Lines changed: 6 additions & 34 deletions

File tree

go/core/test/e2e/mocks/mock_sts_server.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func (m *MockSTSServer) handleTokenExchange(w http.ResponseWriter, r *http.Reque
178178
return
179179
}
180180

181-
accessToken, err := m.generateMockAccessToken(req.SubjectToken)
181+
accessToken, err := m.generateMockAccessToken(req.SubjectToken, req.Audience)
182182
if err != nil {
183183
http.Error(w, fmt.Sprintf("Error generating mock access token: %v", err), http.StatusBadRequest)
184184
return
@@ -200,7 +200,7 @@ func (m *MockSTSServer) handleTokenExchange(w http.ResponseWriter, r *http.Reque
200200
m.requests = append(m.requests, req)
201201
}
202202

203-
func (m *MockSTSServer) generateMockAccessToken(subjectToken string) (string, error) {
203+
func (m *MockSTSServer) generateMockAccessToken(subjectToken string, audience string) (string, error) {
204204
// Try to parse JWT token to extract subject claim
205205
subject, err := extractSubjectFromJWT(subjectToken)
206206
if err != nil {
@@ -219,6 +219,10 @@ func (m *MockSTSServer) generateMockAccessToken(subjectToken string) (string, er
219219
"iss": "mock-sts-server",
220220
}
221221

222+
if audience != "" {
223+
tokenData["aud"] = audience
224+
}
225+
222226
// For testing purposes, we'll return a simple JSON string
223227
// In a real implementation, this would be a signed JWT
224228
tokenBytes, err := json.Marshal(tokenData)

python/packages/agentsts-adk/src/agentsts/adk/_base.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,9 @@
33
import logging
44
from typing import Any, Dict, Optional
55

6-
from google.adk.agents import BaseAgent, LlmAgent
76
from google.adk.agents.invocation_context import InvocationContext
87
from google.adk.agents.readonly_context import ReadonlyContext
9-
from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, HttpAuth, HttpCredentials
10-
from google.adk.events.event import Event
118
from google.adk.plugins.base_plugin import BasePlugin
12-
from google.adk.runners import Runner
13-
from google.adk.sessions import BaseSessionService
14-
from google.adk.sessions.session import Session
159
from google.adk.tools.base_tool import BaseTool
1610
from google.adk.tools.mcp_tool import MCPTool
1711
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
@@ -67,23 +61,6 @@ def register_toolset(self, toolset: MCPToolset, audience: Optional[str]) -> None
6761
if audience and hasattr(toolset, "_mcp_session_manager"):
6862
self._audience_map[id(toolset._mcp_session_manager)] = audience
6963

70-
def add_to_agent(self, agent: BaseAgent):
71-
"""
72-
Add the plugin to an ADK LLM agent by updating its MCP toolset
73-
Call this once when setting up the agent; do not call it at runtime.
74-
"""
75-
if not isinstance(agent, LlmAgent):
76-
return
77-
78-
if not agent.tools:
79-
return
80-
81-
for tool in agent.tools:
82-
if isinstance(tool, MCPToolset):
83-
mcp_toolset = tool
84-
mcp_toolset._header_provider = self.header_provider
85-
logger.debug("Updated tool connection params to include access token from STS server")
86-
8764
def header_provider(
8865
self, readonly_context: Optional[ReadonlyContext], audience: Optional[str] = None
8966
) -> Dict[str, str]:

python/packages/agentsts-adk/tests/test_adk_integration.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from unittest.mock import AsyncMock, Mock, patch
44

55
import pytest
6-
from google.adk.agents import LlmAgent
76
from google.adk.tools.mcp_tool import MCPTool
87
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
98

@@ -98,14 +97,6 @@ async def test_downstream_token_propagation_without_sts(self):
9897
# Without STS, subject token is directly cached under session key
9998
assert plugin.token_cache["sess-2"] == "subj-token-123"
10099

101-
# propagate toolset
102-
mcp_toolset = Mock(spec=MCPToolset)
103-
agent = Mock(spec=LlmAgent)
104-
agent.tools = [mcp_toolset]
105-
plugin.add_to_agent(agent)
106-
# The toolset._header_provider should be callable
107-
assert callable(mcp_toolset._header_provider)
108-
109100
# header provider should return subject token
110101
ro_ctx = self._make_readonly_context(ic)
111102
headers = plugin.header_provider(ro_ctx)

0 commit comments

Comments
 (0)