Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chatlas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ._provider_portkey import ChatPortkey
from ._provider_snowflake import ChatSnowflake
from ._tokens import token_usage
from ._tools import Tool, ToolRejectError
from ._tools import Tool, ToolBuiltIn, ToolRejectError
from ._turn import Turn

try:
Expand Down Expand Up @@ -88,6 +88,7 @@
"Provider",
"token_usage",
"Tool",
"ToolBuiltIn",
"ToolRejectError",
"Turn",
"types",
Expand Down
59 changes: 45 additions & 14 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from ._mcp_manager import MCPSessionManager
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
from ._tokens import compute_cost, get_token_pricing, tokens_log
from ._tools import Tool, ToolRejectError
from ._tools import Tool, ToolBuiltIn, ToolRejectError
from ._turn import Turn, user_turn
from ._typing_extensions import TypedDict, TypeGuard
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
self.system_prompt = system_prompt
self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {}

self._tools: dict[str, Tool] = {}
self._tools: dict[str, Tool | ToolBuiltIn] = {}
self._on_tool_request_callbacks = CallbackManager()
self._on_tool_result_callbacks = CallbackManager()
self._current_display: Optional[MarkdownDisplay] = None
Expand Down Expand Up @@ -1866,7 +1866,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):

def register_tool(
self,
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn",
*,
force: bool = False,
name: Optional[str] = None,
Expand Down Expand Up @@ -1960,31 +1960,39 @@ def add(a: int, b: int) -> int:
ValueError
If a tool with the same name already exists and `force` is `False`.
"""
if isinstance(func, Tool):
if isinstance(func, ToolBuiltIn):
# ToolBuiltIn objects are stored directly without conversion
tool = func
tool_name = tool.name
elif isinstance(func, Tool):
name = name or func.name
annotations = annotations or func.annotations
if model is not None:
func = Tool.from_func(
func.func, name=name, model=model, annotations=annotations
)
func = func.func
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
tool_name = tool.name
else:
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
tool_name = tool.name

tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
if tool.name in self._tools and not force:
if tool_name in self._tools and not force:
raise ValueError(
f"Tool with name '{tool.name}' is already registered. "
f"Tool with name '{tool_name}' is already registered. "
"Set `force=True` to overwrite it."
)
self._tools[tool.name] = tool
self._tools[tool_name] = tool

def get_tools(self) -> list[Tool]:
def get_tools(self) -> list[Tool | ToolBuiltIn]:
"""
Get the list of registered tools.

Returns
-------
list[Tool]
A list of `Tool` instances that are currently registered with the chat.
list[Tool | ToolBuiltIn]
A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat.
"""
return list(self._tools.values())

Expand Down Expand Up @@ -2508,7 +2516,7 @@ def _submit_turns(
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
if any(x._is_async for x in self._tools.values()):
if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")

def emit(text: str | Content):
Expand Down Expand Up @@ -2661,15 +2669,27 @@ def _collect_all_kwargs(

def _invoke_tool(self, request: ContentToolRequest):
tool = self._tools.get(request.name)
func = tool.func if tool is not None else None

if func is None:
if tool is None:
yield self._handle_tool_error_result(
request,
error=RuntimeError("Unknown tool."),
)
return

if isinstance(tool, ToolBuiltIn):
# Built-in tools are handled by the provider, not invoked directly
yield self._handle_tool_error_result(
request,
error=RuntimeError(
f"Built-in tool '{request.name}' cannot be invoked directly. "
"It should be handled by the provider."
),
)
return

func = tool.func

# First, invoke the request callbacks. If a ToolRejectError is raised,
# treat it like a tool failure (i.e., gracefully handle it).
result: ContentToolResult | None = None
Expand Down Expand Up @@ -2717,6 +2737,17 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
)
return

if isinstance(tool, ToolBuiltIn):
# Built-in tools are handled by the provider, not invoked directly
yield self._handle_tool_error_result(
request,
error=RuntimeError(
f"Built-in tool '{request.name}' cannot be invoked directly. "
"It should be handled by the provider."
),
)
return

if tool._is_async:
func = tool.func
else:
Expand Down
49 changes: 39 additions & 10 deletions chatlas/_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._typing_extensions import TypedDict

if TYPE_CHECKING:
from ._tools import Tool
from ._tools import Tool, ToolBuiltIn


class ToolAnnotations(TypedDict, total=False):
Expand Down Expand Up @@ -104,15 +104,28 @@ class ToolInfo(BaseModel):
annotations: Optional[ToolAnnotations] = None

@classmethod
def from_tool(cls, tool: "Tool") -> "ToolInfo":
"""Create a ToolInfo from a Tool instance."""
func_schema = tool.schema["function"]
return cls(
name=tool.name,
description=func_schema.get("description", ""),
parameters=func_schema.get("parameters", {}),
annotations=tool.annotations,
)
def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo":
"""Create a ToolInfo from a Tool or ToolBuiltIn instance."""
from ._tools import ToolBuiltIn

if isinstance(tool, ToolBuiltIn):
# For built-in tools, extract info from the definition
defn = tool.definition
return cls(
name=tool.name,
description=defn.get("description", ""),
parameters=defn.get("parameters", {}),
annotations=None,
)
else:
# For regular tools, extract from schema
func_schema = tool.schema["function"]
return cls(
name=tool.name,
description=func_schema.get("description", ""),
parameters=func_schema.get("parameters", {}),
annotations=tool.annotations,
)


ContentTypeEnum = Literal[
Expand Down Expand Up @@ -247,6 +260,22 @@ def __str__(self):
def _repr_markdown_(self):
return self.__str__()

def _repr_png_(self):
"""Display PNG images directly in Jupyter notebooks."""
if self.image_content_type == "image/png" and self.data:
import base64

return base64.b64decode(self.data)
return None

def _repr_jpeg_(self):
"""Display JPEG images directly in Jupyter notebooks."""
if self.image_content_type == "image/jpeg" and self.data:
import base64

return base64.b64decode(self.data)
return None

def __repr__(self, indent: int = 0):
n_bytes = len(self.data) if self.data else 0
return (
Expand Down
18 changes: 9 additions & 9 deletions chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import BaseModel

from ._content import Content
from ._tools import Tool
from ._tools import Tool, ToolBuiltIn
from ._turn import Turn
from ._typing_extensions import NotRequired, TypedDict

Expand Down Expand Up @@ -162,7 +162,7 @@ def chat_perform(
*,
stream: Literal[False],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> ChatCompletionT: ...
Expand All @@ -174,7 +174,7 @@ def chat_perform(
*,
stream: Literal[True],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> Iterable[ChatCompletionChunkT]: ...
Expand All @@ -185,7 +185,7 @@ def chat_perform(
*,
stream: bool,
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
Expand All @@ -197,7 +197,7 @@ async def chat_perform_async(
*,
stream: Literal[False],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> ChatCompletionT: ...
Expand All @@ -209,7 +209,7 @@ async def chat_perform_async(
*,
stream: Literal[True],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> AsyncIterable[ChatCompletionChunkT]: ...
Expand All @@ -220,7 +220,7 @@ async def chat_perform_async(
*,
stream: bool,
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
Expand Down Expand Up @@ -259,15 +259,15 @@ def value_tokens(
def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
) -> int: ...

@abstractmethod
async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
) -> int: ...

Expand Down
36 changes: 29 additions & 7 deletions chatlas/_provider_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,25 @@ def _chat_perform_args(
config.response_mime_type = "application/json"

if tools:
config.tools = [
GoogleTool(
function_declarations=[
from ._tools import ToolBuiltIn

function_declarations = []
for tool in tools.values():
if isinstance(tool, ToolBuiltIn):
# For built-in tools, pass the raw definition through
# This allows provider-specific tools like image generation
# Note: Google's API expects these in a specific format
continue # Built-in tools are not yet fully supported for Google
else:
function_declarations.append(
FunctionDeclaration.from_callable(
client=self._client._api_client,
callable=tool.func,
)
for tool in tools.values()
]
)
]
)

if function_declarations:
config.tools = [GoogleTool(function_declarations=function_declarations)]

kwargs_full["config"] = config

Expand Down Expand Up @@ -552,6 +560,20 @@ def _as_turn(
),
)
)
inline_data = part.get("inlineData") or part.get("inline_data")
if inline_data:
# Handle image generation responses
mime_type = inline_data.get("mimeType") or inline_data.get("mime_type")
data = inline_data.get("data")
if mime_type and data:
# Ensure data is a string (should be base64 encoded)
data_str = data if isinstance(data, str) else str(data)
contents.append(
ContentImageInline(
image_content_type=mime_type, # type: ignore
data=data_str,
)
)

if isinstance(finish_reason, FinishReason):
finish_reason = finish_reason.name
Expand Down
Loading
Loading