diff --git a/pyproject.toml b/pyproject.toml index 1559420c..579fab24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,10 @@ langgraph = [ "langchain-openai>=0.3.27", ] +smolagents = [ + "smolagents>=1.22.0", +] + [project.scripts] art = "art.cli:app" stop-server = "art.skypilot.stop_server:main" diff --git a/src/art/smolagents/__init__.py b/src/art/smolagents/__init__.py new file mode 100644 index 00000000..d1066a30 --- /dev/null +++ b/src/art/smolagents/__init__.py @@ -0,0 +1,3 @@ +from .llm_wrapper import init_chat_model, wrap_rollout + +__all__ = ["wrap_rollout", "init_chat_model"] diff --git a/src/art/smolagents/llm_wrapper.py b/src/art/smolagents/llm_wrapper.py new file mode 100644 index 00000000..f418efb7 --- /dev/null +++ b/src/art/smolagents/llm_wrapper.py @@ -0,0 +1,190 @@ +"""LLM wrapper with logging functionality for smolagents.""" + +import contextvars +import os +import uuid +from typing import Any + +from smolagents import ChatMessage, MessageRole, Model, Tool + +from art.trajectories import History, Trajectory + +from .logging import FileLogger +from .message_utils import convert_smolagents_input_messages, convert_smolagents_output_to_choice + +CURRENT_CONFIG = contextvars.ContextVar("CURRENT_CONFIG") + + +def add_thread(thread_id, base_url, api_key, model): + log_path = f".art/smolagents/{thread_id}" + os.makedirs(os.path.dirname(log_path), exist_ok=True) + CURRENT_CONFIG.set( + { + "logger": FileLogger(log_path), + "base_url": base_url, + "api_key": api_key, + "model": model, + } + ) + return log_path + + +def create_messages_from_logs(log_path: str, trajectory: Trajectory): + """Load logs and convert them to trajectory format.""" + logs = FileLogger(log_path).load_logs() + + if not logs: + return trajectory + + conversations = [] + tools_list = [] + + for log_entry in logs: + try: + entry_data = log_entry[1] + input_msgs = entry_data["input"] # list[ChatMessage] + output_msg = entry_data["output"] # ChatMessage from model + tools = entry_data["tools"] + + new_conversation = { + "input": input_msgs, + "output": output_msg, + } + + # Try to match with existing conversations by comparing input + matched = False + for idx, existing in enumerate(conversations): + existing_input = existing["input"] + # Compare non-TOOL_RESPONSE messages + existing_non_tool = [m for m in existing_input if m.role != MessageRole.TOOL_RESPONSE] + new_non_tool = [m for m in input_msgs if m.role != MessageRole.TOOL_RESPONSE] + + if len(existing_non_tool) == len(new_non_tool) and all( + e.content == n.content and e.role == n.role + for e, n in zip(existing_non_tool, new_non_tool) + ): + # Replace with the longer conversation + conversations[idx] = new_conversation + tools_list[idx] = tools + matched = True + break + + if not matched: + conversations.append(new_conversation) + tools_list.append(tools) + except Exception as e: + print(f"Warning: Failed to parse log entry: {e}") + continue + + # Convert conversations to trajectory format + for idx, conv in enumerate(conversations): + try: + # Convert input messages + input_converted = convert_smolagents_input_messages(conv["input"]) + # Convert output to Choice + output_choice = convert_smolagents_output_to_choice(conv["output"]) + # Combine: input messages + output choice + converted = input_converted + [output_choice] + + if idx == 0: + trajectory.messages_and_choices = converted + trajectory.tools = tools_list[idx] + else: + trajectory.additional_histories.append( + History(messages_and_choices=converted, tools=tools_list[idx]) + ) + except Exception as e: + print(f"Warning: Failed to convert conversation {idx}: {e}") + continue + + return trajectory + + +def wrap_rollout(model, fn): + """Wrap a rollout function to log all model interactions.""" + async def wrapper(*args, **kwargs): + thread_id = str(uuid.uuid4()) + log_path = add_thread( + thread_id, + model.inference_base_url, + model.inference_api_key, + model.inference_model_name, + ) + result = await fn(*args, **kwargs) + return create_messages_from_logs(log_path, result) + + return wrapper + + +def init_chat_model(**kwargs: Any) -> Model: + """Get a LoggingModel instance with the current config.""" + config = CURRENT_CONFIG.get() + + from smolagents import LiteLLMModel + + base_model = LiteLLMModel( + model_id=config["model"], + api_base=config["base_url"], + api_key=config["api_key"], + **kwargs + ) + + return LoggingModel(base_model, config["logger"]) + + +class LoggingModel(Model): + """A Model wrapper that logs all generate() calls.""" + + def __init__(self, base_model: Model, logger: FileLogger): + super().__init__( + flatten_messages_as_text=base_model.flatten_messages_as_text, + tool_name_key=base_model.tool_name_key, + tool_arguments_key=base_model.tool_arguments_key, + model_id=base_model.model_id, + ) + self.base_model = base_model + self.logger = logger + + def generate( + self, + messages: list[ChatMessage], + stop_sequences: list[str] | None = None, + response_format: dict[str, str] | None = None, + tools_to_call_from: list[Tool] | None = None, + **kwargs, + ) -> ChatMessage: + """Generate a response and log it.""" + completion_id = str(uuid.uuid4()) + + # Store tools for logging + tools_for_log = [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.inputs, + } + } + for tool in tools_to_call_from + ] if tools_to_call_from else None + + # Generate the response + result = self.base_model.generate( + messages, + stop_sequences=stop_sequences, + response_format=response_format, + tools_to_call_from=tools_to_call_from, + **kwargs + ) + + # Log the interaction + if self.logger: + entry = { + "input": messages, + "output": result, + "tools": tools_for_log + } + self.logger.log(completion_id, entry) + + return result diff --git a/src/art/smolagents/logging.py b/src/art/smolagents/logging.py new file mode 100644 index 00000000..4b50a953 --- /dev/null +++ b/src/art/smolagents/logging.py @@ -0,0 +1,30 @@ +import os +import pickle + + +class FileLogger: + def __init__(self, filepath): + self.text_path = filepath + self.pickle_path = filepath + ".pkl" + + def log(self, name, entry): + # Log as readable text + with open(self.text_path, "a") as f: + f.write(f"{name}: {entry}\n") + + # Append to pickle log + with open(self.pickle_path, "ab") as pf: + pickle.dump((name, entry), pf) + + def load_logs(self): + """Load all logs from the pickle file.""" + if not os.path.exists(self.pickle_path): + return [] + logs = [] + with open(self.pickle_path, "rb") as pf: + try: + while True: + logs.append(pickle.load(pf)) + except EOFError: + pass + return logs diff --git a/src/art/smolagents/message_utils.py b/src/art/smolagents/message_utils.py new file mode 100644 index 00000000..d6e88133 --- /dev/null +++ b/src/art/smolagents/message_utils.py @@ -0,0 +1,75 @@ +import json +from typing import List, Union + +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from smolagents import ChatMessage +from smolagents.models import get_clean_message_list, tool_role_conversions + +Message = ChatCompletionMessageParam +MessagesAndChoices = List[Union[Message, Choice]] + + +def create_choice_from_message(msg: ChatMessage) -> Choice: + """Convert a smolagents ChatMessage with token_usage to OpenAI Choice format.""" + tool_calls = None + if msg.tool_calls: + tool_calls = [] + for tc in msg.tool_calls: + tool_calls.append({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": json.dumps(tc.function.arguments) + if isinstance(tc.function.arguments, dict) + else tc.function.arguments, + } + }) + + # Extract content - handle both str and list[dict] formats + content = msg.content + if isinstance(content, list): + # Convert list format to string + text_parts = [item.get("text", "") for item in content if isinstance(item, dict) and item.get("type") == "text"] + content = "\n".join(text_parts) if text_parts else "" + elif content is None: + content = "" + + return Choice( + message=ChatCompletionMessage( + role="assistant", + content=content, + tool_calls=tool_calls, + ), + index=0, + finish_reason="stop", + logprobs=None, + ) + + +def convert_smolagents_input_messages(messages: List[ChatMessage]) -> list: + """Convert input messages to OpenAI Message format. + + Returns a list of dicts compatible with OpenAI message format. + """ + messages_for_clean: list = messages # type: ignore + + # Use smolagents' utility to convert messages to clean dict format + clean_messages = get_clean_message_list( + messages_for_clean, + role_conversions=tool_role_conversions, + convert_images_to_image_urls=False, + flatten_messages_as_text=False, + ) + + return clean_messages + + +def convert_smolagents_output_to_choice(output: ChatMessage) -> Choice: + """Convert model output message to OpenAI Choice format. + + The output should be the result from Model.generate(). + """ + return create_choice_from_message(output) diff --git a/uv.lock b/uv.lock index cb58d75f..8d791a9b 100644 --- a/uv.lock +++ b/uv.lock @@ -4180,6 +4180,9 @@ skypilot = [ { name = "semver" }, { name = "skypilot", extra = ["cudo", "do", "gcp", "kubernetes", "runpod"] }, ] +smolagents = [ + { name = "smolagents" }, +] [package.dev-dependencies] dev = [ @@ -4219,6 +4222,7 @@ requires-dist = [ { name = "setproctitle", marker = "extra == 'backend'", specifier = ">=1.3.6" }, { name = "setuptools", marker = "extra == 'backend'", specifier = ">=78.1.0" }, { name = "skypilot", extras = ["cudo", "do", "fluidstack", "gcp", "lambda", "kubernetes", "paperspace", "runpod"], marker = "extra == 'skypilot'", specifier = "==0.10.2" }, + { name = "smolagents", marker = "extra == 'smolagents'", specifier = ">=1.22.0" }, { name = "tblib", marker = "extra == 'backend'", specifier = ">=3.0.0" }, { name = "torch", marker = "extra == 'backend'", specifier = ">=2.7.0" }, { name = "torchao", marker = "extra == 'backend'", specifier = ">=0.9.0" }, @@ -4233,7 +4237,7 @@ requires-dist = [ { name = "wandb", marker = "extra == 'backend'", specifier = "==0.21.0" }, { name = "weave", specifier = ">=0.51.51" }, ] -provides-extras = ["plotting", "backend", "skypilot", "langgraph"] +provides-extras = ["plotting", "backend", "skypilot", "langgraph", "smolagents"] [package.metadata.requires-dev] dev = [ @@ -6886,6 +6890,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, ] +[[package]] +name = "smolagents" +version = "1.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "jinja2" }, + { name = "pillow" }, + { name = "python-dotenv" }, + { name = "requests" }, + { name = "rich" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/bc/ad2f168b82d26597257adb071c51348dd94da0bc29a6cc10ba4c1bee27c8/smolagents-1.22.0.tar.gz", hash = "sha256:5fb66f48e3b3ab5e8defcef577a89d5b6dfa8fcb55fc98a58e156cb3c59eb68f", size = 213047, upload-time = "2025-09-25T08:42:56.086Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/17/54d7de27b7ac2722ac2c0f452da612bf4af80ae09f90231b44bb7b12b33d/smolagents-1.22.0-py3-none-any.whl", hash = "sha256:5334adb4e7e5814cd814f1d9ad7efa806ef57f53db40635a29d2bd727774c5f5", size = 149836, upload-time = "2025-09-25T08:42:54.205Z" }, +] + [[package]] name = "sniffio" version = "1.3.1"