Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,47 @@
.env
# Python
*.py[cod]
__pycache__/
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Virtual environments
venv/
.venv/
env/
.env

# Pytest
.pytest_cache/
.coverage
htmlcov/
coverage.xml

# IDEs
.idea/
.vscode/
*.swp
*.swo

# MacOS
.DS_Store

# Logs
logs/
*.log
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "Xmem"
version = "0.1.0"
description = "Universal unified memory system for AI agents"
authors = [
{name = "Vedant Mahajan", email = "xmemlabs@gmail.com"}
{name = "Vedant Mahajan", email = "xmemlabs@gmail.com"},
{name = "Ishaan Gupta", email = "xmemlabs@gmail.com"}
]
readme = "README.md"
Expand Down
7 changes: 7 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Xmem agents — re-export the public agent classes."""

from src.agents.classifier import ClassifierAgent

__all__ = [
"ClassifierAgent",
]
35 changes: 35 additions & 0 deletions src/agents/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict

from langchain_core.language_models import BaseChatModel


class BaseAgent(ABC):
def __init__(self, model: BaseChatModel, name: str, system_prompt: str = ""):
self.model = model
self.name = name
self.system_prompt = system_prompt
self.logger = logging.getLogger(f"xmem.agents.{name}")

@abstractmethod
async def arun(self, state: Dict[str, Any]) -> Any:
...

def run(self, state: Dict[str, Any]) -> Any:
import asyncio
return asyncio.run(self.arun(state))

def _build_messages(self, user_message: str) -> list:
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": user_message})
return messages

async def _call_model(self, messages: list) -> str:
response = await self.model.ainvoke(messages)
content = response.content
if isinstance(content, list):
content = "\n".join(str(c) for c in content)
return content
51 changes: 51 additions & 0 deletions src/agents/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Classifier Agent — the entry-point router for Xmem.

Classifies user input into one or more intent categories (code, profile,
event) so downstream agents only receive the sub-queries relevant to them.
"""

from __future__ import annotations

from typing import Any, Dict

from langchain_core.language_models import BaseChatModel

from src.agents.base import BaseAgent
from src.prompts.classifier import build_system_prompt, pack_classification_query
from src.schemas.classification import ClassificationResult
from src.utils.text import parse_raw_response_to_classifications


class ClassifierAgent(BaseAgent):
def __init__(self, model: BaseChatModel) -> None:
super().__init__(
model=model,
name="classifier",
system_prompt=build_system_prompt(),
)

async def arun(self, state: Dict[str, Any]) -> ClassificationResult:
user_input = state.get("user_query")
if not user_input:
self.logger.debug("Empty query — returning empty classifications.")
return ClassificationResult(classifications=[])

user_message = pack_classification_query(user_input)
messages = self._build_messages(user_message)
raw_content = await self._call_model(messages)
classifications = parse_raw_response_to_classifications(raw_content)

if classifications:
self.logger.info("=" * 50)
self.logger.info("Extracted Classifications:")
for idx, cls in enumerate(classifications, 1):
self.logger.info(
" %d. source=%s query=%s", idx, cls["source"], cls["query"]
)
self.logger.info("Total classifications: %d", len(classifications))
self.logger.info("=" * 50)
else:
self.logger.info("No actionable classifications found (trivial input).")

return ClassificationResult(classifications=classifications)
12 changes: 12 additions & 0 deletions src/config/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Shared constants used across Xmem agents and prompt formatting.

These are protocol-level values that all agents rely on for structured
communication with the LLM. Changing them requires updating every
system prompt that references the separator format.
"""

# Delimiter used in the tab-separated format between LLM and agents.
# Format in prompts: `- SOURCE::QUERY`
# Must stay in sync with all system prompts and parsing utilities.
LLM_TAB_SEPARATOR: str = "::"
3 changes: 1 addition & 2 deletions src/config/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional,List
from typing import Optional, List
from pydantic import Field,field_validator
from pydantic_settings import BaseSettings,SettingsConfigDict

Expand Down Expand Up @@ -122,7 +122,6 @@ class Settings(BaseSettings):
@field_validator("fallback_order")
@classmethod
def validate_fallback_order(cls, v: List[str]) -> List[str]:
"""Ensure fallback_order only contains valid provider names."""
valid_providers = {"gemini", "claude", "openai"}
for provider in v:
if provider not in valid_providers:
Expand Down
6 changes: 6 additions & 0 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Xmem models — re-export the public API."""

from src.models.base import Provider
from src.models.registry import get_model

__all__ = ["get_model", "Provider"]
7 changes: 7 additions & 0 deletions src/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Base types for the models module.
"""

from typing import Literal

Provider = Literal["gemini", "claude", "openai"]
23 changes: 23 additions & 0 deletions src/models/claude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Claude model factory.
"""

from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel

from src.config import settings


def build_claude_model(
model_name: str | None = None,
temperature: float | None = None,
) -> BaseChatModel:
api_key = settings.claude_api_key
if not api_key:
raise ValueError("CLAUDE_API_KEY is not set")

return ChatAnthropic(
model=model_name or settings.claude_model,
api_key=api_key,
temperature=temperature if temperature is not None else settings.temperature,
)
23 changes: 23 additions & 0 deletions src/models/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Gemini model factory.
"""

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.language_models import BaseChatModel

from src.config import settings


def build_gemini_model(
model_name: str | None = None,
temperature: float | None = None,
) -> BaseChatModel:
api_key = settings.gemini_api_key
if not api_key:
raise ValueError("GEMINI_API_KEY is not set")

return ChatGoogleGenerativeAI(
model=model_name or settings.gemini_model,
google_api_key=api_key,
temperature=temperature if temperature is not None else settings.temperature,
)
23 changes: 23 additions & 0 deletions src/models/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
OpenAI model factory.
"""

from langchain_openai import ChatOpenAI
from langchain_core.language_models import BaseChatModel

from src.config import settings


def build_openai_model(
model_name: str | None = None,
temperature: float | None = None,
) -> BaseChatModel:
api_key = settings.openai_api_key
if not api_key:
raise ValueError("OPENAI_API_KEY is not set")

return ChatOpenAI(
model=model_name or settings.openai_model,
api_key=api_key,
temperature=temperature if temperature is not None else settings.temperature,
)
87 changes: 87 additions & 0 deletions src/models/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Model registry — single entry-point for getting an LLM instance.

Usage:
from src.models import get_model
model = get_model() # uses first available provider from fallback_order
model = get_model("gemini") # force a specific provider
model = get_model(temperature=0.0) # override temperature
"""

from __future__ import annotations

import logging
from typing import Optional

from langchain_core.language_models import BaseChatModel

from src.config import settings
from src.models.base import Provider

logger = logging.getLogger("xmem.models")

_BUILDERS = {
"gemini": lambda **kw: _build_gemini(**kw),
"claude": lambda **kw: _build_claude(**kw),
"openai": lambda **kw: _build_openai(**kw),
}

_KEY_MAP = {
"gemini": lambda: settings.gemini_api_key,
"claude": lambda: settings.claude_api_key,
"openai": lambda: settings.openai_api_key,
}


def _build_gemini(**kw) -> BaseChatModel:
from src.models.gemini import build_gemini_model
return build_gemini_model(**kw)


def _build_claude(**kw) -> BaseChatModel:
from src.models.claude import build_claude_model
return build_claude_model(**kw)


def _build_openai(**kw) -> BaseChatModel:
from src.models.openai import build_openai_model
return build_openai_model(**kw)


def get_model(
provider: Optional[Provider] = None,
model_name: Optional[str] = None,
temperature: Optional[float] = None,
) -> BaseChatModel:
"""Build and return a chat model.

If *provider* is None the first provider from ``settings.fallback_order``
whose API key is configured will be used. Raises ``RuntimeError`` if no
provider can be initialised.
"""
kw: dict = {}
if model_name is not None:
kw["model_name"] = model_name
if temperature is not None:
kw["temperature"] = temperature

if provider:
return _BUILDERS[provider](**kw)

# Auto-select from fallback order
errors: list[str] = []
for p in settings.fallback_order:
key_fn = _KEY_MAP.get(p)
if key_fn and key_fn():
try:
model = _BUILDERS[p](**kw)
logger.info("Using provider: %s", p)
return model
except Exception as exc:
errors.append(f"{p}: {exc}")
logger.warning("Provider %s failed: %s", p, exc)

raise RuntimeError(
f"No LLM provider could be initialised. Tried: {settings.fallback_order}. "
f"Errors: {errors}"
)
Loading