diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/__init__.py b/evaluation/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/temporal_locomo/models/__init__.py b/evaluation/scripts/temporal_locomo/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/temporal_locomo/locomo_eval.py b/evaluation/scripts/temporal_locomo/models/locomo_eval.py similarity index 78% rename from evaluation/scripts/temporal_locomo/locomo_eval.py rename to evaluation/scripts/temporal_locomo/models/locomo_eval.py index f19e5b68f..f98a481e2 100644 --- a/evaluation/scripts/temporal_locomo/locomo_eval.py +++ b/evaluation/scripts/temporal_locomo/models/locomo_eval.py @@ -9,7 +9,6 @@ from bert_score import score as bert_score from dotenv import load_dotenv -from modules.locomo_eval_module import LocomoEvalModelModules from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu from nltk.translate.meteor_score import meteor_score from openai import AsyncOpenAI @@ -19,6 +18,7 @@ from sentence_transformers import SentenceTransformer from tqdm import tqdm +from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules from memos.log import get_logger @@ -281,33 +281,64 @@ def __init__(self, args): api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL") ) - async def run(self): - print( - f"\n=== Starting LoCoMo evaluation for {self.frame} (version: {self.version}) with {self.num_runs} run(s) per question ===" - ) - print(f"Using {self.max_workers} concurrent workers for processing groups") + def _load_response_data(self): + """ + Load response data from the response path file. + Returns: + dict: The loaded response data + """ with open(self.response_path) as file: - locomo_responses = json.load(file) + return json.load(file) - num_users = 10 + def _load_existing_evaluation_results(self): + """ + Attempt to load existing evaluation results from the judged path. + If the file doesn't exist or there's an error loading it, return an empty dict. + + Returns: + dict: Existing evaluation results or empty dict if none available + """ all_grades = {} + try: + if os.path.exists(self.judged_path): + with open(self.judged_path) as f: + all_grades = json.load(f) + print(f"Loaded existing evaluation results from {self.judged_path}") + except Exception as e: + print(f"Error loading existing evaluation results: {e}") - total_responses_count = sum( - len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users) - ) - print(f"Found {total_responses_count} total responses across {num_users} users to evaluate") + return all_grades + + def _create_evaluation_tasks(self, locomo_responses, all_grades, num_users): + """ + Create evaluation tasks for groups that haven't been evaluated yet. + + Args: + locomo_responses (dict): The loaded response data + all_grades (dict): Existing evaluation results + num_users (int): Number of user groups to process - # Create tasks for processing each group + Returns: + tuple: (tasks list, active users count) + """ tasks = [] active_users = 0 + for group_idx in range(num_users): group_id = f"locomo_exp_user_{group_idx}" group_responses = locomo_responses.get(group_id, []) + if not group_responses: print(f"No responses found for group {group_id}") continue + # Skip groups that already have evaluation results + if all_grades.get(group_id): + print(f"Skipping group {group_id} as it already has evaluation results") + active_users += 1 + continue + active_users += 1 tasks.append( process_single_group( @@ -319,29 +350,50 @@ async def run(self): ) ) - print(f"Starting evaluation of {active_users} user groups with responses") + return tasks, active_users + + async def _process_tasks(self, tasks): + """ + Process evaluation tasks with concurrency control. + + Args: + tasks (list): List of tasks to process + + Returns: + list: Results from processing all tasks + """ + if not tasks: + return [] semaphore = asyncio.Semaphore(self.max_workers) async def limited_task(task): + """Helper function to limit concurrent task execution""" async with semaphore: return await task limited_tasks = [limited_task(task) for task in tasks] - group_results = await asyncio.gather(*limited_tasks) + return await asyncio.gather(*limited_tasks) - for group_id, graded_responses in group_results: - all_grades[group_id] = graded_responses + def _calculate_scores(self, all_grades): + """ + Calculate evaluation scores based on all grades. - print("\n=== Evaluation Complete: Calculating final scores ===") + Args: + all_grades (dict): The complete evaluation results + Returns: + tuple: (run_scores, evaluated_count) + """ run_scores = [] evaluated_count = 0 + if self.num_runs > 0: for i in range(1, self.num_runs + 1): judgment_key = f"judgment_{i}" current_run_correct_count = 0 current_run_total_count = 0 + for group in all_grades.values(): for response in group: if judgment_key in response["llm_judgments"]: @@ -355,6 +407,16 @@ async def limited_task(task): evaluated_count = current_run_total_count + return run_scores, evaluated_count + + def _report_scores(self, run_scores, evaluated_count): + """ + Report evaluation scores to the console. + + Args: + run_scores (list): List of accuracy scores for each run + evaluated_count (int): Number of evaluated responses + """ if evaluated_count > 0: mean_of_scores = np.mean(run_scores) std_of_scores = np.std(run_scores) @@ -368,11 +430,63 @@ async def limited_task(task): print("No responses were evaluated") print("LLM-as-a-Judge score: N/A (0/0)") + def _save_results(self, all_grades): + """ + Save evaluation results to the judged path file. + + Args: + all_grades (dict): The complete evaluation results to save + """ all_grades = convert_numpy_types(all_grades) with open(self.judged_path, "w") as f: json.dump(all_grades, f, indent=2) print(f"Saved detailed evaluation results to {self.judged_path}") + async def run(self): + """ + Main execution method for the LoCoMo evaluation process. + This method orchestrates the entire evaluation workflow: + 1. Loads existing evaluation results if available + 2. Processes only groups that haven't been evaluated yet + 3. Calculates and reports final evaluation scores + """ + print( + f"\n=== Starting LoCoMo evaluation for {self.frame} (version: {self.version}) with {self.num_runs} run(s) per question ===" + ) + print(f"Using {self.max_workers} concurrent workers for processing groups") + + # Load response data and existing evaluation results + locomo_responses = self._load_response_data() + all_grades = self._load_existing_evaluation_results() + + # Count total responses for reporting + num_users = 10 + total_responses_count = sum( + len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users) + ) + print(f"Found {total_responses_count} total responses across {num_users} users to evaluate") + + # Create tasks only for groups that haven't been evaluated yet + tasks, active_users = self._create_evaluation_tasks(locomo_responses, all_grades, num_users) + print( + f"Starting evaluation of {len(tasks)} user groups with responses (out of {active_users} active users)" + ) + + # Process tasks and update all_grades with results + if tasks: + group_results = await self._process_tasks(tasks) + for group_id, graded_responses in group_results: + all_grades[group_id] = graded_responses + + print("\n=== Evaluation Complete: Calculating final scores ===") + + # Calculate and report scores + run_scores, evaluated_count = self._calculate_scores(all_grades) + self._report_scores(run_scores, evaluated_count) + + # Save results + self._save_results(all_grades) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/evaluation/scripts/temporal_locomo/locomo_ingestion.py b/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py similarity index 98% rename from evaluation/scripts/temporal_locomo/locomo_ingestion.py rename to evaluation/scripts/temporal_locomo/models/locomo_ingestion.py index 321302cf2..b45ec3d61 100644 --- a/evaluation/scripts/temporal_locomo/locomo_ingestion.py +++ b/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py @@ -6,16 +6,16 @@ from datetime import datetime, timezone from pathlib import Path -from modules.constants import ( +from tqdm import tqdm + +from evaluation.scripts.temporal_locomo.modules.constants import ( MEM0_GRAPH_MODEL, MEM0_MODEL, MEMOS_MODEL, MEMOS_SCHEDULER_MODEL, ZEP_MODEL, ) -from modules.locomo_eval_module import LocomoEvalModelModules -from tqdm import tqdm - +from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules from memos.log import get_logger diff --git a/evaluation/scripts/temporal_locomo/locomo_metric.py b/evaluation/scripts/temporal_locomo/models/locomo_metric.py similarity index 99% rename from evaluation/scripts/temporal_locomo/locomo_metric.py rename to evaluation/scripts/temporal_locomo/models/locomo_metric.py index 0187c37e7..532fe2e14 100644 --- a/evaluation/scripts/temporal_locomo/locomo_metric.py +++ b/evaluation/scripts/temporal_locomo/models/locomo_metric.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd -from modules.locomo_eval_module import LocomoEvalModelModules +from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules # Category mapping as per your request diff --git a/evaluation/scripts/temporal_locomo/locomo_processor.py b/evaluation/scripts/temporal_locomo/models/locomo_processor.py similarity index 63% rename from evaluation/scripts/temporal_locomo/locomo_processor.py rename to evaluation/scripts/temporal_locomo/models/locomo_processor.py index 4ae9cf915..7cec6f5af 100644 --- a/evaluation/scripts/temporal_locomo/locomo_processor.py +++ b/evaluation/scripts/temporal_locomo/models/locomo_processor.py @@ -7,20 +7,19 @@ from time import time from dotenv import load_dotenv -from modules.constants import ( - MEMOS_MODEL, + +from evaluation.scripts.temporal_locomo.modules.constants import ( MEMOS_SCHEDULER_MODEL, ) -from modules.locomo_eval_module import LocomoEvalModelModules -from modules.prompts import ( +from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules +from evaluation.scripts.temporal_locomo.modules.prompts import ( SEARCH_PROMPT_MEM0, SEARCH_PROMPT_MEM0_GRAPH, SEARCH_PROMPT_MEMOS, SEARCH_PROMPT_ZEP, ) -from modules.schemas import ContextUpdateMethod, RecordingCase -from modules.utils import save_evaluation_cases - +from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase +from evaluation.scripts.temporal_locomo.modules.utils import save_evaluation_cases from memos.log import get_logger @@ -54,77 +53,22 @@ def __init__(self, args): self.processed_data_dir = self.result_dir / "processed_data" def update_context(self, conv_id, method, **kwargs): - if method == ContextUpdateMethod.DIRECT: + if method == ContextUpdateMethod.CHAT_HISTORY: + if "query" not in kwargs or "answer" not in kwargs: + raise ValueError("query and answer are required for TEMPLATE update method") + new_context = f"User: {kwargs['query']}\nAssistant: {kwargs['answer']}\n\n" + if self.pre_context_cache[conv_id] is None: + self.pre_context_cache[conv_id] = "" + self.pre_context_cache[conv_id] += new_context + else: if "cur_context" not in kwargs: raise ValueError("cur_context is required for DIRECT update method") cur_context = kwargs["cur_context"] self.pre_context_cache[conv_id] = cur_context - elif method == ContextUpdateMethod.TEMPLATE: - if "query" not in kwargs or "answer" not in kwargs: - raise ValueError("query and answer are required for TEMPLATE update method") - self._update_context_template(conv_id, kwargs["query"], kwargs["answer"]) - else: - raise ValueError(f"Unsupported update method: {method}") - - def _update_context_template(self, conv_id, query, answer): - new_context = f"User: {query}\nAssistant: {answer}\n\n" - if self.pre_context_cache[conv_id] is None: - self.pre_context_cache[conv_id] = "" - self.pre_context_cache[conv_id] += new_context - - def _process_single_qa( - self, - qa, - *, - client, - reversed_client, - metadata, - frame, - version, - conv_id, - conv_stats_path, - oai_client, - top_k, - conv_stats, - ): - query = qa.get("question") - gold_answer = qa.get("answer") - qa_category = qa.get("category") - if qa_category == 5: - return None - # Search - cur_context, search_duration_ms = self.search_query( - client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k - ) - if not cur_context: - logger.warning(f"No context found for query: {query[:100]}") - cur_context = "" - - # Context answerability analysis (for memos_scheduler only) - if self.pre_context_cache[conv_id] is None: - # Update pre-context cache with current context - if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - cur_context=cur_context, - ) - else: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - query=query, - answer=gold_answer, - ) - return None - - can_answer = False - can_answer_duration_ms = 0.0 + def eval_context(self, context, query, gold_answer, oai_client): can_answer_start = time() - can_answer = self.analyze_context_answerability( - self.pre_context_cache[conv_id], query, gold_answer, oai_client - ) + can_answer = self.analyze_context_answerability(context, query, gold_answer, oai_client) can_answer_duration_ms = (time() - can_answer_start) * 1000 # Update global stats with self.stats_lock: @@ -143,54 +87,41 @@ def _process_single_qa( can_answer_duration_ms ) self.save_stats() + return can_answer, can_answer_duration_ms - # Generate answer - answer_start = time() - answer = self.locomo_response(frame, oai_client, self.pre_context_cache[conv_id], query) - response_duration_ms = (time() - answer_start) * 1000 - - # Record case for memos_scheduler - if frame == "memos_scheduler": - try: - recording_case = RecordingCase( - conv_id=conv_id, - query=query, - answer=answer, - context=cur_context, - pre_context=self.pre_context_cache[conv_id], - can_answer=can_answer, - can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", - search_duration_ms=search_duration_ms, - can_answer_duration_ms=can_answer_duration_ms, - response_duration_ms=response_duration_ms, - category=int(qa_category) if qa_category is not None else None, - golden_answer=str(qa.get("answer", "")), - memories=[], - pre_memories=[], - history_queries=[], - ) - if can_answer: - self.can_answer_cases.append(recording_case) - else: - self.cannot_answer_cases.append(recording_case) - except Exception as e: - logger.error(f"Error creating RecordingCase: {e}") - print(f"Error creating RecordingCase: {e}") - logger.error(f"QA data: {qa}") - print(f"QA data: {qa}") - logger.error(f"Query: {query}") - logger.error(f"Answer: {answer}") - logger.error( - f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" - ) - logger.error(f"Category: {qa_category} (type: {type(qa_category)})") - logger.error(f"Can answer: {can_answer}") - raise e - + def _update_stats_and_context( + self, + *, + conv_id, + frame, + version, + conv_stats, + conv_stats_path, + query, + answer, + gold_answer, + cur_context, + can_answer, + ): + """ + Update conversation statistics and context. + + Args: + conv_id: Conversation ID + frame: Model frame + version: Model version + conv_stats: Conversation statistics dictionary + conv_stats_path: Path to save conversation statistics + query: User query + answer: Generated answer + gold_answer: Golden answer + cur_context: Current context + can_answer: Whether the context can answer the query + """ # Update conversation stats conv_stats["total_queries"] += 1 conv_stats["response_count"] += 1 - if frame == "memos_scheduler": + if frame == MEMOS_SCHEDULER_MODEL: if can_answer: conv_stats["can_answer_count"] += 1 else: @@ -208,22 +139,137 @@ def _process_single_qa( # Update pre-context cache with current context with self.stats_lock: - if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: + if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: self.update_context( conv_id=conv_id, method=self.context_update_method, - cur_context=cur_context, + query=query, + answer=answer, ) else: self.update_context( conv_id=conv_id, method=self.context_update_method, - query=query, - answer=gold_answer, + cur_context=cur_context, ) self.print_eval_info() + def _process_single_qa( + self, + qa, + *, + client, + reversed_client, + metadata, + frame, + version, + conv_id, + conv_stats_path, + oai_client, + top_k, + conv_stats, + ): + query = qa.get("question") + gold_answer = qa.get("answer") + qa_category = qa.get("category") + if qa_category == 5: + return None + + # Search + cur_context, search_duration_ms = self.search_query( + client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k + ) + if not cur_context: + logger.warning(f"No context found for query: {query[:100]}") + cur_context = "" + + if self.context_update_method == ContextUpdateMethod.CURRENT_CONTEXT: + context = cur_context + else: + # Context answer ability analysis (for memos_scheduler only) + if self.pre_context_cache[conv_id] is None: + # Update pre-context cache with current context and return + if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: + answer_from_cur_context = self.locomo_response( + frame, oai_client, cur_context, query + ) + self.update_context( + conv_id=conv_id, + method=self.context_update_method, + query=query, + answer=answer_from_cur_context, + ) + else: + self.update_context( + conv_id=conv_id, + method=self.context_update_method, + cur_context=cur_context, + ) + return None + else: + context = self.pre_context_cache[conv_id] + + # Generate answer + answer_start = time() + answer = self.locomo_response(frame, oai_client, context, query) + response_duration_ms = (time() - answer_start) * 1000 + + can_answer, can_answer_duration_ms = self.eval_context( + context=context, query=query, gold_answer=gold_answer, oai_client=oai_client + ) + + # Record case for memos_scheduler + try: + recording_case = RecordingCase( + conv_id=conv_id, + query=query, + answer=answer, + context=cur_context, + pre_context=self.pre_context_cache[conv_id], + can_answer=can_answer, + can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", + search_duration_ms=search_duration_ms, + can_answer_duration_ms=can_answer_duration_ms, + response_duration_ms=response_duration_ms, + category=int(qa_category) if qa_category is not None else None, + golden_answer=str(qa.get("answer", "")), + ) + if can_answer: + self.can_answer_cases.append(recording_case) + else: + self.cannot_answer_cases.append(recording_case) + except Exception as e: + logger.error(f"Error creating RecordingCase: {e}") + print(f"Error creating RecordingCase: {e}") + logger.error(f"QA data: {qa}") + print(f"QA data: {qa}") + logger.error(f"Query: {query}") + logger.error(f"Answer: {answer}") + logger.error( + f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" + ) + logger.error(f"Category: {qa_category} (type: {type(qa_category)})") + logger.error(f"Can answer: {can_answer}") + raise e + + if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: + answer_from_cur_context = self.locomo_response(frame, oai_client, cur_context, query) + answer = answer_from_cur_context + # Update conversation stats and context + self._update_stats_and_context( + conv_id=conv_id, + frame=frame, + version=version, + conv_stats=conv_stats, + conv_stats_path=conv_stats_path, + query=query, + answer=answer, + gold_answer=gold_answer, + cur_context=cur_context, + can_answer=can_answer, + ) + return { "question": query, "answer": answer, @@ -233,7 +279,7 @@ def _process_single_qa( "response_duration_ms": response_duration_ms, "search_duration_ms": search_duration_ms, "can_answer_duration_ms": can_answer_duration_ms, - "can_answer": can_answer if frame == "memos_scheduler" else None, + "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None, } def run_locomo_processing(self, num_users=10): diff --git a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py index 4ec7d4922..f8db11fbc 100644 --- a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py +++ b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py @@ -16,7 +16,6 @@ from .constants import ( BASE_DIR, - MEMOS_MODEL, MEMOS_SCHEDULER_MODEL, ) from .prompts import ( @@ -42,10 +41,9 @@ def __init__(self, args): self.top_k = self.args.top_k # attributes - if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - self.context_update_method = ContextUpdateMethod.DIRECT - else: - self.context_update_method = ContextUpdateMethod.TEMPLATE + self.context_update_method = getattr( + self.args, "context_update_method", ContextUpdateMethod.PRE_CONTEXT + ) self.custom_instructions = CUSTOM_INSTRUCTIONS self.data_dir = Path(f"{BASE_DIR}/data") self.locomo_df = pd.read_json(f"{self.data_dir}/locomo/locomo10.json") @@ -64,7 +62,7 @@ def __init__(self, args): # Configure result dir; if scheduler disabled and using memos scheduler, mark as ablation if ( hasattr(self.args, "scheduler_flag") - and self.frame == "memos_scheduler" + and self.frame == MEMOS_SCHEDULER_MODEL and self.args.scheduler_flag is False ): self.result_dir = Path( @@ -74,6 +72,11 @@ def __init__(self, args): self.result_dir = Path( f"{BASE_DIR}/results/temporal_locomo/{self.frame}-{self.version}/" ) + + if self.context_update_method != ContextUpdateMethod.PRE_CONTEXT: + self.result_dir = ( + self.result_dir.parent / f"{self.result_dir.name}_{self.context_update_method}" + ) self.result_dir.mkdir(parents=True, exist_ok=True) self.search_path = self.result_dir / f"{self.frame}-{self.version}_search_results.json" @@ -135,10 +138,6 @@ def __init__(self, args): # Statistics tracking with thread safety self.stats = {self.frame: {self.version: defaultdict(dict)}} - self.stats[self.frame][self.version]["response_stats"] = defaultdict(dict) - self.stats[self.frame][self.version]["response_stats"]["response_failure"] = 0 - self.stats[self.frame][self.version]["response_stats"]["response_count"] = 0 - self.stats[self.frame][self.version]["memory_stats"] = defaultdict(dict) self.stats[self.frame][self.version]["memory_stats"]["total_queries"] = 0 self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] = 0 diff --git a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py index c824fe5f4..b05243a11 100644 --- a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py +++ b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py @@ -194,6 +194,10 @@ def memos_scheduler_search( start = time.time() client: MOS = client + if not self.scheduler_flag: + # if not scheduler_flag, search to update working memory + self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client) + # Search for speaker A search_a_results = client.mem_scheduler.search_for_eval( query=query, @@ -527,6 +531,25 @@ def process_qa(qa): json.dump(dict(search_results), fw, indent=2) print(f"Save search results {conv_id}") + search_durations = [] + for result in response_results[conv_id]: + if "search_duration_ms" in result: + search_durations.append(result["search_duration_ms"]) + + if search_durations: + avg_search_duration = sum(search_durations) / len(search_durations) + with self.stats_lock: + if self.stats[self.frame][self.version]["memory_stats"]["avg_search_duration_ms"]: + self.stats[self.frame][self.version]["memory_stats"][ + "avg_search_duration_ms" + ] = ( + self.stats[self.frame][self.version]["memory_stats"][ + "avg_search_duration_ms" + ] + + avg_search_duration + ) / 2 + print(f"Average search duration: {avg_search_duration:.2f} ms") + # Dump stats after processing each user self.save_stats() diff --git a/evaluation/scripts/temporal_locomo/modules/schemas.py b/evaluation/scripts/temporal_locomo/modules/schemas.py index e5872c35d..a41b7539d 100644 --- a/evaluation/scripts/temporal_locomo/modules/schemas.py +++ b/evaluation/scripts/temporal_locomo/modules/schemas.py @@ -1,14 +1,23 @@ -from enum import Enum from typing import Any from pydantic import BaseModel, Field -class ContextUpdateMethod(Enum): +class ContextUpdateMethod: """Enumeration for context update methods""" - DIRECT = "direct" # Directly update with current context - TEMPLATE = "chat_history" # Update using template with history queries and answers + PRE_CONTEXT = "pre_context" + CHAT_HISTORY = "chat_history" + CURRENT_CONTEXT = "current_context" + + @classmethod + def values(cls): + """Return a list of all constant values""" + return [ + getattr(cls, attr) + for attr in dir(cls) + if not attr.startswith("_") and isinstance(getattr(cls, attr), str) + ] class RecordingCase(BaseModel): @@ -22,11 +31,6 @@ class RecordingCase(BaseModel): # Conversation identification conv_id: str = Field(description="Conversation identifier for this evaluation case") - # Conversation history and context - history_queries: list[str] = Field( - default_factory=list, description="List of previous queries in the conversation history" - ) - context: str = Field( default="", description="Current search context retrieved from memory systems for answering the query", @@ -42,16 +46,6 @@ class RecordingCase(BaseModel): answer: str = Field(description="The generated answer for the query") - # Memory data - memories: list[Any] = Field( - default_factory=list, - description="Current memories retrieved from the memory system for this query", - ) - - pre_memories: list[Any] | None = Field( - default=None, description="Previous memories from the last query, used for comparison" - ) - # Evaluation metrics can_answer: bool | None = Field( default=None, diff --git a/evaluation/scripts/temporal_locomo/modules/thread_race.py b/evaluation/scripts/temporal_locomo/modules/thread_race.py new file mode 100644 index 000000000..66aab4652 --- /dev/null +++ b/evaluation/scripts/temporal_locomo/modules/thread_race.py @@ -0,0 +1,134 @@ +import random +import threading +import time + + +class ThreadRace: + def __init__(self): + # Variable to store the result + self.result = None + # Event to mark if the race is finished + self.race_finished = threading.Event() + # Lock to protect the result variable + self.lock = threading.Lock() + # Store thread objects for termination + self.threads = {} + # Stop flags for each thread + self.stop_flags = {} + + def task1(self, stop_flag): + """First task function, can be modified as needed""" + # Simulate random work time + sleep_time = random.uniform(0.1, 2.0) + + # Break the sleep into smaller chunks to check stop flag + chunks = 20 + chunk_time = sleep_time / chunks + + for _ in range(chunks): + # Check if we should stop + if stop_flag.is_set(): + return None + time.sleep(chunk_time) + + return f"Task 1 completed in: {sleep_time:.2f} seconds" + + def task2(self, stop_flag): + """Second task function, can be modified as needed""" + # Simulate random work time + sleep_time = random.uniform(0.1, 2.0) + + # Break the sleep into smaller chunks to check stop flag + chunks = 20 + chunk_time = sleep_time / chunks + + for _ in range(chunks): + # Check if we should stop + if stop_flag.is_set(): + return None + time.sleep(chunk_time) + + return f"Task 2 completed in: {sleep_time:.2f} seconds" + + def worker(self, task_func, task_name): + """Worker thread function""" + # Create a stop flag for this task + stop_flag = threading.Event() + self.stop_flags[task_name] = stop_flag + + try: + # Execute the task with stop flag + result = task_func(stop_flag) + + # If the race is already finished or we were asked to stop, return immediately + if self.race_finished.is_set() or stop_flag.is_set(): + return None + + # Try to set the result (if no other thread has set it yet) + with self.lock: + if not self.race_finished.is_set(): + self.result = (task_name, result) + # Mark the race as finished + self.race_finished.set() + print(f"{task_name} won the race!") + + # Signal other threads to stop + for name, flag in self.stop_flags.items(): + if name != task_name: + print(f"Signaling {name} to stop") + flag.set() + + return self.result + + except Exception as e: + print(f"{task_name} encountered an error: {e}") + + return None + + def run_race(self): + """Start the competition and return the result of the fastest thread""" + # Reset state + self.race_finished.clear() + self.result = None + self.threads.clear() + self.stop_flags.clear() + + # Create threads + thread1 = threading.Thread(target=self.worker, args=(self.task1, "Thread 1")) + thread2 = threading.Thread(target=self.worker, args=(self.task2, "Thread 2")) + + # Record thread objects for later joining + self.threads["Thread 1"] = thread1 + self.threads["Thread 2"] = thread2 + + # Start threads + thread1.start() + thread2.start() + + # Wait for any thread to complete + while not self.race_finished.is_set(): + time.sleep(0.01) # Small delay to avoid high CPU usage + + # If all threads have ended but no result is set, there's a problem + if ( + not thread1.is_alive() + and not thread2.is_alive() + and not self.race_finished.is_set() + ): + print("All threads have ended, but there's no winner") + return None + + # Wait for all threads to end (with timeout to avoid infinite waiting) + thread1.join(timeout=1.0) + thread2.join(timeout=1.0) + + # Return the result + return self.result + + +# Usage example +if __name__ == "__main__": + race = ThreadRace() + result = race.run_race() + print(f"Winner: {result[0] if result else None}") + print(f"Result: {result[1] if result else None}") diff --git a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py index 0a2c20a0e..46385626c 100644 --- a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py +++ b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py @@ -5,13 +5,14 @@ from pathlib import Path -from locomo_eval import LocomoEvaluator -from locomo_ingestion import LocomoIngestor -from locomo_metric import LocomoMetric -from locomo_processor import LocomoProcessor from modules.locomo_eval_module import LocomoEvalModelModules +from modules.schemas import ContextUpdateMethod from modules.utils import compute_can_answer_count_by_pre_evidences +from evaluation.scripts.temporal_locomo.models.locomo_eval import LocomoEvaluator +from evaluation.scripts.temporal_locomo.models.locomo_ingestion import LocomoIngestor +from evaluation.scripts.temporal_locomo.models.locomo_metric import LocomoMetric +from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor from memos.log import get_logger @@ -29,8 +30,10 @@ def __init__(self, args): self.locomo_ingestor = LocomoIngestor(args=args) self.locomo_processor = LocomoProcessor(args=args) + self.locomo_evaluator = LocomoEvaluator(args=args) + self.locomo_metric = LocomoMetric(args=args) - def run_eval_pipeline(self): + def run_answer_hit_eval_pipeline(self, skip_ingestion=True, skip_processing=False): """ Run the complete evaluation pipeline including dataset conversion, data ingestion, and processing. @@ -50,46 +53,39 @@ def run_eval_pipeline(self): print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.") # Step 2: Data ingestion - print("\n" + "=" * 50) - print("Step 2: Data Ingestion") - print("=" * 50) - if not self.ingestion_storage_dir.exists() or not any(self.ingestion_storage_dir.iterdir()): - print(f"Directory {self.ingestion_storage_dir} not found, starting data ingestion...") + if not skip_ingestion: + print("\n" + "=" * 50) + print("Step 2: Data Ingestion") + print("=" * 50) self.locomo_ingestor.run_ingestion() - print("Data ingestion completed.") - else: - print( - f"Directory {self.ingestion_storage_dir} already exists and is not empty, skipping ingestion." - ) # Step 3: Processing and evaluation - print("\n" + "=" * 50) - print("Step 3: Processing and Evaluation") - print("=" * 50) - print("Running locomo processing to search and answer...") + if not skip_processing: + print("\n" + "=" * 50) + print("Step 3: Processing and Evaluation") + print("=" * 50) + print("Running locomo processing to search and answer...") - print("Starting locomo processing to generate search and response results...") - self.locomo_processor.run_locomo_processing(num_users=self.num_of_users) - print("Processing completed successfully.") + print("Starting locomo processing to generate search and response results...") + self.locomo_processor.run_locomo_processing(num_users=self.num_of_users) + print("Processing completed successfully.") # Optional: run post-hoc evaluation over generated responses if available try: - evaluator = LocomoEvaluator(args=args) - - if os.path.exists(evaluator.response_path): + if os.path.exists(self.response_path): print("Running LocomoEvaluator over existing response results...") - asyncio.run(evaluator.run()) + asyncio.run(self.locomo_evaluator.run()) else: print( f"Skipping LocomoEvaluator: response file not found at {evaluator.response_path}" ) # Run metrics summarization if judged file is produced - metric = LocomoMetric(args=args) - if os.path.exists(metric.judged_path): + + if os.path.exists(self.judged_path): print("Running LocomoMetric over judged results...") - metric.run() + self.locomo_metric.run() else: - print(f"Skipping LocomoMetric: judged file not found at {metric.judged_path}") + print(f"Skipping LocomoMetric: judged file not found at {self.judged_path}") except Exception as e: logger.error(f"LocomoEvaluator step skipped due to error: {e}", exc_info=True) @@ -103,6 +99,32 @@ def run_eval_pipeline(self): print(f" - Statistics: {self.stats_path}") print("=" * 80) + def run_inference_eval_pipeline(self, skip_ingestion=True, skip_processing=False): + """ + Run the complete evaluation pipeline including dataset conversion, + data ingestion, and processing. + """ + print("=" * 80) + print("Starting TimeLocomo Evaluation Pipeline") + print("=" * 80) + + # Step 1: Check if temporal_locomo dataset exists, if not convert it + temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json" + if not temporal_locomo_file.exists(): + print(f"Temporal locomo dataset not found at {temporal_locomo_file}") + print("Converting locomo dataset to temporal_locomo format...") + self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo") + print("Dataset conversion completed.") + else: + print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.") + + # Step 2: Data ingestion + if not skip_ingestion: + print("\n" + "=" * 50) + print("Step 2: Data Ingestion") + print("=" * 50) + self.locomo_ingestor.run_ingestion() + def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): """ Compute can-answer statistics per day for each conversation using the @@ -124,7 +146,7 @@ def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): parser.add_argument( "--frame", type=str, - default="memos_scheduler", + default="memos", choices=["zep", "memos", "mem0", "mem0_graph", "memos_scheduler"], help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", ) @@ -143,14 +165,17 @@ def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): parser.add_argument( "--scheduler-flag", action=argparse.BooleanOptionalAction, - default=True, + default=False, help="Enable or disable memory scheduler features", ) + parser.add_argument( + "--context_update_method", + type=str, + default="chat_history", + choices=ContextUpdateMethod.values(), + help="Method to update context: direct (use current context directly), chat_history (use template with history), current_context (use current context)", + ) args = parser.parse_args() evaluator = TemporalLocomoEval(args=args) - evaluator.run_eval_pipeline() - - # rule-based baselines - evaluator.compute_can_answer_count_by_pre_evidences(rounds_to_consider=float("inf")) - evaluator.compute_can_answer_count_by_pre_evidences(rounds_to_consider=1) + evaluator.run_answer_hit_eval_pipeline() diff --git a/examples/mem_api/pipeline_test.py b/examples/mem_api/pipeline_test.py new file mode 100644 index 000000000..cd7b3bee3 --- /dev/null +++ b/examples/mem_api/pipeline_test.py @@ -0,0 +1,178 @@ +""" +Pipeline test script for MemOS Server API functions. +This script directly tests add and search functionalities without going through the API layer. +If you want to start server_api set .env to MemOS/.env and run: +uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8002 --workers 4 +""" + +from typing import Any + +from dotenv import load_dotenv + +# Import directly from server_router to reuse initialized components +from memos.api.routers.server_router import ( + _create_naive_mem_cube, + mem_reader, +) +from memos.log import get_logger + + +# Load environment variables +load_dotenv() + +logger = get_logger(__name__) + + +def test_add_memories( + messages: list[dict[str, str]], + user_id: str, + mem_cube_id: str, + session_id: str = "default_session", +) -> list[str]: + """ + Test adding memories to the system. + + Args: + messages: List of message dictionaries with 'role' and 'content' + user_id: User identifier + mem_cube_id: Memory cube identifier + session_id: Session identifier + + Returns: + List of memory IDs that were added + """ + logger.info(f"Testing add memories for user: {user_id}, mem_cube: {mem_cube_id}") + + # Create NaiveMemCube using server_router function + naive_mem_cube = _create_naive_mem_cube() + + # Extract memories from messages using server_router's mem_reader + memories = mem_reader.get_memory( + [messages], + type="chat", + info={ + "user_id": user_id, + "session_id": session_id, + }, + ) + + # Flatten memory list + flattened_memories = [mm for m in memories for mm in m] + + # Add memories to the system + mem_id_list: list[str] = naive_mem_cube.text_mem.add( + flattened_memories, + user_name=mem_cube_id, + ) + + logger.info(f"Added {len(mem_id_list)} memories: {mem_id_list}") + + # Print details of added memories + for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False): + logger.info(f" - ID: {memory_id}") + logger.info(f" Memory: {memory.memory}") + logger.info(f" Type: {memory.metadata.memory_type}") + + return mem_id_list + + +def test_search_memories( + query: str, + user_id: str, + mem_cube_id: str, + session_id: str = "default_session", + top_k: int = 5, + mode: str = "fast", + internet_search: bool = False, + moscube: bool = False, + chat_history: list | None = None, +) -> list[Any]: + """ + Test searching memories from the system. + + Args: + query: Search query text + user_id: User identifier + mem_cube_id: Memory cube identifier + session_id: Session identifier + top_k: Number of top results to return + mode: Search mode + internet_search: Whether to enable internet search + moscube: Whether to enable moscube search + chat_history: Chat history for context + + Returns: + List of search results + """ + + # Create NaiveMemCube using server_router function + naive_mem_cube = _create_naive_mem_cube() + + # Prepare search filter + search_filter = {"session_id": session_id} if session_id != "default_session" else None + + search_results = naive_mem_cube.text_mem.search( + query=query, + user_name=mem_cube_id, + top_k=top_k, + mode=mode, + manual_close_internet=not internet_search, + moscube=moscube, + search_filter=search_filter, + info={ + "user_id": user_id, + "session_id": session_id, + "chat_history": chat_history or [], + }, + ) + + # Print search results + for idx, result in enumerate(search_results, 1): + logger.info(f"\n Result {idx}:") + logger.info(f" ID: {result.id}") + logger.info(f" Memory: {result.memory}") + logger.info(f" Score: {getattr(result, 'score', 'N/A')}") + logger.info(f" Type: {result.metadata.memory_type}") + + return search_results + + +def main(): + # Test parameters + user_id = "test_user_123" + mem_cube_id = "test_cube_123" + session_id = "test_session_001" + + test_messages = [ + {"role": "user", "content": "Where should I go for Christmas?"}, + { + "role": "assistant", + "content": "There are many places to visit during Christmas, such as the Bund and Disneyland in Shanghai.", + }, + {"role": "user", "content": "What about New Year's Eve?"}, + { + "role": "assistant", + "content": "For New Year's Eve, you could visit Times Square in New York or watch fireworks at the Sydney Opera House.", + }, + ] + + memory_ids = test_add_memories( + messages=test_messages, user_id=user_id, mem_cube_id=mem_cube_id, session_id=session_id + ) + + logger.info(f"\nSuccessfully added {len(memory_ids)} memories!") + + search_queries = [ + "How to enjoy Christmas?", + "Where to celebrate New Year?", + "What are good places to visit during holidays?", + ] + + for query in search_queries: + logger.info("\n" + "-" * 80) + results = test_search_memories(query=query, user_id=user_id, mem_cube_id=mem_cube_id) + print(f"Query: '{query}' returned {len(results)} results") + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index e6830016f..d34f964b6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -6310,4 +6310,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "d85cb8a08870d67df6e462610231f1e735ba5293bd3fe5b0c4a212b3ccff7b72" +content-hash = "d85cb8a08870d67df6e462610231f1e735ba5293bd3fe5b0c4a212b3ccff7b72" \ No newline at end of file diff --git a/src/memos/api/client.py b/src/memos/api/client.py index 5e7947ff5..d45276f2c 100644 --- a/src/memos/api/client.py +++ b/src/memos/api/client.py @@ -14,7 +14,6 @@ MAX_RETRY_COUNT = 3 - class MemOSClient: """MemOS API client""" diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py index 681644a0d..709ad74fb 100644 --- a/src/memos/api/product_api.py +++ b/src/memos/api/product_api.py @@ -33,6 +33,6 @@ parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=8001) - parser.add_argument("--workers", type=int, default=32) + parser.add_argument("--workers", type=int, default=1) args = parser.parse_args() uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 7e425415b..eb2d7aa6d 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field # Import message types from core types module -from memos.types import MessageDict +from memos.types import MessageDict, PermissionDict T = TypeVar("T") @@ -164,6 +164,39 @@ class SearchRequest(BaseRequest): session_id: str | None = Field(None, description="Session ID for soft-filtering memories") +class APISearchRequest(BaseRequest): + """Request model for searching memories.""" + + query: str = Field(..., description="Search query") + user_id: str = Field(None, description="User ID") + mem_cube_id: str | None = Field(None, description="Cube ID to search in") + mode: str = Field("fast", description="search mode fast or fine") + internet_search: bool = Field(False, description="Whether to use internet search") + moscube: bool = Field(False, description="Whether to use MemOSCube") + top_k: int = Field(10, description="Number of results to return") + chat_history: list[MessageDict] | None = Field(None, description="Chat history") + session_id: str | None = Field(None, description="Session ID for soft-filtering memories") + operation: list[PermissionDict] | None = Field( + None, description="operation ids for multi cubes" + ) + + +class APIADDRequest(BaseRequest): + """Request model for creating memories.""" + + user_id: str = Field(None, description="User ID") + mem_cube_id: str = Field(..., description="Cube ID") + messages: list[MessageDict] | None = Field(None, description="List of messages to store.") + memory_content: str | None = Field(None, description="Memory content to store") + doc_path: str | None = Field(None, description="Path to document to store") + source: str | None = Field(None, description="Source of the memory") + chat_history: list[MessageDict] | None = Field(None, description="Chat history") + session_id: str | None = Field(None, description="Session id") + operation: list[PermissionDict] | None = Field( + None, description="operation ids for multi cubes" + ) + + class SuggestionRequest(BaseRequest): """Request model for getting suggestion queries.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py new file mode 100644 index 000000000..1d398ff72 --- /dev/null +++ b/src/memos/api/routers/server_router.py @@ -0,0 +1,282 @@ +import os + +from typing import Any + +from fastapi import APIRouter + +from memos.api.config import APIConfig +from memos.api.product_models import ( + APIADDRequest, + APISearchRequest, + MemoryResponse, + SearchResponse, +) +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.reranker import RerankerConfigFactory +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.log import get_logger +from memos.mem_cube.navie import NaiveMemCube +from memos.mem_reader.factory import MemReaderFactory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory +from memos.types import MOSSearchResult, UserContext + + +logger = get_logger(__name__) + +router = APIRouter(prefix="/product", tags=["Server API"]) + + +def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: + """Build graph database configuration.""" + graph_db_backend_map = { + "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), + "neo4j": APIConfig.get_neo4j_config(user_id=user_id), + "nebular": APIConfig.get_nebular_config(user_id=user_id), + } + + graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + return GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } + ) + + +def _build_llm_config() -> dict[str, Any]: + """Build LLM configuration.""" + return LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } + ) + + +def _build_embedder_config() -> dict[str, Any]: + """Build embedder configuration.""" + return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + +def _build_mem_reader_config() -> dict[str, Any]: + """Build memory reader configuration.""" + return MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] + ) + + +def _build_reranker_config() -> dict[str, Any]: + """Build reranker configuration.""" + return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + +def _build_internet_retriever_config() -> dict[str, Any]: + """Build internet retriever configuration.""" + return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) + + +def _get_default_memory_size(cube_config) -> dict[str, int]: + """Get default memory size configuration.""" + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + +def init_server(): + """Initialize server components and configurations.""" + # Get default cube configuration + default_cube_config = APIConfig.get_default_cube_config() + + # Build component configurations + graph_db_config = _build_graph_db_config() + print(graph_db_config) + llm_config = _build_llm_config() + embedder_config = _build_embedder_config() + mem_reader_config = _build_mem_reader_config() + reranker_config = _build_reranker_config() + internet_retriever_config = _build_internet_retriever_config() + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder + ) + + # Initialize memory manager + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=_get_default_memory_size(default_cube_config), + is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), + ) + + return ( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, + ) + + +# Initialize global components +( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, +) = init_server() + + +def _create_naive_mem_cube() -> NaiveMemCube: + """Create a NaiveMemCube instance with initialized components.""" + naive_mem_cube = NaiveMemCube( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + internet_retriever=internet_retriever, + memory_manager=memory_manager, + default_cube_config=default_cube_config, + ) + return naive_mem_cube + + +def _format_memory_item(memory_data: Any) -> dict[str, Any]: + """Format a single memory item for API response.""" + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + +@router.post("/search", summary="Search memories", response_model=SearchResponse) +def search_memories(search_req: APISearchRequest): + """Search memories for a specific user.""" + # Create UserContext object - how to assign values + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=search_req.mem_cube_id, + session_id=search_req.session_id or "default_session", + ) + logger.info(f"Search user_id is: {user_context.mem_cube_id}") + memories_result: MOSSearchResult = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + } + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + naive_mem_cube = _create_naive_mem_cube() + search_results = naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=search_req.mode, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + formatted_memories = [_format_memory_item(data) for data in search_results] + + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": formatted_memories, + } + ) + + return SearchResponse( + message="Search completed successfully", + data=memories_result, + ) + + +@router.post("/add", summary="Add memories", response_model=MemoryResponse) +def add_memories(add_req: APIADDRequest): + """Add memories for a specific user.""" + # Create UserContext object - how to assign values + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=add_req.mem_cube_id, + session_id=add_req.session_id or "default_session", + ) + naive_mem_cube = _create_naive_mem_cube() + target_session_id = add_req.session_id + if not target_session_id: + target_session_id = "default_session" + memories = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + + # Flatten memory list + flattened_memories = [mm for m in memories for mm in m] + logger.info(f"Memory extraction completed for user {add_req.user_id}") + mem_id_list: list[str] = naive_mem_cube.text_mem.add( + flattened_memories, + user_name=user_context.mem_cube_id, + ) + + logger.info( + f"Added {len(mem_id_list)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_id_list}" + ) + response_data = [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) + ] + return MemoryResponse( + message="Memory added successfully", + data=response_data, + ) diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py new file mode 100644 index 000000000..78e05ef85 --- /dev/null +++ b/src/memos/api/server_api.py @@ -0,0 +1,38 @@ +import logging + +from fastapi import FastAPI + +from memos.api.exceptions import APIExceptionHandler +from memos.api.middleware.request_context import RequestContextMiddleware +from memos.api.routers.server_router import router as server_router + + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +app = FastAPI( + title="MemOS Product REST APIs", + description="A REST API for managing multiple users with MemOS Product.", + version="1.0.1", +) + +app.add_middleware(RequestContextMiddleware) +# Include routers +app.include_router(server_router) + +# Exception handlers +app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) +app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) + + +if __name__ == "__main__": + import argparse + + import uvicorn + + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--workers", type=int, default=1) + args = parser.parse_args() + uvicorn.run("memos.api.server_api:app", host="0.0.0.0", port=args.port, workers=args.workers) diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index a36f3e2f8..82616ac93 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -11,7 +11,7 @@ BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, - DEFAULT_THREAD__POOL_MAX_WORKERS, + DEFAULT_THREAD_POOL_MAX_WORKERS, ) @@ -25,12 +25,12 @@ class BaseSchedulerConfig(BaseConfig): default=True, description="Whether to enable parallel message processing using thread pool" ) thread_pool_max_workers: int = Field( - default=DEFAULT_THREAD__POOL_MAX_WORKERS, + default=DEFAULT_THREAD_POOL_MAX_WORKERS, gt=1, lt=20, - description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD__POOL_MAX_WORKERS})", + description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD_POOL_MAX_WORKERS})", ) - consume_interval_seconds: int = Field( + consume_interval_seconds: float = Field( default=DEFAULT_CONSUME_INTERVAL_SECONDS, gt=0, le=60, diff --git a/src/memos/configs/mem_user.py b/src/memos/configs/mem_user.py index 3ff1066e5..6e1ca4206 100644 --- a/src/memos/configs/mem_user.py +++ b/src/memos/configs/mem_user.py @@ -31,6 +31,17 @@ class MySQLUserManagerConfig(BaseUserManagerConfig): charset: str = Field(default="utf8mb4", description="MySQL charset") +class RedisUserManagerConfig(BaseUserManagerConfig): + """Redis user manager configuration.""" + + host: str = Field(default="localhost", description="Redis server host") + port: int = Field(default=6379, description="Redis server port") + username: str = Field(default="root", description="Redis username") + password: str = Field(default="", description="Redis password") + database: str = Field(default="memos_users", description="Redis database name") + charset: str = Field(default="utf8mb4", description="Redis charset") + + class UserManagerConfigFactory(BaseModel): """Factory for user manager configurations.""" @@ -42,6 +53,7 @@ class UserManagerConfigFactory(BaseModel): backend_to_class: ClassVar[dict[str, Any]] = { "sqlite": SQLiteUserManagerConfig, "mysql": MySQLUserManagerConfig, + "redis": RedisUserManagerConfig, } @field_validator("backend") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 1eea6deaf..237450e15 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -180,6 +180,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ) +class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): + """Simple tree text memory configuration class.""" + + # ─── 3. Global Memory Config Factory ────────────────────────────────────────── @@ -192,6 +196,7 @@ class MemoryConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "naive_text": NaiveTextMemoryConfig, "general_text": GeneralTextMemoryConfig, + "simple_tree_text": SimpleTreeTextMemoryConfig, "tree_text": TreeTextMemoryConfig, "kv_cache": KVCacheMemoryConfig, "vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index b43298d9b..dd1748714 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -39,6 +39,18 @@ def set_default_path(self): return self +class MilvusVecDBConfig(BaseVecDBConfig): + """Configuration for Milvus vector database.""" + + uri: str = Field(..., description="URI for Milvus connection") + collection_name: list[str] = Field(..., description="Name(s) of the collection(s)") + max_length: int = Field( + default=65535, description="Maximum length for string fields (varChar type)" + ) + user_name: str = Field(default="", description="User name for Milvus connection") + password: str = Field(default="", description="Password for Milvus connection") + + class VectorDBConfigFactory(BaseConfig): """Factory class for creating vector database configurations.""" @@ -47,6 +59,7 @@ class VectorDBConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "qdrant": QdrantVecDBConfig, + "milvus": MilvusVecDBConfig, } @field_validator("backend") diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 66ad894ad..a6f6b82a4 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -129,7 +129,6 @@ def _make_client_key(cfg: NebulaGraphDBConfig) -> str: "nebula-sync", ",".join(hosts), str(getattr(cfg, "user", "")), - str(getattr(cfg, "use_multi_db", False)), str(getattr(cfg, "space", "")), ] ) @@ -139,7 +138,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " tmp = object.__new__(NebulaGraphDB) tmp.config = cfg tmp.db_name = cfg.space - tmp.user_name = getattr(cfg, "user_name", None) + tmp.user_name = None tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072) tmp.default_memory_dimension = 3072 tmp.common_fields = { @@ -169,7 +168,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) else "embedding" ) - tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space + tmp.system_db_name = cfg.space tmp._client = client tmp._owns_client = False return tmp @@ -417,7 +416,9 @@ def create_index( self._create_basic_property_indexes() @timed - def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: """ Remove all WorkingMemory nodes except the latest `keep_latest` entries. @@ -426,11 +427,12 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: keep_latest (int): Number of latest WorkingMemory entries to keep. """ optional_condition = "" - if not self.config.use_multi_db and self.config.user_name: - optional_condition = f"AND n.user_name = '{self.config.user_name}'" + user_name = user_name if user_name else self.config.user_name + + optional_condition = f"AND n.user_name = '{user_name}'" query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.memory_type = '{memory_type}' {optional_condition} ORDER BY n.updated_at DESC @@ -440,13 +442,13 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: self.execute_query(query) @timed - def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: """ Insert or update a Memory node in NebulaGraph. """ - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name - + metadata["user_name"] = user_name if user_name else self.config.user_name now = datetime.utcnow() metadata = metadata.copy() metadata.setdefault("created_at", now) @@ -475,11 +477,9 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: ) @timed - def node_not_exist(self, scope: str) -> int: - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"' - else: - filter_clause = f'n.memory_type = "{scope}"' + def node_not_exist(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name + filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"' query = f""" MATCH (n@Memory) WHERE {filter_clause} @@ -495,10 +495,11 @@ def node_not_exist(self, scope: str) -> int: raise @timed - def update_node(self, id: str, fields: dict[str, Any]) -> None: + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: """ Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present. """ + user_name = user_name if user_name else self.config.user_name fields = fields.copy() set_clauses = [] for k, v in fields.items(): @@ -509,45 +510,41 @@ def update_node(self, id: str, fields: dict[str, Any]) -> None: query = f""" MATCH (n@Memory {{id: "{id}"}}) """ - - if not self.config.use_multi_db and self.config.user_name: - query += f'WHERE n.user_name = "{self.config.user_name}"' + query += f'WHERE n.user_name = "{user_name}"' query += f"\nSET {set_clause_str}" self.execute_query(query) @timed - def delete_node(self, id: str) -> None: + def delete_node(self, id: str, user_name: str | None = None) -> None: """ Delete a node from the graph. Args: id: Node identifier to delete. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name query = f""" - MATCH (n@Memory {{id: "{id}"}}) + MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)} + DETACH DELETE n """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" WHERE n.user_name = {self._format_value(user_name)}" - query += "\n DETACH DELETE n" self.execute_query(query) @timed - def add_edge(self, source_id: str, target_id: str, type: str): + def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None): """ Create an edge from source node to target node. Args: source_id: ID of the source node. target_id: ID of the target node. type: Relationship type (e.g., 'RELATE_TO', 'PARENT'). + user_name (str, optional): User name for filtering in non-multi-db mode """ if not source_id or not target_id: raise ValueError("[add_edge] source_id and target_id must be provided") - + user_name = user_name if user_name else self.config.user_name props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' - + props = f'{{user_name: "{user_name}"}}' insert_stmt = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT (a) -[e@{type} {props}]-> (b) @@ -558,35 +555,35 @@ def add_edge(self, source_id: str, target_id: str, type: str): logger.error(f"Failed to insert edge: {e}", exc_info=True) @timed - def delete_edge(self, source_id: str, target_id: str, type: str) -> None: + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: """ Delete a specific edge between two nodes. Args: source_id: ID of the source node. target_id: ID of the target node. type: Relationship type to remove. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (a@Memory) -[r@{type}]-> (b@Memory) WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" - + query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" query += "\nDELETE r" self.execute_query(query) @timed - def get_memory_count(self, memory_type: str) -> int: + def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (n@Memory) WHERE n.memory_type = "{memory_type}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN COUNT(n) AS count" try: @@ -597,14 +594,13 @@ def get_memory_count(self, memory_type: str) -> int: return -1 @timed - def count_nodes(self, scope: str) -> int: + def count_nodes(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (n@Memory) WHERE n.memory_type = "{scope}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN count(n) AS count" result = self.execute_query(query) @@ -612,7 +608,12 @@ def count_nodes(self, scope: str) -> int: @timed def edge_exists( - self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING" + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str | None = None, ) -> bool: """ Check if an edge exists between two nodes. @@ -622,10 +623,12 @@ def edge_exists( type: Relationship type. Use "ANY" to match any relationship type. direction: Direction of the edge. Use "OUTGOING" (default), "INCOMING", or "ANY". + user_name (str, optional): User name for filtering in non-multi-db mode Returns: True if the edge exists, otherwise False. """ # Prepare the relationship pattern + user_name = user_name if user_name else self.config.user_name rel = "r" if type == "ANY" else f"r@{type}" # Prepare the match pattern with direction @@ -640,9 +643,7 @@ def edge_exists( f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." ) query = f"MATCH {pattern}" - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query += "\nRETURN r" # Run the Cypher query @@ -654,22 +655,22 @@ def edge_exists( @timed # Graph Query & Reasoning - def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None: + def get_node( + self, id: str, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any] | None: """ Retrieve a Memory node by its unique ID. Args: id (str): Node ID (Memory.id) include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: dict: Node properties as key-value pairs, or None if not found. """ - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"' - else: - filter_clause = f'n.id = "{id}"' - + user_name = user_name if user_name else self.config.user_name + filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" MATCH (n@Memory) @@ -692,13 +693,18 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | @timed def get_nodes( - self, ids: list[str], include_embedding: bool = False, **kwargs + self, + ids: list[str], + include_embedding: bool = False, + user_name: str | None = None, + **kwargs, ) -> list[dict[str, Any]]: """ Retrieve the metadata and memory of a list of nodes. Args: ids: List of Node identifier. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. @@ -709,19 +715,14 @@ def get_nodes( if not ids: return [] - where_user = "" - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_user = f" AND n.user_name = '{kwargs['cube_name']}'" - else: - where_user = f" AND n.user_name = '{self.config.user_name}'" - + user_name = user_name if user_name else self.config.user_name + where_user = f" AND n.user_name = '{user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.id IN [{id_list}] {where_user} RETURN {return_fields} """ @@ -738,7 +739,9 @@ def get_nodes( return nodes @timed - def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]: + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None + ) -> list[dict[str, str]]: """ Get edges connected to a node, with optional type and direction filter. @@ -746,6 +749,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ id: Node ID to retrieve edges for. type: Relationship type to match, or 'ANY' to match all. direction: 'OUTGOING', 'INCOMING', or 'ANY'. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: List of edges: @@ -756,7 +760,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ """ # Build relationship type filter rel_type = "" if type == "ANY" else f"@{type}" - + user_name = user_name if user_name else self.config.user_name # Build Cypher pattern based on direction if direction == "OUTGOING": pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)" @@ -770,8 +774,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ else: raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'" + where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query = f""" MATCH {pattern} @@ -799,6 +802,7 @@ def get_neighbors_by_tag( top_k: int = 5, min_overlap: int = 1, include_embedding: bool = False, + user_name: str | None = None, ) -> list[dict[str, Any]]: """ Find top-K neighbor nodes with maximum tag overlap. @@ -809,13 +813,14 @@ def get_neighbors_by_tag( top_k: Max number of neighbors to return. min_overlap: Minimum number of overlapping tags required. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: List of dicts with node details and overlap count. """ if not tags: return [] - + user_name = user_name if user_name else self.config.user_name where_clauses = [ 'n.status = "activated"', 'NOT (n.node_type = "reasoning")', @@ -824,8 +829,7 @@ def get_neighbors_by_tag( if exclude_ids: where_clauses.append(f"NOT (n.id IN {exclude_ids})") - if not self.config.use_multi_db and self.config.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') where_clause = " AND ".join(where_clauses) tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" @@ -859,12 +863,11 @@ def get_neighbors_by_tag( return result @timed - def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: - where_user = "" - - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + def get_children_with_embeddings( + self, id: str, user_name: str | None = None + ) -> list[dict[str, Any]]: + user_name = user_name if user_name else self.config.user_name + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" query = f""" MATCH (p@Memory)-[@PARENT]->(c@Memory) @@ -884,7 +887,11 @@ def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: @timed def get_subgraph( - self, center_id: str, depth: int = 2, center_status: str = "activated" + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, ) -> dict[str, Any]: """ Retrieve a local subgraph centered at a given node. @@ -892,6 +899,7 @@ def get_subgraph( center_id: The ID of the center node. depth: The hop distance for neighbors. center_status: Required status for center node. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: { "core_node": {...}, @@ -902,7 +910,8 @@ def get_subgraph( if not 1 <= depth <= 5: raise ValueError("depth must be 1-5") - user_name = self.config.user_name + user_name = user_name if user_name else self.config.user_name + gql = f""" MATCH (center@Memory) WHERE center.id = '{center_id}' @@ -954,6 +963,7 @@ def search_by_embedding( status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, + user_name: str | None = None, **kwargs, ) -> list[dict]: """ @@ -968,6 +978,7 @@ def search_by_embedding( threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. Keys should match node properties, values are the expected values. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. @@ -981,42 +992,35 @@ def search_by_embedding( - Typical use case: restrict to 'status = activated' to avoid matching archived or merged nodes. """ + user_name = user_name if user_name else self.config.user_name vector = _normalize(vector) dim = len(vector) vector_str = ",".join(f"{float(x)}" for x in vector) gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - - where_clauses = [] + where_clauses = [f"n.{self.dim_field} IS NOT NULL"] if scope: where_clauses.append(f'n.memory_type = "{scope}"') if status: where_clauses.append(f'n.status = "{status}"') - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"') - else: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append(f'n.{key} = "{value}"') + else: + where_clauses.append(f"n.{key} = {value}") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - MATCH (n@Memory) - {where_clause} - ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC - APPROXIMATE - LIMIT {top_k} - OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }} - RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score - """ - + let a = {gql_vector} + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) + {where_clause} + ORDER BY inner_product(n.{self.dim_field}, a) DESC + LIMIT {top_k} + RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" try: result = self.execute_query(gql) except Exception as e: @@ -1038,7 +1042,9 @@ def search_by_embedding( return [] @timed - def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: + def get_by_metadata( + self, filters: list[dict[str, Any]], user_name: str | None = None + ) -> list[str]: """ 1. ADD logic: "AND" vs "OR"(support logic combination); 2. Support nested conditional expressions; @@ -1054,6 +1060,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: {"field": "tags", "op": "contains", "value": "AI"}, ... ] + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). @@ -1063,7 +1070,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: - Can be used for faceted recall or prefiltering before embedding rerank. """ where_clauses = [] - + user_name = user_name if user_name else self.config.user_name for _i, f in enumerate(filters): field = f["field"] op = f.get("op", "=") @@ -1087,11 +1094,10 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: else: raise ValueError(f"Unsupported operator: {op}") - if not self.config.use_multi_db and self.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id" + gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" ids = [] try: result = self.execute_query(gql) @@ -1106,6 +1112,7 @@ def get_grouped_counts( group_fields: list[str], where_clause: str = "", params: dict[str, Any] | None = None, + user_name: str | None = None, ) -> list[dict[str, Any]]: """ Count nodes grouped by any fields. @@ -1115,24 +1122,24 @@ def get_grouped_counts( where_clause (str, optional): Extra WHERE condition. E.g., "WHERE n.status = 'activated'" params (dict, optional): Parameters for WHERE clause. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] """ if not group_fields: raise ValueError("group_fields cannot be empty") - - # GQL-specific modifications - if not self.config.use_multi_db and self.config.user_name: - user_clause = f"n.user_name = '{self.config.user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" + user_name = user_name if user_name else self.config.user_name + # GQL-specific modifications + user_clause = f"n.user_name = '{user_name}'" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" else: - where_clause = f"WHERE {user_clause}" + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" # Inline parameters if provided if params: @@ -1151,7 +1158,7 @@ def get_grouped_counts( group_by_fields.append(alias) # Full GQL query construction gql = f""" - MATCH (n) + MATCH (n /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {", ".join(return_fields)}, COUNT(n) AS count GROUP BY {", ".join(group_by_fields)} @@ -1170,16 +1177,16 @@ def get_grouped_counts( return output @timed - def clear(self) -> None: + def clear(self, user_name: str | None = None) -> None: """ Clear the entire graph if the target database exists. + + Args: + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name try: - if not self.config.use_multi_db and self.config.user_name: - query = f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n" - else: - query = "MATCH (n) DETACH DELETE n" - + query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n" self.execute_query(query) logger.info("Cleared all nodes from database.") @@ -1187,11 +1194,14 @@ def clear(self) -> None: logger.error(f"[ERROR] Failed to clear database: {e}") @timed - def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: + def export_graph( + self, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. Args: include_embedding (bool): Whether to include the large embedding field. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: { @@ -1199,13 +1209,11 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] } """ + user_name = user_name if user_name else self.config.user_name node_query = "MATCH (n@Memory)" edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - - if not self.config.use_multi_db and self.config.user_name: - username = self.config.user_name - node_query += f' WHERE n.user_name = "{username}"' - edge_query += f' WHERE r.user_name = "{username}"' + node_query += f' WHERE n.user_name = "{user_name}"' + edge_query += f' WHERE r.user_name = "{user_name}"' try: if include_embedding: @@ -1265,20 +1273,19 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: return {"nodes": nodes, "edges": edges} @timed - def import_graph(self, data: dict[str, Any]) -> None: + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: """ Import the entire graph from a serialized dictionary. Args: data: A dictionary containing all nodes and edges to be loaded. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name for node in data.get("nodes", []): try: id, memory, metadata = _compose_node(node) - - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name - + metadata["user_name"] = user_name metadata = self._prepare_node_metadata(metadata) metadata.update({"id": id, "memory": memory}) properties = ", ".join( @@ -1293,9 +1300,7 @@ def import_graph(self, data: dict[str, Any]) -> None: try: source_id, target_id = edge["source"], edge["target"] edge_type = edge["type"] - props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' + props = f'{{user_name: "{user_name}"}}' edge_gql = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) @@ -1305,29 +1310,31 @@ def import_graph(self, data: dict[str, Any]) -> None: logger.error(f"Fail to load edge: {edge}, error: {e}") @timed - def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]: + def get_all_memory_items( + self, scope: str, include_embedding: bool = False, user_name: str | None = None + ) -> (list)[dict]: """ Retrieve all memory items of a specific memory_type. Args: scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: Full list of memory items under this scope. """ + user_name = user_name if user_name else self.config.user_name if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") where_clause = f"WHERE n.memory_type = '{scope}'" - - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND n.user_name = '{self.config.user_name}'" + where_clause += f" AND n.user_name = '{user_name}'" return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {return_fields} LIMIT 100 @@ -1344,20 +1351,19 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> ( @timed def get_structure_optimization_candidates( - self, scope: str, include_embedding: bool = False + self, scope: str, include_embedding: bool = False, user_name: str | None = None ) -> list[dict]: """ Find nodes that are likely candidates for structure optimization: - Isolated nodes, nodes with empty background, or nodes with exactly one child. - Plus: the child of any parent node that has exactly one child. """ - + user_name = user_name if user_name else self.config.user_name where_clause = f''' n.memory_type = "{scope}" AND n.status = "activated" ''' - if not self.config.use_multi_db and self.config.user_name: - where_clause += f' AND n.user_name = "{self.config.user_name}"' + where_clause += f' AND n.user_name = "{user_name}"' return_fields = self._build_return_fields(include_embedding) return_fields += f", n.{self.dim_field} AS {self.dim_field}" @@ -1386,21 +1392,6 @@ def get_structure_optimization_candidates( logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") return candidates - @timed - def drop_database(self) -> None: - """ - Permanently delete the entire database this instance is using. - WARNING: This operation is destructive and cannot be undone. - """ - if self.config.use_multi_db: - self.execute_query(f"DROP GRAPH `{self.db_name}`") - logger.info(f"Database '`{self.db_name}`' has been dropped.") - else: - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) - @timed def detect_conflicts(self) -> list[tuple[str, str]]: """ @@ -1585,9 +1576,7 @@ def _create_basic_property_indexes(self) -> None: Create standard B-tree indexes on user_name when use Shared Database Multi-Tenant Mode. """ - fields = ["status", "memory_type", "created_at", "updated_at"] - if not self.config.use_multi_db: - fields.append("user_name") + fields = ["status", "memory_type", "created_at", "updated_at", "user_name"] for field in fields: index_name = f"idx_memory_{field}" diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 96908913d..ccc91c48b 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -38,6 +38,10 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: if embedding and isinstance(embedding, list): metadata["embedding"] = [float(x) for x in embedding] + # serialization + if metadata["sources"]: + for idx in range(len(metadata["sources"])): + metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) return metadata diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 8acab420c..54000a51d 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -1,3 +1,4 @@ +import json from typing import Any from memos.configs.graph_db import Neo4jGraphDBConfig @@ -49,6 +50,10 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: # Safely process metadata metadata = _prepare_node_metadata(metadata) + # serialization + if metadata["sources"]: + for idx in range(len(metadata["sources"])): + metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) # Extract required fields embedding = metadata.pop("embedding", None) if embedding is None: @@ -298,7 +303,16 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: if time_field in node and hasattr(node[time_field], "isoformat"): node[time_field] = node[time_field].isoformat() node.pop("user_name", None) - + # serialization + if node["sources"]: + for idx in range(len(node["sources"])): + if not ( + isinstance(node["sources"][idx], str) + and node["sources"][idx][0] == "{" + and node["sources"][idx][0] == "}" + ): + break + node["sources"][idx] = json.loads(node["sources"][idx]) new_node = {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node} try: vec_item = self.vec_db.get_by_id(new_node["id"]) diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py new file mode 100644 index 000000000..7ce3ca642 --- /dev/null +++ b/src/memos/mem_cube/navie.py @@ -0,0 +1,166 @@ +import os + +from typing import Literal + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.utils import get_json_file_model_schema +from memos.embedders.base import BaseEmbedder +from memos.exceptions import ConfigurationError, MemCubeError +from memos.graph_dbs.base import BaseGraphDB +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_cube.base import BaseMemCube +from memos.mem_reader.base import BaseMemReader +from memos.memories.activation.base import BaseActMemory +from memos.memories.parametric.base import BaseParaMemory +from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.reranker.base import BaseReranker + + +logger = get_logger(__name__) + + +class NaiveMemCube(BaseMemCube): + """MemCube is a box for loading and dumping three types of memories.""" + + def __init__( + self, + llm: BaseLLM, + embedder: BaseEmbedder, + mem_reader: BaseMemReader, + graph_db: BaseGraphDB, + reranker: BaseReranker, + memory_manager: MemoryManager, + default_cube_config: GeneralMemCubeConfig, + internet_retriever: None = None, + ): + """Initialize the MemCube with a configuration.""" + self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( + llm, + embedder, + mem_reader, + graph_db, + reranker, + memory_manager, + default_cube_config.text_mem.config, + internet_retriever, + ) + self._act_mem: BaseActMemory | None = None + self._para_mem: BaseParaMemory | None = None + + def load( + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + ) -> None: + """Load memories. + Args: + dir (str): The directory containing the memory files. + memory_types (list[str], optional): List of memory types to load. + If None, loads all available memory types. + Options: ["text_mem", "act_mem", "para_mem"] + """ + loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) + if loaded_schema != self.config.model_schema: + raise ConfigurationError( + f"Configuration schema mismatch. Expected {self.config.model_schema}, " + f"but found {loaded_schema}." + ) + + # If no specific memory types specified, load all + if memory_types is None: + memory_types = ["text_mem", "act_mem", "para_mem"] + + # Load specified memory types + if "text_mem" in memory_types and self.text_mem: + self.text_mem.load(dir) + logger.debug(f"Loaded text_mem from {dir}") + + if "act_mem" in memory_types and self.act_mem: + self.act_mem.load(dir) + logger.info(f"Loaded act_mem from {dir}") + + if "para_mem" in memory_types and self.para_mem: + self.para_mem.load(dir) + logger.info(f"Loaded para_mem from {dir}") + + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") + + def dump( + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + ) -> None: + """Dump memories. + Args: + dir (str): The directory where the memory files will be saved. + memory_types (list[str], optional): List of memory types to dump. + If None, dumps all available memory types. + Options: ["text_mem", "act_mem", "para_mem"] + """ + if os.path.exists(dir) and os.listdir(dir): + raise MemCubeError( + f"Directory {dir} is not empty. Please provide an empty directory for dumping." + ) + + # Always dump config + self.config.to_json_file(os.path.join(dir, self.config.config_filename)) + + # If no specific memory types specified, dump all + if memory_types is None: + memory_types = ["text_mem", "act_mem", "para_mem"] + + # Dump specified memory types + if "text_mem" in memory_types and self.text_mem: + self.text_mem.dump(dir) + logger.info(f"Dumped text_mem to {dir}") + + if "act_mem" in memory_types and self.act_mem: + self.act_mem.dump(dir) + logger.info(f"Dumped act_mem to {dir}") + + if "para_mem" in memory_types and self.para_mem: + self.para_mem.dump(dir) + logger.info(f"Dumped para_mem to {dir}") + + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") + + @property + def text_mem(self) -> "BaseTextMemory | None": + """Get the textual memory.""" + if self._text_mem is None: + logger.warning("Textual memory is not initialized. Returning None.") + return self._text_mem + + @text_mem.setter + def text_mem(self, value: BaseTextMemory) -> None: + """Set the textual memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._text_mem = value + + @property + def act_mem(self) -> "BaseActMemory | None": + """Get the activation memory.""" + if self._act_mem is None: + logger.warning("Activation memory is not initialized. Returning None.") + return self._act_mem + + @act_mem.setter + def act_mem(self, value: BaseActMemory) -> None: + """Set the activation memory.""" + if not isinstance(value, BaseActMemory): + raise TypeError(f"Expected BaseActMemory, got {type(value).__name__}") + self._act_mem = value + + @property + def para_mem(self) -> "BaseParaMemory | None": + """Get the parametric memory.""" + if self._para_mem is None: + logger.warning("Parametric memory is not initialized. Returning None.") + return self._para_mem + + @para_mem.setter + def para_mem(self, value: BaseParaMemory) -> None: + """Set the parametric memory.""" + if not isinstance(value, BaseParaMemory): + raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") + self._para_mem = value diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index a4ab4ef20..7e0ed9aef 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -179,14 +179,14 @@ def _restore_user_instances( """ try: # Get all user configurations from persistent storage - user_configs = self.user_manager.list_user_configs() + user_configs = self.user_manager.list_user_configs(self.max_user_instances) # Get the raw database records for sorting by updated_at session = self.user_manager._get_session() try: from memos.mem_user.persistent_user_manager import UserConfig - db_configs = session.query(UserConfig).all() + db_configs = session.query(UserConfig).limit(self.max_user_instances).all() # Create a mapping of user_id to updated_at timestamp updated_at_map = {config.user_id: config.updated_at for config in db_configs} @@ -217,6 +217,26 @@ def _restore_user_instances( except Exception as e: logger.error(f"Error during user instance restoration: {e}") + def _initialize_cube_from_default_config( + self, cube_id: str, user_id: str, default_config: GeneralMemCubeConfig + ) -> GeneralMemCube | None: + """ + Initialize a cube from default configuration when cube path doesn't exist. + + Args: + cube_id (str): The cube ID to initialize. + user_id (str): The user ID for the cube. + default_config (GeneralMemCubeConfig): The default configuration to use. + """ + cube_config = default_config.model_copy(deep=True) + # Safely modify the graph_db user_name if it exists + if cube_config.text_mem.config.graph_db.config: + cube_config.text_mem.config.graph_db.config.user_name = ( + f"memos{user_id.replace('-', '')}" + ) + mem_cube = GeneralMemCube(config=cube_config) + return mem_cube + def _preload_user_cubes( self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None ) -> None: @@ -286,8 +306,24 @@ def _load_user_cubes( ) else: logger.warning( - f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}" + f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, now init by default config" ) + cube_obj = self._initialize_cube_from_default_config( + cube_id=cube.cube_id, + user_id=user_id, + default_config=default_cube_config, + ) + if cube_obj: + self.register_mem_cube( + cube_obj, + cube.cube_id, + user_id, + memory_types=[], + ) + else: + raise ValueError( + f"Failed to initialize default cube {cube.cube_id} for user {user_id}" + ) except Exception as e: logger.error(f"Failed to load cube {cube.cube_id} for user {user_id}: {e}") logger.info(f"load user {user_id} cubes successfully") @@ -427,6 +463,47 @@ def _build_system_prompt( + mem_block ) + def _build_base_system_prompt( + self, + base_prompt: str | None = None, + tone: str = "friendly", + verbosity: str = "mid", + mode: str = "enhance", + ) -> str: + """ + Build base system prompt without memory references. + """ + now = datetime.now() + formatted_date = now.strftime("%Y-%m-%d (%A)") + sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode) + prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" + return prefix + sys_body + + def _build_memory_context( + self, + memories_all: list[TextualMemoryItem], + mode: str = "enhance", + ) -> str: + """ + Build memory context to be included in user message. + """ + if not memories_all: + return "" + + mem_block_o, mem_block_p = _format_mem_block(memories_all) + + if mode == "enhance": + return ( + "# Memories\n## PersonalMemory (ordered)\n" + + mem_block_p + + "\n## OuterMemory (ordered)\n" + + mem_block_o + + "\n\n" + ) + else: + mem_block = mem_block_o + "\n" + mem_block_p + return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n" + def _build_enhance_system_prompt( self, user_id: str, @@ -436,6 +513,7 @@ def _build_enhance_system_prompt( ) -> str: """ Build enhance prompt for the user with memory references. + [DEPRECATED] Use _build_base_system_prompt and _build_memory_context instead. """ now = datetime.now() formatted_date = now.strftime("%Y-%m-%d (%A)") @@ -966,14 +1044,22 @@ def chat( m.metadata.embedding = [] new_memories_list.append(m) memories_list = new_memories_list - system_prompt = super()._build_system_prompt(memories_list, base_prompt) + # Build base system prompt without memory + system_prompt = self._build_base_system_prompt(base_prompt, mode="base") + + # Build memory context to be included in user message + memory_context = self._build_memory_context(memories_list, mode="base") + + # Combine memory context with user query + user_content = memory_context + query if memory_context else query + history_info = [] if history: history_info = history[-20:] current_messages = [ {"role": "system", "content": system_prompt}, *history_info, - {"role": "user", "content": query}, + {"role": "user", "content": user_content}, ] response = self.chat_llm.generate(current_messages) time_end = time.time() @@ -1043,8 +1129,16 @@ def chat_with_references( reference = prepare_reference_data(memories_list) yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" - # Build custom system prompt with relevant memories) - system_prompt = self._build_enhance_system_prompt(user_id, memories_list) + + # Build base system prompt without memory + system_prompt = self._build_base_system_prompt(mode="enhance") + + # Build memory context to be included in user message + memory_context = self._build_memory_context(memories_list, mode="enhance") + + # Combine memory context with user query + user_content = memory_context + query if memory_context else query + # Get chat history if user_id not in self.chat_history_manager: self._register_chat_history(user_id, session_id) @@ -1055,7 +1149,7 @@ def chat_with_references( current_messages = [ {"role": "system", "content": system_prompt}, *chat_history.chat_history, - {"role": "user", "content": query}, + {"role": "user", "content": user_content}, ] logger.info( f"user_id: {user_id}, cube_id: {cube_id}, current_system_prompt: {system_prompt}" diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index b6ef00d8d..3e25a0ad7 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -20,7 +20,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, - DEFAULT_THREAD__POOL_MAX_WORKERS, + DEFAULT_THREAD_POOL_MAX_WORKERS, MemCubeID, TreeTextMemory_SEARCH_METHOD, UserID, @@ -60,7 +60,7 @@ def __init__(self, config: BaseSchedulerConfig): self.search_method = TreeTextMemory_SEARCH_METHOD self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False) self.thread_pool_max_workers = self.config.get( - "thread_pool_max_workers", DEFAULT_THREAD__POOL_MAX_WORKERS + "thread_pool_max_workers", DEFAULT_THREAD_POOL_MAX_WORKERS ) self.retriever: SchedulerRetriever | None = None diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index ce6df4d5d..e45ce4a2b 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -1,11 +1,14 @@ import concurrent +import threading from collections import defaultdict from collections.abc import Callable +from typing import Any from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.task_threads import ThreadRace from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -22,6 +25,7 @@ class SchedulerDispatcher(BaseSchedulerModule): - Batch message processing - Graceful shutdown - Bulk handler registration + - Thread race competition for parallel task execution """ def __init__(self, max_workers=30, enable_parallel_dispatch=False): @@ -49,6 +53,9 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False): # Set to track active futures for monitoring purposes self._futures = set() + # Thread race module for competitive task execution + self.thread_race = ThreadRace() + def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]): """ Register a handler function for a specific message label. @@ -177,6 +184,22 @@ def join(self, timeout: float | None = None) -> bool: return len(not_done) == 0 + def run_competitive_tasks( + self, tasks: dict[str, Callable[[threading.Event], Any]], timeout: float = 10.0 + ) -> tuple[str, Any] | None: + """ + Run multiple tasks in a competitive race, returning the result of the first task to complete. + + Args: + tasks: Dictionary mapping task names to task functions that accept a stop_flag parameter + timeout: Maximum time to wait for any task to complete (in seconds) + + Returns: + Tuple of (task_name, result) from the winning task, or None if no task completes + """ + logger.info(f"Starting competitive execution of {len(tasks)} tasks") + return self.thread_race.run_race(tasks, timeout) + def shutdown(self) -> None: """Gracefully shutdown the dispatcher.""" self._running = False diff --git a/src/memos/mem_scheduler/general_modules/task_threads.py b/src/memos/mem_scheduler/general_modules/task_threads.py new file mode 100644 index 000000000..9df8ef650 --- /dev/null +++ b/src/memos/mem_scheduler/general_modules/task_threads.py @@ -0,0 +1,139 @@ +import threading + +from collections.abc import Callable +from typing import Any, TypeVar + +from memos.log import get_logger +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule + + +logger = get_logger(__name__) + +T = TypeVar("T") + + +class ThreadRace(BaseSchedulerModule): + """ + Thread race implementation that runs multiple tasks concurrently and returns + the result of the first task to complete successfully. + + Features: + - Cooperative thread termination using stop flags + - Configurable timeout for tasks + - Automatic cleanup of slower threads + - Thread-safe result handling + """ + + def __init__(self): + super().__init__() + # Variable to store the result + self.result: tuple[str, Any] | None = None + # Event to mark if the race is finished + self.race_finished = threading.Event() + # Lock to protect the result variable + self.lock = threading.Lock() + # Store thread objects for termination + self.threads: dict[str, threading.Thread] = {} + # Stop flags for each thread + self.stop_flags: dict[str, threading.Event] = {} + + def worker( + self, task_func: Callable[[threading.Event], T], task_name: str + ) -> tuple[str, T] | None: + """ + Worker thread function that executes a task and handles result reporting. + + Args: + task_func: Function to execute with a stop_flag parameter + task_name: Name identifier for this task/thread + + Returns: + Tuple of (task_name, result) if this thread wins the race, None otherwise + """ + # Create a stop flag for this task + stop_flag = threading.Event() + self.stop_flags[task_name] = stop_flag + + try: + # Execute the task with stop flag + result = task_func(stop_flag) + + # If the race is already finished or we were asked to stop, return immediately + if self.race_finished.is_set() or stop_flag.is_set(): + return None + + # Try to set the result (if no other thread has set it yet) + with self.lock: + if not self.race_finished.is_set(): + self.result = (task_name, result) + # Mark the race as finished + self.race_finished.set() + logger.info(f"Task '{task_name}' won the race") + + # Signal other threads to stop + for name, flag in self.stop_flags.items(): + if name != task_name: + logger.debug(f"Signaling task '{name}' to stop") + flag.set() + + return self.result + + except Exception as e: + logger.error(f"Task '{task_name}' encountered an error: {e}") + + return None + + def run_race( + self, tasks: dict[str, Callable[[threading.Event], T]], timeout: float = 10.0 + ) -> tuple[str, T] | None: + """ + Start a competition between multiple tasks and return the result of the fastest one. + + Args: + tasks: Dictionary mapping task names to task functions + timeout: Maximum time to wait for any task to complete (in seconds) + + Returns: + Tuple of (task_name, result) from the winning task, or None if no task completes + """ + if not tasks: + logger.warning("No tasks provided for the race") + return None + + # Reset state + self.race_finished.clear() + self.result = None + self.threads.clear() + self.stop_flags.clear() + + # Create and start threads for each task + for task_name, task_func in tasks.items(): + thread = threading.Thread( + target=self.worker, args=(task_func, task_name), name=f"race-{task_name}" + ) + self.threads[task_name] = thread + thread.start() + logger.debug(f"Started task '{task_name}'") + + # Wait for any thread to complete or timeout + race_completed = self.race_finished.wait(timeout=timeout) + + if not race_completed: + logger.warning(f"Race timed out after {timeout} seconds") + # Signal all threads to stop + for _name, flag in self.stop_flags.items(): + flag.set() + + # Wait for all threads to end (with timeout to avoid infinite waiting) + for _name, thread in self.threads.items(): + thread.join(timeout=1.0) + if thread.is_alive(): + logger.warning(f"Thread '{_name}' did not terminate within the join timeout") + + # Return the result + if self.result: + logger.info(f"Race completed. Winner: {self.result[0]}") + else: + logger.warning("Race completed with no winner") + + return self.result diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index a81caf5a8..b029e38e8 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -17,8 +17,8 @@ DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT = 30 DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 20 DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" -DEFAULT_THREAD__POOL_MAX_WORKERS = 5 -DEFAULT_CONSUME_INTERVAL_SECONDS = 3 +DEFAULT_THREAD_POOL_MAX_WORKERS = 10 +DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05 NOT_INITIALIZED = -1 diff --git a/src/memos/mem_user/mysql_persistent_user_manager.py b/src/memos/mem_user/mysql_persistent_user_manager.py index f8983c87c..99e49d206 100644 --- a/src/memos/mem_user/mysql_persistent_user_manager.py +++ b/src/memos/mem_user/mysql_persistent_user_manager.py @@ -188,7 +188,7 @@ def delete_user_config(self, user_id: str) -> bool: finally: session.close() - def list_user_configs(self) -> dict[str, MOSConfig]: + def list_user_configs(self, limit: int = 1) -> dict[str, MOSConfig]: """List all user configurations. Returns: @@ -196,7 +196,7 @@ def list_user_configs(self) -> dict[str, MOSConfig]: """ session = self._get_session() try: - user_configs = session.query(UserConfig).all() + user_configs = session.query(UserConfig).limit(limit).all() result = {} for user_config in user_configs: diff --git a/src/memos/mem_user/persistent_factory.py b/src/memos/mem_user/persistent_factory.py index b5ece61b5..6a7b4fa13 100644 --- a/src/memos/mem_user/persistent_factory.py +++ b/src/memos/mem_user/persistent_factory.py @@ -3,6 +3,7 @@ from memos.configs.mem_user import UserManagerConfigFactory from memos.mem_user.mysql_persistent_user_manager import MySQLPersistentUserManager from memos.mem_user.persistent_user_manager import PersistentUserManager +from memos.mem_user.redis_persistent_user_manager import RedisPersistentUserManager class PersistentUserManagerFactory: @@ -11,6 +12,7 @@ class PersistentUserManagerFactory: backend_to_class: ClassVar[dict[str, Any]] = { "sqlite": PersistentUserManager, "mysql": MySQLPersistentUserManager, + "redis": RedisPersistentUserManager, } @classmethod diff --git a/src/memos/mem_user/redis_persistent_user_manager.py b/src/memos/mem_user/redis_persistent_user_manager.py new file mode 100644 index 000000000..48c89c663 --- /dev/null +++ b/src/memos/mem_user/redis_persistent_user_manager.py @@ -0,0 +1,225 @@ +"""Redis-based persistent user management system for MemOS with configuration storage. + +This module provides persistent storage for user configurations using Redis. +""" + +import json + +from memos.configs.mem_os import MOSConfig +from memos.dependency import require_python_package +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class RedisPersistentUserManager: + """Redis-based user configuration manager with persistence.""" + + @require_python_package( + import_name="redis", + install_command="pip install redis", + install_link="https://redis.readthedocs.io/en/stable/", + ) + def __init__( + self, + host: str = "localhost", + port: int = 6379, + password: str = "", + db: int = 0, + decode_responses: bool = True, + ): + """Initialize the Redis persistent user manager. + + Args: + user_id (str, optional): User ID. Defaults to "root". + host (str): Redis server host. Defaults to "localhost". + port (int): Redis server port. Defaults to 6379. + password (str): Redis password. Defaults to "". + db (int): Redis database number. Defaults to 0. + decode_responses (bool): Whether to decode responses to strings. Defaults to True. + """ + import redis + + self.host = host + self.port = port + self.db = db + + try: + # Create Redis connection + self._redis_client = redis.Redis( + host=host, + port=port, + password=password if password else None, + db=db, + decode_responses=decode_responses, + ) + + # Test connection + if not self._redis_client.ping(): + raise ConnectionError("Redis connection failed") + + logger.info( + f"RedisPersistentUserManager initialized successfully, connected to {host}:{port}/{db}" + ) + + except Exception as e: + logger.error(f"Redis connection error: {e}") + raise + + def _get_config_key(self, user_id: str) -> str: + """Generate Redis key for user configuration. + + Args: + user_id (str): User ID. + + Returns: + str: Redis key name. + """ + return user_id + + def save_user_config(self, user_id: str, config: MOSConfig) -> bool: + """Save user configuration to Redis. + + Args: + user_id (str): User ID. + config (MOSConfig): User's MOS configuration. + + Returns: + bool: True if successful, False otherwise. + """ + try: + # Convert config to JSON string + config_dict = config.model_dump(mode="json") + config_json = json.dumps(config_dict, ensure_ascii=False, indent=2) + + # Save to Redis + key = self._get_config_key(user_id) + self._redis_client.set(key, config_json) + + logger.info(f"Successfully saved configuration for user {user_id} to Redis") + return True + + except Exception as e: + logger.error(f"Error saving configuration for user {user_id}: {e}") + return False + + def get_user_config(self, user_id: str) -> dict | None: + """Get user configuration from Redis (search interface). + + Args: + user_id (str): User ID. + + Returns: + MOSConfig | None: User's configuration object, or None if not found. + """ + try: + # Get configuration from Redis + key = self._get_config_key(user_id) + config_json = self._redis_client.get(key) + + if config_json is None: + logger.info(f"Configuration for user {user_id} does not exist") + return None + + # Parse JSON and create MOSConfig object + config_dict = json.loads(config_json) + + logger.info(f"Successfully retrieved configuration for user {user_id}") + return config_dict + + except json.JSONDecodeError as e: + logger.error(f"Error parsing JSON configuration for user {user_id}: {e}") + return None + except Exception as e: + logger.error(f"Error retrieving configuration for user {user_id}: {e}") + return None + + def delete_user_config(self, user_id: str) -> bool: + """Delete user configuration from Redis. + + Args: + user_id (str): User ID. + + Returns: + bool: True if successful, False otherwise. + """ + try: + key = self._get_config_key(user_id) + result = self._redis_client.delete(key) + + if result > 0: + logger.info(f"Successfully deleted configuration for user {user_id}") + return True + else: + logger.warning(f"Configuration for user {user_id} does not exist, cannot delete") + return False + + except Exception as e: + logger.error(f"Error deleting configuration for user {user_id}: {e}") + return False + + def exists_user_config(self, user_id: str) -> bool: + """Check if user configuration exists. + + Args: + user_id (str): User ID. + + Returns: + bool: True if exists, False otherwise. + """ + try: + key = self._get_config_key(user_id) + return self._redis_client.exists(key) > 0 + except Exception as e: + logger.error(f"Error checking if configuration exists for user {user_id}: {e}") + return False + + def list_user_configs( + self, pattern: str = "user_config:*", count: int = 100 + ) -> dict[str, dict]: + """List all user configurations. + + Args: + pattern (str): Redis key matching pattern. Defaults to "user_config:*". + count (int): Number of keys to return per scan. Defaults to 100. + + Returns: + dict[str, dict]: Dictionary mapping user_id to dict objects. + """ + result = {} + try: + # Use SCAN command to iterate through all matching keys + cursor = 0 + while True: + cursor, keys = self._redis_client.scan(cursor, match=pattern, count=count) + + for key in keys: + # Extract user_id (remove "user_config:" prefix) + user_id = key.replace("user_config:", "") + config = self.get_user_config(user_id) + if config: + result[user_id] = config + + if cursor == 0: + break + + logger.info(f"Successfully listed {len(result)} user configurations") + return result + + except Exception as e: + logger.error(f"Error listing user configurations: {e}") + return {} + + def close(self) -> None: + """Close Redis connection. + + This method should be called when the RedisPersistentUserManager is no longer needed + to ensure proper cleanup of Redis connections. + """ + try: + if hasattr(self, "_redis_client") and self._redis_client: + self._redis_client.close() + logger.info("Redis connection closed") + except Exception as e: + logger.error(f"Error closing Redis connection: {e}") diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index 9fdc67c53..bcf7fdd9b 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -10,6 +10,7 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.general import GeneralTextMemory from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree import TreeTextMemory @@ -20,6 +21,7 @@ class MemoryFactory(BaseMemory): "naive_text": NaiveTextMemory, "general_text": GeneralTextMemory, "tree_text": TreeTextMemory, + "simple_tree_text": SimpleTreeTextMemory, "kv_cache": KVCacheMemory, "vllm_kv_cache": VLLMKVCacheMemory, "lora": LoRAMemory, diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 8171fadce..82dad4486 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -24,7 +24,7 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]: """ @abstractmethod - def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: """Add memories. Args: diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py new file mode 100644 index 000000000..9c67db288 --- /dev/null +++ b/src/memos/memories/textual/simple_tree.py @@ -0,0 +1,295 @@ +import time + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from memos.configs.memory import TreeTextMemoryConfig +from memos.embedders.base import BaseEmbedder +from memos.graph_dbs.base import BaseGraphDB +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_reader.base import BaseMemReader +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.memories.textual.tree import TreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.base import BaseReranker +from memos.types import MessageList + + +if TYPE_CHECKING: + from memos.embedders.factory import OllamaEmbedder + from memos.graph_dbs.factory import Neo4jGraphDB + from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM + + +logger = get_logger(__name__) + + +class SimpleTreeTextMemory(TreeTextMemory): + """General textual memory implementation for storing and retrieving memories.""" + + def __init__( + self, + llm: BaseLLM, + embedder: BaseEmbedder, + mem_reader: BaseMemReader, + graph_db: BaseGraphDB, + reranker: BaseReranker, + memory_manager: MemoryManager, + config: TreeTextMemoryConfig, + internet_retriever: None = None, + is_reorganize: bool = False, + ): + """Initialize memory with the given configuration.""" + time_start = time.time() + self.config: TreeTextMemoryConfig = config + + self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm + logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") + + time_start_ex = time.time() + self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = llm + logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}") + + time_start_em = time.time() + self.embedder: OllamaEmbedder = embedder + logger.info(f"time init: embedder time is: {time.time() - time_start_em}") + + time_start_gs = time.time() + self.graph_store: Neo4jGraphDB = graph_db + logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") + + time_start_rr = time.time() + self.reranker = reranker + logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") + + time_start_mm = time.time() + self.memory_manager: MemoryManager = memory_manager + logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}") + time_start_ir = time.time() + # Create internet retriever if configured + self.internet_retriever = None + if config.internet_retriever is not None: + self.internet_retriever = internet_retriever + logger.info( + f"Internet retriever initialized with backend: {config.internet_retriever.backend}" + ) + else: + logger.info("No internet retriever configured") + logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") + + def add( + self, memories: list[TextualMemoryItem | dict[str, Any]], user_name: str | None = None + ) -> list[str]: + """Add memories. + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + Later: + memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] + metadata = extract_metadata(memory_items, self.extractor_llm) + plan = plan_memory_operations(memory_items, metadata, self.graph_store) + execute_plan(memory_items, metadata, plan, self.graph_store) + """ + return self.memory_manager.add(memories, user_name=user_name) + + def replace_working_memory( + self, memories: list[TextualMemoryItem], user_name: str | None = None + ) -> None: + self.memory_manager.replace_working_memory(memories, user_name=user_name) + + def get_working_memory(self, user_name: str | None = None) -> list[TextualMemoryItem]: + working_memories = self.graph_store.get_all_memory_items( + scope="WorkingMemory", user_name=user_name + ) + items = [TextualMemoryItem.from_dict(record) for record in (working_memories)] + # Sort by updated_at in descending order + sorted_items = sorted( + items, key=lambda x: x.metadata.updated_at or datetime.min, reverse=True + ) + return sorted_items + + def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]: + """ + Get the current size of each memory type. + This delegates to the MemoryManager. + """ + return self.memory_manager.get_current_memory_size(user_name=user_name) + + def search( + self, + query: str, + top_k: int, + info=None, + mode: str = "fast", + memory_type: str = "All", + manual_close_internet: bool = False, + moscube: bool = False, + search_filter: dict | None = None, + user_name: str | None = None, + ) -> list[TextualMemoryItem]: + """Search for memories based on a query. + User query -> TaskGoalParser -> MemoryPathResolver -> + GraphMemoryRetriever -> MemoryReranker -> MemoryReasoner -> Final output + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + mode (str, optional): The mode of the search. + - 'fast': Uses a faster search process, sacrificing some precision for speed. + - 'fine': Uses a more detailed search process, invoking large models for higher precision, but slower performance. + memory_type (str): Type restriction for search. + ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory'] + manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config. + moscube (bool): whether you use moscube to answer questions + search_filter (dict, optional): Optional metadata filters for search results. + - Keys correspond to memory metadata fields (e.g., "user_id", "session_id"). + - Values are exact-match conditions. + Example: {"user_id": "123", "session_id": "abc"} + If None, no additional filtering is applied. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + if (self.internet_retriever is not None) and manual_close_internet: + logger.warning( + "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" + ) + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=None, + moscube=moscube, + ) + else: + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=self.internet_retriever, + moscube=moscube, + ) + return searcher.search( + query, top_k, info, mode, memory_type, search_filter, user_name=user_name + ) + + def get_relevant_subgraph( + self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated" + ) -> dict[str, Any]: + """ + Find and merge the local neighborhood sub-graphs of the top-k + nodes most relevant to the query. + Process: + 1. Embed the user query into a vector representation. + 2. Use vector similarity search to find the top-k similar nodes. + 3. For each similar node: + - Ensure its status matches `center_status` (e.g., 'active'). + - Retrieve its local subgraph up to `depth` hops. + - Collect the center node, its neighbors, and connecting edges. + 4. Merge all retrieved subgraphs into a single unified subgraph. + 5. Return the merged subgraph structure. + + Args: + query (str): The user input or concept to find relevant memories for. + top_k (int, optional): How many top similar nodes to retrieve. Default is 5. + depth (int, optional): The neighborhood depth (number of hops). Default is 2. + center_status (str, optional): Status condition the center node must satisfy (e.g., 'active'). + + Returns: + dict[str, Any]: A subgraph dict with: + - 'core_id': ID of the top matching core node, or None if none found. + - 'nodes': List of unique nodes (core + neighbors) in the merged subgraph. + - 'edges': List of unique edges (as dicts with 'from', 'to', 'type') in the merged subgraph. + """ + # Step 1: Embed query + query_embedding = self.embedder.embed([query])[0] + + # Step 2: Get top-1 similar node + similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k) + if not similar_nodes: + logger.info("No similar nodes found for query embedding.") + return {"core_id": None, "nodes": [], "edges": []} + + # Step 3: Fetch neighborhood + all_nodes = {} + all_edges = set() + cores = [] + + for node in similar_nodes: + core_id = node["id"] + score = node["score"] + + subgraph = self.graph_store.get_subgraph( + center_id=core_id, depth=depth, center_status=center_status + ) + + if not subgraph["core_node"]: + logger.info(f"Skipping node {core_id} (inactive or not found).") + continue + + core_node = subgraph["core_node"] + neighbors = subgraph["neighbors"] + edges = subgraph["edges"] + + # Collect nodes + all_nodes[core_node["id"]] = core_node + for n in neighbors: + all_nodes[n["id"]] = n + + # Collect edges + for e in edges: + all_edges.add((e["source"], e["target"], e["type"])) + + cores.append( + {"id": core_id, "score": score, "core_node": core_node, "neighbors": neighbors} + ) + + top_core = cores[0] + return { + "core_id": top_core["id"], + "nodes": list(all_nodes.values()), + "edges": [{"source": f, "target": t, "type": ty} for (f, t, ty) in all_edges], + } + + def extract(self, messages: MessageList) -> list[TextualMemoryItem]: + raise NotImplementedError + + def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: + raise NotImplementedError + + def get(self, memory_id: str) -> TextualMemoryItem: + """Get a memory by its ID.""" + result = self.graph_store.get_node(memory_id) + if result is None: + raise ValueError(f"Memory with ID {memory_id} not found") + metadata_dict = result.get("metadata", {}) + return TextualMemoryItem( + id=result["id"], + memory=result["memory"], + metadata=TreeNodeTextualMemoryMetadata(**metadata_dict), + ) + + def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: + raise NotImplementedError + + def get_all(self) -> dict: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + all_items = self.graph_store.export_graph() + return all_items + + def delete(self, memory_ids: list[str]) -> None: + raise NotImplementedError + + def delete_all(self) -> None: + """Delete all memories and their relationships from the graph store.""" + try: + self.graph_store.clear() + logger.info("All memories and edges have been deleted from the graph.") + except Exception as e: + logger.error(f"An error occurred while deleting all memories: {e}") + raise diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index c9cd4de8a..680052a9d 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -51,14 +51,14 @@ def __init__( ) self._merged_threshold = merged_threshold - def add(self, memories: list[TextualMemoryItem]) -> list[str]: + def add(self, memories: list[TextualMemoryItem], user_name: str | None = None) -> list[str]: """ Add new memories in parallel to different memory types (WorkingMemory, LongTermMemory, UserMemory). """ added_ids: list[str] = [] with ContextThreadPoolExecutor(max_workers=8) as executor: - futures = {executor.submit(self._process_memory, m): m for m in memories} + futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} for future in as_completed(futures, timeout=60): try: ids = future.result() @@ -66,38 +66,31 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]: except Exception as e: logger.exception("Memory processing error: ", exc_info=e) - try: - self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] - ) - except Exception: - logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"] - ) - except Exception: - logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"] - ) - except Exception: - logger.warning(f"Remove UserMemory error: {traceback.format_exc()}") + for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + try: + self.graph_store.remove_oldest_memory( + memory_type="WorkingMemory", + keep_latest=self.memory_size[mem_type], + user_name=user_name, + ) + except Exception: + logger.warning(f"Remove {mem_type} error: {traceback.format_exc()}") - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) return added_ids - def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: + def replace_working_memory( + self, memories: list[TextualMemoryItem], user_name: str | None = None + ) -> None: """ Replace WorkingMemory """ working_memory_top_k = memories[: self.memory_size["WorkingMemory"]] with ContextThreadPoolExecutor(max_workers=8) as executor: futures = [ - executor.submit(self._add_memory_to_db, memory, "WorkingMemory") + executor.submit( + self._add_memory_to_db, memory, "WorkingMemory", user_name=user_name + ) for memory in working_memory_top_k ] for future in as_completed(futures, timeout=60): @@ -107,47 +100,51 @@ def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: logger.exception("Memory processing error: ", exc_info=e) self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] + memory_type="WorkingMemory", + keep_latest=self.memory_size["WorkingMemory"], + user_name=user_name, ) - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) - def get_current_memory_size(self) -> dict[str, int]: + def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]: """ Return the cached memory type counts. """ - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) return self.current_memory_size - def _refresh_memory_size(self) -> None: + def _refresh_memory_size(self, user_name: str | None = None) -> None: """ Query the latest counts from the graph store and update internal state. """ - results = self.graph_store.get_grouped_counts(group_fields=["memory_type"]) + results = self.graph_store.get_grouped_counts( + group_fields=["memory_type"], user_name=user_name + ) self.current_memory_size = {record["memory_type"]: record["count"] for record in results} logger.info(f"[MemoryManager] Refreshed memory sizes: {self.current_memory_size}") - def _process_memory(self, memory: TextualMemoryItem): + def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None): """ Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). This method runs asynchronously to process each memory item. """ ids = [] - # Add to WorkingMemory - working_id = self._add_memory_to_db(memory, "WorkingMemory") - ids.append(working_id) + # Add to WorkingMemory do not return working_id + self._add_memory_to_db(memory, "WorkingMemory", user_name) # Add to LongTermMemory and UserMemory if memory.metadata.memory_type in ["LongTermMemory", "UserMemory"]: added_id = self._add_to_graph_memory( - memory=memory, - memory_type=memory.metadata.memory_type, + memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name ) ids.append(added_id) return ids - def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str: + def _add_memory_to_db( + self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + ) -> str: """ Add a single memory item to the graph store, with FIFO logic for WorkingMemory. """ @@ -158,10 +155,12 @@ def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str: working_memory = TextualMemoryItem(memory=memory.memory, metadata=metadata) # Insert node into graph - self.graph_store.add_node(working_memory.id, working_memory.memory, metadata) + self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) return working_memory.id - def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str): + def _add_to_graph_memory( + self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + ): """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). @@ -175,7 +174,10 @@ def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str): node_id = str(uuid.uuid4()) # Step 2: Add new node to graph self.graph_store.add_node( - node_id, memory.memory, memory.metadata.model_dump(exclude_none=True) + node_id, + memory.memory, + memory.metadata.model_dump(exclude_none=True), + user_name=user_name, ) self.reorganizer.add_message( QueueMessage( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 84cc8ecb3..d4cfcf501 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -30,6 +30,7 @@ def retrieve( memory_scope: str, query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -53,13 +54,13 @@ def retrieve( if memory_scope == "WorkingMemory": # For working memory, retrieve all entries (no filtering) working_memories = self.graph_store.get_all_memory_items( - scope="WorkingMemory", include_embedding=False + scope="WorkingMemory", include_embedding=False, user_name=user_name ) return [TextualMemoryItem.from_dict(record) for record in working_memories] with ContextThreadPoolExecutor(max_workers=2) as executor: # Structured graph-based retrieval - future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope) + future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) # Vector similarity search future_vector = executor.submit( self._vector_recall, @@ -67,6 +68,7 @@ def retrieve( memory_scope, top_k, search_filter=search_filter, + user_name=user_name, ) graph_results = future_graph.result() @@ -92,6 +94,7 @@ def retrieve_from_cube( memory_scope: str, query_embedding: list[list[float]] | None = None, cube_name: str = "memos_cube01", + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -112,7 +115,7 @@ def retrieve_from_cube( raise ValueError(f"Unsupported memory scope: {memory_scope}") graph_results = self._vector_recall( - query_embedding, memory_scope, top_k, cube_name=cube_name + query_embedding, memory_scope, top_k, cube_name=cube_name, user_name=user_name ) for result_i in graph_results: @@ -132,7 +135,7 @@ def retrieve_from_cube( return list(combined.values()) def _graph_recall( - self, parsed_goal: ParsedTaskGoal, memory_scope: str + self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None ) -> list[TextualMemoryItem]: """ Perform structured node-based retrieval from Neo4j. @@ -148,7 +151,7 @@ def _graph_recall( {"field": "key", "op": "in", "value": parsed_goal.keys}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - key_ids = self.graph_store.get_by_metadata(key_filters) + key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) candidate_ids.update(key_ids) # 2) tag-based OR branch @@ -157,7 +160,7 @@ def _graph_recall( {"field": "tags", "op": "contains", "value": parsed_goal.tags}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - tag_ids = self.graph_store.get_by_metadata(tag_filters) + tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name) candidate_ids.update(tag_ids) # No matches → return empty @@ -165,7 +168,9 @@ def _graph_recall( return [] # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=False, user_name=user_name + ) final_nodes = [] for node in node_dicts: @@ -194,6 +199,7 @@ def _vector_recall( max_num: int = 3, cube_name: str | None = None, search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform vector-based similarity retrieval using query embedding. @@ -210,6 +216,7 @@ def search_single(vec, filt=None): scope=memory_scope, cube_name=cube_name, search_filter=filt, + user_name=user_name, ) or [] ) @@ -255,7 +262,7 @@ def search_path_b(): unique_ids = {r["id"] for r in all_hits if r.get("id")} node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), include_embedding=False, cube_name=cube_name + list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name ) or [] ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index df154f23a..05db56f53 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -12,7 +12,6 @@ from memos.reranker.base import BaseReranker from memos.utils import timed -from .internet_retriever_factory import InternetRetrieverFactory from .reasoner import MemoryReasoner from .recall import GraphMemoryRetriever from .task_goal_parser import TaskGoalParser @@ -28,7 +27,7 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, - internet_retriever: InternetRetrieverFactory | None = None, + internet_retriever: None = None, moscube: bool = False, ): self.graph_store = graph_store @@ -54,6 +53,7 @@ def search( mode="fast", memory_type="All", search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Search for memories based on a query. @@ -85,14 +85,22 @@ def search( logger.debug(f"[SEARCH] Received info dict: {info}") parsed_goal, query_embedding, context, query = self._parse_task( - query, info, mode, search_filter=search_filter + query, info, mode, search_filter=search_filter, user_name=user_name ) results = self._retrieve_paths( - query, parsed_goal, query_embedding, info, top_k, mode, memory_type, search_filter + query, + parsed_goal, + query_embedding, + info, + top_k, + mode, + memory_type, + search_filter, + user_name, ) deduped = self._deduplicate_results(results) final_results = self._sort_and_trim(deduped, top_k) - self._update_usage_history(final_results, info) + self._update_usage_history(final_results, info, user_name) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") res_results = "" @@ -104,7 +112,15 @@ def search( return final_results @timed - def _parse_task(self, query, info, mode, top_k=5, search_filter: dict | None = None): + def _parse_task( + self, + query, + info, + mode, + top_k=5, + search_filter: dict | None = None, + user_name: str | None = None, + ): """Parse user query, do embedding search and create context""" context = [] query_embedding = None @@ -118,7 +134,7 @@ def _parse_task(self, query, info, mode, top_k=5, search_filter: dict | None = N related_nodes = [ self.graph_store.get_node(n["id"]) for n in self.graph_store.search_by_embedding( - query_embedding, top_k=top_k, search_filter=search_filter + query_embedding, top_k=top_k, search_filter=search_filter, user_name=user_name ) ] memories = [] @@ -168,6 +184,7 @@ def _retrieve_paths( mode, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Run A/B/C retrieval paths in parallel""" tasks = [] @@ -181,6 +198,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + user_name, ) ) tasks.append( @@ -192,6 +210,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + user_name, ) ) tasks.append( @@ -204,6 +223,7 @@ def _retrieve_paths( info, mode, memory_type, + user_name, ) ) if self.moscube: @@ -235,6 +255,7 @@ def _retrieve_from_working_memory( top_k, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Retrieve and rerank from WorkingMemory""" if memory_type not in ["All", "WorkingMemory"]: @@ -246,6 +267,7 @@ def _retrieve_from_working_memory( top_k=top_k, memory_scope="WorkingMemory", search_filter=search_filter, + user_name=user_name, ) return self.reranker.rerank( query=query, @@ -266,6 +288,7 @@ def _retrieve_from_long_term_and_user( top_k, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] @@ -282,6 +305,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, + user_name=user_name, ) ) if memory_type in ["All", "UserMemory"]: @@ -294,6 +318,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, + user_name=user_name, ) ) @@ -320,6 +345,7 @@ def _retrieve_from_memcubes( top_k=top_k * 2, memory_scope="LongTermMemory", cube_name=cube_name, + user_name=cube_name, ) return self.reranker.rerank( query=query, @@ -332,7 +358,15 @@ def _retrieve_from_memcubes( # --- Path C @timed def _retrieve_from_internet( - self, query, parsed_goal, query_embedding, top_k, info, mode, memory_type + self, + query, + parsed_goal, + query_embedding, + top_k, + info, + mode, + memory_type, + user_id: str | None = None, ): """Retrieve and rerank from Internet source""" if not self.internet_retriever or mode == "fast": @@ -380,7 +414,7 @@ def _sort_and_trim(self, results, top_k): return final_items @timed - def _update_usage_history(self, items, info): + def _update_usage_history(self, items, info, user_name: str | None = None): """Update usage history in graph DB""" now_time = datetime.now().isoformat() info_copy = dict(info or {}) @@ -402,11 +436,15 @@ def _update_usage_history(self, items, info): logger.exception("[USAGE] snapshot item failed") if payload: - self._usage_executor.submit(self._update_usage_history_worker, payload, usage_record) + self._usage_executor.submit( + self._update_usage_history_worker, payload, usage_record, user_name + ) - def _update_usage_history_worker(self, payload, usage_record: str): + def _update_usage_history_worker( + self, payload, usage_record: str, user_name: str | None = None + ): try: for item_id, usage_list in payload: - self.graph_store.update_node(item_id, {"usage": usage_list}) + self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=user_name) except Exception: logger.exception("[USAGE] update usage failed") diff --git a/src/memos/types.py b/src/memos/types.py index 60d5da8d2..635fabccc 100644 --- a/src/memos/types.py +++ b/src/memos/types.py @@ -56,3 +56,25 @@ class MOSSearchResult(TypedDict): text_mem: list[dict[str, str | list[TextualMemoryItem]]] act_mem: list[dict[str, str | list[ActivationMemoryItem]]] para_mem: list[dict[str, str | list[ParametricMemoryItem]]] + + +# ─── API Types ──────────────────────────────────────────────────────────────────── +# for API Permission +Permission: TypeAlias = Literal["read", "write", "delete", "execute"] + + +# Message structure +class PermissionDict(TypedDict, total=False): + """Typed dictionary for chat message dictionaries.""" + + permissions: list[Permission] + mem_cube_id: str + + +class UserContext(BaseModel): + """Model to represent user context.""" + + user_id: str | None = None + mem_cube_id: str | None = None + session_id: str | None = None + operation: list[PermissionDict] | None = None diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py new file mode 100644 index 000000000..7bb1ceeba --- /dev/null +++ b/src/memos/vec_dbs/milvus.py @@ -0,0 +1,367 @@ +from typing import Any + +from memos.configs.vec_db import MilvusVecDBConfig +from memos.dependency import require_python_package +from memos.log import get_logger +from memos.vec_dbs.base import BaseVecDB +from memos.vec_dbs.item import VecDBItem + + +logger = get_logger(__name__) + + +class MilvusVecDB(BaseVecDB): + """Milvus vector database implementation.""" + + @require_python_package( + import_name="pymilvus", + install_command="pip install -U pymilvus", + install_link="https://milvus.io/docs/install-pymilvus.md", + ) + def __init__(self, config: MilvusVecDBConfig): + """Initialize the Milvus vector database and the collection.""" + from pymilvus import MilvusClient + + self.config = config + + # Create Milvus client + self.client = MilvusClient( + uri=self.config.uri, user=self.config.user_name, password=self.config.password + ) + self.schema = self.create_schema() + self.index_params = self.create_index() + self.create_collection() + + def create_schema(self): + """Create schema for the milvus collection.""" + from pymilvus import DataType + + schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True) + schema.add_field( + field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True + ) + schema.add_field( + field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension + ) + schema.add_field(field_name="payload", datatype=DataType.JSON) + + return schema + + def create_index(self): + """Create index for the milvus collection.""" + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", index_type="FLAT", metric_type=self._get_metric_type() + ) + + return index_params + + def create_collection(self) -> None: + """Create a new collection with specified parameters.""" + for collection_name in self.config.collection_name: + if self.collection_exists(collection_name): + logger.warning(f"Collection '{collection_name}' already exists. Skipping creation.") + continue + + self.client.create_collection( + collection_name=collection_name, + dimension=self.config.vector_dimension, + metric_type=self._get_metric_type(), + schema=self.schema, + index_params=self.index_params, + ) + + logger.info( + f"Collection '{collection_name}' created with {self.config.vector_dimension} dimensions." + ) + + def create_collection_by_name(self, collection_name: str) -> None: + """Create a new collection with specified parameters.""" + if self.collection_exists(collection_name): + logger.warning(f"Collection '{collection_name}' already exists. Skipping creation.") + return + + self.client.create_collection( + collection_name=collection_name, + dimension=self.config.vector_dimension, + metric_type=self._get_metric_type(), + schema=self.schema, + index_params=self.index_params, + ) + + def list_collections(self) -> list[str]: + """List all collections.""" + return self.client.list_collections() + + def delete_collection(self, name: str) -> None: + """Delete a collection.""" + self.client.drop_collection(name) + + def collection_exists(self, name: str) -> bool: + """Check if a collection exists.""" + return self.client.has_collection(collection_name=name) + + def search( + self, + query_vector: list[float], + collection_name: str, + top_k: int, + filter: dict[str, Any] | None = None, + ) -> list[VecDBItem]: + """ + Search for similar items in the database. + + Args: + query_vector: Single vector to search + collection_name: Name of the collection to search + top_k: Number of results to return + filter: Payload filters + + Returns: + List of search results with distance scores and payloads. + """ + # Convert filter to Milvus expression + expr = self._dict_to_expr(filter) if filter else "" + + results = self.client.search( + collection_name=collection_name, + data=[query_vector], + limit=top_k, + filter=expr, + output_fields=["*"], # Return all fields + ) + + items = [] + for hit in results[0]: + entity = hit.get("entity", {}) + + items.append( + VecDBItem( + id=str(hit["id"]), + vector=entity.get("vector"), + payload=entity.get("payload", {}), + score=1 - float(hit["distance"]), + ) + ) + + logger.info(f"Milvus search completed with {len(items)} results.") + return items + + def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: + """Convert a dictionary filter to a Milvus expression string.""" + if not filter_dict: + return "" + + conditions = [] + for field, value in filter_dict.items(): + # Skip None values as they cause Milvus query syntax errors + if value is None: + continue + # For JSON fields, we need to use payload["field"] syntax + elif isinstance(value, str): + conditions.append(f"payload['{field}'] == '{value}'") + elif isinstance(value, list) and len(value) == 0: + # Skip empty lists as they cause Milvus query syntax errors + continue + elif isinstance(value, list) and len(value) > 0: + conditions.append(f"payload['{field}'] in {value}") + else: + conditions.append(f"payload['{field}'] == '{value}'") + return " and ".join(conditions) + + def _get_metric_type(self) -> str: + """Get the metric type for search.""" + metric_map = { + "cosine": "COSINE", + "euclidean": "L2", + "dot": "IP", + } + return metric_map.get(self.config.distance_metric, "L2") + + def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: + """Get a single item by ID.""" + results = self.client.get( + collection_name=collection_name, + ids=[id], + ) + + if not results: + return None + + entity = results[0] + payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} + + return VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + + def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: + """Get multiple items by their IDs.""" + results = self.client.get( + collection_name=collection_name, + ids=ids, + ) + + if not results: + return [] + + items = [] + for entity in results: + payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} + items.append( + VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + ) + + return items + + def get_by_filter( + self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 + ) -> list[VecDBItem]: + """ + Retrieve all items that match the given filter criteria using query_iterator. + + Args: + filter: Payload filters to match against stored items + scroll_limit: Maximum number of items to retrieve per batch (batch_size) + + Returns: + List of items including vectors and payload that match the filter + """ + expr = self._dict_to_expr(filter) if filter else "" + all_items = [] + + # Use query_iterator for efficient pagination + iterator = self.client.query_iterator( + collection_name=collection_name, + filter=expr, + batch_size=scroll_limit, + output_fields=["*"], # Include all fields including payload + ) + + # Iterate through all batches + try: + while True: + batch_results = iterator.next() + + if not batch_results: + break + + # Convert batch results to VecDBItem objects + for entity in batch_results: + # Extract the actual payload from Milvus entity + payload = entity.get("payload", {}) + all_items.append( + VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + ) + except Exception as e: + logger.warning( + f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far." + ) + finally: + # Close the iterator + iterator.close() + + logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") + return all_items + + def get_all(self, collection_name: str, scroll_limit=100) -> list[VecDBItem]: + """Retrieve all items in the vector database.""" + return self.get_by_filter(collection_name, {}, scroll_limit=scroll_limit) + + def count(self, collection_name: str, filter: dict[str, Any] | None = None) -> int: + """Count items in the database, optionally with filter.""" + if filter: + # If there's a filter, use query method + expr = self._dict_to_expr(filter) if filter else "" + results = self.client.query( + collection_name=collection_name, + filter=expr, + output_fields=["id"], + ) + return len(results) + else: + # For counting all items, use get_collection_stats for accurate count + stats = self.client.get_collection_stats(collection_name) + # Extract row count from stats - stats is a dict, not a list + return int(stats.get("row_count", 0)) + + def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + """ + Add data to the vector database. + + Args: + data: List of VecDBItem objects or dictionaries containing: + - 'id': unique identifier + - 'vector': embedding vector + - 'payload': additional fields for filtering/retrieval + """ + entities = [] + for item in data: + if isinstance(item, dict): + item = item.copy() + item = VecDBItem.from_dict(item) + + # Prepare entity data + entity = { + "id": item.id, + "vector": item.vector, + "payload": item.payload if item.payload else {}, + } + + entities.append(entity) + + # Use upsert to be safe (insert or update) + self.client.upsert( + collection_name=collection_name, + data=entities, + ) + + def update(self, collection_name: str, id: str, data: VecDBItem | dict[str, Any]) -> None: + """Update an item in the vector database.""" + if isinstance(data, dict): + data = data.copy() + data = VecDBItem.from_dict(data) + + # Use upsert for updates + self.upsert(collection_name, [data]) + + def ensure_payload_indexes(self, fields: list[str]) -> None: + """ + Create payload indexes for specified fields in the collection. + This is idempotent: it will skip if index already exists. + + Args: + fields (list[str]): List of field names to index (as keyword). + """ + # Note: Milvus doesn't have the same concept of payload indexes as Qdrant + # Field indexes are created automatically for scalar fields + logger.info(f"Milvus automatically indexes scalar fields: {fields}") + + def upsert(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + """ + Add or update data in the vector database. + + If an item with the same ID exists, it will be updated. + Otherwise, it will be added as a new item. + """ + # Reuse add method since it already uses upsert + self.add(collection_name, data) + + def delete(self, collection_name: str, ids: list[str]) -> None: + """Delete items from the vector database.""" + if not ids: + return + self.client.delete( + collection_name=collection_name, + ids=ids, + ) diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py new file mode 100644 index 000000000..f4d0d6b97 --- /dev/null +++ b/tests/mem_scheduler/test_dispatcher.py @@ -0,0 +1,295 @@ +import sys +import time +import unittest + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from memos.configs.mem_scheduler import ( + AuthConfig, + GraphDBAuthConfig, + OpenAIConfig, + RabbitMQConfig, + SchedulerConfigFactory, +) +from memos.llms.base import BaseLLM +from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.memories.textual.tree import TreeTextMemory + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + + +class TestSchedulerDispatcher(unittest.TestCase): + """Test cases for the SchedulerDispatcher class.""" + + def _create_mock_auth_config(self): + """Create a mock AuthConfig for testing purposes.""" + # Create mock configs with valid test values + graph_db_config = GraphDBAuthConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test_password_123", # 8+ characters to pass validation + db_name="neo4j", + auto_create=True, + ) + + rabbitmq_config = RabbitMQConfig( + host_name="localhost", port=5672, user_name="guest", password="guest", virtual_host="/" + ) + + openai_config = OpenAIConfig(api_key="test_api_key_123", default_model="gpt-3.5-turbo") + + return AuthConfig(rabbitmq=rabbitmq_config, openai=openai_config, graph_db=graph_db_config) + + def setUp(self): + """Initialize test environment with mock objects.""" + example_scheduler_config_path = ( + f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml" + ) + scheduler_config = SchedulerConfigFactory.from_yaml_file( + yaml_path=example_scheduler_config_path + ) + mem_scheduler = SchedulerFactory.from_config(scheduler_config) + self.scheduler = mem_scheduler + self.llm = MagicMock(spec=BaseLLM) + self.mem_cube = MagicMock(spec=GeneralMemCube) + self.tree_text_memory = MagicMock(spec=TreeTextMemory) + self.mem_cube.text_mem = self.tree_text_memory + self.mem_cube.act_mem = MagicMock() + + # Mock AuthConfig.from_local_env() to return our test config + mock_auth_config = self._create_mock_auth_config() + self.auth_config_patch = patch( + "memos.configs.mem_scheduler.AuthConfig.from_local_env", return_value=mock_auth_config + ) + self.auth_config_patch.start() + + # Initialize general_modules with mock LLM + self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm) + self.scheduler.mem_cube = self.mem_cube + + self.dispatcher = self.scheduler.dispatcher + + # Create mock handlers + self.mock_handler1 = MagicMock() + self.mock_handler2 = MagicMock() + + # Register mock handlers + self.dispatcher.register_handler("label1", self.mock_handler1) + self.dispatcher.register_handler("label2", self.mock_handler2) + + # Create test messages + self.test_messages = [ + ScheduleMessageItem( + item_id="msg1", + user_id="user1", + mem_cube="cube1", + mem_cube_id="msg1", + label="label1", + content="Test content 1", + timestamp=123456789, + ), + ScheduleMessageItem( + item_id="msg2", + user_id="user1", + mem_cube="cube1", + mem_cube_id="msg2", + label="label2", + content="Test content 2", + timestamp=123456790, + ), + ScheduleMessageItem( + item_id="msg3", + user_id="user2", + mem_cube="cube2", + mem_cube_id="msg3", + label="label1", + content="Test content 3", + timestamp=123456791, + ), + ] + + # Mock logging to verify messages + self.logging_warning_patch = patch("logging.warning") + self.mock_logging_warning = self.logging_warning_patch.start() + + # Mock the MemoryFilter logger since that's where the actual logging happens + self.logger_info_patch = patch( + "memos.mem_scheduler.memory_manage_modules.memory_filter.logger.info" + ) + self.mock_logger_info = self.logger_info_patch.start() + + def tearDown(self): + """Clean up patches.""" + self.logging_warning_patch.stop() + self.logger_info_patch.stop() + self.auth_config_patch.stop() + + def test_register_handler(self): + """Test registering a single handler.""" + new_handler = MagicMock() + self.dispatcher.register_handler("new_label", new_handler) + + # Verify handler was registered + self.assertIn("new_label", self.dispatcher.handlers) + self.assertEqual(self.dispatcher.handlers["new_label"], new_handler) + + def test_register_handlers(self): + """Test bulk registration of handlers.""" + new_handlers = { + "bulk1": MagicMock(), + "bulk2": MagicMock(), + } + + self.dispatcher.register_handlers(new_handlers) + + # Verify all handlers were registered + for label, handler in new_handlers.items(): + self.assertIn(label, self.dispatcher.handlers) + self.assertEqual(self.dispatcher.handlers[label], handler) + + def test_dispatch_serial(self): + """Test dispatching messages in serial mode.""" + # Create a new dispatcher with parallel dispatch disabled + serial_dispatcher = SchedulerDispatcher(max_workers=2, enable_parallel_dispatch=False) + serial_dispatcher.register_handler("label1", self.mock_handler1) + serial_dispatcher.register_handler("label2", self.mock_handler2) + + # Dispatch messages + serial_dispatcher.dispatch(self.test_messages) + + # Verify handlers were called with the correct messages + self.mock_handler1.assert_called_once() + self.mock_handler2.assert_called_once() + + # Check that each handler received the correct messages + label1_messages = [msg for msg in self.test_messages if msg.label == "label1"] + label2_messages = [msg for msg in self.test_messages if msg.label == "label2"] + + # The first argument of the first call + self.assertEqual(self.mock_handler1.call_args[0][0], label1_messages) + self.assertEqual(self.mock_handler2.call_args[0][0], label2_messages) + + def test_dispatch_parallel(self): + """Test dispatching messages in parallel mode.""" + # Dispatch messages + self.dispatcher.dispatch(self.test_messages) + + # Wait for all futures to complete + self.dispatcher.join(timeout=1.0) + + # Verify handlers were called + self.mock_handler1.assert_called_once() + self.mock_handler2.assert_called_once() + + # Check that each handler received the correct messages + label1_messages = [msg for msg in self.test_messages if msg.label == "label1"] + label2_messages = [msg for msg in self.test_messages if msg.label == "label2"] + + # The first argument of the first call + self.assertEqual(self.mock_handler1.call_args[0][0], label1_messages) + self.assertEqual(self.mock_handler2.call_args[0][0], label2_messages) + + def test_group_messages_by_user_and_cube(self): + """Test grouping messages by user and cube.""" + # Check actual grouping logic + with patch("memos.mem_scheduler.general_modules.dispatcher.logger.debug"): + result = self.dispatcher.group_messages_by_user_and_cube(self.test_messages) + + # Adjust expected results based on actual grouping logic + # Note: According to dispatcher.py implementation, grouping is by mem_cube_id not mem_cube + expected = { + "user1": { + "msg1": [self.test_messages[0]], + "msg2": [self.test_messages[1]], + }, + "user2": { + "msg3": [self.test_messages[2]], + }, + } + + # Use more flexible assertion method + self.assertEqual(set(result.keys()), set(expected.keys())) + for user_id in expected: + self.assertEqual(set(result[user_id].keys()), set(expected[user_id].keys())) + for cube_id in expected[user_id]: + self.assertEqual(len(result[user_id][cube_id]), len(expected[user_id][cube_id])) + # Check if each message exists + for msg in expected[user_id][cube_id]: + self.assertIn(msg.item_id, [m.item_id for m in result[user_id][cube_id]]) + + def test_thread_race(self): + """Test the ThreadRace integration.""" + + # Define test tasks + def task1(stop_flag): + time.sleep(0.1) + return "result1" + + def task2(stop_flag): + time.sleep(0.2) + return "result2" + + # Run competitive tasks + tasks = { + "task1": task1, + "task2": task2, + } + + result = self.dispatcher.run_competitive_tasks(tasks, timeout=1.0) + + # Verify the result + self.assertIsNotNone(result) + self.assertEqual(result[0], "task1") # task1 should win + self.assertEqual(result[1], "result1") + + def test_thread_race_timeout(self): + """Test ThreadRace with timeout.""" + + # Define a task that takes longer than the timeout + def slow_task(stop_flag): + time.sleep(0.5) + return "slow_result" + + tasks = {"slow": slow_task} + + # Run with a short timeout + result = self.dispatcher.run_competitive_tasks(tasks, timeout=0.1) + + # Verify no result was returned due to timeout + self.assertIsNone(result) + + def test_thread_race_cooperative_termination(self): + """Test that ThreadRace properly terminates slower threads when one completes.""" + + # Create a fast task and a slow task + def fast_task(stop_flag): + return "fast result" + + def slow_task(stop_flag): + # Check stop flag to ensure proper response + for _ in range(10): + if stop_flag.is_set(): + return "stopped early" + time.sleep(0.1) + return "slow result" + + # Run competitive tasks with increased timeout for test stability + result = self.dispatcher.run_competitive_tasks( + {"fast": fast_task, "slow": slow_task}, + timeout=2.0, # Increased timeout + ) + + # Verify the result is from the fast task + self.assertIsNotNone(result) + self.assertEqual(result[0], "fast") + self.assertEqual(result[1], "fast result") + + # Allow enough time for thread cleanup + time.sleep(0.5) diff --git a/tests/memories/textual/test_tree_searcher.py b/tests/memories/textual/test_tree_searcher.py index c9f42ec38..d99664817 100644 --- a/tests/memories/textual/test_tree_searcher.py +++ b/tests/memories/textual/test_tree_searcher.py @@ -73,7 +73,7 @@ def test_searcher_fast_path(mock_searcher): for item in result: assert len(item.metadata.usage) > 0 mock_searcher.graph_store.update_node.assert_any_call( - item.id, {"usage": item.metadata.usage} + item.id, {"usage": item.metadata.usage}, user_name=None )