Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions tests/test_tool_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Tests for tool scheduling (depends_on, tool_call_alias) and param coercion."""
from tool_registry import (
get_tool_schemas,
execute_tool,
register_tool,
ToolDef,
_coerce_params,
_SCHEDULING_PROPS,
)

# Trigger builtin tool registration
import tools # noqa: F401


class TestSchedulingPropsInjection:
def test_schemas_contain_scheduling_fields(self):
schemas = get_tool_schemas()
assert len(schemas) > 0
for s in schemas:
props = s.get("properties", {})
assert "tool_call_alias" in props, f"Missing tool_call_alias in {s.get('name')}"
assert "depends_on" in props, f"Missing depends_on in {s.get('name')}"

def test_scheduling_props_have_correct_types(self):
schemas = get_tool_schemas()
s = schemas[0]
assert s["properties"]["tool_call_alias"]["type"] == "string"
assert s["properties"]["depends_on"]["type"] == "array"

def test_original_schema_not_mutated(self):
"""Verify deepcopy prevents mutation of registered schemas."""
schemas1 = get_tool_schemas()
schemas1[0]["properties"]["tool_call_alias"]["EXTRA"] = True
schemas2 = get_tool_schemas()
assert "EXTRA" not in schemas2[0]["properties"]["tool_call_alias"]


class TestCoerceParams:
def test_int_coercion(self):
schema = {"properties": {"limit": {"type": "integer"}}}
assert _coerce_params({"limit": "42"}, schema) == {"limit": 42}

def test_float_coercion(self):
schema = {"properties": {"rate": {"type": "number"}}}
assert _coerce_params({"rate": "3.14"}, schema) == {"rate": 3.14}

def test_bool_true(self):
schema = {"properties": {"flag": {"type": "boolean"}}}
assert _coerce_params({"flag": "true"}, schema) == {"flag": True}

def test_bool_false(self):
schema = {"properties": {"flag": {"type": "boolean"}}}
assert _coerce_params({"flag": "false"}, schema) == {"flag": False}

def test_array_coercion(self):
schema = {"properties": {"items": {"type": "array"}}}
result = _coerce_params({"items": '["a","b"]'}, schema)
assert result == {"items": ["a", "b"]}

def test_object_coercion(self):
schema = {"properties": {"meta": {"type": "object"}}}
result = _coerce_params({"meta": '{"k": 1}'}, schema)
assert result == {"meta": {"k": 1}}

def test_passthrough_string(self):
schema = {"properties": {"name": {"type": "string"}}}
assert _coerce_params({"name": "hello"}, schema) == {"name": "hello"}

def test_invalid_json_passthrough(self):
schema = {"properties": {"items": {"type": "array"}}}
assert _coerce_params({"items": "not-json"}, schema) == {"items": "not-json"}

def test_unknown_prop_passthrough(self):
schema = {"properties": {}}
assert _coerce_params({"x": "y"}, schema) == {"x": "y"}


class TestExecuteToolStripsScheduling:
def setup_method(self):
self._received = {}

def _handler(params, config=None):
self._received = dict(params)
return "ok"

register_tool(ToolDef(
name="test_sched_tool",
schema={
"name": "test_sched_tool",
"description": "test tool",
"properties": {"msg": {"type": "string"}},
},
func=_handler,
read_only=True,
))

def test_scheduling_params_stripped(self):
execute_tool(
"test_sched_tool",
{"msg": "hi", "tool_call_alias": "t1", "depends_on": ["w1"]},
config={},
)
assert "tool_call_alias" not in self._received
assert "depends_on" not in self._received
assert self._received.get("msg") == "hi"
102 changes: 102 additions & 0 deletions tests/test_tool_scheduling_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""End-to-end: the LLM sees `tool_call_alias` + `depends_on` in every tool
schema, it uses them in a tool call, and the stripping wrapper removes those
fields before the tool handler runs.

Only the LLM provider is mocked (via monkeypatching agent.stream). The tool
registry, schema injection and param stripping all run for real.
"""
from __future__ import annotations

import pytest

import tools as _tools_init # noqa: F401 - force built-in tool registration
from agent import AgentState, run
from providers import AssistantTurn
from tool_registry import ToolDef, register_tool


def _scripted_stream(captured_schemas, turns):
cursor = iter(turns)

def fake_stream(**kwargs):
captured_schemas.append(kwargs.get("tool_schemas") or [])
spec = next(cursor)
yield AssistantTurn(
text=spec.get("text", ""),
tool_calls=spec.get("tool_calls") or [],
in_tokens=1, out_tokens=1,
)

return fake_stream


@pytest.fixture
def receiver_tool():
"""Register a tool that captures whatever params it receives."""
received = {}
from tool_registry import _registry
had_before = "receiver" in _registry
register_tool(ToolDef(
name="receiver",
schema={
"name": "receiver",
"description": "records params for assertions",
"input_schema": {
"type": "object",
"properties": {"msg": {"type": "string"}},
"required": ["msg"],
},
},
func=lambda params, _cfg: received.setdefault("seen", dict(params)) and "ok",
read_only=True, concurrent_safe=True,
))
yield received
if not had_before:
_registry.pop("receiver", None)


def test_schemas_sent_to_llm_include_scheduling_props(monkeypatch, receiver_tool):
"""Every schema the LLM sees must carry tool_call_alias + depends_on."""
captured = []
monkeypatch.setattr(
"agent.stream",
_scripted_stream(captured, [{"text": "nothing to do"}]),
)

list(run("hi", AgentState(), {"model": "test", "permission_mode": "accept-all",
"_session_id": "sch", "disabled_tools": ["Agent"]},
"sys"))

assert captured, "stream was not called"
for schema in captured[0]:
props = schema.get("properties") or schema.get("input_schema", {}).get("properties", {})
assert "tool_call_alias" in props, f"{schema.get('name')} missing tool_call_alias"
assert "depends_on" in props, f"{schema.get('name')} missing depends_on"


def test_scheduling_params_stripped_before_reaching_tool(monkeypatch, receiver_tool):
"""tool_call_alias + depends_on must be gone by the time the handler runs."""
captured_schemas = []
turns = [
{"tool_calls": [{
"id": "r1",
"name": "receiver",
"input": {
"msg": "hello",
"tool_call_alias": "step-1",
"depends_on": ["w1", "w2"],
},
}]},
{"text": "done"},
]
monkeypatch.setattr("agent.stream", _scripted_stream(captured_schemas, turns))

list(run("go", AgentState(), {"model": "test", "permission_mode": "accept-all",
"_session_id": "sch2", "disabled_tools": ["Agent"]},
"sys"))

assert "seen" in receiver_tool, "receiver handler was never called"
seen = receiver_tool["seen"]
assert seen.get("msg") == "hello"
assert "tool_call_alias" not in seen
assert "depends_on" not in seen
110 changes: 110 additions & 0 deletions tool_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,113 @@ def execute_tool(
def clear_registry() -> None:
"""Remove all registered tools. Intended for testing."""
_registry.clear()


# ── Tool scheduling support ────────────────────────────────────────────────

import copy as _copy
import json as _json

_SCHEDULING_PROPS = {
"tool_call_alias": {
"type": "string",
"description": (
"Optional alias for this tool call. "
"Other tools can reference it in depends_on."
),
},
"depends_on": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of tool_call IDs or aliases that must complete before this tool runs."
),
},
}


def _coerce_params(params: dict, schema: dict) -> dict:
"""Coerce string parameter values to their schema-declared types.

Coercion failure is not a hard error: the original string is kept and
passed to the tool handler, which will surface a clear type error to
the model (e.g. `expected int, got 'abc'`) far more usefully than a
ValueError from the registry wrapper.
"""
props = schema.get("properties", {})
return {k: _coerce_value_for(k, v, props) for k, v in params.items()}


def _coerce_value_for(key: str, value, props: dict):
"""Coerce a single value according to its declared type, else return as-is."""
prop_schema = props.get(key)
if not prop_schema or not isinstance(value, str):
return value
coercer = _COERCERS.get(prop_schema.get("type"))
if coercer is None:
return value
return coercer(value)


def _coerce_int(value):
try:
return int(value)
except ValueError:
return value # intentional: tool handler reports the real type mismatch


def _coerce_float(value):
try:
return float(value)
except ValueError:
return value


def _coerce_bool(value):
return value.lower() in ("true", "1", "yes")


def _coerce_json(value):
try:
return _json.loads(value)
except (ValueError, _json.JSONDecodeError):
return value


_COERCERS = {
"integer": _coerce_int,
"number": _coerce_float,
"boolean": _coerce_bool,
"array": _coerce_json,
"object": _coerce_json,
}


# Wrap get_tool_schemas to inject scheduling properties
_orig_get_tool_schemas = get_tool_schemas


def get_tool_schemas():
"""Return tool schemas with scheduling properties injected."""
schemas = _orig_get_tool_schemas()
result = []
for s in schemas:
s = _copy.deepcopy(s)
props = s.setdefault("properties", {})
for k, v in _SCHEDULING_PROPS.items():
props.setdefault(k, _copy.deepcopy(v))
result.append(s)
return result


# Wrap execute_tool to strip scheduling params and coerce types
_orig_execute_tool = execute_tool


def execute_tool(name, params, *args, **kwargs):
"""Execute a tool after stripping scheduling params and coercing types."""
clean = {k: v for k, v in params.items() if k not in _SCHEDULING_PROPS}
tool = get_tool(name)
if tool is not None:
clean = _coerce_params(clean, tool.schema)
return _orig_execute_tool(name, clean, *args, **kwargs)
Loading