Skip to content

Commit 5b5ed49

Browse files
hgao327The tunix Authors
authored andcommitted
Update agentic APIs
PiperOrigin-RevId: 814823652
1 parent cf656b9 commit 5b5ed49

File tree

3 files changed

+11
-24
lines changed

3 files changed

+11
-24
lines changed

tunix/rl/experimental/agentic/parser/tool_parser/gemini_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tunix.rl.experimental.agentic.tools import base_tool
77

88
BaseTool = base_tool.BaseTool
9-
ToolCall = tool_parser_base.ToolCall
9+
ToolCall = base_tool.ToolCall
1010
ToolParser = tool_parser_base.ToolParser
1111

1212

tunix/rl/experimental/agentic/parser/tool_parser/qwen_parser.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import json
8-
from typing import Any, List
8+
from typing import Any
99

1010
from tunix.rl.experimental.agentic.parser.tool_parser import tool_parser_base
1111
from tunix.rl.experimental.agentic.tools import base_tool
@@ -36,13 +36,10 @@ def parse(self, model_response: str) -> list[ToolCall]:
3636
model_response (str): Text containing tool calls
3737
3838
Returns:
39-
ToolInputs: Parsed tool calls
39+
list[ToolCall]: Parsed tool calls
4040
"""
4141
tool_calls_dicts = self.parse_qwen_tool_calls(model_response)
42-
tool_calls = [
43-
ToolCall(name=tc["name"], arguments=tc["arguments"])
44-
for tc in tool_calls_dicts
45-
]
42+
tool_calls = [ToolCall(**tool_call) for tool_call in tool_calls_dicts]
4643
return tool_calls
4744

4845
def parse_qwen_tool_calls(self, text: str) -> list[dict[str, Any]]:
@@ -92,7 +89,7 @@ def parse_qwen_tool_calls(self, text: str) -> list[dict[str, Any]]:
9289

9390
def get_tool_prompt(
9491
self,
95-
tools: List[BaseTool],
92+
tools: list[BaseTool],
9693
*,
9794
schema_style: str = "openai",
9895
) -> str:

tunix/rl/experimental/agentic/parser/tool_parser/tool_parser_base.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313
ABC = abc.ABC
1414

1515

16-
@dataclass
17-
class ToolCall:
18-
name: str
19-
arguments: dict[str, Any]
20-
21-
2216
class ToolParser(ABC):
2317
"""Abstract base class for all tool parsers.
2418
@@ -28,7 +22,7 @@ class ToolParser(ABC):
2822
"""
2923

3024
@abstractmethod
31-
def parse(self, model_response: str) -> list[ToolCall]:
25+
def parse(self, model_response: str) -> list[base_tool.ToolCall]:
3226
"""Parse model output and return a list of tool calls.
3327
3428
Args:
@@ -42,17 +36,13 @@ def parse(self, model_response: str) -> list[ToolCall]:
4236
@abstractmethod
4337
def get_tool_prompt(
4438
self,
45-
tools: List[BaseTool],
46-
*,
47-
schema_style: Literal["openai", "mcp", "gemini"] = "openai",
39+
tools_schema: str,
4840
) -> str:
4941
"""Generate tool-usage instruction prompt from a list of tools.
5042
5143
Args:
52-
tools: List of tool instances (BaseTool).
53-
schema_style: "openai" -> use tool.json (OpenAI function-calling style)
54-
"mcp" -> use tool.to_mcp_json() (MCP-compatible format) "gemini" ->
55-
use Gemini-compatible schema
44+
tools_schema: A string containing the tool schemas, generated by
45+
`_tools_schema_dump`.
5646
5747
Returns:
5848
str: Prompt text to feed into the model (includes tool schemas).
@@ -79,9 +69,9 @@ def _tools_schema_dump(
7969
schemas = [t.to_mcp_json() for t in tools]
8070
elif schema_style == "gemini":
8171
# Gemini also uses JSON schema, same as OpenAI
82-
schemas = [t.json for t in tools]
72+
schemas = [t.get_json_schema() for t in tools]
8373
else:
84-
schemas = [t.json for t in tools]
74+
schemas = [t.get_json_schema() for t in tools]
8575
return json.dumps(schemas, ensure_ascii=False, indent=2)
8676

8777
def parse_tool_outputs(self) -> dict[str, Any]:

0 commit comments

Comments
 (0)