diff --git a/.gitignore b/.gitignore index 2eea525..bdb887a 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,47 @@ -.env \ No newline at end of file +# Python +*.py[cod] +__pycache__/ +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual environments +venv/ +.venv/ +env/ +.env + +# Pytest +.pytest_cache/ +.coverage +htmlcov/ +coverage.xml + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo + +# MacOS +.DS_Store + +# Logs +logs/ +*.log \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a1691f4..e26009c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "Xmem" version = "0.1.0" description = "Universal unified memory system for AI agents" authors = [ - {name = "Vedant Mahajan", email = "xmemlabs@gmail.com"} + {name = "Vedant Mahajan", email = "xmemlabs@gmail.com"}, {name = "Ishaan Gupta", email = "xmemlabs@gmail.com"} ] readme = "README.md" diff --git a/src/agents/__init__.py b/src/agents/__init__.py index e69de29..cc9048f 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -0,0 +1,7 @@ +"""Xmem agents — re-export the public agent classes.""" + +from src.agents.classifier import ClassifierAgent + +__all__ = [ + "ClassifierAgent", +] diff --git a/src/agents/base.py b/src/agents/base.py index e69de29..45ceaa9 100644 --- a/src/agents/base.py +++ b/src/agents/base.py @@ -0,0 +1,35 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict + +from langchain_core.language_models import BaseChatModel + + +class BaseAgent(ABC): + def __init__(self, model: BaseChatModel, name: str, system_prompt: str = ""): + self.model = model + self.name = name + self.system_prompt = system_prompt + self.logger = logging.getLogger(f"xmem.agents.{name}") + + @abstractmethod + async def arun(self, state: Dict[str, Any]) -> Any: + ... + + def run(self, state: Dict[str, Any]) -> Any: + import asyncio + return asyncio.run(self.arun(state)) + + def _build_messages(self, user_message: str) -> list: + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + messages.append({"role": "user", "content": user_message}) + return messages + + async def _call_model(self, messages: list) -> str: + response = await self.model.ainvoke(messages) + content = response.content + if isinstance(content, list): + content = "\n".join(str(c) for c in content) + return content diff --git a/src/agents/classifier.py b/src/agents/classifier.py index e69de29..f53d02f 100644 --- a/src/agents/classifier.py +++ b/src/agents/classifier.py @@ -0,0 +1,51 @@ +""" +Classifier Agent — the entry-point router for Xmem. + +Classifies user input into one or more intent categories (code, profile, +event) so downstream agents only receive the sub-queries relevant to them. +""" + +from __future__ import annotations + +from typing import Any, Dict + +from langchain_core.language_models import BaseChatModel + +from src.agents.base import BaseAgent +from src.prompts.classifier import build_system_prompt, pack_classification_query +from src.schemas.classification import ClassificationResult +from src.utils.text import parse_raw_response_to_classifications + + +class ClassifierAgent(BaseAgent): + def __init__(self, model: BaseChatModel) -> None: + super().__init__( + model=model, + name="classifier", + system_prompt=build_system_prompt(), + ) + + async def arun(self, state: Dict[str, Any]) -> ClassificationResult: + user_input = state.get("user_query") + if not user_input: + self.logger.debug("Empty query — returning empty classifications.") + return ClassificationResult(classifications=[]) + + user_message = pack_classification_query(user_input) + messages = self._build_messages(user_message) + raw_content = await self._call_model(messages) + classifications = parse_raw_response_to_classifications(raw_content) + + if classifications: + self.logger.info("=" * 50) + self.logger.info("Extracted Classifications:") + for idx, cls in enumerate(classifications, 1): + self.logger.info( + " %d. source=%s query=%s", idx, cls["source"], cls["query"] + ) + self.logger.info("Total classifications: %d", len(classifications)) + self.logger.info("=" * 50) + else: + self.logger.info("No actionable classifications found (trivial input).") + + return ClassificationResult(classifications=classifications) diff --git a/src/config/constants.py b/src/config/constants.py index e69de29..711fe1a 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -0,0 +1,12 @@ +""" +Shared constants used across Xmem agents and prompt formatting. + +These are protocol-level values that all agents rely on for structured +communication with the LLM. Changing them requires updating every +system prompt that references the separator format. +""" + +# Delimiter used in the tab-separated format between LLM and agents. +# Format in prompts: `- SOURCE::QUERY` +# Must stay in sync with all system prompts and parsing utilities. +LLM_TAB_SEPARATOR: str = "::" diff --git a/src/config/settings.py b/src/config/settings.py index f6f46f3..b583e99 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -1,4 +1,4 @@ -from typing import Optional,List +from typing import Optional, List from pydantic import Field,field_validator from pydantic_settings import BaseSettings,SettingsConfigDict @@ -122,7 +122,6 @@ class Settings(BaseSettings): @field_validator("fallback_order") @classmethod def validate_fallback_order(cls, v: List[str]) -> List[str]: - """Ensure fallback_order only contains valid provider names.""" valid_providers = {"gemini", "claude", "openai"} for provider in v: if provider not in valid_providers: diff --git a/src/models/__init__.py b/src/models/__init__.py index e69de29..bc2c12e 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -0,0 +1,6 @@ +"""Xmem models — re-export the public API.""" + +from src.models.base import Provider +from src.models.registry import get_model + +__all__ = ["get_model", "Provider"] diff --git a/src/models/base.py b/src/models/base.py index e69de29..5d19dd9 100644 --- a/src/models/base.py +++ b/src/models/base.py @@ -0,0 +1,7 @@ +""" +Base types for the models module. +""" + +from typing import Literal + +Provider = Literal["gemini", "claude", "openai"] diff --git a/src/models/claude.py b/src/models/claude.py new file mode 100644 index 0000000..14e9607 --- /dev/null +++ b/src/models/claude.py @@ -0,0 +1,23 @@ +""" +Claude model factory. +""" + +from langchain_anthropic import ChatAnthropic +from langchain_core.language_models import BaseChatModel + +from src.config import settings + + +def build_claude_model( + model_name: str | None = None, + temperature: float | None = None, +) -> BaseChatModel: + api_key = settings.claude_api_key + if not api_key: + raise ValueError("CLAUDE_API_KEY is not set") + + return ChatAnthropic( + model=model_name or settings.claude_model, + api_key=api_key, + temperature=temperature if temperature is not None else settings.temperature, + ) diff --git a/src/models/gemini.py b/src/models/gemini.py index e69de29..e135560 100644 --- a/src/models/gemini.py +++ b/src/models/gemini.py @@ -0,0 +1,23 @@ +""" +Gemini model factory. +""" + +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_core.language_models import BaseChatModel + +from src.config import settings + + +def build_gemini_model( + model_name: str | None = None, + temperature: float | None = None, +) -> BaseChatModel: + api_key = settings.gemini_api_key + if not api_key: + raise ValueError("GEMINI_API_KEY is not set") + + return ChatGoogleGenerativeAI( + model=model_name or settings.gemini_model, + google_api_key=api_key, + temperature=temperature if temperature is not None else settings.temperature, + ) diff --git a/src/models/openai.py b/src/models/openai.py index e69de29..4f90216 100644 --- a/src/models/openai.py +++ b/src/models/openai.py @@ -0,0 +1,23 @@ +""" +OpenAI model factory. +""" + +from langchain_openai import ChatOpenAI +from langchain_core.language_models import BaseChatModel + +from src.config import settings + + +def build_openai_model( + model_name: str | None = None, + temperature: float | None = None, +) -> BaseChatModel: + api_key = settings.openai_api_key + if not api_key: + raise ValueError("OPENAI_API_KEY is not set") + + return ChatOpenAI( + model=model_name or settings.openai_model, + api_key=api_key, + temperature=temperature if temperature is not None else settings.temperature, + ) diff --git a/src/models/registry.py b/src/models/registry.py index e69de29..544d928 100644 --- a/src/models/registry.py +++ b/src/models/registry.py @@ -0,0 +1,87 @@ +""" +Model registry — single entry-point for getting an LLM instance. + +Usage: + from src.models import get_model + model = get_model() # uses first available provider from fallback_order + model = get_model("gemini") # force a specific provider + model = get_model(temperature=0.0) # override temperature +""" + +from __future__ import annotations + +import logging +from typing import Optional + +from langchain_core.language_models import BaseChatModel + +from src.config import settings +from src.models.base import Provider + +logger = logging.getLogger("xmem.models") + +_BUILDERS = { + "gemini": lambda **kw: _build_gemini(**kw), + "claude": lambda **kw: _build_claude(**kw), + "openai": lambda **kw: _build_openai(**kw), +} + +_KEY_MAP = { + "gemini": lambda: settings.gemini_api_key, + "claude": lambda: settings.claude_api_key, + "openai": lambda: settings.openai_api_key, +} + + +def _build_gemini(**kw) -> BaseChatModel: + from src.models.gemini import build_gemini_model + return build_gemini_model(**kw) + + +def _build_claude(**kw) -> BaseChatModel: + from src.models.claude import build_claude_model + return build_claude_model(**kw) + + +def _build_openai(**kw) -> BaseChatModel: + from src.models.openai import build_openai_model + return build_openai_model(**kw) + + +def get_model( + provider: Optional[Provider] = None, + model_name: Optional[str] = None, + temperature: Optional[float] = None, +) -> BaseChatModel: + """Build and return a chat model. + + If *provider* is None the first provider from ``settings.fallback_order`` + whose API key is configured will be used. Raises ``RuntimeError`` if no + provider can be initialised. + """ + kw: dict = {} + if model_name is not None: + kw["model_name"] = model_name + if temperature is not None: + kw["temperature"] = temperature + + if provider: + return _BUILDERS[provider](**kw) + + # Auto-select from fallback order + errors: list[str] = [] + for p in settings.fallback_order: + key_fn = _KEY_MAP.get(p) + if key_fn and key_fn(): + try: + model = _BUILDERS[p](**kw) + logger.info("Using provider: %s", p) + return model + except Exception as exc: + errors.append(f"{p}: {exc}") + logger.warning("Provider %s failed: %s", p, exc) + + raise RuntimeError( + f"No LLM provider could be initialised. Tried: {settings.fallback_order}. " + f"Errors: {errors}" + ) diff --git a/src/prompts/classifier.py b/src/prompts/classifier.py index e69de29..2d69b1f 100644 --- a/src/prompts/classifier.py +++ b/src/prompts/classifier.py @@ -0,0 +1,110 @@ +""" +System prompt and query formatting for the Classifier agent. + +The classifier routes user input to specialised downstream agents +(code, profile, event) by examining intent and temporal markers. +""" + +from __future__ import annotations + +from typing import List + +from src.config.constants import LLM_TAB_SEPARATOR +from src.prompts.classifier_keywords import ( + CODE_AGENT_KEYWORDS, + EVENT_AGENT_KEYWORDS, + PROFILE_AGENT_KEYWORDS, + get_keywords_string, +) +from src.prompts.examples.classification import CLASSIFICATION_EXAMPLES +from src.utils.text import pack_classifications_into_string + +_SYSTEM_PROMPT_TEMPLATE = """\ +You are an intelligent intent router for a personal memory assistant. +Your task is to accurately route user inputs to the correct specialized agents for MEMORY STORAGE. + +CRITICAL: Your job is to identify WHAT SHOULD BE REMEMBERED about the user. + +--- + +## Available Agents + +### 1. `code` +- **Purpose**: Software engineering and technical tasks (writing, debugging, explaining code) +- **Keywords**: {code_keywords} +- **Route here when**: User wants help with actual coding work, debugging, or technical explanations + +### 2. `profile` +- **Purpose**: Store PERMANENT facts about the user (identity, preferences, traits, background) +- **Keywords**: {profile_keywords} +- **Route here when**: User shares static personal information that doesn't have a specific date +- **Examples**: name, job, hobbies, food preferences, personality traits, where they live + +### 3. `event` +- **Purpose**: Store TIME-BASED events and memories (past, present, or future) +- **Keywords**: {event_keywords} +- **Route here when**: User mentions something that happened/will happen at a SPECIFIC TIME +- **Examples**: birthdays, anniversaries, "last Saturday", "3 years ago", "next month" + +## Logic & Strategy + +### 1. Look for Temporal Markers FIRST +Before classifying, scan for ANY time reference: +- Absolute: dates, years, months, days +- Relative: "ago", "last", "next", "yesterday", "tomorrow" +- Age-based: "when I was X", "at age X", "X years old" +- Ordinal: "first", "18th birthday", "second anniversary" + +If temporal marker found → likely `event` + +### 2. Decomposition (Multi-Intent) +If input contains MULTIPLE distinct pieces of information, split them: +- "I'm John and my birthday is March 15th" → `profile` (name) + `event` (birthday) +- "I moved to NYC last year and now work at Google" → `event` (move) + `profile` (job) + +### 3. Skip Trivial Messages +Pure greetings/acknowledgments with NO factual content → empty list +- "Hi!", "Thanks!", "Great!", "Okay" → [] + +--- + +## Output Format (Strict) + +One classification per line: +- Format: `SOURCE{tab}QUERY` +- `SOURCE` must be: `code`, `profile`, or `event` +- For trivial inputs, output nothing + +--- + +## Examples +{examples} + +--- +""" + + +def build_system_prompt() -> str: + """Render the full system prompt with examples and keywords injected.""" + examples_block = "\n\n".join( + f"\n" + f"{user_input}\n" + f"\n" + f"{pack_classifications_into_string(classifications) if classifications else '(empty - trivial/skip)'}\n" + f"\n" + f"" + for user_input, classifications in CLASSIFICATION_EXAMPLES + ) + + return _SYSTEM_PROMPT_TEMPLATE.format( + tab=LLM_TAB_SEPARATOR, + examples=examples_block, + code_keywords=get_keywords_string(CODE_AGENT_KEYWORDS), + profile_keywords=get_keywords_string(PROFILE_AGENT_KEYWORDS), + event_keywords=get_keywords_string(EVENT_AGENT_KEYWORDS), + ) + + +def pack_classification_query(user_input: str) -> str: + """Wrap the raw user input in the expected user-message format.""" + return f"Analyze this user input:\n\nUser Input: {user_input}" diff --git a/src/prompts/classifier_keywords.py b/src/prompts/classifier_keywords.py new file mode 100644 index 0000000..f143d1e --- /dev/null +++ b/src/prompts/classifier_keywords.py @@ -0,0 +1,112 @@ +CODE_AGENT_KEYWORDS = [ + # Programming languages + "python", "javascript", "typescript", "java", "c++", "c#", "go", "rust", + "ruby", "php", "swift", "kotlin", "scala", "perl", "lua", "haskell", + # Actions + "code", "coding", "program", "script", "implement", "develop", "build", + "refactor", "optimize", "compile", "execute", "run", "deploy", + # Debugging + "debug", "fix", "error", "bug", "issue", "exception", "stack trace", + "segmentation fault", "syntax error", "runtime error", "logic error", + "traceback", "crash", "failing", "broken", + # Concepts + "function", "method", "class", "object", "variable", "loop", + "array", "list", "map", "dictionary", "set", "tuple", "struct", + "algorithm", "data structure", "recursion", "iteration", + # APIs & Web + "api", "endpoint", "request", "response", "rest", "graphql", + "sdk", "library", "framework", "package", "module", "import", + "http", "json", "xml", "websocket", + # DevOps + "git", "github", "docker", "kubernetes", "ci/cd", "pipeline", + "aws", "azure", "gcp", "cloud", "server", "database", + # Testing + "unit test", "test case", "mock", "testing", "pytest", "jest", + # Requests + "how do i code", "write a function", "explain this code", + "why is this failing", "convert this code", "refactor this", + "what does this do", "how does this work" +] + +PROFILE_AGENT_KEYWORDS = [ + # Identity + "my name is", "i am called", "call me", "who am i", "i'm", + "i am a", "i work as", "my job is", "my profession", + # Preferences + "i like", "i love", "i prefer", "my favorite", "what do i prefer", + "i enjoy", "i hate", "i dislike", "can't stand", "i'm into", + # Habits + "i usually", "i often", "i always", "i never", "i sometimes", + "every day", "daily routine", "my habit", + # Location & Background + "i live in", "i am from", "my city is", "my country is", + "i grew up in", "my hometown", "based in", + # Contact & Demographics + "my phone number", "my email", "my age is", "i am X years old", + "my address", "my gender", "my nationality", + # Relationships + "my wife", "my husband", "my partner", "my kids", "my children", + "my mom", "my dad", "my parents", "my family", "my friend", + "my brother", "my sister", "my son", "my daughter", + # Traits & Values + "i believe", "i value", "important to me", "i care about", + "my personality", "i'm the type of person", + # Memory commands + "remember this", "store this", "save this", "add to my profile", + "keep this in mind", "note this down", "store this fact", + "update my", "change my", "that is wrong", "actually i am", + # Interests & Hobbies (without time) + "my hobby", "i'm interested in", "passionate about", + "i volunteer", "i support", "i'm learning" +] + +EVENT_AGENT_KEYWORDS = [ + # Future scheduling + "schedule", "book", "plan", "set up", "arrange", "create event", + "remind me", "set a reminder", "alarm", "notify me", "ping me", + "meeting", "call", "appointment", "interview", "session", + "deadline", "standup", "demo", "reservation", + # Relative time (future) + "today", "tomorrow", "day after tomorrow", + "next week", "next month", "next year", "this weekend", + "upcoming", "soon", "later", "in a few days", + # Days of week + "monday", "tuesday", "wednesday", "thursday", "friday", + "saturday", "sunday", + # Time of day + "in the morning", "in the evening", "at night", "afternoon", + "am", "pm", "o'clock", + # Calendar actions + "calendar", "add to calendar", "reschedule", "cancel meeting", + # PAST temporal expressions (CRITICAL for memory) + "yesterday", "last week", "last month", "last year", + "last saturday", "last sunday", "last monday", "last tuesday", + "last wednesday", "last thursday", "last friday", + "years ago", "months ago", "weeks ago", "days ago", + "a year ago", "a month ago", "a week ago", + "back in", "when i was", "used to", "in the past", + "recently", "just", "the other day", + # Specific years/dates + "in 2020", "in 2021", "in 2022", "in 2023", "in 2024", "in 2025", + "january", "february", "march", "april", "may", "june", + "july", "august", "september", "october", "november", "december", + # Recurring/annual events + "birthday", "anniversary", "graduation", "wedding", "funeral", + "holiday", "christmas", "thanksgiving", "new year", "easter", + "valentine", "halloween", "independence day", + # Age/timeline references + "born on", "born in", "turned", "18th birthday", "21st birthday", + "when i turned", "at age", "years old when", + # Life events with dates (MILESTONES) + "moved", "started", "graduated", "married", "retired", "began", + "finished", "completed", "launched", "opened", "closed", + "joined", "left", "quit", "hired", "fired", "promoted", + "divorced", "engaged", "pregnant", "gave birth", "adopted", + "surgery", "diagnosed", "recovered", "hospitalized", + "traveled", "visited", "went to", "came back from", + "first time", "last time", "this was when" +] + + +def get_keywords_string(keywords: list) -> str: + return ", ".join(keywords) diff --git a/src/prompts/examples/classification.py b/src/prompts/examples/classification.py new file mode 100644 index 0000000..4a82213 --- /dev/null +++ b/src/prompts/examples/classification.py @@ -0,0 +1,252 @@ +from src.schemas.classification import Classification +from typing import List, Tuple + +CLASSIFICATION_EXAMPLES: List[Tuple[str, List[Classification]]] = [ + ( + "Thank you so much!", + [] + ), + ( + "Hi, how are you?", + [] + ), + ( + "Great, thanks!", + [] + ), + ( + "Debug this error: TypeError: 'int' object is not iterable", + [{"source": "code", "query": "Debug this error: TypeError: 'int' object is not iterable"}] + ), + ( + "Explain how the asyncio event loop works in Python", + [{"source": "code", "query": "Explain how the asyncio event loop works in Python"}] + ), + ( + "Help me write a function to reverse a linked list", + [{"source": "code", "query": "Help me write a function to reverse a linked list"}] + ), + ( + "I prefer dark mode in all my applications", + [{"source": "profile", "query": "I prefer dark mode in all my applications"}] + ), + ( + "My name is Alice and I work at Google", + [{"source": "profile", "query": "My name is Alice and I work at Google"}] + ), + ( + "I'm a vegetarian and love Italian food", + [{"source": "profile", "query": "I'm a vegetarian and love Italian food"}] + ), + ( + "My birthday is on March 15th", + [{"source": "event", "query": "My birthday is on March 15th"}] + ), + ( + "Our wedding anniversary is July 22nd, 2019", + [{"source": "event", "query": "Our wedding anniversary is July 22nd, 2019"}] + ), + ( + "I have a dentist appointment on January 10th at 2:30 PM", + [{"source": "event", "query": "I have a dentist appointment on January 10th at 2:30 PM"}] + ), + ( + "My daughter's birthday is December 25th, she was born in 2015", + [{"source": "event", "query": "My daughter's birthday is December 25th, she was born in 2015"}] + ), + ( + "My name is Alice and I want to write a python script to hello world", + [ + {"source": "profile", "query": "My name is Alice"}, + {"source": "code", "query": "I want to write a python script to hello world"} + ] + ), + ( + "I'm learning Rust. How do I print variables in Rust?", + [ + {"source": "profile", "query": "I'm learning Rust"}, + {"source": "code", "query": "how do I print variables in Rust?"} + ] + ), + ( + "My name is John and my birthday is April 5th", + [ + {"source": "profile", "query": "My name is John"}, + {"source": "event", "query": "my birthday is April 5th"} + ] + ), + ( + "I graduated on May 20th 2020 and now I work as a software engineer", + [ + {"source": "event", "query": "I graduated on May 20th 2020"}, + {"source": "profile", "query": "I work as a software engineer"} + ] + ), + ( + "I prefer writing code in TypeScript over JavaScript", + [{"source": "profile", "query": "I prefer writing code in TypeScript over JavaScript"}] + ), + ( + "Mom's birthday is February 14th", + [{"source": "event", "query": "Mom's birthday is February 14th"}] + ), + ( + "I ran a charity race last Saturday", + [{"source": "event", "query": "I ran a charity race last Saturday"}] + ), + ( + "I moved from Sweden 4 years ago", + [{"source": "event", "query": "I moved from Sweden 4 years ago"}] + ), + ( + "I started transitioning about 3 years ago", + [{"source": "event", "query": "I started transitioning about 3 years ago"}] + ), + ( + "I graduated from college in May 2018", + [{"source": "event", "query": "I graduated from college in May 2018"}] + ), + ( + "My 18th birthday was ten years ago when my friend gave me a bowl", + [{"source": "event", "query": "My 18th birthday was ten years ago when my friend gave me a bowl"}] + ), + ( + "I went through a tough breakup last month and now I'm focusing on myself", + [ + {"source": "event", "query": "I went through a tough breakup last month"}, + {"source": "profile", "query": "I'm focusing on myself"} + ] + ), + ( + "How do I set up a Docker container for my Node.js app?", + [{"source": "code", "query": "How do I set up a Docker container for my Node.js app?"}] + ), + ( + "Debug this kubernetes pod crash", + [{"source": "code", "query": "Debug this kubernetes pod crash"}] + ), + ( + "Write unit tests for this function using pytest", + [{"source": "code", "query": "Write unit tests for this function using pytest"}] + ), + ( + "How do I mock API calls in Jest?", + [{"source": "code", "query": "How do I mock API calls in Jest?"}] + ), + ( + "Explain how to make a REST API call with authentication", + [{"source": "code", "query": "Explain how to make a REST API call with authentication"}] + ), + ( + "Convert this JSON response to a Python dictionary", + [{"source": "code", "query": "Convert this JSON response to a Python dictionary"}] + ), + ( + "I usually wake up at 6 AM every day", + [{"source": "profile", "query": "I usually wake up at 6 AM every day"}] + ), + ( + "I never drink coffee after 3 PM", + [{"source": "profile", "query": "I never drink coffee after 3 PM"}] + ), + ( + "I believe in work-life balance", + [{"source": "profile", "query": "I believe in work-life balance"}] + ), + ( + "Privacy is very important to me", + [{"source": "profile", "query": "Privacy is very important to me"}] + ), + ( + "My daughter Sarah is 8 years old", + [{"source": "profile", "query": "My daughter Sarah is 8 years old"}] + ), + ( + "My best friend lives in Seattle", + [{"source": "profile", "query": "My best friend lives in Seattle"}] + ), + ( + "I'm from Tokyo but now I live in San Francisco", + [{"source": "profile", "query": "I'm from Tokyo but now I live in San Francisco"}] + ), + ( + "My email is john@example.com", + [{"source": "profile", "query": "My email is john@example.com"}] + ), + ( + "Schedule a meeting with the team next Tuesday at 10 AM", + [{"source": "event", "query": "Schedule a meeting with the team next Tuesday at 10 AM"}] + ), + ( + "Remind me to call mom tomorrow evening", + [{"source": "event", "query": "Remind me to call mom tomorrow evening"}] + ), + ( + "I have a dentist appointment this Friday at 2:30 PM", + [{"source": "event", "query": "I have a dentist appointment this Friday at 2:30 PM"}] + ), + ( + "I visited Paris in August 2022", + [{"source": "event", "query": "I visited Paris in August 2022"}] + ), + ( + "We got married 5 years ago", + [{"source": "event", "query": "We got married 5 years ago"}] + ), + ( + "I finished my master's degree back in 2019", + [{"source": "event", "query": "I finished my master's degree back in 2019"}] + ), + ( + "Started learning guitar 6 months ago", + [{"source": "event", "query": "Started learning guitar 6 months ago"}] + ), + ( + "I got my first car when I turned 18", + [{"source": "event", "query": "I got my first car when I turned 18"}] + ), + ( + "I was diagnosed with diabetes at age 25", + [{"source": "event", "query": "I was diagnosed with diabetes at age 25"}] + ), + ( + "I joined Google in January 2020", + [{"source": "event", "query": "I joined Google in January 2020"}] + ), + ( + "We adopted our dog Rex last summer", + [{"source": "event", "query": "We adopted our dog Rex last summer"}] + ), + ( + "Launched my startup in March 2023", + [{"source": "event", "query": "Launched my startup in March 2023"}] + ), + ( + "I'm a DevOps engineer and I usually work with Kubernetes. Can you help me debug this pod error?", + [ + {"source": "profile", "query": "I'm a DevOps engineer and I usually work with Kubernetes"}, + {"source": "code", "query": "Can you help me debug this pod error?"} + ] + ), + ( + "I got engaged last Christmas and my fiancé loves hiking", + [ + {"source": "event", "query": "I got engaged last Christmas"}, + {"source": "profile", "query": "my fiancé loves hiking"} + ] + ), + ( + "I prefer using VS Code for development. How do I set up Python debugging in it?", + [ + {"source": "profile", "query": "I prefer using VS Code for development"}, + {"source": "code", "query": "How do I set up Python debugging in it?"} + ] + ), + ( + "My son was born on June 15th 2020 and he loves dinosaurs", + [ + {"source": "event", "query": "My son was born on June 15th 2020"}, + {"source": "profile", "query": "my son loves dinosaurs"} + ] + ), +] diff --git a/src/schemas/classification.py b/src/schemas/classification.py index e69de29..c719bc4 100644 --- a/src/schemas/classification.py +++ b/src/schemas/classification.py @@ -0,0 +1,16 @@ +from typing import List, Literal +from typing_extensions import TypedDict +from pydantic import BaseModel, Field + + +class Classification(TypedDict): + """A classification of a part of the user query.""" + source: Literal["code", "profile", "event"] + query: str + + +class ClassificationResult(BaseModel): + """Result of classifying a user query into agent-specific sub-questions.""" + classifications: List[Classification] = Field( + description="List of agents to invoke with their targeted sub-questions" + ) \ No newline at end of file diff --git a/src/utils/text.py b/src/utils/text.py index e69de29..8d19587 100644 --- a/src/utils/text.py +++ b/src/utils/text.py @@ -0,0 +1,67 @@ +""" +Text parsing and formatting utilities for structured LLM responses. + +All pack_* functions convert structured data → prompt strings. +All parse_* functions convert raw LLM output → structured data. +""" + +from __future__ import annotations + +from typing import List + +from src.config.constants import LLM_TAB_SEPARATOR +from src.schemas.classification import Classification + + +def attribute_unify(value: str) -> str: + return value.lower().replace(" ", "_") + + +# --------------------------------------------------------------------------- +# Classification helpers +# --------------------------------------------------------------------------- + +_VALID_SOURCES = frozenset({"code", "profile", "event"}) + + +def pack_classifications_into_string( + classifications: List[Classification], +) -> str: + """Serialise a list of classifications into the tab-separated prompt format. + + Example output:: + profile::My name is Alice + code::Write me a hello-world script + """ + lines = [ + f"{c['source']}{LLM_TAB_SEPARATOR}{c['query']}" + for c in classifications + ] + return "\n".join(lines) + + +def parse_raw_response_to_classifications(content: str) -> List[Classification]: + """Parse the raw LLM response into a list of Classification dicts. + + Expected line format:: + SOURCE::QUERY + """ + classifications: List[Classification] = [] + + for line in content.strip().splitlines(): + line = line.strip() + if LLM_TAB_SEPARATOR not in line: + continue + + parts = line.split(LLM_TAB_SEPARATOR, maxsplit=1) + + if len(parts) < 2: + continue + + source = parts[0].strip().lower() + query = parts[1].strip() + + if source in _VALID_SOURCES and query: + classifications.append({"source": source, "query": query}) + + return classifications diff --git a/tests/unit/agents/test_classifier.py b/tests/unit/agents/test_classifier.py index e69de29..d07a7c7 100644 --- a/tests/unit/agents/test_classifier.py +++ b/tests/unit/agents/test_classifier.py @@ -0,0 +1,51 @@ +""" +Interactive classifier test — send queries and see live classifications. + +Usage: + PYTHONPATH=. python tests/unit/agents/test_classifier.py + PYTHONPATH=. python tests/unit/agents/test_classifier.py --provider gemini +""" + +import asyncio +import sys + +from src.models import get_model +from src.agents.classifier import ClassifierAgent + + +async def main(): + provider = None + if "--provider" in sys.argv: + idx = sys.argv.index("--provider") + provider = sys.argv[idx + 1] + + model = get_model(provider=provider) + agent = ClassifierAgent(model=model) + + print(f"\n Classifier Agent ready (model: {model.__class__.__name__})") + print(f" Type a query and press Enter. Type 'q' to quit.\n") + + while True: + try: + query = input(">> ").strip() + except (EOFError, KeyboardInterrupt): + print() + break + + if query.lower() in ("q", "quit", "exit"): + break + if not query: + continue + + result = await agent.arun({"user_query": query}) + + if not result.classifications: + print(" (no classifications — trivial/skip)\n") + else: + for i, c in enumerate(result.classifications, 1): + print(f" {i}. [{c['source']}] {c['query']}") + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/utils/test_text.py b/tests/unit/utils/test_text.py new file mode 100644 index 0000000..44a8629 --- /dev/null +++ b/tests/unit/utils/test_text.py @@ -0,0 +1,58 @@ +import pytest +from src.utils.text import ( + attribute_unify, + pack_classifications_into_string, + parse_raw_response_to_classifications, +) + +def test_attribute_unify(): + assert attribute_unify("Favorite Food") == "favorite_food" + assert attribute_unify("dark_mode") == "dark_mode" + assert attribute_unify("") == "" + + +def test_pack_multiple_classifications(): + result = pack_classifications_into_string([ + {"source": "profile", "query": "My name is John"}, + {"source": "event", "query": "my birthday is April 5th"}, + ]) + lines = result.split("\n") + assert len(lines) == 2 + assert lines[0] == "profile::My name is John" + assert lines[1] == "event::my birthday is April 5th" + + +def test_pack_empty_list(): + assert pack_classifications_into_string([]) == "" + + + +class TestParseClassifications: + def test_parse_multiple_lines(self): + raw = "event::I graduated in 2020\nprofile::I work as a developer" + result = parse_raw_response_to_classifications(raw) + assert len(result) == 2 + assert result[0]["source"] == "event" + assert result[1]["source"] == "profile" + + def test_ignores_preamble_and_invalid_lines(self): + # Combined edge case test: preamble + invalid source + valid source + raw = ( + "Analysis:\n" + "invalid::junk\n" + "profile::Valid query" + ) + result = parse_raw_response_to_classifications(raw) + assert len(result) == 1 + assert result[0] == {"source": "profile", "query": "Valid query"} + + def test_query_with_separator_in_text(self): + raw = "code::Fix error: TypeError:: 'int' not iterable" + result = parse_raw_response_to_classifications(raw) + assert len(result) == 1 + assert result[0]["query"] == "Fix error: TypeError:: 'int' not iterable" + + def test_empty_and_trivial_input(self): + assert parse_raw_response_to_classifications("") == [] + assert parse_raw_response_to_classifications("(empty)") == [] +