From 14017633f74fe41c2d4a6b88a1019f0f3714449e Mon Sep 17 00:00:00 2001 From: echobt Date: Tue, 3 Feb 2026 14:21:10 +0000 Subject: [PATCH] feat: Remove OpenRouter support, replace litellm with Chutes API - Remove OPENROUTER from Provider enum, add CHUTES - Update get_api_key() to use CHUTES_API_KEY - Update get_base_url() to use https://api.chutes.ai/v1 - Rewrite LLM client using httpx instead of litellm - Update default model to deepseek/deepseek-chat - Remove litellm from requirements.txt and pyproject.toml - Update all documentation references - Keep LiteLLMClient alias for backward compatibility --- README.md | 18 +- agent.py | 67 +++--- astuces/08-cost-optimization.md | 15 +- pyproject.toml | 1 - requirements.txt | 1 - rules/02-architecture-patterns.md | 4 +- rules/06-llm-usage-guide.md | 131 +++--------- rules/08-error-handling.md | 4 +- src/__init__.py | 4 +- src/api/retry.py | 60 +++--- src/config/__init__.py | 2 +- src/config/defaults.py | 43 +--- src/config/loader.py | 44 ++-- src/config/models.py | 55 ++--- src/core/__init__.py | 29 ++- src/core/agent.py | 109 +++++----- src/core/compaction.py | 121 ++++++----- src/core/executor.py | 167 +++++++-------- src/core/loop.py | 331 ++++++++++++++++-------------- src/core/session.py | 76 +++---- src/exec/__init__.py | 10 +- src/exec/runner.py | 154 +++++++------- src/images/__init__.py | 4 +- src/images/loader.py | 51 ++--- src/llm/__init__.py | 13 +- src/llm/client.py | 290 ++++++++++++++++---------- src/main.py | 36 ++-- src/output/__init__.py | 47 +++-- src/output/events.py | 26 +-- src/output/jsonl.py | 23 ++- src/output/processor.py | 86 ++++---- src/output/streaming.py | 119 ++++++----- src/prompts/system.py | 223 ++++++++++---------- src/tools/__init__.py | 19 +- src/tools/apply_patch.py | 191 +++++++++-------- src/tools/base.py | 34 +-- src/tools/grep_files.py | 75 ++++--- src/tools/list_dir.py | 119 ++++++----- src/tools/read_file.py | 66 +++--- src/tools/registry.py | 259 ++++++++++++----------- src/tools/search_files.py | 102 ++++----- src/tools/shell.py | 61 +++--- src/tools/specs.py | 6 +- src/tools/view_image.py | 87 ++++---- src/tools/write_file.py | 54 ++--- src/utils/__init__.py | 27 ++- src/utils/files.py | 27 ++- src/utils/tokens.py | 8 +- src/utils/truncate.py | 257 ++++++++++++----------- 49 files changed, 1899 insertions(+), 1857 deletions(-) diff --git a/README.md b/README.md index 3b9cf99..580d805 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # BaseAgent - SDK 3.0 -High-performance autonomous agent for [Term Challenge](https://term.challenge). **Does NOT use term_sdk** - fully autonomous with litellm. +High-performance autonomous agent for [Term Challenge](https://term.challenge). **Does NOT use term_sdk** - fully autonomous with Chutes API. ## Installation @@ -36,7 +36,7 @@ my-agent/ │ │ ├── loop.py # Main loop │ │ └── compaction.py # Context management (MANDATORY) │ ├── llm/ -│ │ └── client.py # LLM client (litellm) +│ │ └── client.py # LLM client (Chutes API) │ └── tools/ │ └── ... # Available tools ├── requirements.txt # Dependencies @@ -77,13 +77,13 @@ AUTO_COMPACT_THRESHOLD = 0.85 ## Features -### LLM Client (litellm) +### LLM Client (Chutes API) ```python -from src.llm.client import LiteLLMClient +from src.llm.client import LLMClient -llm = LiteLLMClient( - model="openrouter/anthropic/claude-opus-4.5", +llm = LLMClient( + model="deepseek/deepseek-chat", temperature=0.0, max_tokens=16384, ) @@ -129,7 +129,7 @@ See `src/config/defaults.py`: ```python CONFIG = { - "model": "openrouter/anthropic/claude-opus-4.5", + "model": "deepseek/deepseek-chat", "max_tokens": 16384, "max_iterations": 200, "auto_compact_threshold": 0.85, @@ -142,7 +142,7 @@ CONFIG = { | Variable | Description | |----------|-------------| -| `OPENROUTER_API_KEY` | OpenRouter API key | +| `CHUTES_API_KEY` | Chutes API key | ## Documentation @@ -151,7 +151,7 @@ CONFIG = { See [rules/](rules/) for comprehensive guides: - [Architecture Patterns](rules/02-architecture-patterns.md) - **Mandatory project structure** -- [LLM Usage Guide](rules/06-llm-usage-guide.md) - **Using litellm** +- [LLM Usage Guide](rules/06-llm-usage-guide.md) - **Using Chutes API** - [Best Practices](rules/05-best-practices.md) - [Error Handling](rules/08-error-handling.md) diff --git a/agent.py b/agent.py index 710edb1..db76ffb 100644 --- a/agent.py +++ b/agent.py @@ -3,7 +3,7 @@ SuperAgent for Term Challenge - Entry Point (SDK 3.0 Compatible). This agent accepts --instruction from the validator and runs autonomously. -Uses litellm for LLM calls instead of term_sdk. +Uses Chutes API for LLM calls instead of term_sdk. Installation: pip install . # via pyproject.toml @@ -16,20 +16,20 @@ from __future__ import annotations import argparse -import sys -import time import os import subprocess +import sys +import time from pathlib import Path # Add parent to path for imports sys.path.insert(0, str(Path(__file__).parent)) + # Auto-install dependencies if missing def ensure_dependencies(): """Install dependencies if not present.""" try: - import litellm import httpx import pydantic except ImportError: @@ -37,23 +37,28 @@ def ensure_dependencies(): agent_dir = Path(__file__).parent req_file = agent_dir / "requirements.txt" if req_file.exists(): - subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(req_file), "-q"], check=True) + subprocess.run( + [sys.executable, "-m", "pip", "install", "-r", str(req_file), "-q"], check=True + ) else: - subprocess.run([sys.executable, "-m", "pip", "install", str(agent_dir), "-q"], check=True) + subprocess.run( + [sys.executable, "-m", "pip", "install", str(agent_dir), "-q"], check=True + ) print("[setup] Dependencies installed", file=sys.stderr) + ensure_dependencies() from src.config.defaults import CONFIG from src.core.loop import run_agent_loop +from src.llm.client import CostLimitExceeded, LLMClient +from src.output.jsonl import ErrorEvent, emit from src.tools.registry import ToolRegistry -from src.output.jsonl import emit, ErrorEvent -from src.llm.client import LiteLLMClient, CostLimitExceeded class AgentContext: """Minimal context for agent execution (replaces term_sdk.AgentContext).""" - + def __init__(self, instruction: str, cwd: str = None): self.instruction = instruction self.cwd = cwd or os.getcwd() @@ -61,11 +66,11 @@ def __init__(self, instruction: str, cwd: str = None): self.is_done = False self.history = [] self._start_time = time.time() - + @property def elapsed_secs(self) -> float: return time.time() - self._start_time - + def shell(self, cmd: str, timeout: int = 120) -> "ShellResult": """Execute a shell command.""" self.step += 1 @@ -86,20 +91,22 @@ def shell(self, cmd: str, timeout: int = 120) -> "ShellResult": except Exception as e: output = f"[ERROR] {e}" exit_code = -1 - + shell_result = ShellResult(output=output, exit_code=exit_code) - self.history.append({ - "step": self.step, - "command": cmd, - "output": output[:1000], - "exit_code": exit_code, - }) + self.history.append( + { + "step": self.step, + "command": cmd, + "output": output[:1000], + "exit_code": exit_code, + } + ) return shell_result - + def done(self): """Mark task as complete.""" self.is_done = True - + def log(self, msg: str): """Log a message.""" timestamp = time.strftime("%H:%M:%S") @@ -108,13 +115,13 @@ def log(self, msg: str): class ShellResult: """Result from shell command.""" - + def __init__(self, output: str, exit_code: int): self.output = output self.stdout = output self.stderr = "" self.exit_code = exit_code - + def has(self, text: str) -> bool: return text in self.output @@ -129,29 +136,29 @@ def main(): parser = argparse.ArgumentParser(description="SuperAgent for Term Challenge SDK 3.0") parser.add_argument("--instruction", required=True, help="Task instruction from validator") args = parser.parse_args() - + _log("=" * 60) - _log("SuperAgent Starting (SDK 3.0 - litellm)") + _log("SuperAgent Starting (SDK 3.0 - Chutes API)") _log("=" * 60) _log(f"Model: {CONFIG['model']}") _log(f"Reasoning effort: {CONFIG.get('reasoning_effort', 'default')}") _log(f"Instruction: {args.instruction[:200]}...") _log("-" * 60) - + # Initialize components start_time = time.time() - - llm = LiteLLMClient( + + llm = LLMClient( model=CONFIG["model"], temperature=CONFIG.get("temperature"), max_tokens=CONFIG.get("max_tokens", 16384), ) - + tools = ToolRegistry() ctx = AgentContext(instruction=args.instruction) - + _log("Components initialized") - + try: run_agent_loop( llm=llm, diff --git a/astuces/08-cost-optimization.md b/astuces/08-cost-optimization.md index 1230292..afc958c 100644 --- a/astuces/08-cost-optimization.md +++ b/astuces/08-cost-optimization.md @@ -2,21 +2,20 @@ ## Cost Breakdown -For Claude Sonnet via OpenRouter: +Typical LLM pricing (varies by model): -| Token Type | Cost per 1M | -|------------|-------------| -| Input tokens | $3.00 | -| Cached input | $0.30 (90% off) | -| Output tokens | $15.00 | +| Token Type | Typical Cost per 1M | +|------------|---------------------| +| Input tokens | $1.00 - $15.00 | +| Cached input | 10-50% of input | +| Output tokens | $2.00 - $60.00 | For a typical task: - 50 turns - 100k context average - 500 output tokens per turn -**Without optimization**: 50 × 100k × $3/1M = **$15 per task** -**With 90% caching**: 50 × 100k × $0.30/1M = **$1.50 per task** +Costs vary significantly by model choice. DeepSeek models are typically more cost-effective than Claude or GPT-4. ## Optimization Strategies diff --git a/pyproject.toml b/pyproject.toml index 864644a..41d9205 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "tomli-w>=1.0", "rich>=13.0", "typer>=0.12.0", - "litellm>=1.50.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 02cebfd..c4242ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,3 @@ tomli>=2.0;python_version<'3.11' tomli-w>=1.0 rich>=13.0 typer>=0.12.0 -litellm>=1.50.0 diff --git a/rules/02-architecture-patterns.md b/rules/02-architecture-patterns.md index 933b156..2ab44f6 100644 --- a/rules/02-architecture-patterns.md +++ b/rules/02-architecture-patterns.md @@ -20,7 +20,7 @@ my-agent/ │ │ ├── loop.py # Main loop │ │ └── compaction.py # Context management (MANDATORY) │ ├── llm/ -│ │ └── client.py # LLM client (litellm) +│ │ └── client.py # LLM client (Chutes API) │ └── tools/ │ └── ... # Tools ├── requirements.txt # Dependencies @@ -275,7 +275,7 @@ flowchart TB ### Implementation ```python -# Définition des outils (format OpenAI/litellm) +# Tool definition (OpenAI-compatible format) TOOLS = [ { "name": "run_command", diff --git a/rules/06-llm-usage-guide.md b/rules/06-llm-usage-guide.md index 417375a..aa32128 100644 --- a/rules/06-llm-usage-guide.md +++ b/rules/06-llm-usage-guide.md @@ -1,6 +1,6 @@ -# 06 - LLM Usage Guide (SDK 3.0 - litellm) +# 06 - LLM Usage Guide (SDK 3.0 - Chutes API) -This guide covers using LLMs with **litellm** (no more term_sdk). +This guide covers using LLMs with **Chutes API** via httpx (no more term_sdk). --- @@ -9,11 +9,11 @@ This guide covers using LLMs with **litellm** (no more term_sdk). ### Initialization ```python -from src.llm.client import LiteLLMClient, LLMError, CostLimitExceeded +from src.llm.client import LLMClient, LLMError, CostLimitExceeded # Create the LLM client -llm = LiteLLMClient( - model="openrouter/anthropic/claude-opus-4.5", +llm = LLMClient( + model="deepseek/deepseek-chat", temperature=0.0, # 0 = deterministic max_tokens=16384, cost_limit=10.0 # Cost limit in $ @@ -342,7 +342,7 @@ def run(self, ctx: Any): ### Defining Tools ```python -# Format OpenAI/litellm pour les outils +# Tool format (OpenAI-compatible) TOOLS = [ Tool( @@ -504,9 +504,8 @@ Community fine-tuned models are **forbidden** because they may: def setup(self): # Any official foundation model works # Examples: claude-3.5-sonnet, gpt-4o, deepseek-v3, llama-3, etc. - self.llm = LLM( - provider="openrouter", # or any supported provider - default_model="anthropic/claude-3.5-sonnet", + self.llm = LLMClient( + model="deepseek/deepseek-chat", # or any supported model temperature=0.3 ) ``` @@ -542,74 +541,24 @@ def run(self, ctx: Any): ## Prompt Caching -Prompt caching significantly reduces costs and latency by reusing previously processed prompts. The Term SDK supports caching via the `cache=True` parameter. +Prompt caching significantly reduces costs and latency by reusing previously processed prompts. -### Enabling Caching in Term SDK +### Enabling Caching ```python -from src.llm.client import LiteLLMClient - -class MyAgent(Agent): - def setup(self): - self.llm = LLM( - provider="openrouter", - default_model="anthropic/claude-3.5-sonnet", - cache=True # Enable prompt caching - ) -``` - -### How Caching Works by Provider - -| Provider | Caching | Configuration | -|----------|---------|---------------| -| **OpenAI** | Automatic | No config needed, min 1024 tokens | -| **Anthropic** | Manual | Requires `cache_control` breakpoints | -| **DeepSeek** | Automatic | No config needed | -| **Google Gemini** | Automatic | No config needed, min 4096 tokens | -| **Groq** | Automatic | No config needed | - -### Anthropic Cache Control (Important!) - -Anthropic requires explicit `cache_control` breakpoints. This is critical for cost savings: - -**Pricing:** -- **Cache writes**: 1.25x input price (slightly more expensive) -- **Cache reads**: 0.1x input price (90% savings!) +from src.llm.client import LLMClient -**TTL Options:** -- Default: 5 minutes -- Extended: 1 hour with `"ttl": "1h"` +# Caching is handled at the message level +llm = LLMClient( + model="deepseek/deepseek-chat", +) -### Anthropic Caching Example +# The system manages caching automatically through message preparation +``` -```python -# Structure messages with cache_control for large content -messages = [ - { - "role": "system", - "content": [ - { - "type": "text", - "text": "You are a task-solving agent." - }, - { - "type": "text", - "text": LARGE_SYSTEM_PROMPT, # Cache this! - "cache_control": { - "type": "ephemeral", - "ttl": "1h" # Optional: extend to 1 hour - } - } - ] - }, - { - "role": "user", - "content": f"Task: {ctx.instruction}" - } -] +### How Caching Works -response = self.llm.chat(messages) -``` +Caching behavior depends on the model and provider. The client handles cache_control markers automatically, stripping them for providers that don't support them. ### What to Cache @@ -625,48 +574,22 @@ response = self.llm.chat(messages) - Changing context - Small prompts (under 1024 tokens) -### Cache Placement Strategy - -```python -# Put static content FIRST, dynamic content LAST -messages = [ - { - "role": "system", - "content": [ - # Static: Cache this large prompt - { - "type": "text", - "text": STATIC_SYSTEM_PROMPT, - "cache_control": {"type": "ephemeral"} - } - ] - }, - # Dynamic: User instruction (changes each task) - {"role": "user", "content": ctx.instruction}, - # Dynamic: Previous outputs (change each iteration) - {"role": "assistant", "content": last_response}, - {"role": "user", "content": command_output} -] -``` - ### Inspecting Cache Usage ```python -response = self.llm.chat(messages, usage=True) +response = llm.chat(messages) -# Check cache statistics -if response.usage: - cached_tokens = response.usage.get("cached_tokens", 0) - cache_discount = response.usage.get("cache_discount", 0) - print(f"Cached: {cached_tokens} tokens, saved: ${cache_discount:.4f}") +# Check cache statistics from response tokens +if response.tokens: + cached_tokens = response.tokens.get("cached", 0) + print(f"Cached: {cached_tokens} tokens") ``` ### Cost Optimization Tips 1. **Keep static content first** - Cache hits require matching prefixes -2. **Use 1-hour TTL for long sessions** - Avoids repeated cache writes -3. **Batch related requests** - Maximize cache hits within TTL window -4. **Monitor cache_discount** - Negative = cache write, positive = savings +2. **Batch related requests** - Maximize cache hits within TTL window +3. **Monitor token usage** - Track cached vs uncached tokens --- @@ -681,4 +604,4 @@ if response.usage: | Token awareness | Truncate long outputs | | Clear prompts | Specific format requirements | | Tool definitions | Well-documented parameters | -| **Prompt caching** | Enable `cache=True`, use cache_control for Anthropic | +| **Prompt caching** | Use static prompts first for better cache hits | diff --git a/rules/08-error-handling.md b/rules/08-error-handling.md index 24137c5..24c471d 100644 --- a/rules/08-error-handling.md +++ b/rules/08-error-handling.md @@ -47,12 +47,12 @@ Errors parsing LLM responses: ```python import time -from src.llm.client import LiteLLMClient, LLMError, CostLimitExceeded +from src.llm.client import LLMClient, LLMError, CostLimitExceeded class RobustLLMClient: def __init__(self, ctx: Any): self.ctx = ctx - self.llm = LLM(default_model="anthropic/claude-3.5-sonnet") + self.llm = LLMClient(model="deepseek/deepseek-chat") self.max_retries = 3 self.base_delay = 5 diff --git a/src/__init__.py b/src/__init__.py index 40cb881..2d5c793 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -4,7 +4,7 @@ Inspired by OpenAI Codex CLI, BaseAgent is designed to solve terminal-based coding tasks autonomously using LLMs. -SDK 3.0 Compatible - Uses litellm instead of term_sdk. +SDK 3.0 Compatible - Uses Chutes API via httpx instead of term_sdk. Usage: python agent.py --instruction "Your task here..." @@ -15,8 +15,8 @@ # Import main components for convenience from src.config.defaults import CONFIG -from src.tools.registry import ToolRegistry from src.output.jsonl import emit +from src.tools.registry import ToolRegistry __all__ = [ "CONFIG", diff --git a/src/api/retry.py b/src/api/retry.py index 13f9830..7515f27 100644 --- a/src/api/retry.py +++ b/src/api/retry.py @@ -12,13 +12,13 @@ from src.config.models import RetryConfig - T = TypeVar("T") @dataclass class RetryState: """State of a retry operation.""" + attempt: int last_error: Optional[Exception] last_status_code: Optional[int] @@ -27,57 +27,57 @@ class RetryState: class RetryHandler: """Handles retry logic with exponential backoff.""" - + def __init__(self, config: RetryConfig): self.config = config - + def calculate_delay(self, attempt: int) -> float: """Calculate delay with exponential backoff and jitter. - + Args: attempt: Current attempt number (1-indexed) - + Returns: Delay in seconds with jitter """ # Exponential backoff: base_delay * 2^(attempt-1) exp_delay = self.config.base_delay * (2 ** (attempt - 1)) - + # Cap at max_delay delay = min(exp_delay, self.config.max_delay) - + # Add jitter (0.9 to 1.1 multiplier) jitter = random.uniform(0.9, 1.1) - + return delay * jitter - + def should_retry(self, error: Exception, attempt: int) -> bool: """Determine if we should retry based on the error. - + Args: error: The exception that occurred attempt: Current attempt number - + Returns: True if we should retry """ if attempt >= self.config.max_attempts: return False - + # Check for HTTP status codes if isinstance(error, httpx.HTTPStatusError): return error.response.status_code in self.config.retry_on_status - + # Retry on connection errors if isinstance(error, (httpx.ConnectError, httpx.TimeoutException)): return True - + # Retry on specific exception types if isinstance(error, (ConnectionError, TimeoutError)): return True - + return False - + def execute( self, func: Callable[..., T], @@ -86,62 +86,64 @@ def execute( **kwargs: Any, ) -> T: """Execute a function with retry logic. - + Args: func: Function to execute *args: Positional arguments for func on_retry: Optional callback called before each retry **kwargs: Keyword arguments for func - + Returns: Result of func - + Raises: The last exception if all retries fail """ state = RetryState(attempt=0, last_error=None, last_status_code=None, total_delay=0) - + while True: state.attempt += 1 - + try: return func(*args, **kwargs) except Exception as e: state.last_error = e - + # Extract status code if available if isinstance(e, httpx.HTTPStatusError): state.last_status_code = e.response.status_code - + # Check if we should retry if not self.should_retry(e, state.attempt): raise - + # Calculate and apply delay delay = self.calculate_delay(state.attempt) state.total_delay += delay - + # Call retry callback if on_retry: on_retry(state) - + time.sleep(delay) def with_retry(config: RetryConfig) -> Callable[[Callable[..., T]], Callable[..., T]]: """Decorator to add retry logic to a function. - + Args: config: Retry configuration - + Returns: Decorator function """ handler = RetryHandler(config) - + def decorator(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> T: return handler.execute(func, *args, **kwargs) + return wrapper + return decorator diff --git a/src/config/__init__.py b/src/config/__init__.py index c986030..08c8595 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -1,5 +1,5 @@ """Configuration module.""" -from src.config.defaults import CONFIG, get_config, get +from src.config.defaults import CONFIG, get, get_config __all__ = ["CONFIG", "get_config", "get"] diff --git a/src/config/defaults.py b/src/config/defaults.py index da7615f..a74eda1 100644 --- a/src/config/defaults.py +++ b/src/config/defaults.py @@ -17,91 +17,64 @@ import os from typing import Any, Dict - # Main configuration - simulates Codex exec benchmark mode CONFIG: Dict[str, Any] = { # ========================================================================== # Model Settings (simulates --model gpt-5.2 -c model_reasoning_effort=xhigh) # ========================================================================== - - # Model to use via OpenRouter (prefix with openrouter/ for litellm) - "model": os.environ.get("LLM_MODEL", "openrouter/anthropic/claude-sonnet-4-20250514"), - + # Model to use via Chutes API (OpenAI-compatible) + "model": os.environ.get("LLM_MODEL", "deepseek/deepseek-chat"), # Provider - "provider": "openrouter", - + "provider": "chutes", # Reasoning effort: none, minimal, low, medium, high, xhigh (not used for Claude) "reasoning_effort": "none", - # Token limits "max_tokens": 16384, - # Temperature (0 = deterministic) "temperature": 0.0, - # ========================================================================== # Agent Execution Settings # ========================================================================== - # Maximum iterations before stopping "max_iterations": 200, - # Maximum tokens for tool output truncation (middle-out strategy) "max_output_tokens": 2500, # ~10KB - # Timeout for shell commands (seconds) "shell_timeout": 60, - # ========================================================================== # Context Management (like OpenCode/Codex) # ========================================================================== - # Model context window (Claude Opus 4.5 = 200K) "model_context_limit": 200_000, - # Reserved tokens for output "output_token_max": 32_000, - # Trigger compaction at this % of usable context (85%) "auto_compact_threshold": 0.85, - # Tool output pruning constants (from OpenCode) - "prune_protect": 40_000, # Protect this many tokens of recent tool output - "prune_minimum": 20_000, # Only prune if we can recover at least this many - + "prune_protect": 40_000, # Protect this many tokens of recent tool output + "prune_minimum": 20_000, # Only prune if we can recover at least this many # ========================================================================== - # Prompt Caching (Anthropic via OpenRouter/Bedrock) + # Prompt Caching # ========================================================================== - # Enable prompt caching "cache_enabled": True, - - # Note: Anthropic caching requires minimum tokens per breakpoint: - # - Claude Opus 4.5 on Bedrock: 4096 tokens minimum - # - Claude Sonnet/other: 1024 tokens minimum - # System prompt should be large enough to meet this threshold - + # Note: Caching behavior depends on the model/provider + # System prompt should be large enough to meet provider thresholds # ========================================================================== # Simulated Codex Flags (all enabled/bypassed for benchmark) # ========================================================================== - # --dangerously-bypass-approvals-and-sandbox "bypass_approvals": True, "bypass_sandbox": True, - # --skip-git-repo-check "skip_git_check": True, - # --enable unified_exec "unified_exec": True, - # --json (always JSONL output) "json_output": True, - # ========================================================================== # Double Confirmation for Task Completion # ========================================================================== - # Require double confirmation before marking task complete # Disabled for fully autonomous operation in evaluation mode "require_completion_confirmation": False, diff --git a/src/config/loader.py b/src/config/loader.py index 7111850..31213d6 100644 --- a/src/config/loader.py +++ b/src/config/loader.py @@ -29,7 +29,7 @@ def _flatten_dict(d: dict[str, Any], parent_key: str = "", sep: str = "_") -> di def _nest_dict(flat: dict[str, Any]) -> dict[str, Any]: """Convert a flat dictionary with underscores to nested structure.""" result: dict[str, Any] = {} - + # Map of flat keys to nested paths mappings = { "agent_model": ["model"], @@ -59,7 +59,7 @@ def _nest_dict(flat: dict[str, Any]) -> dict[str, Any]: "paths_readable_roots": ["paths", "readable_roots"], "paths_writable_roots": ["paths", "writable_roots"], } - + for flat_key, value in flat.items(): if flat_key in mappings: path = mappings[flat_key] @@ -69,34 +69,34 @@ def _nest_dict(flat: dict[str, Any]) -> dict[str, Any]: current[part] = {} current = current[part] current[path[-1]] = value - + return result def load_config_from_file(path: Path) -> AgentConfig: """Load configuration from a TOML file. - + Args: path: Path to the TOML configuration file. - + Returns: AgentConfig instance with loaded configuration. - + Raises: FileNotFoundError: If the config file doesn't exist. ValueError: If the config file is invalid. """ if not path.exists(): raise FileNotFoundError(f"Config file not found: {path}") - + with open(path, "rb") as f: raw_config = tomllib.load(f) - + # TOML structure: [agent], [cache], [retry], etc. # We need to transform it to match our Pydantic model structure flat = _flatten_dict(raw_config) nested = _nest_dict(flat) - + # Also handle direct keys from [agent] section if "agent" in raw_config: for key, value in raw_config["agent"].items(): @@ -104,12 +104,12 @@ def load_config_from_file(path: Path) -> AgentConfig: nested[key] = value if key == "reasoning" and isinstance(value, dict): nested["reasoning"] = value - + # Handle other top-level sections directly for section in ["cache", "retry", "tools", "output", "paths"]: if section in raw_config and section not in nested: nested[section] = raw_config[section] - + return AgentConfig(**nested) @@ -118,31 +118,31 @@ def load_config( overrides: Optional[dict[str, Any]] = None, ) -> AgentConfig: """Load configuration with optional overrides. - + Args: config_path: Optional path to a TOML config file. overrides: Optional dictionary of configuration overrides. - + Returns: AgentConfig instance. """ # Start with defaults config_dict: dict[str, Any] = {} - + # Load from file if provided if config_path and config_path.exists(): with open(config_path, "rb") as f: raw_config = tomllib.load(f) - + # Transform TOML structure if "agent" in raw_config: for key, value in raw_config["agent"].items(): config_dict[key] = value - + for section in ["cache", "retry", "tools", "output", "paths"]: if section in raw_config: config_dict[section] = raw_config[section] - + # Apply overrides if overrides: for key, value in overrides.items(): @@ -157,19 +157,19 @@ def load_config( current[parts[-1]] = value else: config_dict[key] = value - + return AgentConfig(**config_dict) def find_config_file() -> Optional[Path]: """Find the configuration file in standard locations. - + Searches in order: 1. ./config.toml 2. ./superagent.toml 3. ~/.config/superagent/config.toml 4. ~/.superagent/config.toml - + Returns: Path to the config file if found, None otherwise. """ @@ -179,9 +179,9 @@ def find_config_file() -> Optional[Path]: Path.home() / ".config" / "superagent" / "config.toml", Path.home() / ".superagent" / "config.toml", ] - + for path in search_paths: if path.exists(): return path - + return None diff --git a/src/config/models.py b/src/config/models.py index 172efe4..ef5c889 100644 --- a/src/config/models.py +++ b/src/config/models.py @@ -5,13 +5,13 @@ import os from enum import Enum from pathlib import Path -from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator class ReasoningEffort(str, Enum): """Reasoning effort levels for the model.""" + NONE = "none" MINIMAL = "minimal" LOW = "low" @@ -22,40 +22,47 @@ class ReasoningEffort(str, Enum): class OutputMode(str, Enum): """Output mode for the agent.""" + HUMAN = "human" JSON = "json" class Provider(str, Enum): """LLM provider.""" - OPENROUTER = "openrouter" + + CHUTES = "chutes" OPENAI = "openai" ANTHROPIC = "anthropic" class ReasoningConfig(BaseModel): """Configuration for model reasoning.""" - effort: ReasoningEffort = Field(default=ReasoningEffort.HIGH, description="Reasoning effort level") + + effort: ReasoningEffort = Field( + default=ReasoningEffort.HIGH, description="Reasoning effort level" + ) class CacheConfig(BaseModel): """Configuration for prompt caching.""" + enabled: bool = Field(default=True, description="Enable prompt caching") class RetryConfig(BaseModel): """Configuration for retry logic.""" + max_attempts: int = Field(default=5, description="Maximum retry attempts") base_delay: float = Field(default=1.0, description="Base delay in seconds") max_delay: float = Field(default=60.0, description="Maximum delay in seconds") retry_on_status: list[int] = Field( - default=[429, 500, 502, 503, 504], - description="HTTP status codes to retry on" + default=[429, 500, 502, 503, 504], description="HTTP status codes to retry on" ) class ToolsConfig(BaseModel): """Configuration for available tools.""" + shell_enabled: bool = Field(default=True, description="Enable shell execution") shell_timeout: int = Field(default=30, description="Shell timeout in seconds") file_ops_enabled: bool = Field(default=True, description="Enable file operations") @@ -66,6 +73,7 @@ class ToolsConfig(BaseModel): class OutputConfig(BaseModel): """Configuration for output formatting.""" + mode: OutputMode = Field(default=OutputMode.HUMAN, description="Output mode") streaming: bool = Field(default=True, description="Enable streaming output") colors: bool = Field(default=True, description="Enable colored output") @@ -73,10 +81,11 @@ class OutputConfig(BaseModel): class PathsConfig(BaseModel): """Configuration for file paths.""" + cwd: str = Field(default="", description="Working directory") readable_roots: list[str] = Field(default=[], description="Additional readable directories") writable_roots: list[str] = Field(default=[], description="Additional writable directories") - + @field_validator("cwd", mode="before") @classmethod def resolve_cwd(cls, v: str) -> str: @@ -88,21 +97,15 @@ def resolve_cwd(cls, v: str) -> str: class AgentConfig(BaseModel): """Main configuration for the SuperAgent.""" - + # Model settings - model: str = Field( - default="anthropic/claude-opus-4-20250514", - description="Model to use" - ) - provider: Provider = Field( - default=Provider.OPENROUTER, - description="LLM provider" - ) + model: str = Field(default="anthropic/claude-opus-4-20250514", description="Model to use") + provider: Provider = Field(default=Provider.CHUTES, description="LLM provider") max_iterations: int = Field(default=50, description="Maximum iterations") timeout: int = Field(default=120, description="Timeout per LLM call in seconds") temperature: float = Field(default=0.7, description="Generation temperature") max_tokens: int = Field(default=16384, description="Maximum tokens for response") - + # Sub-configurations reasoning: ReasoningConfig = Field(default_factory=ReasoningConfig) cache: CacheConfig = Field(default_factory=CacheConfig) @@ -110,32 +113,34 @@ class AgentConfig(BaseModel): tools: ToolsConfig = Field(default_factory=ToolsConfig) output: OutputConfig = Field(default_factory=OutputConfig) paths: PathsConfig = Field(default_factory=PathsConfig) - + @property def working_directory(self) -> Path: """Get the working directory as a Path object.""" return Path(self.paths.cwd or os.getcwd()) - + def get_api_key(self) -> str: """Get the API key for the configured provider.""" env_vars = { - Provider.OPENROUTER: ["OPENROUTER_API_KEY"], + Provider.CHUTES: ["CHUTES_API_KEY"], Provider.OPENAI: ["OPENAI_API_KEY"], Provider.ANTHROPIC: ["ANTHROPIC_API_KEY"], } - + for var in env_vars.get(self.provider, []): key = os.environ.get(var) if key: return key - - raise ValueError(f"No API key found for provider {self.provider}. " - f"Set one of: {env_vars.get(self.provider, [])}") - + + raise ValueError( + f"No API key found for provider {self.provider}. " + f"Set one of: {env_vars.get(self.provider, [])}" + ) + def get_base_url(self) -> str: """Get the base URL for the configured provider.""" urls = { - Provider.OPENROUTER: "https://openrouter.ai/api/v1", + Provider.CHUTES: "https://api.chutes.ai/v1", Provider.OPENAI: "https://api.openai.com/v1", Provider.ANTHROPIC: "https://api.anthropic.com/v1", } diff --git a/src/core/__init__.py b/src/core/__init__.py index f4f1956..f1e9bdb 100644 --- a/src/core/__init__.py +++ b/src/core/__init__.py @@ -1,28 +1,27 @@ """Core module - agent loop, session management, and context compaction.""" -from src.core.executor import ( - AgentExecutor, - ExecutionResult, - RiskLevel, - SandboxPolicy, -) - # Compaction module (like OpenCode/Codex context management) from src.core.compaction import ( - manage_context, - estimate_tokens, + AUTO_COMPACT_THRESHOLD, + MODEL_CONTEXT_LIMIT, + OUTPUT_TOKEN_MAX, + PRUNE_MARKER, + PRUNE_MINIMUM, + PRUNE_PROTECT, estimate_message_tokens, + estimate_tokens, estimate_total_tokens, is_overflow, + manage_context, needs_compaction, prune_old_tool_outputs, run_compaction, - MODEL_CONTEXT_LIMIT, - OUTPUT_TOKEN_MAX, - AUTO_COMPACT_THRESHOLD, - PRUNE_PROTECT, - PRUNE_MINIMUM, - PRUNE_MARKER, +) +from src.core.executor import ( + AgentExecutor, + ExecutionResult, + RiskLevel, + SandboxPolicy, ) # Import run_agent_loop diff --git a/src/core/agent.py b/src/core/agent.py index a6725bb..e946829 100644 --- a/src/core/agent.py +++ b/src/core/agent.py @@ -2,11 +2,10 @@ from __future__ import annotations -import sys from pathlib import Path -from typing import Any, Callable, Optional +from typing import Callable, Optional -from src.api.client import LLMClient, LLMResponse, FunctionCall +from src.api.client import FunctionCall, LLMClient, LLMResponse from src.config.models import AgentConfig from src.core.session import Session from src.output.processor import OutputProcessor @@ -16,14 +15,14 @@ class Agent: """Main agent that runs the LLM loop with tool execution. - + This implements the core agent loop similar to Codex CLI: 1. Send messages to LLM 2. If LLM returns tool calls, execute them 3. Feed results back to LLM 4. Repeat until no more tool calls (needs_follow_up = False) """ - + def __init__( self, config: Optional[AgentConfig] = None, @@ -31,7 +30,7 @@ def __init__( output_processor: Optional[OutputProcessor] = None, ): """Initialize the agent. - + Args: config: Agent configuration cwd: Working directory (defaults to current) @@ -39,15 +38,15 @@ def __init__( """ self.config = config or AgentConfig() self.cwd = cwd or Path(self.config.paths.cwd or ".").resolve() - + # Initialize components self.client = LLMClient(self.config) self.tools = ToolRegistry(self.cwd) self.output = output_processor or OutputProcessor(self.config) - + # Session state self.session: Optional[Session] = None - + def run( self, prompt: str, @@ -55,59 +54,59 @@ def run( on_tool_call: Optional[Callable[[str, dict], None]] = None, ) -> str: """Run the agent with a user prompt. - + Args: prompt: User's instruction/prompt on_message: Optional callback for assistant messages on_tool_call: Optional callback for tool calls - + Returns: Final assistant message """ # Create session self.session = Session(config=self.config, cwd=self.cwd) - + # Add system prompt system_prompt = get_system_prompt(cwd=self.cwd) self.session.add_system_message(system_prompt) - + # Add user message self.session.add_user_message(prompt) - + # Emit session started self.output.emit_turn_started(self.session) - + # Run the agent loop try: final_message = self._run_loop(on_message, on_tool_call) self.session.mark_done(final_message) self.output.emit_turn_completed(self.session, final_message) return final_message - + except Exception as e: error_msg = f"Agent error: {e}" self.output.emit_error(error_msg) self.session.mark_done(error_msg) raise - + finally: self.client.close() - + def _run_loop( self, on_message: Optional[Callable[[str], None]] = None, on_tool_call: Optional[Callable[[str, dict], None]] = None, ) -> str: """Run the main agent loop. - + Returns: Final assistant message """ if not self.session: raise RuntimeError("No session initialized") - + last_message = "" - + while True: # Check iteration limit if not self.session.increment_iteration(): @@ -115,42 +114,42 @@ def _run_loop( f"Reached maximum iterations ({self.config.max_iterations})" ) break - + # Get tools for the LLM tools = self.tools.get_tools_for_llm() - + # Call the LLM self.output.emit_thinking() - + response = self.client.chat( messages=self.session.get_messages_for_api(), tools=tools, ) - + # Update token usage self.session.update_usage( response.input_tokens, response.output_tokens, response.cached_tokens, ) - + # Process the response needs_follow_up = self._process_response( response, on_message, on_tool_call, ) - + # Store last message if response.text: last_message = response.text - + # If no tool calls, we're done if not needs_follow_up: break - + return last_message - + def _process_response( self, response: LLMResponse, @@ -158,84 +157,86 @@ def _process_response( on_tool_call: Optional[Callable[[str, dict], None]] = None, ) -> bool: """Process an LLM response. - + Args: response: The LLM response on_message: Callback for messages on_tool_call: Callback for tool calls - + Returns: True if follow-up is needed (tool calls were made) """ if not self.session: raise RuntimeError("No session initialized") - + # Handle text response if response.text: self.output.emit_assistant_message(response.text) if on_message: on_message(response.text) - + # Check for tool calls if not response.has_function_calls: # No tool calls - add response and we're done self.session.add_assistant_message(response.text) return False - + # Build tool_calls format for the message tool_calls_data = [] for call in response.function_calls: - tool_calls_data.append({ - "id": call.id, - "type": "function", - "function": { - "name": call.name, - "arguments": str(call.arguments), - }, - }) - + tool_calls_data.append( + { + "id": call.id, + "type": "function", + "function": { + "name": call.name, + "arguments": str(call.arguments), + }, + } + ) + # Add assistant message with tool calls self.session.add_assistant_message( response.text or "", tool_calls=tool_calls_data, ) - + # Execute each tool call for call in response.function_calls: result = self._execute_tool_call(call, on_tool_call) - + # Add tool result to conversation self.session.add_tool_result( tool_call_id=call.id, name=call.name, content=result.to_message(), ) - + # Need follow-up since we executed tools return True - + def _execute_tool_call( self, call: FunctionCall, on_tool_call: Optional[Callable[[str, dict], None]] = None, ) -> ToolResult: """Execute a single tool call. - + Args: call: The function call to execute on_tool_call: Optional callback - + Returns: ToolResult from execution """ self.output.emit_tool_call_start(call.name, call.arguments) - + if on_tool_call: on_tool_call(call.name, call.arguments) - + # Execute the tool result = self.tools.execute(call.name, call.arguments) - + self.output.emit_tool_call_end(call.name, result) - + return result diff --git a/src/core/compaction.py b/src/core/compaction.py index 6876605..4ed4ae3 100644 --- a/src/core/compaction.py +++ b/src/core/compaction.py @@ -13,10 +13,10 @@ import sys import time -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional if TYPE_CHECKING: - from src.llm.client import LiteLLMClient + from src.llm.client import LLMClient # ============================================================================= # Constants (matching OpenCode) @@ -59,6 +59,7 @@ # Token Estimation # ============================================================================= + def estimate_tokens(text: str) -> int: """Estimate tokens from text length (4 chars per token heuristic).""" return max(0, len(text or "") // APPROX_CHARS_PER_TOKEN) @@ -67,7 +68,7 @@ def estimate_tokens(text: str) -> int: def estimate_message_tokens(msg: Dict[str, Any]) -> int: """Estimate tokens for a single message.""" tokens = 0 - + # Content tokens content = msg.get("content") if isinstance(content, str): @@ -79,17 +80,17 @@ def estimate_message_tokens(msg: Dict[str, Any]) -> int: # Images count as ~1000 tokens roughly if part.get("type") == "image_url": tokens += 1000 - + # Tool calls tokens (function name + arguments) tool_calls = msg.get("tool_calls", []) for tc in tool_calls: func = tc.get("function", {}) tokens += estimate_tokens(func.get("name", "")) tokens += estimate_tokens(func.get("arguments", "")) - + # Role overhead (~4 tokens) tokens += 4 - + return tokens @@ -102,6 +103,7 @@ def estimate_total_tokens(messages: List[Dict[str, Any]]) -> int: # Overflow Detection # ============================================================================= + def get_usable_context() -> int: """Get usable context window (total - reserved for output).""" return MODEL_CONTEXT_LIMIT - OUTPUT_TOKEN_MAX @@ -123,6 +125,7 @@ def needs_compaction(messages: List[Dict[str, Any]]) -> bool: # Tool Output Pruning # ============================================================================= + def _log(msg: str) -> None: """Log to stderr.""" timestamp = time.strftime("%H:%M:%S") @@ -135,80 +138,82 @@ def prune_old_tool_outputs( ) -> List[Dict[str, Any]]: """ Prune old tool outputs to save tokens. - + Strategy (exactly like OpenCode compaction.ts lines 49-89): 1. Go backwards through messages 2. Skip first 2 user turns (most recent) 3. Accumulate tool output tokens 4. Once we've accumulated PRUNE_PROTECT (40K) tokens, start marking for prune 5. Only actually prune if we can recover > PRUNE_MINIMUM (20K) tokens - + Args: messages: List of messages protect_last_turns: Number of recent user turns to skip (default: 2) - + Returns: Messages with old tool outputs pruned (content replaced with PRUNE_MARKER) """ if not messages: return messages - + total = 0 # Total tool output tokens seen (going backwards) pruned = 0 # Tokens that will be pruned to_prune: List[int] = [] # Indices to prune turns = 0 # User turn counter - + # Go backwards through messages (like OpenCode) for msg_index in range(len(messages) - 1, -1, -1): msg = messages[msg_index] - + # Count user turns if msg.get("role") == "user": turns += 1 - + # Skip the first N user turns (most recent) if turns < protect_last_turns: continue - + # Process tool messages if msg.get("role") == "tool": content = msg.get("content", "") - + # Skip already pruned if content == PRUNE_MARKER: # Already compacted, stop here (like OpenCode: break loop) break - + estimate = estimate_tokens(content) total += estimate - + # Once we've accumulated more than PRUNE_PROTECT tokens, # start marking older outputs for pruning if total > PRUNE_PROTECT: pruned += estimate to_prune.append(msg_index) - + _log(f"Prune scan: {total} total tokens, {pruned} prunable") - + # Only prune if we can recover enough tokens if pruned <= PRUNE_MINIMUM: _log(f"Prune skipped: only {pruned} tokens recoverable (min: {PRUNE_MINIMUM})") return messages - + _log(f"Pruning {len(to_prune)} tool outputs, recovering ~{pruned} tokens") - + # Create new messages with pruned content indices_to_prune = set(to_prune) result = [] for i, msg in enumerate(messages): if i in indices_to_prune: - result.append({ - **msg, - "content": PRUNE_MARKER, - }) + result.append( + { + **msg, + "content": PRUNE_MARKER, + } + ) else: result.append(msg) - + return result @@ -216,15 +221,16 @@ def prune_old_tool_outputs( # AI Compaction # ============================================================================= + def run_compaction( - llm: "LiteLLMClient", + llm: "LLMClient", messages: List[Dict[str, Any]], system_prompt: str, model: Optional[str] = None, ) -> List[Dict[str, Any]]: """ Compact conversation history using AI summarization. - + Process (like Codex): 1. Send all messages + compaction prompt to LLM 2. Get summary response @@ -232,25 +238,27 @@ def run_compaction( - Original system prompt - Summary as user message (with prefix) - Ready for continuation - + Args: llm: LLM client for summarization messages: Current message history system_prompt: Original system prompt to preserve model: Model to use (defaults to current) - + Returns: Compacted message list """ _log("Starting AI compaction...") - + # Build compaction request compaction_messages = messages.copy() - compaction_messages.append({ - "role": "user", - "content": COMPACTION_PROMPT, - }) - + compaction_messages.append( + { + "role": "user", + "content": COMPACTION_PROMPT, + } + ) + try: # Call LLM for summary (no tools, just text) response = llm.chat( @@ -258,24 +266,24 @@ def run_compaction( model=model, max_tokens=4096, # Summary should be concise ) - + summary = response.text or "" - + if not summary: _log("Compaction failed: empty response") return messages - + summary_tokens = estimate_tokens(summary) _log(f"Compaction complete: {summary_tokens} token summary") - + # Build new message list compacted = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": SUMMARY_PREFIX + summary}, ] - + return compacted - + except Exception as e: _log(f"Compaction failed: {e}") # Return original messages if compaction fails @@ -286,57 +294,58 @@ def run_compaction( # Main Context Management # ============================================================================= + def manage_context( messages: List[Dict[str, Any]], system_prompt: str, - llm: "LiteLLMClient", + llm: "LLMClient", force_compaction: bool = False, ) -> List[Dict[str, Any]]: """ Main context management function. - + Called before each LLM request to ensure context fits. - + Strategy: 1. Estimate current token usage 2. If under threshold, return as-is 3. Try pruning old tool outputs first 4. If still over threshold, run AI compaction - + Args: messages: Current message history system_prompt: Original system prompt (preserved through compaction) llm: LLM client (for compaction) force_compaction: Force compaction even if under threshold - + Returns: Managed message list (possibly compacted) """ total_tokens = estimate_total_tokens(messages) usable = get_usable_context() usage_pct = (total_tokens / usable) * 100 - + _log(f"Context: {total_tokens} tokens ({usage_pct:.1f}% of {usable})") - + # Check if we need to do anything if not force_compaction and not is_overflow(total_tokens): return messages - - _log(f"Context overflow detected, managing...") - + + _log("Context overflow detected, managing...") + # Step 1: Try pruning old tool outputs pruned = prune_old_tool_outputs(messages) pruned_tokens = estimate_total_tokens(pruned) - + if not is_overflow(pruned_tokens) and not force_compaction: _log(f"Pruning sufficient: {total_tokens} -> {pruned_tokens} tokens") return pruned - + # Step 2: Run AI compaction _log(f"Pruning insufficient ({pruned_tokens} tokens), running AI compaction...") compacted = run_compaction(llm, pruned, system_prompt) compacted_tokens = estimate_total_tokens(compacted) - + _log(f"Compaction result: {total_tokens} -> {compacted_tokens} tokens") - + return compacted diff --git a/src/core/executor.py b/src/core/executor.py index c2b5b68..c9be8f1 100644 --- a/src/core/executor.py +++ b/src/core/executor.py @@ -2,7 +2,7 @@ This module wraps ToolRegistry with: - Timeout enforcement -- Execution tracking +- Execution tracking - Batch execution support - Risk assessment for commands @@ -14,48 +14,51 @@ import concurrent.futures import json import time -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum, auto from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple if TYPE_CHECKING: pass # AgentContext is duck-typed -from src.tools.registry import ToolRegistry, ExecutorConfig, ExecutorStats from src.tools.base import ToolResult +from src.tools.registry import ExecutorConfig, ExecutorStats, ToolRegistry class RiskLevel(Enum): """Risk level for tool operations.""" - SAFE = auto() # Read-only operations - LOW = auto() # Network/environment access - MEDIUM = auto() # File modifications - HIGH = auto() # Destructive operations + + SAFE = auto() # Read-only operations + LOW = auto() # Network/environment access + MEDIUM = auto() # File modifications + HIGH = auto() # Destructive operations CRITICAL = auto() # System destruction potential class SandboxPolicy(Enum): """Policy for sandbox enforcement.""" - STRICT = auto() # Block all risky operations - PROMPT = auto() # Prompt user for risky operations + + STRICT = auto() # Block all risky operations + PROMPT = auto() # Prompt user for risky operations PERMISSIVE = auto() # Allow most operations @dataclass class ExecutionResult: """Result of a tool execution with timing and metadata.""" + tool_name: str result: ToolResult duration_ms: int cached: bool = False risk_level: Optional[RiskLevel] = None - + @property def success(self) -> bool: """Whether the execution was successful.""" return self.result.success - + @property def output(self) -> str: """The output from the tool.""" @@ -65,9 +68,10 @@ def output(self) -> str: @dataclass class CachedExecutionResult: """A cached execution result with timestamp.""" + result: ExecutionResult cached_at: float - + def is_valid(self, ttl: float) -> bool: """Check if the cached result is still valid.""" return (time.time() - self.cached_at) < ttl @@ -76,21 +80,21 @@ def is_valid(self, ttl: float) -> bool: class AgentExecutor: """ High-level executor for agent tool calls. - + Wraps ToolRegistry with: - Timeout enforcement - Execution tracking - Batch execution support - Risk assessment - Result caching - + Example: executor = AgentExecutor(cwd=Path("/project")) result = executor.execute(ctx, "read_file", {"file_path": "main.py"}) if result.success: print(result.output) """ - + def __init__( self, cwd: Optional[Path] = None, @@ -98,7 +102,7 @@ def __init__( sandbox_policy: SandboxPolicy = SandboxPolicy.PROMPT, ): """Initialize the executor. - + Args: cwd: Working directory for tool operations config: Executor configuration (timeouts, concurrency, etc.) @@ -109,22 +113,22 @@ def __init__( self.registry._config = config self._sandbox_policy = sandbox_policy self._execution_cache: Dict[str, CachedExecutionResult] = {} - + @property def config(self) -> ExecutorConfig: """Get the executor configuration.""" return self.registry._config - + @property def cwd(self) -> Path: """Get the current working directory.""" return self.registry.cwd - + @cwd.setter def cwd(self, value: Path) -> None: """Set the current working directory.""" self.registry.cwd = value - + def execute( self, ctx: "AgentContext", @@ -134,39 +138,35 @@ def execute( ) -> ExecutionResult: """ Execute a single tool with timeout. - + Args: ctx: Agent context with shell() method tool_name: Name of tool to execute arguments: Tool arguments timeout: Optional timeout override (seconds) - + Returns: ExecutionResult with result and timing """ start = time.time() - + # Assess risk level risk = self.assess_risk(tool_name, arguments) - + # Use config timeout if not specified effective_timeout = timeout or self.registry._config.default_timeout - + # Execute with timeout cached = False try: - result = self._execute_with_timeout( - ctx, tool_name, arguments, effective_timeout - ) + result = self._execute_with_timeout(ctx, tool_name, arguments, effective_timeout) except TimeoutError: - result = ToolResult.fail( - f"Tool {tool_name} timed out after {effective_timeout}s" - ) + result = ToolResult.fail(f"Tool {tool_name} timed out after {effective_timeout}s") except Exception as e: result = ToolResult.fail(f"Tool {tool_name} failed: {e}") - + duration_ms = int((time.time() - start) * 1000) - + return ExecutionResult( tool_name=tool_name, result=result, @@ -174,7 +174,7 @@ def execute( cached=cached, risk_level=risk, ) - + def _execute_with_timeout( self, ctx: "AgentContext", @@ -183,28 +183,26 @@ def _execute_with_timeout( timeout: float, ) -> ToolResult: """Execute with timeout using threading. - + Args: ctx: Agent context tool_name: Name of tool to execute arguments: Tool arguments timeout: Timeout in seconds - + Returns: ToolResult from the tool - + Raises: TimeoutError: If execution exceeds timeout """ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit( - self.registry.execute, ctx, tool_name, arguments - ) + future = executor.submit(self.registry.execute, ctx, tool_name, arguments) try: return future.result(timeout=timeout) except concurrent.futures.TimeoutError: raise TimeoutError(f"Execution timed out after {timeout}s") - + def execute_batch( self, ctx: "AgentContext", @@ -213,27 +211,24 @@ def execute_batch( ) -> List[ExecutionResult]: """ Execute multiple tools. - + Args: ctx: Agent context calls: List of (tool_name, arguments) tuples parallel: If True, execute in parallel (up to max_concurrent) - + Returns: List of ExecutionResults in same order as calls """ if not calls: return [] - + if not parallel: - return [ - self.execute(ctx, name, args) - for name, args in calls - ] - + return [self.execute(ctx, name, args) for name, args in calls] + # Parallel execution with ordering preserved results: List[Optional[ExecutionResult]] = [None] * len(calls) - + with concurrent.futures.ThreadPoolExecutor( max_workers=self.registry._config.max_concurrent ) as executor: @@ -242,7 +237,7 @@ def execute_batch( executor.submit(self.execute, ctx, name, args): i for i, (name, args) in enumerate(calls) } - + for future in concurrent.futures.as_completed(future_to_index): index = future_to_index[future] try: @@ -256,9 +251,9 @@ def execute_batch( duration_ms=0, cached=False, ) - + return results # type: ignore - + def execute_sequential( self, ctx: "AgentContext", @@ -266,16 +261,16 @@ def execute_sequential( ) -> List[ExecutionResult]: """ Execute tools sequentially (alias for execute_batch with parallel=False). - + Args: ctx: Agent context calls: List of (tool_name, arguments) tuples - + Returns: List of ExecutionResults in same order as calls """ return self.execute_batch(ctx, calls, parallel=False) - + def assess_risk( self, tool_name: str, @@ -283,11 +278,11 @@ def assess_risk( ) -> RiskLevel: """ Assess risk level of a tool call. - + Args: tool_name: Name of the tool arguments: Tool arguments - + Returns: RiskLevel indicating the risk of this operation """ @@ -295,35 +290,35 @@ def assess_risk( if tool_name == "shell_command": cmd = arguments.get("command", "") return self._assess_command_risk(cmd) - + # Default risk by tool category if tool_name in ("read_file", "list_dir", "grep_files", "view_image"): return RiskLevel.SAFE - + if tool_name == "write_file": return RiskLevel.MEDIUM - + if tool_name == "apply_patch": return RiskLevel.MEDIUM - + if tool_name == "update_plan": return RiskLevel.SAFE - + # Unknown tools get medium risk return RiskLevel.MEDIUM - + def _assess_command_risk(self, command: str) -> RiskLevel: """ Assess risk of a shell command. - + Args: command: Shell command string - + Returns: RiskLevel for the command """ cmd = command.lower().strip() - + # Critical: system destruction if ( (cmd.startswith("rm -rf /") and (cmd == "rm -rf /" or cmd.startswith("rm -rf / "))) @@ -332,7 +327,7 @@ def _assess_command_risk(self, command: str) -> RiskLevel: or "mkfs" in cmd ): return RiskLevel.CRITICAL - + # High: destructive operations if ( "rm -rf" in cmd @@ -346,7 +341,7 @@ def _assess_command_risk(self, command: str) -> RiskLevel: or ("wget" in cmd and "| sh" in cmd) ): return RiskLevel.HIGH - + # Medium: file modifications if ( "mv " in cmd @@ -358,17 +353,11 @@ def _assess_command_risk(self, command: str) -> RiskLevel: or "pip install" in cmd ): return RiskLevel.MEDIUM - + # Low: network or environment access - if ( - "curl" in cmd - or "wget" in cmd - or "ssh" in cmd - or "env" in cmd - or "export" in cmd - ): + if "curl" in cmd or "wget" in cmd or "ssh" in cmd or "env" in cmd or "export" in cmd: return RiskLevel.LOW - + # Safe: read-only operations if ( cmd.startswith("ls") @@ -384,9 +373,9 @@ def _assess_command_risk(self, command: str) -> RiskLevel: or cmd.startswith("git diff") ): return RiskLevel.SAFE - + return RiskLevel.MEDIUM - + def can_auto_approve( self, tool_name: str, @@ -394,40 +383,40 @@ def can_auto_approve( ) -> bool: """ Check if a tool call can be auto-approved based on risk and policy. - + Args: tool_name: Name of the tool arguments: Tool arguments - + Returns: True if the call can be auto-approved """ risk = self.assess_risk(tool_name, arguments) - + if self._sandbox_policy == SandboxPolicy.STRICT: return risk == RiskLevel.SAFE elif self._sandbox_policy == SandboxPolicy.PROMPT: return risk in (RiskLevel.SAFE, RiskLevel.LOW) else: # PERMISSIVE return risk != RiskLevel.CRITICAL - + def stats(self) -> ExecutorStats: """Get execution statistics.""" return self.registry.stats() - + def clear_cache(self) -> None: """Clear the result cache.""" self.registry.clear_cache() self._execution_cache.clear() - + def cache_size(self) -> int: """Get the current cache size.""" return len(self.registry._cache) - + def get_tools_for_llm(self) -> list: """Get tool specs for LLM.""" return self.registry.get_tools_for_llm() - + def get_plan(self) -> list: """Get the current execution plan.""" return self.registry.get_plan() diff --git a/src/core/loop.py b/src/core/loop.py index 93715ec..1e136e4 100644 --- a/src/core/loop.py +++ b/src/core/loop.py @@ -3,7 +3,7 @@ Implements the agentic loop that: 1. Receives instruction via --instruction argument -2. Calls LLM with tools (using litellm) +2. Calls LLM with tools (using Chutes API) 3. Executes tool calls 4. Loops until task is complete 5. Emits JSONL events throughout @@ -17,38 +17,33 @@ from __future__ import annotations -import time import sys +import time from pathlib import Path -from typing import Any, Dict, List, Optional, TYPE_CHECKING - -from src.llm.client import LLMError, CostLimitExceeded +from typing import TYPE_CHECKING, Any, Dict, List +from src.core.compaction import ( + manage_context, +) +from src.llm.client import CostLimitExceeded, LLMError from src.output.jsonl import ( - emit, - next_item_id, - reset_item_counter, + ItemCompletedEvent, + ItemStartedEvent, ThreadStartedEvent, - TurnStartedEvent, TurnCompletedEvent, TurnFailedEvent, - ItemStartedEvent, - ItemCompletedEvent, - ErrorEvent, + TurnStartedEvent, + emit, make_agent_message_item, make_command_execution_item, - make_file_change_item, + next_item_id, + reset_item_counter, ) from src.prompts.system import get_system_prompt -from src.utils.truncate import middle_out_truncate, APPROX_BYTES_PER_TOKEN -from src.core.compaction import ( - manage_context, - estimate_total_tokens, - needs_compaction, -) +from src.utils.truncate import middle_out_truncate if TYPE_CHECKING: - from src.llm.client import LiteLLMClient + from src.llm.client import LLMClient from src.tools.registry import ToolRegistry @@ -64,15 +59,12 @@ def _add_cache_control_to_message( ) -> Dict[str, Any]: """Add cache_control to a message, converting to multipart if needed.""" content = msg.get("content") - + if isinstance(content, list): - has_cache = any( - isinstance(p, dict) and "cache_control" in p - for p in content - ) + has_cache = any(isinstance(p, dict) and "cache_control" in p for p in content) if has_cache: return msg - + new_content = list(content) for i in range(len(new_content) - 1, -1, -1): part = new_content[i] @@ -80,7 +72,7 @@ def _add_cache_control_to_message( new_content[i] = {**part, "cache_control": cache_control} break return {**msg, "content": new_content} - + if isinstance(content, str): return { **msg, @@ -92,7 +84,7 @@ def _add_cache_control_to_message( } ], } - + return msg @@ -104,48 +96,48 @@ def _apply_caching( Apply prompt caching like OpenCode does: - Cache first 2 system messages (stable prefix) - Cache last 2 non-system messages (extends cache to cover conversation history) - + How Anthropic caching works: - Cache is based on IDENTICAL PREFIX - A cache_control breakpoint tells Anthropic to cache everything BEFORE it - By marking the last messages, we cache the entire conversation history - Each new request only adds new messages after the cached prefix - + Anthropic limits: - Maximum 4 cache_control breakpoints - Minimum tokens per breakpoint: 1024 (Sonnet), 4096 (Opus 4.5 on Bedrock) - + Reference: OpenCode transform.ts applyCaching() """ if not enabled or not messages: return messages - + cache_control = {"type": "ephemeral"} - + # Separate system and non-system message indices system_indices = [] non_system_indices = [] - + for i, msg in enumerate(messages): if msg.get("role") == "system": system_indices.append(i) else: non_system_indices.append(i) - + # Determine which messages to cache: # 1. First 2 system messages (stable system prompt) # 2. Last 2 non-system messages (extends cache to conversation history) # Total: up to 4 breakpoints (Anthropic limit) indices_to_cache = set() - + # Add first 2 system messages for idx in system_indices[:2]: indices_to_cache.add(idx) - + # Add last 2 non-system messages for idx in non_system_indices[-2:]: indices_to_cache.add(idx) - + # Build result with cache_control added to selected messages result = [] for i, msg in enumerate(messages): @@ -153,79 +145,83 @@ def _apply_caching( result.append(_add_cache_control_to_message(msg, cache_control)) else: result.append(msg) - + cached_system = len([i for i in indices_to_cache if i in system_indices]) cached_final = len([i for i in indices_to_cache if i in non_system_indices]) - + if indices_to_cache: - _log(f"Prompt caching: {cached_system} system + {cached_final} final messages marked ({len(indices_to_cache)} breakpoints)") - + _log( + f"Prompt caching: {cached_system} system + {cached_final} final messages marked ({len(indices_to_cache)} breakpoints)" + ) + return result def run_agent_loop( - llm: "LiteLLMClient", + llm: "LLMClient", tools: "ToolRegistry", ctx: Any, config: Dict[str, Any], ) -> None: """ Run the main agent loop. - + Args: - llm: LiteLLM client + llm: LLM client tools: Tool registry with available tools ctx: Agent context with instruction, shell(), done() config: Configuration dictionary """ # Reset item counter for fresh session reset_item_counter() - + # Generate session ID session_id = f"sess_{int(time.time() * 1000)}" - + # 1. Emit thread.started emit(ThreadStartedEvent(thread_id=session_id)) - + # 2. Emit turn.started emit(TurnStartedEvent()) - + # 3. Build initial messages cwd = Path(ctx.cwd) system_prompt = get_system_prompt(cwd=cwd) - + messages: List[Dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": ctx.instruction}, ] - + # 4. Get initial terminal state _log("Getting initial state...") initial_result = ctx.shell("pwd && ls -la") max_output_tokens = config.get("max_output_tokens", 2500) initial_state = middle_out_truncate(initial_result.output, max_tokens=max_output_tokens) - - messages.append({ - "role": "user", - "content": f"Current directory and files:\n```\n{initial_state}\n```", - }) - + + messages.append( + { + "role": "user", + "content": f"Current directory and files:\n```\n{initial_state}\n```", + } + ) + # 5. Initialize tracking total_input_tokens = 0 total_output_tokens = 0 total_cached_tokens = 0 pending_completion = False last_agent_message = "" - + max_iterations = config.get("max_iterations", 200) cache_enabled = config.get("cache_enabled", True) - + # 6. Main loop iteration = 0 while iteration < max_iterations: iteration += 1 _log(f"Iteration {iteration}/{max_iterations}") - + try: # ================================================================ # Context Management (replaces sliding window) @@ -236,27 +232,27 @@ def run_agent_loop( system_prompt=system_prompt, llm=llm, ) - + # If compaction happened, update our messages reference if len(context_messages) < len(messages): _log(f"Context compacted: {len(messages)} -> {len(context_messages)} messages") messages = context_messages - + # ================================================================ # Apply caching (system prompt only for stability) # ================================================================ cached_messages = _apply_caching(context_messages, enabled=cache_enabled) - + # Get tool specs tool_specs = tools.get_tools_for_llm() - + # ================================================================ # Call LLM with retry logic # ================================================================ max_retries = 5 response = None last_error = None - + for attempt in range(1, max_retries + 1): try: response = llm.chat( @@ -267,7 +263,7 @@ def run_agent_loop( "reasoning": {"effort": config.get("reasoning_effort", "xhigh")}, }, ) - + # Track token usage from response if hasattr(response, "tokens") and response.tokens: tokens = response.tokens @@ -275,84 +271,89 @@ def run_agent_loop( total_input_tokens += tokens.get("input", 0) total_output_tokens += tokens.get("output", 0) total_cached_tokens += tokens.get("cached", 0) - + break # Success, exit retry loop - + except CostLimitExceeded: raise # Don't retry cost limit errors - + except LLMError as e: last_error = e - error_msg = str(e.message) if hasattr(e, 'message') else str(e) + error_msg = str(e.message) if hasattr(e, "message") else str(e) _log(f"LLM error (attempt {attempt}/{max_retries}): {e.code} - {error_msg}") - + # Don't retry authentication errors if e.code in ("authentication_error", "invalid_api_key"): raise - + # Check if it's a retryable error - is_retryable = any(x in error_msg.lower() for x in [ - "504", "timeout", "empty response", "overloaded", "rate_limit" - ]) - + is_retryable = any( + x in error_msg.lower() + for x in ["504", "timeout", "empty response", "overloaded", "rate_limit"] + ) + if attempt < max_retries and is_retryable: wait_time = 10 * attempt # 10s, 20s, 30s, 40s _log(f"Retrying in {wait_time} seconds...") time.sleep(wait_time) else: raise - + except Exception as e: last_error = e error_msg = str(e) - _log(f"Unexpected error (attempt {attempt}/{max_retries}): {type(e).__name__}: {error_msg}") - + _log( + f"Unexpected error (attempt {attempt}/{max_retries}): {type(e).__name__}: {error_msg}" + ) + is_retryable = any(x in error_msg.lower() for x in ["504", "timeout"]) - + if attempt < max_retries and is_retryable: wait_time = 10 * attempt _log(f"Retrying in {wait_time} seconds...") time.sleep(wait_time) else: raise - + except CostLimitExceeded as e: _log(f"Cost limit exceeded: {e}") emit(TurnFailedEvent(error={"message": f"Cost limit exceeded: {e}"})) ctx.done() return - + except LLMError as e: _log(f"LLM error (fatal): {e.code} - {e.message}") emit(TurnFailedEvent(error={"message": str(e)})) ctx.done() return - + except Exception as e: _log(f"Unexpected error (fatal): {type(e).__name__}: {e}") emit(TurnFailedEvent(error={"message": str(e)})) ctx.done() return - + # Process response text response_text = response.text or "" - + if response_text: last_agent_message = response_text - + # Emit agent message item_id = next_item_id() - emit(ItemCompletedEvent( - item=make_agent_message_item(item_id, response_text) - )) - + emit(ItemCompletedEvent(item=make_agent_message_item(item_id, response_text))) + # Check for function calls - has_function_calls = response.has_function_calls() if hasattr(response, "has_function_calls") else bool(response.function_calls) - + has_function_calls = ( + response.has_function_calls() + if hasattr(response, "has_function_calls") + else bool(response.function_calls) + ) + if not has_function_calls: # No tool calls - agent thinks it's done _log("No tool calls in response") - + # Always do verification before completing (self-questioning) if pending_completion: # Agent already verified - complete the task @@ -362,7 +363,7 @@ def run_agent_loop( # First time without tool calls - ask for self-verification pending_completion = True messages.append({"role": "assistant", "content": response_text}) - + # Build verification prompt with original instruction verification_prompt = f""" # Self-Verification Required - CRITICAL @@ -405,96 +406,118 @@ def run_agent_loop( Proceed with verification now. """ - - messages.append({ - "role": "user", - "content": verification_prompt, - }) + + messages.append( + { + "role": "user", + "content": verification_prompt, + } + ) _log("Requesting self-verification before completion") continue - + # Reset pending completion flag (agent is still working) pending_completion = False - + # Add assistant message with tool calls assistant_msg: Dict[str, Any] = {"role": "assistant", "content": response_text} - + # Build tool_calls for message history tool_calls_data = [] for call in response.function_calls: - tool_calls_data.append({ - "id": call.id, - "type": "function", - "function": { - "name": call.name, - "arguments": str(call.arguments) if isinstance(call.arguments, dict) else call.arguments, - }, - }) - + tool_calls_data.append( + { + "id": call.id, + "type": "function", + "function": { + "name": call.name, + "arguments": ( + str(call.arguments) + if isinstance(call.arguments, dict) + else call.arguments + ), + }, + } + ) + if tool_calls_data: assistant_msg["tool_calls"] = tool_calls_data - + messages.append(assistant_msg) - + # Execute each tool call for call in response.function_calls: tool_name = call.name tool_args = call.arguments if isinstance(call.arguments, dict) else {} - + _log(f"Executing tool: {tool_name}") - + # Emit item.started item_id = next_item_id() - emit(ItemStartedEvent( - item=make_command_execution_item( - item_id=item_id, - command=f"{tool_name}({tool_args})", - status="in_progress", + emit( + ItemStartedEvent( + item=make_command_execution_item( + item_id=item_id, + command=f"{tool_name}({tool_args})", + status="in_progress", + ) ) - )) - + ) + # Execute tool result = tools.execute(ctx, tool_name, tool_args) - + # Truncate output using middle-out (keeps beginning and end) output = middle_out_truncate(result.output, max_tokens=max_output_tokens) - + # Emit item.completed - emit(ItemCompletedEvent( - item=make_command_execution_item( - item_id=item_id, - command=f"{tool_name}", - status="completed" if result.success else "failed", - aggregated_output=output, - exit_code=0 if result.success else 1, + emit( + ItemCompletedEvent( + item=make_command_execution_item( + item_id=item_id, + command=f"{tool_name}", + status="completed" if result.success else "failed", + aggregated_output=output, + exit_code=0 if result.success else 1, + ) ) - )) - + ) + # Handle image injection if result.inject_content: # Add image to next user message - messages.append({ - "role": "user", - "content": [ - {"type": "text", "text": f"Image from {tool_name}:"}, - result.inject_content, - ], - }) - + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": f"Image from {tool_name}:"}, + result.inject_content, + ], + } + ) + # Add tool result to messages - messages.append({ - "role": "tool", - "tool_call_id": call.id, - "content": output, - }) - + messages.append( + { + "role": "tool", + "tool_call_id": call.id, + "content": output, + } + ) + # 7. Emit turn.completed - emit(TurnCompletedEvent(usage={ - "input_tokens": total_input_tokens, - "cached_input_tokens": total_cached_tokens, - "output_tokens": total_output_tokens, - })) - + emit( + TurnCompletedEvent( + usage={ + "input_tokens": total_input_tokens, + "cached_input_tokens": total_cached_tokens, + "output_tokens": total_output_tokens, + } + ) + ) + _log(f"Loop complete after {iteration} iterations") - _log(f"Tokens: {total_input_tokens} input, {total_cached_tokens} cached, {total_output_tokens} output") + _log( + f"Tokens: {total_input_tokens} input, {total_cached_tokens} cached, {total_output_tokens} output" + ) ctx.done() diff --git a/src/core/session.py b/src/core/session.py index 6a93734..065c2ef 100644 --- a/src/core/session.py +++ b/src/core/session.py @@ -14,14 +14,15 @@ @dataclass class TokenUsage: """Token usage tracking.""" + input_tokens: int = 0 output_tokens: int = 0 cached_tokens: int = 0 - + @property def total_tokens(self) -> int: return self.input_tokens + self.output_tokens - + def add(self, other: "TokenUsage") -> None: """Add usage from another TokenUsage instance.""" self.input_tokens += other.input_tokens @@ -32,115 +33,120 @@ def add(self, other: "TokenUsage") -> None: @dataclass class Message: """A message in the conversation history.""" + role: str # "system", "user", "assistant", "tool" content: str tool_call_id: Optional[str] = None tool_calls: Optional[list[dict[str, Any]]] = None name: Optional[str] = None # For tool messages - + def to_dict(self) -> dict[str, Any]: """Convert to API format.""" msg: dict[str, Any] = {"role": self.role, "content": self.content} - + if self.tool_call_id: msg["tool_call_id"] = self.tool_call_id - + if self.tool_calls: msg["tool_calls"] = self.tool_calls - + if self.name: msg["name"] = self.name - + return msg @dataclass class Session: """Manages the state of an agent session.""" - + id: str = field(default_factory=lambda: str(uuid.uuid4())) config: AgentConfig = field(default_factory=AgentConfig) cwd: Path = field(default_factory=Path.cwd) - + # Conversation history messages: list[Message] = field(default_factory=list) - + # Token usage usage: TokenUsage = field(default_factory=TokenUsage) - + # Iteration tracking iteration: int = 0 - + # Timestamps started_at: datetime = field(default_factory=datetime.now) last_activity: datetime = field(default_factory=datetime.now) - + # Status is_done: bool = False final_message: Optional[str] = None - + def add_system_message(self, content: str) -> None: """Add a system message.""" self.messages.append(Message(role="system", content=content)) self._update_activity() - + def add_user_message(self, content: str) -> None: """Add a user message.""" self.messages.append(Message(role="user", content=content)) self._update_activity() - + def add_assistant_message( self, content: str, tool_calls: Optional[list[dict[str, Any]]] = None, ) -> None: """Add an assistant message.""" - self.messages.append(Message( - role="assistant", - content=content, - tool_calls=tool_calls, - )) + self.messages.append( + Message( + role="assistant", + content=content, + tool_calls=tool_calls, + ) + ) self._update_activity() - + def add_tool_result(self, tool_call_id: str, name: str, content: str) -> None: """Add a tool result message.""" - self.messages.append(Message( - role="tool", - content=content, - tool_call_id=tool_call_id, - name=name, - )) + self.messages.append( + Message( + role="tool", + content=content, + tool_call_id=tool_call_id, + name=name, + ) + ) self._update_activity() - + def get_messages_for_api(self) -> list[dict[str, Any]]: """Get messages formatted for the API.""" return [msg.to_dict() for msg in self.messages] - + def update_usage(self, input_tokens: int, output_tokens: int, cached_tokens: int = 0) -> None: """Update token usage.""" self.usage.input_tokens += input_tokens self.usage.output_tokens += output_tokens self.usage.cached_tokens += cached_tokens - + def increment_iteration(self) -> bool: """Increment iteration and check if we should continue. - + Returns: True if we can continue, False if max iterations reached """ self.iteration += 1 return self.iteration < self.config.max_iterations - + def mark_done(self, final_message: Optional[str] = None) -> None: """Mark the session as done.""" self.is_done = True self.final_message = final_message self._update_activity() - + def _update_activity(self) -> None: """Update last activity timestamp.""" self.last_activity = datetime.now() - + @property def elapsed_time(self) -> float: """Get elapsed time in seconds.""" diff --git a/src/exec/__init__.py b/src/exec/__init__.py index 7929338..9fa1293 100644 --- a/src/exec/__init__.py +++ b/src/exec/__init__.py @@ -5,16 +5,16 @@ """ from .runner import ( - OutputChunk, + DEFAULT_TIMEOUT, + MAX_OUTPUT_SIZE, + SENSITIVE_PATTERNS, ExecOptions, ExecOutput, + OutputChunk, + build_safe_environment, execute_command, execute_command_streaming, - build_safe_environment, truncate_output, - DEFAULT_TIMEOUT, - MAX_OUTPUT_SIZE, - SENSITIVE_PATTERNS, ) __all__ = [ diff --git a/src/exec/runner.py b/src/exec/runner.py index f264a19..8055564 100644 --- a/src/exec/runner.py +++ b/src/exec/runner.py @@ -15,7 +15,7 @@ from dataclasses import dataclass, field from enum import Enum, auto from pathlib import Path -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional # ============================================================================= # Constants @@ -27,12 +27,12 @@ # Patterns in variable names that indicate sensitive data (case-insensitive). # These will be excluded from the environment passed to child processes. SENSITIVE_PATTERNS: List[str] = [ - "KEY", # API_KEY, SSH_KEY, etc. - "SECRET", # AWS_SECRET, etc. - "TOKEN", # AUTH_TOKEN, etc. - "PASSWORD", # DB_PASSWORD, etc. - "CREDENTIAL", # GOOGLE_CREDENTIALS, etc. - "PRIVATE", # PRIVATE_KEY, etc. + "KEY", # API_KEY, SSH_KEY, etc. + "SECRET", # AWS_SECRET, etc. + "TOKEN", # AUTH_TOKEN, etc. + "PASSWORD", # DB_PASSWORD, etc. + "CREDENTIAL", # GOOGLE_CREDENTIALS, etc. + "PRIVATE", # PRIVATE_KEY, etc. ] @@ -40,8 +40,10 @@ # Output Types # ============================================================================= + class OutputChunkType(Enum): """Type of output chunk.""" + STDOUT = auto() STDERR = auto() @@ -49,26 +51,27 @@ class OutputChunkType(Enum): @dataclass class OutputChunk: """Output chunk from streaming execution. - + Represents a single chunk of output from either stdout or stderr. """ + chunk_type: OutputChunkType data: str - + @classmethod def stdout(cls, data: str) -> "OutputChunk": """Create a stdout chunk.""" return cls(chunk_type=OutputChunkType.STDOUT, data=data) - + @classmethod def stderr(cls, data: str) -> "OutputChunk": """Create a stderr chunk.""" return cls(chunk_type=OutputChunkType.STDERR, data=data) - + def is_stdout(self) -> bool: """Check if this is a stdout chunk.""" return self.chunk_type == OutputChunkType.STDOUT - + def is_stderr(self) -> bool: """Check if this is a stderr chunk.""" return self.chunk_type == OutputChunkType.STDERR @@ -78,21 +81,23 @@ def is_stderr(self) -> bool: # Options and Output # ============================================================================= + @dataclass class ExecOptions: """Options for command execution. - + Attributes: cwd: Working directory for command execution. timeout: Maximum execution time in seconds. env: Additional environment variables to set. capture_output: Whether to capture stdout/stderr. """ + cwd: Path = field(default_factory=Path.cwd) timeout: float = DEFAULT_TIMEOUT env: Dict[str, str] = field(default_factory=dict) capture_output: bool = True - + def __post_init__(self): """Ensure cwd is a Path object.""" if isinstance(self.cwd, str): @@ -102,7 +107,7 @@ def __post_init__(self): @dataclass class ExecOutput: """Output from command execution. - + Attributes: stdout: Standard output content. stderr: Standard error content. @@ -111,6 +116,7 @@ class ExecOutput: duration: Execution duration in seconds. timed_out: Whether the command timed out. """ + stdout: str stderr: str aggregated: str @@ -123,30 +129,31 @@ class ExecOutput: # Environment Building # ============================================================================= + def build_safe_environment(overrides: Optional[Dict[str, str]] = None) -> Dict[str, str]: """Build a safe environment for command execution. - + - Inherits ALL environment variables from parent process - Excludes variables containing sensitive patterns (KEY, SECRET, TOKEN, etc.) - Forces non-interactive mode for common tools - Applies any custom overrides - + Args: overrides: Custom environment variables to set (override filtering). - + Returns: Dictionary of safe environment variables. """ # Start with filtered parent environment env: Dict[str, str] = {} - + for key, value in os.environ.items(): # Exclude variables with sensitive patterns (case-insensitive) key_upper = key.upper() is_sensitive = any(pattern in key_upper for pattern in SENSITIVE_PATTERNS) if not is_sensitive: env[key] = value - + # Force non-interactive mode for common tools # This prevents commands from hanging waiting for user input env["CI"] = "true" # npm/yarn/pnpm/create-* use this @@ -155,12 +162,12 @@ def build_safe_environment(overrides: Optional[Dict[str, str]] = None) -> Dict[s env["YARN_ENABLE_IMMUTABLE_INSTALLS"] = "false" # yarn env["NO_COLOR"] = "1" # disable color codes env["TERM"] = "dumb" # simple terminal - + # Apply custom overrides if overrides: for key, value in overrides.items(): env[key] = value - + return env @@ -168,21 +175,22 @@ def build_safe_environment(overrides: Optional[Dict[str, str]] = None) -> Dict[s # Output Truncation # ============================================================================= + def truncate_output(data: bytes) -> str: """Truncate output if it exceeds MAX_OUTPUT_SIZE. - + Args: data: Raw bytes from subprocess output. - + Returns: Decoded string, truncated if necessary with a notice. """ # Decode with replacement for invalid UTF-8 text = data.decode("utf-8", errors="replace") - + if len(text) > MAX_OUTPUT_SIZE: return f"{text[:MAX_OUTPUT_SIZE]}...\n[Output truncated, {len(text)} bytes total]" - + return text @@ -190,22 +198,23 @@ def truncate_output(data: bytes) -> str: # Command Execution # ============================================================================= + async def execute_command( command: List[str], options: Optional[ExecOptions] = None, ) -> ExecOutput: """Execute a command with timeout and output capture. - + Args: command: Command and arguments as a list of strings. options: Execution options (uses defaults if not provided). - + Returns: ExecOutput containing stdout, stderr, exit code, duration, etc. """ if options is None: options = ExecOptions() - + # Handle empty command if not command: return ExecOutput( @@ -216,15 +225,15 @@ async def execute_command( duration=0.0, timed_out=False, ) - + program = command[0] args = command[1:] - + start_time = time.monotonic() - + # Build safe environment env = build_safe_environment(options.env) - + try: # Create subprocess process = await asyncio.create_subprocess_exec( @@ -236,21 +245,21 @@ async def execute_command( stderr=asyncio.subprocess.PIPE, env=env, ) - + try: # Wait for completion with timeout stdout_bytes, stderr_bytes = await asyncio.wait_for( process.communicate(), timeout=options.timeout, ) - + duration = time.monotonic() - start_time exit_code = process.returncode if process.returncode is not None else -1 - + # Truncate outputs if necessary stdout = truncate_output(stdout_bytes) stderr = truncate_output(stderr_bytes) - + # Build aggregated output aggregated_parts = [] if stdout: @@ -258,7 +267,7 @@ async def execute_command( if stderr: aggregated_parts.append(stderr) aggregated = "\n".join(aggregated_parts) - + return ExecOutput( stdout=stdout, stderr=stderr, @@ -267,17 +276,17 @@ async def execute_command( duration=duration, timed_out=False, ) - + except asyncio.TimeoutError: # Timeout - kill the process duration = time.monotonic() - start_time - + try: process.kill() await process.wait() except ProcessLookupError: pass # Process already terminated - + return ExecOutput( stdout="", stderr="", @@ -286,7 +295,7 @@ async def execute_command( duration=duration, timed_out=True, ) - + except FileNotFoundError: duration = time.monotonic() - start_time return ExecOutput( @@ -325,20 +334,20 @@ async def execute_command_streaming( callback: Optional[Callable[[OutputChunk], None]] = None, ) -> ExecOutput: """Execute a command with streaming output. - + Reads stdout and stderr line by line, calling the callback for each chunk. - + Args: command: Command and arguments as a list of strings. options: Execution options (uses defaults if not provided). callback: Function called with each OutputChunk as it arrives. - + Returns: ExecOutput containing full stdout, stderr, exit code, duration, etc. """ if options is None: options = ExecOptions() - + # Handle empty command if not command: return ExecOutput( @@ -349,19 +358,19 @@ async def execute_command_streaming( duration=0.0, timed_out=False, ) - + program = command[0] args = command[1:] - + start_time = time.monotonic() - + # Build safe environment env = build_safe_environment(options.env) - + # Accumulators stdout_acc: List[str] = [] stderr_acc: List[str] = [] - + try: # Create subprocess process = await asyncio.create_subprocess_exec( @@ -373,7 +382,7 @@ async def execute_command_streaming( stderr=asyncio.subprocess.PIPE, env=env, ) - + async def read_stdout(): """Read stdout line by line.""" if process.stdout is None: @@ -386,7 +395,7 @@ async def read_stdout(): stdout_acc.append(decoded) if callback: callback(OutputChunk.stdout(decoded)) - + async def read_stderr(): """Read stderr line by line.""" if process.stderr is None: @@ -399,7 +408,7 @@ async def read_stderr(): stderr_acc.append(decoded) if callback: callback(OutputChunk.stderr(decoded)) - + try: # Read streams concurrently with timeout await asyncio.wait_for( @@ -410,20 +419,24 @@ async def read_stderr(): ), timeout=options.timeout, ) - + duration = time.monotonic() - start_time exit_code = process.returncode if process.returncode is not None else -1 - + # Join accumulated output stdout = "".join(stdout_acc) stderr = "".join(stderr_acc) - + # Truncate if necessary if len(stdout) > MAX_OUTPUT_SIZE: - stdout = f"{stdout[:MAX_OUTPUT_SIZE]}...\n[Output truncated, {len(stdout)} bytes total]" + stdout = ( + f"{stdout[:MAX_OUTPUT_SIZE]}...\n[Output truncated, {len(stdout)} bytes total]" + ) if len(stderr) > MAX_OUTPUT_SIZE: - stderr = f"{stderr[:MAX_OUTPUT_SIZE]}...\n[Output truncated, {len(stderr)} bytes total]" - + stderr = ( + f"{stderr[:MAX_OUTPUT_SIZE]}...\n[Output truncated, {len(stderr)} bytes total]" + ) + # Build aggregated output aggregated_parts = [] if stdout: @@ -431,7 +444,7 @@ async def read_stderr(): if stderr: aggregated_parts.append(stderr) aggregated = "\n".join(aggregated_parts) - + return ExecOutput( stdout=stdout, stderr=stderr, @@ -440,21 +453,21 @@ async def read_stderr(): duration=duration, timed_out=False, ) - + except asyncio.TimeoutError: # Timeout - kill the process duration = time.monotonic() - start_time - + try: process.kill() await process.wait() except ProcessLookupError: pass - + # Return what we accumulated before timeout stdout = "".join(stdout_acc) stderr = "".join(stderr_acc) - + return ExecOutput( stdout=stdout, stderr=stderr, @@ -463,7 +476,7 @@ async def read_stderr(): duration=duration, timed_out=True, ) - + except FileNotFoundError: duration = time.monotonic() - start_time return ExecOutput( @@ -500,16 +513,17 @@ async def read_stderr(): # Synchronous Wrappers (convenience) # ============================================================================= + def execute_command_sync( command: List[str], options: Optional[ExecOptions] = None, ) -> ExecOutput: """Synchronous wrapper for execute_command. - + Args: command: Command and arguments as a list of strings. options: Execution options. - + Returns: ExecOutput containing stdout, stderr, exit code, etc. """ @@ -522,12 +536,12 @@ def execute_command_streaming_sync( callback: Optional[Callable[[OutputChunk], None]] = None, ) -> ExecOutput: """Synchronous wrapper for execute_command_streaming. - + Args: command: Command and arguments as a list of strings. options: Execution options. callback: Function called with each OutputChunk. - + Returns: ExecOutput containing stdout, stderr, exit code, etc. """ diff --git a/src/images/__init__.py b/src/images/__init__.py index 770cc8d..c662b3c 100644 --- a/src/images/__init__.py +++ b/src/images/__init__.py @@ -1,11 +1,11 @@ """Image handling module for SuperAgent.""" from src.images.loader import ( + MAX_HEIGHT, + MAX_WIDTH, load_image_as_data_uri, load_image_bytes, resize_image, - MAX_WIDTH, - MAX_HEIGHT, ) __all__ = [ diff --git a/src/images/loader.py b/src/images/loader.py index cf0179d..407eeef 100644 --- a/src/images/loader.py +++ b/src/images/loader.py @@ -17,7 +17,7 @@ from functools import lru_cache from io import BytesIO from pathlib import Path -from typing import Optional, Tuple +from typing import Tuple # Maximum image dimensions (like Codex) MAX_WIDTH = 2048 @@ -26,6 +26,7 @@ # Try to import PIL for image processing try: from PIL import Image + HAS_PIL = True except ImportError: HAS_PIL = False @@ -55,10 +56,10 @@ def _file_hash(path: Path) -> str: def load_image_bytes(path: Path) -> Tuple[bytes, str]: """ Load image bytes from disk. - + Args: path: Path to image file - + Returns: Tuple of (bytes, mime_type) """ @@ -76,33 +77,33 @@ def resize_image( ) -> Tuple[bytes, str]: """ Resize image if it exceeds max dimensions. - + Args: data: Image bytes mime: MIME type max_width: Maximum width max_height: Maximum height - + Returns: Tuple of (resized_bytes, mime_type) """ if not HAS_PIL: # Can't resize without PIL, return as-is return data, mime - + try: img = Image.open(BytesIO(data)) - + # Check if resize needed if img.width <= max_width and img.height <= max_height: return data, mime - + # Resize maintaining aspect ratio img.thumbnail((max_width, max_height), Image.Resampling.LANCZOS) - + # Encode back to bytes output = BytesIO() - + # Use PNG for transparency, JPEG for photos if img.mode in ("RGBA", "LA") or mime == "image/png": img.save(output, format="PNG", optimize=True) @@ -113,7 +114,7 @@ def resize_image( img = img.convert("RGB") img.save(output, format="JPEG", quality=85, optimize=True) return output.getvalue(), "image/jpeg" - + except Exception: # On any error, return original return data, mime @@ -123,46 +124,46 @@ def resize_image( def _load_cached(path_str: str, file_hash: str) -> str: """Load and cache image as data URI (internal).""" path = Path(path_str) - + # Load raw bytes data, mime = load_image_bytes(path) - + # Resize if needed data, mime = resize_image(data, mime) - + # Encode as base64 b64 = base64.b64encode(data).decode("ascii") - + return f"data:{mime};base64,{b64}" def load_image_as_data_uri(path: Path) -> str: """ Load image, resize if needed, encode as base64 data URI. - + Uses LRU cache based on file path and content hash. - + Args: path: Path to image file - + Returns: Data URI string (data:image/png;base64,...) - + Raises: FileNotFoundError: If image doesn't exist ValueError: If file is not a valid image """ path = Path(path).resolve() - + if not path.exists(): raise FileNotFoundError(f"Image not found: {path}") - + if not path.is_file(): raise ValueError(f"Not a file: {path}") - + # Get file hash for cache key file_hash = _file_hash(path) - + # Load with caching return _load_cached(str(path), file_hash) @@ -170,10 +171,10 @@ def load_image_as_data_uri(path: Path) -> str: def make_image_content(data_uri: str) -> dict: """ Create image content block for LLM API. - + Args: data_uri: Base64 data URI - + Returns: Content block dict for API """ diff --git a/src/llm/__init__.py b/src/llm/__init__.py index 8a00106..39affcb 100644 --- a/src/llm/__init__.py +++ b/src/llm/__init__.py @@ -1,5 +1,12 @@ -"""LLM module using litellm.""" +"""LLM module using httpx for Chutes API.""" -from .client import LiteLLMClient, LLMResponse, FunctionCall, CostLimitExceeded, LLMError +from .client import CostLimitExceeded, FunctionCall, LiteLLMClient, LLMClient, LLMError, LLMResponse -__all__ = ["LiteLLMClient", "LLMResponse", "FunctionCall", "CostLimitExceeded", "LLMError"] +__all__ = [ + "LLMClient", + "LiteLLMClient", + "LLMResponse", + "FunctionCall", + "CostLimitExceeded", + "LLMError", +] diff --git a/src/llm/client.py b/src/llm/client.py index 72e048b..2bf0f9d 100644 --- a/src/llm/client.py +++ b/src/llm/client.py @@ -1,17 +1,18 @@ -"""LLM Client using litellm - replaces term_sdk dependency.""" +"""LLM Client using httpx for Chutes API (OpenAI-compatible).""" from __future__ import annotations import json import os -import sys -import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional +import httpx + class CostLimitExceeded(Exception): """Raised when cost limit is exceeded.""" + def __init__(self, message: str, used: float = 0, limit: float = 0): super().__init__(message) self.used = used @@ -20,6 +21,7 @@ def __init__(self, message: str, used: float = 0, limit: float = 0): class LLMError(Exception): """LLM API error.""" + def __init__(self, message: str, code: str = "unknown"): super().__init__(message) self.message = message @@ -29,21 +31,22 @@ def __init__(self, message: str, code: str = "unknown"): @dataclass class FunctionCall: """Represents a function/tool call from the LLM.""" + id: str name: str arguments: Dict[str, Any] - + @classmethod def from_openai(cls, call: Dict[str, Any]) -> "FunctionCall": """Parse from OpenAI tool_calls format.""" func = call.get("function", {}) args_str = func.get("arguments", "{}") - + try: args = json.loads(args_str) except json.JSONDecodeError: args = {"raw": args_str} - + return cls( id=call.get("id", ""), name=func.get("name", ""), @@ -54,49 +57,67 @@ def from_openai(cls, call: Dict[str, Any]) -> "FunctionCall": @dataclass class LLMResponse: """Response from the LLM.""" + text: str = "" function_calls: List[FunctionCall] = field(default_factory=list) tokens: Optional[Dict[str, int]] = None model: str = "" finish_reason: str = "" raw: Optional[Dict[str, Any]] = None - + def has_function_calls(self) -> bool: """Check if response contains function calls.""" return len(self.function_calls) > 0 -class LiteLLMClient: - """LLM Client using litellm.""" - +class LLMClient: + """LLM Client using httpx for Chutes API (OpenAI-compatible format).""" + + # Default Chutes API configuration + DEFAULT_BASE_URL = "https://api.chutes.ai/v1" + DEFAULT_API_KEY_ENV = "CHUTES_API_KEY" + def __init__( self, model: str, temperature: Optional[float] = None, max_tokens: int = 16384, cost_limit: Optional[float] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + timeout: float = 120.0, ): self.model = model self.temperature = temperature self.max_tokens = max_tokens self.cost_limit = cost_limit or float(os.environ.get("LLM_COST_LIMIT", "10.0")) - + self.base_url = base_url or os.environ.get("CHUTES_BASE_URL", self.DEFAULT_BASE_URL) + self.timeout = timeout + + # Get API key + self._api_key = api_key or os.environ.get(self.DEFAULT_API_KEY_ENV) + if not self._api_key: + raise ValueError( + f"API key required. Set {self.DEFAULT_API_KEY_ENV} environment variable or pass api_key parameter." + ) + self._total_cost = 0.0 self._total_tokens = 0 self._request_count = 0 self._input_tokens = 0 self._output_tokens = 0 self._cached_tokens = 0 - - # Import litellm - try: - import litellm - self._litellm = litellm - # Configure litellm - litellm.drop_params = True # Drop unsupported params silently - except ImportError: - raise ImportError("litellm not installed. Run: pip install litellm") - + + # Create httpx client with timeout + self._client = httpx.Client( + base_url=self.base_url, + headers={ + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + }, + timeout=httpx.Timeout(timeout, connect=30.0), + ) + def _supports_temperature(self, model: str) -> bool: """Check if model supports temperature parameter.""" model_lower = model.lower() @@ -104,32 +125,35 @@ def _supports_temperature(self, model: str) -> bool: if any(x in model_lower for x in ["o1", "o3", "deepseek-r1"]): return False return True - + def _build_tools(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]: """Build tools in OpenAI format.""" if not tools: return None - + result = [] for tool in tools: - result.append({ - "type": "function", - "function": { - "name": tool["name"], - "description": tool.get("description", ""), - "parameters": tool.get("parameters", {"type": "object", "properties": {}}), - }, - }) + result.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {"type": "object", "properties": {}}), + }, + } + ) return result - + def chat( self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, max_tokens: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, + model: Optional[str] = None, ) -> LLMResponse: - """Send a chat request.""" + """Send a chat request to Chutes API.""" # Check cost limit if self._total_cost >= self.cost_limit: raise CostLimitExceeded( @@ -137,102 +161,145 @@ def chat( used=self._total_cost, limit=self.cost_limit, ) - - # Build request - kwargs: Dict[str, Any] = { - "model": self.model, - "messages": messages, + + # Build request payload + payload: Dict[str, Any] = { + "model": model or self.model, + "messages": self._prepare_messages(messages), "max_tokens": max_tokens or self.max_tokens, } - - if self._supports_temperature(self.model) and self.temperature is not None: - kwargs["temperature"] = self.temperature - + + if self._supports_temperature(payload["model"]) and self.temperature is not None: + payload["temperature"] = self.temperature + if tools: - kwargs["tools"] = self._build_tools(tools) - kwargs["tool_choice"] = "auto" - - # Add extra body params (like reasoning effort) + payload["tools"] = self._build_tools(tools) + payload["tool_choice"] = "auto" + + # Add extra body params (like reasoning effort) - some may be ignored by API if extra_body: - kwargs.update(extra_body) - + payload.update(extra_body) + try: - response = self._litellm.completion(**kwargs) + response = self._client.post("/chat/completions", json=payload) self._request_count += 1 - except Exception as e: - error_msg = str(e) - if "authentication" in error_msg.lower() or "api_key" in error_msg.lower(): - raise LLMError(error_msg, code="authentication_error") - elif "rate" in error_msg.lower() or "limit" in error_msg.lower(): - raise LLMError(error_msg, code="rate_limit") - else: - raise LLMError(error_msg, code="api_error") - + + # Handle HTTP errors + if response.status_code != 200: + error_body = response.text + try: + error_json = response.json() + error_msg = error_json.get("error", {}).get("message", error_body) + except (json.JSONDecodeError, KeyError): + error_msg = error_body + + # Map status codes to error codes + if response.status_code == 401: + raise LLMError(error_msg, code="authentication_error") + elif response.status_code == 429: + raise LLMError(error_msg, code="rate_limit") + elif response.status_code >= 500: + raise LLMError(error_msg, code="server_error") + else: + raise LLMError(f"HTTP {response.status_code}: {error_msg}", code="api_error") + + data = response.json() + + except httpx.TimeoutException as e: + raise LLMError(f"Request timed out: {e}", code="timeout") + except httpx.ConnectError as e: + raise LLMError(f"Connection error: {e}", code="connection_error") + except httpx.HTTPError as e: + raise LLMError(f"HTTP error: {e}", code="api_error") + # Parse response - result = LLMResponse(raw=response.model_dump() if hasattr(response, "model_dump") else None) - + result = LLMResponse(raw=data) + # Extract usage - if hasattr(response, "usage") and response.usage: - usage = response.usage - input_tokens = getattr(usage, "prompt_tokens", 0) or 0 - output_tokens = getattr(usage, "completion_tokens", 0) or 0 + usage = data.get("usage", {}) + if usage: + input_tokens = usage.get("prompt_tokens", 0) or 0 + output_tokens = usage.get("completion_tokens", 0) or 0 cached_tokens = 0 - - # Check for cached tokens - if hasattr(usage, "prompt_tokens_details"): - details = usage.prompt_tokens_details - if details and hasattr(details, "cached_tokens"): - cached_tokens = details.cached_tokens or 0 - + + # Check for cached tokens (OpenAI format) + prompt_details = usage.get("prompt_tokens_details", {}) + if prompt_details: + cached_tokens = prompt_details.get("cached_tokens", 0) or 0 + self._input_tokens += input_tokens self._output_tokens += output_tokens self._cached_tokens += cached_tokens self._total_tokens += input_tokens + output_tokens - + result.tokens = { "input": input_tokens, "output": output_tokens, "cached": cached_tokens, } - - # Calculate cost using litellm - try: - cost = self._litellm.completion_cost(completion_response=response) + + # Estimate cost (generic pricing, adjust per model if needed) + # Using conservative estimates: $3/1M input, $15/1M output + cost = (input_tokens * 3.0 / 1_000_000) + (output_tokens * 15.0 / 1_000_000) self._total_cost += cost - except Exception: - pass # Cost calculation may fail for some models - + # Extract model - result.model = getattr(response, "model", self.model) - + result.model = data.get("model", self.model) + # Extract choices - if hasattr(response, "choices") and response.choices: - choice = response.choices[0] - message = choice.message - - result.finish_reason = getattr(choice, "finish_reason", "") or "" - result.text = getattr(message, "content", "") or "" - + choices = data.get("choices", []) + if choices: + choice = choices[0] + message = choice.get("message", {}) + + result.finish_reason = choice.get("finish_reason", "") or "" + result.text = message.get("content", "") or "" + # Extract function calls - tool_calls = getattr(message, "tool_calls", None) + tool_calls = message.get("tool_calls", []) if tool_calls: for call in tool_calls: - if hasattr(call, "function"): - func = call.function - args_str = getattr(func, "arguments", "{}") - try: - args = json.loads(args_str) if isinstance(args_str, str) else args_str - except json.JSONDecodeError: - args = {"raw": args_str} - - result.function_calls.append(FunctionCall( - id=getattr(call, "id", "") or "", - name=getattr(func, "name", "") or "", + func = call.get("function", {}) + args_str = func.get("arguments", "{}") + + try: + args = json.loads(args_str) if isinstance(args_str, str) else args_str + except json.JSONDecodeError: + args = {"raw": args_str} + + result.function_calls.append( + FunctionCall( + id=call.get("id", "") or "", + name=func.get("name", "") or "", arguments=args if isinstance(args, dict) else {}, - )) - + ) + ) + return result - + + def _prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare messages for the API, cleaning up any incompatible fields.""" + prepared = [] + for msg in messages: + new_msg = dict(msg) + + # Handle content with cache_control (Anthropic-specific, strip for OpenAI compat) + content = new_msg.get("content") + if isinstance(content, list): + # Convert multipart format, removing cache_control + cleaned_parts = [] + for part in content: + if isinstance(part, dict): + cleaned_part = {k: v for k, v in part.items() if k != "cache_control"} + cleaned_parts.append(cleaned_part) + else: + cleaned_parts.append(part) + new_msg["content"] = cleaned_parts + + prepared.append(new_msg) + + return prepared + def get_stats(self) -> Dict[str, Any]: """Get usage statistics.""" return { @@ -243,7 +310,18 @@ def get_stats(self) -> Dict[str, Any]: "total_cost": self._total_cost, "request_count": self._request_count, } - + def close(self): - """Close client (no-op for litellm).""" - pass + """Close the HTTP client.""" + self._client.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + +# Alias for backward compatibility +LiteLLMClient = LLMClient diff --git a/src/main.py b/src/main.py index b73454f..43ccf33 100644 --- a/src/main.py +++ b/src/main.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -import sys from pathlib import Path from typing import Optional @@ -11,8 +10,8 @@ from rich.console import Console from src import __version__ -from src.config.loader import load_config, find_config_file -from src.config.models import AgentConfig, Provider, OutputMode +from src.config.loader import find_config_file, load_config +from src.config.models import OutputMode, Provider from src.core.agent import Agent from src.output.processor import OutputProcessor @@ -49,11 +48,9 @@ def main( @app.command("exec") def exec_command( prompt: str = typer.Argument(..., help="The task/prompt for the agent"), - # Model/Provider options model: Optional[str] = typer.Option(None, "--model", "-m", help="Model to use"), provider: Optional[Provider] = typer.Option(None, "--provider", "-p", help="LLM provider"), - # Config options config_file: Optional[Path] = typer.Option( None, @@ -61,27 +58,24 @@ def exec_command( "-c", help="Path to config file", ), - # Output options json_mode: bool = typer.Option(False, "--json", help="Output in JSONL format"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), - # Execution options workdir: Optional[Path] = typer.Option(None, "--workdir", "-w", help="Working directory"), max_iterations: Optional[int] = typer.Option(None, help="Maximum iterations"), - # Danger mode (compatibility with Codex CLI) dangerously_bypass_approvals: bool = typer.Option( - False, + False, "--dangerously-bypass-approvals-and-sandbox", help="Run without sandbox/approvals (default behavior in SuperAgent)", ), ): """Execute a task with the agent.""" - + # Load configuration config_path = config_file or find_config_file() - + overrides = {} if model: overrides["model"] = model @@ -93,42 +87,42 @@ def exec_command( overrides["max_iterations"] = max_iterations if workdir: overrides["paths.cwd"] = str(workdir) - + try: config = load_config(config_path, overrides) except Exception as e: console.print(f"[red]Error loading configuration: {e}[/red]") raise typer.Exit(1) - + # Setup working directory cwd = Path(config.paths.cwd or os.getcwd()).resolve() if not cwd.exists(): console.print(f"[red]Working directory does not exist: {cwd}[/red]") raise typer.Exit(1) - + # Initialize output processor output = OutputProcessor(config) - + # Run agent try: agent = Agent(config=config, cwd=cwd, output_processor=output) - + # In JSON mode, we don't print "Starting..." messages to stdout if not json_mode: console.print(f"[bold blue]SuperAgent v{__version__}[/bold blue]") console.print(f"Model: [cyan]{config.model}[/cyan] ({config.provider})") console.print(f"Working directory: [cyan]{cwd}[/cyan]") console.print() - + final_message = agent.run(prompt) - + # In human mode, print the final message clearly if not json_mode and final_message: console.print() console.print("[bold green]Final Result:[/bold green]") output.print_final(final_message) - - except Exception as e: + + except Exception: if verbose: console.print_exception() else: @@ -149,7 +143,7 @@ def show_config( else: console.print("No config file found, using defaults") config = load_config() - + console.print(config.model_dump_json(indent=2)) diff --git a/src/output/__init__.py b/src/output/__init__.py index 0b96b30..9dc5ec2 100644 --- a/src/output/__init__.py +++ b/src/output/__init__.py @@ -1,44 +1,43 @@ """Output module - JSONL event emission and streaming.""" from src.output.jsonl import ( - emit, - emit_raw, - next_item_id, - reset_item_counter, + ErrorEvent, + ItemCompletedEvent, + ItemStartedEvent, + ItemUpdatedEvent, ThreadStartedEvent, - TurnStartedEvent, TurnCompletedEvent, TurnFailedEvent, - ItemStartedEvent, - ItemUpdatedEvent, - ItemCompletedEvent, - ErrorEvent, + TurnStartedEvent, + emit, + emit_raw, make_agent_message_item, make_command_execution_item, + make_error_item, make_file_change_item, make_todo_list_item, - make_error_item, + next_item_id, + reset_item_counter, ) - from src.output.streaming import ( - StreamState, - StreamEvent, - StartEvent, - TextDeltaEvent, - ToolCallStartEvent, - ToolCallDeltaEvent, - ToolCallCompleteEvent, - TokenUsageEvent, CompleteEvent, + SentenceBuffer, + StartEvent, + StreamBuffer, + StreamCollector, StreamContent, + StreamEvent, + StreamProcessor, + StreamState, + StreamStats, StreamToolCall, + TextDeltaEvent, TokenCounts, - StreamStats, - StreamProcessor, - StreamBuffer, + TokenUsageEvent, + ToolCallCompleteEvent, + ToolCallDeltaEvent, + ToolCallStartEvent, WordBuffer, - SentenceBuffer, - StreamCollector, ) __all__ = [ diff --git a/src/output/events.py b/src/output/events.py index dac2c7c..5c6805d 100644 --- a/src/output/events.py +++ b/src/output/events.py @@ -10,30 +10,32 @@ class EventType(str, Enum): """Types of events that can be emitted.""" + TURN_STARTED = "turn.started" TURN_COMPLETED = "turn.completed" TURN_FAILED = "turn.failed" - + ITEM_STARTED = "item.started" ITEM_UPDATED = "item.updated" ITEM_COMPLETED = "item.completed" - + MESSAGE = "message" THINKING = "thinking" - + TOOL_CALL_START = "tool.call.start" TOOL_CALL_END = "tool.call.end" - + ERROR = "error" @dataclass class Event: """An event from the agent.""" + type: EventType timestamp: datetime = field(default_factory=datetime.now) data: dict[str, Any] = field(default_factory=dict) - + def to_dict(self) -> dict[str, Any]: """Convert to JSON-serializable dict.""" return { @@ -41,7 +43,7 @@ def to_dict(self) -> dict[str, Any]: "timestamp": self.timestamp.isoformat(), **self.data, } - + @classmethod def turn_started(cls, session_id: str) -> "Event": """Create a turn started event.""" @@ -49,7 +51,7 @@ def turn_started(cls, session_id: str) -> "Event": type=EventType.TURN_STARTED, data={"session_id": session_id}, ) - + @classmethod def turn_completed( cls, @@ -72,7 +74,7 @@ def turn_completed( }, }, ) - + @classmethod def message(cls, content: str, role: str = "assistant") -> "Event": """Create a message event.""" @@ -80,12 +82,12 @@ def message(cls, content: str, role: str = "assistant") -> "Event": type=EventType.MESSAGE, data={"content": content, "role": role}, ) - + @classmethod def thinking(cls) -> "Event": """Create a thinking event.""" return cls(type=EventType.THINKING) - + @classmethod def tool_call_start(cls, name: str, arguments: dict[str, Any]) -> "Event": """Create a tool call start event.""" @@ -93,7 +95,7 @@ def tool_call_start(cls, name: str, arguments: dict[str, Any]) -> "Event": type=EventType.TOOL_CALL_START, data={"name": name, "arguments": arguments}, ) - + @classmethod def tool_call_end( cls, @@ -112,7 +114,7 @@ def tool_call_end( "error": error, }, ) - + @classmethod def error(cls, message: str, details: Optional[dict[str, Any]] = None) -> "Event": """Create an error event.""" diff --git a/src/output/jsonl.py b/src/output/jsonl.py index 4e85090..5d164bf 100644 --- a/src/output/jsonl.py +++ b/src/output/jsonl.py @@ -18,19 +18,18 @@ from __future__ import annotations import json -import sys -import time -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional - # ============================================================================= # Thread Events # ============================================================================= + @dataclass class ThreadStartedEvent: """Emitted when a new thread/session is started.""" + thread_id: str type: str = field(default="thread.started", init=False) @@ -39,15 +38,18 @@ class ThreadStartedEvent: # Turn Events # ============================================================================= + @dataclass class TurnStartedEvent: """Emitted when a turn is started (user sends message).""" + type: str = field(default="turn.started", init=False) @dataclass class Usage: """Token usage statistics.""" + input_tokens: int = 0 cached_input_tokens: int = 0 output_tokens: int = 0 @@ -56,6 +58,7 @@ class Usage: @dataclass class TurnCompletedEvent: """Emitted when a turn is completed successfully.""" + usage: Dict[str, int] type: str = field(default="turn.completed", init=False) @@ -63,6 +66,7 @@ class TurnCompletedEvent: @dataclass class TurnFailedEvent: """Emitted when a turn fails.""" + error: Dict[str, str] type: str = field(default="turn.failed", init=False) @@ -71,9 +75,11 @@ class TurnFailedEvent: # Item Events # ============================================================================= + @dataclass class ItemStartedEvent: """Emitted when an item starts processing.""" + item: Dict[str, Any] type: str = field(default="item.started", init=False) @@ -81,6 +87,7 @@ class ItemStartedEvent: @dataclass class ItemUpdatedEvent: """Emitted when an item is updated (e.g., todo list).""" + item: Dict[str, Any] type: str = field(default="item.updated", init=False) @@ -88,6 +95,7 @@ class ItemUpdatedEvent: @dataclass class ItemCompletedEvent: """Emitted when an item completes processing.""" + item: Dict[str, Any] type: str = field(default="item.completed", init=False) @@ -96,9 +104,11 @@ class ItemCompletedEvent: # Error Events # ============================================================================= + @dataclass class ErrorEvent: """Emitted for fatal errors.""" + message: str type: str = field(default="error", init=False) @@ -107,6 +117,7 @@ class ErrorEvent: # Item Types (for item.started/completed payloads) # ============================================================================= + def make_agent_message_item(item_id: str, text: str) -> Dict[str, Any]: """Create an agent_message item.""" return { @@ -203,7 +214,7 @@ def reset_item_counter() -> None: def emit(event) -> None: """ Emit a single JSONL event to stdout. - + Args: event: Dataclass event to emit """ @@ -220,7 +231,7 @@ def emit(event) -> None: def emit_raw(data: Dict[str, Any]) -> None: """ Emit a raw dictionary as JSONL. - + Args: data: Dictionary to emit """ diff --git a/src/output/processor.py b/src/output/processor.py index 274f783..00a32dd 100644 --- a/src/output/processor.py +++ b/src/output/processor.py @@ -8,8 +8,6 @@ from rich.console import Console from rich.panel import Panel -from rich.syntax import Syntax -from rich.text import Text from src.config.models import AgentConfig, OutputMode from src.core.session import Session @@ -19,7 +17,7 @@ class OutputProcessor: """Processes and formats agent output.""" - + def __init__( self, config: AgentConfig, @@ -27,7 +25,7 @@ def __init__( stderr: TextIO = sys.stderr, ): """Initialize the output processor. - + Args: config: Agent configuration stdout: Standard output stream @@ -36,20 +34,20 @@ def __init__( self.config = config self.stdout = stdout self.stderr = stderr - + # Rich console for human-readable output self.console = Console( file=stderr, force_terminal=config.output.colors, no_color=not config.output.colors, ) - + # JSON mode outputs to stdout self.json_mode = config.output.mode == OutputMode.JSON - + def emit(self, event: Event) -> None: """Emit an event. - + Args: event: Event to emit """ @@ -57,19 +55,19 @@ def emit(self, event: Event) -> None: self._emit_json(event) else: self._emit_human(event) - + def _emit_json(self, event: Event) -> None: """Emit event as JSON line to stdout.""" line = json.dumps(event.to_dict()) print(line, file=self.stdout, flush=True) - + def _emit_human(self, event: Event) -> None: """Emit event in human-readable format to stderr.""" from src.output.events import EventType - + if event.type == EventType.TURN_STARTED: self.console.print("[dim]Session started[/dim]") - + elif event.type == EventType.TURN_COMPLETED: usage = event.data.get("usage", {}) self.console.print() @@ -78,28 +76,28 @@ def _emit_human(self, event: Event) -> None: f"{usage.get('output_tokens', 0)} out " f"(cached: {usage.get('cached_tokens', 0)})[/dim]" ) - + elif event.type == EventType.MESSAGE: content = event.data.get("content", "") if content: self.console.print() self.console.print(Panel(content, border_style="blue")) - + elif event.type == EventType.THINKING: self.console.print("[dim]Thinking...[/dim]", end="\r") - + elif event.type == EventType.TOOL_CALL_START: name = event.data.get("name", "unknown") self.console.print(f"[yellow]> {name}[/yellow]") - + elif event.type == EventType.TOOL_CALL_END: name = event.data.get("name", "unknown") success = event.data.get("success", False) output = event.data.get("output", "") - + status = "[green]OK[/green]" if success else "[red]FAILED[/red]" self.console.print(f"[dim] {status}[/dim]") - + # Show truncated output if output: lines = output.split("\n") @@ -108,56 +106,60 @@ def _emit_human(self, event: Event) -> None: else: display = output self.console.print(f"[dim]{display}[/dim]") - + elif event.type == EventType.ERROR: message = event.data.get("message", "Unknown error") self.console.print(f"[red]Error: {message}[/red]") - + # Convenience methods - + def emit_turn_started(self, session: Session) -> None: """Emit turn started event.""" self.emit(Event.turn_started(session.id)) - + def emit_turn_completed(self, session: Session, final_message: str) -> None: """Emit turn completed event.""" - self.emit(Event.turn_completed( - session_id=session.id, - final_message=final_message, - input_tokens=session.usage.input_tokens, - output_tokens=session.usage.output_tokens, - cached_tokens=session.usage.cached_tokens, - )) - + self.emit( + Event.turn_completed( + session_id=session.id, + final_message=final_message, + input_tokens=session.usage.input_tokens, + output_tokens=session.usage.output_tokens, + cached_tokens=session.usage.cached_tokens, + ) + ) + def emit_message(self, content: str, role: str = "assistant") -> None: """Emit a message event.""" self.emit(Event.message(content, role)) - + def emit_assistant_message(self, content: str) -> None: """Emit an assistant message.""" self.emit(Event.message(content, "assistant")) - + def emit_thinking(self) -> None: """Emit a thinking event.""" self.emit(Event.thinking()) - + def emit_tool_call_start(self, name: str, arguments: dict[str, Any]) -> None: """Emit tool call start event.""" self.emit(Event.tool_call_start(name, arguments)) - + def emit_tool_call_end(self, name: str, result: ToolResult) -> None: """Emit tool call end event.""" - self.emit(Event.tool_call_end( - name=name, - success=result.success, - output=result.output, - error=result.error, - )) - + self.emit( + Event.tool_call_end( + name=name, + success=result.success, + output=result.output, + error=result.error, + ) + ) + def emit_error(self, message: str, details: Optional[dict[str, Any]] = None) -> None: """Emit an error event.""" self.emit(Event.error(message, details)) - + def print_final(self, message: str) -> None: """Print the final message to stdout.""" if self.json_mode: diff --git a/src/output/streaming.py b/src/output/streaming.py index 82e092b..d9210d4 100644 --- a/src/output/streaming.py +++ b/src/output/streaming.py @@ -6,16 +6,17 @@ """ from __future__ import annotations -from dataclasses import dataclass, field -from enum import Enum, auto -from typing import List, Optional, Callable, Any + import time -import asyncio from collections import deque +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import List, Optional class StreamState(Enum): """State of the stream processor.""" + IDLE = auto() STREAMING_TEXT = auto() STREAMING_TOOL_CALL = auto() @@ -26,24 +27,28 @@ class StreamState(Enum): @dataclass class StreamEvent: """Base class for stream events.""" + pass @dataclass class StartEvent(StreamEvent): """Event indicating stream start.""" + pass -@dataclass +@dataclass class TextDeltaEvent(StreamEvent): """Event containing a text delta.""" + delta: str @dataclass class ToolCallStartEvent(StreamEvent): """Event indicating start of a tool call.""" + id: str name: str @@ -51,6 +56,7 @@ class ToolCallStartEvent(StreamEvent): @dataclass class ToolCallDeltaEvent(StreamEvent): """Event containing tool call argument delta.""" + id: str arguments: str @@ -58,12 +64,14 @@ class ToolCallDeltaEvent(StreamEvent): @dataclass class ToolCallCompleteEvent(StreamEvent): """Event indicating tool call completion.""" + id: str @dataclass class TokenUsageEvent(StreamEvent): """Event containing token usage information.""" + prompt: int completion: int @@ -71,21 +79,24 @@ class TokenUsageEvent(StreamEvent): @dataclass class CompleteEvent(StreamEvent): """Event indicating stream completion.""" + pass @dataclass class ErrorEvent(StreamEvent): """Event indicating an error occurred.""" + message: str @dataclass class TokenCounts: """Token usage counts for prompt and completion.""" + prompt: int = 0 completion: int = 0 - + def total(self) -> int: """Return total token count.""" return self.prompt + self.completion @@ -94,15 +105,17 @@ def total(self) -> int: @dataclass class StreamToolCall: """Represents a tool call being streamed.""" + id: str name: str arguments: str = "" complete: bool = False - + def parse_arguments(self) -> Optional[dict]: """Parse arguments as JSON if complete.""" if self.complete: import json + try: return json.loads(self.arguments) except: @@ -113,32 +126,33 @@ def parse_arguments(self) -> Optional[dict]: @dataclass class StreamContent: """Accumulated content from a stream.""" + text: str = "" tool_calls: List[StreamToolCall] = field(default_factory=list) tokens: TokenCounts = field(default_factory=TokenCounts) - + def append_text(self, delta: str): """Append text delta to content.""" self.text += delta - + def start_tool_call(self, id: str, name: str): """Start a new tool call.""" self.tool_calls.append(StreamToolCall(id=id, name=name)) - + def append_tool_call(self, id: str, arguments: str): """Append arguments to an existing tool call.""" for tc in self.tool_calls: if tc.id == id: tc.arguments += arguments break - + def complete_tool_call(self, id: str): """Mark a tool call as complete.""" for tc in self.tool_calls: if tc.id == id: tc.complete = True break - + def has_content(self) -> bool: """Check if any content has been accumulated.""" return bool(self.text) or bool(self.tool_calls) @@ -147,6 +161,7 @@ def has_content(self) -> bool: @dataclass class StreamStats: """Statistics about a stream.""" + state: StreamState event_count: int text_length: int @@ -158,7 +173,7 @@ class StreamStats: class StreamProcessor: """Process stream events and accumulate content.""" - + def __init__(self): self.state = StreamState.IDLE self.content = StreamContent() @@ -167,70 +182,70 @@ def __init__(self): self.first_token_time: Optional[float] = None self.last_event_time: Optional[float] = None self.event_count = 0 - + def process(self, event: StreamEvent): """Process a stream event.""" now = time.time() - + if self.start_time is None: self.start_time = now - + self.last_event_time = now self.event_count += 1 - + if isinstance(event, StartEvent): self.state = StreamState.STREAMING_TEXT - + elif isinstance(event, TextDeltaEvent): if self.first_token_time is None: self.first_token_time = now self.content.append_text(event.delta) self.state = StreamState.STREAMING_TEXT - + elif isinstance(event, ToolCallStartEvent): self.content.start_tool_call(event.id, event.name) self.state = StreamState.STREAMING_TOOL_CALL - + elif isinstance(event, ToolCallDeltaEvent): self.content.append_tool_call(event.id, event.arguments) - + elif isinstance(event, ToolCallCompleteEvent): self.content.complete_tool_call(event.id) - + elif isinstance(event, TokenUsageEvent): self.content.tokens.prompt = event.prompt self.content.tokens.completion = event.completion - + elif isinstance(event, CompleteEvent): self.state = StreamState.COMPLETE - + elif isinstance(event, ErrorEvent): self.state = StreamState.ERROR - + self.buffer.append(event) - + def time_to_first_token(self) -> Optional[float]: """Get time to first token in seconds.""" if self.start_time and self.first_token_time: return self.first_token_time - self.start_time return None - + def elapsed(self) -> Optional[float]: """Get elapsed time since stream start.""" if self.start_time: return time.time() - self.start_time return None - + def is_complete(self) -> bool: """Check if stream is complete or errored.""" return self.state in (StreamState.COMPLETE, StreamState.ERROR) - + def drain_events(self) -> List[StreamEvent]: """Drain and return all buffered events.""" events = list(self.buffer) self.buffer.clear() return events - + def stats(self) -> StreamStats: """Get current stream statistics.""" return StreamStats( @@ -246,16 +261,16 @@ def stats(self) -> StreamStats: class StreamBuffer: """Buffer for rate limiting output.""" - + def __init__(self, min_interval: float = 0.01): self.buffer = "" self.min_interval = min_interval self.last_flush = time.time() - + def push(self, text: str): """Push text to buffer.""" self.buffer += text - + def flush_if_ready(self) -> Optional[str]: """Flush buffer if minimum interval has passed.""" if time.time() - self.last_flush >= self.min_interval and self.buffer: @@ -264,14 +279,14 @@ def flush_if_ready(self) -> Optional[str]: self.buffer = "" return result return None - + def flush(self) -> str: """Force flush all buffered content.""" self.last_flush = time.time() result = self.buffer self.buffer = "" return result - + def is_empty(self) -> bool: """Check if buffer is empty.""" return not self.buffer @@ -279,27 +294,27 @@ def is_empty(self) -> bool: class WordBuffer: """Buffer for word-boundary aligned output.""" - + def __init__(self, min_words: int = 3): self.buffer = "" self.min_words = min_words - + def push(self, text: str): """Push text to buffer.""" self.buffer += text - + def flush_words(self) -> Optional[str]: """Flush complete words if minimum word count reached.""" word_count = len(self.buffer.split()) if word_count >= self.min_words: # Find last whitespace - pos = self.buffer.rfind(' ') + pos = self.buffer.rfind(" ") if pos > 0: - result = self.buffer[:pos+1] - self.buffer = self.buffer[pos+1:] + result = self.buffer[: pos + 1] + self.buffer = self.buffer[pos + 1 :] return result return None - + def flush(self) -> str: """Force flush all buffered content.""" result = self.buffer @@ -309,32 +324,32 @@ def flush(self) -> str: class SentenceBuffer: """Buffer for sentence-boundary aligned output.""" - + def __init__(self): self.buffer = "" - + def push(self, text: str): """Push text to buffer.""" self.buffer += text - + def flush_sentences(self) -> Optional[str]: """Flush complete sentences.""" endings = [". ", "! ", "? ", ".\n", "!\n", "?\n"] last_end = 0 - + for ending in endings: pos = self.buffer.rfind(ending) if pos >= 0: end = pos + len(ending) if end > last_end: last_end = end - + if last_end > 0: result = self.buffer[:last_end] self.buffer = self.buffer[last_end:] return result return None - + def flush(self) -> str: """Force flush all buffered content.""" result = self.buffer @@ -344,18 +359,18 @@ def flush(self) -> str: class StreamCollector: """Collect all stream content.""" - + def __init__(self): self.processor = StreamProcessor() - + def process(self, event: StreamEvent): """Process a stream event.""" self.processor.process(event) - + def is_complete(self) -> bool: """Check if stream is complete.""" return self.processor.is_complete() - + def result(self) -> dict: """Get collected results.""" return { diff --git a/src/prompts/system.py b/src/prompts/system.py index 48b11c0..6bb21c1 100644 --- a/src/prompts/system.py +++ b/src/prompts/system.py @@ -13,7 +13,6 @@ from pathlib import Path from typing import Dict, List, Optional - # ============================================================================= # Context Strings # ============================================================================= @@ -67,15 +66,16 @@ # Token Estimation # ============================================================================= + def estimate_tokens(text: str) -> int: """Estimate token count for text. - + Uses a simple heuristic based on character count. More accurate estimation would require a tokenizer. - + Args: text: Text to estimate tokens for. - + Returns: Estimated token count. """ @@ -89,39 +89,41 @@ def estimate_tokens(text: str) -> int: # Data Classes # ============================================================================= + @dataclass class PromptSection: """A section of the system prompt. - + Attributes: name: Section name (used as header). content: Section content. enabled: Whether this section is enabled. priority: Priority (higher = earlier in prompt). """ + name: str content: str enabled: bool = True priority: int = 0 - + def with_priority(self, priority: int) -> PromptSection: """Set priority and return self for chaining. - + Args: priority: Priority value (higher = earlier). - + Returns: Self for method chaining. """ self.priority = priority return self - + def set_enabled(self, enabled: bool) -> PromptSection: """Set enabled state and return self for chaining. - + Args: enabled: Whether section is enabled. - + Returns: Self for method chaining. """ @@ -132,10 +134,10 @@ def set_enabled(self, enabled: bool) -> PromptSection: @dataclass class SystemPrompt: """System prompt configuration. - + Supports base prompts, sections, variables, capability contexts, custom instructions, and personas. - + Attributes: base: Base prompt text. sections: Sections to include. @@ -146,6 +148,7 @@ class SystemPrompt: custom_instructions: Custom instructions. persona: Persona/role. """ + base: Optional[str] = None sections: List[PromptSection] = field(default_factory=list) variables: Dict[str, str] = field(default_factory=dict) @@ -155,133 +158,130 @@ class SystemPrompt: custom_instructions: Optional[str] = None persona: Optional[str] = None _token_count: int = 0 - + @classmethod def new(cls) -> SystemPrompt: """Create a new system prompt. - + Returns: New SystemPrompt instance. """ return cls() - + @classmethod def with_base(cls, base: str) -> SystemPrompt: """Create with base text. - + Args: base: Base prompt text. - + Returns: New SystemPrompt with base set. """ prompt = cls(base=base) prompt._recalculate_tokens() return prompt - + def set_base(self, base: str) -> None: """Set base prompt. - + Args: base: Base prompt text. """ self.base = base self._recalculate_tokens() - + def add_section(self, section: PromptSection) -> None: """Add a section. - + Args: section: Section to add. """ self.sections.append(section) self._recalculate_tokens() - + def remove_section(self, name: str) -> None: """Remove a section by name. - + Args: name: Name of section to remove. """ self.sections = [s for s in self.sections if s.name != name] self._recalculate_tokens() - + def set_variable(self, key: str, value: str) -> None: """Set a variable. - + Args: key: Variable name. value: Variable value. """ self.variables[key] = value self._recalculate_tokens() - + def set_persona(self, persona: str) -> None: """Set persona. - + Args: persona: Persona/role description. """ self.persona = persona self._recalculate_tokens() - + def set_custom_instructions(self, instructions: str) -> None: """Set custom instructions. - + Args: instructions: Custom instructions text. """ self.custom_instructions = instructions self._recalculate_tokens() - + def enable_code_execution(self) -> None: """Enable code execution context.""" self.code_execution = True self._recalculate_tokens() - + def enable_file_operations(self) -> None: """Enable file operations context.""" self.file_operations = True self._recalculate_tokens() - + def enable_web_search(self) -> None: """Enable web search context.""" self.web_search = True self._recalculate_tokens() - + def token_count(self) -> int: """Get token count estimate. - + Returns: Estimated token count. """ return self._token_count - + def render(self) -> Optional[str]: """Render the full system prompt. - + Combines persona, base, sections (sorted by priority), capability contexts, and custom instructions. - + Returns: Rendered prompt string, or None if empty. """ parts: List[str] = [] - + # Persona if self.persona: parts.append(self.persona) - + # Base prompt if self.base: rendered = self._render_template(self.base) parts.append(rendered) - + # Sections (sorted by priority, higher first) - sorted_sections = sorted( - self.sections, - key=lambda s: -s.priority - ) + sorted_sections = sorted(self.sections, key=lambda s: -s.priority) for section in sorted_sections: if section.enabled: content = self._render_template(section.content) @@ -289,7 +289,7 @@ def render(self) -> Optional[str]: parts.append(f"## {section.name}\n{content}") else: parts.append(content) - + # Capability contexts if self.code_execution: parts.append(CODE_EXECUTION_CONTEXT) @@ -297,24 +297,24 @@ def render(self) -> Optional[str]: parts.append(FILE_OPERATIONS_CONTEXT) if self.web_search: parts.append(WEB_SEARCH_CONTEXT) - + # Custom instructions if self.custom_instructions: parts.append(f"## Custom Instructions\n{self.custom_instructions}") - + if not parts: return None - + return "\n\n".join(parts) - + def _render_template(self, template: str) -> str: """Render template with variables. - + Supports both {{key}} and ${key} syntax. - + Args: template: Template string. - + Returns: Rendered string with variables substituted. """ @@ -325,7 +325,7 @@ def _render_template(self, template: str) -> str: # Support ${key} syntax result = result.replace(f"${{{key}}}", value) return result - + def _recalculate_tokens(self) -> None: """Recalculate token count estimate.""" rendered = self.render() @@ -339,11 +339,12 @@ def _recalculate_tokens(self) -> None: # Builder Pattern # ============================================================================= + class SystemPromptBuilder: """Builder for system prompts. - + Provides a fluent interface for constructing SystemPrompt instances. - + Example: prompt = (SystemPromptBuilder() .persona("You are a helpful assistant.") @@ -352,118 +353,109 @@ class SystemPromptBuilder: .code_execution() .build()) """ - + def __init__(self) -> None: """Create a new builder.""" self._prompt = SystemPrompt() - + def base(self, base: str) -> SystemPromptBuilder: """Set base prompt. - + Args: base: Base prompt text. - + Returns: Self for method chaining. """ self._prompt.base = base return self - + def persona(self, persona: str) -> SystemPromptBuilder: """Set persona. - + Args: persona: Persona/role description. - + Returns: Self for method chaining. """ self._prompt.persona = persona return self - + def section( - self, - name: str, - content: str, - priority: int = 0, - enabled: bool = True + self, name: str, content: str, priority: int = 0, enabled: bool = True ) -> SystemPromptBuilder: """Add a section. - + Args: name: Section name (used as header). content: Section content. priority: Priority (higher = earlier in prompt). enabled: Whether section is enabled. - + Returns: Self for method chaining. """ self._prompt.sections.append( - PromptSection( - name=name, - content=content, - priority=priority, - enabled=enabled - ) + PromptSection(name=name, content=content, priority=priority, enabled=enabled) ) return self - + def variable(self, key: str, value: str) -> SystemPromptBuilder: """Add a variable. - + Args: key: Variable name. value: Variable value. - + Returns: Self for method chaining. """ self._prompt.variables[key] = value return self - + def custom_instructions(self, instructions: str) -> SystemPromptBuilder: """Set custom instructions. - + Args: instructions: Custom instructions text. - + Returns: Self for method chaining. """ self._prompt.custom_instructions = instructions return self - + def code_execution(self) -> SystemPromptBuilder: """Enable code execution context. - + Returns: Self for method chaining. """ self._prompt.code_execution = True return self - + def file_operations(self) -> SystemPromptBuilder: """Enable file operations context. - + Returns: Self for method chaining. """ self._prompt.file_operations = True return self - + def web_search(self) -> SystemPromptBuilder: """Enable web search context. - + Returns: Self for method chaining. """ self._prompt.web_search = True return self - + def build(self) -> SystemPrompt: """Build the system prompt. - + Returns: Configured SystemPrompt instance. """ @@ -475,59 +467,64 @@ def build(self) -> SystemPrompt: # Presets # ============================================================================= + class Presets: """Predefined system prompts for common use cases.""" - + @staticmethod def coding_assistant() -> SystemPrompt: """Default coding assistant prompt. - + Returns: SystemPrompt configured for coding assistance. """ - return (SystemPromptBuilder() + return ( + SystemPromptBuilder() .persona("You are Fabric, an expert AI coding assistant.") .base(CODING_ASSISTANT_BASE) .code_execution() .file_operations() - .build()) - + .build() + ) + @staticmethod def research_assistant() -> SystemPrompt: """Research assistant prompt. - + Returns: SystemPrompt configured for research assistance. """ - return (SystemPromptBuilder() + return ( + SystemPromptBuilder() .persona("You are a helpful research assistant with access to web search.") .base("Help the user find and analyze information. Cite sources when possible.") .web_search() - .build()) - + .build() + ) + @staticmethod def code_reviewer() -> SystemPrompt: """Code review prompt. - + Returns: SystemPrompt configured for code review. """ - return (SystemPromptBuilder() + return ( + SystemPromptBuilder() .persona("You are an expert code reviewer.") .base(CODE_REVIEWER_BASE) .file_operations() - .build()) - + .build() + ) + @staticmethod def minimal() -> SystemPrompt: """Minimal assistant prompt. - + Returns: SystemPrompt with minimal configuration. """ - return (SystemPromptBuilder() - .base("You are a helpful assistant. Be concise.") - .build()) + return SystemPromptBuilder().base("You are a helpful assistant. Be concise.").build() # ============================================================================= @@ -779,28 +776,28 @@ def get_system_prompt( shell: Optional[str] = None, ) -> str: """Get the full system prompt with environment context. - + Uses the SYSTEM_PROMPT constant which includes autonomous behavior and mandatory verification plan instructions. - + Args: cwd: Current working directory. shell: Shell being used. - + Returns: Complete system prompt string. """ # Use the SYSTEM_PROMPT constant directly (includes all autonomous behavior instructions) cwd_str = str(cwd) if cwd else "/app" shell_str = shell or "/bin/sh" - + # Add environment section env_lines = [ f"- Working directory: {cwd_str}", f"- Platform: {platform.system()}", f"- Shell: {shell_str}", ] - + return f"{SYSTEM_PROMPT}\n\n# Environment\n" + "\n".join(env_lines) diff --git a/src/tools/__init__.py b/src/tools/__init__.py index 38a99c4..a80d77f 100644 --- a/src/tools/__init__.py +++ b/src/tools/__init__.py @@ -1,22 +1,21 @@ """Tools module - registry and tool implementations.""" -from src.tools.base import ToolResult, BaseTool, ToolMetadata +# Individual tools +from src.tools.apply_patch import ApplyPatchTool +from src.tools.base import BaseTool, ToolMetadata, ToolResult +from src.tools.list_dir import ListDirTool +from src.tools.read_file import ReadFileTool from src.tools.registry import ( - ToolRegistry, + CachedResult, ExecutorConfig, ExecutorStats, + ToolRegistry, ToolStats, - CachedResult, ) -from src.tools.specs import get_all_tools, get_tool_spec, TOOL_SPECS - -# Individual tools -from src.tools.apply_patch import ApplyPatchTool -from src.tools.read_file import ReadFileTool -from src.tools.write_file import WriteFileTool -from src.tools.list_dir import ListDirTool from src.tools.search_files import SearchFilesTool +from src.tools.specs import TOOL_SPECS, get_all_tools, get_tool_spec from src.tools.view_image import view_image +from src.tools.write_file import WriteFileTool __all__ = [ # Base diff --git a/src/tools/apply_patch.py b/src/tools/apply_patch.py index cfff32c..593c6a8 100644 --- a/src/tools/apply_patch.py +++ b/src/tools/apply_patch.py @@ -7,23 +7,22 @@ from __future__ import annotations -import os import re from dataclasses import dataclass, field -from enum import Enum from pathlib import Path from typing import Any, List, Optional, Tuple from src.tools.base import BaseTool, ToolResult - # ============================================================================= # Data Structures # ============================================================================= + @dataclass class HunkLine: """A single line in a hunk.""" + type: str # "context", "add", "remove" content: str @@ -31,6 +30,7 @@ class HunkLine: @dataclass class Hunk: """A parsed hunk from a unified diff.""" + old_start: int old_count: int new_start: int @@ -41,6 +41,7 @@ class Hunk: @dataclass class FileChange: """A file change from a unified diff.""" + old_path: Optional[Path] new_path: Optional[Path] hunks: List[Hunk] = field(default_factory=list) @@ -53,40 +54,41 @@ class FileChange: # Unified Diff Parser (matches fabric-core) # ============================================================================= + def parse_file_path(path_str: str) -> Optional[Path]: """Parse a file path from diff header. - + Handles formats: a/path, b/path, or just path """ path = path_str.strip() - + # Remove a/ or b/ prefix if path.startswith("a/"): path = path[2:] elif path.startswith("b/"): path = path[2:] - + # Remove timestamp if present (e.g., "file.txt\t2024-01-01 00:00:00") path = path.split("\t")[0].strip() - + if path == "/dev/null": return Path("/dev/null") - + return Path(path) if path else None def parse_hunk_header(line: str) -> Optional[Hunk]: """Parse a hunk header like '@@ -1,5 +1,6 @@'. - + Returns Hunk with start/count info but empty lines. """ # Strip @@ markers line = line.strip("@").strip() parts = line.split() - + if len(parts) < 2: return None - + def parse_range(s: str) -> Tuple[int, int]: """Parse range like '1,5' or '1' into (start, count).""" s = s.lstrip("-+") @@ -94,7 +96,7 @@ def parse_range(s: str) -> Tuple[int, int]: parts = s.split(",") return int(parts[0]), int(parts[1]) return int(s), 1 - + try: old_start, old_count = parse_range(parts[0]) new_start, new_count = parse_range(parts[1]) @@ -110,19 +112,19 @@ def parse_range(s: str) -> Tuple[int, int]: def parse_unified_diff(patch: str) -> List[FileChange]: """Parse a unified diff into file changes. - + Matches fabric-core parse_unified_diff() implementation. """ file_changes: List[FileChange] = [] current_change: Optional[FileChange] = None current_hunk: Optional[Hunk] = None - + lines = patch.splitlines() i = 0 - + while i < len(lines): line = lines[i] - + # Detect file header: --- a/path if line.startswith("--- "): # Save previous change @@ -131,16 +133,16 @@ def parse_unified_diff(patch: str) -> List[FileChange]: current_change.hunks.append(current_hunk) current_hunk = None file_changes.append(current_change) - + old_path = parse_file_path(line[4:]) - + # Look for +++ line if i + 1 < len(lines) and lines[i + 1].startswith("+++ "): new_path = parse_file_path(lines[i + 1][4:]) - + is_new_file = old_path is not None and str(old_path) == "/dev/null" is_deleted = new_path is not None and str(new_path) == "/dev/null" - + current_change = FileChange( old_path=None if is_new_file else old_path, new_path=None if is_deleted else new_path, @@ -149,17 +151,17 @@ def parse_unified_diff(patch: str) -> List[FileChange]: ) i += 2 continue - + # Detect hunk header: @@ -1,5 +1,6 @@ if line.startswith("@@ "): # Save previous hunk if current_change is not None and current_hunk is not None: current_change.hunks.append(current_hunk) - + current_hunk = parse_hunk_header(line) i += 1 continue - + # Parse hunk lines if current_hunk is not None: if line.startswith("+") and not line.startswith("+++"): @@ -172,15 +174,15 @@ def parse_unified_diff(patch: str) -> List[FileChange]: elif line.startswith("\\"): # "\ No newline at end of file" - ignore pass - + i += 1 - + # Save final change and hunk if current_change is not None: if current_hunk is not None: current_change.hunks.append(current_hunk) file_changes.append(current_change) - + return file_changes @@ -188,15 +190,16 @@ def parse_unified_diff(patch: str) -> List[FileChange]: # Hunk Application (matches fabric-core with fuzzy matching) # ============================================================================= + def matches_at_position(lines: List[str], match_lines: List[str], start: int) -> bool: """Check if lines match at a given position (with whitespace tolerance).""" if start + len(match_lines) > len(lines): return False - + for i, expected in enumerate(match_lines): if lines[start + i].strip() != expected.strip(): return False - + return True @@ -206,22 +209,19 @@ def find_hunk_position( suggested_start: int, ) -> int: """Find the best position to apply a hunk, with fuzzy matching. - + Matches fabric-core find_hunk_position() - searches ±50 lines. """ # Extract context and remove lines for matching - match_lines = [ - hl.content for hl in hunk.lines - if hl.type in ("context", "remove") - ] - + match_lines = [hl.content for hl in hunk.lines if hl.type in ("context", "remove")] + if not match_lines: return suggested_start - + # Try exact position first if matches_at_position(lines, match_lines, suggested_start): return suggested_start - + # Search nearby positions (within 50 lines) for offset in range(1, 51): # Try before @@ -229,16 +229,16 @@ def find_hunk_position( pos = suggested_start - offset if matches_at_position(lines, match_lines, pos): return pos - + # Try after pos = suggested_start + offset if pos < len(lines) and matches_at_position(lines, match_lines, pos): return pos - + # If we can't find a match but position is valid, use it anyway if suggested_start <= len(lines): return suggested_start - + raise ValueError(f"Could not find matching context for hunk at line {hunk.old_start}") @@ -247,39 +247,33 @@ def apply_hunks_to_lines( hunks: List[Hunk], ) -> str: """Apply hunks to existing lines. - + Applies hunks in reverse order to maintain line numbers. """ result_lines = list(original_lines) - + # Apply in reverse order for hunk in reversed(hunks): start_idx = hunk.old_start - 1 if hunk.old_start > 0 else 0 - + # Find actual position actual_start = find_hunk_position(result_lines, hunk, start_idx) - + # Count lines to remove - lines_to_remove = sum( - 1 for hl in hunk.lines - if hl.type in ("remove", "context") - ) - + lines_to_remove = sum(1 for hl in hunk.lines if hl.type in ("remove", "context")) + # Build replacement - replacement = [ - hl.content for hl in hunk.lines - if hl.type in ("add", "context") - ] - + replacement = [hl.content for hl in hunk.lines if hl.type in ("add", "context")] + # Apply replacement end_idx = min(actual_start + lines_to_remove, len(result_lines)) result_lines = result_lines[:actual_start] + replacement + result_lines[end_idx:] - + # Join with newlines content = "\n".join(result_lines) if content and not content.endswith("\n"): content += "\n" - + return content @@ -290,7 +284,7 @@ def build_new_content(hunks: List[Hunk]) -> str: for hl in hunk.lines: if hl.type in ("add", "context"): lines.append(hl.content) - + content = "\n".join(lines) if content and not content.endswith("\n"): content += "\n" @@ -301,27 +295,28 @@ def build_new_content(hunks: List[Hunk]) -> str: # File Change Application # ============================================================================= + def apply_file_change( change: FileChange, cwd: Path, dry_run: bool = False, ) -> str: """Apply a single file change.""" - + # Handle deletion if change.is_deleted and change.old_path: full_path = cwd / change.old_path if not dry_run: full_path.unlink() return f" D {change.old_path}" - + # Get target path target_path = change.new_path or change.old_path if not target_path: raise ValueError("No file path specified") - + full_path = cwd / target_path - + # Handle new file if change.is_new_file: content = build_new_content(change.hunks) @@ -329,16 +324,16 @@ def apply_file_change( full_path.parent.mkdir(parents=True, exist_ok=True) full_path.write_text(content, encoding="utf-8") return f" A {target_path}" - + # Handle modification original_content = full_path.read_text(encoding="utf-8") original_lines = original_content.splitlines() - + new_content = apply_hunks_to_lines(original_lines, change.hunks) - + if not dry_run: full_path.write_text(new_content, encoding="utf-8") - + return f" M {target_path}" @@ -348,23 +343,23 @@ def apply_unified_diff( dry_run: bool = False, ) -> str: """Apply a unified diff to the filesystem. - + Main entry point matching fabric-core apply_unified_diff(). """ file_changes = parse_unified_diff(patch) - + if not file_changes: return "No changes to apply" - + report = [] modified_files = [] - + for change in file_changes: result = apply_file_change(change, cwd, dry_run) report.append(result) if change.new_path: modified_files.append(str(change.new_path)) - + action = "Would apply" if dry_run else "Applied" return f"{action} changes to {len(modified_files)} file(s):\n" + "\n".join(report) @@ -373,30 +368,31 @@ def apply_unified_diff( # Legacy Format Support (*** Begin Patch) # ============================================================================= + def parse_legacy_patch(patch: str) -> List[FileChange]: """Parse legacy *** Begin Patch format.""" file_changes: List[FileChange] = [] - + # Extract content between markers match = re.search(r"\*\*\* Begin Patch\s*\n(.*?)\*\*\* End Patch", patch, re.DOTALL) if not match: return [] - + content = match.group(1) - + # Split into file operations file_pattern = r"\*\*\* (Add|Delete|Update) File: (.+?)(?=\n\*\*\* (?:Add|Delete|Update)|$)" - + for file_match in re.finditer(file_pattern, content, re.DOTALL): op_type = file_match.group(1).lower() file_path = file_match.group(2).strip() - + # Get content after header start = file_match.end() remaining = content[start:] next_file = re.search(r"\*\*\* (?:Add|Delete|Update) File:", remaining) - file_content = remaining[:next_file.start()] if next_file else remaining - + file_content = remaining[: next_file.start()] if next_file else remaining + if op_type == "add": change = FileChange( old_path=None, @@ -413,14 +409,16 @@ def parse_legacy_patch(patch: str) -> List[FileChange]: if hunk.lines: change.hunks.append(hunk) file_changes.append(change) - + elif op_type == "delete": - file_changes.append(FileChange( - old_path=Path(file_path), - new_path=None, - is_deleted=True, - )) - + file_changes.append( + FileChange( + old_path=Path(file_path), + new_path=None, + is_deleted=True, + ) + ) + elif op_type == "update": change = FileChange( old_path=Path(file_path), @@ -443,7 +441,7 @@ def parse_legacy_patch(patch: str) -> List[FileChange]: if current_hunk: change.hunks.append(current_hunk) file_changes.append(change) - + return file_changes @@ -451,33 +449,34 @@ def parse_legacy_patch(patch: str) -> List[FileChange]: # Tool Implementation # ============================================================================= + class ApplyPatchTool(BaseTool): """Tool for applying file patches. - + Supports both standard unified diff format and legacy *** Begin Patch format. """ - + name = "apply_patch" description = "Applies file patches using unified diff or custom format." - + def execute(self, **kwargs: Any) -> ToolResult: """Apply a patch. - + Args: **kwargs: Tool arguments - patch: The patch content (unified diff or *** Begin Patch format) - dry_run: If True, don't actually modify files - + Returns: ToolResult with success/failure info """ # Extract parameters from kwargs patch: str = kwargs.get("patch", "") dry_run: bool = kwargs.get("dry_run", False) - + if not patch: return ToolResult.fail("Missing required parameter: patch") - + try: # Detect format and parse if "*** Begin Patch" in patch: @@ -485,26 +484,26 @@ def execute(self, **kwargs: Any) -> ToolResult: file_changes = parse_legacy_patch(patch) if not file_changes: return ToolResult.fail("No valid operations in patch") - + report = [] for change in file_changes: result = apply_file_change(change, self.cwd, dry_run) report.append(result) - + action = "Would apply" if dry_run else "Applied" return ToolResult.ok(f"{action} changes:\n" + "\n".join(report)) - + elif "---" in patch and "+++" in patch: # Standard unified diff result = apply_unified_diff(patch, self.cwd, dry_run) return ToolResult.ok(result) - + else: return ToolResult.fail( "Invalid patch format. Use unified diff (--- / +++) " "or custom format (*** Begin Patch)" ) - + except FileNotFoundError as e: return ToolResult.fail(f"File not found: {e}") except PermissionError as e: diff --git a/src/tools/base.py b/src/tools/base.py index c0729ba..2bc8e64 100644 --- a/src/tools/base.py +++ b/src/tools/base.py @@ -5,12 +5,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Optional, List, Dict +from typing import Any, Dict, List, Optional @dataclass class ToolMetadata: """Metadata about a tool execution.""" + duration_ms: int = 0 exit_code: Optional[int] = None files_modified: List[str] = field(default_factory=list) @@ -20,28 +21,29 @@ class ToolMetadata: @dataclass class ToolResult: """Result of a tool execution.""" + success: bool output: str error: Optional[str] = None data: Optional[dict[str, Any]] = None inject_content: Optional[dict[str, Any]] = None # For injecting images into context metadata: Optional[ToolMetadata] = None - + @classmethod def ok(cls, output: str, data: Optional[dict[str, Any]] = None) -> "ToolResult": """Create a successful result.""" return cls(success=True, output=output, data=data) - + @classmethod def fail(cls, error: str, output: str = "") -> "ToolResult": """Create a failed result.""" return cls(success=False, output=output, error=error) - + def with_metadata(self, metadata: ToolMetadata) -> "ToolResult": """Add metadata to this result.""" self.metadata = metadata return self - + def to_message(self) -> str: """Convert to message format for the LLM.""" if self.success: @@ -52,36 +54,36 @@ def to_message(self) -> str: class BaseTool(ABC): """Base class for all tools.""" - + name: str description: str - + def __init__(self, cwd: Path): """Initialize the tool. - + Args: cwd: Current working directory for the tool """ self.cwd = cwd - + @abstractmethod def execute(self, **kwargs: Any) -> ToolResult: """Execute the tool with the given arguments. - + Args: **kwargs: Tool-specific arguments - + Returns: ToolResult with success status and output """ pass - + def resolve_path(self, path: str) -> Path: """Resolve a path relative to the working directory. - + Args: path: Path string (absolute or relative) - + Returns: Resolved absolute Path """ @@ -89,11 +91,11 @@ def resolve_path(self, path: str) -> Path: if p.is_absolute(): return p return (self.cwd / p).resolve() - + @classmethod def get_spec(cls) -> dict[str, Any]: """Get the tool specification for the LLM. - + Returns: Tool specification dict """ diff --git a/src/tools/grep_files.py b/src/tools/grep_files.py index ff02283..07c2db2 100644 --- a/src/tools/grep_files.py +++ b/src/tools/grep_files.py @@ -2,26 +2,25 @@ from __future__ import annotations -import os import re import subprocess from pathlib import Path -from typing import Any, Optional +from typing import Optional from src.tools.base import BaseTool, ToolResult class GrepFilesTool(BaseTool): """Tool for searching file contents using patterns.""" - + name = "grep_files" description = "Finds files whose contents match the pattern." - + # Default limits DEFAULT_LIMIT = 100 MAX_LIMIT = 2000 TIMEOUT_SECONDS = 30 - + def execute( self, pattern: str, @@ -30,33 +29,33 @@ def execute( limit: int = DEFAULT_LIMIT, ) -> ToolResult: """Search for files matching a pattern. - + Args: pattern: Regex pattern to search for include: Glob pattern to filter files path: Directory to search in limit: Maximum number of results - + Returns: ToolResult with matching file paths """ # Resolve search path search_path = self.resolve_path(path) if path else self.cwd - + if not search_path.exists(): return ToolResult.fail(f"Path not found: {search_path}") - + # Cap limit limit = min(limit, self.MAX_LIMIT) - + # Try ripgrep first (fastest) rg_result = self._search_with_ripgrep(pattern, include, search_path, limit) if rg_result is not None: return rg_result - + # Fallback to Python implementation return self._search_with_python(pattern, include, search_path, limit) - + def _search_with_ripgrep( self, pattern: str, @@ -65,17 +64,17 @@ def _search_with_ripgrep( limit: int, ) -> Optional[ToolResult]: """Search using ripgrep (rg). - + Returns None if ripgrep is not available. """ cmd = ["rg", "--files-with-matches", "--no-heading"] - + if include: # Convert glob to rg glob format cmd.extend(["--glob", include]) - + cmd.extend([pattern, str(search_path)]) - + try: result = subprocess.run( cmd, @@ -83,34 +82,34 @@ def _search_with_ripgrep( text=True, timeout=self.TIMEOUT_SECONDS, ) - + if result.returncode == 0: files = result.stdout.strip().split("\n") if result.stdout.strip() else [] files = files[:limit] - + if not files: return ToolResult.ok("No matching files found.") - + output = f"Found {len(files)} matching files:\n" + "\n".join(files) return ToolResult.ok(output, data={"count": len(files), "files": files}) - + elif result.returncode == 1: # No matches return ToolResult.ok("No matching files found.") - + elif result.returncode == 2: # Error - might be bad pattern return ToolResult.fail(f"Search error: {result.stderr.strip()}") - + return None # Try fallback - + except FileNotFoundError: return None # rg not installed, use fallback except subprocess.TimeoutExpired: return ToolResult.fail(f"Search timed out after {self.TIMEOUT_SECONDS}s") except Exception: return None # Use fallback - + def _search_with_python( self, pattern: str, @@ -123,10 +122,10 @@ def _search_with_python( regex = re.compile(pattern) except re.error as e: return ToolResult.fail(f"Invalid regex pattern: {e}") - + matching_files: list[str] = [] errors: list[str] = [] - + # Convert include glob to regex if provided include_regex = None if include: @@ -143,21 +142,21 @@ def _search_with_python( include_regex = re.compile(f"^{include_pattern}$", re.IGNORECASE) except re.error: pass - + def should_include(file_path: Path) -> bool: if include_regex is None: return True return include_regex.match(file_path.name) is not None - + def search_dir(dir_path: Path) -> None: if len(matching_files) >= limit: return - + try: for item in dir_path.iterdir(): if len(matching_files) >= limit: return - + if item.is_file() and should_include(item): try: content = item.read_text(encoding="utf-8", errors="ignore") @@ -165,15 +164,15 @@ def search_dir(dir_path: Path) -> None: matching_files.append(str(item)) except (PermissionError, OSError): pass - + elif item.is_dir() and not item.is_symlink(): # Skip hidden directories if not item.name.startswith("."): search_dir(item) - + except PermissionError: errors.append(f"Permission denied: {dir_path}") - + if search_path.is_file(): try: content = search_path.read_text(encoding="utf-8", errors="ignore") @@ -183,13 +182,13 @@ def search_dir(dir_path: Path) -> None: return ToolResult.fail(f"Cannot read file: {e}") else: search_dir(search_path) - + if not matching_files: return ToolResult.ok("No matching files found.") - + output = f"Found {len(matching_files)} matching files:\n" + "\n".join(matching_files) - + if errors: - output += f"\n\nWarnings:\n" + "\n".join(errors[:5]) - + output += "\n\nWarnings:\n" + "\n".join(errors[:5]) + return ToolResult.ok(output, data={"count": len(matching_files), "files": matching_files}) diff --git a/src/tools/list_dir.py b/src/tools/list_dir.py index a3a2c13..5cbaa29 100644 --- a/src/tools/list_dir.py +++ b/src/tools/list_dir.py @@ -2,161 +2,158 @@ from __future__ import annotations -import os from pathlib import Path -from typing import Any, Optional, List +from typing import Any, List, Optional -from .base import BaseTool, ToolResult, ToolMetadata +from .base import BaseTool, ToolMetadata, ToolResult class ListDirTool(BaseTool): """Tool to list directory contents.""" - + name = "list_dir" description = "List the contents of a directory" - + def execute( self, directory_path: str = ".", recursive: bool = False, include_hidden: bool = False, ignore_patterns: Optional[List[str]] = None, - **kwargs: Any + **kwargs: Any, ) -> ToolResult: """List directory contents. - + Args: directory_path: Path to the directory to list recursive: Whether to list recursively include_hidden: Whether to include hidden files/directories ignore_patterns: List of patterns to ignore - + Returns: ToolResult with directory listing and metadata """ import time + start_time = time.time() - + resolved_path = self.resolve_path(directory_path) - + if not resolved_path.exists(): return ToolResult.fail(f"Directory not found: {directory_path}") - + if not resolved_path.is_dir(): return ToolResult.fail(f"Not a directory: {directory_path}") - + ignore_patterns = ignore_patterns or [] entries = [] output_lines = [] - + try: if recursive: items = self._list_recursive(resolved_path, include_hidden, ignore_patterns) else: items = self._list_flat(resolved_path, include_hidden, ignore_patterns) - - for item_path, item_type, item_size in sorted(items, key=lambda x: (x[1] != "dir", x[0].lower())): + + for item_path, item_type, item_size in sorted( + items, key=lambda x: (x[1] != "dir", x[0].lower()) + ): if item_type == "dir": output_lines.append(f"dir {item_path}") else: output_lines.append(f"file {item_path}") - - entries.append({ - "name": item_path, - "type": item_type, - "size": item_size, - }) - + + entries.append( + { + "name": item_path, + "type": item_type, + "size": item_size, + } + ) + if not entries: - output = f"Directory '{directory_path}' is empty (no files or subdirectories found)." + output = ( + f"Directory '{directory_path}' is empty (no files or subdirectories found)." + ) else: output = "\n".join(output_lines) - + duration_ms = int((time.time() - start_time) * 1000) metadata = ToolMetadata( duration_ms=duration_ms, data={ "path": str(resolved_path), "entries": entries, - } + }, ) - + result = ToolResult.ok(output) return result.with_metadata(metadata) - + except PermissionError: return ToolResult.fail(f"Permission denied: {directory_path}") except Exception as e: return ToolResult.fail(f"Error listing directory: {e}") - + def _should_ignore(self, name: str, include_hidden: bool, ignore_patterns: List[str]) -> bool: """Check if a file/directory should be ignored.""" # Check hidden files if not include_hidden and name.startswith("."): return True - + # Check ignore patterns (simple glob matching) for pattern in ignore_patterns: if self._match_pattern(name, pattern): return True - + return False - + def _match_pattern(self, name: str, pattern: str) -> bool: """Simple glob pattern matching with * and ?.""" import fnmatch + return fnmatch.fnmatch(name, pattern) - + def _list_flat( - self, - path: Path, - include_hidden: bool, - ignore_patterns: List[str] + self, path: Path, include_hidden: bool, ignore_patterns: List[str] ) -> List[tuple[str, str, int]]: """List directory contents non-recursively.""" items = [] - + for entry in path.iterdir(): if self._should_ignore(entry.name, include_hidden, ignore_patterns): continue - + item_type = "dir" if entry.is_dir() else "file" item_size = 0 if entry.is_dir() else entry.stat().st_size items.append((entry.name, item_type, item_size)) - + return items - + def _list_recursive( - self, - path: Path, - include_hidden: bool, - ignore_patterns: List[str], - prefix: str = "" + self, path: Path, include_hidden: bool, ignore_patterns: List[str], prefix: str = "" ) -> List[tuple[str, str, int]]: """List directory contents recursively.""" items = [] - + for entry in path.iterdir(): if self._should_ignore(entry.name, include_hidden, ignore_patterns): continue - + relative_name = f"{prefix}{entry.name}" if prefix else entry.name item_type = "dir" if entry.is_dir() else "file" item_size = 0 if entry.is_dir() else entry.stat().st_size items.append((relative_name, item_type, item_size)) - + if entry.is_dir(): # Recurse into subdirectory sub_items = self._list_recursive( - entry, - include_hidden, - ignore_patterns, - prefix=f"{relative_name}/" + entry, include_hidden, ignore_patterns, prefix=f"{relative_name}/" ) items.extend(sub_items) - + return items - + @classmethod def get_spec(cls) -> dict[str, Any]: """Get the tool specification for the LLM.""" @@ -169,24 +166,24 @@ def get_spec(cls) -> dict[str, Any]: "directory_path": { "type": "string", "description": "Path to the directory to list", - "default": "." + "default": ".", }, "recursive": { "type": "boolean", "description": "Whether to list recursively", - "default": False + "default": False, }, "include_hidden": { "type": "boolean", "description": "Whether to include hidden files/directories", - "default": False + "default": False, }, "ignore_patterns": { "type": "array", "items": {"type": "string"}, - "description": "List of patterns to ignore" - } + "description": "List of patterns to ignore", + }, }, - "required": [] - } + "required": [], + }, } diff --git a/src/tools/read_file.py b/src/tools/read_file.py index 6efd51d..5a446d8 100644 --- a/src/tools/read_file.py +++ b/src/tools/read_file.py @@ -2,47 +2,42 @@ from __future__ import annotations -import os -from pathlib import Path from typing import Any, Optional -from .base import BaseTool, ToolResult, ToolMetadata +from .base import BaseTool, ToolMetadata, ToolResult class ReadFileTool(BaseTool): """Tool to read file contents with line numbers.""" - + name = "read_file" description = "Read the contents of a file with line numbers" - + def execute( - self, - file_path: str, - offset: int = 0, - limit: Optional[int] = None, - **kwargs: Any + self, file_path: str, offset: int = 0, limit: Optional[int] = None, **kwargs: Any ) -> ToolResult: """Read file contents. - + Args: file_path: Path to the file to read offset: Line offset to start from (0-based) limit: Maximum number of lines to read (None for all) - + Returns: ToolResult with file contents and metadata """ import time + start_time = time.time() - + resolved_path = self.resolve_path(file_path) - + if not resolved_path.exists(): return ToolResult.fail(f"File not found: {file_path}") - + if not resolved_path.is_file(): return ToolResult.fail(f"Not a file: {file_path}") - + try: content = resolved_path.read_text(encoding="utf-8") except UnicodeDecodeError: @@ -53,10 +48,10 @@ def execute( return ToolResult.fail(f"Cannot read file: {e}") except Exception as e: return ToolResult.fail(f"Error reading file: {e}") - + lines = content.splitlines() total_lines = len(lines) - + # Handle empty file if total_lines == 0 or (total_lines == 1 and lines[0] == ""): duration_ms = int((time.time() - start_time) * 1000) @@ -72,30 +67,30 @@ def execute( "offset": offset, "truncated": False, "empty": True, - } + }, ) result = ToolResult.ok("(empty file)") return result.with_metadata(metadata) - + # Apply offset and limit if offset >= total_lines: return ToolResult.fail(f"Offset {offset} exceeds total lines {total_lines}") - + end_index = total_lines if limit is not None: end_index = min(offset + limit, total_lines) - + selected_lines = lines[offset:end_index] shown_lines = len(selected_lines) truncated = end_index < total_lines - + # Format with line numbers formatted_lines = [] for i, line in enumerate(selected_lines, start=offset + 1): formatted_lines.append(f"L{i}: {line}") - + output = "\n".join(formatted_lines) - + duration_ms = int((time.time() - start_time) * 1000) metadata = ToolMetadata( duration_ms=duration_ms, @@ -109,12 +104,12 @@ def execute( "offset": offset, "truncated": truncated, "empty": False, - } + }, ) - + result = ToolResult.ok(output) return result.with_metadata(metadata) - + @classmethod def get_spec(cls) -> dict[str, Any]: """Get the tool specification for the LLM.""" @@ -124,20 +119,17 @@ def get_spec(cls) -> dict[str, Any]: "parameters": { "type": "object", "properties": { - "file_path": { - "type": "string", - "description": "Path to the file to read" - }, + "file_path": {"type": "string", "description": "Path to the file to read"}, "offset": { "type": "integer", "description": "Line offset to start from (0-based)", - "default": 0 + "default": 0, }, "limit": { "type": "integer", - "description": "Maximum number of lines to read (optional)" - } + "description": "Maximum number of lines to read (optional)", + }, }, - "required": ["file_path"] - } + "required": ["file_path"], + }, } diff --git a/src/tools/registry.py b/src/tools/registry.py index b96409f..deb1ef2 100644 --- a/src/tools/registry.py +++ b/src/tools/registry.py @@ -5,12 +5,11 @@ import hashlib import json import subprocess -import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from src.tools.base import ToolResult from src.tools.specs import get_all_tools @@ -22,6 +21,7 @@ @dataclass class ExecutorConfig: """Configuration for tool execution.""" + max_concurrent: int = 4 default_timeout: float = 120.0 cache_enabled: bool = True @@ -31,9 +31,10 @@ class ExecutorConfig: @dataclass class CachedResult: """A cached tool result with timestamp.""" + result: ToolResult cached_at: float # timestamp from time.time() - + def is_valid(self, ttl: float) -> bool: """Check if the cached result is still valid.""" return (time.time() - self.cached_at) < ttl @@ -42,16 +43,17 @@ def is_valid(self, ttl: float) -> bool: @dataclass class ToolStats: """Per-tool execution statistics.""" + executions: int = 0 successes: int = 0 total_ms: int = 0 - + def success_rate(self) -> float: """Get the success rate for this tool.""" if self.executions == 0: return 0.0 return self.successes / self.executions - + def avg_ms(self) -> float: """Get average execution time in milliseconds.""" if self.executions == 0: @@ -62,25 +64,26 @@ def avg_ms(self) -> float: @dataclass class ExecutorStats: """Aggregate execution statistics.""" + total_executions: int = 0 successful_executions: int = 0 failed_executions: int = 0 cache_hits: int = 0 total_duration_ms: int = 0 by_tool: Dict[str, ToolStats] = field(default_factory=dict) - + def success_rate(self) -> float: """Get overall success rate.""" if self.total_executions == 0: return 0.0 return self.successful_executions / self.total_executions - + def cache_hit_rate(self) -> float: """Get cache hit rate.""" if self.total_executions == 0: return 0.0 return self.cache_hits / self.total_executions - + def avg_duration_ms(self) -> float: """Get average execution duration in milliseconds.""" if self.total_executions == 0: @@ -90,18 +93,18 @@ def avg_duration_ms(self) -> float: class ToolRegistry: """Registry for managing and dispatching tool calls. - + Tools receive AgentContext for shell execution. Includes caching and execution statistics. """ - + def __init__( self, cwd: Optional[Path] = None, config: Optional[ExecutorConfig] = None, ): """Initialize the registry. - + Args: cwd: Current working directory for tools (optional, can be set later) config: Executor configuration (optional, uses defaults) @@ -111,7 +114,7 @@ def __init__( self._config = config or ExecutorConfig() self._cache: Dict[str, CachedResult] = {} self._stats = ExecutorStats() - + def execute( self, ctx: "AgentContext", @@ -119,17 +122,17 @@ def execute( arguments: dict[str, Any], ) -> ToolResult: """Execute a tool by name. - + Args: ctx: Agent context with shell() method name: Tool name arguments: Tool arguments - + Returns: ToolResult from the tool execution """ start_time = time.time() - + # Check cache first if enabled if self._config.cache_enabled: cache_key = self._cache_key(name, arguments) @@ -138,9 +141,9 @@ def execute( duration_ms = int((time.time() - start_time) * 1000) self._record_execution(name, duration_ms, success=True, cached=True) return cached - - cwd = Path(ctx.cwd) if hasattr(ctx, 'cwd') else self.cwd - + + cwd = Path(ctx.cwd) if hasattr(ctx, "cwd") else self.cwd + try: if name == "shell_command": result = self._execute_shell(ctx, cwd, arguments) @@ -160,21 +163,21 @@ def execute( result = self._execute_update_plan(arguments) else: result = ToolResult.fail(f"Unknown tool: {name}") - + except Exception as e: result = ToolResult.fail(f"Tool {name} failed: {e}") - + # Record execution stats duration_ms = int((time.time() - start_time) * 1000) self._record_execution(name, duration_ms, success=result.success, cached=False) - + # Cache successful results if self._config.cache_enabled and result.success: cache_key = self._cache_key(name, arguments) self._cache_result(cache_key, result) - + return result - + def _execute_shell( self, ctx: "AgentContext", @@ -185,18 +188,18 @@ def _execute_shell( command = args.get("command", "") workdir = args.get("workdir") timeout_ms = args.get("timeout_ms", 60000) - + if not command: return ToolResult.fail("No command provided") - + # Resolve working directory effective_cwd = cwd if workdir: wd = Path(workdir) effective_cwd = wd if wd.is_absolute() else cwd / wd - + timeout_sec = max(1, timeout_ms // 1000) - + try: result = subprocess.run( ["sh", "-c", command], @@ -205,19 +208,19 @@ def _execute_shell( text=True, timeout=timeout_sec, ) - + output = result.stdout if result.stderr: output += f"\n{result.stderr}" - + if result.returncode != 0: output += f"\n[exit code: {result.returncode}]" - + return ToolResult( success=result.returncode == 0, output=output.strip(), ) - + except subprocess.TimeoutExpired: return ToolResult( success=False, @@ -225,106 +228,106 @@ def _execute_shell( ) except Exception as e: return ToolResult.fail(str(e)) - + def _execute_read_file(self, cwd: Path, args: dict[str, Any]) -> ToolResult: """Read file contents.""" file_path = args.get("file_path", "") offset = args.get("offset", 1) limit = args.get("limit", 2000) - + if not file_path: return ToolResult.fail("No file_path provided") - + path = Path(file_path) if not path.is_absolute(): path = cwd / path - + if not path.exists(): return ToolResult.fail(f"File not found: {path}") - + if not path.is_file(): return ToolResult.fail(f"Not a file: {path}") - + try: with open(path, "r", encoding="utf-8", errors="replace") as f: lines = f.readlines() - + # Apply offset and limit (1-indexed) start = max(0, offset - 1) end = start + limit selected = lines[start:end] - + # Format with line numbers output_lines = [] for i, line in enumerate(selected, start=start + 1): output_lines.append(f"L{i}: {line.rstrip()}") - + output = "\n".join(output_lines) - + if len(lines) > end: output += f"\n\n[... {len(lines) - end} more lines ...]" - + return ToolResult.ok(output) - + except Exception as e: return ToolResult.fail(f"Failed to read file: {e}") - + def _execute_write_file(self, cwd: Path, args: dict[str, Any]) -> ToolResult: """Write content to a file.""" file_path = args.get("file_path", "") content = args.get("content", "") - + if not file_path: return ToolResult.fail("No file_path provided") - + path = Path(file_path) if not path.is_absolute(): path = cwd / path - + try: # Ensure parent directory exists path.parent.mkdir(parents=True, exist_ok=True) - + with open(path, "w", encoding="utf-8") as f: f.write(content) - + return ToolResult.ok(f"Wrote {len(content)} bytes to {path}") - + except Exception as e: return ToolResult.fail(f"Failed to write file: {e}") - + def _execute_list_dir(self, cwd: Path, args: dict[str, Any]) -> ToolResult: """List directory contents.""" dir_path = args.get("dir_path", ".") depth = args.get("depth", 2) limit = args.get("limit", 50) - + path = Path(dir_path) if not path.is_absolute(): path = cwd / path - + if not path.exists(): return ToolResult.fail(f"Directory not found: {path}") - + if not path.is_dir(): return ToolResult.fail(f"Not a directory: {path}") - + try: entries = [] self._list_recursive(path, path, entries, depth, limit) - + if not entries: return ToolResult.ok("(empty directory)") - + output = "\n".join(entries[:limit]) if len(entries) > limit: output += f"\n\n[... {len(entries) - limit} more entries ...]" - + return ToolResult.ok(output) - + except Exception as e: return ToolResult.fail(f"Failed to list directory: {e}") - + def _list_recursive( self, base: Path, @@ -337,27 +340,29 @@ def _list_recursive( """Recursively list directory contents.""" if current_depth > max_depth or len(entries) >= max_entries: return - + try: items = sorted(current.iterdir(), key=lambda x: (not x.is_dir(), x.name.lower())) - + for item in items: if len(entries) >= max_entries: break - + rel_path = item.relative_to(base) - + if item.is_dir(): entries.append(f"{rel_path}/") - self._list_recursive(base, item, entries, max_depth, max_entries, current_depth + 1) + self._list_recursive( + base, item, entries, max_depth, max_entries, current_depth + 1 + ) elif item.is_symlink(): entries.append(f"{rel_path}@") else: entries.append(str(rel_path)) - + except PermissionError: pass - + def _execute_grep( self, ctx: "AgentContext", @@ -369,21 +374,21 @@ def _execute_grep( include = args.get("include", "") search_path = args.get("path", ".") limit = args.get("limit", 100) - + if not pattern: return ToolResult.fail("No pattern provided") - + # Build ripgrep command cmd_parts = ["rg", "-l", "--color=never"] - + if include: cmd_parts.extend(["-g", include]) - + cmd_parts.append(pattern) cmd_parts.append(search_path) - + cmd = " ".join(f'"{p}"' if " " in p else p for p in cmd_parts) - + try: result = subprocess.run( ["sh", "-c", cmd], @@ -392,52 +397,53 @@ def _execute_grep( text=True, timeout=30, ) - + files = [f for f in result.stdout.strip().split("\n") if f] - + if not files: return ToolResult.ok("No matches found") - + output = "\n".join(files[:limit]) if len(files) > limit: output += f"\n\n[... {len(files) - limit} more files ...]" - + return ToolResult.ok(output) - + except subprocess.TimeoutExpired: return ToolResult.fail("Search timed out") except Exception as e: return ToolResult.fail(f"Search failed: {e}") - + def _execute_apply_patch(self, cwd: Path, args: dict[str, Any]) -> ToolResult: """Apply a patch to files.""" patch = args.get("patch", "") - + if not patch: return ToolResult.fail("No patch provided") - + from src.tools.apply_patch import ApplyPatchTool - + tool = ApplyPatchTool(cwd) return tool.execute(patch=patch) - + def _execute_view_image(self, cwd: Path, args: dict[str, Any]) -> ToolResult: """View an image file.""" path = args.get("path", "") - + if not path: return ToolResult.fail("No path provided") - + from src.tools.view_image import view_image + return view_image(path, cwd) - + def _execute_update_plan(self, args: dict[str, Any]) -> ToolResult: """Update the task plan.""" steps = args.get("steps", []) explanation = args.get("explanation") - + self._plan = steps - + # Format plan for output lines = ["Plan updated:"] for i, step in enumerate(steps, 1): @@ -447,55 +453,56 @@ def _execute_update_plan(self, args: dict[str, Any]) -> ToolResult: "completed": "[x]", }.get(step.get("status", "pending"), "[ ]") lines.append(f" {status_icon} {i}. {step.get('description', '')}") - + if explanation: lines.append(f"\nReason: {explanation}") - + return ToolResult.ok("\n".join(lines)) - + # ------------------------------------------------------------------------- # Caching methods # ------------------------------------------------------------------------- - + def _cache_key(self, name: str, arguments: dict[str, Any]) -> str: """Generate a cache key for a tool call.""" args_json = json.dumps(arguments, sort_keys=True, default=str) content = f"{name}:{args_json}" return hashlib.sha256(content.encode()).hexdigest()[:32] - + def _get_cached(self, key: str) -> Optional[ToolResult]: """Get a cached result if valid.""" cached = self._cache.get(key) if cached is not None and cached.is_valid(self._config.cache_ttl): return cached.result return None - + def _cache_result(self, key: str, result: ToolResult) -> None: """Cache a tool result.""" self._cache[key] = CachedResult(result=result, cached_at=time.time()) - + # Evict old entries if cache is too large if len(self._cache) > 1000: self._evict_expired_cache() - + def _evict_expired_cache(self) -> None: """Remove expired entries from cache.""" now = time.time() expired_keys = [ - key for key, cached in self._cache.items() + key + for key, cached in self._cache.items() if not cached.is_valid(self._config.cache_ttl) ] for key in expired_keys: del self._cache[key] - + def clear_cache(self) -> None: """Clear the entire cache.""" self._cache.clear() - + # ------------------------------------------------------------------------- # Statistics methods # ------------------------------------------------------------------------- - + def _record_execution( self, tool_name: str, @@ -506,91 +513,93 @@ def _record_execution( """Record execution statistics.""" self._stats.total_executions += 1 self._stats.total_duration_ms += duration_ms - + if success: self._stats.successful_executions += 1 else: self._stats.failed_executions += 1 - + if cached: self._stats.cache_hits += 1 - + # Per-tool stats if tool_name not in self._stats.by_tool: self._stats.by_tool[tool_name] = ToolStats() - + tool_stats = self._stats.by_tool[tool_name] tool_stats.executions += 1 tool_stats.total_ms += duration_ms if success: tool_stats.successes += 1 - + def stats(self) -> ExecutorStats: """Get execution statistics.""" return self._stats - + # ------------------------------------------------------------------------- # Batch execution # ------------------------------------------------------------------------- - + def execute_batch( self, ctx: "AgentContext", calls: List[Tuple[str, dict]], ) -> List[ToolResult]: """Execute multiple tool calls in parallel. - + Args: ctx: Agent context with shell() method calls: List of (tool_name, arguments) tuples - + Returns: List of ToolResults in the same order as input calls """ if not calls: return [] - + # For single call, just execute directly if len(calls) == 1: name, args = calls[0] return [self.execute(ctx, name, args)] - + # Execute in parallel using ThreadPoolExecutor results: List[Optional[ToolResult]] = [None] * len(calls) - + with ThreadPoolExecutor(max_workers=self._config.max_concurrent) as executor: future_to_index = { executor.submit(self.execute, ctx, name, args): i for i, (name, args) in enumerate(calls) } - + for future in as_completed(future_to_index): index = future_to_index[future] try: results[index] = future.result() except Exception as e: results[index] = ToolResult.fail(f"Batch execution failed: {e}") - + # Ensure all results are filled (shouldn't happen, but just in case) return [r if r is not None else ToolResult.fail("No result") for r in results] - + def get_plan(self) -> list[dict[str, str]]: """Get the current plan.""" return self._plan.copy() - + def get_tools_for_llm(self) -> list: """Get tool specifications formatted for the LLM. - - Returns tools in OpenAI-compatible format for litellm. + + Returns tools in OpenAI-compatible format. """ specs = get_all_tools() tools = [] - + for spec in specs: - tools.append({ - "name": spec["name"], - "description": spec.get("description", ""), - "parameters": spec.get("parameters", {}), - }) - + tools.append( + { + "name": spec["name"], + "description": spec.get("description", ""), + "parameters": spec.get("parameters", {}), + } + ) + return tools diff --git a/src/tools/search_files.py b/src/tools/search_files.py index 658181f..e29dea7 100644 --- a/src/tools/search_files.py +++ b/src/tools/search_files.py @@ -5,49 +5,56 @@ import fnmatch import os from pathlib import Path -from typing import Any, Optional, List +from typing import Any, Optional -from .base import BaseTool, ToolResult, ToolMetadata +from .base import BaseTool, ToolMetadata, ToolResult class SearchFilesTool(BaseTool): """Tool to search for files using glob patterns.""" - + name = "search_files" description = "Search for files matching a glob pattern" - + # Default directories to skip - DEFAULT_SKIP_DIRS = {".git", "node_modules", "target", "__pycache__", ".venv", "venv", ".tox", "dist", "build"} - + DEFAULT_SKIP_DIRS = { + ".git", + "node_modules", + "target", + "__pycache__", + ".venv", + "venv", + ".tox", + "dist", + "build", + } + def execute( - self, - pattern: str, - path: str = ".", - content_pattern: Optional[str] = None, - **kwargs: Any + self, pattern: str, path: str = ".", content_pattern: Optional[str] = None, **kwargs: Any ) -> ToolResult: """Search for files matching a pattern. - + Args: pattern: Glob pattern to match files (e.g., "*.py", "**/*.js") path: Base path to search from content_pattern: Optional regex pattern to match file contents - + Returns: ToolResult with list of matching file paths """ - import time import re + import time + start_time = time.time() - + resolved_path = self.resolve_path(path) - + if not resolved_path.exists(): return ToolResult.fail(f"Path not found: {path}") - + if not resolved_path.is_dir(): return ToolResult.fail(f"Not a directory: {path}") - + # Compile content pattern if provided content_regex = None if content_pattern: @@ -55,47 +62,46 @@ def execute( content_regex = re.compile(content_pattern) except re.error as e: return ToolResult.fail(f"Invalid content pattern: {e}") - + matches = [] - + try: # Walk the directory tree for root, dirs, files in os.walk(resolved_path): # Skip hidden directories and default skip dirs dirs[:] = [ - d for d in dirs - if not d.startswith(".") and d not in self.DEFAULT_SKIP_DIRS + d for d in dirs if not d.startswith(".") and d not in self.DEFAULT_SKIP_DIRS ] - + root_path = Path(root) - + for filename in files: # Skip hidden files if filename.startswith("."): continue - + file_path = root_path / filename relative_path = file_path.relative_to(resolved_path) - + # Check glob pattern match if not self._match_glob(str(relative_path), pattern): continue - + # Check content pattern if provided if content_regex: if not self._match_content(file_path, content_regex): continue - + matches.append(str(relative_path)) - + # Sort matches matches.sort() - + if not matches: output = f"No files found matching pattern '{pattern}'" else: output = "\n".join(matches) - + duration_ms = int((time.time() - start_time) * 1000) metadata = ToolMetadata( duration_ms=duration_ms, @@ -104,20 +110,20 @@ def execute( "base_path": str(resolved_path), "matches": matches, "count": len(matches), - } + }, ) - + result = ToolResult.ok(output) return result.with_metadata(metadata) - + except PermissionError: return ToolResult.fail(f"Permission denied while searching: {path}") except Exception as e: return ToolResult.fail(f"Error searching files: {e}") - + def _match_glob(self, filepath: str, pattern: str) -> bool: """Match a filepath against a glob pattern. - + Supports: - * matches any characters except path separator - ? matches exactly one character @@ -126,7 +132,7 @@ def _match_glob(self, filepath: str, pattern: str) -> bool: # Normalize path separators filepath = filepath.replace("\\", "/") pattern = pattern.replace("\\", "/") - + # Handle ** pattern (recursive matching) if "**" in pattern: # Split pattern at ** @@ -135,19 +141,19 @@ def _match_glob(self, filepath: str, pattern: str) -> bool: prefix, suffix = parts prefix = prefix.rstrip("/") suffix = suffix.lstrip("/") - + # Check prefix if it exists if prefix and not filepath.startswith(prefix): return False - + # Check suffix against any part of the path if suffix: return fnmatch.fnmatch(filepath, f"*{suffix}") return True - + # Simple glob matching for * and ? return fnmatch.fnmatch(filepath, pattern) - + def _match_content(self, file_path: Path, regex: Any) -> bool: """Check if file content matches the regex pattern.""" try: @@ -155,7 +161,7 @@ def _match_content(self, file_path: Path, regex: Any) -> bool: return bool(regex.search(content)) except Exception: return False - + @classmethod def get_spec(cls) -> dict[str, Any]: """Get the tool specification for the LLM.""" @@ -167,18 +173,18 @@ def get_spec(cls) -> dict[str, Any]: "properties": { "pattern": { "type": "string", - "description": "Glob pattern to match files (e.g., '*.py', '**/*.js')" + "description": "Glob pattern to match files (e.g., '*.py', '**/*.js')", }, "path": { "type": "string", "description": "Base path to search from", - "default": "." + "default": ".", }, "content_pattern": { "type": "string", - "description": "Optional regex pattern to match file contents" - } + "description": "Optional regex pattern to match file contents", + }, }, - "required": ["pattern"] - } + "required": ["pattern"], + }, } diff --git a/src/tools/shell.py b/src/tools/shell.py index b1edf3d..8240826 100644 --- a/src/tools/shell.py +++ b/src/tools/shell.py @@ -5,38 +5,37 @@ import os import platform import subprocess -import sys from pathlib import Path -from typing import Any, Optional +from typing import Optional from src.tools.base import BaseTool, ToolResult class ShellCommandTool(BaseTool): """Tool for executing shell commands.""" - + name = "shell_command" description = "Runs a shell command and returns its output." - + # Default timeout in milliseconds DEFAULT_TIMEOUT_MS = 30000 - + # Maximum output size MAX_OUTPUT_SIZE = 100000 # 100KB - + def __init__(self, cwd: Path, timeout_ms: int = DEFAULT_TIMEOUT_MS): """Initialize the shell command tool. - + Args: cwd: Working directory timeout_ms: Default timeout in milliseconds """ super().__init__(cwd) self.default_timeout_ms = timeout_ms - + def _get_shell(self) -> tuple[str, list[str]]: """Get the shell and shell arguments for the current platform. - + Returns: Tuple of (shell executable, shell arguments) """ @@ -47,7 +46,7 @@ def _get_shell(self) -> tuple[str, list[str]]: # Use bash on Unix with login shell shell = os.environ.get("SHELL", "/bin/bash") return shell, ["-lc"] - + def execute( self, command: str, @@ -55,12 +54,12 @@ def execute( timeout_ms: Optional[int] = None, ) -> ToolResult: """Execute a shell command. - + Args: command: The command to execute workdir: Working directory (defaults to cwd) timeout_ms: Timeout in milliseconds - + Returns: ToolResult with command output """ @@ -69,16 +68,16 @@ def execute( work_path = self.resolve_path(workdir) else: work_path = self.cwd - + if not work_path.exists(): return ToolResult.fail(f"Working directory does not exist: {work_path}") - + # Get timeout timeout_s = (timeout_ms or self.default_timeout_ms) / 1000 - + # Build command shell, shell_args = self._get_shell() - + try: # Run the command result = subprocess.run( @@ -89,40 +88,46 @@ def execute( timeout=timeout_s, env={**os.environ, "TERM": "dumb"}, # Disable color codes ) - + # Combine stdout and stderr output_parts = [] - + if result.stdout: stdout = result.stdout if len(stdout) > self.MAX_OUTPUT_SIZE: - stdout = stdout[:self.MAX_OUTPUT_SIZE] + "\n... (output truncated)" + stdout = stdout[: self.MAX_OUTPUT_SIZE] + "\n... (output truncated)" output_parts.append(stdout) - + if result.stderr: stderr = result.stderr if len(stderr) > self.MAX_OUTPUT_SIZE: - stderr = stderr[:self.MAX_OUTPUT_SIZE] + "\n... (stderr truncated)" + stderr = stderr[: self.MAX_OUTPUT_SIZE] + "\n... (stderr truncated)" if output_parts: output_parts.append(f"\nstderr:\n{stderr}") else: output_parts.append(stderr) - + output = "".join(output_parts).strip() - + # Add exit code info if non-zero if result.returncode != 0: - output = f"{output}\n\nExit code: {result.returncode}" if output else f"Exit code: {result.returncode}" - + output = ( + f"{output}\n\nExit code: {result.returncode}" + if output + else f"Exit code: {result.returncode}" + ) + if not output: output = "(no output)" - + # Return result based on exit code if result.returncode == 0: return ToolResult.ok(output) else: - return ToolResult.ok(output) # Still "ok" - we return the output even on non-zero exit - + return ToolResult.ok( + output + ) # Still "ok" - we return the output even on non-zero exit + except subprocess.TimeoutExpired: return ToolResult.fail( f"Command timed out after {timeout_s}s", diff --git a/src/tools/specs.py b/src/tools/specs.py index fa4381d..69140e9 100644 --- a/src/tools/specs.py +++ b/src/tools/specs.py @@ -243,7 +243,7 @@ def get_all_tools() -> list[dict[str, Any]]: """Get all tool specifications as a list. - + Returns: List of tool specification dicts """ @@ -252,10 +252,10 @@ def get_all_tools() -> list[dict[str, Any]]: def get_tool_spec(name: str) -> dict[str, Any] | None: """Get a specific tool specification. - + Args: name: Name of the tool - + Returns: Tool specification dict or None if not found """ diff --git a/src/tools/view_image.py b/src/tools/view_image.py index 65398e5..ac65d48 100644 --- a/src/tools/view_image.py +++ b/src/tools/view_image.py @@ -10,41 +10,41 @@ from pathlib import Path from typing import Any, Dict, Optional, Tuple -from src.tools.base import ToolResult from src.images.loader import load_image_as_data_uri, make_image_content +from src.tools.base import ToolResult def get_image_dimensions(data: bytes) -> Optional[Tuple[int, int]]: """Parse image dimensions from raw bytes without PIL.""" if len(data) < 24: return None - + # PNG: signature 0x89 PNG, dimensions at offset 16-23 - if data[:4] == b'\x89PNG' and len(data) >= 24: - width = int.from_bytes(data[16:20], 'big') - height = int.from_bytes(data[20:24], 'big') + if data[:4] == b"\x89PNG" and len(data) >= 24: + width = int.from_bytes(data[16:20], "big") + height = int.from_bytes(data[20:24], "big") return (width, height) - + # JPEG: signature 0xFF 0xD8 0xFF, parse SOF markers - if data[:3] == b'\xff\xd8\xff': + if data[:3] == b"\xff\xd8\xff": return _parse_jpeg_dimensions(data) - + # GIF: signature GIF87a or GIF89a, dimensions at offset 6-9 (little-endian) - if data[:6] in (b'GIF87a', b'GIF89a') and len(data) >= 10: - width = int.from_bytes(data[6:8], 'little') - height = int.from_bytes(data[8:10], 'little') + if data[:6] in (b"GIF87a", b"GIF89a") and len(data) >= 10: + width = int.from_bytes(data[6:8], "little") + height = int.from_bytes(data[8:10], "little") return (width, height) - + # BMP: signature BM, dimensions at offset 18-25 (little-endian, signed) - if data[:2] == b'BM' and len(data) >= 26: - width = abs(int.from_bytes(data[18:22], 'little', signed=True)) - height = abs(int.from_bytes(data[22:26], 'little', signed=True)) + if data[:2] == b"BM" and len(data) >= 26: + width = abs(int.from_bytes(data[18:22], "little", signed=True)) + height = abs(int.from_bytes(data[22:26], "little", signed=True)) return (width, height) - + # WebP: RIFF....WEBP - if len(data) >= 30 and data[:4] == b'RIFF' and data[8:12] == b'WEBP': + if len(data) >= 30 and data[:4] == b"RIFF" and data[8:12] == b"WEBP": return _parse_webp_dimensions(data) - + return None @@ -55,44 +55,43 @@ def _parse_jpeg_dimensions(data: bytes) -> Optional[Tuple[int, int]]: if data[i] != 0xFF: i += 1 continue - + marker = data[i + 1] - + # SOF markers: C0, C1, C2, C3, C5, C6, C7, C9, CA, CB, CD, CE, CF - if marker in (0xC0, 0xC1, 0xC2, 0xC3, 0xC5, 0xC6, 0xC7, - 0xC9, 0xCA, 0xCB, 0xCD, 0xCE, 0xCF): + if marker in (0xC0, 0xC1, 0xC2, 0xC3, 0xC5, 0xC6, 0xC7, 0xC9, 0xCA, 0xCB, 0xCD, 0xCE, 0xCF): if i + 9 < len(data): - height = int.from_bytes(data[i+5:i+7], 'big') - width = int.from_bytes(data[i+7:i+9], 'big') + height = int.from_bytes(data[i + 5 : i + 7], "big") + width = int.from_bytes(data[i + 7 : i + 9], "big") return (width, height) - + # Skip to next marker if marker in (0xFF, 0x00, 0x01) or 0xD0 <= marker <= 0xD9: i += 2 elif i + 3 < len(data): - length = int.from_bytes(data[i+2:i+4], 'big') + length = int.from_bytes(data[i + 2 : i + 4], "big") i += 2 + length else: break - + return None def _parse_webp_dimensions(data: bytes) -> Optional[Tuple[int, int]]: """Parse WebP dimensions (VP8 and VP8L formats).""" # VP8 format - if data[12:16] == b'VP8 ' and len(data) >= 30: - width = (int.from_bytes(data[26:28], 'little') & 0x3FFF) - height = (int.from_bytes(data[28:30], 'little') & 0x3FFF) + if data[12:16] == b"VP8 " and len(data) >= 30: + width = int.from_bytes(data[26:28], "little") & 0x3FFF + height = int.from_bytes(data[28:30], "little") & 0x3FFF return (width, height) - + # VP8L format - if data[12:16] == b'VP8L' and len(data) >= 25: + if data[12:16] == b"VP8L" and len(data) >= 25: b0, b1, b2, b3 = data[21], data[22], data[23], data[24] width = ((b1 & 0x3F) << 8 | b0) + 1 height = ((b3 & 0x0F) << 10 | b2 << 2 | (b1 >> 6)) + 1 return (width, height) - + return None @@ -102,11 +101,11 @@ def view_image( ) -> ToolResult: """ Load a local image and return it for the model context. - + Args: file_path: Path to the image file (relative or absolute) cwd: Current working directory - + Returns: ToolResult with success status and optional image content """ @@ -115,20 +114,20 @@ def view_image( if not path.is_absolute(): path = cwd / path path = path.resolve() - + # Check if file exists if not path.exists(): return ToolResult( success=False, output=f"Image not found: {path}", ) - + if not path.is_file(): return ToolResult( success=False, output=f"Not a file: {path}", ) - + # Check if it's an image file valid_extensions = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"} if path.suffix.lower() not in valid_extensions: @@ -136,31 +135,31 @@ def view_image( success=False, output=f"Not a valid image file: {path} (supported: {', '.join(valid_extensions)})", ) - + try: # Read raw bytes first to get dimensions image_data = path.read_bytes() dimensions = get_image_dimensions(image_data) - + # Load and encode the image data_uri = load_image_as_data_uri(path) - + # Create content block for injection image_content = make_image_content(data_uri) - + # Build output message with dimensions if available if dimensions: width, height = dimensions output_msg = f"attached local image: {path.name} ({width}x{height})" else: output_msg = f"attached local image: {path.name}" - + return ToolResult( success=True, output=output_msg, inject_content=image_content, ) - + except FileNotFoundError: return ToolResult( success=False, diff --git a/src/tools/write_file.py b/src/tools/write_file.py index 6a0bb1a..8aa9696 100644 --- a/src/tools/write_file.py +++ b/src/tools/write_file.py @@ -2,54 +2,48 @@ from __future__ import annotations -import os -from pathlib import Path from typing import Any -from .base import BaseTool, ToolResult, ToolMetadata +from .base import BaseTool, ToolMetadata, ToolResult class WriteFileTool(BaseTool): """Tool to write content to a file.""" - + name = "write_file" description = "Write content to a file, creating parent directories if needed" - - def execute( - self, - file_path: str, - content: str, - **kwargs: Any - ) -> ToolResult: + + def execute(self, file_path: str, content: str, **kwargs: Any) -> ToolResult: """Write content to a file. - + Args: file_path: Path to the file to write content: Content to write to the file - + Returns: ToolResult with write status and metadata """ import time + start_time = time.time() - + resolved_path = self.resolve_path(file_path) - + try: # Create parent directories if they don't exist resolved_path.parent.mkdir(parents=True, exist_ok=True) - + # Write the content resolved_path.write_text(content, encoding="utf-8") - + # Get file stats file_size = resolved_path.stat().st_size - + # Create content preview (max 500 chars) content_preview = content[:500] if len(content) > 500: content_preview += "..." - + duration_ms = int((time.time() - start_time) * 1000) metadata = ToolMetadata( duration_ms=duration_ms, @@ -60,17 +54,17 @@ def execute( "extension": resolved_path.suffix, "size": file_size, "content_preview": content_preview, - } + }, ) - + result = ToolResult.ok(f"Successfully wrote {file_size} bytes to {file_path}") return result.with_metadata(metadata) - + except PermissionError: return ToolResult.fail(f"Permission denied: {file_path}") except Exception as e: return ToolResult.fail(f"Error writing file: {e}") - + @classmethod def get_spec(cls) -> dict[str, Any]: """Get the tool specification for the LLM.""" @@ -80,15 +74,9 @@ def get_spec(cls) -> dict[str, Any]: "parameters": { "type": "object", "properties": { - "file_path": { - "type": "string", - "description": "Path to the file to write" - }, - "content": { - "type": "string", - "description": "Content to write to the file" - } + "file_path": {"type": "string", "description": "Path to the file to write"}, + "content": {"type": "string", "description": "Content to write to the file"}, }, - "required": ["file_path", "content"] - } + "required": ["file_path", "content"], + }, } diff --git a/src/utils/__init__.py b/src/utils/__init__.py index fe7c9f6..62930e0 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,33 +1,30 @@ """Utility functions.""" # Legacy API (simple interface) +# Full fabric-core API from src.utils.truncate import ( - limit_output, - limit_lines, - smart_truncate, - limit_output_bytes, - truncate_output, - estimate_tokens, APPROX_BYTES_PER_TOKEN, DEFAULT_MAX_TOKENS, -) - -# Full fabric-core API -from src.utils.truncate import ( - TruncateStrategy, - TruncateConfig, - TruncateResult, TokenEstimator, TruncateBuilder, + TruncateConfig, + TruncateResult, + TruncateStrategy, + estimate_tokens, + limit_lines, + limit_output, + limit_output_bytes, + smart_truncate, truncate, - truncate_file, truncate_batch, + truncate_file, + truncate_output, ) __all__ = [ # Legacy "limit_output", - "limit_lines", + "limit_lines", "smart_truncate", "limit_output_bytes", "truncate_output", diff --git a/src/utils/files.py b/src/utils/files.py index 1516d55..141e918 100644 --- a/src/utils/files.py +++ b/src/utils/files.py @@ -2,37 +2,36 @@ from __future__ import annotations -import os from pathlib import Path from typing import Union def resolve_path(path: Union[str, Path], cwd: Optional[Path] = None) -> Path: """Resolve a path relative to CWD. - + Args: path: Path to resolve cwd: Current working directory (defaults to os.getcwd()) - + Returns: Resolved absolute path """ if cwd is None: cwd = Path.cwd() - + p = Path(path) if p.is_absolute(): return p.resolve() - + return (cwd / p).resolve() def is_binary_file(path: Path) -> bool: """Check if a file is binary. - + Args: path: Path to file - + Returns: True if file appears to be binary """ @@ -46,28 +45,28 @@ def is_binary_file(path: Path) -> bool: def read_file_safely(path: Path, max_size: int = 10 * 1024 * 1024) -> str: """Read a file safely with size limit. - + Args: path: Path to file max_size: Maximum size in bytes - + Returns: File content - + Raises: ValueError: If file is too large or binary """ if not path.exists(): raise FileNotFoundError(f"File not found: {path}") - + if not path.is_file(): raise ValueError(f"Not a file: {path}") - + size = path.stat().st_size if size > max_size: raise ValueError(f"File too large: {size} bytes (max {max_size})") - + if is_binary_file(path): raise ValueError("File appears to be binary") - + return path.read_text(encoding="utf-8", errors="replace") diff --git a/src/utils/tokens.py b/src/utils/tokens.py index d57ab4d..0132969 100644 --- a/src/utils/tokens.py +++ b/src/utils/tokens.py @@ -5,17 +5,17 @@ def estimate_tokens(text: str) -> int: """Estimate the number of tokens in a string. - + This uses a simple heuristic (4 chars per token) which is commonly used as a rough approximation for English text when a tokenizer isn't available. - + Args: text: Input text - + Returns: Estimated token count """ if not text: return 0 - + return len(text) // 4 diff --git a/src/utils/truncate.py b/src/utils/truncate.py index 717219f..89e6c09 100644 --- a/src/utils/truncate.py +++ b/src/utils/truncate.py @@ -17,7 +17,7 @@ class TruncateStrategy(Enum): """Truncation strategy.""" - + # Truncate from the end (keep beginning) END = "end" # Truncate from the beginning (keep end) @@ -33,7 +33,7 @@ class TruncateStrategy(Enum): @dataclass class TruncateConfig: """Truncation configuration.""" - + # Maximum length in characters max_chars: int = 10000 # Maximum length in tokens (approximate) @@ -59,7 +59,7 @@ class TruncateConfig: @dataclass class TruncateResult: """Truncation result.""" - + # Resulting text text: str # Whether truncation occurred @@ -74,13 +74,13 @@ class TruncateResult: final_tokens: int # Strategy that was used strategy_used: TruncateStrategy - + def reduction_percent(self) -> float: """Get reduction percentage.""" if self.original_chars == 0: return 0.0 return (1.0 - (self.final_chars / self.original_chars)) * 100.0 - + def is_ok(self) -> bool: """Check if truncation was successful.""" return len(self.text) > 0 @@ -89,62 +89,62 @@ def is_ok(self) -> bool: def estimate_tokens(text: str) -> int: """ Estimate token count (rough approximation). - + Rough estimate: ~4 chars per token for English. This is a simplification - real tokenization varies by model. """ char_count = len(text) word_count = len(text.split()) - + # Average of character-based and word-based estimates return (char_count // 4 + word_count) // 2 class TokenEstimator: """More accurate token estimation with caching.""" - + def __init__(self, chars_per_token: float = 4.0): """Create a new estimator.""" self._cache: Dict[str, int] = {} self._chars_per_token = chars_per_token - + @classmethod def with_ratio(cls, chars_per_token: float) -> "TokenEstimator": """Create with custom ratio.""" return cls(chars_per_token=chars_per_token) - + def estimate(self, text: str) -> int: """Estimate tokens for text.""" # Create hash for caching text_hash = hashlib.md5(text.encode(), usedforsecurity=False).hexdigest() - + if text_hash in self._cache: return self._cache[text_hash] - + estimate = self._calculate(text) - + # Cache if not too large if len(self._cache) < 10000: self._cache[text_hash] = estimate - + return estimate - + def _calculate(self, text: str) -> int: """Calculate token estimate.""" char_count = len(text) return int((char_count / self._chars_per_token) + 0.5) # ceil-like rounding - + def calibrate(self, samples: List[Tuple[str, int]]) -> None: """Calibrate ratio based on actual token counts.""" if not samples: return - + total_chars = sum(len(text) for text, _ in samples) total_tokens = sum(tokens for _, tokens in samples) - + if total_tokens > 0: self._chars_per_token = total_chars / total_tokens - + def clear_cache(self) -> None: """Clear cache.""" self._cache.clear() @@ -154,12 +154,12 @@ def truncate(text: str, config: TruncateConfig) -> TruncateResult: """Truncate text according to configuration.""" original_len = len(text) original_tokens = estimate_tokens(text) - + # Check if truncation needed needs_truncation = original_len > config.max_chars if config.max_tokens is not None: needs_truncation = needs_truncation or (original_tokens > config.max_tokens) - + if not needs_truncation: return TruncateResult( text=text, @@ -170,7 +170,7 @@ def truncate(text: str, config: TruncateConfig) -> TruncateResult: final_tokens=original_tokens, strategy_used=config.strategy, ) - + # Apply truncation strategy if config.strategy == TruncateStrategy.END: truncated_text = truncate_end(text, config) @@ -184,10 +184,10 @@ def truncate(text: str, config: TruncateConfig) -> TruncateResult: truncated_text = _truncate_summarize(text, config) else: truncated_text = truncate_end(text, config) - + final_len = len(truncated_text) final_tokens = estimate_tokens(truncated_text) - + return TruncateResult( text=truncated_text, truncated=True, @@ -202,12 +202,12 @@ def truncate(text: str, config: TruncateConfig) -> TruncateResult: def truncate_end(text: str, config: TruncateConfig) -> str: """Simple truncation from end.""" target_len = max(0, config.max_chars - len(config.suffix)) - + if len(text) <= target_len: return text - + end = target_len - + # Align to boundary if config.sentence_boundary: end = find_sentence_boundary(text, end, forward=False) @@ -215,19 +215,19 @@ def truncate_end(text: str, config: TruncateConfig) -> str: end = find_line_boundary(text, end, forward=False) elif config.word_boundary: end = find_word_boundary(text, end, forward=False) - + return f"{text[:end]}{config.suffix}" def truncate_start(text: str, config: TruncateConfig) -> str: """Truncation from start (keep end).""" target_len = max(0, config.max_chars - len(config.prefix)) - + if len(text) <= target_len: return text - + start = len(text) - target_len - + # Align to boundary if config.sentence_boundary: start = find_sentence_boundary(text, start, forward=True) @@ -235,7 +235,7 @@ def truncate_start(text: str, config: TruncateConfig) -> str: start = find_line_boundary(text, start, forward=True) elif config.word_boundary: start = find_word_boundary(text, start, forward=True) - + return f"{config.prefix}{text[start:]}" @@ -243,14 +243,14 @@ def truncate_middle(text: str, config: TruncateConfig) -> str: """Truncation from middle (keep both ends).""" separator = "\n\n[...content omitted...]\n\n" target_len = max(0, config.max_chars - len(separator)) - + if len(text) <= target_len: return text - + keep_each = target_len // 2 start_end = _find_boundary(text, keep_each, forward=True, config=config) end_start = len(text) - _find_boundary(text, keep_each, forward=False, config=config) - + return f"{text[:start_end]}{separator}{text[end_start:]}" @@ -260,7 +260,7 @@ def truncate_smart(text: str, config: TruncateConfig) -> str: has_code = "```" in text or " " in text has_lists = "\n- " in text or "\n* " in text or "\n1." in text has_headers = "\n#" in text or "\n==" in text - + # Choose strategy based on content if has_code and config.preserve_code: return truncate_preserve_code(text, config) @@ -276,15 +276,15 @@ def truncate_preserve_code(text: str, config: TruncateConfig) -> str: remaining = config.max_chars in_code_block = False code_block_content: List[str] = [] - - for line in text.split('\n'): + + for line in text.split("\n"): if line.startswith("```"): if in_code_block: # End of code block - add it if it fits code_block_content.append(line) - code_block_content.append('') # For the newline - - block_text = '\n'.join(code_block_content) + code_block_content.append("") # For the newline + + block_text = "\n".join(code_block_content) if len(block_text) <= remaining: result.append(block_text) remaining -= len(block_text) @@ -300,15 +300,15 @@ def truncate_preserve_code(text: str, config: TruncateConfig) -> str: line_len = len(line) + 1 # +1 for newline if line_len <= remaining: result.append(line) - result.append('') # For newline + result.append("") # For newline remaining -= line_len else: break - - result_text = '\n'.join(result) + + result_text = "\n".join(result) if len(result_text) < len(text): result_text += config.suffix - + return result_text @@ -318,17 +318,13 @@ def truncate_preserve_structure(text: str, config: TruncateConfig) -> str: remaining = max(0, config.max_chars - len(config.suffix)) current_section: List[str] = [] section_header = "" - - for line in text.split('\n'): - is_header = ( - line.startswith('#') or - line.startswith("==") or - line.startswith("--") - ) - + + for line in text.split("\n"): + is_header = line.startswith("#") or line.startswith("==") or line.startswith("--") + if is_header: # Flush previous section - section_content = '\n'.join(current_section) + section_content = "\n".join(current_section) total_section = section_header + section_content if current_section and len(total_section) <= remaining: result.append(total_section) @@ -337,21 +333,21 @@ def truncate_preserve_structure(text: str, config: TruncateConfig) -> str: current_section.clear() else: current_section.append(line) - + if remaining == 0: break - + # Add last section if it fits if current_section: - section_content = '\n'.join(current_section) + section_content = "\n".join(current_section) total_section = section_header + section_content if len(total_section) <= remaining: result.append(total_section) - - result_text = ''.join(result) + + result_text = "".join(result) if len(result_text) < len(text): result_text += config.suffix - + return result_text @@ -378,30 +374,30 @@ def find_word_boundary(text: str, pos: int, forward: bool) -> int: """Find word boundary near position.""" if pos >= len(text): return len(text) - + if forward: # Search forward for space or newline for i in range(pos, min(len(text), pos + 50)): - if text[i] in ' \n': + if text[i] in " \n": return i # Search backward if nothing found for i in range(max(0, pos - 50), pos): idx = pos - 1 - (i - max(0, pos - 50)) - if idx >= 0 and text[idx] in ' \n': + if idx >= 0 and text[idx] in " \n": return idx + 1 else: # Search backward for i in range(min(pos, len(text)) - 1, max(0, pos - 50) - 1, -1): - if text[i] in ' \n': + if text[i] in " \n": return i + 1 - + return pos def find_sentence_boundary(text: str, pos: int, forward: bool) -> int: """Find sentence boundary near position.""" sentence_ends = [". ", "! ", "? ", ".\n", "!\n", "?\n"] - + if forward: for i in range(pos, min(len(text), pos + 200)): for end in sentence_ends: @@ -412,19 +408,19 @@ def find_sentence_boundary(text: str, pos: int, forward: bool) -> int: for end in sentence_ends: if text[i:].startswith(end): return i + len(end) - + return find_word_boundary(text, pos, forward) def find_line_boundary(text: str, pos: int, forward: bool) -> int: """Find line boundary near position.""" if forward: - idx = text.find('\n', pos) + idx = text.find("\n", pos) if idx != -1: return idx + 1 return pos else: - idx = text.rfind('\n', 0, pos) + idx = text.rfind("\n", 0, pos) if idx != -1: return idx + 1 return 0 @@ -434,14 +430,14 @@ def truncate_file(content: str, file_type: str, max_chars: int) -> str: """Truncate file content intelligently.""" code_types = {"rs", "py", "js", "ts", "go", "c", "cpp", "java", "rb", "php"} markdown_types = {"md", "markdown"} - + config = TruncateConfig( max_chars=max_chars, preserve_code=file_type in code_types, preserve_markdown=file_type in markdown_types, strategy=TruncateStrategy.SMART, ) - + return truncate(content, config).text @@ -449,84 +445,84 @@ def truncate_batch(items: List[str], total_chars: int) -> List[str]: """Truncate multiple strings to fit total budget.""" if not items: return [] - + total_len = sum(len(s) for s in items) - + if total_len <= total_chars: return list(items) - + # Proportional allocation ratio = total_chars / total_len - + result: List[str] = [] for item in items: target = int(len(item) * ratio) config = TruncateConfig(max_chars=target) result.append(truncate(item, config).text) - + return result @dataclass class TruncateBuilder: """Builder for truncation configuration.""" - + _config: TruncateConfig = field(default_factory=TruncateConfig) - + def max_chars(self, max_val: int) -> "TruncateBuilder": """Set maximum characters.""" self._config.max_chars = max_val return self - + def max_tokens(self, max_val: int) -> "TruncateBuilder": """Set maximum tokens.""" self._config.max_tokens = max_val return self - + def strategy(self, strategy: TruncateStrategy) -> "TruncateBuilder": """Set strategy.""" self._config.strategy = strategy return self - + def suffix(self, suffix: str) -> "TruncateBuilder": """Set suffix.""" self._config.suffix = suffix return self - + def prefix(self, prefix: str) -> "TruncateBuilder": """Set prefix.""" self._config.prefix = prefix return self - + def word_boundary(self, enabled: bool) -> "TruncateBuilder": """Enable word boundary alignment.""" self._config.word_boundary = enabled return self - + def sentence_boundary(self, enabled: bool) -> "TruncateBuilder": """Enable sentence boundary alignment.""" self._config.sentence_boundary = enabled return self - + def line_boundary(self, enabled: bool) -> "TruncateBuilder": """Enable line boundary alignment.""" self._config.line_boundary = enabled return self - + def preserve_code(self, enabled: bool) -> "TruncateBuilder": """Preserve code blocks.""" self._config.preserve_code = enabled return self - + def preserve_markdown(self, enabled: bool) -> "TruncateBuilder": """Preserve markdown structure.""" self._config.preserve_markdown = enabled return self - + def build(self) -> TruncateConfig: """Build configuration.""" return self._config - + def truncate(self, text: str) -> TruncateResult: """Truncate text with built configuration.""" return truncate(text, self._config) @@ -546,6 +542,7 @@ def truncate(self, text: str) -> TruncateResult: @dataclass class LegacyTruncateResult: """Result of truncation operation (legacy).""" + text: str truncated: bool original_bytes: int @@ -560,17 +557,17 @@ def truncate_output( ) -> LegacyTruncateResult: """ Truncate output to max tokens, keeping head and tail. - + Matches Codex behavior: - Uses token-based (not byte-based) limits - Truncates middle, keeping equal head/tail - Format: "{N} tokens truncated" - Prepends "Total output lines: {N}" when truncated - + Args: output: The output string to truncate max_tokens: Maximum tokens to keep (default: 2500 = ~10KB) - + Returns: LegacyTruncateResult with truncated text and metadata """ @@ -583,12 +580,12 @@ def truncate_output( tokens_truncated=0, total_lines=0, ) - + output_bytes = output.encode("utf-8") original_bytes = len(output_bytes) original_tokens = original_bytes // APPROX_BYTES_PER_TOKEN - total_lines = output.count('\n') + (1 if output and not output.endswith('\n') else 0) - + total_lines = output.count("\n") + (1 if output and not output.endswith("\n") else 0) + if original_tokens <= max_tokens: return LegacyTruncateResult( text=output, @@ -598,35 +595,35 @@ def truncate_output( tokens_truncated=0, total_lines=total_lines, ) - + # Calculate bytes to keep (convert tokens back to bytes) max_bytes = max_tokens * APPROX_BYTES_PER_TOKEN - + # Split evenly between head and tail head_bytes = max_bytes // 2 tail_bytes = max_bytes - head_bytes - + # Get head portion head_raw = output_bytes[:head_bytes] - # Get tail portion + # Get tail portion tail_raw = output_bytes[-tail_bytes:] - + # Decode, handling UTF-8 boundary issues head = head_raw.decode("utf-8", errors="ignore") tail = tail_raw.decode("utf-8", errors="ignore") - + # Calculate truncated tokens kept_bytes = len(head.encode()) + len(tail.encode()) tokens_truncated = (original_bytes - kept_bytes) // APPROX_BYTES_PER_TOKEN - + # Build truncation message matching Codex format truncation_msg = f"\n...{tokens_truncated} tokens truncated...\n" - + # Prepend total lines info lines_prefix = f"Total output lines: {total_lines}\n" - + truncated_text = f"{lines_prefix}{head}{truncation_msg}{tail}" - + return LegacyTruncateResult( text=truncated_text, truncated=True, @@ -643,11 +640,11 @@ def limit_output( ) -> str: """ Simple interface: truncate and return just the text. - + Args: output: The output string to truncate max_tokens: Maximum tokens to keep - + Returns: Truncated string """ @@ -661,33 +658,33 @@ def limit_lines( ) -> str: """ Limit output to max lines, keeping first and last portions. - + Args: output: The output string to truncate max_lines: Maximum lines to keep head_lines: Number of lines to keep from the start - + Returns: Truncated string with message if truncated """ if not output: return output - + lines = output.splitlines(keepends=True) total_lines = len(lines) - + if total_lines <= max_lines: return output - + tail_lines = max_lines - head_lines omitted = total_lines - max_lines - + head = "".join(lines[:head_lines]) tail = "".join(lines[-tail_lines:]) if tail_lines > 0 else "" - + # Match Codex message format truncation_msg = f"\n...{omitted} lines omitted...\n" - + return f"Total output lines: {total_lines}\n{head}{truncation_msg}{tail}" @@ -698,21 +695,21 @@ def smart_truncate( ) -> str: """ Smart truncation: applies both token and line limits. - + Args: output: The output to truncate max_tokens: Maximum tokens max_lines: Maximum lines - + Returns: Truncated output """ # First limit by lines (faster check) result = limit_lines(output, max_lines) - + # Then limit by tokens result = limit_output(result, max_tokens) - + return result @@ -729,48 +726,48 @@ def middle_out_truncate( ) -> str: """ Middle-out truncation like Codex. - + Keeps beginning and end, removes middle. More useful than head-only because: - Beginning often has context/headers - End often has results/conclusions - + Args: text: Text to truncate max_tokens: Maximum tokens to keep - + Returns: Truncated text with marker in middle """ if not text: return text - + text_bytes = text.encode("utf-8") original_bytes = len(text_bytes) original_tokens = original_bytes // APPROX_BYTES_PER_TOKEN - + if original_tokens <= max_tokens: return text - + # Calculate bytes to keep max_bytes = max_tokens * APPROX_BYTES_PER_TOKEN - + # Split 50/50 between head and tail head_bytes = max_bytes // 2 tail_bytes = max_bytes - head_bytes - + # Extract portions head_raw = text_bytes[:head_bytes] tail_raw = text_bytes[-tail_bytes:] - + # Decode safely (handle UTF-8 boundary issues) head = head_raw.decode("utf-8", errors="ignore") tail = tail_raw.decode("utf-8", errors="ignore") - + # Calculate removed tokens kept_bytes = len(head.encode()) + len(tail.encode()) removed_tokens = (original_bytes - kept_bytes) // APPROX_BYTES_PER_TOKEN - + return f"{head}\n\n...{removed_tokens} tokens truncated...\n\n{tail}"