diff --git a/pyproject.toml b/pyproject.toml index 21ee73b..5f12f30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ docs = "scripts.docs:main" [project.entry-points.components] tiny_agents = "tinygent.agents.register" +tiny_checkpointers = "tinygent.agents.checkpointer.register" tiny_middlewares = "tinygent.agents.middleware.register" tiny_tools = "tinygent.tools.register" tiny_memory = "tinygent.memory.register" diff --git a/tinygent/agents/base_agent.py b/tinygent/agents/base_agent.py index 51d7352..6fd9cc6 100644 --- a/tinygent/agents/base_agent.py +++ b/tinygent/agents/base_agent.py @@ -5,6 +5,7 @@ from io import StringIO import logging import textwrap +import typing from typing import Any from typing import Awaitable from typing import Callable @@ -17,6 +18,8 @@ from tinygent.agents.middleware.tool_limiter import ToolCallBlockedException from tinygent.core.datamodels.agent import AbstractAgent from tinygent.core.datamodels.agent import AbstractAgentConfig +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer +from tinygent.core.datamodels.checkpointer import AbstractCheckpointerConfig from tinygent.core.datamodels.llm import AbstractLLM from tinygent.core.datamodels.llm import AbstractLLMConfig from tinygent.core.datamodels.memory import AbstractMemory @@ -34,12 +37,20 @@ from tinygent.core.types.io.llm_io_input import TinyLLMInput from tinygent.memory.buffer_chat_memory import BufferChatMemoryConfig -T = TypeVar('T', bound='AbstractAgent') +if typing.TYPE_CHECKING: + from tinygent.agents.checkpointer.default_checkpointer import TinyDefaultCheckpointer +T = TypeVar('T', bound='AbstractAgent') logger = logging.getLogger(__name__) +def _create_default_checkpointer() -> 'TinyDefaultCheckpointer': + from tinygent.agents.checkpointer.default_checkpointer import TinyDefaultCheckpointer + + return TinyDefaultCheckpointer({}) + + class TinyBaseAgentConfig(AbstractAgentConfig[T], Generic[T]): """Configuration for BaseAgent.""" @@ -48,12 +59,14 @@ class TinyBaseAgentConfig(AbstractAgentConfig[T], Generic[T]): middleware: Sequence[AbstractMiddlewareConfig | AbstractMiddleware] = Field( default_factory=list ) - llm: AbstractLLMConfig | AbstractLLM = Field(...) tools: Sequence[AbstractToolConfig | AbstractTool] = Field(default_factory=list) memory: AbstractMemoryConfig | AbstractMemory = Field( default_factory=BufferChatMemoryConfig ) + checkpointer: AbstractCheckpointer | AbstractCheckpointerConfig | None = Field( + default=None + ) def build(self) -> T: """Build the BaseAgent instance from the configuration.""" @@ -84,6 +97,18 @@ def build_memory_instance(self) -> AbstractMemory: return build_memory(self.memory) + def build_checkpointer_instance(self) -> AbstractCheckpointer: + """Build checkpointer instance from config if checkpointer is set.""" + if isinstance(self.checkpointer, AbstractCheckpointer): + return self.checkpointer + + if self.checkpointer is None: + return _create_default_checkpointer() + + from tinygent.core.factory.checkpointer import build_checkpointer + + return build_checkpointer(self.checkpointer) + def build_middleware_list(self) -> list[AbstractMiddleware]: """Build list of middleware instances from configs or return existing instances.""" from tinygent.core.factory.middleware import build_middleware @@ -99,20 +124,26 @@ def __init__( self, llm: AbstractLLM, memory: AbstractMemory, - tools: Sequence[AbstractTool] = (), + tools: Sequence[AbstractTool] = [], middleware: Sequence[AbstractMiddleware] = [], + checkpointer: AbstractCheckpointer | None = None, ) -> None: self.llm = llm self.middleware = middleware self._memory = memory self._tools = tools + self._checkpointer = ( + _create_default_checkpointer() if checkpointer is None else checkpointer + ) self._final_answer: str | None = None def reset(self) -> None: logger.debug('[BASE AGENT RESET]') self.memory.clear() + self.checkpointer.clear() + self._final_answer = None @property diff --git a/tinygent/agents/checkpointer/__init__.py b/tinygent/agents/checkpointer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tinygent/agents/checkpointer/base_checkpointer.py b/tinygent/agents/checkpointer/base_checkpointer.py new file mode 100644 index 0000000..c694e13 --- /dev/null +++ b/tinygent/agents/checkpointer/base_checkpointer.py @@ -0,0 +1,55 @@ +import logging +from typing import Any +from typing import Generic +from typing import TypeVar + +from pydantic import Field + +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer +from tinygent.core.datamodels.checkpointer import AbstractCheckpointerConfig + +T = TypeVar('T', bound='TinyBaseCheckpointer') + + +logger = logging.getLogger(__name__) + + +class TinyBaseCheckpointerConfig(AbstractCheckpointerConfig[T], Generic[T]): + type: Any = Field(default='base') + + data: dict[str, Any] = Field(default={}) + + def build(self) -> T: + """Build the BaseCheckpointer instance from the configuration.""" + raise NotImplementedError('Subclasses must implement this method.') + + +class TinyBaseCheckpointer(AbstractCheckpointer): + def __init__(self, data: dict[str, Any]) -> None: + self.data = data + + def set_data(self, data: dict[str, Any]) -> None: + self.data = data + + def setdefault(self, key: str, value: Any) -> Any: + if key not in self.data: + self.data[key] = value + return self.data[key] + + def clear(self) -> None: + self.data = {} + + def __getitem__(self, key: str, default: Any = None) -> Any: + val = self.data.get(key, default) + if val is None: + logger.warning( + 'Key: %s is missing in checkpoint data %s', key, self.__class__.__name__ + ) + + if default is not None: + self.__setitem__('key', default) + + return val + + def __setitem__(self, key: str, value: Any) -> None: + self.data[key] = value diff --git a/tinygent/agents/checkpointer/default_checkpointer.py b/tinygent/agents/checkpointer/default_checkpointer.py new file mode 100644 index 0000000..dcdbff5 --- /dev/null +++ b/tinygent/agents/checkpointer/default_checkpointer.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import logging +from typing import Any + +from tinygent.agents.checkpointer.base_checkpointer import TinyBaseCheckpointer + +logger = logging.getLogger(__name__) + + +class TinyDefaultCheckpointer(TinyBaseCheckpointer): + def __init__(self, data: dict[str, Any]) -> None: + super().__init__(data) + + def save(self, checkpoint_id: str) -> None: + logger.debug( + 'Checkpoint (%s) will not be saved in default checkpointer', checkpoint_id + ) + + def load(self, checkpoint_id: str) -> None: + logger.debug('Nothing to be loaded for default checkpointer (%s)', checkpoint_id) + + def delete(self, checkpoint_id: str) -> None: + logger.debug('Deleting checkpoint %s', checkpoint_id) + self.data = {} diff --git a/tinygent/agents/checkpointer/local_checkpointer.py b/tinygent/agents/checkpointer/local_checkpointer.py new file mode 100644 index 0000000..4fb0c2f --- /dev/null +++ b/tinygent/agents/checkpointer/local_checkpointer.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from copy import deepcopy +import logging +from threading import Lock +from typing import Any +from typing import TypeVar + +from pydantic import Field + +from tinygent.agents.checkpointer.base_checkpointer import TinyBaseCheckpointer +from tinygent.agents.checkpointer.base_checkpointer import TinyBaseCheckpointerConfig + +logger = logging.getLogger(__name__) + +T = TypeVar('T', bound='TinyLocalCheckpointer') + +_GLOBAL_CHECKPOINTS: dict[str, Any] = {} +_GLOBAL_CHECKPOINTS_LOCK = Lock() + + +class TinyLocalCheckpointerConfig(TinyBaseCheckpointerConfig['TinyLocalCheckpointer']): + type: Any = Field(default='local', frozen=True) + + def build(self) -> TinyLocalCheckpointer: + return TinyLocalCheckpointer(self.data) + + +class TinyLocalCheckpointer(TinyBaseCheckpointer): + def __init__(self, data: dict[str, Any]) -> None: + super().__init__(data) + + def save(self, checkpoint_id: str) -> None: + logger.debug('Saving checkpoint %s', checkpoint_id) + with _GLOBAL_CHECKPOINTS_LOCK: + _GLOBAL_CHECKPOINTS[checkpoint_id] = deepcopy(self.data) + logger.debug('Data %s saved: %s', checkpoint_id, str(self.data)) + + def load(self, checkpoint_id: str) -> None: + logger.debug('Loading checkpoint %s', checkpoint_id) + with _GLOBAL_CHECKPOINTS_LOCK: + tmp = _GLOBAL_CHECKPOINTS.get(checkpoint_id) + if tmp is None: + logger.warning("Couldn't find data for checkpoint id: %s", checkpoint_id) + return + self.data = deepcopy(tmp) + logger.debug('Data %s loaded: %s', checkpoint_id, str(self.data)) + + def delete(self, checkpoint_id: str) -> None: + logger.debug('Deleting checkpoint %s', checkpoint_id) + with _GLOBAL_CHECKPOINTS_LOCK: + _GLOBAL_CHECKPOINTS.pop(checkpoint_id) diff --git a/tinygent/agents/checkpointer/register.py b/tinygent/agents/checkpointer/register.py new file mode 100644 index 0000000..71431b6 --- /dev/null +++ b/tinygent/agents/checkpointer/register.py @@ -0,0 +1,14 @@ +from tinygent.agents.checkpointer.local_checkpointer import TinyLocalCheckpointer +from tinygent.agents.checkpointer.local_checkpointer import TinyLocalCheckpointerConfig +from tinygent.core.runtime.global_registry import GlobalRegistry + + +def _register_checkpointers() -> None: + registry = GlobalRegistry().get_registry() + + registry.register_checkpointer( + 'local', TinyLocalCheckpointerConfig, TinyLocalCheckpointer + ) + + +_register_checkpointers() diff --git a/tinygent/agents/map_agent.py b/tinygent/agents/map_agent.py index 02e8a55..ed548cf 100644 --- a/tinygent/agents/map_agent.py +++ b/tinygent/agents/map_agent.py @@ -8,9 +8,11 @@ import uuid from pydantic import Field +from pydantic import PrivateAttr from tinygent.agents.base_agent import TinyBaseAgent from tinygent.agents.base_agent import TinyBaseAgentConfig +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer from tinygent.core.datamodels.llm import AbstractLLM from tinygent.core.datamodels.memory import AbstractMemory from tinygent.core.datamodels.messages import AllTinyMessages @@ -100,6 +102,26 @@ def validation(self) -> str: return 'Valid response' if self.is_valid else 'NOT-Valid response' +class DecomposedTask(TinyModel): + class subgoal(TinyModel): + index: int + question: str + _finished: bool = PrivateAttr(default=False) + + @property + def description(self) -> str: + return f'{self.index}. {self.question}' + + @property + def finished(self) -> bool: + return self._finished + + def marked_finished(self) -> None: + self._finished = True + + subgoals: list[subgoal] + + class TinyMAPAgentConfig(TinyBaseAgentConfig['TinyMAPAgent']): """Configuration for the TinyMAPAgent.""" @@ -180,8 +202,15 @@ def __init__( max_recurrsion: int = 5, tools: list[AbstractTool] = [], middleware: Sequence[AbstractMiddleware] = [], + checkpointer: AbstractCheckpointer | None = None, ) -> None: - super().__init__(llm=llm, tools=tools, memory=memory, middleware=middleware) + super().__init__( + llm=llm, + tools=tools, + memory=memory, + middleware=middleware, + checkpointer=checkpointer, + ) self.max_plan_length = max_plan_length self.max_branches_per_layer = max_branches_per_layer @@ -191,17 +220,14 @@ def __init__( prompt_template if prompt_template else self.default_prompt_template() ) + def _init_state(self) -> None: + self.checkpointer.setdefault('decomposed_task', None) + self.checkpointer.setdefault('final_plan', []) + @tiny_trace('map_agent_task_decomposer') - async def _task_decomposer(self, run_id: str, input_txt: str) -> list[str]: + async def _task_decomposer(self, run_id: str, input_txt: str) -> DecomposedTask: logger.debug('[TASK DECOMPOSER] task: %s', input_txt) - class DecomposedTask(TinyModel): - class subgoal(TinyModel): - index: int - question: str - - subgoals: list[subgoal] - messages = TinyLLMInput(messages=[*self.memory.copy_chat_messages()]) messages.add_at_end( TinyUserMessage( @@ -222,7 +248,7 @@ class subgoal(TinyModel): output_schema=DecomposedTask, ) - all_subgoals = [f'{sq.index}. {sq.question}' for sq in result.subgoals] + all_subgoals = [sb.description for sb in result] set_tiny_attributes( { @@ -233,7 +259,7 @@ class subgoal(TinyModel): logger.debug( '[TASK DECOMPOSER] decomposed task (%s): %s', input_txt, all_subgoals ) - return all_subgoals + return result @tiny_trace('map_agent_actor') async def _actor( @@ -665,14 +691,20 @@ async def _orchestrator( @tiny_trace('map_agent_map') async def _map(self, run_id: str, question: str) -> list[TinyMAPActionProposal]: logger.debug('[MAP] Running MAP module with task: %s', question) - subgoals = await self._task_decomposer(run_id, question) - subgoals.append( - question - ) # INFO: Last and final subgoal is original user question - final_plan: list[TinyMAPActionProposal] = [] + if self.checkpointer['decomposed_task'] is None: + decomposed_task = await self._task_decomposer(run_id, question) + decomposed_task.subgoals.append( + DecomposedTask.subgoal( + index=max(s.index for s in decomposed_task.subgoals) + 1, + question=question, + ) + ) # INFO: Last and final subgoal is original user question + self.checkpointer['decomposed_task'] = decomposed_task - for subgoal in subgoals: + for subgoal in filter( + lambda x: x.finished, self.checkpointer['decomposed_task'].subgoals + ): logger.debug('[MAP] current subgoal: %s', subgoal) current_state = TinyMAPState( is_valid=True, @@ -681,26 +713,39 @@ async def _map(self, run_id: str, question: str) -> list[TinyMAPActionProposal]: metadata='', ) - validity = await self._orchestrator(run_id, current_state, subgoal) + validity = await self._orchestrator(run_id, current_state, subgoal.question) while ( - not validity.fully_satisfies and len(final_plan) < self.max_plan_length + not validity.fully_satisfies + and len(self.checkpointer['final_plan']) < self.max_plan_length ): - search_res = await self._search(run_id, 0, current_state, subgoal) + search_res = await self._search( + run_id, 0, current_state, subgoal.question + ) - final_plan.append(search_res.action) + self.checkpointer['final_plan'].append(search_res.action) current_state = search_res.next_state - validity = await self._orchestrator(run_id, current_state, subgoal) + validity = await self._orchestrator( + run_id, current_state, subgoal.question + ) await self.on_plan(run_id=run_id, plan=search_res.action.sum, kwargs={}) + subgoal.marked_finished() + set_tiny_attributes( - {'agent.map.final_plan': '\n'.join([p.sum for p in final_plan])} + { + 'agent.map.final_plan': '\n'.join( + [p.sum for p in self.checkpointer['final_plan']] + ) + } ) logger.debug( - '[MAP] task: %s final plan: %s', question, [p.sum for p in final_plan] + '[MAP] task: %s final plan: %s', + question, + [p.sum for p in self.checkpointer['final_plan']], ) - return final_plan + return self.checkpointer['final_plan'] @tiny_trace('agent_run') async def _run_agent(self, input_text: str, run_id: str) -> str: @@ -729,25 +774,36 @@ def reset(self) -> None: logger.debug('[AGENT RESET]') - def setup(self, reset: bool, history: list[AllTinyMessages] | None) -> None: + self._init_state() + + def setup( + self, + reset: bool, + history: list[AllTinyMessages] | None, + checkpoint_id: str | None, + ) -> None: if reset: self.reset() if history: self.memory.save_multiple_context(history) + if checkpoint_id: + self.checkpointer.load(checkpoint_id) + def run( self, input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> str: logger.debug('[USER INPUT] %s', input_text) run_id = run_id or str(uuid.uuid4()) - self.setup(reset=reset, history=history) + self.setup(reset=reset, history=history, checkpoint_id=checkpoint_id) async def _run() -> str: plan = await self._run_agent(run_id=run_id, input_text=input_text) @@ -762,13 +818,14 @@ def run_stream( input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> AsyncGenerator[str, None]: logger.debug('[USER INPUT] %s', input_text) run_id = run_id or str(uuid.uuid4()) - self.setup(reset=reset, history=history) + self.setup(reset=reset, history=history, checkpoint_id=checkpoint_id) async def _generator(): plan = await self._run_agent(run_id=run_id, input_text=input_text) diff --git a/tinygent/agents/middleware/base_tool_selector.py b/tinygent/agents/middleware/base_tool_selector.py index 7bcb5ed..bf0386a 100644 --- a/tinygent/agents/middleware/base_tool_selector.py +++ b/tinygent/agents/middleware/base_tool_selector.py @@ -11,7 +11,6 @@ from tinygent.agents.middleware.base import TinyBaseMiddlewareConfig from tinygent.core.datamodels.tool import AbstractTool from tinygent.core.datamodels.tool import AbstractToolConfig -from tinygent.core.factory import build_tool logger = logging.getLogger(__name__) @@ -40,6 +39,8 @@ class TinyBaseToolSelectorMiddlewareConfig( ) def build_base_kwargs(self) -> dict: + from tinygent.core.factory import build_tool + always_include: list[AbstractTool] | None = None if self.always_include: diff --git a/tinygent/agents/multi_step_agent.py b/tinygent/agents/multi_step_agent.py index 6504c1a..60ded4f 100644 --- a/tinygent/agents/multi_step_agent.py +++ b/tinygent/agents/multi_step_agent.py @@ -11,6 +11,7 @@ from tinygent.agents.base_agent import TinyBaseAgent from tinygent.agents.base_agent import TinyBaseAgentConfig +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer from tinygent.core.datamodels.llm import AbstractLLM from tinygent.core.datamodels.memory import AbstractMemory from tinygent.core.datamodels.messages import AllTinyMessages @@ -56,6 +57,7 @@ def build(self) -> TinyMultiStepAgent: llm=self.build_llm_instance(), tools=self.build_tools_list(), memory=self.build_memory_instance(), + checkpointer=self.build_checkpointer_instance(), prompt_template=self.prompt_template, max_iterations=self.max_iterations, plan_interval=self.plan_interval, @@ -102,12 +104,15 @@ def __init__( max_iterations: int = 15, plan_interval: int = 5, middleware: Sequence[AbstractMiddleware] = [], + checkpointer: AbstractCheckpointer | None = None, ) -> None: - super().__init__(llm=llm, tools=tools, memory=memory, middleware=middleware) - - self._iteration_number: int = 1 - self._planned_steps: list[TinyPlanMessage] = [] - self._tool_calls: list[TinyToolCall] = [] + super().__init__( + llm=llm, + tools=tools, + memory=memory, + middleware=middleware, + checkpointer=checkpointer, + ) self.max_iterations = max_iterations self.plan_interval = plan_interval @@ -116,6 +121,11 @@ def __init__( self.plan_prompt = prompt_template.plan self.fallback_prompt = prompt_template.fallback + def _init_state(self) -> None: + self._checkpointer.setdefault('_iteration_number', 1) + self._checkpointer.setdefault('_planned_steps', []) + self._checkpointer.setdefault('_tool_calls', []) + @tiny_trace('multi_step_agent_steps_creation') async def _stream_steps( self, run_id: str, task: str @@ -127,7 +137,7 @@ class TinyReasonedSteps(TinyModel): variables: dict[str, Any] # Initial plan - if self._iteration_number == 1: + if self._checkpointer['_iteration_number'] == 1: template = self.plan_prompt.init_plan variables = {'task': task, 'tools': self.tools} else: @@ -136,8 +146,10 @@ class TinyReasonedSteps(TinyModel): 'task': task, 'tools': self.tools, 'history': self.memory.load_variables(), - 'steps': self._planned_steps, - 'remaining_steps': self.max_iterations - self._iteration_number + 1, + 'steps': self._checkpointer['_planned_steps'], + 'remaining_steps': self.max_iterations + - self._checkpointer['_iteration_number'] + + 1, } messages = TinyLLMInput( @@ -186,9 +198,9 @@ async def _stream_action( { 'task': task, 'tools': self.tools, - 'tool_calls': self._tool_calls, + 'tool_calls': self._checkpointer['_tool_calls'], 'history': self.memory.load_variables(), - 'steps': self._planned_steps, + 'steps': self._checkpointer['_planned_steps'], }, ) ), @@ -222,7 +234,7 @@ async def _stream_fallback_answer( { 'task': task, 'history': self.memory.load_variables(), - 'steps': self._planned_steps, + 'steps': self._checkpointer['_planned_steps'], }, ) ), @@ -248,43 +260,47 @@ async def _run_agent(self, input_text: str, run_id: str) -> AsyncGenerator[str]: logger.debug('[%s] Running agent with input %s', run_id, input_text) - self._iteration_number = 1 + self._checkpointer['_iteration_number'] = 1 returned_final_answer: bool = False yielded_final_answer: str = '' self.memory.save_context(TinyHumanMessage(content=input_text)) while not returned_final_answer and ( - self._iteration_number <= self.max_iterations + self._checkpointer['_iteration_number'] <= self.max_iterations ): with tiny_trace_span( - 'multi_step_agent_single_iteration', iteration=self._iteration_number + 'multi_step_agent_single_iteration', + iteration=self._checkpointer['_iteration_number'], ): - logger.debug('--- ITERATION %d ---', self._iteration_number) + logger.debug( + '--- ITERATION %d ---', self._checkpointer['_iteration_number'] + ) - if self._iteration_number == 1 or ( - (self._iteration_number - 1) % self.plan_interval == 0 + if self._checkpointer['_iteration_number'] == 1 or ( + (self._checkpointer['_iteration_number'] - 1) % self.plan_interval + == 0 ): # Create new plan plan_generator = self._stream_steps(run_id=run_id, task=input_text) - self._planned_steps = [] + self._checkpointer['_planned_steps'] = [] async for planner_msg in plan_generator: if isinstance(planner_msg, TinyPlanMessage): logger.debug( '[%d. ITERATION - Plan]: %s', - self._iteration_number, + self._checkpointer['_iteration_number'], planner_msg.content, ) await self.on_plan( run_id=run_id, plan=planner_msg.content, kwargs={} ) - self._planned_steps.append(planner_msg) + self._checkpointer['_planned_steps'].append(planner_msg) if isinstance(planner_msg, TinyReasoningMessage): logger.debug( '[%d. ITERATION - Reasoning]: %s', - self._iteration_number, + self._checkpointer['_iteration_number'], planner_msg.content, ) await self.on_reasoning( @@ -316,7 +332,7 @@ async def _run_agent(self, input_text: str, run_id: str) -> AsyncGenerator[str]: run_id=run_id, tool=called_tool, call=tool_call ) ) - self._tool_calls.append(tool_call) + self._checkpointer['_tool_calls'].append(tool_call) else: logger.error( 'Tool %s not found. Skipping tool call.', @@ -327,7 +343,7 @@ async def _run_agent(self, input_text: str, run_id: str) -> AsyncGenerator[str]: reasoning = tool_call.arguments.get('reasoning', '') logger.debug( '[%d. ITERATION - Tool Reasoning]: %s', - self._iteration_number, + self._checkpointer['_iteration_number'], reasoning, ) await self.on_tool_reasoning( @@ -336,7 +352,7 @@ async def _run_agent(self, input_text: str, run_id: str) -> AsyncGenerator[str]: logger.debug( '[%s. ITERATION - Tool Call]: %s(%s) = %s', - self._iteration_number, + self._checkpointer['_iteration_number'], tool_call.tool_name, tool_call.arguments, tool_call.result, @@ -352,7 +368,7 @@ async def _run_agent(self, input_text: str, run_id: str) -> AsyncGenerator[str]: await self.on_error(run_id=run_id, e=e, kwargs={}) raise e finally: - self._iteration_number += 1 + self._checkpointer['_iteration_number'] += 1 if not returned_final_answer: logger.warning( @@ -385,29 +401,37 @@ def reset(self) -> None: logger.debug('[AGENT RESET]') - self._iteration_number = 1 - self._planned_steps = [] - self._tool_calls = [] + self._checkpointer.clear() + self._init_state() - def setup(self, reset: bool, history: list[AllTinyMessages] | None) -> None: + def setup( + self, + reset: bool, + history: list[AllTinyMessages] | None, + checkpoint_id: str | None, + ) -> None: if reset: self.reset() if history: self.memory.save_multiple_context(history) + if checkpoint_id: + self.checkpointer.load(checkpoint_id) + def run( self, input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> str: logger.debug('[USER INPUT] %s', input_text) run_id = run_id or str(uuid.uuid4()) - self.setup(reset=reset, history=history) + self.setup(reset=reset, history=history, checkpoint_id=checkpoint_id) async def _run() -> str: final_answer: str = '' @@ -424,13 +448,14 @@ def run_stream( input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> AsyncGenerator[str, None]: logger.debug('[USER INPUT] %s', input_text) run_id = run_id or str(uuid.uuid4()) - self.setup(reset=reset, history=history) + self.setup(reset=reset, history=history, checkpoint_id=checkpoint_id) async def _generator(): idx = 0 diff --git a/tinygent/agents/react_agent.py b/tinygent/agents/react_agent.py index a614f57..439e325 100644 --- a/tinygent/agents/react_agent.py +++ b/tinygent/agents/react_agent.py @@ -10,6 +10,7 @@ from tinygent.agents.base_agent import TinyBaseAgent from tinygent.agents.base_agent import TinyBaseAgentConfig +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer from tinygent.core.datamodels.llm import AbstractLLM from tinygent.core.datamodels.memory import AbstractMemory from tinygent.core.datamodels.messages import AllTinyMessages @@ -53,6 +54,7 @@ def build(self) -> TinyReActAgent: llm=self.build_llm_instance(), tools=self.build_tools_list(), memory=self.build_memory_instance(), + checkpointer=self.build_checkpointer_instance(), max_iterations=self.max_iterations, ) @@ -86,6 +88,19 @@ class TinyReActAgent(TinyBaseAgent): middleware: List of middleware to apply during execution """ + class TinyReactIteration(TinyModel): + iteration_number: int + tool_calls: list[TinyToolCall] + reasoning: str + + @property + def summary(self) -> str: + return ( + f'Iteration {self.iteration_number}:\n' + f'Reasoning: {self.reasoning}\n' + f'Tool Calls: {", ".join(call.tool_name for call in self.tool_calls)}\n' + ) + def __init__( self, llm: AbstractLLM, @@ -94,30 +109,25 @@ def __init__( tools: list[AbstractTool] = [], max_iterations: int = 10, middleware: Sequence[AbstractMiddleware] = [], + checkpointer: AbstractCheckpointer | None = None, ) -> None: - super().__init__(llm=llm, tools=tools, memory=memory, middleware=middleware) - - class TinyReactIteration(TinyModel): - iteration_number: int - tool_calls: list[TinyToolCall] - reasoning: str - - @property - def summary(self) -> str: - return ( - f'Iteration {self.iteration_number}:\n' - f'Reasoning: {self.reasoning}\n' - f'Tool Calls: {", ".join(call.tool_name for call in self.tool_calls)}\n' - ) - - self.TinyReactIteration = TinyReactIteration - - self._iteration_number: int = 1 - self._react_iterations: list[TinyReactIteration] = [] + super().__init__( + llm=llm, + tools=tools, + memory=memory, + middleware=middleware, + checkpointer=checkpointer, + ) self.prompt_template = prompt_template self.max_iterations = max_iterations + def _init_state(self) -> None: + self.checkpointer.setdefault('iteration_number', 1) + self.checkpointer.setdefault('react_iterations', []) + self.checkpointer.setdefault('returned_final_answer', False) + self.checkpointer.setdefault('yielded_final_answer', '') + @tiny_trace('react_agent_reasoning') async def _stream_reasoning( self, run_id: str, task: str @@ -126,7 +136,7 @@ class TinyReasoningOutcome(TinyModel): type: Literal['reasoning'] = 'reasoning' content: str - if self._iteration_number == 1: + if self.checkpointer['iteration_number'] == 1: template = self.prompt_template.reason.init variables = {'task': task} else: @@ -134,7 +144,8 @@ class TinyReasoningOutcome(TinyModel): variables = { 'task': task, 'overview': '\n'.join( - iteration.summary for iteration in self._react_iterations + iteration.summary + for iteration in self.checkpointer['react_iterations'] ), } @@ -211,7 +222,8 @@ async def _stream_fallback( { 'task': task, 'overview': '\n'.join( - iteration.summary for iteration in self._react_iterations + iteration.summary + for iteration in self.checkpointer['react_iterations'] ), }, ) @@ -241,19 +253,20 @@ async def _run_agent( ) logger.debug('Running agent with task: %s', input_text) - self._iteration_number = 1 - returned_final_answer: bool = False - yielded_final_answer: str = '' + self._init_state() self.memory.save_context(TinyHumanMessage(content=input_text)) - while not returned_final_answer and ( - self._iteration_number <= self.max_iterations + while not self.checkpointer['returned_final_answer'] and ( + self.checkpointer['iteration_number'] <= self.max_iterations ): with tiny_trace_span( - 'react_agent_single_iteration', iteration=self._iteration_number + 'react_agent_single_iteration', + iteration=self.checkpointer['iteration_number'], ): - logger.debug('--- ITERATION %d ---', self._iteration_number) + logger.debug( + '--- ITERATION %d ---', self.checkpointer['iteration_number'] + ) try: reasoning_result = await self._stream_reasoning( @@ -261,17 +274,17 @@ async def _run_agent( ) logger.debug( '[%d. ITERATION - Reasoning Result]: %s', - self._iteration_number, + self.checkpointer['iteration_number'], reasoning_result.content, ) if isinstance(reasoning_result, TinyChatMessage): logger.debug( '[%d. ITERATION - Reasoning Final Answer]: %s', - self._iteration_number, + self.checkpointer['iteration_number'], reasoning_result.content, ) - returned_final_answer = True + self.checkpointer['returned_final_answer'] = True self.memory.save_context(reasoning_result) @@ -279,7 +292,8 @@ async def _run_agent( else: logger.debug( - '[%d. ITERATION - Streaming Action]', self._iteration_number + '[%d. ITERATION - Streaming Action]', + self.checkpointer['iteration_number'], ) tool_calls: list[TinyToolCall] = [] @@ -289,8 +303,10 @@ async def _run_agent( if msg.is_message and isinstance( msg.message, TinyChatMessageChunk ): - returned_final_answer = True - yielded_final_answer += msg.message.content + self.checkpointer['returned_final_answer'] = True + self.checkpointer['yielded_final_answer'] += ( + msg.message.content + ) yield msg.message.content @@ -315,7 +331,7 @@ async def _run_agent( ) logger.debug( '[%d. ITERATION - Tool Reasoning]: %s', - self._iteration_number, + self.checkpointer['iteration_number'], reasoning, ) await self.on_tool_reasoning( @@ -329,20 +345,22 @@ async def _run_agent( logger.debug( '[%s. ITERATION - Tool Call]: %s(%s) = %s', - self._iteration_number, + self.checkpointer['iteration_number'], full_tc.tool_name, full_tc.arguments, full_tc.result, ) - if yielded_final_answer: + if self.checkpointer['yielded_final_answer']: self.memory.save_context( - TinyChatMessage(content=yielded_final_answer) + TinyChatMessage( + content=self.checkpointer['yielded_final_answer'] + ) ) - self._react_iterations.append( + self.checkpointer['react_iterations'].append( self.TinyReactIteration( - iteration_number=self._iteration_number, + iteration_number=self.checkpointer['iteration_number'], tool_calls=tool_calls, reasoning=reasoning_result.content, ) @@ -352,9 +370,9 @@ async def _run_agent( await self.on_error(run_id=run_id, e=e, kwargs={}) raise e finally: - self._iteration_number += 1 + self.checkpointer['iteration_number'] += 1 - if not returned_final_answer: + if not self.checkpointer['returned_final_answer']: logger.warning( 'Max iterations reached without final answer. Using fallback.' 'Returning fallback answer.' @@ -372,8 +390,9 @@ async def _run_agent( yield fallback_chunk if not yielded_fallback: - final_yielded_answer = 'I have completed my reasoning and tool usage but did not arrive at a final answer.' - yield final_yielded_answer + raise RuntimeError( + 'Something went wrong, cannot return answer from react agent.' + ) self.memory.save_context(TinyChatMessage(content=final_yielded_answer)) @@ -382,28 +401,36 @@ def reset(self) -> None: logger.debug('[AGENT RESET]') - self._iteration_number = 1 - self._react_iterations = [] + self._init_state() - def setup(self, reset: bool, history: list[AllTinyMessages] | None) -> None: + def setup( + self, + reset: bool, + history: list[AllTinyMessages] | None, + checkpoint_id: str | None, + ) -> None: if reset: self.reset() if history: self.memory.save_multiple_context(history) + if checkpoint_id: + self.checkpointer.load(checkpoint_id) + def run( self, input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> str: logger.debug('[USER INPUT] %s', input_text) run_id = run_id or str(uuid.uuid4()) - self.setup(reset=reset, history=history) + self.setup(reset=reset, history=history, checkpoint_id=checkpoint_id) async def _run() -> str: final_answer = '' @@ -420,13 +447,14 @@ def run_stream( input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> AsyncGenerator[str, None]: logger.debug('[USER INPUT] %s', input_text) run_id = run_id or str(uuid.uuid4()) - self.setup(reset=reset, history=history) + self.setup(reset=reset, history=history, checkpoint_id=checkpoint_id) async def _generator(): idx = 0 diff --git a/tinygent/agents/squad_agent.py b/tinygent/agents/squad_agent.py index 3f08435..5fec43a 100644 --- a/tinygent/agents/squad_agent.py +++ b/tinygent/agents/squad_agent.py @@ -16,6 +16,7 @@ from tinygent.agents.base_agent import TinyBaseAgentConfig from tinygent.core.datamodels.agent import AbstractAgent from tinygent.core.datamodels.agent import AbstractAgentConfig +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer from tinygent.core.datamodels.llm import AbstractLLM from tinygent.core.datamodels.memory import AbstractMemory from tinygent.core.datamodels.messages import AllTinyMessages @@ -90,11 +91,12 @@ class TinySquadAgentConfig(TinyBaseAgentConfig['TinySquadAgent']): def build(self) -> TinySquadAgent: return TinySquadAgent( middleware=self.build_middleware_list(), - prompt_template=self.prompt_template, llm=self.build_llm_instance(), tools=self.build_tools_list(), memory=self.build_memory_instance(), squad=[AgentSquadMember.from_config(agent_cfg) for agent_cfg in self.squad], + checkpointer=self.build_checkpointer_instance(), + prompt_template=self.prompt_template, ) @model_validator(mode='after') @@ -147,13 +149,24 @@ def __init__( tools: list[AbstractTool] = [], squad: list[AgentSquadMember] = [], middleware: Sequence[AbstractMiddleware] = [], + checkpointer: AbstractCheckpointer | None = None, ) -> None: - super().__init__(llm=llm, tools=tools, memory=memory, middleware=middleware) + super().__init__( + llm=llm, + tools=tools, + memory=memory, + middleware=middleware, + checkpointer=checkpointer, + ) self._squad = [self._normalize_squad_member(member) for member in squad] self.prompt_template = prompt_template + @property + def members(self) -> list[AbstractAgent]: + return [m.agent for m in self._squad] + @staticmethod def _normalize_squad_member(member: AgentSquadMember) -> AgentSquadMember: async def _empty(*_args, **_kwargs) -> None: @@ -296,25 +309,34 @@ def reset(self) -> None: for member in self._squad: member.agent.reset() - def setup(self, reset: bool, history: list[AllTinyMessages] | None) -> None: + def setup( + self, + reset: bool, + history: list[AllTinyMessages] | None, + checkpoint_id: str | None, + ) -> None: if reset: self.reset() if history: self.memory.save_multiple_context(history) + if checkpoint_id: + self.checkpointer.load(checkpoint_id) + def run( self, input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> str: logger.debug('[USER INPUT] %s', input_text) run_id = run_id or str(uuid.uuid4()) - self.setup(reset=reset, history=history) + self.setup(reset=reset, history=history, checkpoint_id=checkpoint_id) async def _run() -> str: final_answer = '' @@ -331,13 +353,14 @@ def run_stream( input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> AsyncGenerator[str, None]: logger.debug('[USER INPUT] %s', input_text) run_id = run_id or str(uuid.uuid4()) - self.setup(reset=reset, history=history) + self.setup(reset=reset, history=history, checkpoint_id=checkpoint_id) async def _generator(): idx = 0 diff --git a/tinygent/core/datamodels/agent.py b/tinygent/core/datamodels/agent.py index 3b85d5d..3543341 100644 --- a/tinygent/core/datamodels/agent.py +++ b/tinygent/core/datamodels/agent.py @@ -8,6 +8,7 @@ from typing import TypeVar from tinygent.agents.middleware.agent import TinyMiddlewareAgent +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer from tinygent.core.datamodels.memory import AbstractMemory from tinygent.core.datamodels.messages import AllTinyMessages from tinygent.core.types.builder import TinyModelBuildable @@ -34,6 +35,11 @@ def memory(self) -> AbstractMemory: """Get agents memory instance.""" raise NotImplementedError('Subclasses must implement this method.') + @property + def checkpointer(self) -> AbstractCheckpointer: + """Get agents checkpointer instance.""" + raise NotImplementedError('Subclasses must implement this method.') + @abstractmethod def reset(self) -> None: """Reset the agent's internal state.""" @@ -45,6 +51,7 @@ def run( input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> str: @@ -57,6 +64,7 @@ def run_stream( input_text: str, *, run_id: str | None = None, + checkpoint_id: str | None = None, reset: bool = True, history: list[AllTinyMessages] | None = None, ) -> AsyncGenerator[str, None]: diff --git a/tinygent/core/datamodels/checkpointer.py b/tinygent/core/datamodels/checkpointer.py new file mode 100644 index 0000000..4d904c4 --- /dev/null +++ b/tinygent/core/datamodels/checkpointer.py @@ -0,0 +1,66 @@ +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Generic +from typing import TypeVar + +from tinygent.core.types.builder import TinyModelBuildable + +T = TypeVar('T', bound='AbstractCheckpointer') + + +class AbstractCheckpointerConfig(TinyModelBuildable[T], Generic[T]): + """Abstract base class for checkpoint configuration.""" + + def build(self) -> T: + """Build the checkpointer instance from the configuration.""" + raise NotImplementedError('Subclasses must implement this method.') + + +class AbstractCheckpointer(ABC): + """Abstract base class for checkpoint middleware.""" + + @abstractmethod + def save(self, checkpoint_id: str) -> None: + """Save current checkpoint state.""" + pass + + @abstractmethod + def load(self, checkpoint_id: str) -> None: + """Load saved checkpoint state.""" + pass + + @abstractmethod + def delete(self, checkpoint_id: str) -> None: + """Delete desired checkpoint.""" + pass + + @abstractmethod + def set_data(self, data: dict[str, Any]) -> None: + """Set checkpoint data manyally.""" + pass + + @abstractmethod + def setdefault(self, key: str, value: Any) -> Any: + """Initialize a key only if missing, then return the stored value. + + Non-destructive: keys already present (e.g. from a loaded checkpoint) + are left untouched. This is the building block for "set sometimes, + use sometimes" state initialization. + """ + pass + + @abstractmethod + def clear(self) -> None: + """Drop all checkpoint data (explicit fresh-start).""" + pass + + @abstractmethod + def __getitem__(self, key: str, default: Any = None) -> Any: + """Get checkpoint data by key.""" + pass + + @abstractmethod + def __setitem__(self, key: str, value: Any) -> None: + """Set checkpoint data by key.""" + pass diff --git a/tinygent/core/datamodels/tool_info.py b/tinygent/core/datamodels/tool_info.py index f3b6f15..48340ee 100644 --- a/tinygent/core/datamodels/tool_info.py +++ b/tinygent/core/datamodels/tool_info.py @@ -164,7 +164,7 @@ def from_callable(cls, fn: Callable[..., R], *args, **kwargs) -> ToolInfo[R]: return cls( name=name, description=description, - arg_count=1, + arg_count=1, # TODO: its not always 1 right now, tool accepts normal params as well so repair it here is_coroutine=is_coroutine, is_generator=is_generator, is_async_generator=is_async_generator, diff --git a/tinygent/core/factory/checkpointer.py b/tinygent/core/factory/checkpointer.py new file mode 100644 index 0000000..405330b --- /dev/null +++ b/tinygent/core/factory/checkpointer.py @@ -0,0 +1,26 @@ +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer +from tinygent.core.datamodels.checkpointer import AbstractCheckpointerConfig +from tinygent.core.factory.helper import check_modules +from tinygent.core.factory.helper import parse_config +from tinygent.core.runtime.global_registry import GlobalRegistry + + +def build_checkpointer( + checkpointer: dict | AbstractCheckpointer | AbstractCheckpointerConfig, **kwargs +) -> AbstractCheckpointer: + """Build tiny checkpointer.""" + if isinstance(checkpointer, AbstractCheckpointer): + return checkpointer + + check_modules() + + if isinstance(checkpointer, str): + checkpointer = {'type': checkpointer, **kwargs} + + if isinstance(checkpointer, AbstractCheckpointerConfig): + checkpointer = checkpointer.model_dump() + + checkpointer_config = parse_config( + checkpointer, lambda: GlobalRegistry.get_registry().get_checkpointers() + ) + return checkpointer_config.build() diff --git a/tinygent/core/factory/llm.py b/tinygent/core/factory/llm.py index fa924be..0f584c8 100644 --- a/tinygent/core/factory/llm.py +++ b/tinygent/core/factory/llm.py @@ -7,13 +7,17 @@ def build_llm( - llm: str | dict | AbstractLLMConfig, + llm: str | dict | AbstractLLM | AbstractLLMConfig, *, provider: str | None = None, temperature: float | None = None, **kwargs, ) -> AbstractLLM: """Build tiny llm.""" + + if isinstance(llm, AbstractLLM): + return llm + check_modules() if isinstance(llm, str): diff --git a/tinygent/core/factory/memory.py b/tinygent/core/factory/memory.py index 462e020..e684a71 100644 --- a/tinygent/core/factory/memory.py +++ b/tinygent/core/factory/memory.py @@ -15,8 +15,13 @@ def build_memory(memory: dict | AbstractMemoryConfig) -> AbstractMemory: ... def build_memory(memory: str, **kwargs) -> AbstractMemory: ... -def build_memory(memory: dict | AbstractMemoryConfig | str, **kwargs) -> AbstractMemory: +def build_memory( + memory: dict | AbstractMemory | AbstractMemoryConfig | str, **kwargs +) -> AbstractMemory: """Build tiny memory.""" + if isinstance(memory, AbstractMemory): + return memory + check_modules() if isinstance(memory, str): diff --git a/tinygent/core/factory/tool.py b/tinygent/core/factory/tool.py index 86befc4..3ddf87d 100644 --- a/tinygent/core/factory/tool.py +++ b/tinygent/core/factory/tool.py @@ -8,7 +8,7 @@ @overload -def build_tool(tool: dict | AbstractToolConfig) -> AbstractTool: ... +def build_tool(tool: dict | AbstractTool | AbstractToolConfig) -> AbstractTool: ... @overload @@ -20,9 +20,15 @@ def build_tool(tool: str, *, tool_type: str) -> AbstractTool: ... def build_tool( - tool: dict | AbstractToolConfig | str, *, tool_type: str | None = None, **tool_kargs + tool: dict | AbstractTool | AbstractToolConfig | str, + *, + tool_type: str | None = None, + **tool_kargs, ) -> AbstractTool: """Build tiny tool.""" + if isinstance(tool, AbstractTool): + return tool + check_modules() if isinstance(tool, str): diff --git a/tinygent/core/runtime/executors.py b/tinygent/core/runtime/executors.py index 3900a81..0ebe8f0 100644 --- a/tinygent/core/runtime/executors.py +++ b/tinygent/core/runtime/executors.py @@ -11,8 +11,8 @@ P = typing.ParamSpec('P') T = typing.TypeVar('T') -_bg_loop = None -_bg_thread = None +_bg_loop: asyncio.AbstractEventLoop | None = None +_bg_thread: threading.Thread | None = None _DEFAULT_SEMAPHORE_LIMIT = int(os.getenv('TINY_SEMPATHORE_DEFAULT_LIMIT', 5)) diff --git a/tinygent/core/runtime/global_registry.py b/tinygent/core/runtime/global_registry.py index ed9d269..43a2642 100644 --- a/tinygent/core/runtime/global_registry.py +++ b/tinygent/core/runtime/global_registry.py @@ -3,6 +3,9 @@ import logging import typing +from tinygent.core.datamodels.checkpointer import AbstractCheckpointer +from tinygent.core.datamodels.checkpointer import AbstractCheckpointerConfig + if typing.TYPE_CHECKING: from tinygent.core.datamodels.agent import AbstractAgent from tinygent.core.datamodels.agent import AbstractAgentConfig @@ -49,6 +52,11 @@ def __init__(self) -> None: str, tuple[type[AbstractMemoryConfig], type[AbstractMemory]] ] = {} + # checkpointers + self._registered_checkpointers: dict[ + str, tuple[type[AbstractCheckpointerConfig], type[AbstractCheckpointer]] + ] = {} + # tools self._registered_tools: dict[ str, tuple[type[AbstractToolConfig], type[AbstractTool]] @@ -68,6 +76,7 @@ def _rebuild_annotations(self) -> None: configs.extend(cfg for cfg, _ in self._registered_embedders.values()) configs.extend(cfg for cfg, _ in self._registered_crossencoders.values()) configs.extend(cfg for cfg, _ in self._registered_memories.values()) + configs.extend(cfg for cfg, _ in self._registered_checkpointers.values()) configs.extend(cfg for cfg, _ in self._registered_tools.values()) for config_cls in configs: @@ -219,6 +228,35 @@ def get_memories( logger.debug('Getting all registered memories') return self._registered_memories + # checkpointers + def register_checkpointer( + self, + name: str, + config_class: type[AbstractCheckpointerConfig], + checkpointer_class: type[AbstractCheckpointer], + ) -> None: + logger.debug('Registering checkpointer %s', name) + if name in self._registered_checkpointers: + raise ValueError(f'Checkpointer {name} already registered.') + + self._registered_checkpointers[name] = (config_class, checkpointer_class) + self._registration_changed() + + def get_checkpointer( + self, name: str + ) -> tuple[type[AbstractCheckpointerConfig], type[AbstractCheckpointer]]: + logger.debug('Getting checkpointer %s', name) + if name not in self._registered_checkpointers: + raise ValueError(f'Checkpointer {name} not registered.') + + return self._registered_checkpointers[name] + + def get_checkpointers( + self, + ) -> dict[str, tuple[type[AbstractCheckpointerConfig], type[AbstractCheckpointer]]]: + logger.debug('Getting all registered checkpointers') + return self._registered_checkpointers + # tools def register_tool( self, diff --git a/tinygent/tools/jit_tool.py b/tinygent/tools/jit_tool.py index 3b828b5..4da409b 100644 --- a/tinygent/tools/jit_tool.py +++ b/tinygent/tools/jit_tool.py @@ -29,6 +29,8 @@ class JITInstructionToolConfig(AbstractToolConfig['JITInstructionTool'], Generic instruction: str = Field(...) + # TODO: add here custom instruction field name + def build(self) -> 'JITInstructionTool': return cast( 'JITInstructionTool',