From 87c37fbb13ee75c9166e1926ffccb0bf434fa733 Mon Sep 17 00:00:00 2001 From: Kai Date: Thu, 11 Sep 2025 14:09:16 +0800 Subject: [PATCH 01/29] fix:#(268)https://github.com/MemTensor/MemOS/issues/286 --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 270fd712c..7bc02af50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,11 @@ mem-scheduler = [ "pika (>=1.3.2,<2.0.0)", # RabbitMQ client ] +# MemUser (MySQL support) +mem-user = [ + "pymysql (>=1.1.0,<2.0.0)", # MySQL client for SQLAlchemy +] + # MemReader mem-reader = [ "chonkie (>=1.0.7,<2.0.0)", # Sentence chunking library @@ -90,6 +95,7 @@ all = [ "schedule (>=1.2.2,<2.0.0)", "redis (>=6.2.0,<7.0.0)", "pika (>=1.3.2,<2.0.0)", + "pymysql (>=1.1.0,<2.0.0)", "chonkie (>=1.0.7,<2.0.0)", "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", From a6a55584b82cdb08f5e743e0a5dbaeab397bceb3 Mon Sep 17 00:00:00 2001 From: Kai Date: Thu, 11 Sep 2025 14:16:16 +0800 Subject: [PATCH 02/29] Add pymysql dependency for MySQL user management --- poetry.lock | 23 ++++++++++++++++++++--- pyproject.toml | 4 ++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index c6b6a0ebf..2517d0b94 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -3773,6 +3773,22 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pymysql" +version = "1.1.2" +description = "Pure Python MySQL Driver" +optional = false +python-versions = ">=3.8" +groups = ["main", "mem-user"] +files = [ + {file = "pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9"}, + {file = "pymysql-1.1.2.tar.gz", hash = "sha256:4961d3e165614ae65014e361811a724e2044ad3ea3739de9903ae7c21f539f03"}, +] + +[package.extras] +ed25519 = ["PyNaCl (>=1.4.0)"] +rsa = ["cryptography"] + [[package]] name = "pyparsing" version = "3.2.3" @@ -6285,12 +6301,13 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["chonkie", "markitdown", "neo4j", "pika", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["chonkie", "markitdown", "neo4j", "pika", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] +mem-user = ["pymysql"] tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "94a3c4f97f0deda4c6ccbfd8ceda194f18dbc7525aa49004ffcc7846a1c40f7e" +content-hash = "505ab4e6784d0191c3f177fdfc1335038d80c3b03b3a711bcdd954ef89afad42" diff --git a/pyproject.toml b/pyproject.toml index 7bc02af50..e2d2e4ffd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,10 @@ python-dotenv = "^1.1.1" langgraph = "^0.5.1" langmem = "^0.0.27" + +[tool.poetry.group.mem-user.dependencies] +pymysql = "^1.1.2" + [[tool.poetry.source]] name = "mirrors" url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" From 4bb4b5c51e678f2e6fd9a9fddc3a79d6cc152b42 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Wed, 24 Sep 2025 20:36:58 +0800 Subject: [PATCH 03/29] add: change deafult pre_load (#338) * add: change deafult pre_load * fix: code --------- Co-authored-by: CaralHsi --- src/memos/api/product_api.py | 2 +- src/memos/mem_os/product.py | 4 ++-- src/memos/mem_user/mysql_persistent_user_manager.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py index 681644a0d..709ad74fb 100644 --- a/src/memos/api/product_api.py +++ b/src/memos/api/product_api.py @@ -33,6 +33,6 @@ parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=8001) - parser.add_argument("--workers", type=int, default=32) + parser.add_argument("--workers", type=int, default=1) args = parser.parse_args() uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index a4ab4ef20..d64643897 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -179,14 +179,14 @@ def _restore_user_instances( """ try: # Get all user configurations from persistent storage - user_configs = self.user_manager.list_user_configs() + user_configs = self.user_manager.list_user_configs(self.max_user_instances) # Get the raw database records for sorting by updated_at session = self.user_manager._get_session() try: from memos.mem_user.persistent_user_manager import UserConfig - db_configs = session.query(UserConfig).all() + db_configs = session.query(UserConfig).limit(self.max_user_instances).all() # Create a mapping of user_id to updated_at timestamp updated_at_map = {config.user_id: config.updated_at for config in db_configs} diff --git a/src/memos/mem_user/mysql_persistent_user_manager.py b/src/memos/mem_user/mysql_persistent_user_manager.py index f8983c87c..99e49d206 100644 --- a/src/memos/mem_user/mysql_persistent_user_manager.py +++ b/src/memos/mem_user/mysql_persistent_user_manager.py @@ -188,7 +188,7 @@ def delete_user_config(self, user_id: str) -> bool: finally: session.close() - def list_user_configs(self) -> dict[str, MOSConfig]: + def list_user_configs(self, limit: int = 1) -> dict[str, MOSConfig]: """List all user configurations. Returns: @@ -196,7 +196,7 @@ def list_user_configs(self) -> dict[str, MOSConfig]: """ session = self._get_session() try: - user_configs = session.query(UserConfig).all() + user_configs = session.query(UserConfig).limit(limit).all() result = {} for user_config in user_configs: From 98dbf8aca09ff80a23d9a448c4436befd72646c2 Mon Sep 17 00:00:00 2001 From: Kai Date: Thu, 25 Sep 2025 21:14:57 +0800 Subject: [PATCH 04/29] feat:reoganize prompt with reference in user content --- src/memos/mem_os/product.py | 83 +++++++++++++++++++++++++++++++------ 1 file changed, 71 insertions(+), 12 deletions(-) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index b6a8d8f5c..e6b6793ff 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -417,12 +417,49 @@ def _build_system_prompt( mem_block_o, mem_block_p = _format_mem_block(memories_all) mem_block = mem_block_o + "\n" + mem_block_p prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return ( - prefix - + sys_body - + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n" - + mem_block - ) + return (prefix + sys_body + + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + + mem_block) + + def _build_base_system_prompt( + self, + base_prompt: str | None = None, + tone: str = "friendly", + verbosity: str = "mid", + mode: str = "enhance", + ) -> str: + """ + Build base system prompt without memory references. + """ + now = datetime.now() + formatted_date = now.strftime("%Y-%m-%d (%A)") + sys_body = get_memos_prompt(date=formatted_date, + tone=tone, + verbosity=verbosity, + mode=mode) + prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" + return prefix + sys_body + + def _build_memory_context( + self, + memories_all: list[TextualMemoryItem], + mode: str = "enhance", + ) -> str: + """ + Build memory context to be included in user message. + """ + if not memories_all: + return "" + + mem_block_o, mem_block_p = _format_mem_block(memories_all) + + if mode == "enhance": + return ("# Memories\n## PersonalMemory (ordered)\n" + mem_block_p + + "\n## OuterMemory (ordered)\n" + mem_block_o + "\n\n") + else: + mem_block = mem_block_o + "\n" + mem_block_p + return ("# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + + mem_block + "\n\n") def _build_enhance_system_prompt( self, @@ -433,6 +470,7 @@ def _build_enhance_system_prompt( ) -> str: """ Build enhance prompt for the user with memory references. + [DEPRECATED] Use _build_base_system_prompt and _build_memory_context instead. """ now = datetime.now() formatted_date = now.strftime("%Y-%m-%d (%A)") @@ -916,17 +954,29 @@ def chat( internet_search=internet_search, moscube=moscube, )["text_mem"] + memories_list = [] if memories_result: memories_list = memories_result[0]["memories"] - memories_list = self._filter_memories_by_threshold(memories_list, threshold) - system_prompt = super()._build_system_prompt(memories_list, base_prompt) + memories_list = self._filter_memories_by_threshold( + memories_list, threshold) + + # Build base system prompt without memory + system_prompt = self._build_base_system_prompt(base_prompt, + mode="base") + + # Build memory context to be included in user message + memory_context = self._build_memory_context(memories_list, mode="base") + + # Combine memory context with user query + user_content = memory_context + query if memory_context else query + history_info = [] if history: history_info = history[-20:] current_messages = [ {"role": "system", "content": system_prompt}, *history_info, - {"role": "user", "content": query}, + {"role": "user", "content": user_content}, ] response = self.chat_llm.generate(current_messages) time_end = time.time() @@ -994,8 +1044,17 @@ def chat_with_references( reference = prepare_reference_data(memories_list) yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" - # Build custom system prompt with relevant memories) - system_prompt = self._build_enhance_system_prompt(user_id, memories_list) + + # Build base system prompt without memory + system_prompt = self._build_base_system_prompt(mode="enhance") + + # Build memory context to be included in user message + memory_context = self._build_memory_context(memories_list, + mode="enhance") + + # Combine memory context with user query + user_content = memory_context + query if memory_context else query + # Get chat history if user_id not in self.chat_history_manager: self._register_chat_history(user_id) @@ -1006,7 +1065,7 @@ def chat_with_references( current_messages = [ {"role": "system", "content": system_prompt}, *chat_history.chat_history, - {"role": "user", "content": query}, + {"role": "user", "content": user_content}, ] logger.info( f"user_id: {user_id}, cube_id: {cube_id}, current_system_prompt: {system_prompt}" From 3734b26ff8294fe07cf9f98d233c351133cdba25 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Fri, 26 Sep 2025 10:59:31 +0800 Subject: [PATCH 05/29] Feat: update load cubes (#350) * feat: update laod cubes * fix: code format --- src/memos/api/client.py | 1 - src/memos/mem_os/product.py | 38 ++++++++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/memos/api/client.py b/src/memos/api/client.py index 5e7947ff5..d45276f2c 100644 --- a/src/memos/api/client.py +++ b/src/memos/api/client.py @@ -14,7 +14,6 @@ MAX_RETRY_COUNT = 3 - class MemOSClient: """MemOS API client""" diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index d64643897..65942346f 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -217,6 +217,26 @@ def _restore_user_instances( except Exception as e: logger.error(f"Error during user instance restoration: {e}") + def _initialize_cube_from_default_config( + self, cube_id: str, user_id: str, default_config: GeneralMemCubeConfig + ) -> GeneralMemCube | None: + """ + Initialize a cube from default configuration when cube path doesn't exist. + + Args: + cube_id (str): The cube ID to initialize. + user_id (str): The user ID for the cube. + default_config (GeneralMemCubeConfig): The default configuration to use. + """ + cube_config = default_config.model_copy(deep=True) + # Safely modify the graph_db user_name if it exists + if cube_config.text_mem.config.graph_db.config: + cube_config.text_mem.config.graph_db.config.user_name = ( + f"memos{user_id.replace('-', '')}" + ) + mem_cube = GeneralMemCube(config=cube_config) + return mem_cube + def _preload_user_cubes( self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None ) -> None: @@ -286,8 +306,24 @@ def _load_user_cubes( ) else: logger.warning( - f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}" + f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, now init by default config" ) + cube_obj = self._initialize_cube_from_default_config( + cube_id=cube.cube_id, + user_id=user_id, + default_config=default_cube_config, + ) + if cube_obj: + self.register_mem_cube( + cube_obj, + cube.cube_id, + user_id, + memory_types=[], + ) + else: + raise ValueError( + f"Failed to initialize default cube {cube.cube_id} for user {user_id}" + ) except Exception as e: logger.error(f"Failed to load cube {cube.cube_id} for user {user_id}: {e}") logger.info(f"load user {user_id} cubes successfully") From 7aafbd0a772c0321ae8b465c619d349dd0842287 Mon Sep 17 00:00:00 2001 From: Kai Date: Fri, 26 Sep 2025 12:06:10 +0800 Subject: [PATCH 06/29] ruff format --- src/memos/memories/activation/kv.py | 3 ++- .../memories/textual/tree_text_memory/organize/handler.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 06cef794f..2fa08590f 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -1,9 +1,10 @@ import os import pickle + from datetime import datetime from importlib.metadata import version -from packaging.version import Version +from packaging.version import Version from transformers import DynamicCache from memos.configs.memory import KVCacheMemoryConfig diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py index a1121fcd2..271902ca0 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/handler.py +++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py @@ -1,5 +1,6 @@ import json import re + from datetime import datetime from dateutil import parser @@ -14,6 +15,7 @@ MEMORY_RELATION_RESOLVER_PROMPT, ) + logger = get_logger(__name__) From 04bc4fbe7d71d9e031f4a2b42502a7299b2d29eb Mon Sep 17 00:00:00 2001 From: Kai Date: Fri, 26 Sep 2025 13:00:01 +0800 Subject: [PATCH 07/29] feat:reoganize prompt with reference in user content -reformat --- src/memos/mem_os/product.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index eb5b3a12f..6f8e8b1c1 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -417,9 +417,12 @@ def _build_system_prompt( mem_block_o, mem_block_p = _format_mem_block(memories_all) mem_block = mem_block_o + "\n" + mem_block_p prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return (prefix + sys_body + - "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + - mem_block) + return ( + prefix + + sys_body + + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + + mem_block + ) def _build_base_system_prompt( self, @@ -433,10 +436,7 @@ def _build_base_system_prompt( """ now = datetime.now() formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt(date=formatted_date, - tone=tone, - verbosity=verbosity, - mode=mode) + sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode) prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" return prefix + sys_body @@ -454,12 +454,16 @@ def _build_memory_context( mem_block_o, mem_block_p = _format_mem_block(memories_all) if mode == "enhance": - return ("# Memories\n## PersonalMemory (ordered)\n" + mem_block_p + - "\n## OuterMemory (ordered)\n" + mem_block_o + "\n\n") + return ( + "# Memories\n## PersonalMemory (ordered)\n" + + mem_block_p + + "\n## OuterMemory (ordered)\n" + + mem_block_o + + "\n\n" + ) else: mem_block = mem_block_o + "\n" + mem_block_p - return ("# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + - mem_block + "\n\n") + return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n" def _build_enhance_system_prompt( self, @@ -981,16 +985,14 @@ def chat( memories_list = [] if memories_result: memories_list = memories_result[0]["memories"] - memories_list = self._filter_memories_by_threshold( - memories_list, threshold) + memories_list = self._filter_memories_by_threshold(memories_list, threshold) new_memories_list = [] for m in memories_list: m.metadata.embedding = [] new_memories_list.append(m) memories_list = new_memories_list # Build base system prompt without memory - system_prompt = self._build_base_system_prompt(base_prompt, - mode="base") + system_prompt = self._build_base_system_prompt(base_prompt, mode="base") # Build memory context to be included in user message memory_context = self._build_memory_context(memories_list, mode="base") @@ -1077,8 +1079,7 @@ def chat_with_references( system_prompt = self._build_base_system_prompt(mode="enhance") # Build memory context to be included in user message - memory_context = self._build_memory_context(memories_list, - mode="enhance") + memory_context = self._build_memory_context(memories_list, mode="enhance") # Combine memory context with user query user_content = memory_context + query if memory_context else query From 4cca56a5649ff2b850a5a873dcec3a9d2d04569a Mon Sep 17 00:00:00 2001 From: chentang Date: Fri, 26 Sep 2025 15:20:21 +0800 Subject: [PATCH 08/29] fix bugs to support eval answer hit with chat history only --- .../scripts/temporal_locomo/locomo_eval.py | 148 ++++++++-- .../temporal_locomo/locomo_processor.py | 276 ++++++++++-------- .../modules/base_eval_module.py | 19 +- .../modules/locomo_eval_module.py | 4 + .../temporal_locomo/modules/schemas.py | 32 +- .../temporal_locomo/temporal_locomo_eval.py | 35 +-- 6 files changed, 336 insertions(+), 178 deletions(-) diff --git a/evaluation/scripts/temporal_locomo/locomo_eval.py b/evaluation/scripts/temporal_locomo/locomo_eval.py index f19e5b68f..62ed209b6 100644 --- a/evaluation/scripts/temporal_locomo/locomo_eval.py +++ b/evaluation/scripts/temporal_locomo/locomo_eval.py @@ -281,33 +281,64 @@ def __init__(self, args): api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL") ) - async def run(self): - print( - f"\n=== Starting LoCoMo evaluation for {self.frame} (version: {self.version}) with {self.num_runs} run(s) per question ===" - ) - print(f"Using {self.max_workers} concurrent workers for processing groups") + def _load_response_data(self): + """ + Load response data from the response path file. + Returns: + dict: The loaded response data + """ with open(self.response_path) as file: - locomo_responses = json.load(file) + return json.load(file) - num_users = 10 + def _load_existing_evaluation_results(self): + """ + Attempt to load existing evaluation results from the judged path. + If the file doesn't exist or there's an error loading it, return an empty dict. + + Returns: + dict: Existing evaluation results or empty dict if none available + """ all_grades = {} + try: + if os.path.exists(self.judged_path): + with open(self.judged_path) as f: + all_grades = json.load(f) + print(f"Loaded existing evaluation results from {self.judged_path}") + except Exception as e: + print(f"Error loading existing evaluation results: {e}") - total_responses_count = sum( - len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users) - ) - print(f"Found {total_responses_count} total responses across {num_users} users to evaluate") + return all_grades + + def _create_evaluation_tasks(self, locomo_responses, all_grades, num_users): + """ + Create evaluation tasks for groups that haven't been evaluated yet. + + Args: + locomo_responses (dict): The loaded response data + all_grades (dict): Existing evaluation results + num_users (int): Number of user groups to process - # Create tasks for processing each group + Returns: + tuple: (tasks list, active users count) + """ tasks = [] active_users = 0 + for group_idx in range(num_users): group_id = f"locomo_exp_user_{group_idx}" group_responses = locomo_responses.get(group_id, []) + if not group_responses: print(f"No responses found for group {group_id}") continue + # Skip groups that already have evaluation results + if all_grades.get(group_id): + print(f"Skipping group {group_id} as it already has evaluation results") + active_users += 1 + continue + active_users += 1 tasks.append( process_single_group( @@ -319,29 +350,50 @@ async def run(self): ) ) - print(f"Starting evaluation of {active_users} user groups with responses") + return tasks, active_users + + async def _process_tasks(self, tasks): + """ + Process evaluation tasks with concurrency control. + + Args: + tasks (list): List of tasks to process + + Returns: + list: Results from processing all tasks + """ + if not tasks: + return [] semaphore = asyncio.Semaphore(self.max_workers) async def limited_task(task): + """Helper function to limit concurrent task execution""" async with semaphore: return await task limited_tasks = [limited_task(task) for task in tasks] - group_results = await asyncio.gather(*limited_tasks) + return await asyncio.gather(*limited_tasks) - for group_id, graded_responses in group_results: - all_grades[group_id] = graded_responses + def _calculate_scores(self, all_grades): + """ + Calculate evaluation scores based on all grades. - print("\n=== Evaluation Complete: Calculating final scores ===") + Args: + all_grades (dict): The complete evaluation results + Returns: + tuple: (run_scores, evaluated_count) + """ run_scores = [] evaluated_count = 0 + if self.num_runs > 0: for i in range(1, self.num_runs + 1): judgment_key = f"judgment_{i}" current_run_correct_count = 0 current_run_total_count = 0 + for group in all_grades.values(): for response in group: if judgment_key in response["llm_judgments"]: @@ -355,6 +407,16 @@ async def limited_task(task): evaluated_count = current_run_total_count + return run_scores, evaluated_count + + def _report_scores(self, run_scores, evaluated_count): + """ + Report evaluation scores to the console. + + Args: + run_scores (list): List of accuracy scores for each run + evaluated_count (int): Number of evaluated responses + """ if evaluated_count > 0: mean_of_scores = np.mean(run_scores) std_of_scores = np.std(run_scores) @@ -368,11 +430,63 @@ async def limited_task(task): print("No responses were evaluated") print("LLM-as-a-Judge score: N/A (0/0)") + def _save_results(self, all_grades): + """ + Save evaluation results to the judged path file. + + Args: + all_grades (dict): The complete evaluation results to save + """ all_grades = convert_numpy_types(all_grades) with open(self.judged_path, "w") as f: json.dump(all_grades, f, indent=2) print(f"Saved detailed evaluation results to {self.judged_path}") + async def run(self): + """ + Main execution method for the LoCoMo evaluation process. + This method orchestrates the entire evaluation workflow: + 1. Loads existing evaluation results if available + 2. Processes only groups that haven't been evaluated yet + 3. Calculates and reports final evaluation scores + """ + print( + f"\n=== Starting LoCoMo evaluation for {self.frame} (version: {self.version}) with {self.num_runs} run(s) per question ===" + ) + print(f"Using {self.max_workers} concurrent workers for processing groups") + + # Load response data and existing evaluation results + locomo_responses = self._load_response_data() + all_grades = self._load_existing_evaluation_results() + + # Count total responses for reporting + num_users = 10 + total_responses_count = sum( + len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users) + ) + print(f"Found {total_responses_count} total responses across {num_users} users to evaluate") + + # Create tasks only for groups that haven't been evaluated yet + tasks, active_users = self._create_evaluation_tasks(locomo_responses, all_grades, num_users) + print( + f"Starting evaluation of {len(tasks)} user groups with responses (out of {active_users} active users)" + ) + + # Process tasks and update all_grades with results + if tasks: + group_results = await self._process_tasks(tasks) + for group_id, graded_responses in group_results: + all_grades[group_id] = graded_responses + + print("\n=== Evaluation Complete: Calculating final scores ===") + + # Calculate and report scores + run_scores, evaluated_count = self._calculate_scores(all_grades) + self._report_scores(run_scores, evaluated_count) + + # Save results + self._save_results(all_grades) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/evaluation/scripts/temporal_locomo/locomo_processor.py b/evaluation/scripts/temporal_locomo/locomo_processor.py index 4ae9cf915..3fd1ca59c 100644 --- a/evaluation/scripts/temporal_locomo/locomo_processor.py +++ b/evaluation/scripts/temporal_locomo/locomo_processor.py @@ -8,7 +8,6 @@ from dotenv import load_dotenv from modules.constants import ( - MEMOS_MODEL, MEMOS_SCHEDULER_MODEL, ) from modules.locomo_eval_module import LocomoEvalModelModules @@ -54,77 +53,22 @@ def __init__(self, args): self.processed_data_dir = self.result_dir / "processed_data" def update_context(self, conv_id, method, **kwargs): - if method == ContextUpdateMethod.DIRECT: + if method == ContextUpdateMethod.CHAT_HISTORY: + if "query" not in kwargs or "answer" not in kwargs: + raise ValueError("query and answer are required for TEMPLATE update method") + new_context = f"User: {kwargs['query']}\nAssistant: {kwargs['answer']}\n\n" + if self.pre_context_cache[conv_id] is None: + self.pre_context_cache[conv_id] = "" + self.pre_context_cache[conv_id] += new_context + else: if "cur_context" not in kwargs: raise ValueError("cur_context is required for DIRECT update method") cur_context = kwargs["cur_context"] self.pre_context_cache[conv_id] = cur_context - elif method == ContextUpdateMethod.TEMPLATE: - if "query" not in kwargs or "answer" not in kwargs: - raise ValueError("query and answer are required for TEMPLATE update method") - self._update_context_template(conv_id, kwargs["query"], kwargs["answer"]) - else: - raise ValueError(f"Unsupported update method: {method}") - - def _update_context_template(self, conv_id, query, answer): - new_context = f"User: {query}\nAssistant: {answer}\n\n" - if self.pre_context_cache[conv_id] is None: - self.pre_context_cache[conv_id] = "" - self.pre_context_cache[conv_id] += new_context - - def _process_single_qa( - self, - qa, - *, - client, - reversed_client, - metadata, - frame, - version, - conv_id, - conv_stats_path, - oai_client, - top_k, - conv_stats, - ): - query = qa.get("question") - gold_answer = qa.get("answer") - qa_category = qa.get("category") - if qa_category == 5: - return None - - # Search - cur_context, search_duration_ms = self.search_query( - client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k - ) - if not cur_context: - logger.warning(f"No context found for query: {query[:100]}") - cur_context = "" - # Context answerability analysis (for memos_scheduler only) - if self.pre_context_cache[conv_id] is None: - # Update pre-context cache with current context - if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - cur_context=cur_context, - ) - else: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - query=query, - answer=gold_answer, - ) - return None - - can_answer = False - can_answer_duration_ms = 0.0 + def eval_context(self, context, query, gold_answer, oai_client): can_answer_start = time() - can_answer = self.analyze_context_answerability( - self.pre_context_cache[conv_id], query, gold_answer, oai_client - ) + can_answer = self.analyze_context_answerability(context, query, gold_answer, oai_client) can_answer_duration_ms = (time() - can_answer_start) * 1000 # Update global stats with self.stats_lock: @@ -143,54 +87,41 @@ def _process_single_qa( can_answer_duration_ms ) self.save_stats() + return can_answer, can_answer_duration_ms - # Generate answer - answer_start = time() - answer = self.locomo_response(frame, oai_client, self.pre_context_cache[conv_id], query) - response_duration_ms = (time() - answer_start) * 1000 - - # Record case for memos_scheduler - if frame == "memos_scheduler": - try: - recording_case = RecordingCase( - conv_id=conv_id, - query=query, - answer=answer, - context=cur_context, - pre_context=self.pre_context_cache[conv_id], - can_answer=can_answer, - can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", - search_duration_ms=search_duration_ms, - can_answer_duration_ms=can_answer_duration_ms, - response_duration_ms=response_duration_ms, - category=int(qa_category) if qa_category is not None else None, - golden_answer=str(qa.get("answer", "")), - memories=[], - pre_memories=[], - history_queries=[], - ) - if can_answer: - self.can_answer_cases.append(recording_case) - else: - self.cannot_answer_cases.append(recording_case) - except Exception as e: - logger.error(f"Error creating RecordingCase: {e}") - print(f"Error creating RecordingCase: {e}") - logger.error(f"QA data: {qa}") - print(f"QA data: {qa}") - logger.error(f"Query: {query}") - logger.error(f"Answer: {answer}") - logger.error( - f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" - ) - logger.error(f"Category: {qa_category} (type: {type(qa_category)})") - logger.error(f"Can answer: {can_answer}") - raise e - + def _update_stats_and_context( + self, + *, + conv_id, + frame, + version, + conv_stats, + conv_stats_path, + query, + answer, + gold_answer, + cur_context, + can_answer, + ): + """ + Update conversation statistics and context. + + Args: + conv_id: Conversation ID + frame: Model frame + version: Model version + conv_stats: Conversation statistics dictionary + conv_stats_path: Path to save conversation statistics + query: User query + answer: Generated answer + gold_answer: Golden answer + cur_context: Current context + can_answer: Whether the context can answer the query + """ # Update conversation stats conv_stats["total_queries"] += 1 conv_stats["response_count"] += 1 - if frame == "memos_scheduler": + if frame == MEMOS_SCHEDULER_MODEL: if can_answer: conv_stats["can_answer_count"] += 1 else: @@ -208,22 +139,137 @@ def _process_single_qa( # Update pre-context cache with current context with self.stats_lock: - if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: + if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: self.update_context( conv_id=conv_id, method=self.context_update_method, - cur_context=cur_context, + query=query, + answer=answer, ) else: self.update_context( conv_id=conv_id, method=self.context_update_method, - query=query, - answer=gold_answer, + cur_context=cur_context, ) self.print_eval_info() + def _process_single_qa( + self, + qa, + *, + client, + reversed_client, + metadata, + frame, + version, + conv_id, + conv_stats_path, + oai_client, + top_k, + conv_stats, + ): + query = qa.get("question") + gold_answer = qa.get("answer") + qa_category = qa.get("category") + if qa_category == 5: + return None + + # Search + cur_context, search_duration_ms = self.search_query( + client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k + ) + if not cur_context: + logger.warning(f"No context found for query: {query[:100]}") + cur_context = "" + + if self.context_update_method == ContextUpdateMethod.CURRENT_CONTEXT: + context = cur_context + else: + # Context answer ability analysis (for memos_scheduler only) + if self.pre_context_cache[conv_id] is None: + # Update pre-context cache with current context and return + if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: + answer_from_cur_context = self.locomo_response( + frame, oai_client, cur_context, query + ) + self.update_context( + conv_id=conv_id, + method=self.context_update_method, + query=query, + answer=answer_from_cur_context, + ) + else: + self.update_context( + conv_id=conv_id, + method=self.context_update_method, + cur_context=cur_context, + ) + return None + else: + context = self.pre_context_cache[conv_id] + + # Generate answer + answer_start = time() + answer = self.locomo_response(frame, oai_client, context, query) + response_duration_ms = (time() - answer_start) * 1000 + + can_answer, can_answer_duration_ms = self.eval_context( + context=context, query=query, gold_answer=gold_answer, oai_client=oai_client + ) + + # Record case for memos_scheduler + try: + recording_case = RecordingCase( + conv_id=conv_id, + query=query, + answer=answer, + context=cur_context, + pre_context=self.pre_context_cache[conv_id], + can_answer=can_answer, + can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", + search_duration_ms=search_duration_ms, + can_answer_duration_ms=can_answer_duration_ms, + response_duration_ms=response_duration_ms, + category=int(qa_category) if qa_category is not None else None, + golden_answer=str(qa.get("answer", "")), + ) + if can_answer: + self.can_answer_cases.append(recording_case) + else: + self.cannot_answer_cases.append(recording_case) + except Exception as e: + logger.error(f"Error creating RecordingCase: {e}") + print(f"Error creating RecordingCase: {e}") + logger.error(f"QA data: {qa}") + print(f"QA data: {qa}") + logger.error(f"Query: {query}") + logger.error(f"Answer: {answer}") + logger.error( + f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" + ) + logger.error(f"Category: {qa_category} (type: {type(qa_category)})") + logger.error(f"Can answer: {can_answer}") + raise e + + if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: + answer_from_cur_context = self.locomo_response(frame, oai_client, cur_context, query) + answer = answer_from_cur_context + # Update conversation stats and context + self._update_stats_and_context( + conv_id=conv_id, + frame=frame, + version=version, + conv_stats=conv_stats, + conv_stats_path=conv_stats_path, + query=query, + answer=answer, + gold_answer=gold_answer, + cur_context=cur_context, + can_answer=can_answer, + ) + return { "question": query, "answer": answer, @@ -233,7 +279,7 @@ def _process_single_qa( "response_duration_ms": response_duration_ms, "search_duration_ms": search_duration_ms, "can_answer_duration_ms": can_answer_duration_ms, - "can_answer": can_answer if frame == "memos_scheduler" else None, + "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None, } def run_locomo_processing(self, num_users=10): diff --git a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py index 4ec7d4922..f8db11fbc 100644 --- a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py +++ b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py @@ -16,7 +16,6 @@ from .constants import ( BASE_DIR, - MEMOS_MODEL, MEMOS_SCHEDULER_MODEL, ) from .prompts import ( @@ -42,10 +41,9 @@ def __init__(self, args): self.top_k = self.args.top_k # attributes - if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - self.context_update_method = ContextUpdateMethod.DIRECT - else: - self.context_update_method = ContextUpdateMethod.TEMPLATE + self.context_update_method = getattr( + self.args, "context_update_method", ContextUpdateMethod.PRE_CONTEXT + ) self.custom_instructions = CUSTOM_INSTRUCTIONS self.data_dir = Path(f"{BASE_DIR}/data") self.locomo_df = pd.read_json(f"{self.data_dir}/locomo/locomo10.json") @@ -64,7 +62,7 @@ def __init__(self, args): # Configure result dir; if scheduler disabled and using memos scheduler, mark as ablation if ( hasattr(self.args, "scheduler_flag") - and self.frame == "memos_scheduler" + and self.frame == MEMOS_SCHEDULER_MODEL and self.args.scheduler_flag is False ): self.result_dir = Path( @@ -74,6 +72,11 @@ def __init__(self, args): self.result_dir = Path( f"{BASE_DIR}/results/temporal_locomo/{self.frame}-{self.version}/" ) + + if self.context_update_method != ContextUpdateMethod.PRE_CONTEXT: + self.result_dir = ( + self.result_dir.parent / f"{self.result_dir.name}_{self.context_update_method}" + ) self.result_dir.mkdir(parents=True, exist_ok=True) self.search_path = self.result_dir / f"{self.frame}-{self.version}_search_results.json" @@ -135,10 +138,6 @@ def __init__(self, args): # Statistics tracking with thread safety self.stats = {self.frame: {self.version: defaultdict(dict)}} - self.stats[self.frame][self.version]["response_stats"] = defaultdict(dict) - self.stats[self.frame][self.version]["response_stats"]["response_failure"] = 0 - self.stats[self.frame][self.version]["response_stats"]["response_count"] = 0 - self.stats[self.frame][self.version]["memory_stats"] = defaultdict(dict) self.stats[self.frame][self.version]["memory_stats"]["total_queries"] = 0 self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] = 0 diff --git a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py index c824fe5f4..4a56b599b 100644 --- a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py +++ b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py @@ -194,6 +194,10 @@ def memos_scheduler_search( start = time.time() client: MOS = client + if not self.scheduler_flag: + # if not scheduler_flag, search to update working memory + self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client) + # Search for speaker A search_a_results = client.mem_scheduler.search_for_eval( query=query, diff --git a/evaluation/scripts/temporal_locomo/modules/schemas.py b/evaluation/scripts/temporal_locomo/modules/schemas.py index e5872c35d..a41b7539d 100644 --- a/evaluation/scripts/temporal_locomo/modules/schemas.py +++ b/evaluation/scripts/temporal_locomo/modules/schemas.py @@ -1,14 +1,23 @@ -from enum import Enum from typing import Any from pydantic import BaseModel, Field -class ContextUpdateMethod(Enum): +class ContextUpdateMethod: """Enumeration for context update methods""" - DIRECT = "direct" # Directly update with current context - TEMPLATE = "chat_history" # Update using template with history queries and answers + PRE_CONTEXT = "pre_context" + CHAT_HISTORY = "chat_history" + CURRENT_CONTEXT = "current_context" + + @classmethod + def values(cls): + """Return a list of all constant values""" + return [ + getattr(cls, attr) + for attr in dir(cls) + if not attr.startswith("_") and isinstance(getattr(cls, attr), str) + ] class RecordingCase(BaseModel): @@ -22,11 +31,6 @@ class RecordingCase(BaseModel): # Conversation identification conv_id: str = Field(description="Conversation identifier for this evaluation case") - # Conversation history and context - history_queries: list[str] = Field( - default_factory=list, description="List of previous queries in the conversation history" - ) - context: str = Field( default="", description="Current search context retrieved from memory systems for answering the query", @@ -42,16 +46,6 @@ class RecordingCase(BaseModel): answer: str = Field(description="The generated answer for the query") - # Memory data - memories: list[Any] = Field( - default_factory=list, - description="Current memories retrieved from the memory system for this query", - ) - - pre_memories: list[Any] | None = Field( - default=None, description="Previous memories from the last query, used for comparison" - ) - # Evaluation metrics can_answer: bool | None = Field( default=None, diff --git a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py index 0a2c20a0e..aab5738fc 100644 --- a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py +++ b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py @@ -10,6 +10,7 @@ from locomo_metric import LocomoMetric from locomo_processor import LocomoProcessor from modules.locomo_eval_module import LocomoEvalModelModules +from modules.schemas import ContextUpdateMethod from modules.utils import compute_can_answer_count_by_pre_evidences from memos.log import get_logger @@ -29,6 +30,8 @@ def __init__(self, args): self.locomo_ingestor = LocomoIngestor(args=args) self.locomo_processor = LocomoProcessor(args=args) + self.locomo_evaluator = LocomoEvaluator(args=args) + self.locomo_metric = LocomoMetric(args=args) def run_eval_pipeline(self): """ @@ -53,14 +56,7 @@ def run_eval_pipeline(self): print("\n" + "=" * 50) print("Step 2: Data Ingestion") print("=" * 50) - if not self.ingestion_storage_dir.exists() or not any(self.ingestion_storage_dir.iterdir()): - print(f"Directory {self.ingestion_storage_dir} not found, starting data ingestion...") - self.locomo_ingestor.run_ingestion() - print("Data ingestion completed.") - else: - print( - f"Directory {self.ingestion_storage_dir} already exists and is not empty, skipping ingestion." - ) + self.locomo_ingestor.run_ingestion() # Step 3: Processing and evaluation print("\n" + "=" * 50) @@ -74,22 +70,20 @@ def run_eval_pipeline(self): # Optional: run post-hoc evaluation over generated responses if available try: - evaluator = LocomoEvaluator(args=args) - - if os.path.exists(evaluator.response_path): + if os.path.exists(self.response_path): print("Running LocomoEvaluator over existing response results...") - asyncio.run(evaluator.run()) + asyncio.run(self.locomo_evaluator.run()) else: print( f"Skipping LocomoEvaluator: response file not found at {evaluator.response_path}" ) # Run metrics summarization if judged file is produced - metric = LocomoMetric(args=args) - if os.path.exists(metric.judged_path): + + if os.path.exists(self.judged_path): print("Running LocomoMetric over judged results...") - metric.run() + self.locomo_metric.run() else: - print(f"Skipping LocomoMetric: judged file not found at {metric.judged_path}") + print(f"Skipping LocomoMetric: judged file not found at {self.judged_path}") except Exception as e: logger.error(f"LocomoEvaluator step skipped due to error: {e}", exc_info=True) @@ -143,9 +137,16 @@ def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): parser.add_argument( "--scheduler-flag", action=argparse.BooleanOptionalAction, - default=True, + default=False, help="Enable or disable memory scheduler features", ) + parser.add_argument( + "--context_update_method", + type=str, + default="chat_history", + choices=ContextUpdateMethod.values(), + help="Method to update context: direct (use current context directly), chat_history (use template with history), current_context (use current context)", + ) args = parser.parse_args() evaluator = TemporalLocomoEval(args=args) From b6834d3d6b6717a0d750fb119559496692c3ff2d Mon Sep 17 00:00:00 2001 From: chentang Date: Fri, 26 Sep 2025 15:58:34 +0800 Subject: [PATCH 09/29] change the consume interval from 3 to 0.5 seconds, and refactor the code structure of temporal locomo. --- evaluation/__init__.py | 0 evaluation/scripts/__init__.py | 0 .../temporal_locomo/models/__init__.py | 0 .../{ => models}/locomo_eval.py | 2 +- .../{ => models}/locomo_ingestion.py | 8 ++--- .../{ => models}/locomo_metric.py | 2 +- .../{ => models}/locomo_processor.py | 12 +++---- .../temporal_locomo/temporal_locomo_eval.py | 36 ++++++++++--------- src/memos/configs/mem_scheduler.py | 2 +- .../mem_scheduler/schemas/general_schemas.py | 2 +- 10 files changed, 33 insertions(+), 31 deletions(-) create mode 100644 evaluation/__init__.py create mode 100644 evaluation/scripts/__init__.py create mode 100644 evaluation/scripts/temporal_locomo/models/__init__.py rename evaluation/scripts/temporal_locomo/{ => models}/locomo_eval.py (99%) rename evaluation/scripts/temporal_locomo/{ => models}/locomo_ingestion.py (98%) rename evaluation/scripts/temporal_locomo/{ => models}/locomo_metric.py (99%) rename evaluation/scripts/temporal_locomo/{ => models}/locomo_processor.py (97%) diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/__init__.py b/evaluation/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/temporal_locomo/models/__init__.py b/evaluation/scripts/temporal_locomo/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/temporal_locomo/locomo_eval.py b/evaluation/scripts/temporal_locomo/models/locomo_eval.py similarity index 99% rename from evaluation/scripts/temporal_locomo/locomo_eval.py rename to evaluation/scripts/temporal_locomo/models/locomo_eval.py index 62ed209b6..f98a481e2 100644 --- a/evaluation/scripts/temporal_locomo/locomo_eval.py +++ b/evaluation/scripts/temporal_locomo/models/locomo_eval.py @@ -9,7 +9,6 @@ from bert_score import score as bert_score from dotenv import load_dotenv -from modules.locomo_eval_module import LocomoEvalModelModules from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu from nltk.translate.meteor_score import meteor_score from openai import AsyncOpenAI @@ -19,6 +18,7 @@ from sentence_transformers import SentenceTransformer from tqdm import tqdm +from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules from memos.log import get_logger diff --git a/evaluation/scripts/temporal_locomo/locomo_ingestion.py b/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py similarity index 98% rename from evaluation/scripts/temporal_locomo/locomo_ingestion.py rename to evaluation/scripts/temporal_locomo/models/locomo_ingestion.py index 321302cf2..b45ec3d61 100644 --- a/evaluation/scripts/temporal_locomo/locomo_ingestion.py +++ b/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py @@ -6,16 +6,16 @@ from datetime import datetime, timezone from pathlib import Path -from modules.constants import ( +from tqdm import tqdm + +from evaluation.scripts.temporal_locomo.modules.constants import ( MEM0_GRAPH_MODEL, MEM0_MODEL, MEMOS_MODEL, MEMOS_SCHEDULER_MODEL, ZEP_MODEL, ) -from modules.locomo_eval_module import LocomoEvalModelModules -from tqdm import tqdm - +from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules from memos.log import get_logger diff --git a/evaluation/scripts/temporal_locomo/locomo_metric.py b/evaluation/scripts/temporal_locomo/models/locomo_metric.py similarity index 99% rename from evaluation/scripts/temporal_locomo/locomo_metric.py rename to evaluation/scripts/temporal_locomo/models/locomo_metric.py index 0187c37e7..532fe2e14 100644 --- a/evaluation/scripts/temporal_locomo/locomo_metric.py +++ b/evaluation/scripts/temporal_locomo/models/locomo_metric.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd -from modules.locomo_eval_module import LocomoEvalModelModules +from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules # Category mapping as per your request diff --git a/evaluation/scripts/temporal_locomo/locomo_processor.py b/evaluation/scripts/temporal_locomo/models/locomo_processor.py similarity index 97% rename from evaluation/scripts/temporal_locomo/locomo_processor.py rename to evaluation/scripts/temporal_locomo/models/locomo_processor.py index 3fd1ca59c..7cec6f5af 100644 --- a/evaluation/scripts/temporal_locomo/locomo_processor.py +++ b/evaluation/scripts/temporal_locomo/models/locomo_processor.py @@ -7,19 +7,19 @@ from time import time from dotenv import load_dotenv -from modules.constants import ( + +from evaluation.scripts.temporal_locomo.modules.constants import ( MEMOS_SCHEDULER_MODEL, ) -from modules.locomo_eval_module import LocomoEvalModelModules -from modules.prompts import ( +from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules +from evaluation.scripts.temporal_locomo.modules.prompts import ( SEARCH_PROMPT_MEM0, SEARCH_PROMPT_MEM0_GRAPH, SEARCH_PROMPT_MEMOS, SEARCH_PROMPT_ZEP, ) -from modules.schemas import ContextUpdateMethod, RecordingCase -from modules.utils import save_evaluation_cases - +from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase +from evaluation.scripts.temporal_locomo.modules.utils import save_evaluation_cases from memos.log import get_logger diff --git a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py index aab5738fc..c21bcfc1c 100644 --- a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py +++ b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py @@ -5,14 +5,14 @@ from pathlib import Path -from locomo_eval import LocomoEvaluator -from locomo_ingestion import LocomoIngestor -from locomo_metric import LocomoMetric -from locomo_processor import LocomoProcessor from modules.locomo_eval_module import LocomoEvalModelModules from modules.schemas import ContextUpdateMethod from modules.utils import compute_can_answer_count_by_pre_evidences +from evaluation.scripts.temporal_locomo.models.locomo_eval import LocomoEvaluator +from evaluation.scripts.temporal_locomo.models.locomo_ingestion import LocomoIngestor +from evaluation.scripts.temporal_locomo.models.locomo_metric import LocomoMetric +from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor from memos.log import get_logger @@ -33,7 +33,7 @@ def __init__(self, args): self.locomo_evaluator = LocomoEvaluator(args=args) self.locomo_metric = LocomoMetric(args=args) - def run_eval_pipeline(self): + def run_eval_pipeline(self, skip_ingestion=True, skip_processing=False): """ Run the complete evaluation pipeline including dataset conversion, data ingestion, and processing. @@ -53,20 +53,22 @@ def run_eval_pipeline(self): print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.") # Step 2: Data ingestion - print("\n" + "=" * 50) - print("Step 2: Data Ingestion") - print("=" * 50) - self.locomo_ingestor.run_ingestion() + if not skip_ingestion: + print("\n" + "=" * 50) + print("Step 2: Data Ingestion") + print("=" * 50) + self.locomo_ingestor.run_ingestion() # Step 3: Processing and evaluation - print("\n" + "=" * 50) - print("Step 3: Processing and Evaluation") - print("=" * 50) - print("Running locomo processing to search and answer...") - - print("Starting locomo processing to generate search and response results...") - self.locomo_processor.run_locomo_processing(num_users=self.num_of_users) - print("Processing completed successfully.") + if not skip_processing: + print("\n" + "=" * 50) + print("Step 3: Processing and Evaluation") + print("=" * 50) + print("Running locomo processing to search and answer...") + + print("Starting locomo processing to generate search and response results...") + self.locomo_processor.run_locomo_processing(num_users=self.num_of_users) + print("Processing completed successfully.") # Optional: run post-hoc evaluation over generated responses if available try: diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index a36f3e2f8..90ed6a272 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -30,7 +30,7 @@ class BaseSchedulerConfig(BaseConfig): lt=20, description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD__POOL_MAX_WORKERS})", ) - consume_interval_seconds: int = Field( + consume_interval_seconds: float = Field( default=DEFAULT_CONSUME_INTERVAL_SECONDS, gt=0, le=60, diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index a81caf5a8..1ac651ca7 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -18,7 +18,7 @@ DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 20 DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" DEFAULT_THREAD__POOL_MAX_WORKERS = 5 -DEFAULT_CONSUME_INTERVAL_SECONDS = 3 +DEFAULT_CONSUME_INTERVAL_SECONDS = 0.5 NOT_INITIALIZED = -1 From 1ca5eadfee9a7ac9714c7fe5b6764f632b88e184 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Sun, 28 Sep 2025 11:59:05 +0800 Subject: [PATCH 10/29] feat: refactor search --- src/memos/api/config.py | 4 +- src/memos/api/product_models.py | 1 - src/memos/api/routers/product_router.py | 66 +++-- src/memos/configs/graph_db.py | 12 - src/memos/graph_dbs/nebular.py | 238 +++++++----------- .../tree_text_memory/retrieve/recall.py | 18 +- .../tree_text_memory/retrieve/searcher.py | 152 +++-------- 7 files changed, 186 insertions(+), 305 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 355ee0385..0d44b8963 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -251,8 +251,8 @@ def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: "user": os.getenv("NEBULAR_USER", "root"), "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"), "space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"), - "user_name": f"memos{user_id.replace('-', '')}", - "use_multi_db": False, + # "user_name": f"memos{user_id.replace('-', '')}", + # "use_multi_db": False, "auto_create": True, "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 7e425415b..898d67fe1 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -159,7 +159,6 @@ class SearchRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Search query") - mem_cube_id: str | None = Field(None, description="Cube ID to search in") top_k: int = Field(10, description="Number of results to return") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 75b614cf4..d9764eed6 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -213,29 +213,55 @@ def create_memory(memory_req: MemoryCreateRequest): raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + + +from memos.configs.embedder import UniversalAPIEmbedderConfig +from memos.configs.graph_db import NebulaGraphDBConfig +from memos.configs.llm import OpenAILLMConfig +from memos.embedders.universal_api import UniversalAPIEmbedder +from memos.graph_dbs.nebular import NebulaGraphDB +from memos.llms.openai import OpenAILLM +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.cosine_local import CosineLocalReranker + +llm = OpenAILLM( + OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o', + temperature=0.8, max_tokens=1024, top_p=0.9, top_k=50, remove_think_prefix=True, + api_key='sk-ZOPtsVgzxqnc8vGAmlTQTmnrpxK8me44fsEoX9bRTXFseh5Y', + api_base='http://123.129.219.111:3000/v1', extra_body=None)) +embedder = UniversalAPIEmbedder( + UniversalAPIEmbedderConfig(model_schema='memos.configs.embedder.UniversalAPIEmbedderConfig', + model_name_or_path='bge-m3', embedding_dims=None, provider='openai', + api_key='EMPTY', base_url='http://106.75.235.231:8081/v1')) + +reranker = CosineLocalReranker(level_weights={"topic": 1.0, "concept": 1.0, "fact": 1.0}, level_field='background') + +graph_store = NebulaGraphDB( + NebulaGraphDBConfig(model_schema='memos.configs.graph_db.NebulaGraphDBConfig', + uri=['106.14.142.60:9669', '120.55.160.164:9669', '106.15.38.5:9669'], + user='root', password='NebulaMemOS0724', space='shared-tree-textual-memory-product-preandtest', + auto_create=True, max_client=1000, embedding_dimension=1024)) + +s = Searcher(llm, graph_store, embedder, reranker, internet_retriever=None, moscube=False) + + @router.post("/search", summary="Search memories", response_model=SearchResponse) def search_memories(search_req: SearchRequest): """Search memories for a specific user.""" - try: - time_start_search = time.time() - mos_product = get_mos_product_instance() - result = mos_product.search( - query=search_req.query, - user_id=search_req.user_id, - install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None, - top_k=search_req.top_k, - session_id=search_req.session_id, - ) - logger.info( - f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}" - ) - return SearchResponse(message="Search completed successfully", data=result) - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to search memories: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + # try: + user_id = f"memos{search_req.user_id.replace('-', '')}" + res = s.search(query=search_req.query, user_id=user_id, top_k=search_req.top_k + , mode="fast", search_filter=None, + info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []}) + res = {"d": res} + print(res) + return SearchResponse(message="Search completed successfully", data=res) + + # except ValueError as err: + # raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + # except Exception as err: + # logger.error(f"Failed to search memories: {traceback.format_exc()}") + # raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err @router.post("/chat", summary="Chat with MemOS") diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 2df917166..49703bb69 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -124,22 +124,10 @@ class NebulaGraphDBConfig(BaseGraphDBConfig): space: str = Field( ..., description="The name of the target NebulaGraph space (like a database)" ) - user_name: str | None = Field( - default=None, - description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)", - ) auto_create: bool = Field( default=False, description="Whether to auto-create the space if it does not exist", ) - use_multi_db: bool = Field( - default=True, - description=( - "If True: use Neo4j's multi-database feature for physical isolation; " - "each user typically gets a separate database. " - "If False: use a single shared database with logical isolation by user_name." - ), - ) max_client: int = Field( default=1000, description=("max_client"), diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 66ad894ad..952260df4 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -129,7 +129,6 @@ def _make_client_key(cfg: NebulaGraphDBConfig) -> str: "nebula-sync", ",".join(hosts), str(getattr(cfg, "user", "")), - str(getattr(cfg, "use_multi_db", False)), str(getattr(cfg, "space", "")), ] ) @@ -139,7 +138,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " tmp = object.__new__(NebulaGraphDB) tmp.config = cfg tmp.db_name = cfg.space - tmp.user_name = getattr(cfg, "user_name", None) + tmp.user_name = None tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072) tmp.default_memory_dimension = 3072 tmp.common_fields = { @@ -169,7 +168,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) else "embedding" ) - tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space + tmp.system_db_name = cfg.space tmp._client = client tmp._owns_client = False return tmp @@ -320,7 +319,6 @@ def __init__(self, config: NebulaGraphDBConfig): self.config = config self.db_name = config.space - self.user_name = config.user_name self.embedding_dimension = config.embedding_dimension self.default_memory_dimension = 3072 self.common_fields = { @@ -350,7 +348,7 @@ def __init__(self, config: NebulaGraphDBConfig): if (str(self.embedding_dimension) != str(self.default_memory_dimension)) else "embedding" ) - self.system_db_name = "system" if config.use_multi_db else config.space + self.system_db_name = config.space # ---- NEW: pool acquisition strategy # Get or create a shared pool from the class-level cache @@ -417,18 +415,12 @@ def create_index( self._create_basic_property_indexes() @timed - def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: + def remove_oldest_memory(self, memory_type: str, keep_latest: int, user_name: str = None) -> None: """ Remove all WorkingMemory nodes except the latest `keep_latest` entries. - - Args: - memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). - keep_latest (int): Number of latest WorkingMemory entries to keep. """ - optional_condition = "" - if not self.config.use_multi_db and self.config.user_name: - optional_condition = f"AND n.user_name = '{self.config.user_name}'" + optional_condition = f"AND n.user_name = '{user_name}'" query = f""" MATCH (n@Memory) WHERE n.memory_type = '{memory_type}' @@ -440,13 +432,12 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: self.execute_query(query) @timed - def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: + def add_node(self, id: str, memory: str, metadata: dict[str, Any], user_name: str = None) -> None: """ Insert or update a Memory node in NebulaGraph. """ - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name + metadata["user_name"] = user_name now = datetime.utcnow() metadata = metadata.copy() metadata.setdefault("created_at", now) @@ -475,11 +466,8 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: ) @timed - def node_not_exist(self, scope: str) -> int: - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"' - else: - filter_clause = f'n.memory_type = "{scope}"' + def node_not_exist(self, scope: str, user_name: str = None) -> int: + filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"' query = f""" MATCH (n@Memory) WHERE {filter_clause} @@ -495,7 +483,7 @@ def node_not_exist(self, scope: str) -> int: raise @timed - def update_node(self, id: str, fields: dict[str, Any]) -> None: + def update_node(self, id: str, fields: dict[str, Any], user_name: str = None) -> None: """ Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present. """ @@ -509,45 +497,40 @@ def update_node(self, id: str, fields: dict[str, Any]) -> None: query = f""" MATCH (n@Memory {{id: "{id}"}}) """ - - if not self.config.use_multi_db and self.config.user_name: - query += f'WHERE n.user_name = "{self.config.user_name}"' + query += f'WHERE n.user_name = "{user_name}"' query += f"\nSET {set_clause_str}" self.execute_query(query) @timed - def delete_node(self, id: str) -> None: + def delete_node(self, id: str, user_name: str = None) -> None: """ Delete a node from the graph. Args: id: Node identifier to delete. + user_name (str, optional): User name for filtering in non-multi-db mode """ query = f""" - MATCH (n@Memory {{id: "{id}"}}) + MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)} + DETACH DELETE n """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" WHERE n.user_name = {self._format_value(user_name)}" - query += "\n DETACH DELETE n" self.execute_query(query) @timed - def add_edge(self, source_id: str, target_id: str, type: str): + def add_edge(self, source_id: str, target_id: str, type: str, user_name: str = None): """ Create an edge from source node to target node. Args: source_id: ID of the source node. target_id: ID of the target node. type: Relationship type (e.g., 'RELATE_TO', 'PARENT'). + user_name (str, optional): User name for filtering in non-multi-db mode """ if not source_id or not target_id: raise ValueError("[add_edge] source_id and target_id must be provided") props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' - + props = f'{{user_name: "{user_name}"}}' insert_stmt = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT (a) -[e@{type} {props}]-> (b) @@ -558,35 +541,31 @@ def add_edge(self, source_id: str, target_id: str, type: str): logger.error(f"Failed to insert edge: {e}", exc_info=True) @timed - def delete_edge(self, source_id: str, target_id: str, type: str) -> None: + def delete_edge(self, source_id: str, target_id: str, type: str, user_name: str = None) -> None: """ Delete a specific edge between two nodes. Args: source_id: ID of the source node. target_id: ID of the target node. type: Relationship type to remove. + user_name (str, optional): User name for filtering in non-multi-db mode """ query = f""" MATCH (a@Memory) -[r@{type}]-> (b@Memory) WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" - + query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" query += "\nDELETE r" self.execute_query(query) @timed - def get_memory_count(self, memory_type: str) -> int: + def get_memory_count(self, memory_type: str, user_name: str = None) -> int: query = f""" MATCH (n@Memory) WHERE n.memory_type = "{memory_type}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN COUNT(n) AS count" try: @@ -597,14 +576,12 @@ def get_memory_count(self, memory_type: str) -> int: return -1 @timed - def count_nodes(self, scope: str) -> int: + def count_nodes(self, scope: str, user_name: str = None) -> int: query = f""" MATCH (n@Memory) WHERE n.memory_type = "{scope}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN count(n) AS count" result = self.execute_query(query) @@ -612,7 +589,7 @@ def count_nodes(self, scope: str) -> int: @timed def edge_exists( - self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING" + self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING", user_name: str = None ) -> bool: """ Check if an edge exists between two nodes. @@ -622,6 +599,7 @@ def edge_exists( type: Relationship type. Use "ANY" to match any relationship type. direction: Direction of the edge. Use "OUTGOING" (default), "INCOMING", or "ANY". + user_name (str, optional): User name for filtering in non-multi-db mode Returns: True if the edge exists, otherwise False. """ @@ -640,9 +618,7 @@ def edge_exists( f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." ) query = f"MATCH {pattern}" - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query += "\nRETURN r" # Run the Cypher query @@ -654,22 +630,19 @@ def edge_exists( @timed # Graph Query & Reasoning - def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None: + def get_node(self, id: str, include_embedding: bool = False, user_name: str = None) -> dict[str, Any] | None: """ Retrieve a Memory node by its unique ID. Args: id (str): Node ID (Memory.id) include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: dict: Node properties as key-value pairs, or None if not found. """ - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"' - else: - filter_clause = f'n.id = "{id}"' - + filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" MATCH (n@Memory) @@ -692,13 +665,14 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | @timed def get_nodes( - self, ids: list[str], include_embedding: bool = False, **kwargs + self, ids: list[str], include_embedding: bool = False, user_name: str = None, **kwargs ) -> list[dict[str, Any]]: """ Retrieve the metadata and memory of a list of nodes. Args: ids: List of Node identifier. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. @@ -709,13 +683,7 @@ def get_nodes( if not ids: return [] - where_user = "" - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_user = f" AND n.user_name = '{kwargs['cube_name']}'" - else: - where_user = f" AND n.user_name = '{self.config.user_name}'" - + where_user = f" AND n.user_name = '{user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) @@ -738,7 +706,7 @@ def get_nodes( return nodes @timed - def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]: + def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str = None) -> list[dict[str, str]]: """ Get edges connected to a node, with optional type and direction filter. @@ -746,6 +714,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ id: Node ID to retrieve edges for. type: Relationship type to match, or 'ANY' to match all. direction: 'OUTGOING', 'INCOMING', or 'ANY'. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: List of edges: @@ -770,8 +739,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ else: raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'" + where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query = f""" MATCH {pattern} @@ -799,6 +767,7 @@ def get_neighbors_by_tag( top_k: int = 5, min_overlap: int = 1, include_embedding: bool = False, + user_name: str = None, ) -> list[dict[str, Any]]: """ Find top-K neighbor nodes with maximum tag overlap. @@ -809,6 +778,7 @@ def get_neighbors_by_tag( top_k: Max number of neighbors to return. min_overlap: Minimum number of overlapping tags required. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: List of dicts with node details and overlap count. @@ -824,8 +794,7 @@ def get_neighbors_by_tag( if exclude_ids: where_clauses.append(f"NOT (n.id IN {exclude_ids})") - if not self.config.use_multi_db and self.config.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') where_clause = " AND ".join(where_clauses) tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" @@ -859,12 +828,8 @@ def get_neighbors_by_tag( return result @timed - def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: - where_user = "" - - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + def get_children_with_embeddings(self, id: str, user_name: str = None) -> list[dict[str, Any]]: + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" query = f""" MATCH (p@Memory)-[@PARENT]->(c@Memory) @@ -884,7 +849,7 @@ def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: @timed def get_subgraph( - self, center_id: str, depth: int = 2, center_status: str = "activated" + self, center_id: str, depth: int = 2, center_status: str = "activated", user_name: str = None ) -> dict[str, Any]: """ Retrieve a local subgraph centered at a given node. @@ -892,6 +857,7 @@ def get_subgraph( center_id: The ID of the center node. depth: The hop distance for neighbors. center_status: Required status for center node. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: { "core_node": {...}, @@ -902,7 +868,9 @@ def get_subgraph( if not 1 <= depth <= 5: raise ValueError("depth must be 1-5") - user_name = self.config.user_name + if not user_name and self.config.user_name: + user_name = self.config.user_name + gql = f""" MATCH (center@Memory) WHERE center.id = '{center_id}' @@ -954,6 +922,7 @@ def search_by_embedding( status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, + user_name: str = None, **kwargs, ) -> list[dict]: """ @@ -968,6 +937,7 @@ def search_by_embedding( threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. Keys should match node properties, values are the expected values. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. @@ -991,19 +961,15 @@ def search_by_embedding( where_clauses.append(f'n.memory_type = "{scope}"') if status: where_clauses.append(f'n.status = "{status}"') - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"') - else: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append(f'n.{key} = "{value}"') + else: + where_clauses.append(f"n.{key} = {value}") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" @@ -1038,7 +1004,7 @@ def search_by_embedding( return [] @timed - def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: + def get_by_metadata(self, filters: list[dict[str, Any]], user_name: str = None) -> list[str]: """ 1. ADD logic: "AND" vs "OR"(support logic combination); 2. Support nested conditional expressions; @@ -1054,6 +1020,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: {"field": "tags", "op": "contains", "value": "AI"}, ... ] + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). @@ -1087,8 +1054,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: else: raise ValueError(f"Unsupported operator: {op}") - if not self.config.use_multi_db and self.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') where_str = " AND ".join(where_clauses) gql = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id" @@ -1106,6 +1072,7 @@ def get_grouped_counts( group_fields: list[str], where_clause: str = "", params: dict[str, Any] | None = None, + user_name: str = None, ) -> list[dict[str, Any]]: """ Count nodes grouped by any fields. @@ -1115,6 +1082,7 @@ def get_grouped_counts( where_clause (str, optional): Extra WHERE condition. E.g., "WHERE n.status = 'activated'" params (dict, optional): Parameters for WHERE clause. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] @@ -1122,17 +1090,16 @@ def get_grouped_counts( if not group_fields: raise ValueError("group_fields cannot be empty") - # GQL-specific modifications - if not self.config.use_multi_db and self.config.user_name: - user_clause = f"n.user_name = '{self.config.user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" + # GQL-specific modifications + user_clause = f"n.user_name = '{user_name}'" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" else: - where_clause = f"WHERE {user_clause}" + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" # Inline parameters if provided if params: @@ -1170,16 +1137,15 @@ def get_grouped_counts( return output @timed - def clear(self) -> None: + def clear(self, user_name: str = None) -> None: """ Clear the entire graph if the target database exists. + + Args: + user_name (str, optional): User name for filtering in non-multi-db mode """ try: - if not self.config.use_multi_db and self.config.user_name: - query = f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n" - else: - query = "MATCH (n) DETACH DELETE n" - + query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n" self.execute_query(query) logger.info("Cleared all nodes from database.") @@ -1187,11 +1153,12 @@ def clear(self) -> None: logger.error(f"[ERROR] Failed to clear database: {e}") @timed - def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: + def export_graph(self, include_embedding: bool = False, user_name: str = None) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. Args: include_embedding (bool): Whether to include the large embedding field. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: { @@ -1201,11 +1168,8 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: """ node_query = "MATCH (n@Memory)" edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - - if not self.config.use_multi_db and self.config.user_name: - username = self.config.user_name - node_query += f' WHERE n.user_name = "{username}"' - edge_query += f' WHERE r.user_name = "{username}"' + node_query += f' WHERE n.user_name = "{user_name}"' + edge_query += f' WHERE r.user_name = "{user_name}"' try: if include_embedding: @@ -1265,20 +1229,18 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: return {"nodes": nodes, "edges": edges} @timed - def import_graph(self, data: dict[str, Any]) -> None: + def import_graph(self, data: dict[str, Any], user_name: str = None) -> None: """ Import the entire graph from a serialized dictionary. Args: data: A dictionary containing all nodes and edges to be loaded. + user_name (str, optional): User name for filtering in non-multi-db mode """ for node in data.get("nodes", []): try: id, memory, metadata = _compose_node(node) - - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name - + metadata["user_name"] = user_name metadata = self._prepare_node_metadata(metadata) metadata.update({"id": id, "memory": memory}) properties = ", ".join( @@ -1293,9 +1255,7 @@ def import_graph(self, data: dict[str, Any]) -> None: try: source_id, target_id = edge["source"], edge["target"] edge_type = edge["type"] - props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' + props = f'{{user_name: "{user_name}"}}' edge_gql = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) @@ -1305,13 +1265,14 @@ def import_graph(self, data: dict[str, Any]) -> None: logger.error(f"Fail to load edge: {edge}, error: {e}") @timed - def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]: + def get_all_memory_items(self, scope: str, include_embedding: bool = False, user_name: str = None) -> (list)[dict]: """ Retrieve all memory items of a specific memory_type. Args: scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: Full list of memory items under this scope. @@ -1320,9 +1281,7 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> ( raise ValueError(f"Unsupported memory type scope: {scope}") where_clause = f"WHERE n.memory_type = '{scope}'" - - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND n.user_name = '{self.config.user_name}'" + where_clause += f" AND n.user_name = '{user_name}'" return_fields = self._build_return_fields(include_embedding) @@ -1344,7 +1303,7 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> ( @timed def get_structure_optimization_candidates( - self, scope: str, include_embedding: bool = False + self, scope: str, include_embedding: bool = False, user_name: str = None ) -> list[dict]: """ Find nodes that are likely candidates for structure optimization: @@ -1356,8 +1315,7 @@ def get_structure_optimization_candidates( n.memory_type = "{scope}" AND n.status = "activated" ''' - if not self.config.use_multi_db and self.config.user_name: - where_clause += f' AND n.user_name = "{self.config.user_name}"' + where_clause += f' AND n.user_name = "{user_name}"' return_fields = self._build_return_fields(include_embedding) return_fields += f", n.{self.dim_field} AS {self.dim_field}" @@ -1386,20 +1344,6 @@ def get_structure_optimization_candidates( logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") return candidates - @timed - def drop_database(self) -> None: - """ - Permanently delete the entire database this instance is using. - WARNING: This operation is destructive and cannot be undone. - """ - if self.config.use_multi_db: - self.execute_query(f"DROP GRAPH `{self.db_name}`") - logger.info(f"Database '`{self.db_name}`' has been dropped.") - else: - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) @timed def detect_conflicts(self) -> list[tuple[str, str]]: @@ -1585,9 +1529,7 @@ def _create_basic_property_indexes(self) -> None: Create standard B-tree indexes on user_name when use Shared Database Multi-Tenant Mode. """ - fields = ["status", "memory_type", "created_at", "updated_at"] - if not self.config.use_multi_db: - fields.append("user_name") + fields = ["status", "memory_type", "created_at", "updated_at", "user_name"] for field in fields: index_name = f"idx_memory_{field}" diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 84cc8ecb3..6e1af2f03 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -25,6 +25,7 @@ def __init__(self, graph_store: Neo4jGraphDB, embedder: OllamaEmbedder): def retrieve( self, query: str, + user_id: str, parsed_goal: ParsedTaskGoal, top_k: int, memory_scope: str, @@ -53,13 +54,13 @@ def retrieve( if memory_scope == "WorkingMemory": # For working memory, retrieve all entries (no filtering) working_memories = self.graph_store.get_all_memory_items( - scope="WorkingMemory", include_embedding=False + scope="WorkingMemory", include_embedding=False, user_name=user_id ) return [TextualMemoryItem.from_dict(record) for record in working_memories] with ContextThreadPoolExecutor(max_workers=2) as executor: # Structured graph-based retrieval - future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope) + future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_id) # Vector similarity search future_vector = executor.submit( self._vector_recall, @@ -67,6 +68,7 @@ def retrieve( memory_scope, top_k, search_filter=search_filter, + user_id=user_id, ) graph_results = future_graph.result() @@ -132,7 +134,7 @@ def retrieve_from_cube( return list(combined.values()) def _graph_recall( - self, parsed_goal: ParsedTaskGoal, memory_scope: str + self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_id: str ) -> list[TextualMemoryItem]: """ Perform structured node-based retrieval from Neo4j. @@ -148,7 +150,7 @@ def _graph_recall( {"field": "key", "op": "in", "value": parsed_goal.keys}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - key_ids = self.graph_store.get_by_metadata(key_filters) + key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_id) candidate_ids.update(key_ids) # 2) tag-based OR branch @@ -157,7 +159,7 @@ def _graph_recall( {"field": "tags", "op": "contains", "value": parsed_goal.tags}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - tag_ids = self.graph_store.get_by_metadata(tag_filters) + tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_id) candidate_ids.update(tag_ids) # No matches → return empty @@ -165,7 +167,7 @@ def _graph_recall( return [] # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False, user_name=user_id) final_nodes = [] for node in node_dicts: @@ -194,6 +196,7 @@ def _vector_recall( max_num: int = 3, cube_name: str | None = None, search_filter: dict | None = None, + user_id: str | None = None, ) -> list[TextualMemoryItem]: """ Perform vector-based similarity retrieval using query embedding. @@ -210,6 +213,7 @@ def search_single(vec, filt=None): scope=memory_scope, cube_name=cube_name, search_filter=filt, + user_name=user_id, ) or [] ) @@ -255,7 +259,7 @@ def search_path_b(): unique_ids = {r["id"] for r in all_hits if r.get("id")} node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), include_embedding=False, cube_name=cube_name + list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_id ) or [] ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index df154f23a..7071e76fa 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -50,6 +50,7 @@ def search( self, query: str, top_k: int, + user_id: str, info=None, mode="fast", memory_type="All", @@ -63,6 +64,7 @@ def search( query (str): The query to search for. top_k (int): The number of top results to return. info (dict): Leave a record of memory consumption. + user_id(str): . mode (str, optional): The mode of the search. - 'fast': Uses a faster search process, sacrificing some precision for speed. - 'fine': Uses a more detailed search process, invoking large models for higher precision, but slower performance. @@ -72,9 +74,7 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ - logger.info( - f"[SEARCH] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" - ) + logger.info(f"[SEARCH] Start query='{query}', user_id='{user_id}', top_k={top_k}, mode={mode}, memory_type={memory_type}") if not info: logger.warning( "Please input 'info' when use tree.search so that " @@ -87,7 +87,7 @@ def search( parsed_goal, query_embedding, context, query = self._parse_task( query, info, mode, search_filter=search_filter ) - results = self._retrieve_paths( + results = self.parallel_retrieve( query, parsed_goal, query_embedding, info, top_k, mode, memory_type, search_filter ) deduped = self._deduplicate_results(results) @@ -158,7 +158,7 @@ def _parse_task(self, query, info, mode, top_k=5, search_filter: dict | None = N return parsed_goal, query_embedding, context, query @timed - def _retrieve_paths( + def parallel_retrieve( self, query, parsed_goal, @@ -171,41 +171,23 @@ def _retrieve_paths( ): """Run A/B/C retrieval paths in parallel""" tasks = [] - with ContextThreadPoolExecutor(max_workers=3) as executor: - tasks.append( - executor.submit( - self._retrieve_from_working_memory, - query, - parsed_goal, - query_embedding, - top_k, - memory_type, - search_filter, - ) - ) - tasks.append( - executor.submit( - self._retrieve_from_long_term_and_user, - query, - parsed_goal, - query_embedding, - top_k, - memory_type, - search_filter, - ) - ) - tasks.append( - executor.submit( - self._retrieve_from_internet, - query, - parsed_goal, - query_embedding, - top_k, - info, - mode, - memory_type, - ) - ) + if memory_type == "All": + memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory"] + else: + memory_type = [memory_type] + with ContextThreadPoolExecutor(max_workers=5) as executor: + for m_type in memory_type: + tasks.append( + executor.submit( + self._retrieve_memory, + query, info, parsed_goal, + query_embedding, top_k, + m_type, search_filter, )) + if self.internet_retriever and mode == "fine": + tasks.append( + executor.submit( + self._retrieve_from_internet, + query, parsed_goal, query_embedding, top_k, info, mode, memory_type, )) if self.moscube: tasks.append( executor.submit( @@ -225,26 +207,25 @@ def _retrieve_paths( logger.info(f"[SEARCH] Total raw results: {len(results)}") return results - # --- Path A + @timed - def _retrieve_from_working_memory( - self, - query, - parsed_goal, - query_embedding, - top_k, - memory_type, - search_filter: dict | None = None, - ): - """Retrieve and rerank from WorkingMemory""" - if memory_type not in ["All", "WorkingMemory"]: - logger.info(f"[PATH-A] '{query}'Skipped (memory_type does not match)") - return [] + def _retrieve_memory( + self, + query, + info, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter: dict | None = None): + if memory_type in ["LongTermMemory", "UserMemory"]: + top_k = 2 * top_k items = self.graph_retriever.retrieve( query=query, + user_id=info['user_id'], parsed_goal=parsed_goal, top_k=top_k, - memory_scope="WorkingMemory", + memory_scope=memory_type, search_filter=search_filter, ) return self.reranker.rerank( @@ -256,60 +237,6 @@ def _retrieve_from_working_memory( search_filter=search_filter, ) - # --- Path B - @timed - def _retrieve_from_long_term_and_user( - self, - query, - parsed_goal, - query_embedding, - top_k, - memory_type, - search_filter: dict | None = None, - ): - """Retrieve and rerank from LongTermMemory and UserMemory""" - results = [] - tasks = [] - - with ContextThreadPoolExecutor(max_workers=2) as executor: - if memory_type in ["All", "LongTermMemory"]: - tasks.append( - executor.submit( - self.graph_retriever.retrieve, - query=query, - parsed_goal=parsed_goal, - query_embedding=query_embedding, - top_k=top_k * 2, - memory_scope="LongTermMemory", - search_filter=search_filter, - ) - ) - if memory_type in ["All", "UserMemory"]: - tasks.append( - executor.submit( - self.graph_retriever.retrieve, - query=query, - parsed_goal=parsed_goal, - query_embedding=query_embedding, - top_k=top_k * 2, - memory_scope="UserMemory", - search_filter=search_filter, - ) - ) - - # Collect results from all tasks - for task in tasks: - results.extend(task.result()) - - return self.reranker.rerank( - query=query, - query_embedding=query_embedding[0], - graph_results=results, - top_k=top_k, - parsed_goal=parsed_goal, - search_filter=search_filter, - ) - @timed def _retrieve_from_memcubes( self, query, parsed_goal, query_embedding, top_k, cube_name="memos_cube01" @@ -335,11 +262,6 @@ def _retrieve_from_internet( self, query, parsed_goal, query_embedding, top_k, info, mode, memory_type ): """Retrieve and rerank from Internet source""" - if not self.internet_retriever or mode == "fast": - logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)") - return [] - if memory_type not in ["All"]: - return [] logger.info(f"[PATH-C] '{query}' Retrieving from internet...") items = self.internet_retriever.retrieve_from_internet( query=query, top_k=top_k, parsed_goal=parsed_goal, info=info @@ -407,6 +329,6 @@ def _update_usage_history(self, items, info): def _update_usage_history_worker(self, payload, usage_record: str): try: for item_id, usage_list in payload: - self.graph_store.update_node(item_id, {"usage": usage_list}) + self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=json.loads(usage_list[0])['info']['user_id']) except Exception: logger.exception("[USAGE] update usage failed") From 405a1625ec6428d2f6b0d2959f44ea4e2c3511ea Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Sun, 28 Sep 2025 14:44:07 +0800 Subject: [PATCH 11/29] feat: refactor search --- src/memos/api/routers/product_router.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index d9764eed6..64c6e3346 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -223,23 +223,23 @@ def create_memory(memory_req: MemoryCreateRequest): from memos.llms.openai import OpenAILLM from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.cosine_local import CosineLocalReranker - +import os llm = OpenAILLM( OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o', temperature=0.8, max_tokens=1024, top_p=0.9, top_k=50, remove_think_prefix=True, - api_key='sk-ZOPtsVgzxqnc8vGAmlTQTmnrpxK8me44fsEoX9bRTXFseh5Y', - api_base='http://123.129.219.111:3000/v1', extra_body=None)) + api_key=os.getenv('OPENAI_API_KEY'), + api_base=os.getenv('OPENAI_API_BASE'), extra_body=None)) embedder = UniversalAPIEmbedder( UniversalAPIEmbedderConfig(model_schema='memos.configs.embedder.UniversalAPIEmbedderConfig', model_name_or_path='bge-m3', embedding_dims=None, provider='openai', - api_key='EMPTY', base_url='http://106.75.235.231:8081/v1')) + api_key='EMPTY', base_url=os.getenv('MOS_EMBEDDER_API_BASE'))) reranker = CosineLocalReranker(level_weights={"topic": 1.0, "concept": 1.0, "fact": 1.0}, level_field='background') graph_store = NebulaGraphDB( NebulaGraphDBConfig(model_schema='memos.configs.graph_db.NebulaGraphDBConfig', - uri=['106.14.142.60:9669', '120.55.160.164:9669', '106.15.38.5:9669'], - user='root', password='NebulaMemOS0724', space='shared-tree-textual-memory-product-preandtest', + uri=json.loads(os.getenv('NEBULAR_HOSTS')), + user=os.getenv('NEBULAR_USER'), password=os.getenv('NEBULAR_PASSWORD'), space=os.getenv('NEBULAR_SPACE'), auto_create=True, max_client=1000, embedding_dimension=1024)) s = Searcher(llm, graph_store, embedder, reranker, internet_retriever=None, moscube=False) @@ -249,12 +249,13 @@ def create_memory(memory_req: MemoryCreateRequest): def search_memories(search_req: SearchRequest): """Search memories for a specific user.""" # try: - user_id = f"memos{search_req.user_id.replace('-', '')}" + # user_id = f"memos{search_req.user_id.replace('-', '')}" + user_id = search_req.user_id res = s.search(query=search_req.query, user_id=user_id, top_k=search_req.top_k , mode="fast", search_filter=None, info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []}) res = {"d": res} - print(res) + # print(res) return SearchResponse(message="Search completed successfully", data=res) # except ValueError as err: From ccef65166dd4aca882043bc1dbb31b72ee4362a9 Mon Sep 17 00:00:00 2001 From: chentang Date: Sun, 28 Sep 2025 16:08:27 +0800 Subject: [PATCH 12/29] add new feat of thread race, and add a new test case for scheduler dispatcher --- .../modules/locomo_eval_module.py | 19 ++ .../temporal_locomo/modules/thread_race.py | 134 ++++++++ .../temporal_locomo/temporal_locomo_eval.py | 36 ++- src/memos/configs/mem_scheduler.py | 6 +- src/memos/mem_scheduler/base_scheduler.py | 4 +- .../general_modules/dispatcher.py | 23 ++ .../general_modules/task_threads.py | 139 +++++++++ .../mem_scheduler/schemas/general_schemas.py | 4 +- tests/mem_scheduler/test_dispatcher.py | 295 ++++++++++++++++++ 9 files changed, 646 insertions(+), 14 deletions(-) create mode 100644 evaluation/scripts/temporal_locomo/modules/thread_race.py create mode 100644 src/memos/mem_scheduler/general_modules/task_threads.py create mode 100644 tests/mem_scheduler/test_dispatcher.py diff --git a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py index 4a56b599b..b05243a11 100644 --- a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py +++ b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py @@ -531,6 +531,25 @@ def process_qa(qa): json.dump(dict(search_results), fw, indent=2) print(f"Save search results {conv_id}") + search_durations = [] + for result in response_results[conv_id]: + if "search_duration_ms" in result: + search_durations.append(result["search_duration_ms"]) + + if search_durations: + avg_search_duration = sum(search_durations) / len(search_durations) + with self.stats_lock: + if self.stats[self.frame][self.version]["memory_stats"]["avg_search_duration_ms"]: + self.stats[self.frame][self.version]["memory_stats"][ + "avg_search_duration_ms" + ] = ( + self.stats[self.frame][self.version]["memory_stats"][ + "avg_search_duration_ms" + ] + + avg_search_duration + ) / 2 + print(f"Average search duration: {avg_search_duration:.2f} ms") + # Dump stats after processing each user self.save_stats() diff --git a/evaluation/scripts/temporal_locomo/modules/thread_race.py b/evaluation/scripts/temporal_locomo/modules/thread_race.py new file mode 100644 index 000000000..66aab4652 --- /dev/null +++ b/evaluation/scripts/temporal_locomo/modules/thread_race.py @@ -0,0 +1,134 @@ +import random +import threading +import time + + +class ThreadRace: + def __init__(self): + # Variable to store the result + self.result = None + # Event to mark if the race is finished + self.race_finished = threading.Event() + # Lock to protect the result variable + self.lock = threading.Lock() + # Store thread objects for termination + self.threads = {} + # Stop flags for each thread + self.stop_flags = {} + + def task1(self, stop_flag): + """First task function, can be modified as needed""" + # Simulate random work time + sleep_time = random.uniform(0.1, 2.0) + + # Break the sleep into smaller chunks to check stop flag + chunks = 20 + chunk_time = sleep_time / chunks + + for _ in range(chunks): + # Check if we should stop + if stop_flag.is_set(): + return None + time.sleep(chunk_time) + + return f"Task 1 completed in: {sleep_time:.2f} seconds" + + def task2(self, stop_flag): + """Second task function, can be modified as needed""" + # Simulate random work time + sleep_time = random.uniform(0.1, 2.0) + + # Break the sleep into smaller chunks to check stop flag + chunks = 20 + chunk_time = sleep_time / chunks + + for _ in range(chunks): + # Check if we should stop + if stop_flag.is_set(): + return None + time.sleep(chunk_time) + + return f"Task 2 completed in: {sleep_time:.2f} seconds" + + def worker(self, task_func, task_name): + """Worker thread function""" + # Create a stop flag for this task + stop_flag = threading.Event() + self.stop_flags[task_name] = stop_flag + + try: + # Execute the task with stop flag + result = task_func(stop_flag) + + # If the race is already finished or we were asked to stop, return immediately + if self.race_finished.is_set() or stop_flag.is_set(): + return None + + # Try to set the result (if no other thread has set it yet) + with self.lock: + if not self.race_finished.is_set(): + self.result = (task_name, result) + # Mark the race as finished + self.race_finished.set() + print(f"{task_name} won the race!") + + # Signal other threads to stop + for name, flag in self.stop_flags.items(): + if name != task_name: + print(f"Signaling {name} to stop") + flag.set() + + return self.result + + except Exception as e: + print(f"{task_name} encountered an error: {e}") + + return None + + def run_race(self): + """Start the competition and return the result of the fastest thread""" + # Reset state + self.race_finished.clear() + self.result = None + self.threads.clear() + self.stop_flags.clear() + + # Create threads + thread1 = threading.Thread(target=self.worker, args=(self.task1, "Thread 1")) + thread2 = threading.Thread(target=self.worker, args=(self.task2, "Thread 2")) + + # Record thread objects for later joining + self.threads["Thread 1"] = thread1 + self.threads["Thread 2"] = thread2 + + # Start threads + thread1.start() + thread2.start() + + # Wait for any thread to complete + while not self.race_finished.is_set(): + time.sleep(0.01) # Small delay to avoid high CPU usage + + # If all threads have ended but no result is set, there's a problem + if ( + not thread1.is_alive() + and not thread2.is_alive() + and not self.race_finished.is_set() + ): + print("All threads have ended, but there's no winner") + return None + + # Wait for all threads to end (with timeout to avoid infinite waiting) + thread1.join(timeout=1.0) + thread2.join(timeout=1.0) + + # Return the result + return self.result + + +# Usage example +if __name__ == "__main__": + race = ThreadRace() + result = race.run_race() + print(f"Winner: {result[0] if result else None}") + print(f"Result: {result[1] if result else None}") diff --git a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py index c21bcfc1c..46385626c 100644 --- a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py +++ b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py @@ -33,7 +33,7 @@ def __init__(self, args): self.locomo_evaluator = LocomoEvaluator(args=args) self.locomo_metric = LocomoMetric(args=args) - def run_eval_pipeline(self, skip_ingestion=True, skip_processing=False): + def run_answer_hit_eval_pipeline(self, skip_ingestion=True, skip_processing=False): """ Run the complete evaluation pipeline including dataset conversion, data ingestion, and processing. @@ -99,6 +99,32 @@ def run_eval_pipeline(self, skip_ingestion=True, skip_processing=False): print(f" - Statistics: {self.stats_path}") print("=" * 80) + def run_inference_eval_pipeline(self, skip_ingestion=True, skip_processing=False): + """ + Run the complete evaluation pipeline including dataset conversion, + data ingestion, and processing. + """ + print("=" * 80) + print("Starting TimeLocomo Evaluation Pipeline") + print("=" * 80) + + # Step 1: Check if temporal_locomo dataset exists, if not convert it + temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json" + if not temporal_locomo_file.exists(): + print(f"Temporal locomo dataset not found at {temporal_locomo_file}") + print("Converting locomo dataset to temporal_locomo format...") + self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo") + print("Dataset conversion completed.") + else: + print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.") + + # Step 2: Data ingestion + if not skip_ingestion: + print("\n" + "=" * 50) + print("Step 2: Data Ingestion") + print("=" * 50) + self.locomo_ingestor.run_ingestion() + def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): """ Compute can-answer statistics per day for each conversation using the @@ -120,7 +146,7 @@ def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): parser.add_argument( "--frame", type=str, - default="memos_scheduler", + default="memos", choices=["zep", "memos", "mem0", "mem0_graph", "memos_scheduler"], help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", ) @@ -152,8 +178,4 @@ def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): args = parser.parse_args() evaluator = TemporalLocomoEval(args=args) - evaluator.run_eval_pipeline() - - # rule-based baselines - evaluator.compute_can_answer_count_by_pre_evidences(rounds_to_consider=float("inf")) - evaluator.compute_can_answer_count_by_pre_evidences(rounds_to_consider=1) + evaluator.run_answer_hit_eval_pipeline() diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 90ed6a272..82616ac93 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -11,7 +11,7 @@ BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, - DEFAULT_THREAD__POOL_MAX_WORKERS, + DEFAULT_THREAD_POOL_MAX_WORKERS, ) @@ -25,10 +25,10 @@ class BaseSchedulerConfig(BaseConfig): default=True, description="Whether to enable parallel message processing using thread pool" ) thread_pool_max_workers: int = Field( - default=DEFAULT_THREAD__POOL_MAX_WORKERS, + default=DEFAULT_THREAD_POOL_MAX_WORKERS, gt=1, lt=20, - description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD__POOL_MAX_WORKERS})", + description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD_POOL_MAX_WORKERS})", ) consume_interval_seconds: float = Field( default=DEFAULT_CONSUME_INTERVAL_SECONDS, diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index b6ef00d8d..3e25a0ad7 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -20,7 +20,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, - DEFAULT_THREAD__POOL_MAX_WORKERS, + DEFAULT_THREAD_POOL_MAX_WORKERS, MemCubeID, TreeTextMemory_SEARCH_METHOD, UserID, @@ -60,7 +60,7 @@ def __init__(self, config: BaseSchedulerConfig): self.search_method = TreeTextMemory_SEARCH_METHOD self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False) self.thread_pool_max_workers = self.config.get( - "thread_pool_max_workers", DEFAULT_THREAD__POOL_MAX_WORKERS + "thread_pool_max_workers", DEFAULT_THREAD_POOL_MAX_WORKERS ) self.retriever: SchedulerRetriever | None = None diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index ce6df4d5d..e45ce4a2b 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -1,11 +1,14 @@ import concurrent +import threading from collections import defaultdict from collections.abc import Callable +from typing import Any from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.task_threads import ThreadRace from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -22,6 +25,7 @@ class SchedulerDispatcher(BaseSchedulerModule): - Batch message processing - Graceful shutdown - Bulk handler registration + - Thread race competition for parallel task execution """ def __init__(self, max_workers=30, enable_parallel_dispatch=False): @@ -49,6 +53,9 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False): # Set to track active futures for monitoring purposes self._futures = set() + # Thread race module for competitive task execution + self.thread_race = ThreadRace() + def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]): """ Register a handler function for a specific message label. @@ -177,6 +184,22 @@ def join(self, timeout: float | None = None) -> bool: return len(not_done) == 0 + def run_competitive_tasks( + self, tasks: dict[str, Callable[[threading.Event], Any]], timeout: float = 10.0 + ) -> tuple[str, Any] | None: + """ + Run multiple tasks in a competitive race, returning the result of the first task to complete. + + Args: + tasks: Dictionary mapping task names to task functions that accept a stop_flag parameter + timeout: Maximum time to wait for any task to complete (in seconds) + + Returns: + Tuple of (task_name, result) from the winning task, or None if no task completes + """ + logger.info(f"Starting competitive execution of {len(tasks)} tasks") + return self.thread_race.run_race(tasks, timeout) + def shutdown(self) -> None: """Gracefully shutdown the dispatcher.""" self._running = False diff --git a/src/memos/mem_scheduler/general_modules/task_threads.py b/src/memos/mem_scheduler/general_modules/task_threads.py new file mode 100644 index 000000000..9df8ef650 --- /dev/null +++ b/src/memos/mem_scheduler/general_modules/task_threads.py @@ -0,0 +1,139 @@ +import threading + +from collections.abc import Callable +from typing import Any, TypeVar + +from memos.log import get_logger +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule + + +logger = get_logger(__name__) + +T = TypeVar("T") + + +class ThreadRace(BaseSchedulerModule): + """ + Thread race implementation that runs multiple tasks concurrently and returns + the result of the first task to complete successfully. + + Features: + - Cooperative thread termination using stop flags + - Configurable timeout for tasks + - Automatic cleanup of slower threads + - Thread-safe result handling + """ + + def __init__(self): + super().__init__() + # Variable to store the result + self.result: tuple[str, Any] | None = None + # Event to mark if the race is finished + self.race_finished = threading.Event() + # Lock to protect the result variable + self.lock = threading.Lock() + # Store thread objects for termination + self.threads: dict[str, threading.Thread] = {} + # Stop flags for each thread + self.stop_flags: dict[str, threading.Event] = {} + + def worker( + self, task_func: Callable[[threading.Event], T], task_name: str + ) -> tuple[str, T] | None: + """ + Worker thread function that executes a task and handles result reporting. + + Args: + task_func: Function to execute with a stop_flag parameter + task_name: Name identifier for this task/thread + + Returns: + Tuple of (task_name, result) if this thread wins the race, None otherwise + """ + # Create a stop flag for this task + stop_flag = threading.Event() + self.stop_flags[task_name] = stop_flag + + try: + # Execute the task with stop flag + result = task_func(stop_flag) + + # If the race is already finished or we were asked to stop, return immediately + if self.race_finished.is_set() or stop_flag.is_set(): + return None + + # Try to set the result (if no other thread has set it yet) + with self.lock: + if not self.race_finished.is_set(): + self.result = (task_name, result) + # Mark the race as finished + self.race_finished.set() + logger.info(f"Task '{task_name}' won the race") + + # Signal other threads to stop + for name, flag in self.stop_flags.items(): + if name != task_name: + logger.debug(f"Signaling task '{name}' to stop") + flag.set() + + return self.result + + except Exception as e: + logger.error(f"Task '{task_name}' encountered an error: {e}") + + return None + + def run_race( + self, tasks: dict[str, Callable[[threading.Event], T]], timeout: float = 10.0 + ) -> tuple[str, T] | None: + """ + Start a competition between multiple tasks and return the result of the fastest one. + + Args: + tasks: Dictionary mapping task names to task functions + timeout: Maximum time to wait for any task to complete (in seconds) + + Returns: + Tuple of (task_name, result) from the winning task, or None if no task completes + """ + if not tasks: + logger.warning("No tasks provided for the race") + return None + + # Reset state + self.race_finished.clear() + self.result = None + self.threads.clear() + self.stop_flags.clear() + + # Create and start threads for each task + for task_name, task_func in tasks.items(): + thread = threading.Thread( + target=self.worker, args=(task_func, task_name), name=f"race-{task_name}" + ) + self.threads[task_name] = thread + thread.start() + logger.debug(f"Started task '{task_name}'") + + # Wait for any thread to complete or timeout + race_completed = self.race_finished.wait(timeout=timeout) + + if not race_completed: + logger.warning(f"Race timed out after {timeout} seconds") + # Signal all threads to stop + for _name, flag in self.stop_flags.items(): + flag.set() + + # Wait for all threads to end (with timeout to avoid infinite waiting) + for _name, thread in self.threads.items(): + thread.join(timeout=1.0) + if thread.is_alive(): + logger.warning(f"Thread '{_name}' did not terminate within the join timeout") + + # Return the result + if self.result: + logger.info(f"Race completed. Winner: {self.result[0]}") + else: + logger.warning("Race completed with no winner") + + return self.result diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 1ac651ca7..b029e38e8 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -17,8 +17,8 @@ DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT = 30 DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 20 DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" -DEFAULT_THREAD__POOL_MAX_WORKERS = 5 -DEFAULT_CONSUME_INTERVAL_SECONDS = 0.5 +DEFAULT_THREAD_POOL_MAX_WORKERS = 10 +DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05 NOT_INITIALIZED = -1 diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py new file mode 100644 index 000000000..f4d0d6b97 --- /dev/null +++ b/tests/mem_scheduler/test_dispatcher.py @@ -0,0 +1,295 @@ +import sys +import time +import unittest + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from memos.configs.mem_scheduler import ( + AuthConfig, + GraphDBAuthConfig, + OpenAIConfig, + RabbitMQConfig, + SchedulerConfigFactory, +) +from memos.llms.base import BaseLLM +from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.memories.textual.tree import TreeTextMemory + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + + +class TestSchedulerDispatcher(unittest.TestCase): + """Test cases for the SchedulerDispatcher class.""" + + def _create_mock_auth_config(self): + """Create a mock AuthConfig for testing purposes.""" + # Create mock configs with valid test values + graph_db_config = GraphDBAuthConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test_password_123", # 8+ characters to pass validation + db_name="neo4j", + auto_create=True, + ) + + rabbitmq_config = RabbitMQConfig( + host_name="localhost", port=5672, user_name="guest", password="guest", virtual_host="/" + ) + + openai_config = OpenAIConfig(api_key="test_api_key_123", default_model="gpt-3.5-turbo") + + return AuthConfig(rabbitmq=rabbitmq_config, openai=openai_config, graph_db=graph_db_config) + + def setUp(self): + """Initialize test environment with mock objects.""" + example_scheduler_config_path = ( + f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml" + ) + scheduler_config = SchedulerConfigFactory.from_yaml_file( + yaml_path=example_scheduler_config_path + ) + mem_scheduler = SchedulerFactory.from_config(scheduler_config) + self.scheduler = mem_scheduler + self.llm = MagicMock(spec=BaseLLM) + self.mem_cube = MagicMock(spec=GeneralMemCube) + self.tree_text_memory = MagicMock(spec=TreeTextMemory) + self.mem_cube.text_mem = self.tree_text_memory + self.mem_cube.act_mem = MagicMock() + + # Mock AuthConfig.from_local_env() to return our test config + mock_auth_config = self._create_mock_auth_config() + self.auth_config_patch = patch( + "memos.configs.mem_scheduler.AuthConfig.from_local_env", return_value=mock_auth_config + ) + self.auth_config_patch.start() + + # Initialize general_modules with mock LLM + self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm) + self.scheduler.mem_cube = self.mem_cube + + self.dispatcher = self.scheduler.dispatcher + + # Create mock handlers + self.mock_handler1 = MagicMock() + self.mock_handler2 = MagicMock() + + # Register mock handlers + self.dispatcher.register_handler("label1", self.mock_handler1) + self.dispatcher.register_handler("label2", self.mock_handler2) + + # Create test messages + self.test_messages = [ + ScheduleMessageItem( + item_id="msg1", + user_id="user1", + mem_cube="cube1", + mem_cube_id="msg1", + label="label1", + content="Test content 1", + timestamp=123456789, + ), + ScheduleMessageItem( + item_id="msg2", + user_id="user1", + mem_cube="cube1", + mem_cube_id="msg2", + label="label2", + content="Test content 2", + timestamp=123456790, + ), + ScheduleMessageItem( + item_id="msg3", + user_id="user2", + mem_cube="cube2", + mem_cube_id="msg3", + label="label1", + content="Test content 3", + timestamp=123456791, + ), + ] + + # Mock logging to verify messages + self.logging_warning_patch = patch("logging.warning") + self.mock_logging_warning = self.logging_warning_patch.start() + + # Mock the MemoryFilter logger since that's where the actual logging happens + self.logger_info_patch = patch( + "memos.mem_scheduler.memory_manage_modules.memory_filter.logger.info" + ) + self.mock_logger_info = self.logger_info_patch.start() + + def tearDown(self): + """Clean up patches.""" + self.logging_warning_patch.stop() + self.logger_info_patch.stop() + self.auth_config_patch.stop() + + def test_register_handler(self): + """Test registering a single handler.""" + new_handler = MagicMock() + self.dispatcher.register_handler("new_label", new_handler) + + # Verify handler was registered + self.assertIn("new_label", self.dispatcher.handlers) + self.assertEqual(self.dispatcher.handlers["new_label"], new_handler) + + def test_register_handlers(self): + """Test bulk registration of handlers.""" + new_handlers = { + "bulk1": MagicMock(), + "bulk2": MagicMock(), + } + + self.dispatcher.register_handlers(new_handlers) + + # Verify all handlers were registered + for label, handler in new_handlers.items(): + self.assertIn(label, self.dispatcher.handlers) + self.assertEqual(self.dispatcher.handlers[label], handler) + + def test_dispatch_serial(self): + """Test dispatching messages in serial mode.""" + # Create a new dispatcher with parallel dispatch disabled + serial_dispatcher = SchedulerDispatcher(max_workers=2, enable_parallel_dispatch=False) + serial_dispatcher.register_handler("label1", self.mock_handler1) + serial_dispatcher.register_handler("label2", self.mock_handler2) + + # Dispatch messages + serial_dispatcher.dispatch(self.test_messages) + + # Verify handlers were called with the correct messages + self.mock_handler1.assert_called_once() + self.mock_handler2.assert_called_once() + + # Check that each handler received the correct messages + label1_messages = [msg for msg in self.test_messages if msg.label == "label1"] + label2_messages = [msg for msg in self.test_messages if msg.label == "label2"] + + # The first argument of the first call + self.assertEqual(self.mock_handler1.call_args[0][0], label1_messages) + self.assertEqual(self.mock_handler2.call_args[0][0], label2_messages) + + def test_dispatch_parallel(self): + """Test dispatching messages in parallel mode.""" + # Dispatch messages + self.dispatcher.dispatch(self.test_messages) + + # Wait for all futures to complete + self.dispatcher.join(timeout=1.0) + + # Verify handlers were called + self.mock_handler1.assert_called_once() + self.mock_handler2.assert_called_once() + + # Check that each handler received the correct messages + label1_messages = [msg for msg in self.test_messages if msg.label == "label1"] + label2_messages = [msg for msg in self.test_messages if msg.label == "label2"] + + # The first argument of the first call + self.assertEqual(self.mock_handler1.call_args[0][0], label1_messages) + self.assertEqual(self.mock_handler2.call_args[0][0], label2_messages) + + def test_group_messages_by_user_and_cube(self): + """Test grouping messages by user and cube.""" + # Check actual grouping logic + with patch("memos.mem_scheduler.general_modules.dispatcher.logger.debug"): + result = self.dispatcher.group_messages_by_user_and_cube(self.test_messages) + + # Adjust expected results based on actual grouping logic + # Note: According to dispatcher.py implementation, grouping is by mem_cube_id not mem_cube + expected = { + "user1": { + "msg1": [self.test_messages[0]], + "msg2": [self.test_messages[1]], + }, + "user2": { + "msg3": [self.test_messages[2]], + }, + } + + # Use more flexible assertion method + self.assertEqual(set(result.keys()), set(expected.keys())) + for user_id in expected: + self.assertEqual(set(result[user_id].keys()), set(expected[user_id].keys())) + for cube_id in expected[user_id]: + self.assertEqual(len(result[user_id][cube_id]), len(expected[user_id][cube_id])) + # Check if each message exists + for msg in expected[user_id][cube_id]: + self.assertIn(msg.item_id, [m.item_id for m in result[user_id][cube_id]]) + + def test_thread_race(self): + """Test the ThreadRace integration.""" + + # Define test tasks + def task1(stop_flag): + time.sleep(0.1) + return "result1" + + def task2(stop_flag): + time.sleep(0.2) + return "result2" + + # Run competitive tasks + tasks = { + "task1": task1, + "task2": task2, + } + + result = self.dispatcher.run_competitive_tasks(tasks, timeout=1.0) + + # Verify the result + self.assertIsNotNone(result) + self.assertEqual(result[0], "task1") # task1 should win + self.assertEqual(result[1], "result1") + + def test_thread_race_timeout(self): + """Test ThreadRace with timeout.""" + + # Define a task that takes longer than the timeout + def slow_task(stop_flag): + time.sleep(0.5) + return "slow_result" + + tasks = {"slow": slow_task} + + # Run with a short timeout + result = self.dispatcher.run_competitive_tasks(tasks, timeout=0.1) + + # Verify no result was returned due to timeout + self.assertIsNone(result) + + def test_thread_race_cooperative_termination(self): + """Test that ThreadRace properly terminates slower threads when one completes.""" + + # Create a fast task and a slow task + def fast_task(stop_flag): + return "fast result" + + def slow_task(stop_flag): + # Check stop flag to ensure proper response + for _ in range(10): + if stop_flag.is_set(): + return "stopped early" + time.sleep(0.1) + return "slow result" + + # Run competitive tasks with increased timeout for test stability + result = self.dispatcher.run_competitive_tasks( + {"fast": fast_task, "slow": slow_task}, + timeout=2.0, # Increased timeout + ) + + # Verify the result is from the fast task + self.assertIsNotNone(result) + self.assertEqual(result[0], "fast") + self.assertEqual(result[1], "fast result") + + # Allow enough time for thread cleanup + time.sleep(0.5) From 7628da8235b781da3d7097a647bb99a790d7be81 Mon Sep 17 00:00:00 2001 From: fridayL Date: Sun, 28 Sep 2025 16:24:12 +0800 Subject: [PATCH 13/29] update: add api for memory --- src/memos/api/product_models.py | 2 +- src/memos/api/routers/server_router.py | 126 ++++++++++++++++++ src/memos/api/server_api.py | 38 ++++++ src/memos/mem_reader/simple_struct.py | 17 ++- .../tree_text_memory/organize/manager.py | 17 +-- 5 files changed, 184 insertions(+), 16 deletions(-) create mode 100644 src/memos/api/routers/server_router.py create mode 100644 src/memos/api/server_api.py diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 898d67fe1..3ad47dbeb 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -151,7 +151,7 @@ class MemoryCreateRequest(BaseRequest): mem_cube_id: str | None = Field(None, description="Cube ID") source: str | None = Field(None, description="Source of the memory") user_profile: bool = Field(False, description="User profile memory") - session_id: str | None = Field(None, description="Session id") + session_id: str | None = Field(default_factory=lambda: str(uuid.uuid4()), description="Session id") class SearchRequest(BaseRequest): diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py new file mode 100644 index 000000000..9b1056678 --- /dev/null +++ b/src/memos/api/routers/server_router.py @@ -0,0 +1,126 @@ +import os +import json +import time +from fastapi import APIRouter +from memos import log +from memos.api.product_models import ( + BaseResponse, + ChatCompleteRequest, + ChatRequest, + GetMemoryRequest, + MemoryCreateRequest, + MemoryResponse, + SearchRequest, + SearchResponse, + SimpleResponse, + SuggestionRequest, + SuggestionResponse, + UserRegisterRequest, + UserRegisterResponse, +) +from memos.configs.embedder import UniversalAPIEmbedderConfig +from memos.configs.graph_db import NebulaGraphDBConfig +from memos.configs.llm import OpenAILLMConfig +from memos.embedders.universal_api import UniversalAPIEmbedder +from memos.graph_dbs.nebular import NebulaGraphDB +from memos.llms.openai import OpenAILLM +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.reranker.cosine_local import CosineLocalReranker +from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.configs.mem_reader import SimpleStructMemReaderConfig +from memos.configs.chunker import ChunkerConfigFactory +from memos.configs.llm import LLMConfigFactory +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.chunker import SentenceChunkerConfig +from memos.chunkers.sentence_chunker import SentenceChunker + +logger = log.get_logger(__name__) +router = APIRouter() + +def init_model(): + llm = OpenAILLM( + OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o', + temperature=0.8, max_tokens=1024, top_p=0.9, top_k=50, remove_think_prefix=True, + api_key=os.getenv('OPENAI_API_KEY'), + api_base=os.getenv('OPENAI_API_BASE'), extra_body=None)) + embedder = UniversalAPIEmbedder( + UniversalAPIEmbedderConfig(model_schema='memos.configs.embedder.UniversalAPIEmbedderConfig', + model_name_or_path='bge-m3', embedding_dims=None, provider='openai', + api_key='EMPTY', base_url=os.getenv('MOS_EMBEDDER_API_BASE'))) + + reranker = CosineLocalReranker(level_weights={"topic": 1.0, "concept": 1.0, "fact": 1.0}, level_field='background') + + graph_store = NebulaGraphDB( + NebulaGraphDBConfig(model_schema='memos.configs.graph_db.NebulaGraphDBConfig', + uri=json.loads(os.getenv('NEBULAR_HOSTS')), + user=os.getenv('NEBULAR_USER'), password=os.getenv('NEBULAR_PASSWORD'), space=os.getenv('NEBULAR_SPACE'), + auto_create=True, max_client=1000, embedding_dimension=1024)) + search_obj = Searcher(llm, graph_store, embedder, reranker, internet_retriever=None, moscube=False) + chunker = SentenceChunker( + SentenceChunkerConfig( + model_schema='memos.configs.chunker.SentenceChunkerConfig', + tokenizer_or_token_counter="gpt2", + chunk_size=512, + chunk_overlap=128, + min_sentences_per_chunk=1, + ) + ) + mem_reader = SimpleStructMemReader( + llm, + embedder, + chunker + ) + memory_add_obj = MemoryManager( + graph_store, + embedder, + llm, + memory_size={ + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + }, + is_reorganize=False + ) + + return search_obj, memory_add_obj, mem_reader + +search_obj, memory_add_obj, mem_reader = init_model() + + +@router.post("/search", summary="Search memories", response_model=SearchResponse) +def search_memories(search_req: SearchRequest): + """Search memories for a specific user.""" + # try: + # user_id = f"memos{search_req.user_id.replace('-', '')}" + user_id = search_req.user_id + res = search_obj.search(query=search_req.query, user_id=user_id, top_k=search_req.top_k + , mode="fast", search_filter=None, + info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []}) + res = {"d": res} + # print(res) + return SearchResponse(message="Search completed successfully", data=res) + + +@router.post("/add", summary="add memories", response_model=SearchResponse) +def add_memories(add_req: MemoryCreateRequest): + """Add memories for a specific user.""" + time_start = time.time() + + memories = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={"user_id": add_req.user_id, "session_id": add_req.session_id}, + ) + logger.info( + f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s" + ) + mem_ids = [] + for mem in memories: + mem_id_list: list[str] = memory_add_obj.add(mem, user_name=add_req.user_id) + mem_ids.extend(mem_id_list) + logger.info( + f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}" + ) + data = {"mem_ids": mem_ids} + return SearchResponse(message="Memory added successfully", data=data) \ No newline at end of file diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py new file mode 100644 index 000000000..78e05ef85 --- /dev/null +++ b/src/memos/api/server_api.py @@ -0,0 +1,38 @@ +import logging + +from fastapi import FastAPI + +from memos.api.exceptions import APIExceptionHandler +from memos.api.middleware.request_context import RequestContextMiddleware +from memos.api.routers.server_router import router as server_router + + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +app = FastAPI( + title="MemOS Product REST APIs", + description="A REST API for managing multiple users with MemOS Product.", + version="1.0.1", +) + +app.add_middleware(RequestContextMiddleware) +# Include routers +app.include_router(server_router) + +# Exception handlers +app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) +app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) + + +if __name__ == "__main__": + import argparse + + import uvicorn + + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--workers", type=int, default=1) + args = parser.parse_args() + uvicorn.run("memos.api.server_api:app", host="0.0.0.0", port=args.port, workers=args.workers) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index b439cb2b2..6bf399650 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -115,17 +115,20 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder class SimpleStructMemReader(BaseMemReader, ABC): """Naive implementation of MemReader.""" - def __init__(self, config: SimpleStructMemReaderConfig): + def __init__(self, llm, embedder, chunker): """ Initialize the NaiveMemReader with configuration. Args: config: Configuration object for the reader """ - self.config = config - self.llm = LLMFactory.from_config(config.llm) - self.embedder = EmbedderFactory.from_config(config.embedder) - self.chunker = ChunkerFactory.from_config(config.chunker) + # self.config = config + # self.llm = LLMFactory.from_config(config.llm) + # self.embedder = EmbedderFactory.from_config(config.embedder) + # self.chunker = ChunkerFactory.from_config(config.chunker) + self.llm = llm + self.embedder = embedder + self.chunker = chunker @timed def _process_chat_data(self, scene_data_info, info): @@ -142,8 +145,8 @@ def _process_chat_data(self, scene_data_info, info): examples = PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", "\n".join(mem_list)) - if self.config.remove_prompt_example: - prompt = prompt.replace(examples, "") + # if self.config.remove_prompt_example: + # prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index c9cd4de8a..1ddc485e0 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -51,14 +51,14 @@ def __init__( ) self._merged_threshold = merged_threshold - def add(self, memories: list[TextualMemoryItem]) -> list[str]: + def add(self, memories: list[TextualMemoryItem], user_name: str = None) -> list[str]: """ Add new memories in parallel to different memory types (WorkingMemory, LongTermMemory, UserMemory). """ added_ids: list[str] = [] with ContextThreadPoolExecutor(max_workers=8) as executor: - futures = {executor.submit(self._process_memory, m): m for m in memories} + futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} for future in as_completed(futures, timeout=60): try: ids = future.result() @@ -126,7 +126,7 @@ def _refresh_memory_size(self) -> None: self.current_memory_size = {record["memory_type"]: record["count"] for record in results} logger.info(f"[MemoryManager] Refreshed memory sizes: {self.current_memory_size}") - def _process_memory(self, memory: TextualMemoryItem): + def _process_memory(self, memory: TextualMemoryItem, user_name: str = None): """ Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). This method runs asynchronously to process each memory item. @@ -134,7 +134,7 @@ def _process_memory(self, memory: TextualMemoryItem): ids = [] # Add to WorkingMemory - working_id = self._add_memory_to_db(memory, "WorkingMemory") + working_id = self._add_memory_to_db(memory, "WorkingMemory", user_name) ids.append(working_id) # Add to LongTermMemory and UserMemory @@ -142,12 +142,13 @@ def _process_memory(self, memory: TextualMemoryItem): added_id = self._add_to_graph_memory( memory=memory, memory_type=memory.metadata.memory_type, + user_name=user_name ) ids.append(added_id) return ids - def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str: + def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str, user_name: str = None) -> str: """ Add a single memory item to the graph store, with FIFO logic for WorkingMemory. """ @@ -158,10 +159,10 @@ def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str: working_memory = TextualMemoryItem(memory=memory.memory, metadata=metadata) # Insert node into graph - self.graph_store.add_node(working_memory.id, working_memory.memory, metadata) + self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) return working_memory.id - def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str): + def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str, user_name: str = None): """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). @@ -175,7 +176,7 @@ def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str): node_id = str(uuid.uuid4()) # Step 2: Add new node to graph self.graph_store.add_node( - node_id, memory.memory, memory.metadata.model_dump(exclude_none=True) + node_id, memory.memory, memory.metadata.model_dump(exclude_none=True), user_name=user_name ) self.reorganizer.add_message( QueueMessage( From 235125a1bda543ef7944c1e2655b217d93d5d652 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Sun, 28 Sep 2025 17:17:10 +0800 Subject: [PATCH 14/29] feat: add memory api return memory and memory type --- src/memos/api/routers/server_router.py | 20 +++++++++---------- .../tree_text_memory/organize/manager.py | 2 +- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 9b1056678..fbcb0cff5 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -110,17 +110,15 @@ def add_memories(add_req: MemoryCreateRequest): memories = mem_reader.get_memory( [add_req.messages], type="chat", - info={"user_id": add_req.user_id, "session_id": add_req.session_id}, - ) + info={"user_id": add_req.user_id, "session_id": add_req.session_id},)[0] logger.info( f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s" ) - mem_ids = [] - for mem in memories: - mem_id_list: list[str] = memory_add_obj.add(mem, user_name=add_req.user_id) - mem_ids.extend(mem_id_list) - logger.info( - f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}" - ) - data = {"mem_ids": mem_ids} - return SearchResponse(message="Memory added successfully", data=data) \ No newline at end of file + data = [] + + mem_id_list: list[str] = memory_add_obj.add(memories, user_name=add_req.user_id) + logger.info(f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}") + + for m_id, m in zip(mem_id_list, memories): + data.append({'memory': m.memory, 'mem_ids': m_id, 'memory_type': m.metadata.memory_type}) + return SearchResponse(message="Memory added successfully", data=data) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 1ddc485e0..94dad104e 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -135,7 +135,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str = None): # Add to WorkingMemory working_id = self._add_memory_to_db(memory, "WorkingMemory", user_name) - ids.append(working_id) + # ids.append(working_id) # Add to LongTermMemory and UserMemory if memory.metadata.memory_type in ["LongTermMemory", "UserMemory"]: From 81bc1f62a19e2f25b84c77178c6319cbd93471a1 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Sun, 28 Sep 2025 17:54:28 +0800 Subject: [PATCH 15/29] =?UTF-8?q?refactor(server):=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=99=A8=E8=B7=AF=E7=94=B1=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E4=BB=A5=E4=BC=98=E5=8C=96=E5=86=85=E5=AD=98=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/api/routers/server_router.py | 89 ++++++++------------------ 1 file changed, 28 insertions(+), 61 deletions(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index fbcb0cff5..5b3297a5a 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,90 +1,62 @@ -import os import json +import os import time from fastapi import APIRouter from memos import log from memos.api.product_models import ( - BaseResponse, - ChatCompleteRequest, - ChatRequest, - GetMemoryRequest, MemoryCreateRequest, - MemoryResponse, SearchRequest, - SearchResponse, - SimpleResponse, - SuggestionRequest, - SuggestionResponse, - UserRegisterRequest, - UserRegisterResponse, + SearchResponse, MemoryResponse, ) +from memos.chunkers.sentence_chunker import SentenceChunker +from memos.configs.chunker import SentenceChunkerConfig from memos.configs.embedder import UniversalAPIEmbedderConfig from memos.configs.graph_db import NebulaGraphDBConfig from memos.configs.llm import OpenAILLMConfig from memos.embedders.universal_api import UniversalAPIEmbedder from memos.graph_dbs.nebular import NebulaGraphDB from memos.llms.openai import OpenAILLM -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.cosine_local import CosineLocalReranker -from memos.mem_reader.simple_struct import SimpleStructMemReader -from memos.configs.mem_reader import SimpleStructMemReaderConfig -from memos.configs.chunker import ChunkerConfigFactory -from memos.configs.llm import LLMConfigFactory -from memos.configs.embedder import EmbedderConfigFactory -from memos.configs.chunker import SentenceChunkerConfig -from memos.chunkers.sentence_chunker import SentenceChunker logger = log.get_logger(__name__) router = APIRouter() + def init_model(): llm = OpenAILLM( - OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o', + OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o-mini', temperature=0.8, max_tokens=1024, top_p=0.9, top_k=50, remove_think_prefix=True, api_key=os.getenv('OPENAI_API_KEY'), api_base=os.getenv('OPENAI_API_BASE'), extra_body=None)) embedder = UniversalAPIEmbedder( UniversalAPIEmbedderConfig(model_schema='memos.configs.embedder.UniversalAPIEmbedderConfig', - model_name_or_path='bge-m3', embedding_dims=None, provider='openai', - api_key='EMPTY', base_url=os.getenv('MOS_EMBEDDER_API_BASE'))) + model_name_or_path='bge-m3', embedding_dims=None, provider='openai', + api_key='EMPTY', base_url=os.getenv('MOS_EMBEDDER_API_BASE'))) reranker = CosineLocalReranker(level_weights={"topic": 1.0, "concept": 1.0, "fact": 1.0}, level_field='background') graph_store = NebulaGraphDB( NebulaGraphDBConfig(model_schema='memos.configs.graph_db.NebulaGraphDBConfig', uri=json.loads(os.getenv('NEBULAR_HOSTS')), - user=os.getenv('NEBULAR_USER'), password=os.getenv('NEBULAR_PASSWORD'), space=os.getenv('NEBULAR_SPACE'), + user=os.getenv('NEBULAR_USER'), password=os.getenv('NEBULAR_PASSWORD'), + space=os.getenv('NEBULAR_SPACE'), auto_create=True, max_client=1000, embedding_dimension=1024)) search_obj = Searcher(llm, graph_store, embedder, reranker, internet_retriever=None, moscube=False) chunker = SentenceChunker( - SentenceChunkerConfig( - model_schema='memos.configs.chunker.SentenceChunkerConfig', - tokenizer_or_token_counter="gpt2", - chunk_size=512, - chunk_overlap=128, - min_sentences_per_chunk=1, - ) - ) - mem_reader = SimpleStructMemReader( - llm, - embedder, - chunker - ) - memory_add_obj = MemoryManager( - graph_store, - embedder, - llm, - memory_size={ - "WorkingMemory": 20, - "LongTermMemory": 1500, - "UserMemory": 480, - }, - is_reorganize=False - ) + SentenceChunkerConfig(model_schema='memos.configs.chunker.SentenceChunkerConfig', + tokenizer_or_token_counter="gpt2", chunk_size=512, chunk_overlap=128, + min_sentences_per_chunk=1)) + mem_reader = SimpleStructMemReader(llm, embedder, chunker) + memory_add_obj = MemoryManager(graph_store, embedder, llm, + memory_size={"WorkingMemory": 20, "LongTermMemory": 1500, "UserMemory": 480}, + is_reorganize=False) return search_obj, memory_add_obj, mem_reader + search_obj, memory_add_obj, mem_reader = init_model() @@ -95,30 +67,25 @@ def search_memories(search_req: SearchRequest): # user_id = f"memos{search_req.user_id.replace('-', '')}" user_id = search_req.user_id res = search_obj.search(query=search_req.query, user_id=user_id, top_k=search_req.top_k - , mode="fast", search_filter=None, - info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []}) + , mode="fast", search_filter=None, + info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []}) res = {"d": res} # print(res) return SearchResponse(message="Search completed successfully", data=res) -@router.post("/add", summary="add memories", response_model=SearchResponse) +@router.post("/add", summary="add memories", response_model=MemoryResponse) def add_memories(add_req: MemoryCreateRequest): """Add memories for a specific user.""" time_start = time.time() - memories = mem_reader.get_memory( [add_req.messages], type="chat", - info={"user_id": add_req.user_id, "session_id": add_req.session_id},)[0] - logger.info( - f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s" - ) - data = [] - + info={"user_id": add_req.user_id, "session_id": add_req.session_id})[0] + logger.info(f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s") mem_id_list: list[str] = memory_add_obj.add(memories, user_name=add_req.user_id) logger.info(f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}") - + data = [] for m_id, m in zip(mem_id_list, memories): - data.append({'memory': m.memory, 'mem_ids': m_id, 'memory_type': m.metadata.memory_type}) - return SearchResponse(message="Memory added successfully", data=data) + data.append({'memory': m.memory, 'memory_id': m_id, 'memory_type': m.metadata.memory_type}) + return MemoryResponse(message="Memory added successfully", data=data) From 71f357a150087bbe5bc43d5c44333eb7b7e431ac Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Mon, 29 Sep 2025 10:59:46 +0800 Subject: [PATCH 16/29] format: ruff format code --- src/memos/api/product_models.py | 4 +- src/memos/api/routers/product_router.py | 67 ++++-------- src/memos/api/routers/server_router.py | 101 +++++++++++++----- src/memos/graph_dbs/nebular.py | 38 +++++-- src/memos/mem_reader/simple_struct.py | 4 - .../tree_text_memory/organize/manager.py | 17 +-- .../tree_text_memory/retrieve/recall.py | 4 +- .../tree_text_memory/retrieve/searcher.py | 52 ++++++--- 8 files changed, 175 insertions(+), 112 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 3ad47dbeb..ea6dd97cc 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -151,7 +151,9 @@ class MemoryCreateRequest(BaseRequest): mem_cube_id: str | None = Field(None, description="Cube ID") source: str | None = Field(None, description="Source of the memory") user_profile: bool = Field(False, description="User profile memory") - session_id: str | None = Field(default_factory=lambda: str(uuid.uuid4()), description="Session id") + session_id: str | None = Field( + default_factory=lambda: str(uuid.uuid4()), description="Session id" + ) class SearchRequest(BaseRequest): diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 64c6e3346..75b614cf4 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -213,56 +213,29 @@ def create_memory(memory_req: MemoryCreateRequest): raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -from memos.configs.embedder import UniversalAPIEmbedderConfig -from memos.configs.graph_db import NebulaGraphDBConfig -from memos.configs.llm import OpenAILLMConfig -from memos.embedders.universal_api import UniversalAPIEmbedder -from memos.graph_dbs.nebular import NebulaGraphDB -from memos.llms.openai import OpenAILLM -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher -from memos.reranker.cosine_local import CosineLocalReranker -import os -llm = OpenAILLM( - OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o', - temperature=0.8, max_tokens=1024, top_p=0.9, top_k=50, remove_think_prefix=True, - api_key=os.getenv('OPENAI_API_KEY'), - api_base=os.getenv('OPENAI_API_BASE'), extra_body=None)) -embedder = UniversalAPIEmbedder( - UniversalAPIEmbedderConfig(model_schema='memos.configs.embedder.UniversalAPIEmbedderConfig', - model_name_or_path='bge-m3', embedding_dims=None, provider='openai', - api_key='EMPTY', base_url=os.getenv('MOS_EMBEDDER_API_BASE'))) - -reranker = CosineLocalReranker(level_weights={"topic": 1.0, "concept": 1.0, "fact": 1.0}, level_field='background') - -graph_store = NebulaGraphDB( - NebulaGraphDBConfig(model_schema='memos.configs.graph_db.NebulaGraphDBConfig', - uri=json.loads(os.getenv('NEBULAR_HOSTS')), - user=os.getenv('NEBULAR_USER'), password=os.getenv('NEBULAR_PASSWORD'), space=os.getenv('NEBULAR_SPACE'), - auto_create=True, max_client=1000, embedding_dimension=1024)) - -s = Searcher(llm, graph_store, embedder, reranker, internet_retriever=None, moscube=False) - - @router.post("/search", summary="Search memories", response_model=SearchResponse) def search_memories(search_req: SearchRequest): """Search memories for a specific user.""" - # try: - # user_id = f"memos{search_req.user_id.replace('-', '')}" - user_id = search_req.user_id - res = s.search(query=search_req.query, user_id=user_id, top_k=search_req.top_k - , mode="fast", search_filter=None, - info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []}) - res = {"d": res} - # print(res) - return SearchResponse(message="Search completed successfully", data=res) - - # except ValueError as err: - # raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - # except Exception as err: - # logger.error(f"Failed to search memories: {traceback.format_exc()}") - # raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + try: + time_start_search = time.time() + mos_product = get_mos_product_instance() + result = mos_product.search( + query=search_req.query, + user_id=search_req.user_id, + install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None, + top_k=search_req.top_k, + session_id=search_req.session_id, + ) + logger.info( + f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}" + ) + return SearchResponse(message="Search completed successfully", data=result) + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + logger.error(f"Failed to search memories: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err @router.post("/chat", summary="Chat with MemOS") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 5b3297a5a..71a34ff3c 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -6,7 +6,8 @@ from memos.api.product_models import ( MemoryCreateRequest, SearchRequest, - SearchResponse, MemoryResponse, + SearchResponse, + MemoryResponse, ) from memos.chunkers.sentence_chunker import SentenceChunker from memos.configs.chunker import SentenceChunkerConfig @@ -27,32 +28,66 @@ def init_model(): llm = OpenAILLM( - OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o-mini', - temperature=0.8, max_tokens=1024, top_p=0.9, top_k=50, remove_think_prefix=True, - api_key=os.getenv('OPENAI_API_KEY'), - api_base=os.getenv('OPENAI_API_BASE'), extra_body=None)) + OpenAILLMConfig( + model_schema="memos.configs.llm.OpenAILLMConfig", + model_name_or_path="gpt-4o-mini", + temperature=0.8, + max_tokens=1024, + top_p=0.9, + top_k=50, + remove_think_prefix=True, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=os.getenv("OPENAI_API_BASE"), + extra_body=None, + ) + ) embedder = UniversalAPIEmbedder( - UniversalAPIEmbedderConfig(model_schema='memos.configs.embedder.UniversalAPIEmbedderConfig', - model_name_or_path='bge-m3', embedding_dims=None, provider='openai', - api_key='EMPTY', base_url=os.getenv('MOS_EMBEDDER_API_BASE'))) + UniversalAPIEmbedderConfig( + model_schema="memos.configs.embedder.UniversalAPIEmbedderConfig", + model_name_or_path="bge-m3", + embedding_dims=None, + provider="openai", + api_key="EMPTY", + base_url=os.getenv("MOS_EMBEDDER_API_BASE"), + ) + ) - reranker = CosineLocalReranker(level_weights={"topic": 1.0, "concept": 1.0, "fact": 1.0}, level_field='background') + reranker = CosineLocalReranker( + level_weights={"topic": 1.0, "concept": 1.0, "fact": 1.0}, level_field="background" + ) graph_store = NebulaGraphDB( - NebulaGraphDBConfig(model_schema='memos.configs.graph_db.NebulaGraphDBConfig', - uri=json.loads(os.getenv('NEBULAR_HOSTS')), - user=os.getenv('NEBULAR_USER'), password=os.getenv('NEBULAR_PASSWORD'), - space=os.getenv('NEBULAR_SPACE'), - auto_create=True, max_client=1000, embedding_dimension=1024)) - search_obj = Searcher(llm, graph_store, embedder, reranker, internet_retriever=None, moscube=False) + NebulaGraphDBConfig( + model_schema="memos.configs.graph_db.NebulaGraphDBConfig", + uri=json.loads(os.getenv("NEBULAR_HOSTS")), + user=os.getenv("NEBULAR_USER"), + password=os.getenv("NEBULAR_PASSWORD"), + space=os.getenv("NEBULAR_SPACE"), + auto_create=True, + max_client=1000, + embedding_dimension=1024, + ) + ) + search_obj = Searcher( + llm, graph_store, embedder, reranker, internet_retriever=None, moscube=False + ) chunker = SentenceChunker( - SentenceChunkerConfig(model_schema='memos.configs.chunker.SentenceChunkerConfig', - tokenizer_or_token_counter="gpt2", chunk_size=512, chunk_overlap=128, - min_sentences_per_chunk=1)) + SentenceChunkerConfig( + model_schema="memos.configs.chunker.SentenceChunkerConfig", + tokenizer_or_token_counter="gpt2", + chunk_size=512, + chunk_overlap=128, + min_sentences_per_chunk=1, + ) + ) mem_reader = SimpleStructMemReader(llm, embedder, chunker) - memory_add_obj = MemoryManager(graph_store, embedder, llm, - memory_size={"WorkingMemory": 20, "LongTermMemory": 1500, "UserMemory": 480}, - is_reorganize=False) + memory_add_obj = MemoryManager( + graph_store, + embedder, + llm, + memory_size={"WorkingMemory": 20, "LongTermMemory": 1500, "UserMemory": 480}, + is_reorganize=False, + ) return search_obj, memory_add_obj, mem_reader @@ -66,9 +101,14 @@ def search_memories(search_req: SearchRequest): # try: # user_id = f"memos{search_req.user_id.replace('-', '')}" user_id = search_req.user_id - res = search_obj.search(query=search_req.query, user_id=user_id, top_k=search_req.top_k - , mode="fast", search_filter=None, - info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []}) + res = search_obj.search( + query=search_req.query, + user_id=user_id, + top_k=search_req.top_k, + mode="fast", + search_filter=None, + info={"user_id": user_id, "session_id": "root_session", "chat_history": []}, + ) res = {"d": res} # print(res) return SearchResponse(message="Search completed successfully", data=res) @@ -81,11 +121,16 @@ def add_memories(add_req: MemoryCreateRequest): memories = mem_reader.get_memory( [add_req.messages], type="chat", - info={"user_id": add_req.user_id, "session_id": add_req.session_id})[0] - logger.info(f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s") + info={"user_id": add_req.user_id, "session_id": add_req.session_id}, + )[0] + logger.info( + f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s" + ) mem_id_list: list[str] = memory_add_obj.add(memories, user_name=add_req.user_id) - logger.info(f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}") + logger.info( + f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}" + ) data = [] for m_id, m in zip(mem_id_list, memories): - data.append({'memory': m.memory, 'memory_id': m_id, 'memory_type': m.metadata.memory_type}) + data.append({"memory": m.memory, "memory_id": m_id, "memory_type": m.metadata.memory_type}) return MemoryResponse(message="Memory added successfully", data=data) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 952260df4..43a7538fa 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -415,7 +415,9 @@ def create_index( self._create_basic_property_indexes() @timed - def remove_oldest_memory(self, memory_type: str, keep_latest: int, user_name: str = None) -> None: + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str = None + ) -> None: """ Remove all WorkingMemory nodes except the latest `keep_latest` entries. """ @@ -432,7 +434,9 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int, user_name: st self.execute_query(query) @timed - def add_node(self, id: str, memory: str, metadata: dict[str, Any], user_name: str = None) -> None: + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str = None + ) -> None: """ Insert or update a Memory node in NebulaGraph. """ @@ -589,7 +593,12 @@ def count_nodes(self, scope: str, user_name: str = None) -> int: @timed def edge_exists( - self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING", user_name: str = None + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str = None, ) -> bool: """ Check if an edge exists between two nodes. @@ -630,7 +639,9 @@ def edge_exists( @timed # Graph Query & Reasoning - def get_node(self, id: str, include_embedding: bool = False, user_name: str = None) -> dict[str, Any] | None: + def get_node( + self, id: str, include_embedding: bool = False, user_name: str = None + ) -> dict[str, Any] | None: """ Retrieve a Memory node by its unique ID. @@ -706,7 +717,9 @@ def get_nodes( return nodes @timed - def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str = None) -> list[dict[str, str]]: + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str = None + ) -> list[dict[str, str]]: """ Get edges connected to a node, with optional type and direction filter. @@ -849,7 +862,11 @@ def get_children_with_embeddings(self, id: str, user_name: str = None) -> list[d @timed def get_subgraph( - self, center_id: str, depth: int = 2, center_status: str = "activated", user_name: str = None + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str = None, ) -> dict[str, Any]: """ Retrieve a local subgraph centered at a given node. @@ -1153,7 +1170,9 @@ def clear(self, user_name: str = None) -> None: logger.error(f"[ERROR] Failed to clear database: {e}") @timed - def export_graph(self, include_embedding: bool = False, user_name: str = None) -> dict[str, Any]: + def export_graph( + self, include_embedding: bool = False, user_name: str = None + ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. Args: @@ -1265,7 +1284,9 @@ def import_graph(self, data: dict[str, Any], user_name: str = None) -> None: logger.error(f"Fail to load edge: {edge}, error: {e}") @timed - def get_all_memory_items(self, scope: str, include_embedding: bool = False, user_name: str = None) -> (list)[dict]: + def get_all_memory_items( + self, scope: str, include_embedding: bool = False, user_name: str = None + ) -> (list)[dict]: """ Retrieve all memory items of a specific memory_type. @@ -1344,7 +1365,6 @@ def get_structure_optimization_candidates( logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") return candidates - @timed def detect_conflicts(self) -> list[tuple[str, str]]: """ diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 6bf399650..a4d38ccca 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -122,10 +122,6 @@ def __init__(self, llm, embedder, chunker): Args: config: Configuration object for the reader """ - # self.config = config - # self.llm = LLMFactory.from_config(config.llm) - # self.embedder = EmbedderFactory.from_config(config.embedder) - # self.chunker = ChunkerFactory.from_config(config.chunker) self.llm = llm self.embedder = embedder self.chunker = chunker diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 94dad104e..2048daac5 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -140,15 +140,15 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str = None): # Add to LongTermMemory and UserMemory if memory.metadata.memory_type in ["LongTermMemory", "UserMemory"]: added_id = self._add_to_graph_memory( - memory=memory, - memory_type=memory.metadata.memory_type, - user_name=user_name + memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name ) ids.append(added_id) return ids - def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str, user_name: str = None) -> str: + def _add_memory_to_db( + self, memory: TextualMemoryItem, memory_type: str, user_name: str = None + ) -> str: """ Add a single memory item to the graph store, with FIFO logic for WorkingMemory. """ @@ -162,7 +162,9 @@ def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str, user_na self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) return working_memory.id - def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str, user_name: str = None): + def _add_to_graph_memory( + self, memory: TextualMemoryItem, memory_type: str, user_name: str = None + ): """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). @@ -176,7 +178,10 @@ def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str, user node_id = str(uuid.uuid4()) # Step 2: Add new node to graph self.graph_store.add_node( - node_id, memory.memory, memory.metadata.model_dump(exclude_none=True), user_name=user_name + node_id, + memory.memory, + memory.metadata.model_dump(exclude_none=True), + user_name=user_name, ) self.reorganizer.add_message( QueueMessage( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 6e1af2f03..7c75fe3f3 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -167,7 +167,9 @@ def _graph_recall( return [] # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False, user_name=user_id) + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=False, user_name=user_id + ) final_nodes = [] for node in node_dicts: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 7071e76fa..08c0772e5 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -74,7 +74,9 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ - logger.info(f"[SEARCH] Start query='{query}', user_id='{user_id}', top_k={top_k}, mode={mode}, memory_type={memory_type}") + logger.info( + f"[SEARCH] Start query='{query}', user_id='{user_id}', top_k={top_k}, mode={mode}, memory_type={memory_type}" + ) if not info: logger.warning( "Please input 'info' when use tree.search so that " @@ -180,14 +182,28 @@ def parallel_retrieve( tasks.append( executor.submit( self._retrieve_memory, - query, info, parsed_goal, - query_embedding, top_k, - m_type, search_filter, )) + query, + info, + parsed_goal, + query_embedding, + top_k, + m_type, + search_filter, + ) + ) if self.internet_retriever and mode == "fine": tasks.append( executor.submit( self._retrieve_from_internet, - query, parsed_goal, query_embedding, top_k, info, mode, memory_type, )) + query, + parsed_goal, + query_embedding, + top_k, + info, + mode, + memory_type, + ) + ) if self.moscube: tasks.append( executor.submit( @@ -207,22 +223,22 @@ def parallel_retrieve( logger.info(f"[SEARCH] Total raw results: {len(results)}") return results - @timed def _retrieve_memory( - self, - query, - info, - parsed_goal, - query_embedding, - top_k, - memory_type, - search_filter: dict | None = None): + self, + query, + info, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter: dict | None = None, + ): if memory_type in ["LongTermMemory", "UserMemory"]: top_k = 2 * top_k items = self.graph_retriever.retrieve( query=query, - user_id=info['user_id'], + user_id=info["user_id"], parsed_goal=parsed_goal, top_k=top_k, memory_scope=memory_type, @@ -329,6 +345,10 @@ def _update_usage_history(self, items, info): def _update_usage_history_worker(self, payload, usage_record: str): try: for item_id, usage_list in payload: - self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=json.loads(usage_list[0])['info']['user_id']) + self.graph_store.update_node( + item_id, + {"usage": usage_list}, + user_name=json.loads(usage_list[0])["info"]["user_id"], + ) except Exception: logger.exception("[USAGE] update usage failed") From c04ed797c9bdebd469aa528ca472cfb21dc30d58 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Mon, 29 Sep 2025 11:28:56 +0800 Subject: [PATCH 17/29] =?UTF-8?q?feat(server):=20=E5=A2=9E=E5=8A=A0LLM?= =?UTF-8?q?=E6=9C=80=E5=A4=A7=E4=BB=A4=E7=89=8C=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/api/routers/server_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 71a34ff3c..8cbfd9716 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -32,7 +32,7 @@ def init_model(): model_schema="memos.configs.llm.OpenAILLMConfig", model_name_or_path="gpt-4o-mini", temperature=0.8, - max_tokens=1024, + max_tokens=4096, top_p=0.9, top_k=50, remove_think_prefix=True, From 880f60c4d5f1995b45d6fb997f10ced4faf58d7c Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Mon, 29 Sep 2025 15:11:51 +0800 Subject: [PATCH 18/29] fix: user query embedding for search --- .../memories/textual/tree_text_memory/retrieve/searcher.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 08c0772e5..6fec648e9 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -236,14 +236,19 @@ def _retrieve_memory( ): if memory_type in ["LongTermMemory", "UserMemory"]: top_k = 2 * top_k + if memory_type == "WorkingMemory": + query_embedding = None items = self.graph_retriever.retrieve( query=query, user_id=info["user_id"], parsed_goal=parsed_goal, top_k=top_k, memory_scope=memory_type, + query_embedding=query_embedding, search_filter=search_filter, ) + if memory_type == "WorkingMemory": + query_embedding = [None] return self.reranker.rerank( query=query, query_embedding=query_embedding[0], From d01c8cf96b3b02866f8a3c1a1c8e577eb77c11d4 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Mon, 29 Sep 2025 19:32:56 +0800 Subject: [PATCH 19/29] hotfix:noe4j community dataformat (#353) --- src/memos/graph_dbs/neo4j.py | 4 ++++ src/memos/graph_dbs/neo4j_community.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 96908913d..ccc91c48b 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -38,6 +38,10 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: if embedding and isinstance(embedding, list): metadata["embedding"] = [float(x) for x in embedding] + # serialization + if metadata["sources"]: + for idx in range(len(metadata["sources"])): + metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) return metadata diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 8acab420c..54000a51d 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -1,3 +1,4 @@ +import json from typing import Any from memos.configs.graph_db import Neo4jGraphDBConfig @@ -49,6 +50,10 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: # Safely process metadata metadata = _prepare_node_metadata(metadata) + # serialization + if metadata["sources"]: + for idx in range(len(metadata["sources"])): + metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) # Extract required fields embedding = metadata.pop("embedding", None) if embedding is None: @@ -298,7 +303,16 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: if time_field in node and hasattr(node[time_field], "isoformat"): node[time_field] = node[time_field].isoformat() node.pop("user_name", None) - + # serialization + if node["sources"]: + for idx in range(len(node["sources"])): + if not ( + isinstance(node["sources"][idx], str) + and node["sources"][idx][0] == "{" + and node["sources"][idx][0] == "}" + ): + break + node["sources"][idx] = json.loads(node["sources"][idx]) new_node = {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node} try: vec_item = self.vec_db.get_by_id(new_node["id"]) From cce9f6cbab430bec5f6dd8bef15a0ffa09e3367f Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Mon, 29 Sep 2025 19:46:35 +0800 Subject: [PATCH 20/29] count memory_size by user --- .../tree_text_memory/organize/manager.py | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 2048daac5..d856fa0a8 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -66,28 +66,15 @@ def add(self, memories: list[TextualMemoryItem], user_name: str = None) -> list[ except Exception as e: logger.exception("Memory processing error: ", exc_info=e) - try: - self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] - ) - except Exception: - logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"] - ) - except Exception: - logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"] - ) - except Exception: - logger.warning(f"Remove UserMemory error: {traceback.format_exc()}") + for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + try: + self.graph_store.remove_oldest_memory( + memory_type="WorkingMemory", keep_latest=self.memory_size[mem_type], user_name=user_name + ) + except Exception: + logger.warning(f"Remove {mem_type} error: {traceback.format_exc()}") - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) return added_ids def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: @@ -118,11 +105,11 @@ def get_current_memory_size(self) -> dict[str, int]: self._refresh_memory_size() return self.current_memory_size - def _refresh_memory_size(self) -> None: + def _refresh_memory_size(self, user_name: str = None) -> None: """ Query the latest counts from the graph store and update internal state. """ - results = self.graph_store.get_grouped_counts(group_fields=["memory_type"]) + results = self.graph_store.get_grouped_counts(group_fields=["memory_type"], user_name=user_name) self.current_memory_size = {record["memory_type"]: record["count"] for record in results} logger.info(f"[MemoryManager] Refreshed memory sizes: {self.current_memory_size}") From 9b465896d70d019bb48789487ab45cf17018e841 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Mon, 29 Sep 2025 20:27:37 +0800 Subject: [PATCH 21/29] =?UTF-8?q?fix(server):=E4=BF=AE=E5=A4=8D=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E8=AF=BB=E5=8F=96=E9=80=BB=E8=BE=91=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E5=88=97=E8=A1=A8=E5=B1=95=E5=BC=80=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/api/routers/server_router.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 8cbfd9716..36ecd82f6 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -122,7 +122,8 @@ def add_memories(add_req: MemoryCreateRequest): [add_req.messages], type="chat", info={"user_id": add_req.user_id, "session_id": add_req.session_id}, - )[0] + ) + memories = [mm for m in memories for mm in m] logger.info( f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s" ) From 2da62c89f1414e5e131a5d901fd25a25ac487552 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Tue, 30 Sep 2025 14:26:24 +0800 Subject: [PATCH 22/29] milvus implement (#354) * milvus implement * milvus implement * milvus implement --------- Co-authored-by: yuan.wang --- src/memos/configs/vec_db.py | 13 ++ src/memos/vec_dbs/milvus.py | 365 ++++++++++++++++++++++++++++++++++++ 2 files changed, 378 insertions(+) create mode 100644 src/memos/vec_dbs/milvus.py diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index b43298d9b..dd1748714 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -39,6 +39,18 @@ def set_default_path(self): return self +class MilvusVecDBConfig(BaseVecDBConfig): + """Configuration for Milvus vector database.""" + + uri: str = Field(..., description="URI for Milvus connection") + collection_name: list[str] = Field(..., description="Name(s) of the collection(s)") + max_length: int = Field( + default=65535, description="Maximum length for string fields (varChar type)" + ) + user_name: str = Field(default="", description="User name for Milvus connection") + password: str = Field(default="", description="Password for Milvus connection") + + class VectorDBConfigFactory(BaseConfig): """Factory class for creating vector database configurations.""" @@ -47,6 +59,7 @@ class VectorDBConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "qdrant": QdrantVecDBConfig, + "milvus": MilvusVecDBConfig, } @field_validator("backend") diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py new file mode 100644 index 000000000..fca6a59c2 --- /dev/null +++ b/src/memos/vec_dbs/milvus.py @@ -0,0 +1,365 @@ +from typing import Any + +from memos.configs.vec_db import MilvusVecDBConfig +from memos.dependency import require_python_package +from memos.log import get_logger +from memos.vec_dbs.base import BaseVecDB +from memos.vec_dbs.item import VecDBItem + + +logger = get_logger(__name__) + + +class MilvusVecDB(BaseVecDB): + """Milvus vector database implementation.""" + + @require_python_package( + import_name="pymilvus", + install_command="pip install -U pymilvus", + install_link="https://milvus.io/docs/install-pymilvus.md", + ) + def __init__(self, config: MilvusVecDBConfig): + """Initialize the Milvus vector database and the collection.""" + from pymilvus import MilvusClient + self.config = config + + # Create Milvus client + self.client = MilvusClient( + uri=self.config.uri, user=self.config.user_name, password=self.config.password + ) + self.schema = self.create_schema() + self.index_params = self.create_index() + self.create_collection() + + def create_schema(self): + """Create schema for the milvus collection.""" + from pymilvus import DataType + schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True) + schema.add_field( + field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True + ) + schema.add_field( + field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension + ) + schema.add_field(field_name="payload", datatype=DataType.JSON) + + return schema + + def create_index(self): + """Create index for the milvus collection.""" + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", index_type="FLAT", metric_type=self._get_metric_type() + ) + + return index_params + + def create_collection(self) -> None: + """Create a new collection with specified parameters.""" + for collection_name in self.config.collection_name: + if self.collection_exists(collection_name): + logger.warning(f"Collection '{collection_name}' already exists. Skipping creation.") + continue + + self.client.create_collection( + collection_name=collection_name, + dimension=self.config.vector_dimension, + metric_type=self._get_metric_type(), + schema=self.schema, + index_params=self.index_params, + ) + + logger.info( + f"Collection '{collection_name}' created with {self.config.vector_dimension} dimensions." + ) + + def create_collection_by_name(self, collection_name: str) -> None: + """Create a new collection with specified parameters.""" + if self.collection_exists(collection_name): + logger.warning(f"Collection '{collection_name}' already exists. Skipping creation.") + return + + self.client.create_collection( + collection_name=collection_name, + dimension=self.config.vector_dimension, + metric_type=self._get_metric_type(), + schema=self.schema, + index_params=self.index_params, + ) + + def list_collections(self) -> list[str]: + """List all collections.""" + return self.client.list_collections() + + def delete_collection(self, name: str) -> None: + """Delete a collection.""" + self.client.drop_collection(name) + + def collection_exists(self, name: str) -> bool: + """Check if a collection exists.""" + return self.client.has_collection(collection_name=name) + + def search( + self, + query_vector: list[float], + collection_name: str, + top_k: int, + filter: dict[str, Any] | None = None, + ) -> list[VecDBItem]: + """ + Search for similar items in the database. + + Args: + query_vector: Single vector to search + collection_name: Name of the collection to search + top_k: Number of results to return + filter: Payload filters + + Returns: + List of search results with distance scores and payloads. + """ + # Convert filter to Milvus expression + expr = self._dict_to_expr(filter) if filter else "" + + results = self.client.search( + collection_name=collection_name, + data=[query_vector], + limit=top_k, + filter=expr, + output_fields=["*"], # Return all fields + ) + + items = [] + for hit in results[0]: + entity = hit.get("entity", {}) + + items.append( + VecDBItem( + id=str(hit["id"]), + vector=entity.get("vector"), + payload=entity.get("payload", {}), + score=1 - float(hit["distance"]), + ) + ) + + logger.info(f"Milvus search completed with {len(items)} results.") + return items + + def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: + """Convert a dictionary filter to a Milvus expression string.""" + if not filter_dict: + return "" + + conditions = [] + for field, value in filter_dict.items(): + # Skip None values as they cause Milvus query syntax errors + if value is None: + continue + # For JSON fields, we need to use payload["field"] syntax + elif isinstance(value, str): + conditions.append(f"payload['{field}'] == '{value}'") + elif isinstance(value, list) and len(value) == 0: + # Skip empty lists as they cause Milvus query syntax errors + continue + elif isinstance(value, list) and len(value) > 0: + conditions.append(f"payload['{field}'] in {value}") + else: + conditions.append(f"payload['{field}'] == '{value}'") + return " and ".join(conditions) + + def _get_metric_type(self) -> str: + """Get the metric type for search.""" + metric_map = { + "cosine": "COSINE", + "euclidean": "L2", + "dot": "IP", + } + return metric_map.get(self.config.distance_metric, "L2") + + def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: + """Get a single item by ID.""" + results = self.client.get( + collection_name=collection_name, + ids=[id], + ) + + if not results: + return None + + entity = results[0] + payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} + + return VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + + def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: + """Get multiple items by their IDs.""" + results = self.client.get( + collection_name=collection_name, + ids=ids, + ) + + if not results: + return [] + + items = [] + for entity in results: + payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} + items.append( + VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + ) + + return items + + def get_by_filter( + self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 + ) -> list[VecDBItem]: + """ + Retrieve all items that match the given filter criteria using query_iterator. + + Args: + filter: Payload filters to match against stored items + scroll_limit: Maximum number of items to retrieve per batch (batch_size) + + Returns: + List of items including vectors and payload that match the filter + """ + expr = self._dict_to_expr(filter) if filter else "" + all_items = [] + + # Use query_iterator for efficient pagination + iterator = self.client.query_iterator( + collection_name=collection_name, + filter=expr, + batch_size=scroll_limit, + output_fields=["*"], # Include all fields including payload + ) + + # Iterate through all batches + try: + while True: + batch_results = iterator.next() + + if not batch_results: + break + + # Convert batch results to VecDBItem objects + for entity in batch_results: + # Extract the actual payload from Milvus entity + payload = entity.get("payload", {}) + all_items.append( + VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + ) + except Exception as e: + logger.warning( + f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far." + ) + finally: + # Close the iterator + iterator.close() + + logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") + return all_items + + def get_all(self, collection_name: str, scroll_limit=100) -> list[VecDBItem]: + """Retrieve all items in the vector database.""" + return self.get_by_filter(collection_name, {}, scroll_limit=scroll_limit) + + def count(self, collection_name: str, filter: dict[str, Any] | None = None) -> int: + """Count items in the database, optionally with filter.""" + if filter: + # If there's a filter, use query method + expr = self._dict_to_expr(filter) if filter else "" + results = self.client.query( + collection_name=collection_name, + filter=expr, + output_fields=["id"], + ) + return len(results) + else: + # For counting all items, use get_collection_stats for accurate count + stats = self.client.get_collection_stats(collection_name) + # Extract row count from stats - stats is a dict, not a list + return int(stats.get("row_count", 0)) + + def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + """ + Add data to the vector database. + + Args: + data: List of VecDBItem objects or dictionaries containing: + - 'id': unique identifier + - 'vector': embedding vector + - 'payload': additional fields for filtering/retrieval + """ + entities = [] + for item in data: + if isinstance(item, dict): + item = item.copy() + item = VecDBItem.from_dict(item) + + # Prepare entity data + entity = { + "id": item.id, + "vector": item.vector, + "payload": item.payload if item.payload else {}, + } + + entities.append(entity) + + # Use upsert to be safe (insert or update) + self.client.upsert( + collection_name=collection_name, + data=entities, + ) + + def update(self, collection_name: str, id: str, data: VecDBItem | dict[str, Any]) -> None: + """Update an item in the vector database.""" + if isinstance(data, dict): + data = data.copy() + data = VecDBItem.from_dict(data) + + # Use upsert for updates + self.upsert(collection_name, [data]) + + def ensure_payload_indexes(self, fields: list[str]) -> None: + """ + Create payload indexes for specified fields in the collection. + This is idempotent: it will skip if index already exists. + + Args: + fields (list[str]): List of field names to index (as keyword). + """ + # Note: Milvus doesn't have the same concept of payload indexes as Qdrant + # Field indexes are created automatically for scalar fields + logger.info(f"Milvus automatically indexes scalar fields: {fields}") + + def upsert(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + """ + Add or update data in the vector database. + + If an item with the same ID exists, it will be updated. + Otherwise, it will be added as a new item. + """ + # Reuse add method since it already uses upsert + self.add(collection_name, data) + + def delete(self, collection_name: str, ids: list[str]) -> None: + """Delete items from the vector database.""" + if not ids: + return + self.client.delete( + collection_name=collection_name, + ids=ids, + ) From 15cdbac864b61b90735f984c54d9bbcbbb10bb40 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Tue, 30 Sep 2025 14:38:57 +0800 Subject: [PATCH 23/29] fix: code ruff format (#355) --- src/memos/vec_dbs/milvus.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index fca6a59c2..7bb1ceeba 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -21,6 +21,7 @@ class MilvusVecDB(BaseVecDB): def __init__(self, config: MilvusVecDBConfig): """Initialize the Milvus vector database and the collection.""" from pymilvus import MilvusClient + self.config = config # Create Milvus client @@ -34,6 +35,7 @@ def __init__(self, config: MilvusVecDBConfig): def create_schema(self): """Create schema for the milvus collection.""" from pymilvus import DataType + schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True) schema.add_field( field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True From 9a45f6052ce1892128dd20a249b55eec35f09197 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Tue, 14 Oct 2025 11:55:04 +0800 Subject: [PATCH 24/29] =?UTF-8?q?feat(nebular):=E4=BC=98=E5=8C=96=E5=9B=BE?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=9F=A5=E8=AF=A2=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/graph_dbs/nebular.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 43a7538fa..21455a182 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -700,7 +700,7 @@ def get_nodes( return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.id IN [{id_list}] {where_user} RETURN {return_fields} """ @@ -972,8 +972,7 @@ def search_by_embedding( dim = len(vector) vector_str = ",".join(f"{float(x)}" for x in vector) gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - - where_clauses = [] + where_clauses = [f"n.{self.dim_field} IS NOT NULL"] if scope: where_clauses.append(f'n.memory_type = "{scope}"') if status: @@ -991,15 +990,12 @@ def search_by_embedding( where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - MATCH (n@Memory) - {where_clause} - ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC - APPROXIMATE - LIMIT {top_k} - OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }} - RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score - """ - + let a = {gql_vector} + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) + {where_clause} + ORDER BY inner_product(n.{self.dim_field}, a) DESC + LIMIT {top_k} + RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" try: result = self.execute_query(gql) except Exception as e: @@ -1074,7 +1070,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]], user_name: str = None) where_clauses.append(f'n.user_name = "{user_name}"') where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id" + gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" ids = [] try: result = self.execute_query(gql) @@ -1303,11 +1299,12 @@ def get_all_memory_items( where_clause = f"WHERE n.memory_type = '{scope}'" where_clause += f" AND n.user_name = '{user_name}'" + # where_clause = f"WHERE n.user_name = '{user_name}'" return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {return_fields} LIMIT 100 From 49369074d94c6442cf87bdc858af28ac4d056882 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Wed, 15 Oct 2025 16:12:05 +0800 Subject: [PATCH 25/29] feat: remove user idx_memory_user_name --- src/memos/graph_dbs/nebular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 21455a182..7fcf817ba 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -424,7 +424,7 @@ def remove_oldest_memory( optional_condition = f"AND n.user_name = '{user_name}'" query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.memory_type = '{memory_type}' {optional_condition} ORDER BY n.updated_at DESC From d51885b49f68bf25a15df7ab968d440a6f3c6a4e Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Wed, 15 Oct 2025 16:32:19 +0800 Subject: [PATCH 26/29] =?UTF-8?q?feat(graph):=E4=BC=98=E5=8C=96Nebula?= =?UTF-8?q?=E5=9B=BE=E6=95=B0=E6=8D=AE=E5=BA=93=E6=9F=A5=E8=AF=A2=E6=80=A7?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/graph_dbs/nebular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 7fcf817ba..b40706769 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -1131,7 +1131,7 @@ def get_grouped_counts( group_by_fields.append(alias) # Full GQL query construction gql = f""" - MATCH (n) + MATCH (n /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {", ".join(return_fields)}, COUNT(n) AS count GROUP BY {", ".join(group_by_fields)} From a2715f56a27cee8b29ce2833c15a74927c2f53dc Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Wed, 15 Oct 2025 18:49:21 +0800 Subject: [PATCH 27/29] feat: add server api prd (#362) * feat: add server api prd * feat: update memcube for api * feat: add run server api md and change user_id to user_id * fix: code format * fix:code * fix: fix code format * feat: remove ids * fix: working ids --- examples/mem_api/pipeline_test.py | 178 ++++++++++ src/memos/api/product_models.py | 35 +- src/memos/api/routers/server_router.py | 282 ++++++++++++++++ src/memos/api/server_api.py | 38 +++ src/memos/configs/mem_user.py | 12 + src/memos/configs/memory.py | 5 + src/memos/graph_dbs/nebular.py | 315 +++++++++--------- src/memos/mem_cube/navie.py | 166 +++++++++ src/memos/mem_user/persistent_factory.py | 2 + .../mem_user/redis_persistent_user_manager.py | 225 +++++++++++++ src/memos/memories/factory.py | 2 + src/memos/memories/textual/base.py | 2 +- src/memos/memories/textual/simple_tree.py | 295 ++++++++++++++++ .../tree_text_memory/organize/manager.py | 84 ++--- .../tree_text_memory/retrieve/recall.py | 23 +- .../tree_text_memory/retrieve/searcher.py | 62 +++- src/memos/types.py | 22 ++ tests/memories/textual/test_tree_searcher.py | 2 +- 18 files changed, 1523 insertions(+), 227 deletions(-) create mode 100644 examples/mem_api/pipeline_test.py create mode 100644 src/memos/api/routers/server_router.py create mode 100644 src/memos/api/server_api.py create mode 100644 src/memos/mem_cube/navie.py create mode 100644 src/memos/mem_user/redis_persistent_user_manager.py create mode 100644 src/memos/memories/textual/simple_tree.py diff --git a/examples/mem_api/pipeline_test.py b/examples/mem_api/pipeline_test.py new file mode 100644 index 000000000..cd7b3bee3 --- /dev/null +++ b/examples/mem_api/pipeline_test.py @@ -0,0 +1,178 @@ +""" +Pipeline test script for MemOS Server API functions. +This script directly tests add and search functionalities without going through the API layer. +If you want to start server_api set .env to MemOS/.env and run: +uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8002 --workers 4 +""" + +from typing import Any + +from dotenv import load_dotenv + +# Import directly from server_router to reuse initialized components +from memos.api.routers.server_router import ( + _create_naive_mem_cube, + mem_reader, +) +from memos.log import get_logger + + +# Load environment variables +load_dotenv() + +logger = get_logger(__name__) + + +def test_add_memories( + messages: list[dict[str, str]], + user_id: str, + mem_cube_id: str, + session_id: str = "default_session", +) -> list[str]: + """ + Test adding memories to the system. + + Args: + messages: List of message dictionaries with 'role' and 'content' + user_id: User identifier + mem_cube_id: Memory cube identifier + session_id: Session identifier + + Returns: + List of memory IDs that were added + """ + logger.info(f"Testing add memories for user: {user_id}, mem_cube: {mem_cube_id}") + + # Create NaiveMemCube using server_router function + naive_mem_cube = _create_naive_mem_cube() + + # Extract memories from messages using server_router's mem_reader + memories = mem_reader.get_memory( + [messages], + type="chat", + info={ + "user_id": user_id, + "session_id": session_id, + }, + ) + + # Flatten memory list + flattened_memories = [mm for m in memories for mm in m] + + # Add memories to the system + mem_id_list: list[str] = naive_mem_cube.text_mem.add( + flattened_memories, + user_name=mem_cube_id, + ) + + logger.info(f"Added {len(mem_id_list)} memories: {mem_id_list}") + + # Print details of added memories + for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False): + logger.info(f" - ID: {memory_id}") + logger.info(f" Memory: {memory.memory}") + logger.info(f" Type: {memory.metadata.memory_type}") + + return mem_id_list + + +def test_search_memories( + query: str, + user_id: str, + mem_cube_id: str, + session_id: str = "default_session", + top_k: int = 5, + mode: str = "fast", + internet_search: bool = False, + moscube: bool = False, + chat_history: list | None = None, +) -> list[Any]: + """ + Test searching memories from the system. + + Args: + query: Search query text + user_id: User identifier + mem_cube_id: Memory cube identifier + session_id: Session identifier + top_k: Number of top results to return + mode: Search mode + internet_search: Whether to enable internet search + moscube: Whether to enable moscube search + chat_history: Chat history for context + + Returns: + List of search results + """ + + # Create NaiveMemCube using server_router function + naive_mem_cube = _create_naive_mem_cube() + + # Prepare search filter + search_filter = {"session_id": session_id} if session_id != "default_session" else None + + search_results = naive_mem_cube.text_mem.search( + query=query, + user_name=mem_cube_id, + top_k=top_k, + mode=mode, + manual_close_internet=not internet_search, + moscube=moscube, + search_filter=search_filter, + info={ + "user_id": user_id, + "session_id": session_id, + "chat_history": chat_history or [], + }, + ) + + # Print search results + for idx, result in enumerate(search_results, 1): + logger.info(f"\n Result {idx}:") + logger.info(f" ID: {result.id}") + logger.info(f" Memory: {result.memory}") + logger.info(f" Score: {getattr(result, 'score', 'N/A')}") + logger.info(f" Type: {result.metadata.memory_type}") + + return search_results + + +def main(): + # Test parameters + user_id = "test_user_123" + mem_cube_id = "test_cube_123" + session_id = "test_session_001" + + test_messages = [ + {"role": "user", "content": "Where should I go for Christmas?"}, + { + "role": "assistant", + "content": "There are many places to visit during Christmas, such as the Bund and Disneyland in Shanghai.", + }, + {"role": "user", "content": "What about New Year's Eve?"}, + { + "role": "assistant", + "content": "For New Year's Eve, you could visit Times Square in New York or watch fireworks at the Sydney Opera House.", + }, + ] + + memory_ids = test_add_memories( + messages=test_messages, user_id=user_id, mem_cube_id=mem_cube_id, session_id=session_id + ) + + logger.info(f"\nSuccessfully added {len(memory_ids)} memories!") + + search_queries = [ + "How to enjoy Christmas?", + "Where to celebrate New Year?", + "What are good places to visit during holidays?", + ] + + for query in search_queries: + logger.info("\n" + "-" * 80) + results = test_search_memories(query=query, user_id=user_id, mem_cube_id=mem_cube_id) + print(f"Query: '{query}' returned {len(results)} results") + + +if __name__ == "__main__": + main() diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 7e425415b..eb2d7aa6d 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field # Import message types from core types module -from memos.types import MessageDict +from memos.types import MessageDict, PermissionDict T = TypeVar("T") @@ -164,6 +164,39 @@ class SearchRequest(BaseRequest): session_id: str | None = Field(None, description="Session ID for soft-filtering memories") +class APISearchRequest(BaseRequest): + """Request model for searching memories.""" + + query: str = Field(..., description="Search query") + user_id: str = Field(None, description="User ID") + mem_cube_id: str | None = Field(None, description="Cube ID to search in") + mode: str = Field("fast", description="search mode fast or fine") + internet_search: bool = Field(False, description="Whether to use internet search") + moscube: bool = Field(False, description="Whether to use MemOSCube") + top_k: int = Field(10, description="Number of results to return") + chat_history: list[MessageDict] | None = Field(None, description="Chat history") + session_id: str | None = Field(None, description="Session ID for soft-filtering memories") + operation: list[PermissionDict] | None = Field( + None, description="operation ids for multi cubes" + ) + + +class APIADDRequest(BaseRequest): + """Request model for creating memories.""" + + user_id: str = Field(None, description="User ID") + mem_cube_id: str = Field(..., description="Cube ID") + messages: list[MessageDict] | None = Field(None, description="List of messages to store.") + memory_content: str | None = Field(None, description="Memory content to store") + doc_path: str | None = Field(None, description="Path to document to store") + source: str | None = Field(None, description="Source of the memory") + chat_history: list[MessageDict] | None = Field(None, description="Chat history") + session_id: str | None = Field(None, description="Session id") + operation: list[PermissionDict] | None = Field( + None, description="operation ids for multi cubes" + ) + + class SuggestionRequest(BaseRequest): """Request model for getting suggestion queries.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py new file mode 100644 index 000000000..1d398ff72 --- /dev/null +++ b/src/memos/api/routers/server_router.py @@ -0,0 +1,282 @@ +import os + +from typing import Any + +from fastapi import APIRouter + +from memos.api.config import APIConfig +from memos.api.product_models import ( + APIADDRequest, + APISearchRequest, + MemoryResponse, + SearchResponse, +) +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.reranker import RerankerConfigFactory +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.log import get_logger +from memos.mem_cube.navie import NaiveMemCube +from memos.mem_reader.factory import MemReaderFactory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory +from memos.types import MOSSearchResult, UserContext + + +logger = get_logger(__name__) + +router = APIRouter(prefix="/product", tags=["Server API"]) + + +def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: + """Build graph database configuration.""" + graph_db_backend_map = { + "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), + "neo4j": APIConfig.get_neo4j_config(user_id=user_id), + "nebular": APIConfig.get_nebular_config(user_id=user_id), + } + + graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + return GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } + ) + + +def _build_llm_config() -> dict[str, Any]: + """Build LLM configuration.""" + return LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } + ) + + +def _build_embedder_config() -> dict[str, Any]: + """Build embedder configuration.""" + return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + +def _build_mem_reader_config() -> dict[str, Any]: + """Build memory reader configuration.""" + return MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] + ) + + +def _build_reranker_config() -> dict[str, Any]: + """Build reranker configuration.""" + return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + +def _build_internet_retriever_config() -> dict[str, Any]: + """Build internet retriever configuration.""" + return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) + + +def _get_default_memory_size(cube_config) -> dict[str, int]: + """Get default memory size configuration.""" + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + +def init_server(): + """Initialize server components and configurations.""" + # Get default cube configuration + default_cube_config = APIConfig.get_default_cube_config() + + # Build component configurations + graph_db_config = _build_graph_db_config() + print(graph_db_config) + llm_config = _build_llm_config() + embedder_config = _build_embedder_config() + mem_reader_config = _build_mem_reader_config() + reranker_config = _build_reranker_config() + internet_retriever_config = _build_internet_retriever_config() + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder + ) + + # Initialize memory manager + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=_get_default_memory_size(default_cube_config), + is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), + ) + + return ( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, + ) + + +# Initialize global components +( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, +) = init_server() + + +def _create_naive_mem_cube() -> NaiveMemCube: + """Create a NaiveMemCube instance with initialized components.""" + naive_mem_cube = NaiveMemCube( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + internet_retriever=internet_retriever, + memory_manager=memory_manager, + default_cube_config=default_cube_config, + ) + return naive_mem_cube + + +def _format_memory_item(memory_data: Any) -> dict[str, Any]: + """Format a single memory item for API response.""" + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + +@router.post("/search", summary="Search memories", response_model=SearchResponse) +def search_memories(search_req: APISearchRequest): + """Search memories for a specific user.""" + # Create UserContext object - how to assign values + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=search_req.mem_cube_id, + session_id=search_req.session_id or "default_session", + ) + logger.info(f"Search user_id is: {user_context.mem_cube_id}") + memories_result: MOSSearchResult = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + } + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + naive_mem_cube = _create_naive_mem_cube() + search_results = naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=search_req.mode, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + formatted_memories = [_format_memory_item(data) for data in search_results] + + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": formatted_memories, + } + ) + + return SearchResponse( + message="Search completed successfully", + data=memories_result, + ) + + +@router.post("/add", summary="Add memories", response_model=MemoryResponse) +def add_memories(add_req: APIADDRequest): + """Add memories for a specific user.""" + # Create UserContext object - how to assign values + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=add_req.mem_cube_id, + session_id=add_req.session_id or "default_session", + ) + naive_mem_cube = _create_naive_mem_cube() + target_session_id = add_req.session_id + if not target_session_id: + target_session_id = "default_session" + memories = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + + # Flatten memory list + flattened_memories = [mm for m in memories for mm in m] + logger.info(f"Memory extraction completed for user {add_req.user_id}") + mem_id_list: list[str] = naive_mem_cube.text_mem.add( + flattened_memories, + user_name=user_context.mem_cube_id, + ) + + logger.info( + f"Added {len(mem_id_list)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_id_list}" + ) + response_data = [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) + ] + return MemoryResponse( + message="Memory added successfully", + data=response_data, + ) diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py new file mode 100644 index 000000000..78e05ef85 --- /dev/null +++ b/src/memos/api/server_api.py @@ -0,0 +1,38 @@ +import logging + +from fastapi import FastAPI + +from memos.api.exceptions import APIExceptionHandler +from memos.api.middleware.request_context import RequestContextMiddleware +from memos.api.routers.server_router import router as server_router + + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +app = FastAPI( + title="MemOS Product REST APIs", + description="A REST API for managing multiple users with MemOS Product.", + version="1.0.1", +) + +app.add_middleware(RequestContextMiddleware) +# Include routers +app.include_router(server_router) + +# Exception handlers +app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) +app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) + + +if __name__ == "__main__": + import argparse + + import uvicorn + + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--workers", type=int, default=1) + args = parser.parse_args() + uvicorn.run("memos.api.server_api:app", host="0.0.0.0", port=args.port, workers=args.workers) diff --git a/src/memos/configs/mem_user.py b/src/memos/configs/mem_user.py index 3ff1066e5..6e1ca4206 100644 --- a/src/memos/configs/mem_user.py +++ b/src/memos/configs/mem_user.py @@ -31,6 +31,17 @@ class MySQLUserManagerConfig(BaseUserManagerConfig): charset: str = Field(default="utf8mb4", description="MySQL charset") +class RedisUserManagerConfig(BaseUserManagerConfig): + """Redis user manager configuration.""" + + host: str = Field(default="localhost", description="Redis server host") + port: int = Field(default=6379, description="Redis server port") + username: str = Field(default="root", description="Redis username") + password: str = Field(default="", description="Redis password") + database: str = Field(default="memos_users", description="Redis database name") + charset: str = Field(default="utf8mb4", description="Redis charset") + + class UserManagerConfigFactory(BaseModel): """Factory for user manager configurations.""" @@ -42,6 +53,7 @@ class UserManagerConfigFactory(BaseModel): backend_to_class: ClassVar[dict[str, Any]] = { "sqlite": SQLiteUserManagerConfig, "mysql": MySQLUserManagerConfig, + "redis": RedisUserManagerConfig, } @field_validator("backend") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 1eea6deaf..237450e15 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -180,6 +180,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ) +class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): + """Simple tree text memory configuration class.""" + + # ─── 3. Global Memory Config Factory ────────────────────────────────────────── @@ -192,6 +196,7 @@ class MemoryConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "naive_text": NaiveTextMemoryConfig, "general_text": GeneralTextMemoryConfig, + "simple_tree_text": SimpleTreeTextMemoryConfig, "tree_text": TreeTextMemoryConfig, "kv_cache": KVCacheMemoryConfig, "vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 66ad894ad..10c3c75d0 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -129,7 +129,6 @@ def _make_client_key(cfg: NebulaGraphDBConfig) -> str: "nebula-sync", ",".join(hosts), str(getattr(cfg, "user", "")), - str(getattr(cfg, "use_multi_db", False)), str(getattr(cfg, "space", "")), ] ) @@ -139,7 +138,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " tmp = object.__new__(NebulaGraphDB) tmp.config = cfg tmp.db_name = cfg.space - tmp.user_name = getattr(cfg, "user_name", None) + tmp.user_name = None tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072) tmp.default_memory_dimension = 3072 tmp.common_fields = { @@ -169,7 +168,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) else "embedding" ) - tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space + tmp.system_db_name = cfg.space tmp._client = client tmp._owns_client = False return tmp @@ -417,7 +416,9 @@ def create_index( self._create_basic_property_indexes() @timed - def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: """ Remove all WorkingMemory nodes except the latest `keep_latest` entries. @@ -426,9 +427,10 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: keep_latest (int): Number of latest WorkingMemory entries to keep. """ optional_condition = "" - if not self.config.use_multi_db and self.config.user_name: - optional_condition = f"AND n.user_name = '{self.config.user_name}'" + user_name = user_name if user_name else self.config.user_name + + optional_condition = f"AND n.user_name = '{user_name}'" query = f""" MATCH (n@Memory) WHERE n.memory_type = '{memory_type}' @@ -440,13 +442,13 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: self.execute_query(query) @timed - def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: """ Insert or update a Memory node in NebulaGraph. """ - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name - + metadata["user_name"] = user_name if user_name else self.config.user_name now = datetime.utcnow() metadata = metadata.copy() metadata.setdefault("created_at", now) @@ -475,11 +477,9 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: ) @timed - def node_not_exist(self, scope: str) -> int: - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"' - else: - filter_clause = f'n.memory_type = "{scope}"' + def node_not_exist(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name + filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"' query = f""" MATCH (n@Memory) WHERE {filter_clause} @@ -495,10 +495,11 @@ def node_not_exist(self, scope: str) -> int: raise @timed - def update_node(self, id: str, fields: dict[str, Any]) -> None: + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: """ Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present. """ + user_name = user_name if user_name else self.config.user_name fields = fields.copy() set_clauses = [] for k, v in fields.items(): @@ -509,45 +510,41 @@ def update_node(self, id: str, fields: dict[str, Any]) -> None: query = f""" MATCH (n@Memory {{id: "{id}"}}) """ - - if not self.config.use_multi_db and self.config.user_name: - query += f'WHERE n.user_name = "{self.config.user_name}"' + query += f'WHERE n.user_name = "{user_name}"' query += f"\nSET {set_clause_str}" self.execute_query(query) @timed - def delete_node(self, id: str) -> None: + def delete_node(self, id: str, user_name: str | None = None) -> None: """ Delete a node from the graph. Args: id: Node identifier to delete. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name query = f""" - MATCH (n@Memory {{id: "{id}"}}) + MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)} + DETACH DELETE n """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" WHERE n.user_name = {self._format_value(user_name)}" - query += "\n DETACH DELETE n" self.execute_query(query) @timed - def add_edge(self, source_id: str, target_id: str, type: str): + def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None): """ Create an edge from source node to target node. Args: source_id: ID of the source node. target_id: ID of the target node. type: Relationship type (e.g., 'RELATE_TO', 'PARENT'). + user_name (str, optional): User name for filtering in non-multi-db mode """ if not source_id or not target_id: raise ValueError("[add_edge] source_id and target_id must be provided") - + user_name = user_name if user_name else self.config.user_name props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' - + props = f'{{user_name: "{user_name}"}}' insert_stmt = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT (a) -[e@{type} {props}]-> (b) @@ -558,35 +555,35 @@ def add_edge(self, source_id: str, target_id: str, type: str): logger.error(f"Failed to insert edge: {e}", exc_info=True) @timed - def delete_edge(self, source_id: str, target_id: str, type: str) -> None: + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: """ Delete a specific edge between two nodes. Args: source_id: ID of the source node. target_id: ID of the target node. type: Relationship type to remove. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (a@Memory) -[r@{type}]-> (b@Memory) WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" - + query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" query += "\nDELETE r" self.execute_query(query) @timed - def get_memory_count(self, memory_type: str) -> int: + def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (n@Memory) WHERE n.memory_type = "{memory_type}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN COUNT(n) AS count" try: @@ -597,14 +594,13 @@ def get_memory_count(self, memory_type: str) -> int: return -1 @timed - def count_nodes(self, scope: str) -> int: + def count_nodes(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (n@Memory) WHERE n.memory_type = "{scope}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN count(n) AS count" result = self.execute_query(query) @@ -612,7 +608,12 @@ def count_nodes(self, scope: str) -> int: @timed def edge_exists( - self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING" + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str | None = None, ) -> bool: """ Check if an edge exists between two nodes. @@ -622,10 +623,12 @@ def edge_exists( type: Relationship type. Use "ANY" to match any relationship type. direction: Direction of the edge. Use "OUTGOING" (default), "INCOMING", or "ANY". + user_name (str, optional): User name for filtering in non-multi-db mode Returns: True if the edge exists, otherwise False. """ # Prepare the relationship pattern + user_name = user_name if user_name else self.config.user_name rel = "r" if type == "ANY" else f"r@{type}" # Prepare the match pattern with direction @@ -640,9 +643,7 @@ def edge_exists( f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." ) query = f"MATCH {pattern}" - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query += "\nRETURN r" # Run the Cypher query @@ -654,22 +655,22 @@ def edge_exists( @timed # Graph Query & Reasoning - def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None: + def get_node( + self, id: str, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any] | None: """ Retrieve a Memory node by its unique ID. Args: id (str): Node ID (Memory.id) include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: dict: Node properties as key-value pairs, or None if not found. """ - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"' - else: - filter_clause = f'n.id = "{id}"' - + user_name = user_name if user_name else self.config.user_name + filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" MATCH (n@Memory) @@ -692,13 +693,18 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | @timed def get_nodes( - self, ids: list[str], include_embedding: bool = False, **kwargs + self, + ids: list[str], + include_embedding: bool = False, + user_name: str | None = None, + **kwargs, ) -> list[dict[str, Any]]: """ Retrieve the metadata and memory of a list of nodes. Args: ids: List of Node identifier. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. @@ -709,19 +715,14 @@ def get_nodes( if not ids: return [] - where_user = "" - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_user = f" AND n.user_name = '{kwargs['cube_name']}'" - else: - where_user = f" AND n.user_name = '{self.config.user_name}'" - + user_name = user_name if user_name else self.config.user_name + where_user = f" AND n.user_name = '{user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.id IN [{id_list}] {where_user} RETURN {return_fields} """ @@ -738,7 +739,9 @@ def get_nodes( return nodes @timed - def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]: + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None + ) -> list[dict[str, str]]: """ Get edges connected to a node, with optional type and direction filter. @@ -746,6 +749,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ id: Node ID to retrieve edges for. type: Relationship type to match, or 'ANY' to match all. direction: 'OUTGOING', 'INCOMING', or 'ANY'. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: List of edges: @@ -756,7 +760,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ """ # Build relationship type filter rel_type = "" if type == "ANY" else f"@{type}" - + user_name = user_name if user_name else self.config.user_name # Build Cypher pattern based on direction if direction == "OUTGOING": pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)" @@ -770,8 +774,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ else: raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'" + where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query = f""" MATCH {pattern} @@ -799,6 +802,7 @@ def get_neighbors_by_tag( top_k: int = 5, min_overlap: int = 1, include_embedding: bool = False, + user_name: str | None = None, ) -> list[dict[str, Any]]: """ Find top-K neighbor nodes with maximum tag overlap. @@ -809,13 +813,14 @@ def get_neighbors_by_tag( top_k: Max number of neighbors to return. min_overlap: Minimum number of overlapping tags required. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: List of dicts with node details and overlap count. """ if not tags: return [] - + user_name = user_name if user_name else self.config.user_name where_clauses = [ 'n.status = "activated"', 'NOT (n.node_type = "reasoning")', @@ -824,8 +829,7 @@ def get_neighbors_by_tag( if exclude_ids: where_clauses.append(f"NOT (n.id IN {exclude_ids})") - if not self.config.use_multi_db and self.config.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') where_clause = " AND ".join(where_clauses) tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" @@ -859,12 +863,11 @@ def get_neighbors_by_tag( return result @timed - def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: - where_user = "" - - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + def get_children_with_embeddings( + self, id: str, user_name: str | None = None + ) -> list[dict[str, Any]]: + user_name = user_name if user_name else self.config.user_name + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" query = f""" MATCH (p@Memory)-[@PARENT]->(c@Memory) @@ -884,7 +887,11 @@ def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: @timed def get_subgraph( - self, center_id: str, depth: int = 2, center_status: str = "activated" + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, ) -> dict[str, Any]: """ Retrieve a local subgraph centered at a given node. @@ -892,6 +899,7 @@ def get_subgraph( center_id: The ID of the center node. depth: The hop distance for neighbors. center_status: Required status for center node. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: { "core_node": {...}, @@ -902,7 +910,8 @@ def get_subgraph( if not 1 <= depth <= 5: raise ValueError("depth must be 1-5") - user_name = self.config.user_name + user_name = user_name if user_name else self.config.user_name + gql = f""" MATCH (center@Memory) WHERE center.id = '{center_id}' @@ -954,6 +963,7 @@ def search_by_embedding( status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, + user_name: str | None = None, **kwargs, ) -> list[dict]: """ @@ -968,6 +978,7 @@ def search_by_embedding( threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. Keys should match node properties, values are the expected values. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. @@ -981,42 +992,35 @@ def search_by_embedding( - Typical use case: restrict to 'status = activated' to avoid matching archived or merged nodes. """ + user_name = user_name if user_name else self.config.user_name vector = _normalize(vector) dim = len(vector) vector_str = ",".join(f"{float(x)}" for x in vector) gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - - where_clauses = [] + where_clauses = [f"n.{self.dim_field} IS NOT NULL"] if scope: where_clauses.append(f'n.memory_type = "{scope}"') if status: where_clauses.append(f'n.status = "{status}"') - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"') - else: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append(f'n.{key} = "{value}"') + else: + where_clauses.append(f"n.{key} = {value}") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - MATCH (n@Memory) - {where_clause} - ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC - APPROXIMATE - LIMIT {top_k} - OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }} - RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score - """ - + let a = {gql_vector} + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) + {where_clause} + ORDER BY inner_product(n.{self.dim_field}, a) DESC + LIMIT {top_k} + RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" try: result = self.execute_query(gql) except Exception as e: @@ -1038,7 +1042,9 @@ def search_by_embedding( return [] @timed - def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: + def get_by_metadata( + self, filters: list[dict[str, Any]], user_name: str | None = None + ) -> list[str]: """ 1. ADD logic: "AND" vs "OR"(support logic combination); 2. Support nested conditional expressions; @@ -1054,6 +1060,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: {"field": "tags", "op": "contains", "value": "AI"}, ... ] + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). @@ -1063,7 +1070,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: - Can be used for faceted recall or prefiltering before embedding rerank. """ where_clauses = [] - + user_name = user_name if user_name else self.config.user_name for _i, f in enumerate(filters): field = f["field"] op = f.get("op", "=") @@ -1087,11 +1094,10 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: else: raise ValueError(f"Unsupported operator: {op}") - if not self.config.use_multi_db and self.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id" + gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" ids = [] try: result = self.execute_query(gql) @@ -1106,6 +1112,7 @@ def get_grouped_counts( group_fields: list[str], where_clause: str = "", params: dict[str, Any] | None = None, + user_name: str | None = None, ) -> list[dict[str, Any]]: """ Count nodes grouped by any fields. @@ -1115,24 +1122,24 @@ def get_grouped_counts( where_clause (str, optional): Extra WHERE condition. E.g., "WHERE n.status = 'activated'" params (dict, optional): Parameters for WHERE clause. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] """ if not group_fields: raise ValueError("group_fields cannot be empty") - - # GQL-specific modifications - if not self.config.use_multi_db and self.config.user_name: - user_clause = f"n.user_name = '{self.config.user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" + user_name = user_name if user_name else self.config.user_name + # GQL-specific modifications + user_clause = f"n.user_name = '{user_name}'" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" else: - where_clause = f"WHERE {user_clause}" + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" # Inline parameters if provided if params: @@ -1170,16 +1177,16 @@ def get_grouped_counts( return output @timed - def clear(self) -> None: + def clear(self, user_name: str | None = None) -> None: """ Clear the entire graph if the target database exists. + + Args: + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name try: - if not self.config.use_multi_db and self.config.user_name: - query = f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n" - else: - query = "MATCH (n) DETACH DELETE n" - + query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n" self.execute_query(query) logger.info("Cleared all nodes from database.") @@ -1187,11 +1194,14 @@ def clear(self) -> None: logger.error(f"[ERROR] Failed to clear database: {e}") @timed - def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: + def export_graph( + self, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. Args: include_embedding (bool): Whether to include the large embedding field. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: { @@ -1199,13 +1209,11 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] } """ + user_name = user_name if user_name else self.config.user_name node_query = "MATCH (n@Memory)" edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - - if not self.config.use_multi_db and self.config.user_name: - username = self.config.user_name - node_query += f' WHERE n.user_name = "{username}"' - edge_query += f' WHERE r.user_name = "{username}"' + node_query += f' WHERE n.user_name = "{user_name}"' + edge_query += f' WHERE r.user_name = "{user_name}"' try: if include_embedding: @@ -1265,20 +1273,19 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: return {"nodes": nodes, "edges": edges} @timed - def import_graph(self, data: dict[str, Any]) -> None: + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: """ Import the entire graph from a serialized dictionary. Args: data: A dictionary containing all nodes and edges to be loaded. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name for node in data.get("nodes", []): try: id, memory, metadata = _compose_node(node) - - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name - + metadata["user_name"] = user_name metadata = self._prepare_node_metadata(metadata) metadata.update({"id": id, "memory": memory}) properties = ", ".join( @@ -1293,9 +1300,7 @@ def import_graph(self, data: dict[str, Any]) -> None: try: source_id, target_id = edge["source"], edge["target"] edge_type = edge["type"] - props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' + props = f'{{user_name: "{user_name}"}}' edge_gql = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) @@ -1305,29 +1310,31 @@ def import_graph(self, data: dict[str, Any]) -> None: logger.error(f"Fail to load edge: {edge}, error: {e}") @timed - def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]: + def get_all_memory_items( + self, scope: str, include_embedding: bool = False, user_name: str | None = None + ) -> (list)[dict]: """ Retrieve all memory items of a specific memory_type. Args: scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: Full list of memory items under this scope. """ + user_name = user_name if user_name else self.config.user_name if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") where_clause = f"WHERE n.memory_type = '{scope}'" - - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND n.user_name = '{self.config.user_name}'" + where_clause += f" AND n.user_name = '{user_name}'" return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {return_fields} LIMIT 100 @@ -1344,20 +1351,19 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> ( @timed def get_structure_optimization_candidates( - self, scope: str, include_embedding: bool = False + self, scope: str, include_embedding: bool = False, user_name: str | None = None ) -> list[dict]: """ Find nodes that are likely candidates for structure optimization: - Isolated nodes, nodes with empty background, or nodes with exactly one child. - Plus: the child of any parent node that has exactly one child. """ - + user_name = user_name if user_name else self.config.user_name where_clause = f''' n.memory_type = "{scope}" AND n.status = "activated" ''' - if not self.config.use_multi_db and self.config.user_name: - where_clause += f' AND n.user_name = "{self.config.user_name}"' + where_clause += f' AND n.user_name = "{user_name}"' return_fields = self._build_return_fields(include_embedding) return_fields += f", n.{self.dim_field} AS {self.dim_field}" @@ -1386,21 +1392,6 @@ def get_structure_optimization_candidates( logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") return candidates - @timed - def drop_database(self) -> None: - """ - Permanently delete the entire database this instance is using. - WARNING: This operation is destructive and cannot be undone. - """ - if self.config.use_multi_db: - self.execute_query(f"DROP GRAPH `{self.db_name}`") - logger.info(f"Database '`{self.db_name}`' has been dropped.") - else: - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) - @timed def detect_conflicts(self) -> list[tuple[str, str]]: """ @@ -1585,9 +1576,7 @@ def _create_basic_property_indexes(self) -> None: Create standard B-tree indexes on user_name when use Shared Database Multi-Tenant Mode. """ - fields = ["status", "memory_type", "created_at", "updated_at"] - if not self.config.use_multi_db: - fields.append("user_name") + fields = ["status", "memory_type", "created_at", "updated_at", "user_name"] for field in fields: index_name = f"idx_memory_{field}" diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py new file mode 100644 index 000000000..7ce3ca642 --- /dev/null +++ b/src/memos/mem_cube/navie.py @@ -0,0 +1,166 @@ +import os + +from typing import Literal + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.utils import get_json_file_model_schema +from memos.embedders.base import BaseEmbedder +from memos.exceptions import ConfigurationError, MemCubeError +from memos.graph_dbs.base import BaseGraphDB +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_cube.base import BaseMemCube +from memos.mem_reader.base import BaseMemReader +from memos.memories.activation.base import BaseActMemory +from memos.memories.parametric.base import BaseParaMemory +from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.reranker.base import BaseReranker + + +logger = get_logger(__name__) + + +class NaiveMemCube(BaseMemCube): + """MemCube is a box for loading and dumping three types of memories.""" + + def __init__( + self, + llm: BaseLLM, + embedder: BaseEmbedder, + mem_reader: BaseMemReader, + graph_db: BaseGraphDB, + reranker: BaseReranker, + memory_manager: MemoryManager, + default_cube_config: GeneralMemCubeConfig, + internet_retriever: None = None, + ): + """Initialize the MemCube with a configuration.""" + self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( + llm, + embedder, + mem_reader, + graph_db, + reranker, + memory_manager, + default_cube_config.text_mem.config, + internet_retriever, + ) + self._act_mem: BaseActMemory | None = None + self._para_mem: BaseParaMemory | None = None + + def load( + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + ) -> None: + """Load memories. + Args: + dir (str): The directory containing the memory files. + memory_types (list[str], optional): List of memory types to load. + If None, loads all available memory types. + Options: ["text_mem", "act_mem", "para_mem"] + """ + loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) + if loaded_schema != self.config.model_schema: + raise ConfigurationError( + f"Configuration schema mismatch. Expected {self.config.model_schema}, " + f"but found {loaded_schema}." + ) + + # If no specific memory types specified, load all + if memory_types is None: + memory_types = ["text_mem", "act_mem", "para_mem"] + + # Load specified memory types + if "text_mem" in memory_types and self.text_mem: + self.text_mem.load(dir) + logger.debug(f"Loaded text_mem from {dir}") + + if "act_mem" in memory_types and self.act_mem: + self.act_mem.load(dir) + logger.info(f"Loaded act_mem from {dir}") + + if "para_mem" in memory_types and self.para_mem: + self.para_mem.load(dir) + logger.info(f"Loaded para_mem from {dir}") + + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") + + def dump( + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + ) -> None: + """Dump memories. + Args: + dir (str): The directory where the memory files will be saved. + memory_types (list[str], optional): List of memory types to dump. + If None, dumps all available memory types. + Options: ["text_mem", "act_mem", "para_mem"] + """ + if os.path.exists(dir) and os.listdir(dir): + raise MemCubeError( + f"Directory {dir} is not empty. Please provide an empty directory for dumping." + ) + + # Always dump config + self.config.to_json_file(os.path.join(dir, self.config.config_filename)) + + # If no specific memory types specified, dump all + if memory_types is None: + memory_types = ["text_mem", "act_mem", "para_mem"] + + # Dump specified memory types + if "text_mem" in memory_types and self.text_mem: + self.text_mem.dump(dir) + logger.info(f"Dumped text_mem to {dir}") + + if "act_mem" in memory_types and self.act_mem: + self.act_mem.dump(dir) + logger.info(f"Dumped act_mem to {dir}") + + if "para_mem" in memory_types and self.para_mem: + self.para_mem.dump(dir) + logger.info(f"Dumped para_mem to {dir}") + + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") + + @property + def text_mem(self) -> "BaseTextMemory | None": + """Get the textual memory.""" + if self._text_mem is None: + logger.warning("Textual memory is not initialized. Returning None.") + return self._text_mem + + @text_mem.setter + def text_mem(self, value: BaseTextMemory) -> None: + """Set the textual memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._text_mem = value + + @property + def act_mem(self) -> "BaseActMemory | None": + """Get the activation memory.""" + if self._act_mem is None: + logger.warning("Activation memory is not initialized. Returning None.") + return self._act_mem + + @act_mem.setter + def act_mem(self, value: BaseActMemory) -> None: + """Set the activation memory.""" + if not isinstance(value, BaseActMemory): + raise TypeError(f"Expected BaseActMemory, got {type(value).__name__}") + self._act_mem = value + + @property + def para_mem(self) -> "BaseParaMemory | None": + """Get the parametric memory.""" + if self._para_mem is None: + logger.warning("Parametric memory is not initialized. Returning None.") + return self._para_mem + + @para_mem.setter + def para_mem(self, value: BaseParaMemory) -> None: + """Set the parametric memory.""" + if not isinstance(value, BaseParaMemory): + raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") + self._para_mem = value diff --git a/src/memos/mem_user/persistent_factory.py b/src/memos/mem_user/persistent_factory.py index b5ece61b5..6a7b4fa13 100644 --- a/src/memos/mem_user/persistent_factory.py +++ b/src/memos/mem_user/persistent_factory.py @@ -3,6 +3,7 @@ from memos.configs.mem_user import UserManagerConfigFactory from memos.mem_user.mysql_persistent_user_manager import MySQLPersistentUserManager from memos.mem_user.persistent_user_manager import PersistentUserManager +from memos.mem_user.redis_persistent_user_manager import RedisPersistentUserManager class PersistentUserManagerFactory: @@ -11,6 +12,7 @@ class PersistentUserManagerFactory: backend_to_class: ClassVar[dict[str, Any]] = { "sqlite": PersistentUserManager, "mysql": MySQLPersistentUserManager, + "redis": RedisPersistentUserManager, } @classmethod diff --git a/src/memos/mem_user/redis_persistent_user_manager.py b/src/memos/mem_user/redis_persistent_user_manager.py new file mode 100644 index 000000000..48c89c663 --- /dev/null +++ b/src/memos/mem_user/redis_persistent_user_manager.py @@ -0,0 +1,225 @@ +"""Redis-based persistent user management system for MemOS with configuration storage. + +This module provides persistent storage for user configurations using Redis. +""" + +import json + +from memos.configs.mem_os import MOSConfig +from memos.dependency import require_python_package +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class RedisPersistentUserManager: + """Redis-based user configuration manager with persistence.""" + + @require_python_package( + import_name="redis", + install_command="pip install redis", + install_link="https://redis.readthedocs.io/en/stable/", + ) + def __init__( + self, + host: str = "localhost", + port: int = 6379, + password: str = "", + db: int = 0, + decode_responses: bool = True, + ): + """Initialize the Redis persistent user manager. + + Args: + user_id (str, optional): User ID. Defaults to "root". + host (str): Redis server host. Defaults to "localhost". + port (int): Redis server port. Defaults to 6379. + password (str): Redis password. Defaults to "". + db (int): Redis database number. Defaults to 0. + decode_responses (bool): Whether to decode responses to strings. Defaults to True. + """ + import redis + + self.host = host + self.port = port + self.db = db + + try: + # Create Redis connection + self._redis_client = redis.Redis( + host=host, + port=port, + password=password if password else None, + db=db, + decode_responses=decode_responses, + ) + + # Test connection + if not self._redis_client.ping(): + raise ConnectionError("Redis connection failed") + + logger.info( + f"RedisPersistentUserManager initialized successfully, connected to {host}:{port}/{db}" + ) + + except Exception as e: + logger.error(f"Redis connection error: {e}") + raise + + def _get_config_key(self, user_id: str) -> str: + """Generate Redis key for user configuration. + + Args: + user_id (str): User ID. + + Returns: + str: Redis key name. + """ + return user_id + + def save_user_config(self, user_id: str, config: MOSConfig) -> bool: + """Save user configuration to Redis. + + Args: + user_id (str): User ID. + config (MOSConfig): User's MOS configuration. + + Returns: + bool: True if successful, False otherwise. + """ + try: + # Convert config to JSON string + config_dict = config.model_dump(mode="json") + config_json = json.dumps(config_dict, ensure_ascii=False, indent=2) + + # Save to Redis + key = self._get_config_key(user_id) + self._redis_client.set(key, config_json) + + logger.info(f"Successfully saved configuration for user {user_id} to Redis") + return True + + except Exception as e: + logger.error(f"Error saving configuration for user {user_id}: {e}") + return False + + def get_user_config(self, user_id: str) -> dict | None: + """Get user configuration from Redis (search interface). + + Args: + user_id (str): User ID. + + Returns: + MOSConfig | None: User's configuration object, or None if not found. + """ + try: + # Get configuration from Redis + key = self._get_config_key(user_id) + config_json = self._redis_client.get(key) + + if config_json is None: + logger.info(f"Configuration for user {user_id} does not exist") + return None + + # Parse JSON and create MOSConfig object + config_dict = json.loads(config_json) + + logger.info(f"Successfully retrieved configuration for user {user_id}") + return config_dict + + except json.JSONDecodeError as e: + logger.error(f"Error parsing JSON configuration for user {user_id}: {e}") + return None + except Exception as e: + logger.error(f"Error retrieving configuration for user {user_id}: {e}") + return None + + def delete_user_config(self, user_id: str) -> bool: + """Delete user configuration from Redis. + + Args: + user_id (str): User ID. + + Returns: + bool: True if successful, False otherwise. + """ + try: + key = self._get_config_key(user_id) + result = self._redis_client.delete(key) + + if result > 0: + logger.info(f"Successfully deleted configuration for user {user_id}") + return True + else: + logger.warning(f"Configuration for user {user_id} does not exist, cannot delete") + return False + + except Exception as e: + logger.error(f"Error deleting configuration for user {user_id}: {e}") + return False + + def exists_user_config(self, user_id: str) -> bool: + """Check if user configuration exists. + + Args: + user_id (str): User ID. + + Returns: + bool: True if exists, False otherwise. + """ + try: + key = self._get_config_key(user_id) + return self._redis_client.exists(key) > 0 + except Exception as e: + logger.error(f"Error checking if configuration exists for user {user_id}: {e}") + return False + + def list_user_configs( + self, pattern: str = "user_config:*", count: int = 100 + ) -> dict[str, dict]: + """List all user configurations. + + Args: + pattern (str): Redis key matching pattern. Defaults to "user_config:*". + count (int): Number of keys to return per scan. Defaults to 100. + + Returns: + dict[str, dict]: Dictionary mapping user_id to dict objects. + """ + result = {} + try: + # Use SCAN command to iterate through all matching keys + cursor = 0 + while True: + cursor, keys = self._redis_client.scan(cursor, match=pattern, count=count) + + for key in keys: + # Extract user_id (remove "user_config:" prefix) + user_id = key.replace("user_config:", "") + config = self.get_user_config(user_id) + if config: + result[user_id] = config + + if cursor == 0: + break + + logger.info(f"Successfully listed {len(result)} user configurations") + return result + + except Exception as e: + logger.error(f"Error listing user configurations: {e}") + return {} + + def close(self) -> None: + """Close Redis connection. + + This method should be called when the RedisPersistentUserManager is no longer needed + to ensure proper cleanup of Redis connections. + """ + try: + if hasattr(self, "_redis_client") and self._redis_client: + self._redis_client.close() + logger.info("Redis connection closed") + except Exception as e: + logger.error(f"Error closing Redis connection: {e}") diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index 9fdc67c53..bcf7fdd9b 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -10,6 +10,7 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.general import GeneralTextMemory from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree import TreeTextMemory @@ -20,6 +21,7 @@ class MemoryFactory(BaseMemory): "naive_text": NaiveTextMemory, "general_text": GeneralTextMemory, "tree_text": TreeTextMemory, + "simple_tree_text": SimpleTreeTextMemory, "kv_cache": KVCacheMemory, "vllm_kv_cache": VLLMKVCacheMemory, "lora": LoRAMemory, diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 8171fadce..82dad4486 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -24,7 +24,7 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]: """ @abstractmethod - def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: """Add memories. Args: diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py new file mode 100644 index 000000000..9c67db288 --- /dev/null +++ b/src/memos/memories/textual/simple_tree.py @@ -0,0 +1,295 @@ +import time + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from memos.configs.memory import TreeTextMemoryConfig +from memos.embedders.base import BaseEmbedder +from memos.graph_dbs.base import BaseGraphDB +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_reader.base import BaseMemReader +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.memories.textual.tree import TreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.base import BaseReranker +from memos.types import MessageList + + +if TYPE_CHECKING: + from memos.embedders.factory import OllamaEmbedder + from memos.graph_dbs.factory import Neo4jGraphDB + from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM + + +logger = get_logger(__name__) + + +class SimpleTreeTextMemory(TreeTextMemory): + """General textual memory implementation for storing and retrieving memories.""" + + def __init__( + self, + llm: BaseLLM, + embedder: BaseEmbedder, + mem_reader: BaseMemReader, + graph_db: BaseGraphDB, + reranker: BaseReranker, + memory_manager: MemoryManager, + config: TreeTextMemoryConfig, + internet_retriever: None = None, + is_reorganize: bool = False, + ): + """Initialize memory with the given configuration.""" + time_start = time.time() + self.config: TreeTextMemoryConfig = config + + self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm + logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") + + time_start_ex = time.time() + self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = llm + logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}") + + time_start_em = time.time() + self.embedder: OllamaEmbedder = embedder + logger.info(f"time init: embedder time is: {time.time() - time_start_em}") + + time_start_gs = time.time() + self.graph_store: Neo4jGraphDB = graph_db + logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") + + time_start_rr = time.time() + self.reranker = reranker + logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") + + time_start_mm = time.time() + self.memory_manager: MemoryManager = memory_manager + logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}") + time_start_ir = time.time() + # Create internet retriever if configured + self.internet_retriever = None + if config.internet_retriever is not None: + self.internet_retriever = internet_retriever + logger.info( + f"Internet retriever initialized with backend: {config.internet_retriever.backend}" + ) + else: + logger.info("No internet retriever configured") + logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") + + def add( + self, memories: list[TextualMemoryItem | dict[str, Any]], user_name: str | None = None + ) -> list[str]: + """Add memories. + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + Later: + memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] + metadata = extract_metadata(memory_items, self.extractor_llm) + plan = plan_memory_operations(memory_items, metadata, self.graph_store) + execute_plan(memory_items, metadata, plan, self.graph_store) + """ + return self.memory_manager.add(memories, user_name=user_name) + + def replace_working_memory( + self, memories: list[TextualMemoryItem], user_name: str | None = None + ) -> None: + self.memory_manager.replace_working_memory(memories, user_name=user_name) + + def get_working_memory(self, user_name: str | None = None) -> list[TextualMemoryItem]: + working_memories = self.graph_store.get_all_memory_items( + scope="WorkingMemory", user_name=user_name + ) + items = [TextualMemoryItem.from_dict(record) for record in (working_memories)] + # Sort by updated_at in descending order + sorted_items = sorted( + items, key=lambda x: x.metadata.updated_at or datetime.min, reverse=True + ) + return sorted_items + + def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]: + """ + Get the current size of each memory type. + This delegates to the MemoryManager. + """ + return self.memory_manager.get_current_memory_size(user_name=user_name) + + def search( + self, + query: str, + top_k: int, + info=None, + mode: str = "fast", + memory_type: str = "All", + manual_close_internet: bool = False, + moscube: bool = False, + search_filter: dict | None = None, + user_name: str | None = None, + ) -> list[TextualMemoryItem]: + """Search for memories based on a query. + User query -> TaskGoalParser -> MemoryPathResolver -> + GraphMemoryRetriever -> MemoryReranker -> MemoryReasoner -> Final output + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + mode (str, optional): The mode of the search. + - 'fast': Uses a faster search process, sacrificing some precision for speed. + - 'fine': Uses a more detailed search process, invoking large models for higher precision, but slower performance. + memory_type (str): Type restriction for search. + ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory'] + manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config. + moscube (bool): whether you use moscube to answer questions + search_filter (dict, optional): Optional metadata filters for search results. + - Keys correspond to memory metadata fields (e.g., "user_id", "session_id"). + - Values are exact-match conditions. + Example: {"user_id": "123", "session_id": "abc"} + If None, no additional filtering is applied. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + if (self.internet_retriever is not None) and manual_close_internet: + logger.warning( + "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" + ) + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=None, + moscube=moscube, + ) + else: + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=self.internet_retriever, + moscube=moscube, + ) + return searcher.search( + query, top_k, info, mode, memory_type, search_filter, user_name=user_name + ) + + def get_relevant_subgraph( + self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated" + ) -> dict[str, Any]: + """ + Find and merge the local neighborhood sub-graphs of the top-k + nodes most relevant to the query. + Process: + 1. Embed the user query into a vector representation. + 2. Use vector similarity search to find the top-k similar nodes. + 3. For each similar node: + - Ensure its status matches `center_status` (e.g., 'active'). + - Retrieve its local subgraph up to `depth` hops. + - Collect the center node, its neighbors, and connecting edges. + 4. Merge all retrieved subgraphs into a single unified subgraph. + 5. Return the merged subgraph structure. + + Args: + query (str): The user input or concept to find relevant memories for. + top_k (int, optional): How many top similar nodes to retrieve. Default is 5. + depth (int, optional): The neighborhood depth (number of hops). Default is 2. + center_status (str, optional): Status condition the center node must satisfy (e.g., 'active'). + + Returns: + dict[str, Any]: A subgraph dict with: + - 'core_id': ID of the top matching core node, or None if none found. + - 'nodes': List of unique nodes (core + neighbors) in the merged subgraph. + - 'edges': List of unique edges (as dicts with 'from', 'to', 'type') in the merged subgraph. + """ + # Step 1: Embed query + query_embedding = self.embedder.embed([query])[0] + + # Step 2: Get top-1 similar node + similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k) + if not similar_nodes: + logger.info("No similar nodes found for query embedding.") + return {"core_id": None, "nodes": [], "edges": []} + + # Step 3: Fetch neighborhood + all_nodes = {} + all_edges = set() + cores = [] + + for node in similar_nodes: + core_id = node["id"] + score = node["score"] + + subgraph = self.graph_store.get_subgraph( + center_id=core_id, depth=depth, center_status=center_status + ) + + if not subgraph["core_node"]: + logger.info(f"Skipping node {core_id} (inactive or not found).") + continue + + core_node = subgraph["core_node"] + neighbors = subgraph["neighbors"] + edges = subgraph["edges"] + + # Collect nodes + all_nodes[core_node["id"]] = core_node + for n in neighbors: + all_nodes[n["id"]] = n + + # Collect edges + for e in edges: + all_edges.add((e["source"], e["target"], e["type"])) + + cores.append( + {"id": core_id, "score": score, "core_node": core_node, "neighbors": neighbors} + ) + + top_core = cores[0] + return { + "core_id": top_core["id"], + "nodes": list(all_nodes.values()), + "edges": [{"source": f, "target": t, "type": ty} for (f, t, ty) in all_edges], + } + + def extract(self, messages: MessageList) -> list[TextualMemoryItem]: + raise NotImplementedError + + def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: + raise NotImplementedError + + def get(self, memory_id: str) -> TextualMemoryItem: + """Get a memory by its ID.""" + result = self.graph_store.get_node(memory_id) + if result is None: + raise ValueError(f"Memory with ID {memory_id} not found") + metadata_dict = result.get("metadata", {}) + return TextualMemoryItem( + id=result["id"], + memory=result["memory"], + metadata=TreeNodeTextualMemoryMetadata(**metadata_dict), + ) + + def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: + raise NotImplementedError + + def get_all(self) -> dict: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + all_items = self.graph_store.export_graph() + return all_items + + def delete(self, memory_ids: list[str]) -> None: + raise NotImplementedError + + def delete_all(self) -> None: + """Delete all memories and their relationships from the graph store.""" + try: + self.graph_store.clear() + logger.info("All memories and edges have been deleted from the graph.") + except Exception as e: + logger.error(f"An error occurred while deleting all memories: {e}") + raise diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index c9cd4de8a..680052a9d 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -51,14 +51,14 @@ def __init__( ) self._merged_threshold = merged_threshold - def add(self, memories: list[TextualMemoryItem]) -> list[str]: + def add(self, memories: list[TextualMemoryItem], user_name: str | None = None) -> list[str]: """ Add new memories in parallel to different memory types (WorkingMemory, LongTermMemory, UserMemory). """ added_ids: list[str] = [] with ContextThreadPoolExecutor(max_workers=8) as executor: - futures = {executor.submit(self._process_memory, m): m for m in memories} + futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} for future in as_completed(futures, timeout=60): try: ids = future.result() @@ -66,38 +66,31 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]: except Exception as e: logger.exception("Memory processing error: ", exc_info=e) - try: - self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] - ) - except Exception: - logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"] - ) - except Exception: - logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"] - ) - except Exception: - logger.warning(f"Remove UserMemory error: {traceback.format_exc()}") + for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + try: + self.graph_store.remove_oldest_memory( + memory_type="WorkingMemory", + keep_latest=self.memory_size[mem_type], + user_name=user_name, + ) + except Exception: + logger.warning(f"Remove {mem_type} error: {traceback.format_exc()}") - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) return added_ids - def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: + def replace_working_memory( + self, memories: list[TextualMemoryItem], user_name: str | None = None + ) -> None: """ Replace WorkingMemory """ working_memory_top_k = memories[: self.memory_size["WorkingMemory"]] with ContextThreadPoolExecutor(max_workers=8) as executor: futures = [ - executor.submit(self._add_memory_to_db, memory, "WorkingMemory") + executor.submit( + self._add_memory_to_db, memory, "WorkingMemory", user_name=user_name + ) for memory in working_memory_top_k ] for future in as_completed(futures, timeout=60): @@ -107,47 +100,51 @@ def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: logger.exception("Memory processing error: ", exc_info=e) self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] + memory_type="WorkingMemory", + keep_latest=self.memory_size["WorkingMemory"], + user_name=user_name, ) - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) - def get_current_memory_size(self) -> dict[str, int]: + def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]: """ Return the cached memory type counts. """ - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) return self.current_memory_size - def _refresh_memory_size(self) -> None: + def _refresh_memory_size(self, user_name: str | None = None) -> None: """ Query the latest counts from the graph store and update internal state. """ - results = self.graph_store.get_grouped_counts(group_fields=["memory_type"]) + results = self.graph_store.get_grouped_counts( + group_fields=["memory_type"], user_name=user_name + ) self.current_memory_size = {record["memory_type"]: record["count"] for record in results} logger.info(f"[MemoryManager] Refreshed memory sizes: {self.current_memory_size}") - def _process_memory(self, memory: TextualMemoryItem): + def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None): """ Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). This method runs asynchronously to process each memory item. """ ids = [] - # Add to WorkingMemory - working_id = self._add_memory_to_db(memory, "WorkingMemory") - ids.append(working_id) + # Add to WorkingMemory do not return working_id + self._add_memory_to_db(memory, "WorkingMemory", user_name) # Add to LongTermMemory and UserMemory if memory.metadata.memory_type in ["LongTermMemory", "UserMemory"]: added_id = self._add_to_graph_memory( - memory=memory, - memory_type=memory.metadata.memory_type, + memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name ) ids.append(added_id) return ids - def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str: + def _add_memory_to_db( + self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + ) -> str: """ Add a single memory item to the graph store, with FIFO logic for WorkingMemory. """ @@ -158,10 +155,12 @@ def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str: working_memory = TextualMemoryItem(memory=memory.memory, metadata=metadata) # Insert node into graph - self.graph_store.add_node(working_memory.id, working_memory.memory, metadata) + self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) return working_memory.id - def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str): + def _add_to_graph_memory( + self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + ): """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). @@ -175,7 +174,10 @@ def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str): node_id = str(uuid.uuid4()) # Step 2: Add new node to graph self.graph_store.add_node( - node_id, memory.memory, memory.metadata.model_dump(exclude_none=True) + node_id, + memory.memory, + memory.metadata.model_dump(exclude_none=True), + user_name=user_name, ) self.reorganizer.add_message( QueueMessage( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 84cc8ecb3..d4cfcf501 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -30,6 +30,7 @@ def retrieve( memory_scope: str, query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -53,13 +54,13 @@ def retrieve( if memory_scope == "WorkingMemory": # For working memory, retrieve all entries (no filtering) working_memories = self.graph_store.get_all_memory_items( - scope="WorkingMemory", include_embedding=False + scope="WorkingMemory", include_embedding=False, user_name=user_name ) return [TextualMemoryItem.from_dict(record) for record in working_memories] with ContextThreadPoolExecutor(max_workers=2) as executor: # Structured graph-based retrieval - future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope) + future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) # Vector similarity search future_vector = executor.submit( self._vector_recall, @@ -67,6 +68,7 @@ def retrieve( memory_scope, top_k, search_filter=search_filter, + user_name=user_name, ) graph_results = future_graph.result() @@ -92,6 +94,7 @@ def retrieve_from_cube( memory_scope: str, query_embedding: list[list[float]] | None = None, cube_name: str = "memos_cube01", + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -112,7 +115,7 @@ def retrieve_from_cube( raise ValueError(f"Unsupported memory scope: {memory_scope}") graph_results = self._vector_recall( - query_embedding, memory_scope, top_k, cube_name=cube_name + query_embedding, memory_scope, top_k, cube_name=cube_name, user_name=user_name ) for result_i in graph_results: @@ -132,7 +135,7 @@ def retrieve_from_cube( return list(combined.values()) def _graph_recall( - self, parsed_goal: ParsedTaskGoal, memory_scope: str + self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None ) -> list[TextualMemoryItem]: """ Perform structured node-based retrieval from Neo4j. @@ -148,7 +151,7 @@ def _graph_recall( {"field": "key", "op": "in", "value": parsed_goal.keys}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - key_ids = self.graph_store.get_by_metadata(key_filters) + key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) candidate_ids.update(key_ids) # 2) tag-based OR branch @@ -157,7 +160,7 @@ def _graph_recall( {"field": "tags", "op": "contains", "value": parsed_goal.tags}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - tag_ids = self.graph_store.get_by_metadata(tag_filters) + tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name) candidate_ids.update(tag_ids) # No matches → return empty @@ -165,7 +168,9 @@ def _graph_recall( return [] # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=False, user_name=user_name + ) final_nodes = [] for node in node_dicts: @@ -194,6 +199,7 @@ def _vector_recall( max_num: int = 3, cube_name: str | None = None, search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform vector-based similarity retrieval using query embedding. @@ -210,6 +216,7 @@ def search_single(vec, filt=None): scope=memory_scope, cube_name=cube_name, search_filter=filt, + user_name=user_name, ) or [] ) @@ -255,7 +262,7 @@ def search_path_b(): unique_ids = {r["id"] for r in all_hits if r.get("id")} node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), include_embedding=False, cube_name=cube_name + list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name ) or [] ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index df154f23a..05db56f53 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -12,7 +12,6 @@ from memos.reranker.base import BaseReranker from memos.utils import timed -from .internet_retriever_factory import InternetRetrieverFactory from .reasoner import MemoryReasoner from .recall import GraphMemoryRetriever from .task_goal_parser import TaskGoalParser @@ -28,7 +27,7 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, - internet_retriever: InternetRetrieverFactory | None = None, + internet_retriever: None = None, moscube: bool = False, ): self.graph_store = graph_store @@ -54,6 +53,7 @@ def search( mode="fast", memory_type="All", search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Search for memories based on a query. @@ -85,14 +85,22 @@ def search( logger.debug(f"[SEARCH] Received info dict: {info}") parsed_goal, query_embedding, context, query = self._parse_task( - query, info, mode, search_filter=search_filter + query, info, mode, search_filter=search_filter, user_name=user_name ) results = self._retrieve_paths( - query, parsed_goal, query_embedding, info, top_k, mode, memory_type, search_filter + query, + parsed_goal, + query_embedding, + info, + top_k, + mode, + memory_type, + search_filter, + user_name, ) deduped = self._deduplicate_results(results) final_results = self._sort_and_trim(deduped, top_k) - self._update_usage_history(final_results, info) + self._update_usage_history(final_results, info, user_name) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") res_results = "" @@ -104,7 +112,15 @@ def search( return final_results @timed - def _parse_task(self, query, info, mode, top_k=5, search_filter: dict | None = None): + def _parse_task( + self, + query, + info, + mode, + top_k=5, + search_filter: dict | None = None, + user_name: str | None = None, + ): """Parse user query, do embedding search and create context""" context = [] query_embedding = None @@ -118,7 +134,7 @@ def _parse_task(self, query, info, mode, top_k=5, search_filter: dict | None = N related_nodes = [ self.graph_store.get_node(n["id"]) for n in self.graph_store.search_by_embedding( - query_embedding, top_k=top_k, search_filter=search_filter + query_embedding, top_k=top_k, search_filter=search_filter, user_name=user_name ) ] memories = [] @@ -168,6 +184,7 @@ def _retrieve_paths( mode, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Run A/B/C retrieval paths in parallel""" tasks = [] @@ -181,6 +198,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + user_name, ) ) tasks.append( @@ -192,6 +210,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + user_name, ) ) tasks.append( @@ -204,6 +223,7 @@ def _retrieve_paths( info, mode, memory_type, + user_name, ) ) if self.moscube: @@ -235,6 +255,7 @@ def _retrieve_from_working_memory( top_k, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Retrieve and rerank from WorkingMemory""" if memory_type not in ["All", "WorkingMemory"]: @@ -246,6 +267,7 @@ def _retrieve_from_working_memory( top_k=top_k, memory_scope="WorkingMemory", search_filter=search_filter, + user_name=user_name, ) return self.reranker.rerank( query=query, @@ -266,6 +288,7 @@ def _retrieve_from_long_term_and_user( top_k, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] @@ -282,6 +305,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, + user_name=user_name, ) ) if memory_type in ["All", "UserMemory"]: @@ -294,6 +318,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, + user_name=user_name, ) ) @@ -320,6 +345,7 @@ def _retrieve_from_memcubes( top_k=top_k * 2, memory_scope="LongTermMemory", cube_name=cube_name, + user_name=cube_name, ) return self.reranker.rerank( query=query, @@ -332,7 +358,15 @@ def _retrieve_from_memcubes( # --- Path C @timed def _retrieve_from_internet( - self, query, parsed_goal, query_embedding, top_k, info, mode, memory_type + self, + query, + parsed_goal, + query_embedding, + top_k, + info, + mode, + memory_type, + user_id: str | None = None, ): """Retrieve and rerank from Internet source""" if not self.internet_retriever or mode == "fast": @@ -380,7 +414,7 @@ def _sort_and_trim(self, results, top_k): return final_items @timed - def _update_usage_history(self, items, info): + def _update_usage_history(self, items, info, user_name: str | None = None): """Update usage history in graph DB""" now_time = datetime.now().isoformat() info_copy = dict(info or {}) @@ -402,11 +436,15 @@ def _update_usage_history(self, items, info): logger.exception("[USAGE] snapshot item failed") if payload: - self._usage_executor.submit(self._update_usage_history_worker, payload, usage_record) + self._usage_executor.submit( + self._update_usage_history_worker, payload, usage_record, user_name + ) - def _update_usage_history_worker(self, payload, usage_record: str): + def _update_usage_history_worker( + self, payload, usage_record: str, user_name: str | None = None + ): try: for item_id, usage_list in payload: - self.graph_store.update_node(item_id, {"usage": usage_list}) + self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=user_name) except Exception: logger.exception("[USAGE] update usage failed") diff --git a/src/memos/types.py b/src/memos/types.py index 60d5da8d2..635fabccc 100644 --- a/src/memos/types.py +++ b/src/memos/types.py @@ -56,3 +56,25 @@ class MOSSearchResult(TypedDict): text_mem: list[dict[str, str | list[TextualMemoryItem]]] act_mem: list[dict[str, str | list[ActivationMemoryItem]]] para_mem: list[dict[str, str | list[ParametricMemoryItem]]] + + +# ─── API Types ──────────────────────────────────────────────────────────────────── +# for API Permission +Permission: TypeAlias = Literal["read", "write", "delete", "execute"] + + +# Message structure +class PermissionDict(TypedDict, total=False): + """Typed dictionary for chat message dictionaries.""" + + permissions: list[Permission] + mem_cube_id: str + + +class UserContext(BaseModel): + """Model to represent user context.""" + + user_id: str | None = None + mem_cube_id: str | None = None + session_id: str | None = None + operation: list[PermissionDict] | None = None diff --git a/tests/memories/textual/test_tree_searcher.py b/tests/memories/textual/test_tree_searcher.py index c9f42ec38..d99664817 100644 --- a/tests/memories/textual/test_tree_searcher.py +++ b/tests/memories/textual/test_tree_searcher.py @@ -73,7 +73,7 @@ def test_searcher_fast_path(mock_searcher): for item in result: assert len(item.metadata.usage) > 0 mock_searcher.graph_store.update_node.assert_any_call( - item.id, {"usage": item.metadata.usage} + item.id, {"usage": item.metadata.usage}, user_name=None ) From 2bfde7d9d2d54163a76d374c11a83880288a3261 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Wed, 15 Oct 2025 19:51:44 +0800 Subject: [PATCH 28/29] feat:nebula gql add index --- src/memos/graph_dbs/nebular.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 10c3c75d0..a6f6b82a4 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -432,7 +432,7 @@ def remove_oldest_memory( optional_condition = f"AND n.user_name = '{user_name}'" query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.memory_type = '{memory_type}' {optional_condition} ORDER BY n.updated_at DESC @@ -1158,7 +1158,7 @@ def get_grouped_counts( group_by_fields.append(alias) # Full GQL query construction gql = f""" - MATCH (n) + MATCH (n /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {", ".join(return_fields)}, COUNT(n) AS count GROUP BY {", ".join(group_by_fields)} From 07751cdb17821516767e4a30220821237da494e9 Mon Sep 17 00:00:00 2001 From: Hao <120852460@qq.com> Date: Wed, 15 Oct 2025 20:00:53 +0800 Subject: [PATCH 29/29] feat: align code --- src/memos/api/config.py | 4 ++-- src/memos/api/product_models.py | 5 ++--- src/memos/configs/graph_db.py | 12 ++++++++++++ src/memos/mem_reader/simple_struct.py | 13 +++++++------ 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 0d44b8963..355ee0385 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -251,8 +251,8 @@ def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: "user": os.getenv("NEBULAR_USER", "root"), "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"), "space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"), - # "user_name": f"memos{user_id.replace('-', '')}", - # "use_multi_db": False, + "user_name": f"memos{user_id.replace('-', '')}", + "use_multi_db": False, "auto_create": True, "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index f0b6e2487..eb2d7aa6d 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -151,9 +151,7 @@ class MemoryCreateRequest(BaseRequest): mem_cube_id: str | None = Field(None, description="Cube ID") source: str | None = Field(None, description="Source of the memory") user_profile: bool = Field(False, description="User profile memory") - session_id: str | None = Field( - default_factory=lambda: str(uuid.uuid4()), description="Session id" - ) + session_id: str | None = Field(None, description="Session id") class SearchRequest(BaseRequest): @@ -161,6 +159,7 @@ class SearchRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Search query") + mem_cube_id: str | None = Field(None, description="Cube ID to search in") top_k: int = Field(10, description="Number of results to return") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 49703bb69..2df917166 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -124,10 +124,22 @@ class NebulaGraphDBConfig(BaseGraphDBConfig): space: str = Field( ..., description="The name of the target NebulaGraph space (like a database)" ) + user_name: str | None = Field( + default=None, + description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)", + ) auto_create: bool = Field( default=False, description="Whether to auto-create the space if it does not exist", ) + use_multi_db: bool = Field( + default=True, + description=( + "If True: use Neo4j's multi-database feature for physical isolation; " + "each user typically gets a separate database. " + "If False: use a single shared database with logical isolation by user_name." + ), + ) max_client: int = Field( default=1000, description=("max_client"), diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index a4d38ccca..b439cb2b2 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -115,16 +115,17 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder class SimpleStructMemReader(BaseMemReader, ABC): """Naive implementation of MemReader.""" - def __init__(self, llm, embedder, chunker): + def __init__(self, config: SimpleStructMemReaderConfig): """ Initialize the NaiveMemReader with configuration. Args: config: Configuration object for the reader """ - self.llm = llm - self.embedder = embedder - self.chunker = chunker + self.config = config + self.llm = LLMFactory.from_config(config.llm) + self.embedder = EmbedderFactory.from_config(config.embedder) + self.chunker = ChunkerFactory.from_config(config.chunker) @timed def _process_chat_data(self, scene_data_info, info): @@ -141,8 +142,8 @@ def _process_chat_data(self, scene_data_info, info): examples = PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", "\n".join(mem_list)) - # if self.config.remove_prompt_example: - # prompt = prompt.replace(examples, "") + if self.config.remove_prompt_example: + prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}]