Skip to content

Commit 4b7e941

Browse files
authored
fix issue #447 (#449)
1 parent d373057 commit 4b7e941

File tree

6 files changed

+145
-57
lines changed

6 files changed

+145
-57
lines changed

src/fast_agent/agents/mcp_agent.py

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from fast_agent.core.exceptions import PromptExitError
4242
from fast_agent.core.logging.logger import get_logger
4343
from fast_agent.interfaces import FastAgentLLMProtocol
44+
from fast_agent.mcp.common import SEP
4445
from fast_agent.mcp.mcp_aggregator import MCPAggregator, ServerStatus
4546
from fast_agent.skills.registry import format_skills_for_prompt
4647
from fast_agent.tools.elicitation import (
@@ -397,40 +398,51 @@ async def list_tools(self) -> ListToolsResult:
397398
Returns:
398399
ListToolsResult with available tools
399400
"""
400-
# Get all tools from the aggregator
401-
result = await self._aggregator.list_tools()
402-
aggregator_tools = list(result.tools)
401+
aggregator_result = await self._aggregator.list_tools()
402+
aggregator_tools = list(aggregator_result.tools or [])
403403

404404
# Apply filtering if tools are specified in config
405405
if self.config.tools is not None:
406-
filtered_tools = []
406+
filtered_tools: list[Tool] = []
407407
for tool in aggregator_tools:
408408
# Extract server name from tool name, handling server names with hyphens
409409
server_name = None
410410
for configured_server in self.config.tools.keys():
411-
if tool.name.startswith(f"{configured_server}-"):
411+
if tool.name.startswith(f"{configured_server}{SEP}"):
412412
server_name = configured_server
413413
break
414414

415-
# Check if this server has tool filters
416-
if server_name and server_name in self.config.tools:
417-
# Check if tool matches any pattern for this server
418-
for pattern in self.config.tools[server_name]:
419-
if self._matches_pattern(tool.name, pattern, server_name):
420-
filtered_tools.append(tool)
421-
break
415+
if not server_name:
416+
continue
417+
418+
# Check if tool matches any pattern for this server
419+
for pattern in self.config.tools[server_name]:
420+
if self._matches_pattern(tool.name, pattern, server_name):
421+
filtered_tools.append(tool)
422+
break
422423
aggregator_tools = filtered_tools
423424

424-
result.tools = aggregator_tools
425+
# Start with filtered aggregator tools and merge in subclass/local tools
426+
merged_tools: list[Tool] = list(aggregator_tools)
427+
existing_names = {tool.name for tool in merged_tools}
425428

426-
if self._bash_tool and all(tool.name != self._bash_tool.name for tool in result.tools):
427-
result.tools.append(self._bash_tool)
429+
local_tools = (await ToolAgent.list_tools(self)).tools
430+
for tool in local_tools:
431+
if tool.name not in existing_names:
432+
merged_tools.append(tool)
433+
existing_names.add(tool.name)
428434

429-
# Append human input tool if enabled and available
430-
if self.config.human_input and getattr(self, "_human_input_tool", None):
431-
result.tools.append(self._human_input_tool)
435+
if self._bash_tool and self._bash_tool.name not in existing_names:
436+
merged_tools.append(self._bash_tool)
437+
existing_names.add(self._bash_tool.name)
432438

433-
return result
439+
if self.config.human_input:
440+
human_tool = getattr(self, "_human_input_tool", None)
441+
if human_tool and human_tool.name not in existing_names:
442+
merged_tools.append(human_tool)
443+
existing_names.add(human_tool.name)
444+
445+
return ListToolsResult(tools=merged_tools)
434446

435447
async def call_tool(self, name: str, arguments: Dict[str, Any] | None = None) -> CallToolResult:
436448
"""
@@ -449,6 +461,9 @@ async def call_tool(self, name: str, arguments: Dict[str, Any] | None = None) ->
449461
if name == HUMAN_INPUT_TOOL_NAME:
450462
# Call the elicitation-backed human input tool
451463
return await self._call_human_input_tool(arguments)
464+
465+
if name in self._execution_tools:
466+
return await super().call_tool(name, arguments)
452467
else:
453468
return await self._aggregator.call_tool(name, arguments)
454469

@@ -703,38 +718,61 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
703718
tool_results: dict[str, CallToolResult] = {}
704719
tool_loop_error: str | None = None
705720

706-
# Cache available tool names (original, not namespaced) for display
707-
available_tools = [
708-
namespaced_tool.tool.name
709-
for namespaced_tool in self._aggregator._namespaced_tool_map.values()
710-
]
711-
if self._shell_runtime.tool:
712-
available_tools.append(self._shell_runtime.tool.name)
721+
# Cache available tool names exactly as advertised to the LLM for display/highlighting
722+
try:
723+
listed_tools = await self.list_tools()
724+
except Exception as exc: # pragma: no cover - defensive guard, should not happen
725+
self.logger.warning(f"Failed to list tools before execution: {exc}")
726+
listed_tools = ListToolsResult(tools=[])
727+
728+
available_tools: list[str] = []
729+
seen_tool_names: set[str] = set()
730+
for tool_schema in listed_tools.tools:
731+
if tool_schema.name in seen_tool_names:
732+
continue
733+
available_tools.append(tool_schema.name)
734+
seen_tool_names.add(tool_schema.name)
713735

714-
available_tools = list(dict.fromkeys(available_tools))
736+
# Cache namespaced tools for routing/metadata
737+
namespaced_tools = self._aggregator._namespaced_tool_map
715738

716739
# Process each tool call using our aggregator
717740
for correlation_id, tool_request in request.tool_calls.items():
718741
tool_name = tool_request.params.name
719742
tool_args = tool_request.params.arguments or {}
720743

721-
# Get the original tool name for display (not namespaced)
722-
namespaced_tool = self._aggregator._namespaced_tool_map.get(tool_name)
723-
display_tool_name = namespaced_tool.tool.name if namespaced_tool else tool_name
724-
725-
tool_available = False
726-
if tool_name == HUMAN_INPUT_TOOL_NAME:
727-
tool_available = True
728-
elif self._bash_tool and tool_name == self._bash_tool.name:
729-
tool_available = True
730-
elif namespaced_tool:
731-
tool_available = True
732-
else:
733-
tool_available = any(
734-
candidate.tool.name == tool_name
735-
for candidate in self._aggregator._namespaced_tool_map.values()
744+
# Determine which tool we are calling (namespaced MCP, local, etc.)
745+
namespaced_tool = namespaced_tools.get(tool_name)
746+
local_tool = self._execution_tools.get(tool_name)
747+
candidate_namespaced_tool = None
748+
if namespaced_tool is None and local_tool is None:
749+
candidate_namespaced_tool = next(
750+
(
751+
candidate
752+
for candidate in namespaced_tools.values()
753+
if candidate.tool.name == tool_name
754+
),
755+
None,
736756
)
737757

758+
# Select display/highlight names
759+
display_tool_name = tool_name
760+
highlight_name = tool_name
761+
if namespaced_tool is not None:
762+
display_tool_name = namespaced_tool.namespaced_tool_name
763+
highlight_name = namespaced_tool.namespaced_tool_name
764+
elif candidate_namespaced_tool is not None:
765+
display_tool_name = candidate_namespaced_tool.namespaced_tool_name
766+
highlight_name = candidate_namespaced_tool.namespaced_tool_name
767+
768+
tool_available = (
769+
tool_name == HUMAN_INPUT_TOOL_NAME
770+
or (self._shell_runtime.tool and tool_name == self._shell_runtime.tool.name)
771+
or namespaced_tool is not None
772+
or local_tool is not None
773+
or candidate_namespaced_tool is not None
774+
)
775+
738776
if not tool_available:
739777
error_message = f"Tool '{display_tool_name}' is not available"
740778
self.logger.error(error_message)
@@ -748,7 +786,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
748786
# Find the index of the current tool in available_tools for highlighting
749787
highlight_index = None
750788
try:
751-
highlight_index = available_tools.index(display_tool_name)
789+
highlight_index = available_tools.index(highlight_name)
752790
except ValueError:
753791
# Tool not found in list, no highlighting
754792
pass
@@ -757,7 +795,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
757795
if (
758796
self._shell_runtime_enabled
759797
and self._shell_runtime.tool
760-
and display_tool_name == self._shell_runtime.tool.name
798+
and tool_name == self._shell_runtime.tool.name
761799
):
762800
metadata = self._shell_runtime.metadata(tool_args.get("command"))
763801

@@ -778,9 +816,10 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
778816

779817
# Show tool result (like ToolAgent does)
780818
skybridge_config = None
781-
if namespaced_tool:
819+
skybridge_tool = namespaced_tool or candidate_namespaced_tool
820+
if skybridge_tool:
782821
skybridge_config = await self._aggregator.get_skybridge_config(
783-
namespaced_tool.server_name
822+
skybridge_tool.server_name
784823
)
785824

786825
if not getattr(result, "_suppress_display", False):

src/fast_agent/agents/tool_agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
logger.warning(f"Failed to initialize human-input tool: {e}")
5656

5757
for tool in working_tools:
58+
(tool)
5859
if isinstance(tool, FastMCPTool):
5960
fast_tool = tool
6061
elif callable(tool):
@@ -137,12 +138,13 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
137138
tool_results: dict[str, CallToolResult] = {}
138139
tool_loop_error: str | None = None
139140
# TODO -- use gather() for parallel results, update display
140-
available_tools = [t.name for t in (await self.list_tools()).tools]
141+
tool_schemas = (await self.list_tools()).tools
142+
available_tools = [t.name for t in tool_schemas]
141143
for correlation_id, tool_request in request.tool_calls.items():
142144
tool_name = tool_request.params.name
143145
tool_args = tool_request.params.arguments or {}
144146

145-
if tool_name not in self._execution_tools:
147+
if tool_name not in available_tools and tool_name not in self._execution_tools:
146148
error_message = f"Tool '{tool_name}' is not available"
147149
logger.error(error_message)
148150
tool_loop_error = self._mark_tool_loop_error(

src/fast_agent/context_dependent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ContextDependent:
1717

1818
def __init__(self, context: "Context | None" = None, **kwargs: dict[str, Any]) -> None:
1919
self._context = context
20-
super().__init__(**kwargs)
20+
super().__init__()
2121

2222
@property
2323
def context(self) -> "Context":

src/fast_agent/ui/console_display.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ async def show_assistant_message(
629629
model: Optional model name for right info
630630
additional_message: Optional additional styled message to append
631631
"""
632-
if not self.config or not self.config.logger.show_chat:
632+
if self.config and not self.config.logger.show_chat:
633633
return
634634

635635
# Extract text from PromptMessageExtended if needed
@@ -749,7 +749,7 @@ def _display_mermaid_diagrams(self, diagrams: List[MermaidDiagram]) -> None:
749749

750750
async def show_mcp_ui_links(self, links: List[UILink]) -> None:
751751
"""Display MCP-UI links beneath the chat like mermaid links."""
752-
if not self.config or not self.config.logger.show_chat:
752+
if self.config and not self.config.logger.show_chat:
753753
return
754754

755755
if not links:
@@ -776,7 +776,7 @@ def show_user_message(
776776
name: str | None = None,
777777
) -> None:
778778
"""Display a user message in the new visual style."""
779-
if not self.config or not self.config.logger.show_chat:
779+
if self.config and not self.config.logger.show_chat:
780780
return
781781

782782
# Build right side with model and turn
@@ -803,7 +803,7 @@ def show_system_message(
803803
server_count: int = 0,
804804
) -> None:
805805
"""Display the system prompt in a formatted panel."""
806-
if not self.config or not self.config.logger.show_chat:
806+
if self.config and not self.config.logger.show_chat:
807807
return
808808

809809
# Build right side info
@@ -844,7 +844,7 @@ async def show_prompt_loaded(
844844
highlight_server: Optional server name to highlight
845845
arguments: Optional dictionary of arguments passed to the prompt template
846846
"""
847-
if not self.config or not self.config.logger.show_tools:
847+
if self.config and not self.config.logger.show_tools:
848848
return
849849

850850
# Build the server list with highlighting

src/fast_agent/ui/tool_display.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def show_tool_result(
3636
) -> None:
3737
"""Display a tool result in the console."""
3838
config = self._display.config
39-
if not config or not config.logger.show_tools:
39+
if config and not config.logger.show_tools:
4040
return
4141

4242
from fast_agent.mcp.helpers.content_helpers import get_text, is_text_content
@@ -200,7 +200,7 @@ def show_tool_call(
200200
) -> None:
201201
"""Display a tool call header and body."""
202202
config = self._display.config
203-
if not config or not config.logger.show_tools:
203+
if config and not config.logger.show_tools:
204204
return
205205

206206
tool_args = tool_args or {}
@@ -267,7 +267,7 @@ def show_tool_call(
267267
async def show_tool_update(self, updated_server: str, *, agent_name: str | None = None) -> None:
268268
"""Show a background tool update notification."""
269269
config = self._display.config
270-
if not config or not config.logger.show_tools:
270+
if config and not config.logger.show_tools:
271271
return
272272

273273
try:
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import TYPE_CHECKING, Any, Dict, List
2+
3+
import pytest
4+
5+
from fast_agent.agents.agent_types import AgentConfig
6+
from fast_agent.agents.mcp_agent import McpAgent
7+
from fast_agent.context import Context
8+
9+
if TYPE_CHECKING:
10+
from mcp.types import CallToolResult
11+
12+
13+
def _make_agent_config() -> AgentConfig:
14+
return AgentConfig(name="test-agent", instruction="do things", servers=[])
15+
16+
17+
@pytest.mark.asyncio
18+
async def test_local_tools_listed_and_callable() -> None:
19+
calls: List[Dict[str, Any]] = []
20+
21+
def sample_tool(video_id: str) -> str:
22+
calls.append({"video_id": video_id})
23+
return f"transcript for {video_id}"
24+
25+
config = _make_agent_config()
26+
context = Context()
27+
28+
class LocalToolAgent(McpAgent):
29+
def __init__(self) -> None:
30+
super().__init__(
31+
config=config,
32+
connection_persistence=False,
33+
context=context,
34+
tools=[sample_tool],
35+
)
36+
37+
agent = LocalToolAgent()
38+
39+
tool_names = {tool.name for tool in (await agent.list_tools()).tools}
40+
assert "sample_tool" in tool_names
41+
42+
result: CallToolResult = await agent.call_tool("sample_tool", {"video_id": "1234"})
43+
assert not result.isError
44+
assert calls == [{"video_id": "1234"}]
45+
assert [block.text for block in result.content or []] == ["transcript for 1234"]
46+
47+
await agent._aggregator.close()

0 commit comments

Comments
 (0)