Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ docs = "scripts.docs:main"

[project.entry-points.components]
tiny_agents = "tinygent.agents.register"
tiny_checkpointers = "tinygent.agents.checkpointer.register"
tiny_middlewares = "tinygent.agents.middleware.register"
tiny_tools = "tinygent.tools.register"
tiny_memory = "tinygent.memory.register"
Expand Down
37 changes: 34 additions & 3 deletions tinygent/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from io import StringIO
import logging
import textwrap
import typing
from typing import Any
from typing import Awaitable
from typing import Callable
Expand All @@ -17,6 +18,8 @@
from tinygent.agents.middleware.tool_limiter import ToolCallBlockedException
from tinygent.core.datamodels.agent import AbstractAgent
from tinygent.core.datamodels.agent import AbstractAgentConfig
from tinygent.core.datamodels.checkpointer import AbstractCheckpointer
from tinygent.core.datamodels.checkpointer import AbstractCheckpointerConfig
from tinygent.core.datamodels.llm import AbstractLLM
from tinygent.core.datamodels.llm import AbstractLLMConfig
from tinygent.core.datamodels.memory import AbstractMemory
Expand All @@ -34,12 +37,20 @@
from tinygent.core.types.io.llm_io_input import TinyLLMInput
from tinygent.memory.buffer_chat_memory import BufferChatMemoryConfig

T = TypeVar('T', bound='AbstractAgent')
if typing.TYPE_CHECKING:
from tinygent.agents.checkpointer.default_checkpointer import TinyDefaultCheckpointer

T = TypeVar('T', bound='AbstractAgent')

logger = logging.getLogger(__name__)


def _create_default_checkpointer() -> 'TinyDefaultCheckpointer':
from tinygent.agents.checkpointer.default_checkpointer import TinyDefaultCheckpointer

return TinyDefaultCheckpointer({})


class TinyBaseAgentConfig(AbstractAgentConfig[T], Generic[T]):
"""Configuration for BaseAgent."""

Expand All @@ -48,12 +59,14 @@ class TinyBaseAgentConfig(AbstractAgentConfig[T], Generic[T]):
middleware: Sequence[AbstractMiddlewareConfig | AbstractMiddleware] = Field(
default_factory=list
)

llm: AbstractLLMConfig | AbstractLLM = Field(...)
tools: Sequence[AbstractToolConfig | AbstractTool] = Field(default_factory=list)
memory: AbstractMemoryConfig | AbstractMemory = Field(
default_factory=BufferChatMemoryConfig
)
checkpointer: AbstractCheckpointer | AbstractCheckpointerConfig | None = Field(
default=None
)

def build(self) -> T:
"""Build the BaseAgent instance from the configuration."""
Expand Down Expand Up @@ -84,6 +97,18 @@ def build_memory_instance(self) -> AbstractMemory:

return build_memory(self.memory)

def build_checkpointer_instance(self) -> AbstractCheckpointer:
"""Build checkpointer instance from config if checkpointer is set."""
if isinstance(self.checkpointer, AbstractCheckpointer):
return self.checkpointer

if self.checkpointer is None:
return _create_default_checkpointer()

from tinygent.core.factory.checkpointer import build_checkpointer

return build_checkpointer(self.checkpointer)

def build_middleware_list(self) -> list[AbstractMiddleware]:
"""Build list of middleware instances from configs or return existing instances."""
from tinygent.core.factory.middleware import build_middleware
Expand All @@ -99,20 +124,26 @@ def __init__(
self,
llm: AbstractLLM,
memory: AbstractMemory,
tools: Sequence[AbstractTool] = (),
tools: Sequence[AbstractTool] = [],
middleware: Sequence[AbstractMiddleware] = [],
checkpointer: AbstractCheckpointer | None = None,
) -> None:
self.llm = llm
self.middleware = middleware

self._memory = memory
self._tools = tools
self._checkpointer = (
_create_default_checkpointer() if checkpointer is None else checkpointer
)
self._final_answer: str | None = None

def reset(self) -> None:
logger.debug('[BASE AGENT RESET]')

self.memory.clear()
self.checkpointer.clear()

self._final_answer = None

@property
Expand Down
Empty file.
55 changes: 55 additions & 0 deletions tinygent/agents/checkpointer/base_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging
from typing import Any
from typing import Generic
from typing import TypeVar

from pydantic import Field

from tinygent.core.datamodels.checkpointer import AbstractCheckpointer
from tinygent.core.datamodels.checkpointer import AbstractCheckpointerConfig

T = TypeVar('T', bound='TinyBaseCheckpointer')


logger = logging.getLogger(__name__)


class TinyBaseCheckpointerConfig(AbstractCheckpointerConfig[T], Generic[T]):
type: Any = Field(default='base')

data: dict[str, Any] = Field(default={})

def build(self) -> T:
"""Build the BaseCheckpointer instance from the configuration."""
raise NotImplementedError('Subclasses must implement this method.')


class TinyBaseCheckpointer(AbstractCheckpointer):
def __init__(self, data: dict[str, Any]) -> None:
self.data = data

def set_data(self, data: dict[str, Any]) -> None:
self.data = data

def setdefault(self, key: str, value: Any) -> Any:
if key not in self.data:
self.data[key] = value
return self.data[key]

def clear(self) -> None:
self.data = {}

def __getitem__(self, key: str, default: Any = None) -> Any:
val = self.data.get(key, default)
if val is None:
logger.warning(
'Key: %s is missing in checkpoint data %s', key, self.__class__.__name__
)

if default is not None:
self.__setitem__('key', default)

return val

def __setitem__(self, key: str, value: Any) -> None:
self.data[key] = value
25 changes: 25 additions & 0 deletions tinygent/agents/checkpointer/default_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

import logging
from typing import Any

from tinygent.agents.checkpointer.base_checkpointer import TinyBaseCheckpointer

logger = logging.getLogger(__name__)


class TinyDefaultCheckpointer(TinyBaseCheckpointer):
def __init__(self, data: dict[str, Any]) -> None:
super().__init__(data)

def save(self, checkpoint_id: str) -> None:
logger.debug(
'Checkpoint (%s) will not be saved in default checkpointer', checkpoint_id
)

def load(self, checkpoint_id: str) -> None:
logger.debug('Nothing to be loaded for default checkpointer (%s)', checkpoint_id)

def delete(self, checkpoint_id: str) -> None:
logger.debug('Deleting checkpoint %s', checkpoint_id)
self.data = {}
52 changes: 52 additions & 0 deletions tinygent/agents/checkpointer/local_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from copy import deepcopy
import logging
from threading import Lock
from typing import Any
from typing import TypeVar

from pydantic import Field

from tinygent.agents.checkpointer.base_checkpointer import TinyBaseCheckpointer
from tinygent.agents.checkpointer.base_checkpointer import TinyBaseCheckpointerConfig

logger = logging.getLogger(__name__)

T = TypeVar('T', bound='TinyLocalCheckpointer')

_GLOBAL_CHECKPOINTS: dict[str, Any] = {}
_GLOBAL_CHECKPOINTS_LOCK = Lock()


class TinyLocalCheckpointerConfig(TinyBaseCheckpointerConfig['TinyLocalCheckpointer']):
type: Any = Field(default='local', frozen=True)

def build(self) -> TinyLocalCheckpointer:
return TinyLocalCheckpointer(self.data)


class TinyLocalCheckpointer(TinyBaseCheckpointer):
def __init__(self, data: dict[str, Any]) -> None:
super().__init__(data)

def save(self, checkpoint_id: str) -> None:
logger.debug('Saving checkpoint %s', checkpoint_id)
with _GLOBAL_CHECKPOINTS_LOCK:
_GLOBAL_CHECKPOINTS[checkpoint_id] = deepcopy(self.data)
logger.debug('Data %s saved: %s', checkpoint_id, str(self.data))

def load(self, checkpoint_id: str) -> None:
logger.debug('Loading checkpoint %s', checkpoint_id)
with _GLOBAL_CHECKPOINTS_LOCK:
tmp = _GLOBAL_CHECKPOINTS.get(checkpoint_id)
if tmp is None:
logger.warning("Couldn't find data for checkpoint id: %s", checkpoint_id)
return
self.data = deepcopy(tmp)
logger.debug('Data %s loaded: %s', checkpoint_id, str(self.data))

def delete(self, checkpoint_id: str) -> None:
logger.debug('Deleting checkpoint %s', checkpoint_id)
with _GLOBAL_CHECKPOINTS_LOCK:
_GLOBAL_CHECKPOINTS.pop(checkpoint_id)
14 changes: 14 additions & 0 deletions tinygent/agents/checkpointer/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from tinygent.agents.checkpointer.local_checkpointer import TinyLocalCheckpointer
from tinygent.agents.checkpointer.local_checkpointer import TinyLocalCheckpointerConfig
from tinygent.core.runtime.global_registry import GlobalRegistry


def _register_checkpointers() -> None:
registry = GlobalRegistry().get_registry()

registry.register_checkpointer(
'local', TinyLocalCheckpointerConfig, TinyLocalCheckpointer
)


_register_checkpointers()
Loading
Loading