From 33448dd005a56aad7fc9f23f76a2ed226b2f68b6 Mon Sep 17 00:00:00 2001 From: cocoli Date: Thu, 14 Aug 2025 08:29:58 +0800 Subject: [PATCH 01/30] add web_search and logger_manager --- example.env | 3 + prometheus/app/db.py | 4 +- prometheus/app/main.py | 29 +- .../app/services/service_coordinator.py | 23 +- prometheus/docker/base_container.py | 5 +- prometheus/git/git_repository.py | 5 +- prometheus/graph/knowledge_graph.py | 4 +- .../bug_fix_verification_subgraph_node.py | 6 +- .../lang_graph/nodes/bug_fix_verify_node.py | 4 +- .../nodes/bug_fix_verify_structured_node.py | 6 +- .../nodes/bug_reproducing_execute_node.py | 4 +- .../nodes/bug_reproducing_file_node.py | 4 +- .../nodes/bug_reproducing_structured_node.py | 6 +- .../bug_reproducing_write_message_node.py | 6 +- .../nodes/bug_reproducing_write_node.py | 4 +- .../nodes/bug_reproduction_subgraph_node.py | 6 +- .../nodes/build_and_test_subgraph_node.py | 4 +- .../nodes/context_extraction_node.py | 4 +- .../lang_graph/nodes/context_provider_node.py | 4 +- .../nodes/context_query_message_node.py | 4 +- .../lang_graph/nodes/context_refine_node.py | 4 +- .../nodes/context_retrieval_subgraph_node.py | 6 +- .../lang_graph/nodes/edit_message_node.py | 4 +- prometheus/lang_graph/nodes/edit_node.py | 4 +- .../nodes/final_patch_selection_node.py | 4 +- .../lang_graph/nodes/general_build_node.py | 4 +- .../nodes/general_build_structured_node.py | 6 +- .../lang_graph/nodes/general_test_node.py | 4 +- .../nodes/general_test_structured_node.py | 4 +- prometheus/lang_graph/nodes/git_diff_node.py | 4 +- prometheus/lang_graph/nodes/git_reset_node.py | 4 +- .../nodes/issue_bug_analyzer_message_node.py | 6 +- .../nodes/issue_bug_analyzer_node.py | 79 ++++- .../nodes/issue_bug_context_message_node.py | 6 +- ...e_bug_reproduction_context_message_node.py | 6 +- .../nodes/issue_bug_responder_node.py | 4 +- .../nodes/issue_bug_subgraph_node.py | 4 +- ...sue_classification_context_message_node.py | 6 +- .../issue_classification_subgraph_node.py | 6 +- .../lang_graph/nodes/issue_classifier_node.py | 4 +- .../issue_not_verified_bug_subgraph_node.py | 7 +- .../nodes/issue_verified_bug_subgraph_node.py | 6 +- prometheus/lang_graph/nodes/noop_node.py | 4 +- .../lang_graph/nodes/reset_messages_node.py | 4 +- .../lang_graph/nodes/update_container_node.py | 4 +- .../nodes/user_defined_build_node.py | 4 +- .../nodes/user_defined_test_node.py | 4 +- prometheus/neo4j/knowledge_graph_handler.py | 4 +- prometheus/tools/file_operation.py | 4 +- prometheus/tools/web_search.py | 147 +++++++++ prometheus/utils/logger_manager.py | 282 ++++++++++++++++++ pyproject.toml | 4 +- .../nodes/test_issue_bug_analyzer_node.py | 114 ++++++- 53 files changed, 730 insertions(+), 162 deletions(-) create mode 100644 prometheus/tools/web_search.py create mode 100644 prometheus/utils/logger_manager.py diff --git a/example.env b/example.env index d2b2c063..8ecaf619 100644 --- a/example.env +++ b/example.env @@ -32,5 +32,8 @@ PROMETHEUS_MAX_OUTPUT_TOKENS=15000 # GitHub settings PROMETHEUS_GITHUB_ACCESS_TOKEN=github_access_token +# Tavily API settings +PROMETHEUS_TAVILY_API_KEY=your_tavily_api_key + # Database settings PROMETHEUS_DATABASE_URL=postgresql://postgres:password@localhost:5432/postgres?sslmode=disable diff --git a/prometheus/app/db.py b/prometheus/app/db.py index a95fa8f1..0d48c2f3 100644 --- a/prometheus/app/db.py +++ b/prometheus/app/db.py @@ -1,4 +1,3 @@ -import logging from typing import Optional from passlib.hash import bcrypt @@ -6,9 +5,10 @@ from prometheus.app.entity.user import User from prometheus.configuration.config import settings +from prometheus.utils.logger_manager import get_logger engine = create_engine(settings.DATABASE_URL, echo=True) -_logger = logging.getLogger("prometheus.app.db") +_logger = get_logger(__name__) # Create the database and tables diff --git a/prometheus/app/main.py b/prometheus/app/main.py index 2eeca317..0729c4c1 100644 --- a/prometheus/app/main.py +++ b/prometheus/app/main.py @@ -1,4 +1,3 @@ -import logging from contextlib import asynccontextmanager from datetime import datetime, timezone @@ -6,30 +5,10 @@ from prometheus.app import dependencies from prometheus.app.api import issue, repository -from prometheus.configuration.config import settings - -# Create a logger for the application's namespace -logger = logging.getLogger("prometheus") -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) -logger.addHandler(console_handler) -logger.setLevel(getattr(logging, settings.LOGGING_LEVEL)) -logger.propagate = False - -# Log the configuration settings -logger.info(f"LOGGING_LEVEL={settings.LOGGING_LEVEL}") -logger.info(f"ADVANCED_MODEL={settings.ADVANCED_MODEL}") -logger.info(f"BASE_MODEL={settings.BASE_MODEL}") -logger.info(f"NEO4J_BATCH_SIZE={settings.NEO4J_BATCH_SIZE}") -logger.info(f"WORKING_DIRECTORY={settings.WORKING_DIRECTORY}") -logger.info(f"KNOWLEDGE_GRAPH_MAX_AST_DEPTH={settings.KNOWLEDGE_GRAPH_MAX_AST_DEPTH}") -logger.info(f"KNOWLEDGE_GRAPH_CHUNK_SIZE={settings.KNOWLEDGE_GRAPH_CHUNK_SIZE}") -logger.info(f"KNOWLEDGE_GRAPH_CHUNK_OVERLAP={settings.KNOWLEDGE_GRAPH_CHUNK_OVERLAP}") -logger.info(f"MAX_TOKEN_PER_NEO4J_RESULT={settings.MAX_TOKEN_PER_NEO4J_RESULT}") -logger.info(f"TEMPERATURE={settings.TEMPERATURE}") -logger.info(f"MAX_INPUT_TOKENS={settings.MAX_INPUT_TOKENS}") -logger.info(f"MAX_OUTPUT_TOKENS={settings.MAX_OUTPUT_TOKENS}") +from prometheus.utils.logger_manager import get_logger + +# 获取应用程序的logger +logger = get_logger(__name__) @asynccontextmanager diff --git a/prometheus/app/services/service_coordinator.py b/prometheus/app/services/service_coordinator.py index 05955eb2..202660f7 100644 --- a/prometheus/app/services/service_coordinator.py +++ b/prometheus/app/services/service_coordinator.py @@ -6,7 +6,6 @@ issue handling, and conversation management. """ -import logging import traceback from datetime import datetime from pathlib import Path @@ -18,6 +17,7 @@ from prometheus.app.services.neo4j_service import Neo4jService from prometheus.app.services.repository_service import RepositoryService from prometheus.lang_graph.graphs.issue_state import IssueType +from prometheus.utils.logger_manager import get_logger, create_timestamped_file_handler class ServiceCoordinator: @@ -62,7 +62,7 @@ def __init__( self.working_directory = working_directory self.answer_issue_log_dir = self.working_directory / "answer_issue_logs" self.answer_issue_log_dir.mkdir(parents=True, exist_ok=True) - self._logger = logging.getLogger("prometheus.app.services.service_coordinator") + self._logger = get_logger(__name__) if ( self.knowledge_graph_service.get_local_path() @@ -118,13 +118,12 @@ def answer_issue( - passed_existing_test: Whether existing tests passed. - issue_response: Response from the issue service after processing. """ - logger = logging.getLogger("prometheus") - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = self.answer_issue_log_dir / f"{timestamp}.log" - file_handler = logging.FileHandler(log_file) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + # 为这个特定的issue处理创建带时间戳的日志文件处理器 + file_handler = create_timestamped_file_handler( + self.answer_issue_log_dir, + f"issue_{issue_number}", + "prometheus" + ) try: # fix issue @@ -159,11 +158,11 @@ def answer_issue( issue_response, ) except Exception as e: - logger.error(f"Error in answer_issue: {str(e)}\n{traceback.format_exc()}") + self._logger.error(f"Error in answer_issue: {str(e)}\n{traceback.format_exc()}") return None, None, False, False, False, None finally: - logger.removeHandler(file_handler) - file_handler.close() + # 移除文件处理器并关闭文件 + self._logger.remove_file_handler(file_handler, "prometheus") def exists_knowledge_graph(self) -> bool: return self.knowledge_graph_service.exists() diff --git a/prometheus/docker/base_container.py b/prometheus/docker/base_container.py index c09f5187..b1d6332d 100644 --- a/prometheus/docker/base_container.py +++ b/prometheus/docker/base_container.py @@ -1,4 +1,3 @@ -import logging import shutil import tarfile import tempfile @@ -7,6 +6,7 @@ from typing import Optional, Sequence import docker +from prometheus.utils.logger_manager import get_logger class BaseContainer(ABC): @@ -24,7 +24,6 @@ class BaseContainer(ABC): container: docker.models.containers.Container project_path: Path timeout: int = 120 - logger: logging.Logger def __init__(self, project_path: Path, workdir: Optional[str] = None): """Initialize the container with a project directory. @@ -34,7 +33,7 @@ def __init__(self, project_path: Path, workdir: Optional[str] = None): Args: project_path: Path to the project directory to be containerized. """ - self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") + self._logger = get_logger(f"{self.__class__.__module__}.{self.__class__.__name__}") temp_dir = Path(tempfile.mkdtemp()) temp_project_path = temp_dir / project_path.name shutil.copytree(project_path, temp_project_path) diff --git a/prometheus/git/git_repository.py b/prometheus/git/git_repository.py index aa5412da..2852ab74 100644 --- a/prometheus/git/git_repository.py +++ b/prometheus/git/git_repository.py @@ -8,6 +8,7 @@ from typing import Optional, Sequence from git import Git, InvalidGitRepositoryError, Repo +from prometheus.utils.logger_manager import get_logger class GitRepository: @@ -38,12 +39,12 @@ def __init__( github_access_token: GitHub access token for authentication with remote repositories. Required if address is an HTTPS URL. Defaults to None. """ - self._logger = logging.getLogger("prometheus.git.git_repository") + self._logger = get_logger(__name__) # Configure git command to use our logger g = Git() type(g).GIT_PYTHON_TRACE = "full" - git_cmd_logger = logging.getLogger("git.cmd") + git_cmd_logger = get_logger("git.cmd") # Ensure git command output goes to our logger for handler in git_cmd_logger.handlers: diff --git a/prometheus/graph/knowledge_graph.py b/prometheus/graph/knowledge_graph.py index d00c35ea..59ee443a 100644 --- a/prometheus/graph/knowledge_graph.py +++ b/prometheus/graph/knowledge_graph.py @@ -18,7 +18,6 @@ """ import itertools -import logging from collections import defaultdict, deque from pathlib import Path from typing import Mapping, Optional, Sequence @@ -45,6 +44,7 @@ Neo4jTextNode, TextNode, ) +from prometheus.utils.logger_manager import get_logger class KnowledgeGraph: @@ -82,7 +82,7 @@ def __init__( knowledge_graph_edges if knowledge_graph_edges is not None else [] ) self._file_graph_builder = FileGraphBuilder(max_ast_depth, chunk_size, chunk_overlap) - self._logger = logging.getLogger("prometheus.graph.knowledge_graph") + self._logger = get_logger(__name__) def build_graph( self, root_dir: Path, https_url: Optional[str] = None, commit_id: Optional[str] = None diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index 25620677..46a0ab5e 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -1,10 +1,10 @@ -import logging from langchain_core.language_models.chat_models import BaseChatModel from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.bug_fix_verification_subgraph import BugFixVerificationSubgraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState +from prometheus.utils.logger_manager import get_logger class BugFixVerificationSubgraphNode: @@ -13,9 +13,7 @@ def __init__( model: BaseChatModel, container: BaseContainer, ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_fix_verification_subgraph_node" - ) + self._logger = get_logger(__name__) self.subgraph = BugFixVerificationSubgraph( model=model, container=container, diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index 27094be6..1184753e 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -1,5 +1,4 @@ import functools -import logging from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -8,6 +7,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerficationState from prometheus.tools import container_command +from prometheus.utils.logger_manager import get_logger class BugFixVerifyNode: @@ -53,7 +53,7 @@ def __init__(self, model: BaseChatModel, container: BaseContainer): self.tools = self._init_tools(container) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_verify_node") + self._logger = get_logger(__name__) def _init_tools(self, container: BaseContainer): tools = [] diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py index ac4fb533..ab9f1b9a 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py @@ -1,4 +1,3 @@ -import logging from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -6,6 +5,7 @@ from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerficationState from prometheus.utils.lang_graph_util import get_last_message_content +from prometheus.utils.logger_manager import get_logger class BugFixVerifyStructureOutput(BaseModel): @@ -90,9 +90,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BugFixVerifyStructureOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_fix_verify_structured_node" - ) + self._logger = get_logger(__name__) def __call__(self, state: BugFixVerficationState): bug_fix_verify_message = get_last_message_content(state["bug_fix_verify_messages"]) diff --git a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py index e7121965..5c351750 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py @@ -1,5 +1,4 @@ import functools -import logging from pathlib import Path from typing import Optional, Sequence @@ -12,6 +11,7 @@ from prometheus.tools import container_command from prometheus.utils.issue_util import format_test_commands from prometheus.utils.patch_util import get_updated_files +from prometheus.utils.logger_manager import get_logger class BugReproducingExecuteNode: @@ -55,7 +55,7 @@ def __init__( self.tools = self._init_tools(container) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_execute_node") + self._logger = get_logger(__name__) def _init_tools(self, container: BaseContainer): tools = [] diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index 25275b84..d381c87b 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -1,5 +1,4 @@ import functools -import logging from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -9,6 +8,7 @@ from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.tools import file_operation from prometheus.utils.lang_graph_util import get_last_message_content +from prometheus.utils.logger_manager import get_logger class BugReproducingFileNode: @@ -44,7 +44,7 @@ def __init__( self.tools = self._init_tools(str(kg.get_local_path())) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_file_node") + self._logger = get_logger(__name__) def _init_tools(self, root_path: str): """Initializes file operation tools with the given root path. diff --git a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py index 9fa204aa..b8ec36ae 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py @@ -1,4 +1,3 @@ -import logging from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -11,6 +10,7 @@ format_agent_tool_message_history, get_last_message_content, ) +from prometheus.utils.logger_manager import get_logger class BugReproducingStructuredOutput(BaseModel): @@ -134,9 +134,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BugReproducingStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_reproducing_structured_node" - ) + self._logger = get_logger(__name__) def __call__(self, state: BugReproductionState): bug_reproducing_log = format_agent_tool_message_history( diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py index f1c85d97..f19d4759 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py @@ -1,9 +1,9 @@ -import logging from langchain_core.messages import HumanMessage from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_logger class BugReproducingWriteMessageNode: @@ -24,9 +24,7 @@ class BugReproducingWriteMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_reproducing_write_message_node" - ) + self._logger = get_logger(__name__) def format_human_message(self, state: BugReproductionState): if "reproduced_bug_failure_log" in state and state["reproduced_bug_failure_log"]: diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index fad2d6a7..619dd936 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -1,5 +1,4 @@ import functools -import logging from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -8,6 +7,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.tools import file_operation +from prometheus.utils.logger_manager import get_logger class BugReproducingWriteNode: @@ -115,7 +115,7 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): self.tools = self._init_tools(str(kg.get_local_path())) self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_write_node") + self._logger = get_logger(__name__) def _init_tools(self, root_path: str): """Initializes file operation tools with the given root path. diff --git a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py index 39b0d0a5..13048240 100644 --- a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py @@ -1,4 +1,3 @@ -import logging from typing import Optional, Sequence import neo4j @@ -10,6 +9,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_subgraph import BugReproductionSubgraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState +from prometheus.utils.logger_manager import get_logger class BugReproductionSubgraphNode: @@ -24,9 +24,7 @@ def __init__( max_token_per_neo4j_result: int, test_commands: Optional[Sequence[str]], ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_reproduction_subgraph_node" - ) + self._logger = get_logger(__name__) self.git_repo = git_repo self.bug_reproduction_subgraph = BugReproductionSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py index 7f84c90b..853b8340 100644 --- a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py +++ b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py @@ -1,4 +1,3 @@ -import logging from typing import Optional, Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -7,6 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_subgraph import BuildAndTestSubgraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState +from prometheus.utils.logger_manager import get_logger class BuildAndTestSubgraphNode: @@ -25,7 +25,7 @@ def __init__( build_commands=build_commands, test_commands=test_commands, ) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.build_and_test_subgraph_node") + self._logger = get_logger(__name__) def __call__(self, state: IssueBugState): exist_build = None diff --git a/prometheus/lang_graph/nodes/context_extraction_node.py b/prometheus/lang_graph/nodes/context_extraction_node.py index e1f8db73..99263b44 100644 --- a/prometheus/lang_graph/nodes/context_extraction_node.py +++ b/prometheus/lang_graph/nodes/context_extraction_node.py @@ -1,4 +1,3 @@ -import logging from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -9,6 +8,7 @@ from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState from prometheus.models.context import Context from prometheus.utils.file_utils import read_file_with_line_numbers +from prometheus.utils.logger_manager import get_logger SYS_PROMPT = """\ You are a context summary agent that summarizes code context that is relevant to a given query from history messages. @@ -88,7 +88,7 @@ def __init__(self, model: BaseChatModel, root_path: str): self.model = structured_llm self.system_prompt = SystemMessage(SYS_PROMPT) self.root_path = root_path - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_extraction_node") + self._logger = get_logger(__name__) def __call__(self, state: ContextRetrievalState): self._logger.info("Starting context extraction process") diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index 30fd8345..39313699 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -6,7 +6,6 @@ """ import functools -import logging from typing import Dict import neo4j @@ -16,6 +15,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.tools import graph_traversal +from prometheus.utils.logger_manager import get_logger class ContextProviderNode: @@ -116,7 +116,7 @@ def __init__( self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_provider_node") + self._logger = get_logger(__name__) def _init_tools(self): """ diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py index ba43a5b1..f34da15c 100644 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ b/prometheus/lang_graph/nodes/context_query_message_node.py @@ -1,13 +1,13 @@ -import logging from langchain_core.messages import HumanMessage from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.utils.logger_manager import get_logger class ContextQueryMessageNode: def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_query_message_node") + self._logger = get_logger(__name__) def __call__(self, state: ContextRetrievalState): human_message = HumanMessage(state["query"]) diff --git a/prometheus/lang_graph/nodes/context_refine_node.py b/prometheus/lang_graph/nodes/context_refine_node.py index 05861bcf..f24e2fe8 100644 --- a/prometheus/lang_graph/nodes/context_refine_node.py +++ b/prometheus/lang_graph/nodes/context_refine_node.py @@ -1,4 +1,3 @@ -import logging from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage @@ -7,6 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.utils.logger_manager import get_logger class ContextRefineStructuredOutput(BaseModel): @@ -79,7 +79,7 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): ) structured_llm = model.with_structured_output(ContextRefineStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_refine_node") + self._logger = get_logger(__name__) def format_refine_message(self, state: ContextRetrievalState): original_query = state["query"] diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index 03583315..81c83253 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -1,4 +1,3 @@ -import logging from typing import Dict, Sequence import neo4j @@ -7,6 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.context_retrieval_subgraph import ContextRetrievalSubgraph from prometheus.models.context import Context +from prometheus.utils.logger_manager import get_logger class ContextRetrievalSubgraphNode: @@ -19,9 +19,7 @@ def __init__( query_key_name: str, context_key_name: str, ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.context_retrieval_subgraph_node" - ) + self._logger = get_logger(__name__) self.context_retrieval_subgraph = ContextRetrievalSubgraph( model=model, kg=kg, diff --git a/prometheus/lang_graph/nodes/edit_message_node.py b/prometheus/lang_graph/nodes/edit_message_node.py index bbbb93a8..07da4c1a 100644 --- a/prometheus/lang_graph/nodes/edit_message_node.py +++ b/prometheus/lang_graph/nodes/edit_message_node.py @@ -1,10 +1,10 @@ -import logging from typing import Dict from langchain_core.messages import HumanMessage from prometheus.utils.issue_util import format_issue_info from prometheus.utils.lang_graph_util import get_last_message_content +from prometheus.utils.logger_manager import get_logger class EditMessageNode: @@ -32,7 +32,7 @@ class EditMessageNode: """ def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.edit_message_node") + self._logger = get_logger(__name__) def format_human_message(self, state: Dict): edit_error = "" diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index 06c8ff8e..ad5b75e1 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -7,7 +7,6 @@ """ import functools -import logging from typing import Dict from langchain.tools import StructuredTool @@ -16,6 +15,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.tools import file_operation +from prometheus.utils.logger_manager import get_logger class EditNode: @@ -120,7 +120,7 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.tools = self._init_tools(kg.get_local_path()) self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.edit_node") + self._logger = get_logger(__name__) def _init_tools(self, root_path: str): """Initializes file operation tools with the given root path. diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index 5637a3c5..8702b62b 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -1,4 +1,3 @@ -import logging from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -6,6 +5,7 @@ from pydantic import BaseModel, Field from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_logger class FinalPatchSelectionStructuredOutput(BaseModel): @@ -126,7 +126,7 @@ def __init__(self, model: BaseChatModel, max_retries: int = 2): ) structured_llm = model.with_structured_output(FinalPatchSelectionStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.final_patch_selection_node") + self._logger = get_logger(__name__) def format_human_message(self, state: Dict): patches = "" diff --git a/prometheus/lang_graph/nodes/general_build_node.py b/prometheus/lang_graph/nodes/general_build_node.py index 1d3454f5..b176eaf2 100644 --- a/prometheus/lang_graph/nodes/general_build_node.py +++ b/prometheus/lang_graph/nodes/general_build_node.py @@ -1,5 +1,4 @@ import functools -import logging from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -9,6 +8,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.tools import container_command +from prometheus.utils.logger_manager import get_logger class GeneralBuildNode: @@ -46,7 +46,7 @@ def __init__(self, model: BaseChatModel, container: BaseContainer, kg: Knowledge self.tools = self._init_tools(container) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_build_node") + self._logger = get_logger(__name__) def _init_tools(self, container: BaseContainer): tools = [] diff --git a/prometheus/lang_graph/nodes/general_build_structured_node.py b/prometheus/lang_graph/nodes/general_build_structured_node.py index af9058bc..a5ffa23c 100644 --- a/prometheus/lang_graph/nodes/general_build_structured_node.py +++ b/prometheus/lang_graph/nodes/general_build_structured_node.py @@ -6,7 +6,6 @@ identify any failures. """ -import logging from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -14,6 +13,7 @@ from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.utils.lang_graph_util import format_agent_tool_message_history +from prometheus.utils.logger_manager import get_logger class BuildStructuredOutput(BaseModel): @@ -237,9 +237,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BuildStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.general_build_structured_node" - ) + self._logger = get_logger(__name__) def __call__(self, state: BuildAndTestState): """Processes build state to generate structured build analysis. diff --git a/prometheus/lang_graph/nodes/general_test_node.py b/prometheus/lang_graph/nodes/general_test_node.py index d35c1ee8..faac21cf 100644 --- a/prometheus/lang_graph/nodes/general_test_node.py +++ b/prometheus/lang_graph/nodes/general_test_node.py @@ -1,5 +1,4 @@ import functools -import logging from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -9,6 +8,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.tools import container_command +from prometheus.utils.logger_manager import get_logger class GeneralTestNode: @@ -63,7 +63,7 @@ def __init__(self, model: BaseChatModel, container: BaseContainer, kg: Knowledge self.tools = self._init_tools(container) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_test_node") + self._logger = get_logger(__name__) def _init_tools(self, container: BaseContainer): tools = [] diff --git a/prometheus/lang_graph/nodes/general_test_structured_node.py b/prometheus/lang_graph/nodes/general_test_structured_node.py index 78758070..55727abf 100644 --- a/prometheus/lang_graph/nodes/general_test_structured_node.py +++ b/prometheus/lang_graph/nodes/general_test_structured_node.py @@ -6,7 +6,6 @@ identify any failures. """ -import logging from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -14,6 +13,7 @@ from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.utils.lang_graph_util import format_agent_tool_message_history +from prometheus.utils.logger_manager import get_logger class TestStructuredOutput(BaseModel): @@ -286,7 +286,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(TestStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_test_structured_node") + self._logger = get_logger(__name__) def __call__(self, state: BuildAndTestState): """Processes test state to generate structured test analysis. diff --git a/prometheus/lang_graph/nodes/git_diff_node.py b/prometheus/lang_graph/nodes/git_diff_node.py index 4a8be50c..c0951040 100644 --- a/prometheus/lang_graph/nodes/git_diff_node.py +++ b/prometheus/lang_graph/nodes/git_diff_node.py @@ -6,10 +6,10 @@ output. """ -import logging from typing import Dict, Optional from prometheus.git.git_repository import GitRepository +from prometheus.utils.logger_manager import get_logger class GitDiffNode: @@ -32,7 +32,7 @@ def __init__( self.state_patch_name = state_patch_name self.state_excluded_files_key = state_excluded_files_key self.return_list = return_list - self._logger = logging.getLogger("prometheus.lang_graph.nodes.git_diff_node") + self._logger = get_logger(__name__) def __call__(self, state: Dict): """Generates a Git diff for the current project state. diff --git a/prometheus/lang_graph/nodes/git_reset_node.py b/prometheus/lang_graph/nodes/git_reset_node.py index b9d63b30..1847654d 100644 --- a/prometheus/lang_graph/nodes/git_reset_node.py +++ b/prometheus/lang_graph/nodes/git_reset_node.py @@ -1,6 +1,6 @@ -import logging from prometheus.git.git_repository import GitRepository +from prometheus.utils.logger_manager import get_logger class GitResetNode: @@ -9,7 +9,7 @@ def __init__( git_repo: GitRepository, ): self.git_repo = git_repo - self._logger = logging.getLogger("prometheus.lang_graph.nodes.git_reset_node") + self._logger = get_logger(__name__) def __call__(self, _): self._logger.debug("Resetting the git repository") diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py index 89845ced..dbbd461f 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py @@ -1,9 +1,9 @@ -import logging from typing import Dict from langchain_core.messages import HumanMessage from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_logger class IssueBugAnalyzerMessageNode: @@ -64,9 +64,7 @@ class IssueBugAnalyzerMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_bug_analyzer_message_node" - ) + self._logger = get_logger(__name__) def format_human_message(self, state: Dict): edit_error = "" diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index 74586fb9..3301156e 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -1,11 +1,52 @@ -import logging from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import SystemMessage +from langchain.tools import StructuredTool +import functools +from prometheus.tools import web_search +from prometheus.utils.logger_manager import get_logger class IssueBugAnalyzerNode: +# SYS_PROMPT = """\ +# You are an expert software engineer specializing in bug analysis and fixes. Your role is to: + +# 1. Carefully analyze reported software issues and bugs by: +# - Understanding issue descriptions and symptoms +# - Identifying affected code components +# - Tracing problematic execution paths + +# 2. Determine root causes through systematic investigation: +# - Analyze why the current behavior deviates from expected +# - Identify which specific code elements are responsible +# - Understand the context and interactions causing the issue + +# 3. Provide high-level fix suggestions by describing: +# - Which specific files need modification +# - Which functions or code blocks need changes +# - What logical changes are needed (e.g., "variable x needs to be renamed to y", "need to add validation for parameter z") +# - Why these changes would resolve the issue + +# 4. For patch failures, analyze by: +# - Understanding error messages and test failures +# - Identifying what went wrong with the previous attempt +# - Suggesting revised high-level changes that avoid the previous issues + +# Tools available: +# - web_search: Searches the web for technical information to aid in bug analysis and resolution. + +# Important: +# - Do NOT provide actual code snippets or diffs +# - DO provide clear file paths and function names where changes are needed +# - Focus on describing WHAT needs to change and WHY, not HOW to change it +# - Keep descriptions precise and actionable, as they will be used by another agent to implement the changes + +# Communicate in a clear, technical manner focused on accurate analysis and practical suggestions +# rather than implementation details. +# """ + + SYS_PROMPT = """\ You are an expert software engineer specializing in bug analysis and fixes. Your role is to: @@ -30,24 +71,54 @@ class IssueBugAnalyzerNode: - Identifying what went wrong with the previous attempt - Suggesting revised high-level changes that avoid the previous issues +MANDATORY TOOL USAGE: +- You MUST use the web_search tool for EVERY bug analysis +- Before providing any analysis, search for: + * Similar error messages or exceptions + * Known issues with the specific libraries/frameworks involved + * Best practices for the type of bug you're analyzing + * Official documentation for error resolution +- Only proceed with analysis after gathering relevant web information + +Tools available: +- web_search: Searches the web for technical information to aid in bug analysis and resolution. + Important: - Do NOT provide actual code snippets or diffs - DO provide clear file paths and function names where changes are needed - Focus on describing WHAT needs to change and WHY, not HOW to change it - Keep descriptions precise and actionable, as they will be used by another agent to implement the changes +- ALWAYS start your analysis with web search results Communicate in a clear, technical manner focused on accurate analysis and practical suggestions rather than implementation details. """ def __init__(self, model: BaseChatModel): - self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model = model - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_analyzer_node") + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self.tools = self._init_tools() + self.model_with_tools = model.bind_tools(self.tools) + self._logger = get_logger(__name__) + + def _init_tools(self): + """Initializes tools for the node.""" + tools = [] + + web_search_fn = functools.partial(web_search.web_search) + web_search_tool = StructuredTool.from_function( + func=web_search_fn, + name=web_search.web_search.__name__, + description=web_search.WEB_SEARCH_DESCRIPTION, + args_schema=web_search.WebSearchInput, + ) + tools.append(web_search_tool) + + return tools def __call__(self, state: Dict): message_history = [self.system_prompt] + state["issue_bug_analyzer_messages"] - response = self.model.invoke(message_history) + response = self.model_with_tools.invoke(message_history) self._logger.debug(response) return {"issue_bug_analyzer_messages": [response]} diff --git a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py index acf6ccd8..f7de19cd 100644 --- a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py @@ -1,7 +1,7 @@ -import logging from typing import Dict from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_logger class IssueBugContextMessageNode: @@ -19,9 +19,7 @@ class IssueBugContextMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_bug_context_message_node" - ) + self._logger = get_logger(__name__) def __call__(self, state: Dict): bug_fix_query = self.BUG_FIX_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py index e951e2d4..ab34b7e6 100644 --- a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py @@ -1,7 +1,7 @@ -import logging from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_logger class IssueBugReproductionContextMessageNode: @@ -108,9 +108,7 @@ def test_file_permission_denied(self, mock_open, mock_access): """ def __init__(self): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_bug_reproduction_context_message_node" - ) + self._logger = get_logger(__name__) def __call__(self, state: BugReproductionState): bug_reproducing_query = self.BUG_REPRODUCING_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_bug_responder_node.py b/prometheus/lang_graph/nodes/issue_bug_responder_node.py index 1be5f22e..43fa841c 100644 --- a/prometheus/lang_graph/nodes/issue_bug_responder_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_responder_node.py @@ -1,10 +1,10 @@ -import logging from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_logger class IssueBugResponderNode: @@ -48,7 +48,7 @@ def __init__(self, model: BaseChatModel): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model = model - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_responder_node") + self._logger = get_logger(__name__) def format_human_message(self, state: Dict) -> HumanMessage: verification_messages = [] diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index 89415443..0583a0a1 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -1,4 +1,3 @@ -import logging from typing import Optional, Sequence import neo4j @@ -10,6 +9,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_state import IssueState from prometheus.lang_graph.subgraphs.issue_bug_subgraph import IssueBugSubgraph +from prometheus.utils.logger_manager import get_logger class IssueBugSubgraphNode: @@ -29,7 +29,7 @@ def __init__( build_commands: Optional[Sequence[str]] = None, test_commands: Optional[Sequence[str]] = None, ): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_subgraph_node") + self._logger = get_logger(__name__) self.container = container self.issue_bug_subgraph = IssueBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py index a90e484a..8d93648e 100644 --- a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py @@ -1,7 +1,7 @@ -import logging from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_logger class IssueClassificationContextMessageNode: @@ -72,9 +72,7 @@ class IssueClassificationContextMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_classification_context_message_node" - ) + self._logger = get_logger(__name__) def __call__(self, state: IssueClassificationState): issue_classification_query = self.ISSUE_CLASSIFICATION_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py index 07229cd3..ed02fe47 100644 --- a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import neo4j from langchain_core.language_models.chat_models import BaseChatModel @@ -8,6 +7,7 @@ from prometheus.lang_graph.subgraphs.issue_classification_subgraph import ( IssueClassificationSubgraph, ) +from prometheus.utils.logger_manager import get_logger class IssueClassificationSubgraphNode: @@ -18,9 +18,7 @@ def __init__( neo4j_driver: neo4j.Driver, max_token_per_neo4j_result: int, ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_classification_subgraph_node" - ) + self._logger = get_logger(__name__) self.issue_classification_subgraph = IssueClassificationSubgraph( model=model, kg=kg, diff --git a/prometheus/lang_graph/nodes/issue_classifier_node.py b/prometheus/lang_graph/nodes/issue_classifier_node.py index eb8a60df..3ecb5874 100644 --- a/prometheus/lang_graph/nodes/issue_classifier_node.py +++ b/prometheus/lang_graph/nodes/issue_classifier_node.py @@ -1,4 +1,3 @@ -import logging from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -7,6 +6,7 @@ from prometheus.lang_graph.graphs.issue_state import IssueType from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_logger class IssueClassifierOutput(BaseModel): @@ -124,7 +124,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(IssueClassifierOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_classifier_node") + self._logger = get_logger(__name__) def format_context_info(self, state: IssueClassificationState) -> str: context_info = self.ISSUE_CLASSIFICATION_CONTEXT.format( diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index c0ce236c..d909f53e 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -1,4 +1,3 @@ -import logging from typing import Dict import neo4j @@ -9,6 +8,7 @@ from prometheus.lang_graph.subgraphs.issue_not_verified_bug_subgraph import ( IssueNotVerifiedBugSubgraph, ) +from prometheus.utils.logger_manager import get_logger class IssueNotVerifiedBugSubgraphNode: @@ -21,9 +21,8 @@ def __init__( neo4j_driver: neo4j.Driver, max_token_per_neo4j_result: int, ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_not_verified_bug_subgraph_node" - ) + self._logger = get_logger(__name__) + self.issue_not_verified_bug_subgraph = IssueNotVerifiedBugSubgraph( advanced_model=advanced_model, base_model=base_model, diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index 43ace6f2..4c5a8d71 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -1,4 +1,3 @@ -import logging from typing import Dict, Optional, Sequence import neo4j @@ -9,6 +8,7 @@ from prometheus.git.git_repository import GitRepository from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.issue_verified_bug_subgraph import IssueVerifiedBugSubgraph +from prometheus.utils.logger_manager import get_logger class IssueVerifiedBugSubgraphNode: @@ -28,9 +28,7 @@ def __init__( build_commands: Optional[Sequence[str]] = None, test_commands: Optional[Sequence[str]] = None, ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_verified_bug_subgraph_node" - ) + self._logger = get_logger(__name__) self.git_repo = git_repo self.issue_reproduced_bug_subgraph = IssueVerifiedBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/noop_node.py b/prometheus/lang_graph/nodes/noop_node.py index e21b3f70..c57634f7 100644 --- a/prometheus/lang_graph/nodes/noop_node.py +++ b/prometheus/lang_graph/nodes/noop_node.py @@ -5,7 +5,7 @@ node graphs where a connection is needed but no processing is required. """ -import logging +from prometheus.utils.logger_manager import get_logger from typing import Dict @@ -19,7 +19,7 @@ class NoopNode: """ def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.noop_node") + self._logger = get_logger(__name__) def __call__(self, state: Dict) -> None: """Routes the workflow without performing any operations. diff --git a/prometheus/lang_graph/nodes/reset_messages_node.py b/prometheus/lang_graph/nodes/reset_messages_node.py index bf87c134..3fdffdd5 100644 --- a/prometheus/lang_graph/nodes/reset_messages_node.py +++ b/prometheus/lang_graph/nodes/reset_messages_node.py @@ -10,7 +10,7 @@ - The same state attribute name is reused """ -import logging +from prometheus.utils.logger_manager import get_logger from typing import Dict @@ -35,7 +35,7 @@ def __init__(self, message_state_key: str): be reset during node execution. """ self.message_state_key = message_state_key - self._logger = logging.getLogger("prometheus.lang_graph.nodes.reset_messages_node") + self._logger = get_logger(__name__) def __call__(self, state: Dict): """Resets the specified message state for the next iteration. diff --git a/prometheus/lang_graph/nodes/update_container_node.py b/prometheus/lang_graph/nodes/update_container_node.py index 5c473dc6..351b2851 100644 --- a/prometheus/lang_graph/nodes/update_container_node.py +++ b/prometheus/lang_graph/nodes/update_container_node.py @@ -6,12 +6,12 @@ between the agent's workspace and the container environment. """ -import logging from typing import Dict from prometheus.docker.base_container import BaseContainer from prometheus.git.git_repository import GitRepository from prometheus.utils.patch_util import get_updated_files +from prometheus.utils.logger_manager import get_logger class UpdateContainerNode: @@ -33,7 +33,7 @@ def __init__(self, container: BaseContainer, git_repo: GitRepository): """ self.container = container self.git_repo = git_repo - self._logger = logging.getLogger("prometheus.lang_graph.nodes.update_container_node") + self._logger = get_logger(__name__) def __call__(self, _: Dict): """Synchronizes the current project state with the container.""" diff --git a/prometheus/lang_graph/nodes/user_defined_build_node.py b/prometheus/lang_graph/nodes/user_defined_build_node.py index aa8dc40c..e49ff4c9 100644 --- a/prometheus/lang_graph/nodes/user_defined_build_node.py +++ b/prometheus/lang_graph/nodes/user_defined_build_node.py @@ -1,16 +1,16 @@ -import logging import uuid from typing import Any from langchain_core.messages import ToolMessage from prometheus.docker.base_container import BaseContainer +from prometheus.utils.logger_manager import get_logger class UserDefinedBuildNode: def __init__(self, container: BaseContainer): self.container = container - self._logger = logging.getLogger("prometheus.lang_graph.nodes.user_defined_build_node") + self._logger = get_logger(__name__) def __call__(self, _: Any): build_output = self.container.run_build() diff --git a/prometheus/lang_graph/nodes/user_defined_test_node.py b/prometheus/lang_graph/nodes/user_defined_test_node.py index 7f3e6e61..e04cefed 100644 --- a/prometheus/lang_graph/nodes/user_defined_test_node.py +++ b/prometheus/lang_graph/nodes/user_defined_test_node.py @@ -1,16 +1,16 @@ -import logging import uuid from typing import Any from langchain_core.messages import ToolMessage from prometheus.docker.base_container import BaseContainer +from prometheus.utils.logger_manager import get_logger class UserDefinedTestNode: def __init__(self, container: BaseContainer): self.container = container - self._logger = logging.getLogger("prometheus.lang_graph.nodes.user_defined_test_node") + self._logger = get_logger(__name__) def __call__(self, _: Any): test_output = self.container.run_test() diff --git a/prometheus/neo4j/knowledge_graph_handler.py b/prometheus/neo4j/knowledge_graph_handler.py index 70869c4c..d0cac311 100644 --- a/prometheus/neo4j/knowledge_graph_handler.py +++ b/prometheus/neo4j/knowledge_graph_handler.py @@ -1,6 +1,5 @@ """The neo4j handler for writing the knowledge graph to neo4j.""" -import logging from typing import Mapping, Sequence from neo4j import GraphDatabase, ManagedTransaction @@ -18,6 +17,7 @@ Neo4jTextNode, ) from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.utils.logger_manager import get_logger class KnowledgeGraphHandler: @@ -32,7 +32,7 @@ def __init__(self, driver: GraphDatabase.driver, batch_size: int): self.driver = driver self.batch_size = batch_size - self._logger = logging.getLogger("prometheus.neo4j.knowledge_graph_handler") + self._logger = get_logger(__name__) def _init_database(self, tx: ManagedTransaction): """Initialization of the neo4j database.""" diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index 4122d0c2..330e149f 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -1,4 +1,3 @@ -import logging import os import shutil from pathlib import Path @@ -6,8 +5,9 @@ from pydantic import BaseModel, Field from prometheus.utils.str_util import pre_append_line_numbers +from prometheus.utils.logger_manager import get_logger -logger = logging.getLogger("prometheus.tools.file_operation") +logger = get_logger(__name__) class ReadFileInput(BaseModel): diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py new file mode 100644 index 00000000..816ab522 --- /dev/null +++ b/prometheus/tools/web_search.py @@ -0,0 +1,147 @@ +import os +import shutil +from pathlib import Path +from typing import Annotated +import json +import asyncio +from dynaconf.vendor.dotenv import load_dotenv +from pydantic import BaseModel, Field, field_validator +from mcp.server import Server +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData +from mcp.server.stdio import stdio_server +from mcp.types import ( + GetPromptResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, + Tool, + INVALID_PARAMS, + INTERNAL_ERROR, +) +from tavily import TavilyClient, InvalidAPIKeyError, UsageLimitExceededError +from prometheus.configuration.config import settings +from prometheus.utils.logger_manager import get_logger + +logger = get_logger(__name__) + + +tavily_api_key = settings.get("TAVILY_API_KEY", None) +if tavily_api_key is None: + logger.warning("Tavily API key is not set") + tavily_client = None +else: + tavily_client = TavilyClient(api_key=tavily_api_key) + + +class WebSearchInput(BaseModel): + """Base parameters for Tavily search.""" + query: Annotated[str, Field(description="Search query")] + +WEB_SEARCH_DESCRIPTION = """\ + Searches the web for technical information to aid in bug analysis and resolution. + Use this when you need external context, such as: + 1. Looking up unfamiliar error messages, exceptions, or stack traces. + 2. Finding official documentation or usage examples for a specific library, framework, or API. + 3. Searching for known bugs, common pitfalls, or compatibility issues related to a software version. + 4. Learning the best practices or design patterns for fixing a class of vulnerability (e.g., SQL injection, XSS). + + Queries should be specific and include relevant keywords like library names, version numbers, and error codes. +""" + +def format_results(response: dict) -> str: + """Format Tavily search results into a readable string.""" + output = [] + + # Add domain filter information if present + if response.get("included_domains") or response.get("excluded_domains"): + filters = [] + if response.get("included_domains"): + filters.append(f"Including domains: {', '.join(response['included_domains'])}") + if response.get("excluded_domains"): + filters.append(f"Excluding domains: {', '.join(response['excluded_domains'])}") + output.append("Search Filters:") + output.extend(filters) + output.append("") # Empty line for separation + + if response.get("answer"): + output.append(f"Answer: {response['answer']}") + output.append("\nSources:") + # Add immediate source references for the answer + for result in response["results"]: + output.append(f"- {result['title']}: {result['url']}") + output.append("") # Empty line for separation + + output.append("Detailed Results:") + for result in response["results"]: + output.append(f"\nTitle: {result['title']}") + output.append(f"URL: {result['url']}") + output.append(f"Content: {result['content']}") + if result.get("published_date"): + output.append(f"Published: {result['published_date']}") + + return "\n".join(output) + + + +def web_search(query: str, + max_results: int = 5, + include_domains: list[str] = [ + 'stackoverflow.com', + 'github.com', + 'developer.mozilla.org', + 'learn.microsoft.com', + 'docs.python.org', + 'pydantic.dev', + 'pypi.org', + 'readthedocs.org', + ], + exclude_domains: list[str] = None) -> str: + """ + Search the web for technical information to aid in bug analysis and resolution. + """ + if tavily_client is None: + raise McpError(ErrorData( + code=INTERNAL_ERROR, + message="Tavily API key is not set" + )) + try: + response = tavily_client.search( + query=query, + max_results=max_results, + search_depth="advanced", + include_answer=True, + include_domains=include_domains or [], # Convert None to empty list + exclude_domains=exclude_domains or [], # Convert None to empty list + ) + return format_results(response) + except InvalidAPIKeyError: + raise McpError(ErrorData( + code=INVALID_PARAMS, + message="Invalid Tavily API key" + )) + except UsageLimitExceededError: + raise McpError(ErrorData( + code=INTERNAL_ERROR, + message="Usage limit exceeded" + )) + except Exception as e: + raise McpError(ErrorData( + code=INTERNAL_ERROR, + message=f"An error occurred: {str(e)}" + )) + + + + +if __name__ == "__main__": + load_dotenv() + tavily_api_key = os.getenv("PROMETHEUS_TAVILY_API_KEY") + if tavily_api_key is None: + logger.warning("Tavily API key is not set") + tavily_client = None + else: + tavily_client = TavilyClient(api_key=tavily_api_key) + + print(web_search("What is the capital of France?")) \ No newline at end of file diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py new file mode 100644 index 00000000..d76566e6 --- /dev/null +++ b/prometheus/utils/logger_manager.py @@ -0,0 +1,282 @@ +""" +Unified Log Manager + +This module provides a centralized logging management solution for the entire Prometheus project. +All logger configuration and retrieval should be done through this module. +""" + +import logging +import os +import sys +from pathlib import Path +from typing import Optional +from datetime import datetime + +from prometheus.configuration.config import settings + + +class ColoredFormatter(logging.Formatter): + """Colored log formatter""" + + # ANSI color codes + COLORS = { + 'DEBUG': '\033[36m', # Cyan + 'INFO': '\033[32m', # Green + 'WARNING': '\033[33m', # Yellow + 'ERROR': '\033[31m', # Red + 'CRITICAL': '\033[35m', # Purple + 'RESET': '\033[0m' # Reset color + } + + # Colored level names + COLORED_LEVELNAMES = { + 'DEBUG': f'{COLORS["DEBUG"]}DEBUG{COLORS["RESET"]}', + 'INFO': f'{COLORS["INFO"]}INFO{COLORS["RESET"]}', + 'WARNING': f'{COLORS["WARNING"]}WARNING{COLORS["RESET"]}', + 'ERROR': f'{COLORS["ERROR"]}ERROR{COLORS["RESET"]}', + 'CRITICAL': f'{COLORS["CRITICAL"]}CRITICAL{COLORS["RESET"]}', + } + + def __init__(self, fmt=None, datefmt=None, use_colors=True): + """ + Initialize colored formatter + + Args: + fmt: Log format string + datefmt: Date format string + use_colors: Whether to use colors + """ + super().__init__(fmt, datefmt) + self.use_colors = use_colors and self._supports_color() + + def _supports_color(self) -> bool: + """Check if terminal supports colors""" + # Check if running in a color-supporting terminal + return ( + hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() and + sys.platform != 'win32' # Windows may need special handling + ) or 'FORCE_COLOR' in os.environ + + def format(self, record): + """Format log record""" + if self.use_colors and record.levelname in self.COLORED_LEVELNAMES: + # Save original level name + original_levelname = record.levelname + # Use colored level name + record.levelname = self.COLORED_LEVELNAMES[record.levelname] + + # Format message + formatted = super().format(record) + + # Restore original level name + record.levelname = original_levelname + + return formatted + else: + return super().format(record) + + +class LoggerManager: + """Logger manager class, responsible for creating and configuring all loggers""" + + _instance: Optional['LoggerManager'] = None + _initialized: bool = False + + def __new__(cls) -> 'LoggerManager': + """Singleton pattern implementation""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + """Initialize logger manager""" + if not self._initialized: + self._setup_root_logger() + self._initialized = True + + def _setup_root_logger(self): + """Setup root logger""" + # Get root logger + self.root_logger = logging.getLogger("prometheus") + + # Clear existing handlers to avoid duplication + self.root_logger.handlers.clear() + + # Set log level + log_level = getattr(settings, 'LOGGING_LEVEL', 'INFO') + self.root_logger.setLevel(getattr(logging, log_level)) + + # Create colored formatter for console output + self.colored_formatter = ColoredFormatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Create plain formatter for file output + self.file_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Create console handler (using colored formatter) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(self.colored_formatter) + self.root_logger.addHandler(console_handler) + + # Prevent log propagation to parent logger + self.root_logger.propagate = False + + # Log configuration information + self._log_configuration() + + def _log_configuration(self): + """Log configuration information""" + config_attrs = [ + 'LOGGING_LEVEL', 'ADVANCED_MODEL', 'BASE_MODEL', 'NEO4J_BATCH_SIZE', + 'WORKING_DIRECTORY', 'KNOWLEDGE_GRAPH_MAX_AST_DEPTH', + 'KNOWLEDGE_GRAPH_CHUNK_SIZE', 'KNOWLEDGE_GRAPH_CHUNK_OVERLAP', + 'MAX_TOKEN_PER_NEO4J_RESULT', 'TEMPERATURE', 'MAX_INPUT_TOKENS', + 'MAX_OUTPUT_TOKENS' + ] + + for attr in config_attrs: + value = getattr(settings, attr, 'Not Set') + self.root_logger.info(f"{attr}={value}") + + def get_logger(self, name: str) -> logging.Logger: + """ + Get logger with specified name + + Args: + name: Logger name, recommended to use full module path + + Returns: + Configured logger instance + """ + # Ensure logger name starts with prometheus + if not name.startswith("prometheus"): + name = f"prometheus.{name}" + + logger = logging.getLogger(name) + + # If it's a child logger, inherit root logger configuration + if name != "prometheus": + logger.parent = self.root_logger + logger.propagate = True + + return logger + + def create_file_handler(self, log_file_path: Path, logger_name: str = "prometheus") -> logging.FileHandler: + """ + Create file handler for specified logger + + Args: + log_file_path: Log file path + logger_name: Logger name + + Returns: + Configured file handler + """ + # Ensure log directory exists + log_file_path.parent.mkdir(parents=True, exist_ok=True) + + # Create file handler (using plain formatter, without colors) + file_handler = logging.FileHandler(log_file_path) + file_handler.setFormatter(self.file_formatter) + + # Get logger and add handler + logger = self.get_logger(logger_name) + logger.addHandler(file_handler) + + return file_handler + + def create_timestamped_file_handler(self, log_dir: Path, prefix: str = "prometheus", + logger_name: str = "prometheus") -> logging.FileHandler: + """ + Create file handler with timestamp + + Args: + log_dir: Log directory + prefix: Log file prefix + logger_name: Logger name + + Returns: + Configured file handler + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = log_dir / f"{prefix}_{timestamp}.log" + return self.create_file_handler(log_file, logger_name) + + def remove_file_handler(self, handler: logging.FileHandler, logger_name: str = "prometheus"): + """ + Remove file handler + + Args: + handler: Handler to remove + logger_name: Logger name + """ + logger = self.get_logger(logger_name) + logger.removeHandler(handler) + handler.close() + + def enable_colors(self): + """Enable colored log output""" + self.colored_formatter.use_colors = True and self.colored_formatter._supports_color() + + def disable_colors(self): + """Disable colored log output""" + self.colored_formatter.use_colors = False + + def is_colors_enabled(self) -> bool: + """Check if colored output is enabled""" + return self.colored_formatter.use_colors + + +# Create global logger manager instance +logger_manager = LoggerManager() + + +def get_logger(name: str) -> logging.Logger: + """ + Convenience function to get logger + + Args: + name: Logger name, recommended to use __name__ or module path + + Returns: + Configured logger instance + + Examples: + >>> logger = get_logger(__name__) + >>> logger = get_logger("prometheus.tools.web_search") + """ + return logger_manager.get_logger(name) + + +def create_file_handler(log_file_path: Path, logger_name: str = "prometheus") -> logging.FileHandler: + """ + Convenience function to create file handler + + Args: + log_file_path: Log file path + logger_name: Logger name + + Returns: + Configured file handler + """ + return logger_manager.create_file_handler(log_file_path, logger_name) + + +def create_timestamped_file_handler(log_dir: Path, prefix: str = "prometheus", + logger_name: str = "prometheus") -> logging.FileHandler: + """ + Convenience function to create timestamped file handler + + Args: + log_dir: Log directory + prefix: Log file prefix + logger_name: Logger name + + Returns: + Configured file handler + """ + return logger_manager.create_timestamped_file_handler(log_dir, prefix, logger_name) + diff --git a/pyproject.toml b/pyproject.toml index 43acfb9b..8fb5e4c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,9 @@ dependencies = [ "unidiff>=0.7.5", "passlib[bcrypt]>=1.7.4", "sqlmodel==0.0.24", - "psycopg2-binary" + "psycopg2-binary", + "mcp>=1.4.1", + "tavily-python>=0.5.1" ] requires-python = ">= 3.11" diff --git a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py index 38842e7e..b64c0790 100644 --- a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py @@ -1,5 +1,7 @@ import pytest -from langchain_core.messages import HumanMessage +from unittest.mock import Mock, patch +from langchain_core.messages import HumanMessage, AIMessage, ToolMessage +from langchain_core.messages.tool import ToolCall from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode from tests.test_utils.util import FakeListChatWithToolsModel @@ -10,6 +12,22 @@ def fake_llm(): return FakeListChatWithToolsModel(responses=["Bug analysis completed successfully"]) +@pytest.fixture +def fake_llm_with_tool_call(): + """LLM that simulates making a web_search tool call.""" + return FakeListChatWithToolsModel(responses=["I need to search for information about this error."]) + + +def test_init_issue_bug_analyzer_node(fake_llm): + """Test IssueBugAnalyzerNode initialization.""" + node = IssueBugAnalyzerNode(fake_llm) + + assert node.system_prompt is not None + assert len(node.tools) == 1 # Should have web_search tool + assert node.tools[0].name == "web_search" + assert node.model_with_tools is not None + + def test_call_method_basic(fake_llm): """Test basic call functionality.""" node = IssueBugAnalyzerNode(fake_llm) @@ -20,3 +38,97 @@ def test_call_method_basic(fake_llm): assert "issue_bug_analyzer_messages" in result assert len(result["issue_bug_analyzer_messages"]) == 1 assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully" + + +def test_web_search_tool_integration(fake_llm_with_tool_call): + """Test that the web_search tool is properly integrated and can be called.""" + node = IssueBugAnalyzerNode(fake_llm_with_tool_call) + state = { + "issue_bug_analyzer_messages": [ + HumanMessage(content="I'm getting a ValueError in my Python code. Can you help analyze it?") + ] + } + + result = node(state) + + # Verify the result contains the response message + assert "issue_bug_analyzer_messages" in result + assert len(result["issue_bug_analyzer_messages"]) == 1 + assert result["issue_bug_analyzer_messages"][0].content == "I need to search for information about this error." + + +def test_web_search_tool_call_with_correct_parameters(fake_llm): + """Test that web_search tool has correct configuration and can be called.""" + node = IssueBugAnalyzerNode(fake_llm) + + # Test that the tool exists and has correct configuration + web_search_tool = node.tools[0] + assert web_search_tool.name == "web_search" + assert "technical information" in web_search_tool.description.lower() + + # Test that the tool has the correct args schema + assert hasattr(web_search_tool, 'args_schema') + assert web_search_tool.args_schema is not None + + +@patch('prometheus.tools.web_search.tavily_client') +def test_web_search_tool_without_api_key(mock_tavily_client, fake_llm): + """Test web_search tool behavior when API key is not available.""" + # Simulate no API key scenario + mock_tavily_client = None + + node = IssueBugAnalyzerNode(fake_llm) + web_search_tool = node.tools[0] + + # The tool should still be created but may handle missing API key gracefully + assert web_search_tool.name == "web_search" + + +def test_system_prompt_contains_web_search_info(fake_llm): + """Test that the system prompt mentions web_search tool.""" + node = IssueBugAnalyzerNode(fake_llm) + + system_prompt_content = node.system_prompt.content.lower() + assert "web_search" in system_prompt_content + assert "technical information" in system_prompt_content + + +def test_web_search_tool_schema_validation(fake_llm): + """Test that the web_search tool has proper input validation.""" + node = IssueBugAnalyzerNode(fake_llm) + web_search_tool = node.tools[0] + + # Check that the tool has an args_schema + assert hasattr(web_search_tool, 'args_schema') + assert web_search_tool.args_schema is not None + + # Test with valid input + valid_input = {"query": "Python debugging techniques"} + # This should not raise an exception + validated_input = web_search_tool.args_schema(**valid_input) + assert validated_input.query == "Python debugging techniques" + + +def test_multiple_tool_calls_in_conversation(fake_llm): + """Test handling multiple web_search calls in a conversation.""" + node = IssueBugAnalyzerNode(fake_llm) + + # Simulate a conversation with tool calls + state = { + "issue_bug_analyzer_messages": [ + HumanMessage(content="Analyze this bug: ImportError in my application"), + AIMessage( + content="Let me search for information about this error.", + tool_calls=[ToolCall(name="web_search", args={"query": "Python ImportError debugging"}, id="call_1")] + ), + ToolMessage(content="Search results: ImportError occurs when...", tool_call_id="call_1"), + HumanMessage(content="The error still persists after trying the suggested fixes") + ] + } + + result = node(state) + + assert "issue_bug_analyzer_messages" in result + assert len(result["issue_bug_analyzer_messages"]) == 1 + # The new response should be added to the conversation + assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully" From a436658b415662f5a3a52227a927b12ef7a7b894 Mon Sep 17 00:00:00 2001 From: cocoli Date: Sat, 16 Aug 2025 17:50:35 +0800 Subject: [PATCH 02/30] tools formatting --- .../app/services/service_coordinator.py | 8 +- prometheus/graph_config/node_tools.yml | 162 ++++ .../lang_graph/nodes/bug_fix_verify_node.py | 15 +- .../nodes/bug_reproducing_execute_node.py | 17 +- .../nodes/bug_reproducing_file_node.py | 33 +- .../nodes/bug_reproducing_write_node.py | 15 +- .../lang_graph/nodes/context_provider_node.py | 135 ++- prometheus/lang_graph/nodes/edit_node.py | 47 +- .../lang_graph/nodes/general_build_node.py | 15 +- .../lang_graph/nodes/general_test_node.py | 15 +- .../nodes/issue_bug_analyzer_node.py | 11 +- prometheus/tools/__init__.py | 5 + prometheus/tools/container_command.py | 41 +- prometheus/tools/file_operation.py | 220 ++++- prometheus/tools/graph_traversal.py | 779 +++++++++--------- prometheus/tools/web_search.py | 145 ++-- pyproject.toml | 3 +- tests/tools/test_mcp_client.py | 83 ++ tests/tools/test_mcp_client_config.py | 97 +++ tests/tools/test_mcp_server.py | 153 ++++ tests/tools/test_mcp_tools.py | 64 ++ 21 files changed, 1389 insertions(+), 674 deletions(-) create mode 100644 prometheus/graph_config/node_tools.yml create mode 100644 tests/tools/test_mcp_client.py create mode 100644 tests/tools/test_mcp_client_config.py create mode 100644 tests/tools/test_mcp_server.py create mode 100644 tests/tools/test_mcp_tools.py diff --git a/prometheus/app/services/service_coordinator.py b/prometheus/app/services/service_coordinator.py index 202660f7..00904e3d 100644 --- a/prometheus/app/services/service_coordinator.py +++ b/prometheus/app/services/service_coordinator.py @@ -17,7 +17,7 @@ from prometheus.app.services.neo4j_service import Neo4jService from prometheus.app.services.repository_service import RepositoryService from prometheus.lang_graph.graphs.issue_state import IssueType -from prometheus.utils.logger_manager import get_logger, create_timestamped_file_handler +from prometheus.utils.logger_manager import get_logger, create_timestamped_file_handler, logger_manager class ServiceCoordinator: @@ -160,9 +160,9 @@ def answer_issue( except Exception as e: self._logger.error(f"Error in answer_issue: {str(e)}\n{traceback.format_exc()}") return None, None, False, False, False, None - finally: - # 移除文件处理器并关闭文件 - self._logger.remove_file_handler(file_handler, "prometheus") + # finally: + # # 移除文件处理器并关闭文件 + # self._logger.remove_file_handler(file_handler, "prometheus") def exists_knowledge_graph(self) -> bool: return self.knowledge_graph_service.exists() diff --git a/prometheus/graph_config/node_tools.yml b/prometheus/graph_config/node_tools.yml new file mode 100644 index 00000000..bca51965 --- /dev/null +++ b/prometheus/graph_config/node_tools.yml @@ -0,0 +1,162 @@ +nodes: + - name: BugFixVerificationSubgraphNode + tools: [] + + - name: BugFixVerifyNode + class: + - ContainerCommandTool: + tools: + - run_command + + - name: BugFixVerifyStructuredNode + tools: [] + + - name: BugReproducingExecuteNode + class: + - ContainerCommandTool: + tools: + - run_command + + - name: BugReproducingFileNode + class: + - FileOperationTool: + tools: + - read_file + - create_file + + - name: BugReproducingStructuredNode + tools: [] + + - name: BugReproducingWriteMessageNode + tools: [] + + - name: BugReproducingWriteNode + class: + - FileOperationTool: + tools: + - read_file + + - name: BugReproductionSubgraphNode + tools: [] + + - name: BuildAndTestSubgraphNode + tools: [] + + - name: ContextExtractionNode + tools: [] + + - name: ContextProviderNode + class: + - GraphTraversalTool: + tools: + - find_file_node_with_basename + - find_file_node_with_relative_path + - find_ast_node_with_text_in_file_with_basename + - find_ast_node_with_text_in_file_with_relative_path + - find_text_node_with_text + - find_text_node_with_text_in_file + - get_next_text_node_with_node_id + - preview_file_content_with_basename + - preview_file_content_with_relative_path + - read_code_with_basename + - read_code_with_relative_path + + - name: ContextQueryMessageNode + tools: [] + + - name: ContextRefineNode + tools: [] + + - name: ContextRetrievalSubgraphNode + tools: [] + + - name: EditMessageNode + tools: [] + + - name: EditNode + class: + - FileOperationTool: + tools: + - read_file + - read_file_with_line_numbers + - create_file + - delete + - edit_file + + - name: FinalPatchSelectionNode + tools: [] + + - name: GeneralBuildNode + class: + - ContainerCommandTool: + tools: + - run_command + + - name: GeneralBuildStructuredNode + tools: [] + + - name: GeneralTestNode + class: + - ContainerCommandTool: + tools: + - run_command + + - name: GeneralTestStructuredNode + tools: [] + + - name: GitDiffNode + tools: [] + + - name: GitResetNode + tools: [] + + - name: IssueBugAnalyzerMessageNode + tools: [] + + - name: IssueBugAnalyzerNode + class: + - WebSearchTool: + tools: + - web_search + + - name: IssueBugContextMessageNode + tools: [] + + - name: IssueBugReproductionContextMessageNode + tools: [] + + - name: IssueBugResponderNode + tools: [] + + - name: IssueBugSubgraphNode + tools: [] + + - name: IssueClassificationContextMessageNode + tools: [] + + - name: IssueClassificationSubgraphNode + tools: [] + + - name: IssueClassifierNode + tools: [] + + - name: IssueNotVerifiedBugSubgraphNode + tools: [] + + - name: IssueVerifiedBugSubgraphNode + tools: [] + + - name: NoopNode + tools: [] + + - name: ResetMessagesNode + tools: [] + + - name: UpdateContainerNode + tools: [] + + - name: UserDefinedBuildNode + tools: [] + + - name: UserDefinedTestNode + tools: [] diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index 1184753e..612bd9af 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -6,7 +6,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerficationState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.logger_manager import get_logger @@ -50,20 +50,21 @@ class BugFixVerifyNode: """ def __init__(self, model: BaseChatModel, container: BaseContainer): - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) self._logger = get_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py index 5c351750..5346aa9c 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py @@ -8,10 +8,10 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.issue_util import format_test_commands from prometheus.utils.patch_util import get_updated_files -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_logger class BugReproducingExecuteNode: @@ -52,20 +52,21 @@ def __init__( test_commands: Optional[Sequence[str]] = None, ): self.test_commands = test_commands - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) self._logger = get_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index d381c87b..149c9b5f 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -6,7 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState -from prometheus.tools import file_operation +from prometheus.tools.file_operation import FileOperationTool from prometheus.utils.lang_graph_util import get_last_message_content from prometheus.utils.logger_manager import get_logger @@ -41,37 +41,32 @@ def __init__( kg: KnowledgeGraph, ): self.kg = kg - self.tools = self._init_tools(str(kg.get_local_path())) + self.file_operation_tool = FileOperationTool(str(kg.get_local_path())) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) self._logger = get_logger(__name__) + - def _init_tools(self, root_path: str): - """Initializes file operation tools with the given root path. - - Args: - root_path: Base directory path for all file operations. - - Returns: - List of StructuredTool instances configured for file operations. - """ + def _init_tools(self): + """Initializes file operation tools.""" tools = [] - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_fn = functools.partial(self.file_operation_tool.read_file) read_file_tool = StructuredTool.from_function( func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, + name=FileOperationTool.read_file.__name__, + description=FileOperationTool.read_file_spec.description, + args_schema=FileOperationTool.read_file_spec.input_schema, ) tools.append(read_file_tool) - create_file_fn = functools.partial(file_operation.create_file, root_path=root_path) + create_file_fn = functools.partial(self.file_operation_tool.create_file) create_file_tool = StructuredTool.from_function( func=create_file_fn, - name=file_operation.create_file.__name__, - description=file_operation.CREATE_FILE_DESCRIPTION, - args_schema=file_operation.CreateFileInput, + name=FileOperationTool.create_file.__name__, + description=FileOperationTool.create_file_spec.description, + args_schema=FileOperationTool.create_file_spec.input_schema, ) tools.append(create_file_tool) diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index 619dd936..3fc2eed7 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -6,7 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState -from prometheus.tools import file_operation +from prometheus.tools.file_operation import FileOperationTool from prometheus.utils.logger_manager import get_logger @@ -112,12 +112,13 @@ def test_empty_array_parsing(parser): ''' def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): - self.tools = self._init_tools(str(kg.get_local_path())) + self.file_operation_tool = FileOperationTool(str(kg.get_local_path())) + self.tools = self._init_tools() self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model_with_tools = model.bind_tools(self.tools) self._logger = get_logger(__name__) - def _init_tools(self, root_path: str): + def _init_tools(self): """Initializes file operation tools with the given root path. Args: @@ -128,12 +129,12 @@ def _init_tools(self, root_path: str): """ tools = [] - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_fn = functools.partial(self.file_operation_tool.read_file) read_file_tool = StructuredTool.from_function( func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, + name=self.file_operation_tool.read_file.__name__, + description=self.file_operation_tool.read_file_spec.description, + args_schema=self.file_operation_tool.read_file_spec.input_schema, ) tools.append(read_file_tool) diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index 39313699..7718e667 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -14,7 +14,7 @@ from langchain_core.messages import SystemMessage from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.tools import graph_traversal +from prometheus.tools.graph_traversal import GraphTraversalTool from prometheus.utils.logger_manager import get_logger @@ -108,6 +108,9 @@ def __init__( """ self.neo4j_driver = neo4j_driver self.max_token_per_result = max_token_per_result + # Initialize GraphTraversalTool with the driver and token limit + self.graph_traversal_tool = GraphTraversalTool(neo4j_driver, max_token_per_result) + ast_node_types_str = ", ".join(kg.get_all_ast_node_types()) self.system_prompt = SystemMessage( @@ -132,15 +135,13 @@ def _init_tools(self): # Tool: Find file node by filename (basename) # Used when only the filename (not full path) is known find_file_node_with_basename_fn = functools.partial( - graph_traversal.find_file_node_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.find_file_node_with_basename, ) find_file_node_with_basename_tool = StructuredTool.from_function( func=find_file_node_with_basename_fn, - name=graph_traversal.find_file_node_with_basename.__name__, - description=graph_traversal.FIND_FILE_NODE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindFileNodeWithBasenameInput, + name=self.graph_traversal_tool.find_file_node_with_basename.__name__, + description=self.graph_traversal_tool.find_file_node_with_basename_spec.description, + args_schema=self.graph_traversal_tool.find_file_node_with_basename_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_file_node_with_basename_tool) @@ -148,15 +149,13 @@ def _init_tools(self): # Tool: Find file node by relative path # Preferred method when the exact file path is known find_file_node_with_relative_path_fn = functools.partial( - graph_traversal.find_file_node_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.find_file_node_with_relative_path ) find_file_node_with_relative_path_tool = StructuredTool.from_function( func=find_file_node_with_relative_path_fn, - name=graph_traversal.find_file_node_with_relative_path.__name__, - description=graph_traversal.FIND_FILE_NODE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindFileNodeWithRelativePathInput, + name=self.graph_traversal_tool.find_file_node_with_relative_path.__name__, + description=self.graph_traversal_tool.find_file_node_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.find_file_node_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_file_node_with_relative_path_tool) @@ -166,30 +165,26 @@ def _init_tools(self): # Tool: Find AST node by text match in file (by basename) # Useful for searching specific snippets or patterns in unknown locations find_ast_node_with_text_in_file_with_basename_fn = functools.partial( - graph_traversal.find_ast_node_with_text_in_file_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename, ) find_ast_node_with_text_in_file_with_basename_tool = StructuredTool.from_function( func=find_ast_node_with_text_in_file_with_basename_fn, - name=graph_traversal.find_ast_node_with_text_in_file_with_basename.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTextInFileWithBasenameInput, + name=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename.__name__, + description=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename_spec.description, + args_schema=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_ast_node_with_text_in_file_with_basename_tool) # Tool: Find AST node by text match in file (by relative path) find_ast_node_with_text_in_file_with_relative_path_fn = functools.partial( - graph_traversal.find_ast_node_with_text_in_file_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path, ) find_ast_node_with_text_in_file_with_relative_path_tool = StructuredTool.from_function( func=find_ast_node_with_text_in_file_with_relative_path_fn, - name=graph_traversal.find_ast_node_with_text_in_file_with_relative_path.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTextInFileWithRelativePathInput, + name=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path.__name__, + description=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_ast_node_with_text_in_file_with_relative_path_tool) @@ -197,30 +192,26 @@ def _init_tools(self): # Tool: Find AST node by type in file (by basename) # Example types: FunctionDef, ClassDef, Assign, etc. find_ast_node_with_type_in_file_with_basename_fn = functools.partial( - graph_traversal.find_ast_node_with_type_in_file_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename, ) find_ast_node_with_type_in_file_with_basename_tool = StructuredTool.from_function( func=find_ast_node_with_type_in_file_with_basename_fn, - name=graph_traversal.find_ast_node_with_type_in_file_with_basename.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTypeInFileWithBasenameInput, + name=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename.__name__, + description=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename_spec.description, + args_schema=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_ast_node_with_type_in_file_with_basename_tool) # Tool: Find AST node by type in file (by relative path) find_ast_node_with_type_in_file_with_relative_path_fn = functools.partial( - graph_traversal.find_ast_node_with_type_in_file_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path, ) find_ast_node_with_type_in_file_with_relative_path_tool = StructuredTool.from_function( func=find_ast_node_with_type_in_file_with_relative_path_fn, - name=graph_traversal.find_ast_node_with_type_in_file_with_relative_path.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTypeInFileWithRelativePathInput, + name=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path.__name__, + description=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_ast_node_with_type_in_file_with_relative_path_tool) @@ -229,45 +220,39 @@ def _init_tools(self): # Tool: Find text node globally by keyword find_text_node_with_text_fn = functools.partial( - graph_traversal.find_text_node_with_text, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.find_text_node_with_text, ) find_text_node_with_text_tool = StructuredTool.from_function( func=find_text_node_with_text_fn, - name=graph_traversal.find_text_node_with_text.__name__, - description=graph_traversal.FIND_TEXT_NODE_WITH_TEXT_DESCRIPTION, - args_schema=graph_traversal.FindTextNodeWithTextInput, + name=self.graph_traversal_tool.find_text_node_with_text.__name__, + description=self.graph_traversal_tool.find_text_node_with_text_spec.description, + args_schema=self.graph_traversal_tool.find_text_node_with_text_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_text_node_with_text_tool) # Tool: Find text node by keyword in specific file find_text_node_with_text_in_file_fn = functools.partial( - graph_traversal.find_text_node_with_text_in_file, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.find_text_node_with_text_in_file, ) find_text_node_with_text_in_file_tool = StructuredTool.from_function( func=find_text_node_with_text_in_file_fn, - name=graph_traversal.find_text_node_with_text_in_file.__name__, - description=graph_traversal.FIND_TEXT_NODE_WITH_TEXT_IN_FILE_DESCRIPTION, - args_schema=graph_traversal.FindTextNodeWithTextInFileInput, + name=self.graph_traversal_tool.find_text_node_with_text_in_file.__name__, + description=self.graph_traversal_tool.find_text_node_with_text_in_file_spec.description, + args_schema=self.graph_traversal_tool.find_text_node_with_text_in_file_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_text_node_with_text_in_file_tool) # Tool: Fetch the next text node chunk in a chain (used for long docs/comments) get_next_text_node_with_node_id_fn = functools.partial( - graph_traversal.get_next_text_node_with_node_id, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.get_next_text_node_with_node_id, ) get_next_text_node_with_node_id_tool = StructuredTool.from_function( func=get_next_text_node_with_node_id_fn, - name=graph_traversal.get_next_text_node_with_node_id.__name__, - description=graph_traversal.GET_NEXT_TEXT_NODE_WITH_NODE_ID_DESCRIPTION, - args_schema=graph_traversal.GetNextTextNodeWithNodeIdInput, + name=self.graph_traversal_tool.get_next_text_node_with_node_id.__name__, + description=self.graph_traversal_tool.get_next_text_node_with_node_id_spec.description, + args_schema=self.graph_traversal_tool.get_next_text_node_with_node_id_spec.input_schema, response_format="content_and_artifact", ) tools.append(get_next_text_node_with_node_id_tool) @@ -276,60 +261,52 @@ def _init_tools(self): # Tool: Preview contents of file by basename preview_file_content_with_basename_fn = functools.partial( - graph_traversal.preview_file_content_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.preview_file_content_with_basename, ) preview_file_content_with_basename_tool = StructuredTool.from_function( func=preview_file_content_with_basename_fn, - name=graph_traversal.preview_file_content_with_basename.__name__, - description=graph_traversal.PREVIEW_FILE_CONTENT_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.PreviewFileContentWithBasenameInput, + name=self.graph_traversal_tool.preview_file_content_with_basename.__name__, + description=self.graph_traversal_tool.preview_file_content_with_basename_spec.description, + args_schema=self.graph_traversal_tool.preview_file_content_with_basename_spec.input_schema, response_format="content_and_artifact", ) tools.append(preview_file_content_with_basename_tool) # Tool: Preview contents of file by relative path preview_file_content_with_relative_path_fn = functools.partial( - graph_traversal.preview_file_content_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.preview_file_content_with_relative_path, ) preview_file_content_with_relative_path_tool = StructuredTool.from_function( func=preview_file_content_with_relative_path_fn, - name=graph_traversal.preview_file_content_with_relative_path.__name__, - description=graph_traversal.PREVIEW_FILE_CONTENT_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.PreviewFileContentWithRelativePathInput, + name=self.graph_traversal_tool.preview_file_content_with_relative_path.__name__, + description=self.graph_traversal_tool.preview_file_content_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.preview_file_content_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(preview_file_content_with_relative_path_tool) # Tool: Read entire code file by basename read_code_with_basename_fn = functools.partial( - graph_traversal.read_code_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.read_code_with_basename, ) read_code_with_basename_tool = StructuredTool.from_function( func=read_code_with_basename_fn, - name=graph_traversal.read_code_with_basename.__name__, - description=graph_traversal.READ_CODE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.ReadCodeWithBasenameInput, + name=self.graph_traversal_tool.read_code_with_basename.__name__, + description=self.graph_traversal_tool.read_code_with_basename_spec.description, + args_schema=self.graph_traversal_tool.read_code_with_basename_spec.input_schema, response_format="content_and_artifact", ) tools.append(read_code_with_basename_tool) # Tool: Read entire code file by relative path read_code_with_relative_path_fn = functools.partial( - graph_traversal.read_code_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, + self.graph_traversal_tool.read_code_with_relative_path, ) read_code_with_relative_path_tool = StructuredTool.from_function( func=read_code_with_relative_path_fn, - name=graph_traversal.read_code_with_relative_path.__name__, - description=graph_traversal.READ_CODE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.ReadCodeWithRelativePathInput, + name=self.graph_traversal_tool.read_code_with_relative_path.__name__, + description=self.graph_traversal_tool.read_code_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.read_code_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(read_code_with_relative_path_tool) diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index ad5b75e1..c519d347 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -14,7 +14,7 @@ from langchain_core.messages import SystemMessage from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.tools import file_operation +from prometheus.tools.file_operation import FileOperationTool from prometheus.utils.logger_manager import get_logger @@ -118,11 +118,12 @@ def other_method(): def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): self.system_prompt = SystemMessage(self.SYS_PROMPT) - self.tools = self._init_tools(kg.get_local_path()) + self.file_operation_tool = FileOperationTool(str(kg.get_local_path())) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self._logger = get_logger(__name__) - def _init_tools(self, root_path: str): + def _init_tools(self): """Initializes file operation tools with the given root path. Args: @@ -133,50 +134,50 @@ def _init_tools(self, root_path: str): """ tools = [] - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_fn = functools.partial(self.file_operation_tool.read_file) read_file_tool = StructuredTool.from_function( func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, + name=self.file_operation_tool.read_file.__name__, + description=self.file_operation_tool.read_file_spec.description, + args_schema=self.file_operation_tool.read_file_spec.input_schema, ) tools.append(read_file_tool) read_file_with_line_numbers_fn = functools.partial( - file_operation.read_file_with_line_numbers, root_path=root_path + self.file_operation_tool.read_file_with_line_numbers ) read_file_with_line_numbers_tool = StructuredTool.from_function( func=read_file_with_line_numbers_fn, - name=file_operation.read_file_with_line_numbers.__name__, - description=file_operation.READ_FILE_WITH_LINE_NUMBERS_DESCRIPTION, - args_schema=file_operation.ReadFileWithLineNumbersInput, + name=self.file_operation_tool.read_file_with_line_numbers.__name__, + description=self.file_operation_tool.read_file_with_line_numbers_spec.description, + args_schema=self.file_operation_tool.read_file_with_line_numbers_spec.input_schema, ) tools.append(read_file_with_line_numbers_tool) - create_file_fn = functools.partial(file_operation.create_file, root_path=root_path) + create_file_fn = functools.partial(self.file_operation_tool.create_file) create_file_tool = StructuredTool.from_function( func=create_file_fn, - name=file_operation.create_file.__name__, - description=file_operation.CREATE_FILE_DESCRIPTION, - args_schema=file_operation.CreateFileInput, + name=self.file_operation_tool.create_file.__name__, + description=self.file_operation_tool.create_file_spec.description, + args_schema=self.file_operation_tool.create_file_spec.input_schema, ) tools.append(create_file_tool) - delete_fn = functools.partial(file_operation.delete, root_path=root_path) + delete_fn = functools.partial(self.file_operation_tool.delete) delete_tool = StructuredTool.from_function( func=delete_fn, - name=file_operation.delete.__name__, - description=file_operation.DELETE_DESCRIPTION, - args_schema=file_operation.DeleteInput, + name=self.file_operation_tool.delete.__name__, + description=self.file_operation_tool.delete_spec.description, + args_schema=self.file_operation_tool.delete_spec.input_schema, ) tools.append(delete_tool) - edit_file_fn = functools.partial(file_operation.edit_file, root_path=root_path) + edit_file_fn = functools.partial(self.file_operation_tool.edit_file) edit_file_tool = StructuredTool.from_function( func=edit_file_fn, - name=file_operation.edit_file.__name__, - description=file_operation.EDIT_FILE_DESCRIPTION, - args_schema=file_operation.EditFileInput, + name=self.file_operation_tool.edit_file.__name__, + description=self.file_operation_tool.edit_file_spec.description, + args_schema=self.file_operation_tool.edit_file_spec.input_schema, ) tools.append(edit_file_tool) diff --git a/prometheus/lang_graph/nodes/general_build_node.py b/prometheus/lang_graph/nodes/general_build_node.py index b176eaf2..2024e90b 100644 --- a/prometheus/lang_graph/nodes/general_build_node.py +++ b/prometheus/lang_graph/nodes/general_build_node.py @@ -7,7 +7,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.logger_manager import get_logger @@ -43,20 +43,21 @@ class GeneralBuildNode: def __init__(self, model: BaseChatModel, container: BaseContainer, kg: KnowledgeGraph): self.kg = kg - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) self._logger = get_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/general_test_node.py b/prometheus/lang_graph/nodes/general_test_node.py index faac21cf..7ac86ad7 100644 --- a/prometheus/lang_graph/nodes/general_test_node.py +++ b/prometheus/lang_graph/nodes/general_test_node.py @@ -7,7 +7,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.logger_manager import get_logger @@ -60,20 +60,21 @@ class GeneralTestNode: def __init__(self, model: BaseChatModel, container: BaseContainer, kg: KnowledgeGraph): self.kg = kg - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) self._logger = get_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index 3301156e..d91a11c3 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -4,7 +4,7 @@ from langchain_core.messages import SystemMessage from langchain.tools import StructuredTool import functools -from prometheus.tools import web_search +from prometheus.tools.web_search import WebSearchTool from prometheus.utils.logger_manager import get_logger @@ -95,6 +95,7 @@ class IssueBugAnalyzerNode: """ def __init__(self, model: BaseChatModel): + self.web_search_tool = WebSearchTool() self.model = model self.system_prompt = SystemMessage(self.SYS_PROMPT) self.tools = self._init_tools() @@ -105,12 +106,12 @@ def _init_tools(self): """Initializes tools for the node.""" tools = [] - web_search_fn = functools.partial(web_search.web_search) + web_search_fn = functools.partial(self.web_search_tool.web_search) web_search_tool = StructuredTool.from_function( func=web_search_fn, - name=web_search.web_search.__name__, - description=web_search.WEB_SEARCH_DESCRIPTION, - args_schema=web_search.WebSearchInput, + name=self.web_search_tool.web_search.__name__, + description=self.web_search_tool.web_search_spec.description, + args_schema=self.web_search_tool.web_search_spec.input_schema, ) tools.append(web_search_tool) diff --git a/prometheus/tools/__init__.py b/prometheus/tools/__init__.py index e69de29b..15e3cbf6 100644 --- a/prometheus/tools/__init__.py +++ b/prometheus/tools/__init__.py @@ -0,0 +1,5 @@ + +# Ensure MCP tools are registered when this package is imported by the MCP server +# Importing the module executes the @mcp.tool decorators +from prometheus.mcp_tools import web_search as _mcp_web_search # noqa: F401 + diff --git a/prometheus/tools/container_command.py b/prometheus/tools/container_command.py index 42459ad8..d558c481 100644 --- a/prometheus/tools/container_command.py +++ b/prometheus/tools/container_command.py @@ -1,17 +1,38 @@ from pydantic import BaseModel, Field - from prometheus.docker.general_container import GeneralContainer +from dataclasses import dataclass +@dataclass +class ToolSpec: + description: str + input_schema: type class RunCommandInput(BaseModel): command: str = Field("The shell command to be run in the container") - -RUN_COMMAND_DESCRIPTION = """\ -Run a shell command in the container and return the result of the command. You are always at the root -of the codebase. -""" - - -def run_command(command: str, container: GeneralContainer) -> str: - return container.execute_command(command) +class ContainerCommandTool: + """Tool class for executing shell commands in containers.""" + + run_command_spec = ToolSpec( + description="""\ + Run a shell command in the container and return the result of the command. You are always at the root + of the codebase. + """, + input_schema=RunCommandInput + ) + + def __init__(self, container: GeneralContainer): + """Initialize the container command tool. + Args: + container: The GeneralContainer instance to execute commands in. + """ + self.container = container + + def run_command(self, command: str) -> str: + """Run a shell command in the container and return the result. + Args: + command: The shell command to be run in the container. + Returns: + The output of the command execution. + """ + return self.container.execute_command(command) diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index 330e149f..3db40e53 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -1,7 +1,7 @@ import os import shutil from pathlib import Path - +from dataclasses import dataclass from pydantic import BaseModel, Field from prometheus.utils.str_util import pre_append_line_numbers @@ -9,10 +9,200 @@ logger = get_logger(__name__) +@dataclass +class ToolSpec: + description: str + input_schema: type class ReadFileInput(BaseModel): relative_path: str = Field("The relative path of the file to read") +class ReadFileWithLineNumbersInput(BaseModel): + relative_path: str = Field( + description="The relative path of the file to read, eg. foo/bar/test.py, not absolute path" + ) + start_line: int = Field(description="The start line number to read, 1-indexed and inclusive") + end_line: int = Field(description="The ending line number to read, 1-indexed and exclusive") + +class CreateFileInput(BaseModel): + relative_path: str = Field( + description="The relative path of the file to create, eg. foo/bar/test.py, not absolute path" + ) + content: str = Field(description="The content of the file to create") + +class DeleteInput(BaseModel): + relative_path: str = Field( + description="The relative path of the file/dir to delete, eg. foo/bar/test.py, not absolute path" + ) + +class EditFileInput(BaseModel): + relative_path: str = Field( + description="The relative path of the file to edit, eg. foo/bar/test.py, not absolute path" + ) + old_content: str = Field( + description="The exact string content to be replaced in the file. Must match exactly one occurrence in the file" + ) + new_content: str = Field( + description="The new content that will replace the old_content in the file" + ) + + +class FileOperationTool: + """Tool class for file operations including reading, creating, editing, and deleting files.""" + + # File operation tools + read_file_spec = ToolSpec( + description="""\ + Read the content of a file with line numbers prepended from the codebase with a safety limit on the number of lines. + Returns up to the first 1000 lines by default to prevent context issues with large files. + Returns an error message if the file doesn't exist. + """, + input_schema=ReadFileInput + ) + + read_file_with_line_numbers_spec = ToolSpec( + description="""\ + Read a specific range of lines from a file and return the content with line numbers prepended. + The line numbers are 1-indexed where start_line is inclusive and end_line is exclusive. + For best results when analyzing code or text files, consider reading chunks of 500-1000 lines at a time. + """, + input_schema=ReadFileWithLineNumbersInput + ) + + create_file_spec = ToolSpec( + description="""\ + Create a new file at the specified path with the given content. + If the parent directories don't exist, they will be created automatically. + Returns an error message if the file already exists. + """, + input_schema=CreateFileInput + ) + + delete_spec = ToolSpec( + description="""\ + Delete a file or directory at the specified path. + For directories, it will recursively delete all contents. + Returns an error message if the path doesn't exist. + """, + input_schema=DeleteInput + ) + + edit_file_spec = ToolSpec( + description="""\ + Edit a file by replacing specific content with new content. + Performs an exact string replacement of old_content with new_content. + Returns an error message if: + - The file doesn't exist + - The old_content is not found in the file + - The old_content matches multiple locations (in which case more context is needed) + - The provided path is absolute instead of relative + + Example usage: + edit_file( + relative_path="src/calculator.py", + old_content="return a * b", + new_content="return a / b" + ) + """, + input_schema=EditFileInput + ) + + def __init__(self, root_path: str): + """Initialize the file operation tool. + + Args: + root_path: The root path of the codebase for relative path operations. + """ + self.root_path = root_path + + def read_file(self, relative_path: str, n_lines: int = 1000) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." + + with file_path.open() as f: + lines = f.readlines() + + return pre_append_line_numbers("".join(lines[:n_lines]), 1) + + def read_file_with_line_numbers(self, relative_path: str, start_line: int, end_line: int) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." + + if end_line < start_line: + return f"The end line number {end_line} must be greater than the start line number {start_line}." + + zero_based_start_line = start_line - 1 + zero_based_end_line = end_line - 1 + + with file_path.open() as f: + lines = f.readlines() + + return pre_append_line_numbers( + "".join(lines[zero_based_start_line:zero_based_end_line]), start_line + ) + + def create_file(self, relative_path: str, content: str) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if file_path.exists(): + return f"The file {relative_path} already exists." + + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content) + return f"The file {relative_path} has been created." + + + def delete(self, relative_path: str) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." + + if file_path.is_dir(): + shutil.rmtree(file_path) + return f"The directory {relative_path} has been deleted." + + file_path.unlink() + return f"The file {relative_path} has been deleted." + + def edit_file(self, relative_path: str, old_content: str, new_content: str) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." + + content = file_path.read_text() + + occurrences = content.count(old_content) + + if occurrences == 0: + return f"No match found for the specified content in {relative_path}. Please verify the content to replace." + + if occurrences > 1: + return ( + f"Found {occurrences} occurrences of the specified content in {relative_path}. " + "Please provide more context to ensure a unique match." + ) + + new_content_full = content.replace(old_content, new_content) + file_path.write_text(new_content_full) + + return f"Successfully edited {relative_path}." + READ_FILE_DESCRIPTION = """\ Read the content of a file with line numbers prepended from the codebase with a safety limit on the number of lines. @@ -35,12 +225,6 @@ def read_file(relative_path: str, root_path: str, n_lines: int = 1000) -> str: return pre_append_line_numbers("".join(lines[:n_lines]), 1) -class ReadFileWithLineNumbersInput(BaseModel): - relative_path: str = Field( - description="The relative path of the file to read, eg. foo/bar/test.py, not absolute path" - ) - start_line: int = Field(description="The start line number to read, 1-indexed and inclusive") - end_line: int = Field(description="The ending line number to read, 1-indexed and exclusive") READ_FILE_WITH_LINE_NUMBERS_DESCRIPTION = """\ @@ -74,12 +258,6 @@ def read_file_with_line_numbers( ) -class CreateFileInput(BaseModel): - relative_path: str = Field( - description="The relative path of the file to create, eg. foo/bar/test.py, not absolute path" - ) - content: str = Field(description="The content of the file to create") - CREATE_FILE_DESCRIPTION = """\ Create a new file at the specified path with the given content. @@ -101,11 +279,6 @@ def create_file(relative_path: str, root_path: str, content: str) -> str: return f"The file {relative_path} has been created." -class DeleteInput(BaseModel): - relative_path: str = Field( - description="The relative path of the file/dir to delete, eg. foo/bar/test.py, not absolute path" - ) - DELETE_DESCRIPTION = """\ Delete a file or directory at the specified path. @@ -130,17 +303,6 @@ def delete(relative_path: str, root_path: str) -> str: return f"The file {relative_path} has been deleted." -class EditFileInput(BaseModel): - relative_path: str = Field( - description="The relative path of the file to edit, eg. foo/bar/test.py, not absolute path" - ) - old_content: str = Field( - description="The exact string content to be replaced in the file. Must match exactly one occurrence in the file" - ) - new_content: str = Field( - description="The new content that will replace the old_content in the file" - ) - EDIT_FILE_DESCRIPTION = """\ Edit a file by replacing specific content with new content. diff --git a/prometheus/tools/graph_traversal.py b/prometheus/tools/graph_traversal.py index d2d6d790..6235cbfb 100644 --- a/prometheus/tools/graph_traversal.py +++ b/prometheus/tools/graph_traversal.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Any, Mapping, Sequence, Union +from dataclasses import dataclass from neo4j import GraphDatabase from pydantic import BaseModel, Field @@ -10,7 +11,6 @@ MAX_RESULT = 30 - """ Tools for retrieving nodes from the Neo4j graph database. These tools allow you to search for FileNode, ASTNode, and TextNode based on various attributes @@ -20,409 +20,54 @@ The content is a string representation of the node(s) found, and the artifact is a list of dictionaries """ - -############################################################################### -# FileNode retrieval # -############################################################################### +@dataclass +class ToolSpec: + description: str + input_schema: type class FindFileNodeWithBasenameInput(BaseModel): basename: str = Field("The basename of FileNode to search for") - -FIND_FILE_NODE_WITH_BASENAME_DESCRIPTION = """\ -Find all FileNode in the graph with this basename of a file/dir. The basename must -include the extension, like 'bar.py', 'baz.java' or 'foo' -(in this case foo is a directory or a file without extension). - -You can use this tool to check if a file/dir with this basename exists or get all -attributes related to the file/dir.""" - - -def find_file_node_with_basename( - basename: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f""" - MATCH (f:FileNode {{ basename: '{basename}' }}) - RETURN f AS FileNode - ORDER BY f.node_id - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - class FindFileNodeWithRelativePathInput(BaseModel): relative_path: str = Field("The relative_path of FileNode to search for") - -FIND_FILE_NODE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Search FileNode in the graph with this relative_path of a file/dir. The relative_path is -the relative path from the root path of codebase. The relative_path must include the extension, -like 'foo/bar/baz.java'. - -You can use this tool to check if a file/dir with this relative_path exists or get all -attributes related to the file/dir.""" - - -def find_file_node_with_relative_path( - relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) - RETURN f AS FileNode - ORDER BY f.node_id - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - -############################################################################### -# ASTNode retrieval # -############################################################################### - - class FindASTNodeWithTextInFileWithBasenameInput(BaseModel): text: str = Field("Search ASTNode that exactly contains this text.") basename: str = Field("The basename of file/directory to search under for ASTNodes.") - -FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_BASENAME_DESCRIPTION = """\ -Find all ASTNode in the graph that exactly contains this text in any source file under -a file/directory with this basename. For reliable results, search for longer, distinct text -sequences rather than short common words or fragments. The contains is same as python's check -`'foo' in text`, ie. it is case sensitive and is looking for exact matches. For best results, -use unique text segments of at least several words. The basename can be either a file (like -'bar.py', 'baz.java') or a directory (like 'src' or 'test').""" - - -def find_ast_node_with_text_in_file_with_basename( - text: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE f.basename = '{basename}' AND a.text CONTAINS '{text}' - RETURN c as FileNode, a AS ASTNode - ORDER BY SIZE(a.text) - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - class FindASTNodeWithTextInFileWithRelativePathInput(BaseModel): text: str = Field("Search ASTNode that exactly contains this text.") relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") - -FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Find all ASTNode in the graph that exactly contains this text in any source file under -a file/directory with this relative path. For reliable results, search for longer, distinct text -sequences rather than short common words or fragments. The contains is same as python's check `'foo' in text`, -ie. it is case sensitive and is looking for exact matches. Therefore the search text should -be exact as well. The relative path should be the path from the root of codebase -(like 'src/core/parser.py' or 'test/unit').""" - - -def find_ast_node_with_text_in_file_with_relative_path( - text: str, relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE f.relative_path = '{relative_path}' AND a.text CONTAINS '{text}' - RETURN c as FileNode, a AS ASTNode - ORDER BY SIZE(a.text) - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - class FindASTNodeWithTypeInFileWithBasenameInput(BaseModel): type: str = Field("Search ASTNode with this tree-sitter node type.") basename: str = Field("The basename of file/directory to search under for ASTNodes.") - -FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_BASENAME_DESCRIPTION = """\ -Find all ASTNode in the graph that has this tree-sitter node type in any source file under -a file/directory with this basename. This tool is useful for searching class/function/method -under a file/directory. The basename can be either a file (like 'bar.py', -'baz.java') or a directory (like 'core' or 'test').""" - - -def find_ast_node_with_type_in_file_with_basename( - type: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE f.basename = '{basename}' AND a.type = '{type}' - RETURN c as FileNode, a AS ASTNode - ORDER BY SIZE(a.text) - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - class FindASTNodeWithTypeInFileWithRelativePathInput(BaseModel): type: str = Field("Search ASTNode with this tree-sitter node type.") relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") - -FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Find all ASTNode in the graph that has this tree-sitter node type in any source file under -a file/directory with this relative path. This tool is useful for searching class/function/method -under a file/directory. The relative path should be the path from the root -of codebase (like 'src/core/parser.py' or 'test/unit').""" - - -def find_ast_node_with_type_in_file_with_relative_path( - type: str, relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE f.relative_path = '{relative_path}' AND a.type = '{type}' - RETURN c as FileNode, a AS ASTNode - ORDER BY SIZE(a.text) - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - -############################################################################### -# TextNode retrieval # -############################################################################### - - class FindTextNodeWithTextInput(BaseModel): text: str = Field("Search TextNode that exactly contains this text.") - -FIND_TEXT_NODE_WITH_TEXT_DESCRIPTION = """\ -Find all TextNode in the graph that exactly contains this text. The contains is -same as python's check `'foo' in text`, ie. it is case sensitive and is -looking for exact matches. Therefore the search text should be exact as well. - -You can use this tool to find all text/documentation in codebase that contains this text.""" - - -def find_text_node_with_text( - text: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (f:FileNode) -[:HAS_TEXT]-> (t:TextNode) - WHERE t.text CONTAINS '{text}' - RETURN f as FileNode, t AS TextNode - ORDER BY t.node_id - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - class FindTextNodeWithTextInFileInput(BaseModel): text: str = Field("Search TextNode that exactly contains this text.") basename: str = Field("The basename of FileNode to search TextNode.") - -FIND_TEXT_NODE_WITH_TEXT_IN_FILE_DESCRIPTION = """\ -Find all TextNode in the graph that exactly contains this text in a file with this basename. -The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is -looking for exact matches. Therefore the search text should be exact as well. -The basename must include the extension, like 'bar.py', 'baz.java' or 'foo' -(in this case foo is a directory or a file without extension). - -You can use this tool to find text/documentation in a specific file that contains this text.""" - - -def find_text_node_with_text_in_file( - text: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (f:FileNode) -[:HAS_TEXT]-> (t:TextNode) - WHERE f.basename = '{basename}' AND t.text CONTAINS '{text}' - RETURN f as FileNode, t AS TextNode - ORDER BY t.node_id - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - class GetNextTextNodeWithNodeIdInput(BaseModel): node_id: int = Field("Get the next TextNode of this given node_id.") - -GET_NEXT_TEXT_NODE_WITH_NODE_ID_DESCRIPTION = """\ -Get the next TextNode of this given node_id. - -You can use this tool to read the next section of text that you are interested in.""" - - -def get_next_text_node_with_node_id( - node_id: int, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (f:FileNode) -[:HAS_TEXT]-> (a:TextNode {{ node_id: {node_id} }}) -[:NEXT_CHUNK]-> (b:TextNode) - RETURN f as FileNode, b AS TextNode - """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) - - -############################################################################### -# Other # -############################################################################### - - class PreviewFileContentWithBasenameInput(BaseModel): basename: str = Field("The basename of FileNode to preview.") - -PREVIEW_FILE_CONTENT_WITH_BASENAME_DESCRIPTION = """\ -Preview the content of a file with this basename. The basename must include -the extension, like 'bar.py', 'baz.java' or 'foo' (in this case foo is a -directory or a file without extension). - -You can use this tool to preview the content of a specific file to see what it contains -in the first 1000 lines or the first section. If the file is interesting, use other tools -to look at the file.""" - - -def preview_file_content_with_basename( - basename: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - source_code_query = f"""\ - MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_AST]-> (a:ASTNode) - WITH f, apoc.text.split(a.text, '\\R') AS lines - RETURN - f AS FileNode, - {{ - text: apoc.text.join(lines[0..1000], '\\n'), - start_line: 1, - end_line: 1000 - }} AS preview - ORDER BY f.node_id - """ - - text_query = f"""\ - MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_TEXT]-> (t:TextNode) - WHERE NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) - RETURN f as FileNode, t.text AS preview - ORDER BY f.node_id - """ - - if tree_sitter_parser.supports_file(Path(basename)): - data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, driver) - else: - data = neo4j_util.run_neo4j_query_without_formatting(text_query, driver) - for result in data: - if isinstance(result["preview"], dict): - result["preview"]["text"] = pre_append_line_numbers( - result["preview"]["text"], result["preview"]["start_line"] - ) - result["preview"]["end_line"] = ( - result["preview"]["start_line"] + len(result["preview"]["text"].splitlines()) - 1 - ) - return neo4j_util.format_neo4j_data(data, max_token_per_result), data - - class PreviewFileContentWithRelativePathInput(BaseModel): relative_path: str = Field("The relative path of FileNode to preview.") - -PREVIEW_FILE_CONTENT_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Preview the content of a file with this relative path from the root of codebase. -The relative path must include the extension and full path from root, like 'src/core/parser.py', -'test/unit/test_parser.java' or 'docs/README.md'. - -You can use this tool to preview the content of a specific file to see what it contains -in the first 1000 lines or the first section. If the file is interesting, use other tools -to look at the file.""" - - -def preview_file_content_with_relative_path( - relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int -) -> tuple[str, Sequence[Mapping[str, Any]]]: - source_code_query = f"""\ - MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_AST]-> (a:ASTNode) - WITH f, apoc.text.split(a.text, '\\R') AS lines - RETURN - f as FileNode, - {{ - text: apoc.text.join(lines[0..1000], '\\n'), - start_line: 1, - end_line: 1000 - }} AS preview - ORDER BY f.node_id - """ - - text_query = f"""\ - MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_TEXT]-> (t:TextNode) - WHERE NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) - RETURN f as FileNode, t.text AS preview - ORDER BY f.node_id - """ - - if tree_sitter_parser.supports_file(Path(relative_path)): - data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, driver) - else: - data = neo4j_util.run_neo4j_query_without_formatting(text_query, driver) - for result in data: - if isinstance(result["preview"], dict): - result["preview"]["text"] = pre_append_line_numbers( - result["preview"]["text"], result["preview"]["start_line"] - ) - result["preview"]["end_line"] = ( - result["preview"]["start_line"] + len(result["preview"]["text"].splitlines()) - 1 - ) - return neo4j_util.format_neo4j_data(data, max_token_per_result), data - - class ReadCodeWithBasenameInput(BaseModel): basename: str = Field("The basename of FileNode to read.") start_line: int = Field("The starting line number, 1-indexed and inclusive.") - end_line: int = Field("The ending line number, 1-indexed and exclusive.") - - -READ_CODE_WITH_BASENAME_DESCRIPTION = """\ -Read a specific section of a source code file's content by specifying its basename and line range. -The basename must include the extension, like 'bar.py' or 'baz.java' - -This tool ONLY works with source code files (not text files or documentation). It is designed -to read large sections of code at once - you should request substantial chunks (hundreds of lines) -rather than making multiple small requests of 10-20 lines each, which would be inefficient. - -Line numbers are 1-indexed, where start_line is inclusive and end_line is exclusive. - -This tool is useful for examining specific sections of source code files when you know -the exact line range you want to analyze. The function will return an error message if -end_line is less than start_line. -""" - - -def read_code_with_basename( - basename: str, - start_line: int, - end_line: int, - driver: GraphDatabase.driver, - max_token_per_result: int, -) -> tuple[str, Union[Sequence[Mapping[str, Any]], None]]: - if end_line < start_line: - return f"end_line {end_line} must be greater than start_line {start_line}", None - - source_code_query = f"""\ - MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_AST]-> (a:ASTNode) - WITH f, apoc.text.split(a.text, '\\R') AS lines - RETURN - f as FileNode, - {{ - text: apoc.text.join(lines[{start_line - 1}..{end_line - 1}], '\\n'), - start_line: {start_line}, - end_line: {end_line} - }} AS SelectedLines - ORDER BY f.node_id - """ - data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, driver) - for result in data: - result["SelectedLines"]["text"] = pre_append_line_numbers( - result["SelectedLines"]["text"], result["SelectedLines"]["start_line"] - ) - return neo4j_util.format_neo4j_data(data, max_token_per_result), data - + end_line: int = Field("The ending line number, 1-indexed and exclusive.") class ReadCodeWithRelativePathInput(BaseModel): relative_path: str = Field("The relative path of FileNode to read from root of codebase.") @@ -430,35 +75,354 @@ class ReadCodeWithRelativePathInput(BaseModel): end_line: int = Field("The ending line number, 1-indexed and exclusive.") -READ_CODE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Read a specific section of a source code file's content by specifying its relative path and line range. -The relative path must be the full path from the root of codebase, like 'src/core/parser.py' or -'test/unit/test_parser.java'. - -This tool ONLY works with source code files (not text files or documentation). It is designed -to read large sections of code at once - you should request substantial chunks (hundreds of lines) -rather than making multiple small requests of 10-20 lines each, which would be inefficient. - -Line numbers are 1-indexed, where start_line is inclusive and end_line is exclusive. +class GraphTraversalTool: + + # FileNode retrieval tools + find_file_node_with_basename_spec = ToolSpec( + description="""Find all FileNode in the graph with this basename of a file/dir. The basename must + include the extension, like 'bar.py', 'baz.java' or 'foo' + (in this case foo is a directory or a file without extension). + + You can use this tool to check if a file/dir with this basename exists or get all + attributes related to the file/dir.""", + input_schema=FindFileNodeWithBasenameInput + ) + + find_file_node_with_relative_path_spec = ToolSpec( + description="""Search FileNode in the graph with this relative_path of a file/dir. The relative_path is + the relative path from the root path of codebase. The relative_path must include the extension, + like 'foo/bar/baz.java'. + + You can use this tool to check if a file/dir with this relative_path exists or get all + attributes related to the file/dir.""", + input_schema=FindFileNodeWithRelativePathInput + ) + + # ASTNode retrieval tools + find_ast_node_with_text_in_file_with_basename_spec = ToolSpec( + description="""Find all ASTNode in the graph that exactly contains this text in any source file under + a file/directory with this basename. For reliable results, search for longer, distinct text + sequences rather than short common words or fragments. The contains is same as python's check + `'foo' in text`, ie. it is case sensitive and is looking for exact matches. For best results, + use unique text segments of at least several words. The basename can be either a file (like + 'bar.py', 'baz.java') or a directory (like 'src' or 'test').""", + input_schema=FindASTNodeWithTextInFileWithBasenameInput + ) + + find_ast_node_with_text_in_file_with_relative_path_spec = ToolSpec( + description="""Find all ASTNode in the graph that exactly contains this text in any source file under + a file/directory with this relative path. For reliable results, search for longer, distinct text + sequences rather than short common words or fragments. The contains is same as python's check `'foo' in text`, + ie. it is case sensitive and is looking for exact matches. Therefore the search text should + be exact as well. The relative path should be the path from the root of codebase + (like 'src/core/parser.py' or 'test/unit').""", + input_schema=FindASTNodeWithTextInFileWithRelativePathInput + ) + + find_ast_node_with_type_in_file_with_basename_spec = ToolSpec( + description="""Find all ASTNode in the graph that has this tree-sitter node type in any source file under + a file/directory with this basename. This tool is useful for searching class/function/method + under a file/directory. The basename can be either a file (like 'bar.py', + 'baz.java') or a directory (like 'core' or 'test').""", + input_schema=FindASTNodeWithTypeInFileWithBasenameInput + ) + + find_ast_node_with_type_in_file_with_relative_path_spec = ToolSpec( + description="""Find all ASTNode in the graph that has this tree-sitter node type in any source file under + a file/directory with this relative path. This tool is useful for searching class/function/method + under a file/directory. The relative path should be the path from the root + of codebase (like 'src/core/parser.py' or 'test/unit').""", + input_schema=FindASTNodeWithTypeInFileWithRelativePathInput + ) + + # TextNode retrieval tools + find_text_node_with_text_spec = ToolSpec( + description="""Find all TextNode in the graph that exactly contains this text. The contains is + same as python's check `'foo' in text`, ie. it is case sensitive and is + looking for exact matches. Therefore the search text should be exact as well. + + You can use this tool to find all text/documentation in codebase that contains this text.""", + input_schema=FindTextNodeWithTextInput + ) + + find_text_node_with_text_in_file_spec = ToolSpec( + description="""Find all TextNode in the graph that exactly contains this text in a file with this basename. + The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is + looking for exact matches. Therefore the search text should be exact as well. + The basename must include the extension, like 'bar.py', 'baz.java' or 'foo' + (in this case foo is a directory or a file without extension). + + You can use this tool to find text/documentation in a specific file that contains this text.""", + input_schema=FindTextNodeWithTextInFileInput + ) + + get_next_text_node_with_node_id_spec = ToolSpec( + description="""Get the next TextNode of this given node_id. + + You can use this tool to read the next section of text that you are interested in.""", + input_schema=GetNextTextNodeWithNodeIdInput + ) + + # Other tools + preview_file_content_with_basename_spec = ToolSpec( + description="""Preview the content of a file with this basename. The basename must include + the extension, like 'bar.py', 'baz.java' or 'foo' (in this case foo is a + directory or a file without extension). + + You can use this tool to preview the content of a specific file to see what it contains + in the first 1000 lines or the first section. If the file is interesting, use other tools + to look at the file.""", + input_schema=PreviewFileContentWithBasenameInput + ) + + preview_file_content_with_relative_path_spec = ToolSpec( + description="""Preview the content of a file with this relative path from the root of codebase. + The relative path must include the extension and full path from root, like 'src/core/parser.py', + 'test/unit/test_parser.java' or 'docs/README.md'. + + You can use this tool to preview the content of a specific file to see what it contains + in the first 1000 lines or the first section. If the file is interesting, use other tools + to look at the file.""", + input_schema=PreviewFileContentWithRelativePathInput + ) + + read_code_with_basename_spec = ToolSpec( + description="""Read a specific section of a source code file's content by specifying its basename and line range. + The basename must include the extension, like 'bar.py' or 'baz.java' + + This tool ONLY works with source code files (not text files or documentation). It is designed + to read large sections of code at once - you should request substantial chunks (hundreds of lines) + rather than making multiple small requests of 10-20 lines each, which would be inefficient. + + Line numbers are 1-indexed, where start_line is inclusive and end_line is exclusive. + + This tool is useful for examining specific sections of source code files when you know + the exact line range you want to analyze. The function will return an error message if + end_line is less than start_line.""", + input_schema=ReadCodeWithBasenameInput + ) + + read_code_with_relative_path_spec = ToolSpec( + description="""Read a specific section of a source code file's content by specifying its relative path and line range. + The relative path must be the full path from the root of codebase, like 'src/core/parser.py' or + 'test/unit/test_parser.java'. + + This tool ONLY works with source code files (not text files or documentation). It is designed + to read large sections of code at once - you should request substantial chunks (hundreds of lines) + rather than making multiple small requests of 10-20 lines each, which would be inefficient. + + Line numbers are 1-indexed, where start_line is inclusive and end_line is exclusive. + + This tool is useful for examining specific sections of source code files when you know + the exact line range you want to analyze. The function will return an error message if + end_line is less than start_line.""", + input_schema=ReadCodeWithRelativePathInput + ) + + def __init__(self, driver: GraphDatabase.driver, max_token_per_result: int): + self.driver = driver + self.max_token_per_result = max_token_per_result + + ############################################################################### + # FileNode retrieval # + ############################################################################### + + def find_file_node_with_basename(self, basename: str) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f""" + MATCH (f:FileNode {{ basename: '{basename}' }}) + RETURN f AS FileNode + ORDER BY f.node_id + LIMIT {MAX_RESULT} + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) -This tool is useful for examining specific sections of source code files when you know -the exact line range you want to analyze. The function will return an error message if -end_line is less than start_line. -""" + def find_file_node_with_relative_path(self, relative_path: str) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f""" + MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) + RETURN f AS FileNode + ORDER BY f.node_id + LIMIT {MAX_RESULT} + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + ############################################################################### + # ASTNode retrieval # + ############################################################################### -def read_code_with_relative_path( - relative_path: str, - start_line: int, - end_line: int, - driver: GraphDatabase.driver, - max_token_per_result: int, -) -> tuple[str, Union[Sequence[Mapping[str, Any]], None]]: - if end_line < start_line: - return f"end_line {end_line} must be greater than start_line {start_line}", None + def find_ast_node_with_text_in_file_with_basename( + self, text: str, basename: str + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f"""\ + MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) + WHERE f.basename = '{basename}' AND a.text CONTAINS '{text}' + RETURN c as FileNode, a AS ASTNode + ORDER BY SIZE(a.text) + LIMIT {MAX_RESULT} + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + + def find_ast_node_with_text_in_file_with_relative_path( + self, text: str, relative_path: str + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f"""\ + MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) + WHERE f.relative_path = '{relative_path}' AND a.text CONTAINS '{text}' + RETURN c as FileNode, a AS ASTNode + ORDER BY SIZE(a.text) + LIMIT {MAX_RESULT} + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + + def find_ast_node_with_type_in_file_with_basename( + self, type: str, basename: str + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f"""\ + MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) + WHERE f.basename = '{basename}' AND a.type = '{type}' + RETURN c as FileNode, a AS ASTNode + ORDER BY SIZE(a.text) + LIMIT {MAX_RESULT} + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + + def find_ast_node_with_type_in_file_with_relative_path( + self, type: str, relative_path: str + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f"""\ + MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) + WHERE f.relative_path = '{relative_path}' AND a.type = '{type}' + RETURN c as FileNode, a AS ASTNode + ORDER BY SIZE(a.text) + LIMIT {MAX_RESULT} + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + + ############################################################################### + # TextNode retrieval # + ############################################################################### + + def find_text_node_with_text( + self, text: str + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f"""\ + MATCH (f:FileNode) -[:HAS_TEXT]-> (t:TextNode) + WHERE t.text CONTAINS '{text}' + RETURN f as FileNode, t AS TextNode + ORDER BY t.node_id + LIMIT {MAX_RESULT} + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + + def find_text_node_with_text_in_file( + self, text: str, basename: str + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f"""\ + MATCH (f:FileNode) -[:HAS_TEXT]-> (t:TextNode) + WHERE f.basename = '{basename}' AND t.text CONTAINS '{text}' + RETURN f as FileNode, t AS TextNode + ORDER BY t.node_id + LIMIT {MAX_RESULT} + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + + def get_next_text_node_with_node_id( + self, node_id: int + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + query = f"""\ + MATCH (f:FileNode) -[:HAS_TEXT]-> (a:TextNode {{ node_id: {node_id} }}) -[:NEXT_CHUNK]-> (b:TextNode) + RETURN f as FileNode, b AS TextNode + """ + return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + + ############################################################################### + # Other # + ############################################################################### + + def preview_file_content_with_basename( + self, basename: str + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + source_code_query = f"""\ + MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_AST]-> (a:ASTNode) + WITH f, apoc.text.split(a.text, '\\R') AS lines + RETURN + f AS FileNode, + {{ + text: apoc.text.join(lines[0..1000], '\\n'), + start_line: 1, + end_line: 1000 + }} AS preview + ORDER BY f.node_id + """ - source_code_query = f"""\ - MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_AST]-> (a:ASTNode) + text_query = f"""\ + MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_TEXT]-> (t:TextNode) + WHERE NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) + RETURN f as FileNode, t.text AS preview + ORDER BY f.node_id + """ + + if tree_sitter_parser.supports_file(Path(basename)): + data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, self.driver) + else: + data = neo4j_util.run_neo4j_query_without_formatting(text_query, self.driver) + for result in data: + if isinstance(result["preview"], dict): + result["preview"]["text"] = pre_append_line_numbers( + result["preview"]["text"], result["preview"]["start_line"] + ) + result["preview"]["end_line"] = ( + result["preview"]["start_line"] + len(result["preview"]["text"].splitlines()) - 1 + ) + return neo4j_util.format_neo4j_data(data, self.max_token_per_result), data + + def preview_file_content_with_relative_path( + self, relative_path: str + ) -> tuple[str, Sequence[Mapping[str, Any]]]: + source_code_query = f"""\ + MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_AST]-> (a:ASTNode) + WITH f, apoc.text.split(a.text, '\\R') AS lines + RETURN + f as FileNode, + {{ + text: apoc.text.join(lines[0..1000], '\\n'), + start_line: 1, + end_line: 1000 + }} AS preview + ORDER BY f.node_id + """ + + text_query = f"""\ + MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_TEXT]-> (t:TextNode) + WHERE NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) + RETURN f as FileNode, t.text AS preview + ORDER BY f.node_id + """ + + if tree_sitter_parser.supports_file(Path(relative_path)): + data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, self.driver) + else: + data = neo4j_util.run_neo4j_query_without_formatting(text_query, self.driver) + for result in data: + if isinstance(result["preview"], dict): + result["preview"]["text"] = pre_append_line_numbers( + result["preview"]["text"], result["preview"]["start_line"] + ) + result["preview"]["end_line"] = ( + result["preview"]["start_line"] + len(result["preview"]["text"].splitlines()) - 1 + ) + return neo4j_util.format_neo4j_data(data, self.max_token_per_result), data + + def read_code_with_basename( + self, + basename: str, + start_line: int, + end_line: int, + ) -> tuple[str, Union[Sequence[Mapping[str, Any]], None]]: + if end_line < start_line: + return f"end_line {end_line} must be greater than start_line {start_line}", None + + source_code_query = f"""\ + MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_AST]-> (a:ASTNode) WITH f, apoc.text.split(a.text, '\\R') AS lines RETURN f as FileNode, @@ -468,12 +432,41 @@ def read_code_with_relative_path( end_line: {end_line} }} AS SelectedLines ORDER BY f.node_id - """ + """ + data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, self.driver) + for result in data: + result["SelectedLines"]["text"] = pre_append_line_numbers( + result["SelectedLines"]["text"], result["SelectedLines"]["start_line"] + ) + return neo4j_util.format_neo4j_data(data, self.max_token_per_result), data + + def read_code_with_relative_path( + self, + relative_path: str, + start_line: int, + end_line: int, + ) -> tuple[str, Union[Sequence[Mapping[str, Any]], None]]: + if end_line < start_line: + return f"end_line {end_line} must be greater than start_line {start_line}", None + + source_code_query = f"""\ + MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_AST]-> (a:ASTNode) + WITH f, apoc.text.split(a.text, '\\R') AS lines + RETURN + f as FileNode, + {{ + text: apoc.text.join(lines[{start_line - 1}..{end_line - 1}], '\\n'), + start_line: {start_line}, + end_line: {end_line} + }} AS SelectedLines + ORDER BY f.node_id + """ + + data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, self.driver) + for result in data: + result["SelectedLines"]["text"] = pre_append_line_numbers( + result["SelectedLines"]["text"], result["SelectedLines"]["start_line"] + ) - data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, driver) - for result in data: - result["SelectedLines"]["text"] = pre_append_line_numbers( - result["SelectedLines"]["text"], result["SelectedLines"]["start_line"] - ) + return neo4j_util.format_neo4j_data(data, self.max_token_per_result), data - return neo4j_util.format_neo4j_data(data, max_token_per_result), data diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 816ab522..829cf9cf 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -4,28 +4,20 @@ from typing import Annotated import json import asyncio +from dataclasses import dataclass from dynaconf.vendor.dotenv import load_dotenv from pydantic import BaseModel, Field, field_validator -from mcp.server import Server -from mcp.shared.exceptions import McpError -from mcp.types import ErrorData -from mcp.server.stdio import stdio_server -from mcp.types import ( - GetPromptResult, - Prompt, - PromptArgument, - PromptMessage, - TextContent, - Tool, - INVALID_PARAMS, - INTERNAL_ERROR, -) from tavily import TavilyClient, InvalidAPIKeyError, UsageLimitExceededError from prometheus.configuration.config import settings from prometheus.utils.logger_manager import get_logger logger = get_logger(__name__) +@dataclass +class ToolSpec: + description: str + input_schema: type + tavily_api_key = settings.get("TAVILY_API_KEY", None) if tavily_api_key is None: @@ -39,17 +31,6 @@ class WebSearchInput(BaseModel): """Base parameters for Tavily search.""" query: Annotated[str, Field(description="Search query")] -WEB_SEARCH_DESCRIPTION = """\ - Searches the web for technical information to aid in bug analysis and resolution. - Use this when you need external context, such as: - 1. Looking up unfamiliar error messages, exceptions, or stack traces. - 2. Finding official documentation or usage examples for a specific library, framework, or API. - 3. Searching for known bugs, common pitfalls, or compatibility issues related to a software version. - 4. Learning the best practices or design patterns for fixing a class of vulnerability (e.g., SQL injection, XSS). - - Queries should be specific and include relevant keywords like library names, version numbers, and error codes. -""" - def format_results(response: dict) -> str: """Format Tavily search results into a readable string.""" output = [] @@ -84,55 +65,69 @@ def format_results(response: dict) -> str: return "\n".join(output) - -def web_search(query: str, - max_results: int = 5, - include_domains: list[str] = [ - 'stackoverflow.com', - 'github.com', - 'developer.mozilla.org', - 'learn.microsoft.com', - 'docs.python.org', - 'pydantic.dev', - 'pypi.org', - 'readthedocs.org', - ], - exclude_domains: list[str] = None) -> str: - """ - Search the web for technical information to aid in bug analysis and resolution. - """ - if tavily_client is None: - raise McpError(ErrorData( - code=INTERNAL_ERROR, - message="Tavily API key is not set" - )) - try: - response = tavily_client.search( - query=query, - max_results=max_results, - search_depth="advanced", - include_answer=True, - include_domains=include_domains or [], # Convert None to empty list - exclude_domains=exclude_domains or [], # Convert None to empty list - ) - return format_results(response) - except InvalidAPIKeyError: - raise McpError(ErrorData( - code=INVALID_PARAMS, - message="Invalid Tavily API key" - )) - except UsageLimitExceededError: - raise McpError(ErrorData( - code=INTERNAL_ERROR, - message="Usage limit exceeded" - )) - except Exception as e: - raise McpError(ErrorData( - code=INTERNAL_ERROR, - message=f"An error occurred: {str(e)}" - )) - - +class WebSearchTool: + """Tool class for web search functionality.""" + + web_search_spec = ToolSpec( + description="""\ + Searches the web for technical information to aid in bug analysis and resolution. + Use this when you need external context, such as: + 1. Looking up unfamiliar error messages, exceptions, or stack traces. + 2. Finding official documentation or usage examples for a specific library, framework, or API. + 3. Searching for known bugs, common pitfalls, or compatibility issues related to a software version. + 4. Learning the best practices or design patterns for fixing a class of vulnerability (e.g., SQL injection, XSS). + + Queries should be specific and include relevant keywords like library names, version numbers, and error codes. + """, + input_schema=WebSearchInput + ) + + def __init__(self): + """Initialize the web search tool.""" + self.tavily_client = tavily_client + + def web_search(self, query: str, max_results: int = 5, + include_domains: list[str] = [ + 'stackoverflow.com', + 'github.com', + 'developer.mozilla.org', + 'learn.microsoft.com', + 'docs.python.org', + 'pydantic.dev', + 'pypi.org', + 'readthedocs.org', + ], + exclude_domains: list[str] = None) -> str: + """Search the web for technical information to aid in bug analysis and resolution. + + Args: + query: Search query string. + max_results: Maximum number of results to return (default: 5). + include_domains: List of domains to include in search. + exclude_domains: List of domains to exclude from search. + + Returns: + Formatted search results as a string. + """ + + if tavily_client is None: + raise RuntimeError("Tavily API key is not set") + try: + response = tavily_client.search( + query=query, + max_results=max_results, + search_depth="advanced", + include_answer=True, + include_domains=include_domains or [], # Convert None to empty list + exclude_domains=exclude_domains or [], # Convert None to empty list + ) + return format_results(response) + except InvalidAPIKeyError: + raise ValueError("Invalid Tavily API key") + except UsageLimitExceededError: + raise RuntimeError("Usage limit exceeded") + except Exception as e: + raise RuntimeError(f"An error occurred: {str(e)}") if __name__ == "__main__": @@ -144,4 +139,4 @@ def web_search(query: str, else: tavily_client = TavilyClient(api_key=tavily_api_key) - print(web_search("What is the capital of France?")) \ No newline at end of file + print(web_search("What is the capital of France?")) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8fb5e4c1..d34cab63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ dependencies = [ "sqlmodel==0.0.24", "psycopg2-binary", "mcp>=1.4.1", - "tavily-python>=0.5.1" + "tavily-python>=0.5.1", + "langchain-mcp-adapters>=0.1.9" ] requires-python = ">= 3.11" diff --git a/tests/tools/test_mcp_client.py b/tests/tools/test_mcp_client.py new file mode 100644 index 00000000..968dfe34 --- /dev/null +++ b/tests/tools/test_mcp_client.py @@ -0,0 +1,83 @@ +import asyncio +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.graph import StateGraph, MessagesState, START +from langgraph.prebuilt import ToolNode, tools_condition +from langchain_core.messages import AIMessage, ToolMessage + +# 使用项目中的自定义模拟模型,支持工具调用 +import sys +sys.path.append("/root/lix/Prometheus/") +from tests.test_utils.util import FakeListChatWithToolsModel + +async def main(): + # 可以动态设置多个配置参数 + config = { + "driver": "111111111111111111111", + "timeout": 60, + "max_retries": 5 + } + # 将配置转换为命令行参数 + args = ["/root/lix/Prometheus/tests/tools/test_mcp_tools.py"] + for key, value in config.items(): + args.extend([f"--{key}", str(value)]) + + client = MultiServerMCPClient( + { + "weather": { + "command": "python", + "args": args, + "transport": "stdio", + } + } + ) + + # 异步获取工具 + tools = await client.get_tools() + print(f"获取到的工具: {[tool.name for tool in tools]}") + + # 使用支持工具的模拟模型 + model = FakeListChatWithToolsModel(responses=["I need to check the weather for NYC"]) + + # 创建工具节点 + tool_node = ToolNode(tools) + + def call_model(state: MessagesState): + messages = state["messages"] + + # 检查是否已经有工具消息,如果有就结束 + if any(isinstance(msg, ToolMessage) for msg in messages): + return {"messages": [AIMessage(content="Weather check completed!")]} + + # 第一次调用时创建工具调用响应 + response = AIMessage( + content="Let me check the weather for you", + tool_calls=[{ + "name": "get_weather", + "args": {"location": "nyc"}, + "id": "call_1" + }] + ) + return {"messages": [response]} + + # 构建图 + builder = StateGraph(MessagesState) + builder.add_node("call_model", call_model) + builder.add_node("tools", tool_node) + builder.add_edge(START, "call_model") + builder.add_conditional_edges( + "call_model", + tools_condition, + ) + builder.add_edge("tools", "call_model") + + graph = builder.compile() + + # 执行测试 + weather_response = await graph.ainvoke({"messages": "what is the weather in nyc?"}) + print("Response:", weather_response) + + return weather_response + +# 运行异步主函数 +if __name__ == "__main__": + result = asyncio.run(main()) \ No newline at end of file diff --git a/tests/tools/test_mcp_client_config.py b/tests/tools/test_mcp_client_config.py new file mode 100644 index 00000000..43082af6 --- /dev/null +++ b/tests/tools/test_mcp_client_config.py @@ -0,0 +1,97 @@ +import asyncio +import json +import tempfile +import os +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.graph import StateGraph, MessagesState, START +from langgraph.prebuilt import ToolNode, tools_condition +from langchain_core.messages import AIMessage, ToolMessage + +# 使用项目中的自定义模拟模型,支持工具调用 +import sys +sys.path.append("/root/lix/Prometheus/") +from tests.test_utils.util import FakeListChatWithToolsModel + +async def main(): + # 可以动态设置多个配置参数 + config = { + "driver": "neo4j://enterprise-cluster:7687", + "timeout": 120, + "max_retries": 10 + } + + # 创建临时配置文件 + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config, f, indent=2) + config_file_path = f.name + + try: + client = MultiServerMCPClient( + { + "weather": { + "command": "python", + "args": ["/root/lix/Prometheus/tests/tools/config_based_mcp_tools.py"], + "transport": "stdio", + "env": { + "MCP_WEATHER_CONFIG": config_file_path # 通过环境变量传递配置文件路径 + } + } + } + ) + + # 异步获取工具 + tools = await client.get_tools() + print(f"获取到的工具: {[tool.name for tool in tools]}") + + # 使用支持工具的模拟模型 + model = FakeListChatWithToolsModel(responses=["I need to check the weather for NYC"]) + + # 创建工具节点 + tool_node = ToolNode(tools) + + def call_model(state: MessagesState): + messages = state["messages"] + + # 检查是否已经有工具消息,如果有就结束 + if any(isinstance(msg, ToolMessage) for msg in messages): + return {"messages": [AIMessage(content="Weather check completed!")]} + + # 第一次调用时创建工具调用响应 + response = AIMessage( + content="Let me check the weather for you", + tool_calls=[{ + "name": "get_weather", + "args": {"location": "nyc"}, + "id": "call_1" + }] + ) + return {"messages": [response]} + + # 构建图 + builder = StateGraph(MessagesState) + builder.add_node("call_model", call_model) + builder.add_node("tools", tool_node) + builder.add_edge(START, "call_model") + builder.add_conditional_edges( + "call_model", + tools_condition, + ) + builder.add_edge("tools", "call_model") + + graph = builder.compile() + + # 执行测试 + weather_response = await graph.ainvoke({"messages": "what is the weather in nyc?"}) + print("Response:", weather_response) + + return weather_response + + finally: + # 清理临时配置文件 + if os.path.exists(config_file_path): + os.unlink(config_file_path) + print(f"🗑️ 清理临时配置文件: {config_file_path}") + +# 运行异步主函数 +if __name__ == "__main__": + result = asyncio.run(main()) diff --git a/tests/tools/test_mcp_server.py b/tests/tools/test_mcp_server.py new file mode 100644 index 00000000..d6deb0b7 --- /dev/null +++ b/tests/tools/test_mcp_server.py @@ -0,0 +1,153 @@ +from mcp.server.fastmcp import FastMCP +import os +import sys +import json +import asyncio +from pathlib import Path +from typing import Dict, List, Any, Optional, Set + +import yaml +from langchain_mcp_adapters.client import MultiServerMCPClient + + +# ========================================== +# MCP Server (existing behavior preserved) +# ========================================== +# Create unified MCP server instance +mcp = FastMCP("PrometheusTools") + + +# ========================================== +# Dynamic MCP client based on node_tools.yml +# ========================================== +_NODE_TOOLS_CACHE: Optional[Dict[str, List[str]]] = None +_CLIENT_CACHE: Optional[MultiServerMCPClient] = None + + +def _load_node_tools_map(config_path: Optional[Path] = None) -> Dict[str, List[str]]: + """Load node->tools mapping from prometheus/graph_config/node_tools.yml. + + Returns a dict: { node_name: [tool_name, ...] } + """ + global _NODE_TOOLS_CACHE + if _NODE_TOOLS_CACHE is not None: + return _NODE_TOOLS_CACHE + + if config_path is None: + # mcp_server.py is in prometheus/tools/, go up one to prometheus/ + project_root = Path(__file__).resolve().parents[1] + config_path = project_root / "graph_config" / "node_tools.yml" + + with config_path.open("r", encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + + nodes = data.get("nodes", []) + node_to_tools: Dict[str, List[str]] = {} + for item in nodes: + name = item.get("name") + tools = item.get("tools", []) or [] + if name is None: + continue + if isinstance(tools, list): + node_to_tools[name] = tools + else: + # In case of malformed YAML (non-list), coerce to list + node_to_tools[name] = [tools] + + _NODE_TOOLS_CACHE = node_to_tools + return node_to_tools + + +def _load_server_configs() -> Dict[str, Dict[str, Any]]: + """Load MCP server configurations from environment. + + Expected env var PROMETHEUS_MCP_SERVERS as JSON, e.g.: + { + "math": { + "command": "python", + "args": ["/abs/path/to/examples/math_server.py"], + "transport": "stdio" + }, + "weather": { + "url": "http://localhost:8000/mcp/", + "transport": "streamable_http" + } + } + + If not provided, default to spawning this file as a stdio MCP server under id "PrometheusTools". + """ + raw = os.getenv("PROMETHEUS_MCP_SERVERS") + if raw: + try: + cfg = json.loads(raw) + if isinstance(cfg, dict): + return cfg + except json.JSONDecodeError: + pass + + # Fallback: local stdio server using this file + this_file = Path(__file__).resolve() + return { + "PrometheusTools": { + "command": "python", + "args": [str(this_file)], + "transport": "stdio", + } + } + + +def _build_client(server_configs: Optional[Dict[str, Dict[str, Any]]] = None) -> MultiServerMCPClient: + global _CLIENT_CACHE + if _CLIENT_CACHE is not None: + return _CLIENT_CACHE + + if server_configs is None: + server_configs = _load_server_configs() + + client = MultiServerMCPClient(server_configs) + _CLIENT_CACHE = client + return client + + +async def get_all_tools() -> List[Any]: + """Fetch all tools from all configured MCP servers.""" + client = _build_client() + tools = await client.get_tools() + return tools + + +def get_required_tool_names_for_node(node_name: str) -> List[str]: + mapping = _load_node_tools_map() + return mapping.get(node_name, []) + + +async def get_tools_for_node(node_name: str) -> List[Any]: + """Return the list of MCP tools required by the given node name. + + This will connect to all configured MCP servers, fetch their tools, and filter + by the names listed for the node in node_tools.yml. + """ + required: Set[str] = set(get_required_tool_names_for_node(node_name)) + if not required: + return [] + all_tools = await get_all_tools() + return [t for t in all_tools if getattr(t, "name", None) in required] + + +def build_default_node_tool_client() -> MultiServerMCPClient: + """Expose a builder for external callers if needed.""" + return _build_client() + + +if __name__ == "__main__": + # 确保在启动前注册所有工具 + sys.path.append("~/lix/Prometheus") + import prometheus.tools # noqa: F401 + mcp.run(transport="stdio") + + async def main(): + tools = await get_all_tools() + for t in sorted(tools, key=lambda x: getattr(x, "name", "")): + print(getattr(t, "name", str(t))) + + asyncio.run(main()) diff --git a/tests/tools/test_mcp_tools.py b/tests/tools/test_mcp_tools.py new file mode 100644 index 00000000..9a9e9bf6 --- /dev/null +++ b/tests/tools/test_mcp_tools.py @@ -0,0 +1,64 @@ +# weather_server.py +import sys +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("Weather") + +class WeatherTools: + driver: str + timeout: int + max_retries: int + + @classmethod + def configure(cls, **kwargs): + """动态配置工具参数""" + for key, value in kwargs.items(): + if hasattr(cls, key): + setattr(cls, key, value) + print(f"设置 {key} = {value}") + else: + print(f"警告: 未知参数 {key} = {value}") + + @staticmethod + @mcp.tool() + async def get_weather(location: str) -> str: + """Get weather for a location.""" + return f"[Driver={WeatherTools.driver}, Timeout={WeatherTools.timeout}s] It's always sunny in {location}" + + @staticmethod + @mcp.tool() + async def get_temperature(location: str) -> str: + """Get temperature for a location.""" + return f"[Driver={WeatherTools.driver}, Retries={WeatherTools.max_retries}] Temperature in {location} is 25°C" + +def parse_args(args): + """解析命令行参数为 kwargs""" + kwargs = {} + i = 1 + while i < len(args): + if args[i].startswith("--"): + key = args[i][2:] # 移除 "--" 前缀 + if i + 1 < len(args) and not args[i + 1].startswith("--"): + value = args[i + 1] + # 尝试转换数据类型 + if value.isdigit(): + value = int(value) + elif value.lower() in ['true', 'false']: + value = value.lower() == 'true' + kwargs[key] = value + i += 2 + else: + kwargs[key] = True # 布尔标志 + i += 1 + else: + i += 1 + return kwargs + +if __name__ == "__main__": + # 动态解析命令行参数 + # 支持格式:python test_mcp_tools.py --driver neo4j://server --timeout 60 --max_retries 5 + config = parse_args(sys.argv) + if config: + WeatherTools.configure(**config) + # 启动 MCP + mcp.run(transport="stdio") \ No newline at end of file From 37c2d0507049d6f22802b9a0f0786c5a6f4f6d8e Mon Sep 17 00:00:00 2001 From: cocoli Date: Sat, 16 Aug 2025 18:00:56 +0800 Subject: [PATCH 03/30] fix --- prometheus/app/services/service_coordinator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prometheus/app/services/service_coordinator.py b/prometheus/app/services/service_coordinator.py index 00904e3d..e88da969 100644 --- a/prometheus/app/services/service_coordinator.py +++ b/prometheus/app/services/service_coordinator.py @@ -17,7 +17,7 @@ from prometheus.app.services.neo4j_service import Neo4jService from prometheus.app.services.repository_service import RepositoryService from prometheus.lang_graph.graphs.issue_state import IssueType -from prometheus.utils.logger_manager import get_logger, create_timestamped_file_handler, logger_manager +from prometheus.utils.logger_manager import get_logger, create_timestamped_file_handler class ServiceCoordinator: From 2ad9e19e941bc976ecf645346c43846b4871de7d Mon Sep 17 00:00:00 2001 From: cocoli Date: Sat, 16 Aug 2025 18:07:26 +0800 Subject: [PATCH 04/30] remove broken mcp --- prometheus/tools/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/prometheus/tools/__init__.py b/prometheus/tools/__init__.py index 15e3cbf6..e69de29b 100644 --- a/prometheus/tools/__init__.py +++ b/prometheus/tools/__init__.py @@ -1,5 +0,0 @@ - -# Ensure MCP tools are registered when this package is imported by the MCP server -# Importing the module executes the @mcp.tool decorators -from prometheus.mcp_tools import web_search as _mcp_web_search # noqa: F401 - From 76917b69c63dd4f2f859f49761edf06326727b84 Mon Sep 17 00:00:00 2001 From: cocoli Date: Sat, 16 Aug 2025 21:38:06 +0800 Subject: [PATCH 05/30] fix: tool node --- .../subgraphs/issue_not_verified_bug_subgraph.py | 16 +++++++++++++++- .../subgraphs/issue_verified_bug_subgraph.py | 16 +++++++++++++++- prometheus/tools/__init__.py | 1 + 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index f7c4417a..4f5f4c9f 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -43,6 +43,11 @@ def __init__( issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) + issue_bug_analyzer_tools = ToolNode( + tools=issue_bug_analyzer_node.tools, + name="issue_bug_analyzer_tools", + messages_key="issue_bug_analyzer_messages", + ) edit_message_node = EditMessageNode() edit_node = EditNode(advanced_model, kg) @@ -66,6 +71,7 @@ def __init__( workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) + workflow.add_node("issue_bug_analyzer_tools", issue_bug_analyzer_tools) workflow.add_node("edit_message_node", edit_message_node) workflow.add_node("edit_node", edit_node) @@ -84,7 +90,15 @@ def __init__( workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") - workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") + + # Conditionally invoke tools or continue to edit message + workflow.add_conditional_edges( + "issue_bug_analyzer_node", + functools.partial(tools_condition, messages_key="issue_bug_analyzer_messages"), + {"tools": "issue_bug_analyzer_tools", END: "edit_message_node"}, + ) + + workflow.add_edge("issue_bug_analyzer_tools", "issue_bug_analyzer_node") workflow.add_edge("edit_message_node", "edit_node") workflow.add_conditional_edges( diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py index cbcc4cce..fd115676 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py @@ -83,6 +83,11 @@ def __init__( # Phase 2: Analyze the bug and generate hypotheses issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) + issue_bug_analyzer_tools = ToolNode( + tools=issue_bug_analyzer_node.tools, + name="issue_bug_analyzer_tools", + messages_key="issue_bug_analyzer_messages", + ) # Phase 3: Generate code edits and optionally apply toolchains edit_message_node = EditMessageNode() @@ -122,6 +127,7 @@ def __init__( workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) + workflow.add_node("issue_bug_analyzer_tools", issue_bug_analyzer_tools) workflow.add_node("edit_message_node", edit_message_node) workflow.add_node("edit_node", edit_node) @@ -138,7 +144,15 @@ def __init__( workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") - workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") + + # Conditionally invoke tools or continue to edit message + workflow.add_conditional_edges( + "issue_bug_analyzer_node", + functools.partial(tools_condition, messages_key="issue_bug_analyzer_messages"), + {"tools": "issue_bug_analyzer_tools", END: "edit_message_node"}, + ) + + workflow.add_edge("issue_bug_analyzer_tools", "issue_bug_analyzer_node") workflow.add_edge("edit_message_node", "edit_node") # Conditionally invoke tools or continue to diffing diff --git a/prometheus/tools/__init__.py b/prometheus/tools/__init__.py index e69de29b..8b137891 100644 --- a/prometheus/tools/__init__.py +++ b/prometheus/tools/__init__.py @@ -0,0 +1 @@ + From ca58c286bcf0d7992217c35b491d4dfdcd932dd1 Mon Sep 17 00:00:00 2001 From: cocoli Date: Sat, 16 Aug 2025 23:03:19 +0800 Subject: [PATCH 06/30] add log for web_search tool --- prometheus/tools/web_search.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 829cf9cf..84bd317a 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -121,7 +121,9 @@ def web_search(self, query: str, max_results: int = 5, include_domains=include_domains or [], # Convert None to empty list exclude_domains=exclude_domains or [], # Convert None to empty list ) - return format_results(response) + format_response = format_results(response) + self._logger.info(f"web_search format_response: {format_response}") + return format_response except InvalidAPIKeyError: raise ValueError("Invalid Tavily API key") except UsageLimitExceededError: From 369436a0fcef790c1e710c62b688706ed4a014e3 Mon Sep 17 00:00:00 2001 From: cocoli Date: Tue, 26 Aug 2025 23:31:00 +0800 Subject: [PATCH 07/30] add web_search mcp --- .../nodes/issue_bug_analyzer_node.py | 11 +- prometheus/tools/web_search.py | 13 + tests/tools/test_mcp_client.py | 253 ++++++++++++++---- tests/tools/test_mcp_web_search.py | 208 ++++++++++++++ 4 files changed, 434 insertions(+), 51 deletions(-) create mode 100644 tests/tools/test_mcp_web_search.py diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index d91a11c3..a8312b52 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -1,10 +1,11 @@ from typing import Dict +import asyncio from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import SystemMessage from langchain.tools import StructuredTool import functools -from prometheus.tools.web_search import WebSearchTool +from prometheus.tools.web_search import WebSearchTool, mcp_web_search from prometheus.utils.logger_manager import get_logger @@ -82,6 +83,13 @@ class IssueBugAnalyzerNode: Tools available: - web_search: Searches the web for technical information to aid in bug analysis and resolution. +When using the web_search tool, ALWAYS include these parameters: + - exclude_domains: ["*swe-bench*"] + - include_domains: ['stackoverflow.com', 'github.com', 'developer.mozilla.org', 'learn.microsoft.com', 'fastapi.tiangolo.com' + 'docs.python.org', 'pydantic.dev', 'pypi.org', 'readthedocs.org', 'docs.djangoproject.com','flask.palletsprojects.com'] + - search_depth: "advanced" + + Make sure to explicitly pass these parameters in your tool call. Important: - Do NOT provide actual code snippets or diffs @@ -98,6 +106,7 @@ def __init__(self, model: BaseChatModel): self.web_search_tool = WebSearchTool() self.model = model self.system_prompt = SystemMessage(self.SYS_PROMPT) + # self.tools = asyncio.run(mcp_web_search()) # mcp mode self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self._logger = get_logger(__name__) diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 84bd317a..96e0fb63 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -10,6 +10,7 @@ from tavily import TavilyClient, InvalidAPIKeyError, UsageLimitExceededError from prometheus.configuration.config import settings from prometheus.utils.logger_manager import get_logger +from langchain_mcp_adapters.client import MultiServerMCPClient logger = get_logger(__name__) @@ -131,6 +132,18 @@ def web_search(self, query: str, max_results: int = 5, except Exception as e: raise RuntimeError(f"An error occurred: {str(e)}") +async def mcp_web_search(): + client = MultiServerMCPClient( + { + "tavily_web_search": { + "transport": "streamable_http", + "url": f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}", + } + } + ) + # 异步获取工具 + tools = await client.get_tools() + return tools if __name__ == "__main__": load_dotenv() diff --git a/tests/tools/test_mcp_client.py b/tests/tools/test_mcp_client.py index 968dfe34..0e8d10c4 100644 --- a/tests/tools/test_mcp_client.py +++ b/tests/tools/test_mcp_client.py @@ -3,66 +3,198 @@ from langgraph.graph import StateGraph, MessagesState, START from langgraph.prebuilt import ToolNode, tools_condition from langchain_core.messages import AIMessage, ToolMessage +from prometheus.app.services.llm_service import LLMService, get_model +from langchain.tools import StructuredTool +import functools +# 使用真实模型进行工具调用 +from prometheus.configuration.config import settings +import json +import re -# 使用项目中的自定义模拟模型,支持工具调用 -import sys -sys.path.append("/root/lix/Prometheus/") -from tests.test_utils.util import FakeListChatWithToolsModel +import asyncio +import inspect +import json +from copy import copy +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + get_type_hints, +) -async def main(): - # 可以动态设置多个配置参数 - config = { - "driver": "111111111111111111111", - "timeout": 60, - "max_retries": 5 +from langchain_core.messages import ( + AIMessage, + AnyMessage, + ToolCall, + ToolMessage, +) +from langchain_core.runnables import RunnableConfig +from langchain_core.runnables.config import ( + get_config_list, + get_executor_for_config, +) +from langchain_core.runnables.utils import Input +from langchain_core.tools import BaseTool, InjectedToolArg +from langchain_core.tools import tool as create_tool +from langchain_core.tools.base import get_all_basemodel_annotations +from typing_extensions import Annotated, get_args, get_origin + +from langgraph.errors import GraphInterrupt +from langgraph.store.base import BaseStore +from langgraph.utils.runnable import RunnableCallable +from langgraph.prebuilt.tool_node import msg_content_output, _infer_handled_types, _handle_tool_error + + + # 创建自定义 ToolNode +preset_params = { + "tavily-search": { + "include_domains": ["pypi.org", "docs.python.org"], + "exclude_domains": ["stackoverflow.com", "*huggingface*", "discourse.slicer.org","ask.csdn.net", + "codepudding.com", "*geeksforgeeks*", "*github*", "forum.developer.parrot.com"], } - # 将配置转换为命令行参数 - args = ["/root/lix/Prometheus/tests/tools/test_mcp_tools.py"] - for key, value in config.items(): - args.extend([f"--{key}", str(value)]) +} + +class CustomToolNode(ToolNode): + """自定义 ToolNode,支持为特定工具添加预设参数""" - client = MultiServerMCPClient( - { - "weather": { - "command": "python", - "args": args, - "transport": "stdio", - } - } - ) + def __init__(self, tools, preset_params=None, **kwargs): + super().__init__(tools, **kwargs) + self.preset_params = preset_params or {} - # 异步获取工具 - tools = await client.get_tools() - print(f"获取到的工具: {[tool.name for tool in tools]}") + async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + + try: + # 构建基础输入 + input = {**call, **{"type": "tool_call"}} + + # 如果这个工具有预设参数,则添加到输入中 + if call["name"] in self.preset_params: + preset_for_tool = self.preset_params[call["name"]] + # 预设参数优先级较低,不会覆盖用户传递的参数 + merged_args = {**preset_for_tool, **call.get("args", {})} + input["args"] = merged_args + print(f"🔧 为工具 {call['name']} 添加预设参数: {preset_for_tool}") + print(f"🔧 最终参数: {merged_args}") + + tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( + input, config + ) + tool_message.content = cast( + Union[str, list], msg_content_output(tool_message.content) + ) + return tool_message + except GraphInterrupt as e: + raise e + except Exception as e: + # 使用父类的错误处理逻辑 + if isinstance(self.handle_tool_errors, tuple): + handled_types: tuple = self.handle_tool_errors + elif callable(self.handle_tool_errors): + handled_types = _infer_handled_types(self.handle_tool_errors) + else: + handled_types = (Exception,) + + if not self.handle_tool_errors or not isinstance(e, handled_types): + raise e + else: + content = _handle_tool_error(e, flag=self.handle_tool_errors) + return ToolMessage( + content=content, name=call["name"], tool_call_id=call["id"], status="error" + ) + + + +async def main(): + # 获取 Tavily API key + tavily_api_key = settings.get("TAVILY_API_KEY", None) + if tavily_api_key is None: + print("错误: 未设置 TAVILY_API_KEY") + return - # 使用支持工具的模拟模型 - model = FakeListChatWithToolsModel(responses=["I need to check the weather for NYC"]) + model = get_model("gpt-4o-mini", + openai_format_api_key=settings.get("OPENAI_FORMAT_API_KEY", None), + openai_format_base_url=settings.get("OPENAI_FORMAT_BASE_URL", None), + anthropic_api_key=None, + gemini_api_key=None, + temperature=0.0, + max_output_tokens=15000, + ) + - # 创建工具节点 - tool_node = ToolNode(tools) - def call_model(state: MessagesState): + async def init_tool(): + # 使用 HTTP 传输直接连接到 Tavily MCP 服务器 + client = MultiServerMCPClient( + { + "tavily_web_search": { + "transport": "streamable_http", + "url": f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}", + } + } + ) + + + # 异步获取工具 + tools = await client.get_tools() + print(f"获取到的工具: {[tool.name for tool in tools]}") + # for tool in tools: + # print(f"\n工具名称: {tool.name}") + # if hasattr(tool, 'args_schema') and tool.args_schema: + # properties = tool.args_schema.get('properties', {}) + # # 简单的正则匹配设置默认值 + # for param_name in properties.keys(): + # param_lower = param_name.lower() + # if re.search(r'include.*domain', param_lower): + # properties[param_name]['default'] = ["pypi.org", "docs.python.org"] + # print(f" ✅ 设置 {param_name} 默认值: include domains") + # elif re.search(r'exclude.*domain', param_lower): + # properties[param_name]['default'] = ["stackoverflow.com", "*huggingface", "discourse.slicer.org","ask.csdn.net", + # "codepudding.com", "*geeksforgeeks*", "*github*", "forum.developer.parrot.com"] + # print(f" ✅ 设置 {param_name} 默认值: exclude domains") + + # elif re.search(r'search.*depth', param_lower): + # properties[param_name]['default'] = "advanced" + # print(f" ✅ 设置 {param_name} 默认值: advanced") + return tools + + + tools = await init_tool() + + async def call_model(state: MessagesState): messages = state["messages"] + print(f"\n=== call_model 被调用 ===") + print(f"输入消息数量: {len(messages)}") + + print(f"可用工具: {[tool.name for tool in tools]}") + + # 使用真实模型调用,绑定预设参数的工具 + model_with_tools = model.bind_tools(tools) + print("开始调用模型...") + + response = await model_with_tools.ainvoke(messages) + print(f"模型响应类型: {type(response)}") - # 检查是否已经有工具消息,如果有就结束 - if any(isinstance(msg, ToolMessage) for msg in messages): - return {"messages": [AIMessage(content="Weather check completed!")]} - # 第一次调用时创建工具调用响应 - response = AIMessage( - content="Let me check the weather for you", - tool_calls=[{ - "name": "get_weather", - "args": {"location": "nyc"}, - "id": "call_1" - }] - ) return {"messages": [response]} - # 构建图 + # 创建工具节点 builder = StateGraph(MessagesState) builder.add_node("call_model", call_model) - builder.add_node("tools", tool_node) + # builder.add_node("tools", CustomToolNode(tools, preset_params=preset_params)) + builder.add_node("tools", ToolNode(tools)) + + # 构建图 builder.add_edge(START, "call_model") builder.add_conditional_edges( "call_model", @@ -72,12 +204,33 @@ def call_model(state: MessagesState): graph = builder.compile() - # 执行测试 - weather_response = await graph.ainvoke({"messages": "what is the weather in nyc?"}) - print("Response:", weather_response) + # 执行测试 - 演示如何传递 include_domains 等参数 + # 注意:参数会在工具调用时由 LLM 自动传递,这里展示一个需要特定域名搜索的查询 + test_query = """ + ERROR: Could not find a version that satisfies the requirement opencv (from versions: none) + ERROR: No matching distribution found for opencv + 报错 + """ + + system_prompt = """\ + You are a web search assistant. When using the tavily_search tool, ALWAYS include these parameters: + - exclude_domains: ["stackoverflow.com", "*huggingface*", "discourse.slicer.org","ask.csdn.net", "codepudding.com", "*geeksforgeeks*", "*github*", "forum.developer.parrot.com"] + - include_domains: ['pypi.org', 'docs.python.org'] + - search_depth: "advanced" + + Make sure to explicitly pass these parameters in your tool call. + """ + # system_prompt = """\ + # You are a web search assistant. help the human to find the answer to the question. + # """ + + response = await graph.ainvoke({"messages": system_prompt + "\n" + test_query}) + # print("Response:", response) - return weather_response + return response + # 运行异步主函数 if __name__ == "__main__": - result = asyncio.run(main()) \ No newline at end of file + result = asyncio.run(main()) + print(result['messages'][-1].content) \ No newline at end of file diff --git a/tests/tools/test_mcp_web_search.py b/tests/tools/test_mcp_web_search.py new file mode 100644 index 00000000..7ff1de5b --- /dev/null +++ b/tests/tools/test_mcp_web_search.py @@ -0,0 +1,208 @@ +import os +import aiohttp +import asyncio +from pathlib import Path +from typing import Annotated, Optional +import json +from dataclasses import dataclass +from dynaconf.vendor.dotenv import load_dotenv +from pydantic import BaseModel, Field, field_validator +from prometheus.configuration.config import settings +from mcp.server.fastmcp import FastMCP +from prometheus.utils.logger_manager import get_logger +from prometheus.app.services.llm_service import LLMService, get_model + +logger = get_logger(__name__) +@dataclass +class MCPToolSpec: + description: str + input_schema: type + +model = get_model("gpt-4o-mini", + openai_format_api_key=settings.get("OPENAI_API_KEY", None), + openai_format_base_url=settings.get("OPENAI_BASE_URL", None), + anthropic_api_key=None, + gemini_api_key=None, + temperature=0.0, + max_output_tokens=15000, +) + + +# Initialize MCP server +mcp = FastMCP("WebSearchTool") + +# Get Tavily API key +tavily_api_key = settings.get("TAVILY_API_KEY", None) +if tavily_api_key is None: + logger.warning("Tavily API key is not set") + +# MCP server URL +TAVILY_SERVER_URL = f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}" + + +class WebSearchInput(BaseModel): + """Input parameters for web search.""" + query: Annotated[str, Field(description="Search query string")] + max_results: Annotated[int, Field(description="Maximum number of results", default=5)] + include_domains: Annotated[Optional[list[str]], Field(description="List of domains to include", default=None)] + exclude_domains: Annotated[Optional[list[str]], Field(description="List of domains to exclude", default=None)] + + +def format_results(response: dict) -> str: + """Format Tavily search results into a readable string.""" + output = [] + + # Add domain filter information if present + if response.get("included_domains") or response.get("excluded_domains"): + filters = [] + if response.get("included_domains"): + filters.append(f"Including domains: {', '.join(response['included_domains'])}") + if response.get("excluded_domains"): + filters.append(f"Excluding domains: {', '.join(response['excluded_domains'])}") + output.append("Search Filters:") + output.extend(filters) + output.append("") # Empty line for separation + + # Add answer if present + if response.get("answer"): + output.append(f"Answer: {response['answer']}") + output.append("\nSources:") + # Add immediate source references for the answer + for result in response.get("results", []): + output.append(f"- {result.get('title', 'No title')}: {result.get('url', 'No URL')}") + output.append("") # Empty line for separation + + # Add detailed results + output.append("Detailed Results:") + for result in response.get("results", []): + output.append(f"\nTitle: {result.get('title', 'No title')}") + output.append(f"URL: {result.get('url', 'No URL')}") + output.append(f"Content: {result.get('content', 'No content')}") + if result.get("published_date"): + output.append(f"Published: {result['published_date']}") + + return "\n".join(output) + + +class MCPWebSearchTool: + """Web search tool class.""" + + web_search_spec = MCPToolSpec( + description="""\ + Searches the web for technical information to aid in bug analysis and resolution. + Use this when you need external context, such as: + 1. Looking up unfamiliar error messages, exceptions, or stack traces. + 2. Finding official documentation or usage examples for a specific library, framework, or API. + 3. Searching for known bugs, common pitfalls, or compatibility issues related to a software version. + 4. Learning the best practices or design patterns for fixing a class of vulnerability (e.g., SQL injection, XSS). + + Queries should be specific and include relevant keywords like library names, version numbers, and error codes. + """, + input_schema=WebSearchInput + ) + + +@mcp.tool() +async def web_search( + query: str, + max_results: int = 5, + include_domains: Optional[list[str]] = None, + exclude_domains: Optional[list[str]] = None, +) -> str: + """\ + Searches the web for technical information to aid in bug analysis and resolution. + Use this when you need external context, such as: + 1. Looking up unfamiliar error messages, exceptions, or stack traces. + 2. Finding official documentation or usage examples for a specific library, framework, or API. + 3. Searching for known bugs, common pitfalls, or compatibility issues related to a software version. + 4. Learning the best practices or design patterns for fixing a class of vulnerability (e.g., SQL injection, XSS). + + Queries should be specific and include relevant keywords like library names, version numbers, and error codes. + + + Args: + query: Search query string + max_results: Maximum number of results (default: 5) + include_domains: List of domains to include (default: technical documentation sites) + exclude_domains: List of domains to exclude + """ + + # Check if API key is available + if tavily_api_key is None: + return "Error: Tavily API key is not set" + + # Default technical search domains + if include_domains is None: + include_domains = [ + 'stackoverflow.com', + 'github.com', + 'developer.mozilla.org', + 'learn.microsoft.com', + 'docs.python.org', + 'pydantic.dev', + 'pypi.org', + 'readthedocs.org', + 'docs.djangoproject.com', + 'flask.palletsprojects.com', + 'fastapi.tiangolo.com' + ] + + # Build request payload + payload = { + "query": query, + "max_results": max_results, + "include_domains": include_domains or [], + "exclude_domains": exclude_domains or [], + } + + try: + logger.info(f"Executing web search, query: {query}") + + # Use aiohttp to send HTTP request to MCP server + async with aiohttp.ClientSession() as session: + async with session.post(TAVILY_SERVER_URL, json=payload) as resp: + if resp.status != 200: + error_msg = f"HTTP error {resp.status}: {await resp.text()}" + logger.error(error_msg) + return error_msg + + data = await resp.json() + + # Format response + formatted_response = format_results(data) + logger.info(f"Web search completed, returned {len(data.get('results', []))} results") + + return formatted_response + + except aiohttp.ClientError as e: + error_msg = f"Network request error: {str(e)}" + logger.error(error_msg) + return error_msg + except json.JSONDecodeError as e: + error_msg = f"JSON parsing error: {str(e)}" + logger.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error occurred during search: {str(e)}" + logger.error(error_msg) + return error_msg + + +def run_mcp_server(): + """Run MCP server.""" + logger.info("Starting MCP Web search server...") + mcp.run() + + +if __name__ == "__main__": + # # Load environment variables + # load_dotenv() + + # # Get Tavily API key + # tavily_api_key = settings.get("TAVILY_API_KEY", None) + # if tavily_api_key is None: + # logger.warning("Tavily API key is not set") + # TAVILY_SERVER_URL = f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}" + + # Run server + run_mcp_server() From 21f381230dd052621516d432b86bacabe0961804 Mon Sep 17 00:00:00 2001 From: cocoli Date: Wed, 27 Aug 2025 17:39:20 +0800 Subject: [PATCH 08/30] fix --- .../lang_graph/nodes/context_provider_node.py | 2 +- prometheus/tools/graph_traversal.py | 168 ++++++++++-------- 2 files changed, 97 insertions(+), 73 deletions(-) diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index d78fb32b..ea599db0 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -155,7 +155,7 @@ def _init_tools(self): # Tool: Find file node by relative path # Preferred method when the exact file path is known find_file_node_with_relative_path_fn = functools.partial( - self.graph_traversal_tool.find_file_node_with_relative_path + self.graph_traversal_tool.find_file_node_with_relative_path, root_node_id=self.root_node_id, ) find_file_node_with_relative_path_tool = StructuredTool.from_function( diff --git a/prometheus/tools/graph_traversal.py b/prometheus/tools/graph_traversal.py index 3287db06..a94a0104 100644 --- a/prometheus/tools/graph_traversal.py +++ b/prometheus/tools/graph_traversal.py @@ -228,18 +228,20 @@ def __init__(self, driver: GraphDatabase.driver, max_token_per_result: int): # FileNode retrieval # ############################################################################### - def find_file_node_with_basename(self, basename: str) -> tuple[str, Sequence[Mapping[str, Any]]]: + def find_file_node_with_basename(self, basename: str, root_node_id: int) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f""" - MATCH (f:FileNode {{ basename: '{basename}' }}) + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode {{ basename: '{basename}' }}) + WHERE root.node_id = {root_node_id} RETURN f AS FileNode ORDER BY f.node_id LIMIT {MAX_RESULT} """ return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) - def find_file_node_with_relative_path(self, relative_path: str) -> tuple[str, Sequence[Mapping[str, Any]]]: + def find_file_node_with_relative_path(self, relative_path: str, root_node_id: int) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f""" - MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode {{ relative_path: '{relative_path}' }}) + WHERE root.node_id = {root_node_id} RETURN f AS FileNode ORDER BY f.node_id LIMIT {MAX_RESULT} @@ -251,48 +253,48 @@ def find_file_node_with_relative_path(self, relative_path: str) -> tuple[str, Se ############################################################################### def find_ast_node_with_text_in_file_with_basename( - self, text: str, basename: str + self, text: str, basename: str, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f"""\ - MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE f.basename = '{basename}' AND a.text CONTAINS '{text}' - RETURN c as FileNode, a AS ASTNode + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) + WHERE root.node_id = {root_node_id} AND f.basename = '{basename}' AND a.text CONTAINS '{text}' + RETURN f as FileNode, a AS ASTNode ORDER BY SIZE(a.text) LIMIT {MAX_RESULT} """ return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) def find_ast_node_with_text_in_file_with_relative_path( - self, text: str, relative_path: str + self, text: str, relative_path: str, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f"""\ - MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE f.relative_path = '{relative_path}' AND a.text CONTAINS '{text}' - RETURN c as FileNode, a AS ASTNode + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) + WHERE root.node_id = {root_node_id} AND f.relative_path = '{relative_path}' AND a.text CONTAINS '{text}' + RETURN f as FileNode, a AS ASTNode ORDER BY SIZE(a.text) LIMIT {MAX_RESULT} """ return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) def find_ast_node_with_type_in_file_with_basename( - self, type: str, basename: str + self, type: str, basename: str, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f"""\ - MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE f.basename = '{basename}' AND a.type = '{type}' - RETURN c as FileNode, a AS ASTNode + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) + WHERE root.node_id = {root_node_id} AND f.basename = '{basename}' AND a.type = '{type}' + RETURN f as FileNode, a AS ASTNode ORDER BY SIZE(a.text) LIMIT {MAX_RESULT} """ return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) def find_ast_node_with_type_in_file_with_relative_path( - self, type: str, relative_path: str + self, type: str, relative_path: str, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f"""\ - MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE f.relative_path = '{relative_path}' AND a.type = '{type}' - RETURN c as FileNode, a AS ASTNode + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) + WHERE root.node_id = {root_node_id} AND f.relative_path = '{relative_path}' AND a.type = '{type}' + RETURN f as FileNode, a AS ASTNode ORDER BY SIZE(a.text) LIMIT {MAX_RESULT} """ @@ -303,11 +305,11 @@ def find_ast_node_with_type_in_file_with_relative_path( ############################################################################### def find_text_node_with_text( - self, text: str + self, text: str, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f"""\ - MATCH (f:FileNode) -[:HAS_TEXT]-> (t:TextNode) - WHERE t.text CONTAINS '{text}' + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_TEXT]-> (t:TextNode) + WHERE root.node_id = {root_node_id} AND t.text CONTAINS '{text}' RETURN f as FileNode, t AS TextNode ORDER BY t.node_id LIMIT {MAX_RESULT} @@ -315,11 +317,11 @@ def find_text_node_with_text( return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) def find_text_node_with_text_in_file( - self, text: str, basename: str + self, text: str, basename: str, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f"""\ - MATCH (f:FileNode) -[:HAS_TEXT]-> (t:TextNode) - WHERE f.basename = '{basename}' AND t.text CONTAINS '{text}' + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_TEXT]-> (t:TextNode) + WHERE root.node_id = {root_node_id} AND f.basename = '{basename}' AND t.text CONTAINS '{text}' RETURN f as FileNode, t AS TextNode ORDER BY t.node_id LIMIT {MAX_RESULT} @@ -327,10 +329,11 @@ def find_text_node_with_text_in_file( return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) def get_next_text_node_with_node_id( - self, node_id: int + self, node_id: int, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: query = f"""\ - MATCH (f:FileNode) -[:HAS_TEXT]-> (a:TextNode {{ node_id: {node_id} }}) -[:NEXT_CHUNK]-> (b:TextNode) + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_TEXT]-> (a:TextNode {{ node_id: {node_id} }}) -[:NEXT_CHUNK]-> (b:TextNode) + WHERE root.node_id = {root_node_id} RETURN f as FileNode, b AS TextNode """ return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) @@ -340,27 +343,35 @@ def get_next_text_node_with_node_id( ############################################################################### def preview_file_content_with_basename( - self, basename: str + self, basename: str, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: source_code_query = f"""\ - MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_AST]-> (a:ASTNode) + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (a:ASTNode) + WHERE root.node_id = {root_node_id} AND f.basename = '{basename}' WITH f, apoc.text.split(a.text, '\\R') AS lines RETURN - f AS FileNode, - {{ - text: apoc.text.join(lines[0..1000], '\\n'), - start_line: 1, - end_line: 1000 - }} AS preview + f AS FileNode, + {{ + text: apoc.text.join(lines[0..1000], '\\n'), + start_line: 1, + end_line: 1000 + }} AS preview ORDER BY f.node_id - """ + """ text_query = f"""\ - MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_TEXT]-> (t:TextNode) - WHERE NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) - RETURN f as FileNode, t.text AS preview + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_TEXT]-> (t:TextNode) + WHERE root.node_id = {root_node_id} AND f.basename = '{basename}' + AND NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) + RETURN + f AS FileNode, + {{ + text: t.text, + start_line: 1, + end_line: 1000 + }} AS preview ORDER BY f.node_id - """ + """ if tree_sitter_parser.supports_file(Path(basename)): data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, self.driver) @@ -377,27 +388,36 @@ def preview_file_content_with_basename( return neo4j_util.format_neo4j_data(data, self.max_token_per_result), data def preview_file_content_with_relative_path( - self, relative_path: str + self, relative_path: str, root_node_id: int ) -> tuple[str, Sequence[Mapping[str, Any]]]: source_code_query = f"""\ - MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_AST]-> (a:ASTNode) - WITH f, apoc.text.split(a.text, '\\R') AS lines - RETURN - f as FileNode, + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (a:ASTNode) + WHERE root.node_id = {root_node_id} AND f.relative_path = '{relative_path}' + WITH f, apoc.text.split(a.text, '\\R') AS lines + RETURN + f AS FileNode, + {{ + text: apoc.text.join(lines[0..1000], '\\n'), + start_line: 1, + end_line: 1000 + }} AS preview + ORDER BY f.node_id + """ + + text_query = f"""\ + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_TEXT]-> (t:TextNode) + WHERE root.node_id = {root_node_id} AND f.relative_path = '{relative_path}' + AND NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) + RETURN + f AS FileNode, {{ - text: apoc.text.join(lines[0..1000], '\\n'), - start_line: 1, - end_line: 1000 + text: t.text, + start_line: 1, + end_line: 1000 }} AS preview - ORDER BY f.node_id - """ + ORDER BY f.node_id + """ - text_query = f"""\ - MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_TEXT]-> (t:TextNode) - WHERE NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) - RETURN f as FileNode, t.text AS preview - ORDER BY f.node_id - """ if tree_sitter_parser.supports_file(Path(relative_path)): data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, self.driver) @@ -418,22 +438,24 @@ def read_code_with_basename( basename: str, start_line: int, end_line: int, + root_node_id: int, ) -> tuple[str, Union[Sequence[Mapping[str, Any]], None]]: if end_line < start_line: return f"end_line {end_line} must be greater than start_line {start_line}", None source_code_query = f"""\ - MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_AST]-> (a:ASTNode) + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (a:ASTNode) + WHERE root.node_id = {root_node_id} AND f.basename = '{basename}' WITH f, apoc.text.split(a.text, '\\R') AS lines RETURN - f as FileNode, - {{ - text: apoc.text.join(lines[{start_line - 1}..{end_line - 1}], '\\n'), - start_line: {start_line}, - end_line: {end_line} - }} AS SelectedLines + f as FileNode, + {{ + text: apoc.text.join(lines[{start_line - 1}..{end_line - 1}], '\\n'), + start_line: {start_line}, + end_line: {end_line} + }} AS SelectedLines ORDER BY f.node_id - """ + """ data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, self.driver) for result in data: result["SelectedLines"]["text"] = pre_append_line_numbers( @@ -446,21 +468,23 @@ def read_code_with_relative_path( relative_path: str, start_line: int, end_line: int, + root_node_id: int, ) -> tuple[str, Union[Sequence[Mapping[str, Any]], None]]: if end_line < start_line: return f"end_line {end_line} must be greater than start_line {start_line}", None source_code_query = f"""\ - MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_AST]-> (a:ASTNode) - WITH f, apoc.text.split(a.text, '\\R') AS lines - RETURN - f as FileNode, - {{ + MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (a:ASTNode) + WHERE root.node_id = {root_node_id} AND f.relative_path = '{relative_path}' + WITH f, apoc.text.split(a.text, '\\R') AS lines + RETURN + f as FileNode, + {{ text: apoc.text.join(lines[{start_line - 1}..{end_line - 1}], '\\n'), start_line: {start_line}, end_line: {end_line} - }} AS SelectedLines - ORDER BY f.node_id + }} AS SelectedLines + ORDER BY f.node_id """ data = neo4j_util.run_neo4j_query_without_formatting(source_code_query, self.driver) From bcd2b827fd6767de361c120cf537bf0295702f27 Mon Sep 17 00:00:00 2001 From: cocoli Date: Mon, 1 Sep 2025 11:31:21 +0800 Subject: [PATCH 09/30] =?UTF-8?q?adapt=20=E2=80=99main=E2=80=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- prometheus/app/api/routes/repository.py | 4 +- prometheus/app/services/database_service.py | 3 +- prometheus/configuration/config.py | 3 ++ .../lang_graph/nodes/bug_fix_verify_node.py | 2 +- .../nodes/bug_reproducing_file_node.py | 2 +- .../nodes/run_regression_tests_node.py | 21 +++++------ prometheus/tools/container_command.py | 37 ++++++++++++++----- prometheus/tools/web_search.py | 2 +- prometheus/utils/logger_manager.py | 21 +++++++---- 9 files changed, 61 insertions(+), 34 deletions(-) diff --git a/prometheus/app/api/routes/repository.py b/prometheus/app/api/routes/repository.py index 898f7153..b1d23e93 100644 --- a/prometheus/app/api/routes/repository.py +++ b/prometheus/app/api/routes/repository.py @@ -170,7 +170,7 @@ def list_repositories(request: Request): response_model=Response, ) @requireLogin -def delete(repository_id: int, request: Request): +def delete(repository_id: int, request: Request, force: bool = False, ): knowledge_graph_service: KnowledgeGraphService = request.app.state.service[ "knowledge_graph_service" ] @@ -180,7 +180,7 @@ def delete(repository_id: int, request: Request): if not repository: raise ServerException(code=404, message="Repository not found") # Check if the repository is being processed - if repository.is_working: + if repository.is_working and not force: raise ServerException( code=400, message="Repository is currently being processed, please try again later" ) diff --git a/prometheus/app/services/database_service.py b/prometheus/app/services/database_service.py index ae4cfc7b..4c0a2048 100644 --- a/prometheus/app/services/database_service.py +++ b/prometheus/app/services/database_service.py @@ -3,12 +3,13 @@ from sqlmodel import SQLModel, create_engine from prometheus.app.services.base_service import BaseService +from prometheus.utils.logger_manager import get_logger class DatabaseService(BaseService): def __init__(self, DATABASE_URL: str): self.engine = create_engine(DATABASE_URL, echo=True) - self._logger = logging.getLogger("prometheus.app.services.database_service") + self._logger = get_logger(__name__) # Create the database and tables def create_db_and_tables(self): diff --git a/prometheus/configuration/config.py b/prometheus/configuration/config.py index c754f47c..cfdb9849 100644 --- a/prometheus/configuration/config.py +++ b/prometheus/configuration/config.py @@ -67,5 +67,8 @@ class Settings(BaseSettings): # Default normal user repository number DEFAULT_USER_REPOSITORY_LIMIT: int = 5 + # tool for Websearch + TAVILY_API_KEY: str + settings = Settings() diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index 3e79b1db..3028ea8e 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -8,7 +8,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerificationState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.logger_manager import get_logger diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index 859b8f56..72872d4a 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -39,7 +39,7 @@ class BugReproducingFileNode: def __init__(self, model: BaseChatModel, kg: KnowledgeGraph, local_path: str): self.kg = kg - self.file_operation_tool = FileOperationTool(str(kg.get_local_path())) + self.file_operation_tool = FileOperationTool(local_path) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) diff --git a/prometheus/lang_graph/nodes/run_regression_tests_node.py b/prometheus/lang_graph/nodes/run_regression_tests_node.py index 0a0768ae..e791f77f 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_node.py @@ -1,5 +1,4 @@ import functools -import logging import threading from langchain.tools import StructuredTool @@ -8,7 +7,8 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.run_regression_tests_state import RunRegressionTestsState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool +from prometheus.utils.logger_manager import get_logger class RunRegressionTestsNode: @@ -55,22 +55,21 @@ class RunRegressionTestsNode: """ def __init__(self, model: BaseChatModel, container: BaseContainer): - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_regression_tests_node" - ) + self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/tools/container_command.py b/prometheus/tools/container_command.py index d92fa65e..3f3bd92e 100644 --- a/prometheus/tools/container_command.py +++ b/prometheus/tools/container_command.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field -from prometheus.docker.base_container import BaseContainer +from prometheus.docker.general_container import GeneralContainer from dataclasses import dataclass +from prometheus.docker.base_container import BaseContainer @dataclass class ToolSpec: @@ -11,11 +12,29 @@ class RunCommandInput(BaseModel): command: str = Field("The shell command to be run in the container") -RUN_COMMAND_DESCRIPTION = """\ -Run a shell command in the container and return the result of the command. You are always at the root -of the codebase. -""" - - -def run_command(command: str, container: BaseContainer) -> str: - return container.execute_command(command) +class ContainerCommandTool: + """Tool class for executing shell commands in containers.""" + + run_command_spec = ToolSpec( + description="""\ + Run a shell command in the container and return the result of the command. You are always at the root + of the codebase. + """, + input_schema=RunCommandInput + ) + + def __init__(self, container: BaseContainer): + """Initialize the container command tool. + Args: + container: The GeneralContainer instance to execute commands in. + """ + self.container = container + + def run_command(self, command: str) -> str: + """Run a shell command in the container and return the result. + Args: + command: The shell command to be run in the container. + Returns: + The output of the command execution. + """ + return self.container.execute_command(command) diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 96e0fb63..918ea934 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -20,7 +20,7 @@ class ToolSpec: input_schema: type -tavily_api_key = settings.get("TAVILY_API_KEY", None) +tavily_api_key = settings.TAVILY_API_KEY if tavily_api_key is None: logger.warning("Tavily API key is not set") tavily_client = None diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index d76566e6..a6a3413d 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -129,17 +129,22 @@ def _setup_root_logger(self): def _log_configuration(self): """Log configuration information""" - config_attrs = [ - 'LOGGING_LEVEL', 'ADVANCED_MODEL', 'BASE_MODEL', 'NEO4J_BATCH_SIZE', - 'WORKING_DIRECTORY', 'KNOWLEDGE_GRAPH_MAX_AST_DEPTH', - 'KNOWLEDGE_GRAPH_CHUNK_SIZE', 'KNOWLEDGE_GRAPH_CHUNK_OVERLAP', - 'MAX_TOKEN_PER_NEO4J_RESULT', 'TEMPERATURE', 'MAX_INPUT_TOKENS', - 'MAX_OUTPUT_TOKENS' - ] + # 动态获取settings中所有可用的配置属性 + config_attrs = [attr for attr in dir(settings) + if attr.isupper() and not attr.startswith('_')] for attr in config_attrs: value = getattr(settings, attr, 'Not Set') - self.root_logger.info(f"{attr}={value}") + + # 使用通配符匹配敏感配置项(包含KEY、API、PASSWORD的) + is_sensitive = any(keyword in attr.upper() for keyword in ['KEY', 'API', 'PASSWORD', "SECRET"]) + + # 如果是敏感配置项,用星号代替 + if is_sensitive and value and value != 'Not Set': + masked_value = '*' * min(len(str(value)), 8) # 最多显示8个星号 + self.root_logger.info(f"{attr}={masked_value}") + else: + self.root_logger.info(f"{attr}={value}") def get_logger(self, name: str) -> logging.Logger: """ From 16881b61ba50d8406bda52369d3bff06c7433ab8 Mon Sep 17 00:00:00 2001 From: cocoli Date: Wed, 3 Sep 2025 00:28:47 +0800 Subject: [PATCH 10/30] fix tools --- .../lang_graph/nodes/context_provider_node.py | 57 +- prometheus/tools/file_operation.py | 59 +- prometheus/tools/graph_traversal.py | 987 ++++++++---------- 3 files changed, 517 insertions(+), 586 deletions(-) diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index ab9211c0..a8d9c163 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -16,6 +16,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.tools.graph_traversal import GraphTraversalTool +from prometheus.tools.file_operation import FileOperationTool from prometheus.utils.logger_manager import get_logger @@ -89,6 +90,7 @@ def __init__( self, model: BaseChatModel, kg: KnowledgeGraph, + local_path: str, ): """Initializes the ContextProviderNode with model, knowledge graph, and database connection. @@ -103,11 +105,13 @@ def __init__( kg: Knowledge graph instance containing the processed codebase structure. Used to obtain the file tree for system prompts. """ - self.neo4j_driver = neo4j_driver + # self.neo4j_driver = neo4j_driver self.root_node_id = kg.root_node_id - self.max_token_per_result = max_token_per_result + self.kg = kg + # self.max_token_per_result = max_token_per_result # Initialize GraphTraversalTool with the driver and token limit - self.graph_traversal_tool = GraphTraversalTool(neo4j_driver, max_token_per_result) + self.graph_traversal_tool = GraphTraversalTool(kg) + self.file_operation_tool = FileOperationTool(local_path, kg) ast_node_types_str = ", ".join(kg.get_all_ast_node_types()) @@ -134,8 +138,7 @@ def _init_tools(self): # Tool: Find file node by filename (basename) # Used when only the filename (not full path) is known find_file_node_with_basename_fn = functools.partial( - graph_traversal.find_file_node_with_basename, - kg=self.kg, + self.graph_traversal_tool.find_file_node_with_basename ) find_file_node_with_basename_tool = StructuredTool.from_function( func=find_file_node_with_basename_fn, @@ -149,8 +152,7 @@ def _init_tools(self): # Tool: Find file node by relative path # Preferred method when the exact file path is known find_file_node_with_relative_path_fn = functools.partial( - graph_traversal.find_file_node_with_relative_path, - kg=self.kg, + self.graph_traversal_tool.find_file_node_with_relative_path ) find_file_node_with_relative_path_tool = StructuredTool.from_function( func=find_file_node_with_relative_path_fn, @@ -166,8 +168,7 @@ def _init_tools(self): # Tool: Find AST node by text match in file (by basename) # Useful for searching specific snippets or patterns in unknown locations find_ast_node_with_text_in_file_with_basename_fn = functools.partial( - graph_traversal.find_ast_node_with_text_in_file_with_basename, - kg=self.kg, + self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename ) find_ast_node_with_text_in_file_with_basename_tool = StructuredTool.from_function( func=find_ast_node_with_text_in_file_with_basename_fn, @@ -180,8 +181,7 @@ def _init_tools(self): # Tool: Find AST node by text match in file (by relative path) find_ast_node_with_text_in_file_with_relative_path_fn = functools.partial( - graph_traversal.find_ast_node_with_text_in_file_with_relative_path, - kg=self.kg, + self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path ) find_ast_node_with_text_in_file_with_relative_path_tool = StructuredTool.from_function( func=find_ast_node_with_text_in_file_with_relative_path_fn, @@ -195,8 +195,7 @@ def _init_tools(self): # Tool: Find AST node by type in file (by basename) # Example types: FunctionDef, ClassDef, Assign, etc. find_ast_node_with_type_in_file_with_basename_fn = functools.partial( - graph_traversal.find_ast_node_with_type_in_file_with_basename, - kg=self.kg, + self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename ) find_ast_node_with_type_in_file_with_basename_tool = StructuredTool.from_function( func=find_ast_node_with_type_in_file_with_basename_fn, @@ -209,8 +208,7 @@ def _init_tools(self): # Tool: Find AST node by type in file (by relative path) find_ast_node_with_type_in_file_with_relative_path_fn = functools.partial( - graph_traversal.find_ast_node_with_type_in_file_with_relative_path, - kg=self.kg, + self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path ) find_ast_node_with_type_in_file_with_relative_path_tool = StructuredTool.from_function( func=find_ast_node_with_type_in_file_with_relative_path_fn, @@ -225,8 +223,7 @@ def _init_tools(self): # Tool: Find text node globally by keyword find_text_node_with_text_fn = functools.partial( - graph_traversal.find_text_node_with_text, - kg=self.kg, + self.graph_traversal_tool.find_text_node_with_text ) find_text_node_with_text_tool = StructuredTool.from_function( func=find_text_node_with_text_fn, @@ -239,8 +236,7 @@ def _init_tools(self): # Tool: Find text node by keyword in specific file find_text_node_with_text_in_file_fn = functools.partial( - graph_traversal.find_text_node_with_text_in_file, - kg=self.kg, + self.graph_traversal_tool.find_text_node_with_text_in_file ) find_text_node_with_text_in_file_tool = StructuredTool.from_function( func=find_text_node_with_text_in_file_fn, @@ -253,8 +249,7 @@ def _init_tools(self): # Tool: Fetch the next text node chunk in a chain (used for long docs/comments) get_next_text_node_with_node_id_fn = functools.partial( - graph_traversal.get_next_text_node_with_node_id, - kg=self.kg, + self.graph_traversal_tool.get_next_text_node_with_node_id ) get_next_text_node_with_node_id_tool = StructuredTool.from_function( func=get_next_text_node_with_node_id_fn, @@ -268,23 +263,21 @@ def _init_tools(self): # === FILE PREVIEW & READING TOOLS === # Tool: Preview contents of file by relative path - preview_file_content_with_relative_path_fn = functools.partial( - graph_traversal.preview_file_content_with_relative_path, - kg=self.kg, + read_file_fn = functools.partial( + self.file_operation_tool.read_file_with_knowledge_graph_data ) - preview_file_content_with_relative_path_tool = StructuredTool.from_function( - func=preview_file_content_with_relative_path_fn, - name=self.graph_traversal_tool.preview_file_content_with_relative_path.__name__, - description=self.graph_traversal_tool.preview_file_content_with_relative_path_spec.description, - args_schema=self.graph_traversal_tool.preview_file_content_with_relative_path_spec.input_schema, + read_file_tool = StructuredTool.from_function( + func=read_file_fn, + name=self.file_operation_tool.read_file_with_knowledge_graph_data.__name__, + description=self.file_operation_tool.read_file_spec.description, + args_schema=self.file_operation_tool.read_file_spec.input_schema, response_format="content_and_artifact", ) - tools.append(preview_file_content_with_relative_path_tool) + tools.append(read_file_tool) # Tool: Read entire code file by relative path read_code_with_relative_path_fn = functools.partial( - graph_traversal.read_code_with_relative_path, - kg=self.kg, + self.graph_traversal_tool.read_code_with_relative_path ) read_code_with_relative_path_tool = StructuredTool.from_function( func=read_code_with_relative_path_fn, diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index 1861a9a8..83efb463 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -3,9 +3,11 @@ from pathlib import Path from dataclasses import dataclass from pydantic import BaseModel, Field - +from typing import Any, Dict, List, Tuple, Union from prometheus.utils.str_util import pre_append_line_numbers from prometheus.utils.logger_manager import get_logger +from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.utils.knowledge_graph_utils import format_knowledge_graph_data logger = get_logger(__name__) @@ -107,14 +109,18 @@ class FileOperationTool: input_schema=EditFileInput ) - def __init__(self, root_path: str): + def __init__(self, root_path: str, kg: KnowledgeGraph): """Initialize the file operation tool. - Args: root_path: The root path of the codebase for relative path operations. + kg: The knowledge graph for context provider node. + Args: + root_path: The root path of the codebase for relative path operations. + kg: The knowledge graph for context provider node. """ self.root_path = root_path - + self.kg = kg + def read_file(self, relative_path: str, n_lines: int = 1000) -> str: if os.path.isabs(relative_path): return f"relative_path: {relative_path} is a abolsute path, not relative path." @@ -203,6 +209,51 @@ def edit_file(self, relative_path: str, old_content: str, new_content: str) -> s return f"Successfully edited {relative_path}." + def read_file_with_knowledge_graph_data(self, relative_path: str) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: + """ + Read the content of a file and return it along with structured knowledge graph data. + Used for context provider node + """ + + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a absolute path, not relative path.", None + + file_node = None + for file_node_ in self.kg.get_file_nodes(): + if file_node_.node.relative_path == relative_path: + file_node = file_node_ + break + + # Check if file node exists in the knowledge graph + if not file_node: + return f"The file {relative_path} does not exist.", None + + file_path = Path(os.path.join(self.root_path, file_node.node.relative_path)) + + # Read the file content + with file_path.open() as f: + lines = f.readlines() + # Limit to first 1000 lines to avoid context issues + selected_text_with_line_numbers = pre_append_line_numbers("".join(lines[:1000]), 1) + + result_data = [ + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "preview": { + "text": selected_text_with_line_numbers, + "start_line": 1, + "end_line": len(selected_text_with_line_numbers), + }, + } + ] + return format_knowledge_graph_data(result_data), result_data + + + READ_FILE_DESCRIPTION = """\ Read the content of a file with line numbers prepended from the codebase with a safety limit on the number of lines. diff --git a/prometheus/tools/graph_traversal.py b/prometheus/tools/graph_traversal.py index 2ac82608..e3cb0bfd 100644 --- a/prometheus/tools/graph_traversal.py +++ b/prometheus/tools/graph_traversal.py @@ -20,568 +20,526 @@ Returns a list of dictionaries containing the found nodes and their attributes. """ - -############################################################################### -# FileNode retrieval # -############################################################################### +@dataclass +class ToolSpec: + description: str + input_schema: type class FindFileNodeWithBasenameInput(BaseModel): basename: str = Field("The basename of FileNode to search for") - -FIND_FILE_NODE_WITH_BASENAME_DESCRIPTION = """\ -Find all FileNode in the graph with this basename of a file/dir. The basename must -include the extension, like 'bar.py', 'baz.java' or 'foo' -(in this case foo is a directory or a file without extension). - -You can use this tool to check if a file/dir with this basename exists or get all -attributes related to the file/dir.""" - - -def find_file_node_with_basename( - basename: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all FileNodes with the given basename.""" - results = [] - for kg_node in kg.get_file_nodes(): - if kg_node.node.basename == basename: - results.append( - { - "FileNode": { - "node_id": kg_node.node_id, - "basename": kg_node.node.basename, - "relative_path": kg_node.node.relative_path, - } - } - ) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - - class FindFileNodeWithRelativePathInput(BaseModel): relative_path: str = Field("The relative_path of FileNode to search for") +class FindASTNodeWithTextInFileWithBasenameInput(BaseModel): + text: str = Field("Search ASTNode that exactly contains this text.") + basename: str = Field("The basename of file/directory to search under for ASTNodes.") -FIND_FILE_NODE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Search FileNode in the graph with this relative_path of a file/dir. The relative_path is -the relative path from the root path of codebase. The relative_path must include the extension, -like 'foo/bar/baz.java'. +class FindASTNodeWithTextInFileWithRelativePathInput(BaseModel): + text: str = Field("Search ASTNode that exactly contains this text.") + relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") -You can use this tool to check if a file/dir with this relative_path exists or get all -attributes related to the file/dir.""" +class FindASTNodeWithTypeInFileWithBasenameInput(BaseModel): + type: str = Field("Search ASTNode with this tree-sitter node type.") + basename: str = Field("The basename of file/directory to search under for ASTNodes.") +class FindASTNodeWithTypeInFileWithRelativePathInput(BaseModel): + type: str = Field("Search ASTNode with this tree-sitter node type.") + relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") -def find_file_node_with_relative_path( - relative_path: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all FileNodes with the given relative path.""" - results = [] - for kg_node in kg.get_file_nodes(): - if kg_node.node.relative_path == relative_path: - results.append( - { - "FileNode": { - "node_id": kg_node.node_id, - "basename": kg_node.node.basename, - "relative_path": kg_node.node.relative_path, - } - } - ) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] +class FindTextNodeWithTextInput(BaseModel): + text: str = Field("Search TextNode that exactly contains this text.") +class FindTextNodeWithTextInFileInput(BaseModel): + text: str = Field("Search TextNode that exactly contains this text.") + basename: str = Field("The basename of FileNode to search TextNode.") -############################################################################### -# ASTNode retrieval # -############################################################################### +class GetNextTextNodeWithNodeIdInput(BaseModel): + node_id: int = Field("Get the next TextNode of this given node_id.") +class PreviewFileContentWithBasenameInput(BaseModel): + basename: str = Field("The basename of FileNode to preview.") -def find_ast_node_with_text_in_file( - text: str, target_files_nodes: List[KnowledgeGraphNode], kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes containing the given text in files with the given basename.""" - results = [] +class PreviewFileContentWithRelativePathInput(BaseModel): + relative_path: str = Field("The relative path of FileNode to preview.") - # Get HAS_AST edges to find which AST nodes belong to these files - has_ast_edges = kg.get_has_ast_edges() - file_to_ast_map = { - edge.source.node_id: edge.target - for edge in has_ast_edges - if edge.source.node_id in [n.node_id for n in target_files_nodes] - } +class ReadCodeWithBasenameInput(BaseModel): + basename: str = Field("The basename of FileNode to read.") + start_line: int = Field("The starting line number, 1-indexed and inclusive.") + end_line: int = Field("The ending line number, 1-indexed and exclusive.") - # Get PARENT_OF edges to traverse AST tree - parent_of_edges = kg.get_parent_of_edges() +class ReadCodeWithRelativePathInput(BaseModel): + relative_path: str = Field("The relative path of FileNode to read from root of codebase.") + start_line: int = Field("The starting line number, 1-indexed and inclusive.") + end_line: int = Field("The ending line number, 1-indexed and exclusive.") - for file_node in target_files_nodes: - # Start with root AST node for this file - root_ast = file_to_ast_map[file_node.node_id] - # Add all descendant AST nodes - stack = [root_ast] - while stack: - current_node = stack.pop() - # Check if current node contains the text - if text in current_node.node.text: +class GraphTraversalTool: + + # FileNode retrieval tools + find_file_node_with_basename_spec = ToolSpec( + description="""Find all FileNode in the graph with this basename of a file/dir. The basename must + include the extension, like 'bar.py', 'baz.java' or 'foo' + (in this case foo is a directory or a file without extension). + + You can use this tool to check if a file/dir with this basename exists or get all + attributes related to the file/dir.""", + input_schema=FindFileNodeWithBasenameInput + ) + + find_file_node_with_relative_path_spec = ToolSpec( + description="""Search FileNode in the graph with this relative_path of a file/dir. The relative_path is + the relative path from the root path of codebase. The relative_path must include the extension, + like 'foo/bar/baz.java'. + + You can use this tool to check if a file/dir with this relative_path exists or get all + attributes related to the file/dir.""", + input_schema=FindFileNodeWithRelativePathInput + ) + + # ASTNode retrieval tools + find_ast_node_with_text_in_file_with_basename_spec = ToolSpec( + description="""Find all ASTNode in the graph that exactly contains this text in any source file with this basename. + For reliable results, search for longer, distinct text sequences rather than short common words or fragments. + The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is looking for exact matches. + For best results, use unique text segments of at least several words. The basename can be either a file (like + 'bar.py', 'baz.java').""", + input_schema=FindASTNodeWithTextInFileWithBasenameInput + ) + + find_ast_node_with_text_in_file_with_relative_path_spec = ToolSpec( + description="""Find all ASTNode in the graph that exactly contains this text in any source file with this relative path. + For reliable results, search for longer, distinct text sequences rather than short common words or fragments. + The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is looking for exact matches. + Therefore the search text should be exact as well. The relative path should be the path from the root of codebase + (like 'src/core/parser.py').""", + input_schema=FindASTNodeWithTextInFileWithRelativePathInput + ) + + find_ast_node_with_type_in_file_with_basename_spec = ToolSpec( + description="""Find all ASTNode in the graph that has this tree-sitter node type in any source file with this basename. + This tool is useful for searching class/function/method under files.""", + input_schema=FindASTNodeWithTypeInFileWithBasenameInput + ) + + find_ast_node_with_type_in_file_with_relative_path_spec = ToolSpec( + description="""Find all ASTNode in the graph that has this tree-sitter node type in any source file with this relative path. + This tool is useful for searching class/function/method under a file.""", + input_schema=FindASTNodeWithTypeInFileWithRelativePathInput + ) + + # TextNode retrieval tools + find_text_node_with_text_spec = ToolSpec( + description="""Find all TextNode in the graph that exactly contains this text. The contains is + same as python's check `'foo' in text`, ie. it is case sensitive and is + looking for exact matches. Therefore the search text should be exact as well. + + You can use this tool to find all text/documentation in codebase that contains this text.""", + input_schema=FindTextNodeWithTextInput + ) + + find_text_node_with_text_in_file_spec = ToolSpec( + description="""Find all TextNode in the graph that exactly contains this text in a file with this basename. + The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is + looking for exact matches. Therefore the search text should be exact as well. + The basename must include the extension, like 'bar.py', 'baz.java' or 'foo' + (in this case foo is a file without extension). + + You can use this tool to find text/documentation in a specific file that contains this text.""", + input_schema=FindTextNodeWithTextInFileInput + ) + + get_next_text_node_with_node_id_spec = ToolSpec( + description="""Get the next TextNode of this given node_id. + + You can use this tool to read the next section of text that you are interested in.""", + input_schema=GetNextTextNodeWithNodeIdInput + ) + + read_code_with_relative_path_spec = ToolSpec( + description="""Read a specific section of a source code file's content by specifying its relative path and line range. + The relative path must be the full path from the root of codebase, like 'src/core/parser.py' or + 'test/unit/test_parser.java'. + + This tool ONLY works with source code files (not text files or documentation). It is designed + to read large sections of code at once - you should request substantial chunks (hundreds of lines) + rather than making multiple small requests of 10-20 lines each, which would be inefficient. + + Line numbers are 1-indexed, where start_line is inclusive and end_line is exclusive. + + This tool is useful for examining specific sections of source code files when you know + the exact line range you want to analyze. The function will return an error message if + end_line is less than start_line.""", + input_schema=ReadCodeWithRelativePathInput + ) + + def __init__(self, kg: KnowledgeGraph): + self.kg = kg + + + ############################################################################### + # FileNode retrieval # + ############################################################################### + + def find_file_node_with_basename(self, basename: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all FileNodes with the given basename.""" + results = [] + for kg_node in self.kg.get_file_nodes(): + if kg_node.node.basename == basename: results.append( { "FileNode": { - "node_id": file_node.node_id, - "basename": file_node.node.basename, - "relative_path": file_node.node.relative_path, - }, - "ASTNode": { - "node_id": current_node.node_id, - "type": current_node.node.type, - "start_line": current_node.node.start_line, - "end_line": current_node.node.end_line, - "text": current_node.node.text, - }, + "node_id": kg_node.node_id, + "basename": kg_node.node.basename, + "relative_path": kg_node.node.relative_path, + } } ) - - # Add children to stack - stack += [ - edge.target - for edge in parent_of_edges - if edge.source.node_id == current_node.node_id - ] - - # Sort by text length (smaller first) - results.sort(key=lambda x: len(x["ASTNode"]["text"])) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - - -class FindASTNodeWithTextInFileWithBasenameInput(BaseModel): - text: str = Field("Search ASTNode that exactly contains this text.") - basename: str = Field("The basename of file/directory to search under for ASTNodes.") - - -FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_BASENAME_DESCRIPTION = """\ -Find all ASTNode in the graph that exactly contains this text in any source file under -a file/directory with this basename. For reliable results, search for longer, distinct text -sequences rather than short common words or fragments. The contains is same as python's check -`'foo' in text`, ie. it is case sensitive and is looking for exact matches. For best results, -use unique text segments of at least several words. The basename can be either a file (like -'bar.py', 'baz.java') or a directory (like 'src' or 'test').""" - - -def find_ast_node_with_text_in_file_with_basename( - text: str, basename: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes containing the given text in files with the given basename.""" - # Get file nodes with the given basename - target_files_nodes: List[KnowledgeGraphNode] = [ - node for node in kg.get_file_nodes() if node.node.basename == basename - ] - return find_ast_node_with_text_in_file(text, target_files_nodes, kg) - - -class FindASTNodeWithTextInFileWithRelativePathInput(BaseModel): - text: str = Field("Search ASTNode that exactly contains this text.") - relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") - - -FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Find all ASTNode in the graph that exactly contains this text in any source file under -a file/directory with this relative path. For reliable results, search for longer, distinct text -sequences rather than short common words or fragments. The contains is same as python's check `'foo' in text`, -ie. it is case sensitive and is looking for exact matches. Therefore the search text should -be exact as well. The relative path should be the path from the root of codebase -(like 'src/core/parser.py' or 'test/unit').""" - - -def find_ast_node_with_text_in_file_with_relative_path( - text: str, relative_path: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes containing the given text in files with the given relative path.""" - # Get file nodes with the given basename - target_files_nodes: List[KnowledgeGraphNode] = [ - node for node in kg.get_file_nodes() if node.node.relative_path == relative_path - ] - return find_ast_node_with_text_in_file(text, target_files_nodes, kg) - - -def find_ast_node_with_type_in_file( - type: str, target_files_nodes: List[KnowledgeGraphNode], kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes containing the given text in files with the given basename.""" - results = [] - - # Get HAS_AST edges to find which AST nodes belong to these files - has_ast_edges = kg.get_has_ast_edges() - file_to_ast_map = { - edge.source.node_id: edge.target - for edge in has_ast_edges - if edge.source.node_id in [n.node_id for n in target_files_nodes] - } - - # Get PARENT_OF edges to traverse AST tree - parent_of_edges = kg.get_parent_of_edges() - - for file_node in target_files_nodes: - # Start with root AST node for this file - root_ast = file_to_ast_map[file_node.node_id] - - # Add all descendant AST nodes - stack = [root_ast] - while stack: - current_node = stack.pop() - - # Check if current node contains the text - if current_node.node.type == type: + results.sort(key=lambda x: x["FileNode"]["node_id"]) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] + + def find_file_node_with_relative_path(self, relative_path: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all FileNodes with the given relative path.""" + results = [] + for kg_node in self.kg.get_file_nodes(): + if kg_node.node.relative_path == relative_path: results.append( { "FileNode": { - "node_id": file_node.node_id, - "basename": file_node.node.basename, - "relative_path": file_node.node.relative_path, - }, - "ASTNode": { - "node_id": current_node.node_id, - "type": current_node.node.type, - "start_line": current_node.node.start_line, - "end_line": current_node.node.end_line, - "text": current_node.node.text, - }, + "node_id": kg_node.node_id, + "basename": kg_node.node.basename, + "relative_path": kg_node.node.relative_path, + } } ) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - # Add children to stack - stack += [ - edge.target - for edge in parent_of_edges - if edge.source.node_id == current_node.node_id - ] - - # Sort by text length (smaller first) - results.sort(key=lambda x: len(x["ASTNode"]["text"])) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - - -class FindASTNodeWithTypeInFileWithBasenameInput(BaseModel): - type: str = Field("Search ASTNode with this tree-sitter node type.") - basename: str = Field("The basename of file/directory to search under for ASTNodes.") - - -FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_BASENAME_DESCRIPTION = """\ -Find all ASTNode in the graph that has this tree-sitter node type in any source file under -a file/directory with this basename. This tool is useful for searching class/function/method -under a file/directory. The basename can be either a file (like 'bar.py', -'baz.java') or a directory (like 'core' or 'test').""" + ############################################################################### + # ASTNode retrieval # + ############################################################################### -def find_ast_node_with_type_in_file_with_basename( - type: str, basename: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes with the given type in files with the given basename.""" - # Get file nodes with the given basename - target_files_nodes: List[KnowledgeGraphNode] = [ - node for node in kg.get_file_nodes() if node.node.basename == basename - ] - return find_ast_node_with_type_in_file(type, target_files_nodes, kg) + def find_ast_node_with_text_in_file( + self, text: str, target_files_nodes: List[KnowledgeGraphNode] + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes containing the given text in files with the given basename.""" + results = [] + # Get HAS_AST edges to find which AST nodes belong to these files + has_ast_edges = self.kg.get_has_ast_edges() + file_to_ast_map = { + edge.source.node_id: edge.target + for edge in has_ast_edges + if edge.source.node_id in [n.node_id for n in target_files_nodes] + } -class FindASTNodeWithTypeInFileWithRelativePathInput(BaseModel): - type: str = Field("Search ASTNode with this tree-sitter node type.") - relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") - + # Construct parent to children map for AST traversal + parent_to_children = self.kg.get_parent_to_children_map() -FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Find all ASTNode in the graph that has this tree-sitter node type in any source file under -a file/directory with this relative path. This tool is useful for searching class/function/method -under a file/directory. The relative path should be the path from the root -of codebase (like 'src/core/parser.py' or 'test/unit').""" - - -def find_ast_node_with_type_in_file_with_relative_path( - type: str, relative_path: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes with the given type in files with the given relative path.""" - # Get file nodes with the given basename - target_files_nodes: List[KnowledgeGraphNode] = [ - node for node in kg.get_file_nodes() if node.node.relative_path == relative_path - ] - return find_ast_node_with_type_in_file(type, target_files_nodes, kg) - - def find_ast_node_with_type_in_file_with_basename( - self, type: str, basename: str, root_node_id: int - ) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE root.node_id = {root_node_id} AND f.basename = '{basename}' AND a.type = '{type}' - RETURN f as FileNode, a AS ASTNode - ORDER BY SIZE(a.text) - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) - - def find_ast_node_with_type_in_file_with_relative_path( - self, type: str, relative_path: str, root_node_id: int - ) -> tuple[str, Sequence[Mapping[str, Any]]]: - query = f"""\ - MATCH (root:FileNode)-[:HAS_FILE*]->(f:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) - WHERE root.node_id = {root_node_id} AND f.relative_path = '{relative_path}' AND a.type = '{type}' - RETURN f as FileNode, a AS ASTNode - ORDER BY SIZE(a.text) - LIMIT {MAX_RESULT} - """ - return neo4j_util.run_neo4j_query(query, self.driver, self.max_token_per_result) + # Get root AstNode id list + root_ast_node_ids = set([node.node_id for node in file_to_ast_map.values()]) -############################################################################### -# TextNode retrieval # -############################################################################### + for file_node in target_files_nodes: + # Start with root AST node for this file + root_ast = file_to_ast_map[file_node.node_id] + # Add all descendant AST nodes + stack = [root_ast] + while stack: + current_node = stack.pop() -class FindTextNodeWithTextInput(BaseModel): - text: str = Field("Search TextNode that exactly contains this text.") + # Check if the current node contains the text + # Don't include the root AST node itself + if text in current_node.node.text and current_node.node_id not in root_ast_node_ids: + results.append( + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "ASTNode": { + "node_id": current_node.node_id, + "type": current_node.node.type, + "start_line": current_node.node.start_line, + "end_line": current_node.node.end_line, + "text": current_node.node.text, + }, + } + ) + # Add children to stack + stack += parent_to_children.get(current_node.node_id, []) -FIND_TEXT_NODE_WITH_TEXT_DESCRIPTION = """\ -Find all TextNode in the graph that exactly contains this text. The contains is -same as python's check `'foo' in text`, ie. it is case sensitive and is -looking for exact matches. Therefore the search text should be exact as well. + # Sort by text length (smaller first) + results.sort(key=lambda x: len(x["ASTNode"]["text"])) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] -You can use this tool to find all text/documentation in codebase that contains this text.""" + def find_ast_node_with_text_in_file_with_basename(self, text: str, basename: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes containing the given text in files with the given basename.""" + # Get file nodes with the given basename + target_files_nodes: List[KnowledgeGraphNode] = [ + node for node in self.kg.get_file_nodes() if node.node.basename == basename + ] + return self.find_ast_node_with_text_in_file(text, target_files_nodes) + + + def find_ast_node_with_text_in_file_with_relative_path(self, text: str, relative_path: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes containing the given text in files with the given relative path.""" + # Get file nodes with the given basename + target_files_nodes: List[KnowledgeGraphNode] = [ + node for node in self.kg.get_file_nodes() if node.node.relative_path == relative_path + ] + return self.find_ast_node_with_text_in_file(text, target_files_nodes) + + def find_ast_node_with_type_in_file( + self, type: str, target_files_nodes: List[KnowledgeGraphNode] + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes containing the given text in files with the given basename.""" + results = [] + + # Get HAS_AST edges to find which AST nodes belong to these files + has_ast_edges = self.kg.get_has_ast_edges() + file_to_ast_map = { + edge.source.node_id: edge.target + for edge in has_ast_edges + if edge.source.node_id in [n.node_id for n in target_files_nodes] + } + # Construct parent to children map for AST traversal + parent_to_children = self.kg.get_parent_to_children_map() -def find_text_node_with_text(text: str, kg: KnowledgeGraph) -> Tuple[str, List[Dict[str, Any]]]: - """Find all TextNodes containing the given text.""" - results = [] + # Get root AstNode id list + root_ast_node_ids = set([node.node_id for node in file_to_ast_map.values()]) - # Get HAS_TEXT edges to find which text nodes belong to which files - has_text_edges = kg.get_has_text_edges() + for file_node in target_files_nodes: + # Start with root AST node for this file + root_ast = file_to_ast_map[file_node.node_id] - for edge in has_text_edges: - root_text_node = edge.target - stack = [root_text_node] + # Add all descendant AST nodes + stack = [root_ast] + while stack: + current_node = stack.pop() - while stack: - current_node = stack.pop() - if text in current_node.node.text: - results.append( - { - "FileNode": { - "node_id": edge.source.node_id, - "basename": edge.source.node.basename, - "relative_path": edge.source.node.relative_path, - }, - "TextNode": { - "node_id": current_node.node_id, - "text": current_node.node.text, - "metadata": current_node.node.metadata, - }, - } - ) - # Get next chunk nodes - stack += [ - e.target - for e in kg.get_next_chunk_edges() - if e.source.node_id == current_node.node_id - ] + # Check if current node contains the text + # Don't include the root AST node itself + if current_node.node.type == type and current_node.node_id not in root_ast_node_ids: + results.append( + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "ASTNode": { + "node_id": current_node.node_id, + "type": current_node.node.type, + "start_line": current_node.node.start_line, + "end_line": current_node.node.end_line, + "text": current_node.node.text, + }, + } + ) - # Sort by node_id - results.sort(key=lambda x: x["TextNode"]["node_id"]) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] + # Add children to stack + stack += parent_to_children.get(current_node.node_id, []) + # Sort by text length (smaller first) + results.sort(key=lambda x: len(x["ASTNode"]["text"])) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] -class FindTextNodeWithTextInFileInput(BaseModel): - text: str = Field("Search TextNode that exactly contains this text.") - basename: str = Field("The basename of FileNode to search TextNode.") + def find_ast_node_with_type_in_file_with_basename(self, type: str, basename: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes with the given type in files with the given basename.""" + # Get file nodes with the given basename + target_files_nodes: List[KnowledgeGraphNode] = [ + node for node in self.kg.get_file_nodes() if node.node.basename == basename + ] + return self.find_ast_node_with_type_in_file(type, target_files_nodes) + + + def find_ast_node_with_type_in_file_with_relative_path(self, type: str, relative_path: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes with the given type in files with the given relative path.""" + # Get file nodes with the given basename + target_files_nodes: List[KnowledgeGraphNode] = [ + node for node in self.kg.get_file_nodes() if node.node.relative_path == relative_path + ] + return self.find_ast_node_with_type_in_file(type, target_files_nodes) -FIND_TEXT_NODE_WITH_TEXT_IN_FILE_DESCRIPTION = """\ -Find all TextNode in the graph that exactly contains this text in a file with this basename. -The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is -looking for exact matches. Therefore the search text should be exact as well. -The basename must include the extension, like 'bar.py', 'baz.java' or 'foo' -(in this case foo is a directory or a file without extension). + ############################################################################### + # TextNode retrieval # + ############################################################################### -You can use this tool to find text/documentation in a specific file that contains this text.""" + def find_file_node_of_a_text_node(self, text_node: KnowledgeGraphNode) -> KnowledgeGraphNode: + """ + Find a file node that contains the given text node. + """ + next_chunk_reverse_map = { + edge.target.node_id: edge.source for edge in self.kg.get_next_chunk_edges() + } + has_file_node_map = {edge.target.node_id: edge.source for edge in self.kg.get_has_text_edges()} + + # Find the root text node + current_text_node = text_node + while next_chunk_reverse_map.get(current_text_node.node_id, None) is not None: + current_text_node = next_chunk_reverse_map[current_text_node.node_id] + + # Now current_text_node is the root text node + file_node = has_file_node_map[current_text_node.node_id] + return file_node + + def find_text_node_with_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all TextNodes containing the given text.""" + results = [] + # Find text nodes that contain the given text + text_nodes_with_text = [node for node in kg.get_text_nodes() if text in node.node.text] + + # If no text nodes found, return early + if not text_nodes_with_text: + return format_knowledge_graph_data([]), [] + for text_node in text_nodes_with_text: + # Find the file node that contains this text node + file_node = self.find_file_node_of_a_text_node(text_node) + results.append( + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "TextNode": { + "node_id": text_node.node_id, + "text": text_node.node.text, + "metadata": text_node.node.metadata, + }, + } + ) + # Sort by node_id + results.sort(key=lambda x: x["TextNode"]["node_id"]) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] -def find_text_node_with_text_in_file( - text: str, basename: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all TextNodes containing the given text in files with the given basename.""" - results = [] - # Get HAS_TEXT edges to find which text nodes belong to which files - has_text_edges = kg.get_has_text_edges() + def find_text_node_with_text_in_file( + self, text: str, basename: str + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all TextNodes containing the given text in files with the given basename.""" + results = [] + # Find text nodes that contain the given text + text_nodes_with_text = [node for node in self.kg.get_text_nodes() if text in node.node.text] - for edge in has_text_edges: - root_text_node = edge.target - if edge.source.node.basename != basename: - continue + # If no text nodes found, return early + if not text_nodes_with_text: + return format_knowledge_graph_data([]), [] - stack = [root_text_node] + for text_node in text_nodes_with_text: + # Now current_text_node is the root text node + file_node = self.find_file_node_of_a_text_node(text_node) - while stack: - current_node = stack.pop() - if text in current_node.node.text: + # If the file node matches the given basename, add to results + if file_node.node.basename == basename: results.append( { "FileNode": { - "node_id": edge.source.node_id, - "basename": edge.source.node.basename, - "relative_path": edge.source.node.relative_path, + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, }, "TextNode": { - "node_id": current_node.node_id, - "text": current_node.node.text, - "metadata": current_node.node.metadata, + "node_id": text_node.node_id, + "text": text_node.node.text, + "metadata": text_node.node.metadata, }, } ) - # Get next chunk nodes - stack += [ - e.target - for e in kg.get_next_chunk_edges() - if e.source.node_id == current_node.node_id - ] - - # Sort by node_id - results.sort(key=lambda x: x["TextNode"]["node_id"]) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - - -class GetNextTextNodeWithNodeIdInput(BaseModel): - node_id: int = Field("Get the next TextNode of this given node_id.") + # Sort by node_id + results.sort(key=lambda x: x["TextNode"]["node_id"]) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] -GET_NEXT_TEXT_NODE_WITH_NODE_ID_DESCRIPTION = """\ -Get the next TextNode of this given node_id. + + def get_next_text_node_with_node_id(self, node_id: int) -> Tuple[str, List[Dict[str, Any]]]: + """Get the next TextNode for the given node_id.""" -You can use this tool to read the next section of text that you are interested in.""" + results = [] + # Find the current text node + current_text_node = None + for node in self.kg.get_text_nodes(): + if node.node_id == node_id: + current_text_node = node + break -def get_next_text_node_with_node_id( - node_id: int, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Get the next TextNode for the given node_id.""" + # If the current text node does not exist, return empty result + if not current_text_node: + return format_knowledge_graph_data([]), [] - results = [] + # Get next chunk map + next_chunk_map = {edge.source.node_id: edge.target for edge in self.kg.get_next_chunk_edges()} - # Get HAS_TEXT edges to find which text nodes belong to which files - has_text_edges = kg.get_has_text_edges() - next_chunk_edges_map = {edge.source.node_id: edge.target for edge in kg.get_next_chunk_edges()} + # Get the next text node + next_text_node = next_chunk_map.get(current_text_node.node_id, None) - for edge in has_text_edges: - root_text_node = edge.target - stack = [root_text_node] + # if the next text node does not exist, return empty result + if not next_text_node: + return format_knowledge_graph_data([]), [] - while stack: - current_node = stack.pop() - if current_node.node_id == node_id: - next_text_node = next_chunk_edges_map.get(current_node.node_id, None) - if next_text_node: - results.append( - { - "FileNode": { - "node_id": edge.source.node_id, - "basename": edge.source.node.basename, - "relative_path": edge.source.node.relative_path, - }, - "TextNode": { - "node_id": next_text_node.node_id, - "text": next_text_node.node.text, - "metadata": next_text_node.node.metadata, - }, - } - ) - break + # Find the file node that contains this text node + file_node = self.find_file_node_of_a_text_node(next_text_node) + results.append( + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "TextNode": { + "node_id": next_text_node.node_id, + "text": next_text_node.node.text, + "metadata": next_text_node.node.metadata, + }, + } + ) + return format_knowledge_graph_data(results), results - # Get next chunk nodes - stack += [ - e for e in kg.get_next_chunk_edges() if e.source.node_id == current_node.node_id - ] - # Sort by node_id - return format_knowledge_graph_data(results), results + ############################################################################### + # Other # + ############################################################################### -############################################################################### -# Other # -############################################################################### + def read_code_with_relative_path(self, relative_path: str, start_line: int, end_line: int) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: + """Read a specific section of a source code file by relative path and line range.""" + if end_line < start_line: + return f"end_line {end_line} must be greater than start_line {start_line}!", None + # Find file nodes with the given relative path + target_file = None + for node in self.kg.get_file_nodes(): + if node.node.relative_path == relative_path: + target_file = node + break -class PreviewFileContentWithRelativePathInput(BaseModel): - relative_path: str = Field("The relative path of FileNode to preview.") + # Check if the file exists + if not target_file: + return format_knowledge_graph_data([]), [] + # Check if it is a source code file + if not tree_sitter_parser.supports_file(Path(target_file.node.relative_path)): + return f"The file {relative_path} is not a source code file!", None -PREVIEW_FILE_CONTENT_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Preview the content of a file with this relative path from the root of codebase. -The relative path must include the extension and full path from root, like 'src/core/parser.py', -'test/unit/test_parser.java' or 'docs/README.md'. - -You can use this tool to preview the content of a specific file to see what it contains -in the first 1000 lines or the first section. If the file is interesting, use other tools -to look at the file.""" - - -def preview_file_content_with_relative_path( - relative_path: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Preview the content of a file with the given relative path.""" - # Find file nodes with the given relative path - target_file = None - for node in kg.get_file_nodes(): - if node.node.relative_path == relative_path: - target_file = node - break - # Check if the file exists - if not target_file: - return format_knowledge_graph_data([]), [] - - # Handle source code files - if tree_sitter_parser.supports_file(Path(target_file.node.relative_path)): + # Get the first ast node for this file first_ast_node = [ - edge.target - for edge in kg.get_has_ast_edges() - if edge.source.node_id == target_file.node_id + edge.target for edge in self.kg.get_has_ast_edges() if edge.source.node_id == target_file.node_id ][0] - text = first_ast_node.node.text[0:1000] - preview_text_with_line_numbers = pre_append_line_numbers(text, 1) - result_data = [ - { - "FileNode": { - "node_id": target_file.node_id, - "basename": target_file.node.basename, - "relative_path": target_file.node.relative_path, - }, - "preview": { - "text": preview_text_with_line_numbers, - "start_line": 1, - "end_line": len(text.split("\n")), - }, - } - ] + text = first_ast_node.node.text + lines = text.split("\n") + selected_lines = lines[start_line - 1 : end_line - 1] # Convert to 0-indexed + selected_text = "\n".join(selected_lines) + selected_text_with_line_numbers = pre_append_line_numbers(selected_text, start_line) - # Handle text files - else: - root_text_node = [ - edge.target - for edge in kg.get_has_text_edges() - if edge.source.node_id == target_file.node_id - ][0] - stack = [root_text_node] - all_text = "" - while stack: - current_node = stack.pop() - all_text += current_node.node.text - if len(all_text.splitlines()) >= 1000: - break - - # Get next chunk nodes - stack += [ - e.target - for e in kg.get_next_chunk_edges() - if e.source.node_id == current_node.node_id - ] - - # Collect 1000 lines - text = all_text[:1000] - preview_text_with_line_numbers = pre_append_line_numbers(text, 1) result_data = [ { "FileNode": { @@ -589,83 +547,12 @@ def preview_file_content_with_relative_path( "basename": target_file.node.basename, "relative_path": target_file.node.relative_path, }, - "preview": { - "text": preview_text_with_line_numbers, - "start_line": 1, - "end_line": len(text.split("\n")), + "SelectedLines": { + "text": selected_text_with_line_numbers, + "start_line": start_line, + "end_line": end_line, }, } ] - return format_knowledge_graph_data(result_data), result_data - - -class ReadCodeWithRelativePathInput(BaseModel): - relative_path: str = Field("The relative path of FileNode to read from root of codebase.") - start_line: int = Field("The starting line number, 1-indexed and inclusive.") - end_line: int = Field("The ending line number, 1-indexed and exclusive.") - - -READ_CODE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Read a specific section of a source code file's content by specifying its relative path and line range. -The relative path must be the full path from the root of codebase, like 'src/core/parser.py' or -'test/unit/test_parser.java'. - -This tool ONLY works with source code files (not text files or documentation). It is designed -to read large sections of code at once - you should request substantial chunks (hundreds of lines) -rather than making multiple small requests of 10-20 lines each, which would be inefficient. + return format_knowledge_graph_data(result_data), result_data -Line numbers are 1-indexed, where start_line is inclusive and end_line is exclusive. - -This tool is useful for examining specific sections of source code files when you know -the exact line range you want to analyze. The function will return an error message if -end_line is less than start_line. -""" - - -def read_code_with_relative_path( - relative_path: str, start_line: int, end_line: int, kg: KnowledgeGraph -) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: - """Read a specific section of a source code file by relative path and line range.""" - if end_line < start_line: - return f"end_line {end_line} must be greater than start_line {start_line}!", None - - # Find file nodes with the given relative path - target_file = None - for node in kg.get_file_nodes(): - if node.node.relative_path == relative_path: - target_file = node - break - - # Check if the file exists - if not target_file: - return format_knowledge_graph_data([]), [] - - # Check if it is a source code file - if not tree_sitter_parser.supports_file(Path(target_file.node.relative_path)): - return f"The file {relative_path} is not a source code file!", None - - # Get the first ast node for this file - first_ast_node = [ - edge.target for edge in kg.get_has_ast_edges() if edge.source.node_id == target_file.node_id - ][0] - text = first_ast_node.node.text - lines = text.split("\n") - selected_lines = lines[start_line - 1 : end_line - 1] # Convert to 0-indexed - selected_text = "\n".join(selected_lines) - selected_text_with_line_numbers = pre_append_line_numbers(selected_text, start_line) - - result_data = [ - { - "FileNode": { - "node_id": target_file.node_id, - "basename": target_file.node.basename, - "relative_path": target_file.node.relative_path, - }, - "SelectedLines": { - "text": selected_text_with_line_numbers, - "start_line": start_line, - "end_line": end_line, - }, - } - ] - return format_knowledge_graph_data(result_data), result_data From 2730625c1608ef8d084ecf66f5b104000663f92b Mon Sep 17 00:00:00 2001 From: cocoli Date: Wed, 3 Sep 2025 00:35:19 +0800 Subject: [PATCH 11/30] fix --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7fee9497..52d59bc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,10 +28,10 @@ dependencies = [ "sqlmodel==0.0.24", "psycopg2-binary", "asyncpg", - "pyjwt==2.6.0",, + "pyjwt==2.6.0", "mcp>=1.4.1", "tavily-python>=0.5.1", - "langchain-mcp-adapters>=0.1.9" + "langchain-mcp-adapters>=0.1.9", "httpx==0.28.1", ] requires-python = ">= 3.11" From c1bf9a03f865ea1b869cbf47bb03df3f45497e22 Mon Sep 17 00:00:00 2001 From: cocoli Date: Wed, 3 Sep 2025 16:58:00 +0800 Subject: [PATCH 12/30] fix merge --- prometheus/app/main.py | 11 -- .../nodes/bug_reproducing_file_node.py | 2 +- .../nodes/bug_reproducing_write_node.py | 5 +- .../lang_graph/nodes/context_provider_node.py | 4 +- prometheus/lang_graph/nodes/edit_node.py | 5 +- .../subgraphs/bug_reproduction_subgraph.py | 2 +- .../issue_not_verified_bug_subgraph.py | 2 +- .../subgraphs/issue_verified_bug_subgraph.py | 2 +- prometheus/tools/file_operation.py | 148 ------------------ 9 files changed, 11 insertions(+), 170 deletions(-) diff --git a/prometheus/app/main.py b/prometheus/app/main.py index 20a40498..4ed095c6 100644 --- a/prometheus/app/main.py +++ b/prometheus/app/main.py @@ -17,17 +17,6 @@ from prometheus.configuration.config import settings from prometheus.utils.logger_manager import get_logger logger = get_logger(__name__) -# Log the configuration settings -logger.info(f"LOGGING_LEVEL={settings.LOGGING_LEVEL}") -logger.info(f"ENVIRONMENT={settings.ENVIRONMENT}") -logger.info(f"BACKEND_CORS_ORIGINS={settings.BACKEND_CORS_ORIGINS}") -logger.info(f"ADVANCED_MODEL={settings.ADVANCED_MODEL}") -logger.info(f"BASE_MODEL={settings.BASE_MODEL}") -logger.info(f"NEO4J_BATCH_SIZE={settings.NEO4J_BATCH_SIZE}") -logger.info(f"WORKING_DIRECTORY={settings.WORKING_DIRECTORY}") -logger.info(f"KNOWLEDGE_GRAPH_MAX_AST_DEPTH={settings.KNOWLEDGE_GRAPH_MAX_AST_DEPTH}") -logger.info(f"KNOWLEDGE_GRAPH_CHUNK_SIZE={settings.KNOWLEDGE_GRAPH_CHUNK_SIZE}") -logger.info(f"KNOWLEDGE_GRAPH_CHUNK_OVERLAP={settings.KNOWLEDGE_GRAPH_CHUNK_OVERLAP}") @asynccontextmanager diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index 72872d4a..0cb9d01e 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -39,7 +39,7 @@ class BugReproducingFileNode: def __init__(self, model: BaseChatModel, kg: KnowledgeGraph, local_path: str): self.kg = kg - self.file_operation_tool = FileOperationTool(local_path) + self.file_operation_tool = FileOperationTool(local_path, kg) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index 93be8e69..f4d2fc27 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -8,6 +8,7 @@ from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.tools.file_operation import FileOperationTool +from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.utils.logger_manager import get_logger @@ -112,8 +113,8 @@ def test_empty_array_parsing(parser): ''' - def __init__(self, model: BaseChatModel, local_path: str): - self.file_operation_tool = FileOperationTool(local_path) + def __init__(self, model: BaseChatModel, local_path: str, kg: KnowledgeGraph): + self.file_operation_tool = FileOperationTool(local_path, kg) self.tools = self._init_tools() self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model_with_tools = model.bind_tools(self.tools) diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index c5b0e304..81c284f0 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -91,7 +91,6 @@ def __init__( model: BaseChatModel, kg: KnowledgeGraph, local_path: str, - local_path: str, ): """Initializes the ContextProviderNode with model, knowledge graph, and database connection. @@ -109,6 +108,7 @@ def __init__( # self.neo4j_driver = neo4j_driver self.root_node_id = kg.root_node_id self.kg = kg + self.root_path = local_path # self.max_token_per_result = max_token_per_result # Initialize GraphTraversalTool with the driver and token limit self.graph_traversal_tool = GraphTraversalTool(kg) @@ -116,8 +116,6 @@ def __init__( ast_node_types_str = ", ".join(kg.get_all_ast_node_types()) - self.kg = kg - self.root_path = local_path self.system_prompt = SystemMessage( self.SYS_PROMPT.format(file_tree=kg.get_file_tree(), ast_node_types=ast_node_types_str) ) diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index f7c95a3d..3d10f4d4 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -16,6 +16,7 @@ from langchain_core.messages import SystemMessage from prometheus.tools.file_operation import FileOperationTool +from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.utils.logger_manager import get_logger @@ -119,9 +120,9 @@ def other_method(): 7. NEVER write tests, your change will be tested by reproduction tests and regression tests later """ - def __init__(self, model: BaseChatModel, local_path: str): + def __init__(self, model: BaseChatModel, local_path: str, kg: KnowledgeGraph): self.system_prompt = SystemMessage(self.SYS_PROMPT) - self.file_operation_tool = FileOperationTool(local_path) + self.file_operation_tool = FileOperationTool(local_path, kg) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") diff --git a/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py b/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py index dff4c0e7..ea678434 100644 --- a/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py +++ b/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py @@ -70,7 +70,7 @@ def __init__( # Step 3: Write a patch to reproduce the bug bug_reproducing_write_message_node = BugReproducingWriteMessageNode() bug_reproducing_write_node = BugReproducingWriteNode( - advanced_model, git_repo.playground_path + advanced_model, git_repo.playground_path, kg ) bug_reproducing_write_tools = ToolNode( tools=bug_reproducing_write_node.tools, diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index 71af812c..50e10190 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -52,7 +52,7 @@ def __init__( ) edit_message_node = EditMessageNode() - edit_node = EditNode(advanced_model, git_repo.playground_path) + edit_node = EditNode(advanced_model, git_repo.playground_path, kg) edit_tools = ToolNode( tools=edit_node.tools, name="edit_tools", diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py index 8a763161..896a3706 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py @@ -89,7 +89,7 @@ def __init__( # Phase 3: Generate code edits and optionally apply toolchains edit_message_node = EditMessageNode() - edit_node = EditNode(advanced_model, git_repo.playground_path) + edit_node = EditNode(advanced_model, git_repo.playground_path, kg) edit_tools = ToolNode( tools=edit_node.tools, name="edit_tools", diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index 32edef8b..02127875 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -253,151 +253,3 @@ def read_file_with_knowledge_graph_data(self, relative_path: str) -> Union[Tuple } ] return format_knowledge_graph_data(result_data), result_data - - - - -READ_FILE_DESCRIPTION = """\ -Read the content of a file with line numbers prepended from the codebase with a safety limit on the number of lines. -Returns up to the first 1000 lines by default to prevent context issues with large files. -Returns an error message if the file doesn't exist. -""" - - -def read_file(relative_path: str, root_path: str, n_lines: int = 1000) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." - - with file_path.open() as f: - lines = f.readlines() - - return pre_append_line_numbers("".join(lines[:n_lines]), 1) - - - - -READ_FILE_WITH_LINE_NUMBERS_DESCRIPTION = """\ -Read a specific range of lines from a file and return the content with line numbers prepended. -The line numbers are 1-indexed where start_line is inclusive and end_line is exclusive. -For best results when analyzing code or text files, consider reading chunks of 500-1000 lines at a time. -""" - - -def read_file_with_line_numbers( - relative_path: str, root_path: str, start_line: int, end_line: int -) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." - - if end_line < start_line: - return f"The end line number {end_line} must be greater than the start line number {start_line}." - - zero_based_start_line = start_line - 1 - zero_based_end_line = end_line - 1 - - with file_path.open() as f: - lines = f.readlines() - final_content = "".join(lines[zero_based_start_line:zero_based_end_line]) - if not final_content: - return f"No content found between lines {start_line} and {end_line} in {relative_path}!" - - return pre_append_line_numbers(final_content, start_line) - - - -CREATE_FILE_DESCRIPTION = """\ -Create a new file at the specified path with the given content. -If the parent directories don't exist, they will be created automatically. -Returns an error message if the file already exists. -""" - - -def create_file(relative_path: str, root_path: str, content: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if file_path.exists(): - return f"The file {relative_path} already exists." - - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(content) - return f"The file {relative_path} has been created." - - - -DELETE_DESCRIPTION = """\ -Delete a file or directory at the specified path. -For directories, it will recursively delete all contents. -Returns an error message if the path doesn't exist. -""" - - -def delete(relative_path: str, root_path: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." - - if file_path.is_dir(): - shutil.rmtree(file_path) - return f"The directory {relative_path} has been deleted." - - file_path.unlink() - return f"The file {relative_path} has been deleted." - - - -EDIT_FILE_DESCRIPTION = """\ -Edit a file by replacing specific content with new content. -Performs an exact string replacement of old_content with new_content. -Returns an error message if: -- The file doesn't exist -- The old_content is not found in the file -- The old_content matches multiple locations (in which case more context is needed) -- The provided path is absolute instead of relative - -Example usage: -edit_file( - relative_path="src/calculator.py", - old_content="return a * b", - new_content="return a / b" -) -""" - - -def edit_file(relative_path: str, root_path: str, old_content: str, new_content: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." - - content = file_path.read_text() - - occurrences = content.count(old_content) - - if occurrences == 0: - return f"No match found for the specified content in {relative_path}. Please verify the content to replace." - - if occurrences > 1: - return ( - f"Found {occurrences} occurrences of the specified content in {relative_path}. " - "Please provide more context to ensure a unique match." - ) - - new_content_full = content.replace(old_content, new_content) - file_path.write_text(new_content_full) - - return f"Successfully edited {relative_path}." From 03b87e61c1afc4e446d2a08e51679da8bd02c0d9 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Thu, 4 Sep 2025 21:48:48 +0800 Subject: [PATCH 13/30] refactor: Remove unused logging imports across multiple files --- prometheus/app/api/routes/repository.py | 6 +- prometheus/app/main.py | 1 + prometheus/app/services/database_service.py | 1 - prometheus/docker/base_container.py | 4 +- prometheus/git/git_repository.py | 2 +- .../bug_fix_verification_subgraph_node.py | 1 - .../lang_graph/nodes/bug_fix_verify_node.py | 1 - .../nodes/bug_fix_verify_structured_node.py | 1 - .../nodes/bug_reproducing_execute_node.py | 3 +- .../nodes/bug_reproducing_file_node.py | 1 - .../nodes/bug_reproducing_structured_node.py | 1 - .../bug_reproducing_write_message_node.py | 1 - .../nodes/bug_reproducing_write_node.py | 3 +- .../nodes/bug_reproduction_subgraph_node.py | 1 - .../nodes/build_and_test_subgraph_node.py | 1 - .../nodes/context_extraction_node.py | 1 - .../lang_graph/nodes/context_provider_node.py | 9 +- .../nodes/context_query_message_node.py | 1 - .../lang_graph/nodes/context_refine_node.py | 1 - .../nodes/context_retrieval_subgraph_node.py | 1 - .../lang_graph/nodes/edit_message_node.py | 1 - prometheus/lang_graph/nodes/edit_node.py | 3 +- .../nodes/final_patch_selection_node.py | 1 - .../lang_graph/nodes/general_build_node.py | 1 - .../nodes/general_build_structured_node.py | 1 - .../lang_graph/nodes/general_test_node.py | 3 +- .../nodes/general_test_structured_node.py | 1 - prometheus/lang_graph/nodes/git_diff_node.py | 1 - prometheus/lang_graph/nodes/git_reset_node.py | 1 - .../nodes/issue_bug_analyzer_message_node.py | 1 - .../nodes/issue_bug_analyzer_node.py | 82 ++++---- .../nodes/issue_bug_context_message_node.py | 1 - ...e_bug_reproduction_context_message_node.py | 1 - .../nodes/issue_bug_responder_node.py | 1 - .../nodes/issue_bug_subgraph_node.py | 1 - ...sue_classification_context_message_node.py | 1 - .../issue_classification_subgraph_node.py | 1 - .../lang_graph/nodes/issue_classifier_node.py | 1 - .../issue_not_verified_bug_subgraph_node.py | 1 - .../nodes/issue_verified_bug_subgraph_node.py | 1 - prometheus/lang_graph/nodes/noop_node.py | 3 +- .../lang_graph/nodes/reset_messages_node.py | 3 +- .../lang_graph/nodes/update_container_node.py | 3 +- .../nodes/user_defined_build_node.py | 1 - .../issue_not_verified_bug_subgraph.py | 4 +- .../subgraphs/issue_verified_bug_subgraph.py | 4 +- prometheus/tools/container_command.py | 17 +- prometheus/tools/file_operation.py | 31 ++-- prometheus/tools/graph_traversal.py | 82 ++++---- prometheus/tools/web_search.py | 139 +++++++------- prometheus/utils/logger_manager.py | 175 +++++++++--------- .../nodes/test_issue_bug_analyzer_node.py | 56 ++++-- tests/tools/test_mcp_client.py | 122 +++++------- tests/tools/test_mcp_client_config.py | 58 +++--- tests/tools/test_mcp_server.py | 14 +- tests/tools/test_mcp_tools.py | 12 +- tests/tools/test_mcp_web_search.py | 87 +++++---- 57 files changed, 485 insertions(+), 472 deletions(-) diff --git a/prometheus/app/api/routes/repository.py b/prometheus/app/api/routes/repository.py index ca346b4f..dfd48a23 100644 --- a/prometheus/app/api/routes/repository.py +++ b/prometheus/app/api/routes/repository.py @@ -177,7 +177,11 @@ async def list_repositories(request: Request): response_model=Response, ) @requireLogin -async def delete(repository_id: int, request: Request, force: bool = False, ): +async def delete( + repository_id: int, + request: Request, + force: bool = False, +): knowledge_graph_service: KnowledgeGraphService = request.app.state.service[ "knowledge_graph_service" ] diff --git a/prometheus/app/main.py b/prometheus/app/main.py index 4ed095c6..a25877dd 100644 --- a/prometheus/app/main.py +++ b/prometheus/app/main.py @@ -16,6 +16,7 @@ ) from prometheus.configuration.config import settings from prometheus.utils.logger_manager import get_logger + logger = get_logger(__name__) diff --git a/prometheus/app/services/database_service.py b/prometheus/app/services/database_service.py index c69a59d4..4d889a4e 100644 --- a/prometheus/app/services/database_service.py +++ b/prometheus/app/services/database_service.py @@ -1,4 +1,3 @@ -import logging from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel diff --git a/prometheus/docker/base_container.py b/prometheus/docker/base_container.py index 0b80e605..03c9809d 100644 --- a/prometheus/docker/base_container.py +++ b/prometheus/docker/base_container.py @@ -9,6 +9,8 @@ import docker +from prometheus.utils.logger_manager import get_logger + class BaseContainer(ABC): """An abstract base class for managing Docker containers with file synchronization capabilities. @@ -41,7 +43,7 @@ def __init__( Args: project_path: Path to the project directory to be containerized. """ - self._logger = logging.getLogger( + self._logger = get_logger( f"thread-{threading.get_ident()}.{self.__class__.__module__}.{self.__class__.__name__}" ) temp_dir = Path(tempfile.mkdtemp()) diff --git a/prometheus/git/git_repository.py b/prometheus/git/git_repository.py index 46cf13c8..5a96413e 100644 --- a/prometheus/git/git_repository.py +++ b/prometheus/git/git_repository.py @@ -1,13 +1,13 @@ """Git repository management module.""" import asyncio -import logging import shutil import tempfile from pathlib import Path from typing import Optional, Sequence from git import Git, GitCommandError, InvalidGitRepositoryError, Repo + from prometheus.utils.logger_manager import get_logger diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index cc3d869b..13a72ad2 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index 3028ea8e..fcff00b5 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -1,5 +1,4 @@ import functools -import logging import threading from langchain.tools import StructuredTool diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py index 354481b6..80f8052f 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py index e0cf7654..38727ec5 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py @@ -1,5 +1,4 @@ import functools -import logging import threading from pathlib import Path from typing import Optional, Sequence @@ -12,8 +11,8 @@ from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.issue_util import format_test_commands +from prometheus.utils.logger_manager import get_logger from prometheus.utils.patch_util import get_updated_files -from prometheus.utils.logger_manager import get_logger class BugReproducingExecuteNode: diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index 0cb9d01e..684f077a 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -1,5 +1,4 @@ import functools -import logging import threading from langchain.tools import StructuredTool diff --git a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py index 43726a20..cb2ca68f 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Sequence diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py index b1db4d02..f2dd20cd 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.messages import HumanMessage diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index f4d2fc27..9cd66b92 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -1,14 +1,13 @@ import functools -import logging import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import SystemMessage +from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.tools.file_operation import FileOperationTool -from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.utils.logger_manager import get_logger diff --git a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py index 586523d0..20519e0f 100644 --- a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Optional, Sequence diff --git a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py index 206a38bd..f6f1f380 100644 --- a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py +++ b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Optional, Sequence diff --git a/prometheus/lang_graph/nodes/context_extraction_node.py b/prometheus/lang_graph/nodes/context_extraction_node.py index 6c94459f..0d8bd0d9 100644 --- a/prometheus/lang_graph/nodes/context_extraction_node.py +++ b/prometheus/lang_graph/nodes/context_extraction_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Sequence diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index d9183ccc..b90cb798 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -6,7 +6,6 @@ """ import functools -import logging import threading from typing import Dict @@ -15,8 +14,9 @@ from langchain_core.messages import SystemMessage from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.tools.graph_traversal import GraphTraversalTool from prometheus.tools.file_operation import FileOperationTool +from prometheus.tools.graph_traversal import GraphTraversalTool +from prometheus.utils.logger_manager import get_logger class ContextProviderNode: @@ -113,16 +113,13 @@ def __init__( self.graph_traversal_tool = GraphTraversalTool(kg) self.file_operation_tool = FileOperationTool(local_path, kg) - ast_node_types_str = ", ".join(kg.get_all_ast_node_types()) self.system_prompt = SystemMessage( self.SYS_PROMPT.format(file_tree=kg.get_file_tree(), ast_node_types=ast_node_types_str) ) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_provider_node" - ) + self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") def _init_tools(self): """ diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py index 5d6b13bb..e6bc476a 100644 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ b/prometheus/lang_graph/nodes/context_query_message_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.messages import HumanMessage diff --git a/prometheus/lang_graph/nodes/context_refine_node.py b/prometheus/lang_graph/nodes/context_refine_node.py index ea2caaad..923f8b34 100644 --- a/prometheus/lang_graph/nodes/context_refine_node.py +++ b/prometheus/lang_graph/nodes/context_refine_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index 45c622a1..c9044f19 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Dict, Sequence diff --git a/prometheus/lang_graph/nodes/edit_message_node.py b/prometheus/lang_graph/nodes/edit_message_node.py index 76553988..e7182501 100644 --- a/prometheus/lang_graph/nodes/edit_message_node.py +++ b/prometheus/lang_graph/nodes/edit_message_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Dict diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index ebb5f656..9d70ad77 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -7,7 +7,6 @@ """ import functools -import logging import threading from typing import Dict @@ -15,8 +14,8 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import SystemMessage -from prometheus.tools.file_operation import FileOperationTool from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.tools.file_operation import FileOperationTool from prometheus.utils.logger_manager import get_logger diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index 469de7f6..3d361e53 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Sequence diff --git a/prometheus/lang_graph/nodes/general_build_node.py b/prometheus/lang_graph/nodes/general_build_node.py index cd9a0441..30ff5405 100644 --- a/prometheus/lang_graph/nodes/general_build_node.py +++ b/prometheus/lang_graph/nodes/general_build_node.py @@ -1,5 +1,4 @@ import functools -import logging import threading from langchain.tools import StructuredTool diff --git a/prometheus/lang_graph/nodes/general_build_structured_node.py b/prometheus/lang_graph/nodes/general_build_structured_node.py index 525ee0d4..e1b10f2d 100644 --- a/prometheus/lang_graph/nodes/general_build_structured_node.py +++ b/prometheus/lang_graph/nodes/general_build_structured_node.py @@ -6,7 +6,6 @@ identify any failures. """ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/general_test_node.py b/prometheus/lang_graph/nodes/general_test_node.py index 81ce97a3..9b207ab5 100644 --- a/prometheus/lang_graph/nodes/general_test_node.py +++ b/prometheus/lang_graph/nodes/general_test_node.py @@ -1,5 +1,4 @@ import functools -import logging import threading from langchain.tools import StructuredTool @@ -9,7 +8,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState -from prometheus.tools.container_command import ContainerCommandTool +from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.logger_manager import get_logger diff --git a/prometheus/lang_graph/nodes/general_test_structured_node.py b/prometheus/lang_graph/nodes/general_test_structured_node.py index c884866c..6b4dc55e 100644 --- a/prometheus/lang_graph/nodes/general_test_structured_node.py +++ b/prometheus/lang_graph/nodes/general_test_structured_node.py @@ -6,7 +6,6 @@ identify any failures. """ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/git_diff_node.py b/prometheus/lang_graph/nodes/git_diff_node.py index eae4c7b4..6f9b0da0 100644 --- a/prometheus/lang_graph/nodes/git_diff_node.py +++ b/prometheus/lang_graph/nodes/git_diff_node.py @@ -6,7 +6,6 @@ output. """ -import logging import threading from typing import Dict, Optional diff --git a/prometheus/lang_graph/nodes/git_reset_node.py b/prometheus/lang_graph/nodes/git_reset_node.py index 40bf78f0..5bda5375 100644 --- a/prometheus/lang_graph/nodes/git_reset_node.py +++ b/prometheus/lang_graph/nodes/git_reset_node.py @@ -1,4 +1,3 @@ -import logging import threading from prometheus.git.git_repository import GitRepository diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py index 12b82d84..d6ecae23 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Dict diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index ac299539..899471d9 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -1,54 +1,52 @@ -import logging +import functools import threading from typing import Dict -import asyncio +from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import SystemMessage -from langchain.tools import StructuredTool -import functools -from prometheus.tools.web_search import WebSearchTool, mcp_web_search + +from prometheus.tools.web_search import WebSearchTool from prometheus.utils.logger_manager import get_logger class IssueBugAnalyzerNode: -# SYS_PROMPT = """\ -# You are an expert software engineer specializing in bug analysis and fixes. Your role is to: - -# 1. Carefully analyze reported software issues and bugs by: -# - Understanding issue descriptions and symptoms -# - Identifying affected code components -# - Tracing problematic execution paths - -# 2. Determine root causes through systematic investigation: -# - Analyze why the current behavior deviates from expected -# - Identify which specific code elements are responsible -# - Understand the context and interactions causing the issue - -# 3. Provide high-level fix suggestions by describing: -# - Which specific files need modification -# - Which functions or code blocks need changes -# - What logical changes are needed (e.g., "variable x needs to be renamed to y", "need to add validation for parameter z") -# - Why these changes would resolve the issue - -# 4. For patch failures, analyze by: -# - Understanding error messages and test failures -# - Identifying what went wrong with the previous attempt -# - Suggesting revised high-level changes that avoid the previous issues - -# Tools available: -# - web_search: Searches the web for technical information to aid in bug analysis and resolution. - -# Important: -# - Do NOT provide actual code snippets or diffs -# - DO provide clear file paths and function names where changes are needed -# - Focus on describing WHAT needs to change and WHY, not HOW to change it -# - Keep descriptions precise and actionable, as they will be used by another agent to implement the changes - -# Communicate in a clear, technical manner focused on accurate analysis and practical suggestions -# rather than implementation details. -# """ - + # SYS_PROMPT = """\ + # You are an expert software engineer specializing in bug analysis and fixes. Your role is to: + + # 1. Carefully analyze reported software issues and bugs by: + # - Understanding issue descriptions and symptoms + # - Identifying affected code components + # - Tracing problematic execution paths + + # 2. Determine root causes through systematic investigation: + # - Analyze why the current behavior deviates from expected + # - Identify which specific code elements are responsible + # - Understand the context and interactions causing the issue + + # 3. Provide high-level fix suggestions by describing: + # - Which specific files need modification + # - Which functions or code blocks need changes + # - What logical changes are needed (e.g., "variable x needs to be renamed to y", "need to add validation for parameter z") + # - Why these changes would resolve the issue + + # 4. For patch failures, analyze by: + # - Understanding error messages and test failures + # - Identifying what went wrong with the previous attempt + # - Suggesting revised high-level changes that avoid the previous issues + + # Tools available: + # - web_search: Searches the web for technical information to aid in bug analysis and resolution. + + # Important: + # - Do NOT provide actual code snippets or diffs + # - DO provide clear file paths and function names where changes are needed + # - Focus on describing WHAT needs to change and WHY, not HOW to change it + # - Keep descriptions precise and actionable, as they will be used by another agent to implement the changes + + # Communicate in a clear, technical manner focused on accurate analysis and practical suggestions + # rather than implementation details. + # """ SYS_PROMPT = """\ You are an expert software engineer specializing in bug analysis and fixes. Your role is to: diff --git a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py index 13b060a8..f7d09aee 100644 --- a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Dict diff --git a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py index a83efccf..3d53d250 100644 --- a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py @@ -1,4 +1,3 @@ -import logging import threading from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState diff --git a/prometheus/lang_graph/nodes/issue_bug_responder_node.py b/prometheus/lang_graph/nodes/issue_bug_responder_node.py index 78f1e38b..d980772a 100644 --- a/prometheus/lang_graph/nodes/issue_bug_responder_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_responder_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index e69ae34a..4b550bdd 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Optional, Sequence diff --git a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py index ba96092c..5a6aeee3 100644 --- a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py @@ -1,4 +1,3 @@ -import logging import threading from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState diff --git a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py index 4eef63f6..3f2cdc06 100644 --- a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/issue_classifier_node.py b/prometheus/lang_graph/nodes/issue_classifier_node.py index 3d8195da..6127a034 100644 --- a/prometheus/lang_graph/nodes/issue_classifier_node.py +++ b/prometheus/lang_graph/nodes/issue_classifier_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index b725e6b1..4fa5806c 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import threading from typing import Dict diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index 4ec962b1..fccf8135 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -1,4 +1,3 @@ -import logging import threading from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/noop_node.py b/prometheus/lang_graph/nodes/noop_node.py index 56acb33e..c5878235 100644 --- a/prometheus/lang_graph/nodes/noop_node.py +++ b/prometheus/lang_graph/nodes/noop_node.py @@ -5,10 +5,11 @@ node graphs where a connection is needed but no processing is required. """ -from prometheus.utils.logger_manager import get_logger import threading from typing import Dict +from prometheus.utils.logger_manager import get_logger + class NoopNode: """No-operation node that routes workflow without processing. diff --git a/prometheus/lang_graph/nodes/reset_messages_node.py b/prometheus/lang_graph/nodes/reset_messages_node.py index 9b6bfbd8..a58d8422 100644 --- a/prometheus/lang_graph/nodes/reset_messages_node.py +++ b/prometheus/lang_graph/nodes/reset_messages_node.py @@ -10,10 +10,11 @@ - The same state attribute name is reused """ -from prometheus.utils.logger_manager import get_logger import threading from typing import Dict +from prometheus.utils.logger_manager import get_logger + class ResetMessagesNode: """Resets message states for workflow loop iterations. diff --git a/prometheus/lang_graph/nodes/update_container_node.py b/prometheus/lang_graph/nodes/update_container_node.py index f347fdff..c1dacb6d 100644 --- a/prometheus/lang_graph/nodes/update_container_node.py +++ b/prometheus/lang_graph/nodes/update_container_node.py @@ -6,14 +6,13 @@ between the agent's workspace and the container environment. """ -import logging import threading from typing import Dict from prometheus.docker.base_container import BaseContainer from prometheus.git.git_repository import GitRepository -from prometheus.utils.patch_util import get_updated_files from prometheus.utils.logger_manager import get_logger +from prometheus.utils.patch_util import get_updated_files class UpdateContainerNode: diff --git a/prometheus/lang_graph/nodes/user_defined_build_node.py b/prometheus/lang_graph/nodes/user_defined_build_node.py index eee7f09d..660b5373 100644 --- a/prometheus/lang_graph/nodes/user_defined_build_node.py +++ b/prometheus/lang_graph/nodes/user_defined_build_node.py @@ -1,4 +1,3 @@ -import logging import threading import uuid from typing import Any diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index 50e10190..43a1753c 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -112,14 +112,14 @@ def __init__( workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") - + # Conditionally invoke tools or continue to edit message workflow.add_conditional_edges( "issue_bug_analyzer_node", functools.partial(tools_condition, messages_key="issue_bug_analyzer_messages"), {"tools": "issue_bug_analyzer_tools", END: "edit_message_node"}, ) - + workflow.add_edge("issue_bug_analyzer_tools", "issue_bug_analyzer_node") workflow.add_edge("edit_message_node", "edit_node") diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py index 66ccb6af..de1a1c0d 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py @@ -156,14 +156,14 @@ def __init__( workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") - + # Conditionally invoke tools or continue to edit message workflow.add_conditional_edges( "issue_bug_analyzer_node", functools.partial(tools_condition, messages_key="issue_bug_analyzer_messages"), {"tools": "issue_bug_analyzer_tools", END: "edit_message_node"}, ) - + workflow.add_edge("issue_bug_analyzer_tools", "issue_bug_analyzer_node") workflow.add_edge("edit_message_node", "edit_node") diff --git a/prometheus/tools/container_command.py b/prometheus/tools/container_command.py index 3f3bd92e..e906528f 100644 --- a/prometheus/tools/container_command.py +++ b/prometheus/tools/container_command.py @@ -1,39 +1,42 @@ -from pydantic import BaseModel, Field -from prometheus.docker.general_container import GeneralContainer from dataclasses import dataclass + +from pydantic import BaseModel, Field + from prometheus.docker.base_container import BaseContainer + @dataclass class ToolSpec: description: str input_schema: type + class RunCommandInput(BaseModel): command: str = Field("The shell command to be run in the container") class ContainerCommandTool: """Tool class for executing shell commands in containers.""" - + run_command_spec = ToolSpec( description="""\ Run a shell command in the container and return the result of the command. You are always at the root of the codebase. """, - input_schema=RunCommandInput + input_schema=RunCommandInput, ) - + def __init__(self, container: BaseContainer): """Initialize the container command tool. Args: container: The GeneralContainer instance to execute commands in. """ self.container = container - + def run_command(self, command: str) -> str: """Run a shell command in the container and return the result. Args: - command: The shell command to be run in the container. + command: The shell command to be run in the container. Returns: The output of the command execution. """ diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index 8098fcce..179087e6 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -1,26 +1,29 @@ -import logging import os import shutil -from pathlib import Path from dataclasses import dataclass +from pathlib import Path from typing import Any, Dict, List, Tuple, Union from pydantic import BaseModel, Field from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.utils.knowledge_graph_utils import format_knowledge_graph_data +from prometheus.utils.logger_manager import get_logger from prometheus.utils.str_util import pre_append_line_numbers -logger = logging.getLogger("prometheus.tools.file_operation") +logger = get_logger(__name__) + @dataclass class ToolSpec: description: str input_schema: type + class ReadFileInput(BaseModel): relative_path: str = Field("The relative path of the file to read") + class ReadFileWithLineNumbersInput(BaseModel): relative_path: str = Field( description="The relative path of the file to read, eg. foo/bar/test.py, not absolute path" @@ -28,17 +31,20 @@ class ReadFileWithLineNumbersInput(BaseModel): start_line: int = Field(description="The start line number to read, 1-indexed and inclusive") end_line: int = Field(description="The ending line number to read, 1-indexed and exclusive") + class CreateFileInput(BaseModel): relative_path: str = Field( description="The relative path of the file to create, eg. foo/bar/test.py, not absolute path" ) content: str = Field(description="The content of the file to create") + class DeleteInput(BaseModel): relative_path: str = Field( description="The relative path of the file/dir to delete, eg. foo/bar/test.py, not absolute path" ) + class EditFileInput(BaseModel): relative_path: str = Field( description="The relative path of the file to edit, eg. foo/bar/test.py, not absolute path" @@ -61,7 +67,7 @@ class FileOperationTool: Returns up to the first 1000 lines by default to prevent context issues with large files. Returns an error message if the file doesn't exist. """, - input_schema=ReadFileInput + input_schema=ReadFileInput, ) read_file_with_line_numbers_spec = ToolSpec( @@ -70,7 +76,7 @@ class FileOperationTool: The line numbers are 1-indexed where start_line is inclusive and end_line is exclusive. For best results when analyzing code or text files, consider reading chunks of 500-1000 lines at a time. """, - input_schema=ReadFileWithLineNumbersInput + input_schema=ReadFileWithLineNumbersInput, ) create_file_spec = ToolSpec( @@ -79,7 +85,7 @@ class FileOperationTool: If the parent directories don't exist, they will be created automatically. Returns an error message if the file already exists. """, - input_schema=CreateFileInput + input_schema=CreateFileInput, ) delete_spec = ToolSpec( @@ -88,7 +94,7 @@ class FileOperationTool: For directories, it will recursively delete all contents. Returns an error message if the path doesn't exist. """, - input_schema=DeleteInput + input_schema=DeleteInput, ) edit_file_spec = ToolSpec( @@ -108,7 +114,7 @@ class FileOperationTool: new_content="return a / b" ) """, - input_schema=EditFileInput + input_schema=EditFileInput, ) def __init__(self, root_path: str, kg: KnowledgeGraph): @@ -136,7 +142,9 @@ def read_file(self, relative_path: str, n_lines: int = 1000) -> str: return pre_append_line_numbers("".join(lines[:n_lines]), 1) - def read_file_with_line_numbers(self, relative_path: str, start_line: int, end_line: int) -> str: + def read_file_with_line_numbers( + self, relative_path: str, start_line: int, end_line: int + ) -> str: if os.path.isabs(relative_path): return f"relative_path: {relative_path} is a absolute path, not relative path." @@ -169,7 +177,6 @@ def create_file(self, relative_path: str, content: str) -> str: file_path.write_text(content) return f"The file {relative_path} has been created." - def delete(self, relative_path: str) -> str: if os.path.isabs(relative_path): return f"relative_path: {relative_path} is a abolsute path, not relative path." @@ -211,7 +218,9 @@ def edit_file(self, relative_path: str, old_content: str, new_content: str) -> s return f"Successfully edited {relative_path}." - def read_file_with_knowledge_graph_data(self, relative_path: str) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: + def read_file_with_knowledge_graph_data( + self, relative_path: str + ) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: """ Read the content of a file and return it along with structured knowledge graph data. Used for context provider node diff --git a/prometheus/tools/graph_traversal.py b/prometheus/tools/graph_traversal.py index c8307985..2cec41ad 100644 --- a/prometheus/tools/graph_traversal.py +++ b/prometheus/tools/graph_traversal.py @@ -1,6 +1,6 @@ +from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Tuple, Union -from dataclasses import dataclass from pydantic import BaseModel, Field @@ -20,6 +20,7 @@ Returns a list of dictionaries containing the found nodes and their attributes. """ + @dataclass class ToolSpec: description: str @@ -29,55 +30,65 @@ class ToolSpec: class FindFileNodeWithBasenameInput(BaseModel): basename: str = Field("The basename of FileNode to search for") + class FindFileNodeWithRelativePathInput(BaseModel): relative_path: str = Field("The relative_path of FileNode to search for") + class FindASTNodeWithTextInFileWithBasenameInput(BaseModel): text: str = Field("Search ASTNode that exactly contains this text.") basename: str = Field("The basename of file/directory to search under for ASTNodes.") + class FindASTNodeWithTextInFileWithRelativePathInput(BaseModel): text: str = Field("Search ASTNode that exactly contains this text.") relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") + class FindASTNodeWithTypeInFileWithBasenameInput(BaseModel): type: str = Field("Search ASTNode with this tree-sitter node type.") basename: str = Field("The basename of file/directory to search under for ASTNodes.") + class FindASTNodeWithTypeInFileWithRelativePathInput(BaseModel): type: str = Field("Search ASTNode with this tree-sitter node type.") relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") + class FindTextNodeWithTextInput(BaseModel): text: str = Field("Search TextNode that exactly contains this text.") + class FindTextNodeWithTextInFileInput(BaseModel): text: str = Field("Search TextNode that exactly contains this text.") basename: str = Field("The basename of FileNode to search TextNode.") + class GetNextTextNodeWithNodeIdInput(BaseModel): node_id: int = Field("Get the next TextNode of this given node_id.") + class PreviewFileContentWithBasenameInput(BaseModel): basename: str = Field("The basename of FileNode to preview.") + class PreviewFileContentWithRelativePathInput(BaseModel): relative_path: str = Field("The relative path of FileNode to preview.") + class ReadCodeWithBasenameInput(BaseModel): basename: str = Field("The basename of FileNode to read.") start_line: int = Field("The starting line number, 1-indexed and inclusive.") end_line: int = Field("The ending line number, 1-indexed and exclusive.") + class ReadCodeWithRelativePathInput(BaseModel): relative_path: str = Field("The relative path of FileNode to read from root of codebase.") start_line: int = Field("The starting line number, 1-indexed and inclusive.") end_line: int = Field("The ending line number, 1-indexed and exclusive.") - class GraphTraversalTool: - # FileNode retrieval tools find_file_node_with_basename_spec = ToolSpec( description="""Find all FileNode in the graph with this basename of a file/dir. The basename must @@ -86,7 +97,7 @@ class GraphTraversalTool: You can use this tool to check if a file/dir with this basename exists or get all attributes related to the file/dir.""", - input_schema=FindFileNodeWithBasenameInput + input_schema=FindFileNodeWithBasenameInput, ) find_file_node_with_relative_path_spec = ToolSpec( @@ -96,7 +107,7 @@ class GraphTraversalTool: You can use this tool to check if a file/dir with this relative_path exists or get all attributes related to the file/dir.""", - input_schema=FindFileNodeWithRelativePathInput + input_schema=FindFileNodeWithRelativePathInput, ) # ASTNode retrieval tools @@ -106,7 +117,7 @@ class GraphTraversalTool: The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is looking for exact matches. For best results, use unique text segments of at least several words. The basename can be either a file (like 'bar.py', 'baz.java').""", - input_schema=FindASTNodeWithTextInFileWithBasenameInput + input_schema=FindASTNodeWithTextInFileWithBasenameInput, ) find_ast_node_with_text_in_file_with_relative_path_spec = ToolSpec( @@ -115,19 +126,19 @@ class GraphTraversalTool: The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is looking for exact matches. Therefore the search text should be exact as well. The relative path should be the path from the root of codebase (like 'src/core/parser.py').""", - input_schema=FindASTNodeWithTextInFileWithRelativePathInput + input_schema=FindASTNodeWithTextInFileWithRelativePathInput, ) find_ast_node_with_type_in_file_with_basename_spec = ToolSpec( description="""Find all ASTNode in the graph that has this tree-sitter node type in any source file with this basename. This tool is useful for searching class/function/method under files.""", - input_schema=FindASTNodeWithTypeInFileWithBasenameInput + input_schema=FindASTNodeWithTypeInFileWithBasenameInput, ) find_ast_node_with_type_in_file_with_relative_path_spec = ToolSpec( description="""Find all ASTNode in the graph that has this tree-sitter node type in any source file with this relative path. This tool is useful for searching class/function/method under a file.""", - input_schema=FindASTNodeWithTypeInFileWithRelativePathInput + input_schema=FindASTNodeWithTypeInFileWithRelativePathInput, ) # TextNode retrieval tools @@ -137,7 +148,7 @@ class GraphTraversalTool: looking for exact matches. Therefore the search text should be exact as well. You can use this tool to find all text/documentation in codebase that contains this text.""", - input_schema=FindTextNodeWithTextInput + input_schema=FindTextNodeWithTextInput, ) find_text_node_with_text_in_file_spec = ToolSpec( @@ -148,14 +159,14 @@ class GraphTraversalTool: (in this case foo is a file without extension). You can use this tool to find text/documentation in a specific file that contains this text.""", - input_schema=FindTextNodeWithTextInFileInput + input_schema=FindTextNodeWithTextInFileInput, ) get_next_text_node_with_node_id_spec = ToolSpec( description="""Get the next TextNode of this given node_id. You can use this tool to read the next section of text that you are interested in.""", - input_schema=GetNextTextNodeWithNodeIdInput + input_schema=GetNextTextNodeWithNodeIdInput, ) read_code_with_relative_path_spec = ToolSpec( @@ -172,13 +183,12 @@ class GraphTraversalTool: This tool is useful for examining specific sections of source code files when you know the exact line range you want to analyze. The function will return an error message if end_line is less than start_line.""", - input_schema=ReadCodeWithRelativePathInput + input_schema=ReadCodeWithRelativePathInput, ) def __init__(self, kg: KnowledgeGraph): self.kg = kg - ############################################################################### # FileNode retrieval # ############################################################################### @@ -200,7 +210,9 @@ def find_file_node_with_basename(self, basename: str) -> Tuple[str, List[Dict[st results.sort(key=lambda x: x["FileNode"]["node_id"]) return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - def find_file_node_with_relative_path(self, relative_path: str) -> Tuple[str, List[Dict[str, Any]]]: + def find_file_node_with_relative_path( + self, relative_path: str + ) -> Tuple[str, List[Dict[str, Any]]]: """Find all FileNodes with the given relative path.""" results = [] for kg_node in self.kg.get_file_nodes(): @@ -216,7 +228,6 @@ def find_file_node_with_relative_path(self, relative_path: str) -> Tuple[str, Li ) return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - ############################################################################### # ASTNode retrieval # ############################################################################### @@ -277,7 +288,9 @@ def find_ast_node_with_text_in_file( results.sort(key=lambda x: len(x["ASTNode"]["text"])) return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - def find_ast_node_with_text_in_file_with_basename(self, text: str, basename: str) -> Tuple[str, List[Dict[str, Any]]]: + def find_ast_node_with_text_in_file_with_basename( + self, text: str, basename: str + ) -> Tuple[str, List[Dict[str, Any]]]: """Find all ASTNodes containing the given text in files with the given basename.""" # Get file nodes with the given basename target_files_nodes: List[KnowledgeGraphNode] = [ @@ -285,8 +298,9 @@ def find_ast_node_with_text_in_file_with_basename(self, text: str, basename: str ] return self.find_ast_node_with_text_in_file(text, target_files_nodes) - - def find_ast_node_with_text_in_file_with_relative_path(self, text: str, relative_path: str) -> Tuple[str, List[Dict[str, Any]]]: + def find_ast_node_with_text_in_file_with_relative_path( + self, text: str, relative_path: str + ) -> Tuple[str, List[Dict[str, Any]]]: """Find all ASTNodes containing the given text in files with the given relative path.""" # Get file nodes with the given basename target_files_nodes: List[KnowledgeGraphNode] = [ @@ -350,7 +364,9 @@ def find_ast_node_with_type_in_file( results.sort(key=lambda x: len(x["ASTNode"]["text"])) return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - def find_ast_node_with_type_in_file_with_basename(self, type: str, basename: str) -> Tuple[str, List[Dict[str, Any]]]: + def find_ast_node_with_type_in_file_with_basename( + self, type: str, basename: str + ) -> Tuple[str, List[Dict[str, Any]]]: """Find all ASTNodes with the given type in files with the given basename.""" # Get file nodes with the given basename target_files_nodes: List[KnowledgeGraphNode] = [ @@ -358,8 +374,9 @@ def find_ast_node_with_type_in_file_with_basename(self, type: str, basename: str ] return self.find_ast_node_with_type_in_file(type, target_files_nodes) - - def find_ast_node_with_type_in_file_with_relative_path(self, type: str, relative_path: str) -> Tuple[str, List[Dict[str, Any]]]: + def find_ast_node_with_type_in_file_with_relative_path( + self, type: str, relative_path: str + ) -> Tuple[str, List[Dict[str, Any]]]: """Find all ASTNodes with the given type in files with the given relative path.""" # Get file nodes with the given basename target_files_nodes: List[KnowledgeGraphNode] = [ @@ -367,7 +384,6 @@ def find_ast_node_with_type_in_file_with_relative_path(self, type: str, relative ] return self.find_ast_node_with_type_in_file(type, target_files_nodes) - ############################################################################### # TextNode retrieval # ############################################################################### @@ -379,7 +395,9 @@ def find_file_node_of_a_text_node(self, text_node: KnowledgeGraphNode) -> Knowle next_chunk_reverse_map = { edge.target.node_id: edge.source for edge in self.kg.get_next_chunk_edges() } - has_file_node_map = {edge.target.node_id: edge.source for edge in self.kg.get_has_text_edges()} + has_file_node_map = { + edge.target.node_id: edge.source for edge in self.kg.get_has_text_edges() + } # Find the root text node current_text_node = text_node @@ -422,7 +440,6 @@ def find_text_node_with_text(self, text: str) -> Tuple[str, List[Dict[str, Any]] results.sort(key=lambda x: x["TextNode"]["node_id"]) return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - def find_text_node_with_text_in_file( self, text: str, basename: str ) -> Tuple[str, List[Dict[str, Any]]]: @@ -461,7 +478,6 @@ def find_text_node_with_text_in_file( results.sort(key=lambda x: x["TextNode"]["node_id"]) return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - def get_next_text_node_with_node_id(self, node_id: int) -> Tuple[str, List[Dict[str, Any]]]: """Get the next TextNode for the given node_id.""" @@ -479,7 +495,9 @@ def get_next_text_node_with_node_id(self, node_id: int) -> Tuple[str, List[Dict[ return format_knowledge_graph_data([]), [] # Get next chunk map - next_chunk_map = {edge.source.node_id: edge.target for edge in self.kg.get_next_chunk_edges()} + next_chunk_map = { + edge.source.node_id: edge.target for edge in self.kg.get_next_chunk_edges() + } # Get the next text node next_text_node = next_chunk_map.get(current_text_node.node_id, None) @@ -507,13 +525,13 @@ def get_next_text_node_with_node_id(self, node_id: int) -> Tuple[str, List[Dict[ ) return format_knowledge_graph_data(results), results - ############################################################################### # Other # ############################################################################### - - def read_code_with_relative_path(self, relative_path: str, start_line: int, end_line: int) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: + def read_code_with_relative_path( + self, relative_path: str, start_line: int, end_line: int + ) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: """Read a specific section of a source code file by relative path and line range.""" if end_line < start_line: return f"end_line {end_line} must be greater than start_line {start_line}!", None @@ -535,7 +553,9 @@ def read_code_with_relative_path(self, relative_path: str, start_line: int, end_ # Get the first ast node for this file first_ast_node = [ - edge.target for edge in self.kg.get_has_ast_edges() if edge.source.node_id == target_file.node_id + edge.target + for edge in self.kg.get_has_ast_edges() + if edge.source.node_id == target_file.node_id ][0] text = first_ast_node.node.text lines = text.split("\n") diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 918ea934..55a6266f 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -1,19 +1,18 @@ import os -import shutil -from pathlib import Path -from typing import Annotated -import json -import asyncio from dataclasses import dataclass +from typing import Annotated + from dynaconf.vendor.dotenv import load_dotenv -from pydantic import BaseModel, Field, field_validator -from tavily import TavilyClient, InvalidAPIKeyError, UsageLimitExceededError +from langchain_mcp_adapters.client import MultiServerMCPClient +from pydantic import BaseModel, Field +from tavily import InvalidAPIKeyError, TavilyClient, UsageLimitExceededError + from prometheus.configuration.config import settings from prometheus.utils.logger_manager import get_logger -from langchain_mcp_adapters.client import MultiServerMCPClient logger = get_logger(__name__) + @dataclass class ToolSpec: description: str @@ -30,45 +29,47 @@ class ToolSpec: class WebSearchInput(BaseModel): """Base parameters for Tavily search.""" + query: Annotated[str, Field(description="Search query")] + def format_results(response: dict) -> str: - """Format Tavily search results into a readable string.""" - output = [] - - # Add domain filter information if present - if response.get("included_domains") or response.get("excluded_domains"): - filters = [] - if response.get("included_domains"): - filters.append(f"Including domains: {', '.join(response['included_domains'])}") - if response.get("excluded_domains"): - filters.append(f"Excluding domains: {', '.join(response['excluded_domains'])}") - output.append("Search Filters:") - output.extend(filters) - output.append("") # Empty line for separation - - if response.get("answer"): - output.append(f"Answer: {response['answer']}") - output.append("\nSources:") - # Add immediate source references for the answer - for result in response["results"]: - output.append(f"- {result['title']}: {result['url']}") - output.append("") # Empty line for separation - - output.append("Detailed Results:") + """Format Tavily search results into a readable string.""" + output = [] + + # Add domain filter information if present + if response.get("included_domains") or response.get("excluded_domains"): + filters = [] + if response.get("included_domains"): + filters.append(f"Including domains: {', '.join(response['included_domains'])}") + if response.get("excluded_domains"): + filters.append(f"Excluding domains: {', '.join(response['excluded_domains'])}") + output.append("Search Filters:") + output.extend(filters) + output.append("") # Empty line for separation + + if response.get("answer"): + output.append(f"Answer: {response['answer']}") + output.append("\nSources:") + # Add immediate source references for the answer for result in response["results"]: - output.append(f"\nTitle: {result['title']}") - output.append(f"URL: {result['url']}") - output.append(f"Content: {result['content']}") - if result.get("published_date"): - output.append(f"Published: {result['published_date']}") - - return "\n".join(output) + output.append(f"- {result['title']}: {result['url']}") + output.append("") # Empty line for separation + + output.append("Detailed Results:") + for result in response["results"]: + output.append(f"\nTitle: {result['title']}") + output.append(f"URL: {result['url']}") + output.append(f"Content: {result['content']}") + if result.get("published_date"): + output.append(f"Published: {result['published_date']}") + + return "\n".join(output) class WebSearchTool: """Tool class for web search functionality.""" - + web_search_spec = ToolSpec( description="""\ Searches the web for technical information to aid in bug analysis and resolution. @@ -80,37 +81,41 @@ class WebSearchTool: Queries should be specific and include relevant keywords like library names, version numbers, and error codes. """, - input_schema=WebSearchInput + input_schema=WebSearchInput, ) - + def __init__(self): """Initialize the web search tool.""" self.tavily_client = tavily_client - - def web_search(self, query: str, max_results: int = 5, - include_domains: list[str] = [ - 'stackoverflow.com', - 'github.com', - 'developer.mozilla.org', - 'learn.microsoft.com', - 'docs.python.org', - 'pydantic.dev', - 'pypi.org', - 'readthedocs.org', - ], - exclude_domains: list[str] = None) -> str: + + def web_search( + self, + query: str, + max_results: int = 5, + include_domains: list[str] = [ + "stackoverflow.com", + "github.com", + "developer.mozilla.org", + "learn.microsoft.com", + "docs.python.org", + "pydantic.dev", + "pypi.org", + "readthedocs.org", + ], + exclude_domains: list[str] = None, + ) -> str: """Search the web for technical information to aid in bug analysis and resolution. - + Args: query: Search query string. max_results: Maximum number of results to return (default: 5). include_domains: List of domains to include in search. exclude_domains: List of domains to exclude from search. - + Returns: Formatted search results as a string. """ - + if tavily_client is None: raise RuntimeError("Tavily API key is not set") try: @@ -125,26 +130,28 @@ def web_search(self, query: str, max_results: int = 5, format_response = format_results(response) self._logger.info(f"web_search format_response: {format_response}") return format_response - except InvalidAPIKeyError: + except InvalidAPIKeyError: raise ValueError("Invalid Tavily API key") except UsageLimitExceededError: raise RuntimeError("Usage limit exceeded") except Exception as e: raise RuntimeError(f"An error occurred: {str(e)}") + async def mcp_web_search(): client = MultiServerMCPClient( - { - "tavily_web_search": { - "transport": "streamable_http", - "url": f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}", - } + { + "tavily_web_search": { + "transport": "streamable_http", + "url": f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}", } - ) - # 异步获取工具 + } + ) + # 异步获取工具 tools = await client.get_tools() return tools + if __name__ == "__main__": load_dotenv() tavily_api_key = os.getenv("PROMETHEUS_TAVILY_API_KEY") @@ -154,4 +161,4 @@ async def mcp_web_search(): else: tavily_client = TavilyClient(api_key=tavily_api_key) - print(web_search("What is the capital of France?")) \ No newline at end of file + print(web_search("What is the capital of France?")) diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index a6a3413d..3565912b 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -8,39 +8,39 @@ import logging import os import sys +from datetime import datetime from pathlib import Path from typing import Optional -from datetime import datetime from prometheus.configuration.config import settings class ColoredFormatter(logging.Formatter): """Colored log formatter""" - + # ANSI color codes COLORS = { - 'DEBUG': '\033[36m', # Cyan - 'INFO': '\033[32m', # Green - 'WARNING': '\033[33m', # Yellow - 'ERROR': '\033[31m', # Red - 'CRITICAL': '\033[35m', # Purple - 'RESET': '\033[0m' # Reset color + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[35m", # Purple + "RESET": "\033[0m", # Reset color } - + # Colored level names COLORED_LEVELNAMES = { - 'DEBUG': f'{COLORS["DEBUG"]}DEBUG{COLORS["RESET"]}', - 'INFO': f'{COLORS["INFO"]}INFO{COLORS["RESET"]}', - 'WARNING': f'{COLORS["WARNING"]}WARNING{COLORS["RESET"]}', - 'ERROR': f'{COLORS["ERROR"]}ERROR{COLORS["RESET"]}', - 'CRITICAL': f'{COLORS["CRITICAL"]}CRITICAL{COLORS["RESET"]}', + "DEBUG": f"{COLORS['DEBUG']}DEBUG{COLORS['RESET']}", + "INFO": f"{COLORS['INFO']}INFO{COLORS['RESET']}", + "WARNING": f"{COLORS['WARNING']}WARNING{COLORS['RESET']}", + "ERROR": f"{COLORS['ERROR']}ERROR{COLORS['RESET']}", + "CRITICAL": f"{COLORS['CRITICAL']}CRITICAL{COLORS['RESET']}", } - + def __init__(self, fmt=None, datefmt=None, use_colors=True): """ Initialize colored formatter - + Args: fmt: Log format string datefmt: Date format string @@ -48,15 +48,16 @@ def __init__(self, fmt=None, datefmt=None, use_colors=True): """ super().__init__(fmt, datefmt) self.use_colors = use_colors and self._supports_color() - + def _supports_color(self) -> bool: """Check if terminal supports colors""" # Check if running in a color-supporting terminal return ( - hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() and - sys.platform != 'win32' # Windows may need special handling - ) or 'FORCE_COLOR' in os.environ - + hasattr(sys.stdout, "isatty") + and sys.stdout.isatty() + and sys.platform != "win32" # Windows may need special handling + ) or "FORCE_COLOR" in os.environ + def format(self, record): """Format log record""" if self.use_colors and record.levelname in self.COLORED_LEVELNAMES: @@ -64,13 +65,13 @@ def format(self, record): original_levelname = record.levelname # Use colored level name record.levelname = self.COLORED_LEVELNAMES[record.levelname] - + # Format message formatted = super().format(record) - + # Restore original level name record.levelname = original_levelname - + return formatted else: return super().format(record) @@ -78,142 +79,148 @@ def format(self, record): class LoggerManager: """Logger manager class, responsible for creating and configuring all loggers""" - - _instance: Optional['LoggerManager'] = None + + _instance: Optional["LoggerManager"] = None _initialized: bool = False - - def __new__(cls) -> 'LoggerManager': + + def __new__(cls) -> "LoggerManager": """Singleton pattern implementation""" if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance - + def __init__(self): """Initialize logger manager""" if not self._initialized: self._setup_root_logger() self._initialized = True - + def _setup_root_logger(self): """Setup root logger""" # Get root logger self.root_logger = logging.getLogger("prometheus") - + # Clear existing handlers to avoid duplication self.root_logger.handlers.clear() - + # Set log level - log_level = getattr(settings, 'LOGGING_LEVEL', 'INFO') + log_level = getattr(settings, "LOGGING_LEVEL", "INFO") self.root_logger.setLevel(getattr(logging, log_level)) - + # Create colored formatter for console output self.colored_formatter = ColoredFormatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) - + # Create plain formatter for file output self.file_formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) - + # Create console handler (using colored formatter) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(self.colored_formatter) self.root_logger.addHandler(console_handler) - + # Prevent log propagation to parent logger self.root_logger.propagate = False - + # Log configuration information self._log_configuration() - + def _log_configuration(self): """Log configuration information""" # 动态获取settings中所有可用的配置属性 - config_attrs = [attr for attr in dir(settings) - if attr.isupper() and not attr.startswith('_')] - + config_attrs = [ + attr for attr in dir(settings) if attr.isupper() and not attr.startswith("_") + ] + for attr in config_attrs: - value = getattr(settings, attr, 'Not Set') - + value = getattr(settings, attr, "Not Set") + # 使用通配符匹配敏感配置项(包含KEY、API、PASSWORD的) - is_sensitive = any(keyword in attr.upper() for keyword in ['KEY', 'API', 'PASSWORD', "SECRET"]) - + is_sensitive = any( + keyword in attr.upper() for keyword in ["KEY", "API", "PASSWORD", "SECRET"] + ) + # 如果是敏感配置项,用星号代替 - if is_sensitive and value and value != 'Not Set': - masked_value = '*' * min(len(str(value)), 8) # 最多显示8个星号 + if is_sensitive and value and value != "Not Set": + masked_value = "*" * min(len(str(value)), 8) # 最多显示8个星号 self.root_logger.info(f"{attr}={masked_value}") else: self.root_logger.info(f"{attr}={value}") - + def get_logger(self, name: str) -> logging.Logger: """ Get logger with specified name - + Args: name: Logger name, recommended to use full module path - + Returns: Configured logger instance """ # Ensure logger name starts with prometheus if not name.startswith("prometheus"): name = f"prometheus.{name}" - + logger = logging.getLogger(name) - + # If it's a child logger, inherit root logger configuration if name != "prometheus": logger.parent = self.root_logger logger.propagate = True - + return logger - - def create_file_handler(self, log_file_path: Path, logger_name: str = "prometheus") -> logging.FileHandler: + + def create_file_handler( + self, log_file_path: Path, logger_name: str = "prometheus" + ) -> logging.FileHandler: """ Create file handler for specified logger - + Args: log_file_path: Log file path logger_name: Logger name - + Returns: Configured file handler """ # Ensure log directory exists log_file_path.parent.mkdir(parents=True, exist_ok=True) - + # Create file handler (using plain formatter, without colors) file_handler = logging.FileHandler(log_file_path) file_handler.setFormatter(self.file_formatter) - + # Get logger and add handler logger = self.get_logger(logger_name) logger.addHandler(file_handler) - + return file_handler - - def create_timestamped_file_handler(self, log_dir: Path, prefix: str = "prometheus", - logger_name: str = "prometheus") -> logging.FileHandler: + + def create_timestamped_file_handler( + self, log_dir: Path, prefix: str = "prometheus", logger_name: str = "prometheus" + ) -> logging.FileHandler: """ Create file handler with timestamp - + Args: log_dir: Log directory prefix: Log file prefix logger_name: Logger name - + Returns: Configured file handler """ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = log_dir / f"{prefix}_{timestamp}.log" return self.create_file_handler(log_file, logger_name) - + def remove_file_handler(self, handler: logging.FileHandler, logger_name: str = "prometheus"): """ Remove file handler - + Args: handler: Handler to remove logger_name: Logger name @@ -221,15 +228,15 @@ def remove_file_handler(self, handler: logging.FileHandler, logger_name: str = " logger = self.get_logger(logger_name) logger.removeHandler(handler) handler.close() - + def enable_colors(self): """Enable colored log output""" self.colored_formatter.use_colors = True and self.colored_formatter._supports_color() - + def disable_colors(self): """Disable colored log output""" self.colored_formatter.use_colors = False - + def is_colors_enabled(self) -> bool: """Check if colored output is enabled""" return self.colored_formatter.use_colors @@ -242,13 +249,13 @@ def is_colors_enabled(self) -> bool: def get_logger(name: str) -> logging.Logger: """ Convenience function to get logger - + Args: name: Logger name, recommended to use __name__ or module path - + Returns: Configured logger instance - + Examples: >>> logger = get_logger(__name__) >>> logger = get_logger("prometheus.tools.web_search") @@ -256,32 +263,34 @@ def get_logger(name: str) -> logging.Logger: return logger_manager.get_logger(name) -def create_file_handler(log_file_path: Path, logger_name: str = "prometheus") -> logging.FileHandler: +def create_file_handler( + log_file_path: Path, logger_name: str = "prometheus" +) -> logging.FileHandler: """ Convenience function to create file handler - + Args: log_file_path: Log file path logger_name: Logger name - + Returns: Configured file handler """ return logger_manager.create_file_handler(log_file_path, logger_name) -def create_timestamped_file_handler(log_dir: Path, prefix: str = "prometheus", - logger_name: str = "prometheus") -> logging.FileHandler: +def create_timestamped_file_handler( + log_dir: Path, prefix: str = "prometheus", logger_name: str = "prometheus" +) -> logging.FileHandler: """ Convenience function to create timestamped file handler - + Args: log_dir: Log directory - prefix: Log file prefix + prefix: Log file prefix logger_name: Logger name - + Returns: Configured file handler """ return logger_manager.create_timestamped_file_handler(log_dir, prefix, logger_name) - diff --git a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py index b64c0790..a446fa66 100644 --- a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py @@ -1,6 +1,7 @@ +from unittest.mock import patch + import pytest -from unittest.mock import Mock, patch -from langchain_core.messages import HumanMessage, AIMessage, ToolMessage +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages.tool import ToolCall from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode @@ -15,13 +16,15 @@ def fake_llm(): @pytest.fixture def fake_llm_with_tool_call(): """LLM that simulates making a web_search tool call.""" - return FakeListChatWithToolsModel(responses=["I need to search for information about this error."]) + return FakeListChatWithToolsModel( + responses=["I need to search for information about this error."] + ) def test_init_issue_bug_analyzer_node(fake_llm): """Test IssueBugAnalyzerNode initialization.""" node = IssueBugAnalyzerNode(fake_llm) - + assert node.system_prompt is not None assert len(node.tools) == 1 # Should have web_search tool assert node.tools[0].name == "web_search" @@ -45,7 +48,9 @@ def test_web_search_tool_integration(fake_llm_with_tool_call): node = IssueBugAnalyzerNode(fake_llm_with_tool_call) state = { "issue_bug_analyzer_messages": [ - HumanMessage(content="I'm getting a ValueError in my Python code. Can you help analyze it?") + HumanMessage( + content="I'm getting a ValueError in my Python code. Can you help analyze it?" + ) ] } @@ -54,32 +59,35 @@ def test_web_search_tool_integration(fake_llm_with_tool_call): # Verify the result contains the response message assert "issue_bug_analyzer_messages" in result assert len(result["issue_bug_analyzer_messages"]) == 1 - assert result["issue_bug_analyzer_messages"][0].content == "I need to search for information about this error." + assert ( + result["issue_bug_analyzer_messages"][0].content + == "I need to search for information about this error." + ) def test_web_search_tool_call_with_correct_parameters(fake_llm): """Test that web_search tool has correct configuration and can be called.""" node = IssueBugAnalyzerNode(fake_llm) - + # Test that the tool exists and has correct configuration web_search_tool = node.tools[0] assert web_search_tool.name == "web_search" assert "technical information" in web_search_tool.description.lower() - + # Test that the tool has the correct args schema - assert hasattr(web_search_tool, 'args_schema') + assert hasattr(web_search_tool, "args_schema") assert web_search_tool.args_schema is not None -@patch('prometheus.tools.web_search.tavily_client') +@patch("prometheus.tools.web_search.tavily_client") def test_web_search_tool_without_api_key(mock_tavily_client, fake_llm): """Test web_search tool behavior when API key is not available.""" # Simulate no API key scenario mock_tavily_client = None - + node = IssueBugAnalyzerNode(fake_llm) web_search_tool = node.tools[0] - + # The tool should still be created but may handle missing API key gracefully assert web_search_tool.name == "web_search" @@ -87,7 +95,7 @@ def test_web_search_tool_without_api_key(mock_tavily_client, fake_llm): def test_system_prompt_contains_web_search_info(fake_llm): """Test that the system prompt mentions web_search tool.""" node = IssueBugAnalyzerNode(fake_llm) - + system_prompt_content = node.system_prompt.content.lower() assert "web_search" in system_prompt_content assert "technical information" in system_prompt_content @@ -97,11 +105,11 @@ def test_web_search_tool_schema_validation(fake_llm): """Test that the web_search tool has proper input validation.""" node = IssueBugAnalyzerNode(fake_llm) web_search_tool = node.tools[0] - + # Check that the tool has an args_schema - assert hasattr(web_search_tool, 'args_schema') + assert hasattr(web_search_tool, "args_schema") assert web_search_tool.args_schema is not None - + # Test with valid input valid_input = {"query": "Python debugging techniques"} # This should not raise an exception @@ -112,17 +120,25 @@ def test_web_search_tool_schema_validation(fake_llm): def test_multiple_tool_calls_in_conversation(fake_llm): """Test handling multiple web_search calls in a conversation.""" node = IssueBugAnalyzerNode(fake_llm) - + # Simulate a conversation with tool calls state = { "issue_bug_analyzer_messages": [ HumanMessage(content="Analyze this bug: ImportError in my application"), AIMessage( content="Let me search for information about this error.", - tool_calls=[ToolCall(name="web_search", args={"query": "Python ImportError debugging"}, id="call_1")] + tool_calls=[ + ToolCall( + name="web_search", + args={"query": "Python ImportError debugging"}, + id="call_1", + ) + ], + ), + ToolMessage( + content="Search results: ImportError occurs when...", tool_call_id="call_1" ), - ToolMessage(content="Search results: ImportError occurs when...", tool_call_id="call_1"), - HumanMessage(content="The error still persists after trying the suggested fixes") + HumanMessage(content="The error still persists after trying the suggested fixes"), ] } diff --git a/tests/tools/test_mcp_client.py b/tests/tools/test_mcp_client.py index 0e8d10c4..0ddfc70e 100644 --- a/tests/tools/test_mcp_client.py +++ b/tests/tools/test_mcp_client.py @@ -1,75 +1,54 @@ import asyncio -from langchain_mcp_adapters.client import MultiServerMCPClient -from langgraph.graph import StateGraph, MessagesState, START -from langgraph.prebuilt import ToolNode, tools_condition -from langchain_core.messages import AIMessage, ToolMessage -from prometheus.app.services.llm_service import LLMService, get_model -from langchain.tools import StructuredTool -import functools -# 使用真实模型进行工具调用 -from prometheus.configuration.config import settings -import json -import re - -import asyncio -import inspect -import json -from copy import copy from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Literal, - Optional, - Sequence, - Tuple, - Type, Union, cast, - get_type_hints, ) from langchain_core.messages import ( - AIMessage, - AnyMessage, ToolCall, ToolMessage, ) from langchain_core.runnables import RunnableConfig -from langchain_core.runnables.config import ( - get_config_list, - get_executor_for_config, +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.errors import GraphInterrupt +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition +from langgraph.prebuilt.tool_node import ( + _handle_tool_error, + _infer_handled_types, + msg_content_output, ) -from langchain_core.runnables.utils import Input -from langchain_core.tools import BaseTool, InjectedToolArg -from langchain_core.tools import tool as create_tool -from langchain_core.tools.base import get_all_basemodel_annotations -from typing_extensions import Annotated, get_args, get_origin -from langgraph.errors import GraphInterrupt -from langgraph.store.base import BaseStore -from langgraph.utils.runnable import RunnableCallable -from langgraph.prebuilt.tool_node import msg_content_output, _infer_handled_types, _handle_tool_error +from prometheus.app.services.llm_service import get_model +# 使用真实模型进行工具调用 +from prometheus.configuration.config import settings - # 创建自定义 ToolNode +# 创建自定义 ToolNode preset_params = { "tavily-search": { "include_domains": ["pypi.org", "docs.python.org"], - "exclude_domains": ["stackoverflow.com", "*huggingface*", "discourse.slicer.org","ask.csdn.net", - "codepudding.com", "*geeksforgeeks*", "*github*", "forum.developer.parrot.com"], + "exclude_domains": [ + "stackoverflow.com", + "*huggingface*", + "discourse.slicer.org", + "ask.csdn.net", + "codepudding.com", + "*geeksforgeeks*", + "*github*", + "forum.developer.parrot.com", + ], } } + class CustomToolNode(ToolNode): """自定义 ToolNode,支持为特定工具添加预设参数""" - + def __init__(self, tools, preset_params=None, **kwargs): super().__init__(tools, **kwargs) self.preset_params = preset_params or {} - + async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: if invalid_tool_message := self._validate_tool_call(call): return invalid_tool_message @@ -77,7 +56,7 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage try: # 构建基础输入 input = {**call, **{"type": "tool_call"}} - + # 如果这个工具有预设参数,则添加到输入中 if call["name"] in self.preset_params: preset_for_tool = self.preset_params[call["name"]] @@ -86,13 +65,11 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage input["args"] = merged_args print(f"🔧 为工具 {call['name']} 添加预设参数: {preset_for_tool}") print(f"🔧 最终参数: {merged_args}") - + tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( input, config ) - tool_message.content = cast( - Union[str, list], msg_content_output(tool_message.content) - ) + tool_message.content = cast(Union[str, list], msg_content_output(tool_message.content)) return tool_message except GraphInterrupt as e: raise e @@ -114,15 +91,15 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage ) - async def main(): # 获取 Tavily API key tavily_api_key = settings.get("TAVILY_API_KEY", None) if tavily_api_key is None: print("错误: 未设置 TAVILY_API_KEY") return - - model = get_model("gpt-4o-mini", + + model = get_model( + "gpt-4o-mini", openai_format_api_key=settings.get("OPENAI_FORMAT_API_KEY", None), openai_format_base_url=settings.get("OPENAI_FORMAT_BASE_URL", None), anthropic_api_key=None, @@ -131,20 +108,17 @@ async def main(): max_output_tokens=15000, ) - - - async def init_tool(): + async def init_tool(): # 使用 HTTP 传输直接连接到 Tavily MCP 服务器 client = MultiServerMCPClient( - { + { "tavily_web_search": { "transport": "streamable_http", "url": f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}", } } ) - - + # 异步获取工具 tools = await client.get_tools() print(f"获取到的工具: {[tool.name for tool in tools]}") @@ -159,41 +133,39 @@ async def init_tool(): # properties[param_name]['default'] = ["pypi.org", "docs.python.org"] # print(f" ✅ 设置 {param_name} 默认值: include domains") # elif re.search(r'exclude.*domain', param_lower): - # properties[param_name]['default'] = ["stackoverflow.com", "*huggingface", "discourse.slicer.org","ask.csdn.net", + # properties[param_name]['default'] = ["stackoverflow.com", "*huggingface", "discourse.slicer.org","ask.csdn.net", # "codepudding.com", "*geeksforgeeks*", "*github*", "forum.developer.parrot.com"] # print(f" ✅ 设置 {param_name} 默认值: exclude domains") - + # elif re.search(r'search.*depth', param_lower): # properties[param_name]['default'] = "advanced" # print(f" ✅ 设置 {param_name} 默认值: advanced") return tools - tools = await init_tool() async def call_model(state: MessagesState): messages = state["messages"] - print(f"\n=== call_model 被调用 ===") + print("\n=== call_model 被调用 ===") print(f"输入消息数量: {len(messages)}") - + print(f"可用工具: {[tool.name for tool in tools]}") - + # 使用真实模型调用,绑定预设参数的工具 model_with_tools = model.bind_tools(tools) print("开始调用模型...") - + response = await model_with_tools.ainvoke(messages) print(f"模型响应类型: {type(response)}") - - + return {"messages": [response]} - + # 创建工具节点 builder = StateGraph(MessagesState) builder.add_node("call_model", call_model) # builder.add_node("tools", CustomToolNode(tools, preset_params=preset_params)) builder.add_node("tools", ToolNode(tools)) - + # 构建图 builder.add_edge(START, "call_model") builder.add_conditional_edges( @@ -201,9 +173,9 @@ async def call_model(state: MessagesState): tools_condition, ) builder.add_edge("tools", "call_model") - + graph = builder.compile() - + # 执行测试 - 演示如何传递 include_domains 等参数 # 注意:参数会在工具调用时由 LLM 自动传递,这里展示一个需要特定域名搜索的查询 test_query = """ @@ -226,11 +198,11 @@ async def call_model(state: MessagesState): response = await graph.ainvoke({"messages": system_prompt + "\n" + test_query}) # print("Response:", response) - + return response # 运行异步主函数 if __name__ == "__main__": result = asyncio.run(main()) - print(result['messages'][-1].content) \ No newline at end of file + print(result["messages"][-1].content) diff --git a/tests/tools/test_mcp_client_config.py b/tests/tools/test_mcp_client_config.py index 43082af6..bfa10e40 100644 --- a/tests/tools/test_mcp_client_config.py +++ b/tests/tools/test_mcp_client_config.py @@ -1,72 +1,67 @@ import asyncio import json -import tempfile import os -from langchain_mcp_adapters.client import MultiServerMCPClient -from langgraph.graph import StateGraph, MessagesState, START -from langgraph.prebuilt import ToolNode, tools_condition -from langchain_core.messages import AIMessage, ToolMessage # 使用项目中的自定义模拟模型,支持工具调用 import sys +import tempfile + +from langchain_core.messages import AIMessage, ToolMessage +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + sys.path.append("/root/lix/Prometheus/") from tests.test_utils.util import FakeListChatWithToolsModel + async def main(): # 可以动态设置多个配置参数 - config = { - "driver": "neo4j://enterprise-cluster:7687", - "timeout": 120, - "max_retries": 10 - } - + config = {"driver": "neo4j://enterprise-cluster:7687", "timeout": 120, "max_retries": 10} + # 创建临时配置文件 - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(config, f, indent=2) config_file_path = f.name - + try: client = MultiServerMCPClient( - { + { "weather": { "command": "python", "args": ["/root/lix/Prometheus/tests/tools/config_based_mcp_tools.py"], "transport": "stdio", "env": { "MCP_WEATHER_CONFIG": config_file_path # 通过环境变量传递配置文件路径 - } + }, } } ) - + # 异步获取工具 tools = await client.get_tools() print(f"获取到的工具: {[tool.name for tool in tools]}") - + # 使用支持工具的模拟模型 model = FakeListChatWithToolsModel(responses=["I need to check the weather for NYC"]) - + # 创建工具节点 tool_node = ToolNode(tools) - + def call_model(state: MessagesState): messages = state["messages"] - + # 检查是否已经有工具消息,如果有就结束 if any(isinstance(msg, ToolMessage) for msg in messages): return {"messages": [AIMessage(content="Weather check completed!")]} - + # 第一次调用时创建工具调用响应 response = AIMessage( content="Let me check the weather for you", - tool_calls=[{ - "name": "get_weather", - "args": {"location": "nyc"}, - "id": "call_1" - }] + tool_calls=[{"name": "get_weather", "args": {"location": "nyc"}, "id": "call_1"}], ) return {"messages": [response]} - + # 构建图 builder = StateGraph(MessagesState) builder.add_node("call_model", call_model) @@ -77,21 +72,22 @@ def call_model(state: MessagesState): tools_condition, ) builder.add_edge("tools", "call_model") - + graph = builder.compile() - + # 执行测试 weather_response = await graph.ainvoke({"messages": "what is the weather in nyc?"}) print("Response:", weather_response) - + return weather_response - + finally: # 清理临时配置文件 if os.path.exists(config_file_path): os.unlink(config_file_path) print(f"🗑️ 清理临时配置文件: {config_file_path}") + # 运行异步主函数 if __name__ == "__main__": result = asyncio.run(main()) diff --git a/tests/tools/test_mcp_server.py b/tests/tools/test_mcp_server.py index d6deb0b7..5009cf5b 100644 --- a/tests/tools/test_mcp_server.py +++ b/tests/tools/test_mcp_server.py @@ -1,14 +1,13 @@ -from mcp.server.fastmcp import FastMCP +import asyncio +import json import os import sys -import json -import asyncio from pathlib import Path -from typing import Dict, List, Any, Optional, Set +from typing import Any, Dict, List, Optional, Set import yaml from langchain_mcp_adapters.client import MultiServerMCPClient - +from mcp.server.fastmcp import FastMCP # ========================================== # MCP Server (existing behavior preserved) @@ -96,7 +95,9 @@ def _load_server_configs() -> Dict[str, Dict[str, Any]]: } -def _build_client(server_configs: Optional[Dict[str, Dict[str, Any]]] = None) -> MultiServerMCPClient: +def _build_client( + server_configs: Optional[Dict[str, Dict[str, Any]]] = None, +) -> MultiServerMCPClient: global _CLIENT_CACHE if _CLIENT_CACHE is not None: return _CLIENT_CACHE @@ -143,6 +144,7 @@ def build_default_node_tool_client() -> MultiServerMCPClient: # 确保在启动前注册所有工具 sys.path.append("~/lix/Prometheus") import prometheus.tools # noqa: F401 + mcp.run(transport="stdio") async def main(): diff --git a/tests/tools/test_mcp_tools.py b/tests/tools/test_mcp_tools.py index 9a9e9bf6..1c3f39a7 100644 --- a/tests/tools/test_mcp_tools.py +++ b/tests/tools/test_mcp_tools.py @@ -1,14 +1,16 @@ # weather_server.py import sys + from mcp.server.fastmcp import FastMCP mcp = FastMCP("Weather") + class WeatherTools: driver: str timeout: int max_retries: int - + @classmethod def configure(cls, **kwargs): """动态配置工具参数""" @@ -31,6 +33,7 @@ async def get_temperature(location: str) -> str: """Get temperature for a location.""" return f"[Driver={WeatherTools.driver}, Retries={WeatherTools.max_retries}] Temperature in {location} is 25°C" + def parse_args(args): """解析命令行参数为 kwargs""" kwargs = {} @@ -43,8 +46,8 @@ def parse_args(args): # 尝试转换数据类型 if value.isdigit(): value = int(value) - elif value.lower() in ['true', 'false']: - value = value.lower() == 'true' + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" kwargs[key] = value i += 2 else: @@ -54,6 +57,7 @@ def parse_args(args): i += 1 return kwargs + if __name__ == "__main__": # 动态解析命令行参数 # 支持格式:python test_mcp_tools.py --driver neo4j://server --timeout 60 --max_retries 5 @@ -61,4 +65,4 @@ def parse_args(args): if config: WeatherTools.configure(**config) # 启动 MCP - mcp.run(transport="stdio") \ No newline at end of file + mcp.run(transport="stdio") diff --git a/tests/tools/test_mcp_web_search.py b/tests/tools/test_mcp_web_search.py index 7ff1de5b..f3b104c0 100644 --- a/tests/tools/test_mcp_web_search.py +++ b/tests/tools/test_mcp_web_search.py @@ -1,24 +1,26 @@ -import os -import aiohttp -import asyncio -from pathlib import Path -from typing import Annotated, Optional import json from dataclasses import dataclass -from dynaconf.vendor.dotenv import load_dotenv -from pydantic import BaseModel, Field, field_validator -from prometheus.configuration.config import settings +from typing import Annotated, Optional + +import aiohttp from mcp.server.fastmcp import FastMCP +from pydantic import BaseModel, Field + +from prometheus.app.services.llm_service import get_model +from prometheus.configuration.config import settings from prometheus.utils.logger_manager import get_logger -from prometheus.app.services.llm_service import LLMService, get_model logger = get_logger(__name__) + + @dataclass class MCPToolSpec: description: str input_schema: type -model = get_model("gpt-4o-mini", + +model = get_model( + "gpt-4o-mini", openai_format_api_key=settings.get("OPENAI_API_KEY", None), openai_format_base_url=settings.get("OPENAI_BASE_URL", None), anthropic_api_key=None, @@ -42,16 +44,21 @@ class MCPToolSpec: class WebSearchInput(BaseModel): """Input parameters for web search.""" + query: Annotated[str, Field(description="Search query string")] max_results: Annotated[int, Field(description="Maximum number of results", default=5)] - include_domains: Annotated[Optional[list[str]], Field(description="List of domains to include", default=None)] - exclude_domains: Annotated[Optional[list[str]], Field(description="List of domains to exclude", default=None)] + include_domains: Annotated[ + Optional[list[str]], Field(description="List of domains to include", default=None) + ] + exclude_domains: Annotated[ + Optional[list[str]], Field(description="List of domains to exclude", default=None) + ] def format_results(response: dict) -> str: """Format Tavily search results into a readable string.""" output = [] - + # Add domain filter information if present if response.get("included_domains") or response.get("excluded_domains"): filters = [] @@ -62,7 +69,7 @@ def format_results(response: dict) -> str: output.append("Search Filters:") output.extend(filters) output.append("") # Empty line for separation - + # Add answer if present if response.get("answer"): output.append(f"Answer: {response['answer']}") @@ -71,7 +78,7 @@ def format_results(response: dict) -> str: for result in response.get("results", []): output.append(f"- {result.get('title', 'No title')}: {result.get('url', 'No URL')}") output.append("") # Empty line for separation - + # Add detailed results output.append("Detailed Results:") for result in response.get("results", []): @@ -80,13 +87,13 @@ def format_results(response: dict) -> str: output.append(f"Content: {result.get('content', 'No content')}") if result.get("published_date"): output.append(f"Published: {result['published_date']}") - + return "\n".join(output) class MCPWebSearchTool: """Web search tool class.""" - + web_search_spec = MCPToolSpec( description="""\ Searches the web for technical information to aid in bug analysis and resolution. @@ -98,7 +105,7 @@ class MCPWebSearchTool: Queries should be specific and include relevant keywords like library names, version numbers, and error codes. """, - input_schema=WebSearchInput + input_schema=WebSearchInput, ) @@ -126,27 +133,27 @@ async def web_search( include_domains: List of domains to include (default: technical documentation sites) exclude_domains: List of domains to exclude """ - + # Check if API key is available if tavily_api_key is None: return "Error: Tavily API key is not set" - + # Default technical search domains if include_domains is None: include_domains = [ - 'stackoverflow.com', - 'github.com', - 'developer.mozilla.org', - 'learn.microsoft.com', - 'docs.python.org', - 'pydantic.dev', - 'pypi.org', - 'readthedocs.org', - 'docs.djangoproject.com', - 'flask.palletsprojects.com', - 'fastapi.tiangolo.com' + "stackoverflow.com", + "github.com", + "developer.mozilla.org", + "learn.microsoft.com", + "docs.python.org", + "pydantic.dev", + "pypi.org", + "readthedocs.org", + "docs.djangoproject.com", + "flask.palletsprojects.com", + "fastapi.tiangolo.com", ] - + # Build request payload payload = { "query": query, @@ -154,10 +161,10 @@ async def web_search( "include_domains": include_domains or [], "exclude_domains": exclude_domains or [], } - + try: logger.info(f"Executing web search, query: {query}") - + # Use aiohttp to send HTTP request to MCP server async with aiohttp.ClientSession() as session: async with session.post(TAVILY_SERVER_URL, json=payload) as resp: @@ -165,15 +172,15 @@ async def web_search( error_msg = f"HTTP error {resp.status}: {await resp.text()}" logger.error(error_msg) return error_msg - + data = await resp.json() - + # Format response formatted_response = format_results(data) logger.info(f"Web search completed, returned {len(data.get('results', []))} results") - + return formatted_response - + except aiohttp.ClientError as e: error_msg = f"Network request error: {str(e)}" logger.error(error_msg) @@ -197,12 +204,12 @@ def run_mcp_server(): if __name__ == "__main__": # # Load environment variables # load_dotenv() - + # # Get Tavily API key # tavily_api_key = settings.get("TAVILY_API_KEY", None) # if tavily_api_key is None: # logger.warning("Tavily API key is not set") # TAVILY_SERVER_URL = f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}" - + # Run server run_mcp_server() From 1aa0c89e23371ac290ea551fbfd6a006abf2efa9 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Thu, 4 Sep 2025 22:12:02 +0800 Subject: [PATCH 14/30] feat: Add Tavily API key configuration to Docker Compose files and update web search tool initialization --- .github/workflows/pytest_and_coverage.yml | 3 + docker-compose.win_mac.yml | 3 + docker-compose.yml | 3 + prometheus/app/services/database_service.py | 1 - prometheus/graph_config/node_tools.yml | 162 ------------- .../nodes/issue_bug_analyzer_node.py | 38 ---- prometheus/tools/web_search.py | 78 ++----- .../nodes/test_issue_bug_analyzer_node.py | 4 +- tests/tools/test_mcp_client.py | 208 ----------------- tests/tools/test_mcp_client_config.py | 93 -------- tests/tools/test_mcp_server.py | 155 ------------- tests/tools/test_mcp_tools.py | 68 ------ tests/tools/test_mcp_web_search.py | 215 ------------------ 13 files changed, 35 insertions(+), 996 deletions(-) delete mode 100644 prometheus/graph_config/node_tools.yml delete mode 100644 tests/tools/test_mcp_client.py delete mode 100644 tests/tools/test_mcp_client_config.py delete mode 100644 tests/tools/test_mcp_server.py delete mode 100644 tests/tools/test_mcp_tools.py delete mode 100644 tests/tools/test_mcp_web_search.py diff --git a/.github/workflows/pytest_and_coverage.yml b/.github/workflows/pytest_and_coverage.yml index 7cb16920..eb7adf12 100644 --- a/.github/workflows/pytest_and_coverage.yml +++ b/.github/workflows/pytest_and_coverage.yml @@ -53,6 +53,9 @@ jobs: # GitHub settings PROMETHEUS_GITHUB_ACCESS_TOKEN: github_access_token + # Tavily API key + PROMETHEUS_TAVILY_API_KEY: tavily_api_key + # DATABASE settings PROMETHEUS_DATABASE_URL: postgresql://postgres:password@localhost:5432/postgres?sslmode=disable diff --git a/docker-compose.win_mac.yml b/docker-compose.win_mac.yml index 0f3e2087..9be33dd6 100644 --- a/docker-compose.win_mac.yml +++ b/docker-compose.win_mac.yml @@ -72,6 +72,9 @@ services: - PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS=${PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS} - PROMETHEUS_BASE_MODEL_TEMPERATURE=${PROMETHEUS_BASE_MODEL_TEMPERATURE} + # Tavily API key + - PROMETHEUS_TAVILY_API_KEY=${PROMETHEUS_TAVILY_API_KEY} + # Database settings - PROMETHEUS_DATABASE_URL=${PROMETHEUS_DATABASE_URL} diff --git a/docker-compose.yml b/docker-compose.yml index eebe3e81..a16d220b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -93,6 +93,9 @@ services: - PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS=${PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS} - PROMETHEUS_BASE_MODEL_TEMPERATURE=${PROMETHEUS_BASE_MODEL_TEMPERATURE} + # Tavily API key + - PROMETHEUS_TAVILY_API_KEY=${PROMETHEUS_TAVILY_API_KEY} + # Database settings - PROMETHEUS_DATABASE_URL=${PROMETHEUS_DATABASE_URL} diff --git a/prometheus/app/services/database_service.py b/prometheus/app/services/database_service.py index 4d889a4e..1280f665 100644 --- a/prometheus/app/services/database_service.py +++ b/prometheus/app/services/database_service.py @@ -1,4 +1,3 @@ - from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel diff --git a/prometheus/graph_config/node_tools.yml b/prometheus/graph_config/node_tools.yml deleted file mode 100644 index bca51965..00000000 --- a/prometheus/graph_config/node_tools.yml +++ /dev/null @@ -1,162 +0,0 @@ -nodes: - - name: BugFixVerificationSubgraphNode - tools: [] - - - name: BugFixVerifyNode - class: - - ContainerCommandTool: - tools: - - run_command - - - name: BugFixVerifyStructuredNode - tools: [] - - - name: BugReproducingExecuteNode - class: - - ContainerCommandTool: - tools: - - run_command - - - name: BugReproducingFileNode - class: - - FileOperationTool: - tools: - - read_file - - create_file - - - name: BugReproducingStructuredNode - tools: [] - - - name: BugReproducingWriteMessageNode - tools: [] - - - name: BugReproducingWriteNode - class: - - FileOperationTool: - tools: - - read_file - - - name: BugReproductionSubgraphNode - tools: [] - - - name: BuildAndTestSubgraphNode - tools: [] - - - name: ContextExtractionNode - tools: [] - - - name: ContextProviderNode - class: - - GraphTraversalTool: - tools: - - find_file_node_with_basename - - find_file_node_with_relative_path - - find_ast_node_with_text_in_file_with_basename - - find_ast_node_with_text_in_file_with_relative_path - - find_text_node_with_text - - find_text_node_with_text_in_file - - get_next_text_node_with_node_id - - preview_file_content_with_basename - - preview_file_content_with_relative_path - - read_code_with_basename - - read_code_with_relative_path - - - name: ContextQueryMessageNode - tools: [] - - - name: ContextRefineNode - tools: [] - - - name: ContextRetrievalSubgraphNode - tools: [] - - - name: EditMessageNode - tools: [] - - - name: EditNode - class: - - FileOperationTool: - tools: - - read_file - - read_file_with_line_numbers - - create_file - - delete - - edit_file - - - name: FinalPatchSelectionNode - tools: [] - - - name: GeneralBuildNode - class: - - ContainerCommandTool: - tools: - - run_command - - - name: GeneralBuildStructuredNode - tools: [] - - - name: GeneralTestNode - class: - - ContainerCommandTool: - tools: - - run_command - - - name: GeneralTestStructuredNode - tools: [] - - - name: GitDiffNode - tools: [] - - - name: GitResetNode - tools: [] - - - name: IssueBugAnalyzerMessageNode - tools: [] - - - name: IssueBugAnalyzerNode - class: - - WebSearchTool: - tools: - - web_search - - - name: IssueBugContextMessageNode - tools: [] - - - name: IssueBugReproductionContextMessageNode - tools: [] - - - name: IssueBugResponderNode - tools: [] - - - name: IssueBugSubgraphNode - tools: [] - - - name: IssueClassificationContextMessageNode - tools: [] - - - name: IssueClassificationSubgraphNode - tools: [] - - - name: IssueClassifierNode - tools: [] - - - name: IssueNotVerifiedBugSubgraphNode - tools: [] - - - name: IssueVerifiedBugSubgraphNode - tools: [] - - - name: NoopNode - tools: [] - - - name: ResetMessagesNode - tools: [] - - - name: UpdateContainerNode - tools: [] - - - name: UserDefinedBuildNode - tools: [] - - - name: UserDefinedTestNode - tools: [] diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index 899471d9..ae568bc0 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -11,43 +11,6 @@ class IssueBugAnalyzerNode: - # SYS_PROMPT = """\ - # You are an expert software engineer specializing in bug analysis and fixes. Your role is to: - - # 1. Carefully analyze reported software issues and bugs by: - # - Understanding issue descriptions and symptoms - # - Identifying affected code components - # - Tracing problematic execution paths - - # 2. Determine root causes through systematic investigation: - # - Analyze why the current behavior deviates from expected - # - Identify which specific code elements are responsible - # - Understand the context and interactions causing the issue - - # 3. Provide high-level fix suggestions by describing: - # - Which specific files need modification - # - Which functions or code blocks need changes - # - What logical changes are needed (e.g., "variable x needs to be renamed to y", "need to add validation for parameter z") - # - Why these changes would resolve the issue - - # 4. For patch failures, analyze by: - # - Understanding error messages and test failures - # - Identifying what went wrong with the previous attempt - # - Suggesting revised high-level changes that avoid the previous issues - - # Tools available: - # - web_search: Searches the web for technical information to aid in bug analysis and resolution. - - # Important: - # - Do NOT provide actual code snippets or diffs - # - DO provide clear file paths and function names where changes are needed - # - Focus on describing WHAT needs to change and WHY, not HOW to change it - # - Keep descriptions precise and actionable, as they will be used by another agent to implement the changes - - # Communicate in a clear, technical manner focused on accurate analysis and practical suggestions - # rather than implementation details. - # """ - SYS_PROMPT = """\ You are an expert software engineer specializing in bug analysis and fixes. Your role is to: @@ -106,7 +69,6 @@ def __init__(self, model: BaseChatModel): self.web_search_tool = WebSearchTool() self.model = model self.system_prompt = SystemMessage(self.SYS_PROMPT) - # self.tools = asyncio.run(mcp_web_search()) # mcp mode self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 55a6266f..5e59f8c8 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -1,17 +1,12 @@ -import os from dataclasses import dataclass from typing import Annotated -from dynaconf.vendor.dotenv import load_dotenv -from langchain_mcp_adapters.client import MultiServerMCPClient from pydantic import BaseModel, Field from tavily import InvalidAPIKeyError, TavilyClient, UsageLimitExceededError from prometheus.configuration.config import settings from prometheus.utils.logger_manager import get_logger -logger = get_logger(__name__) - @dataclass class ToolSpec: @@ -19,14 +14,6 @@ class ToolSpec: input_schema: type -tavily_api_key = settings.TAVILY_API_KEY -if tavily_api_key is None: - logger.warning("Tavily API key is not set") - tavily_client = None -else: - tavily_client = TavilyClient(api_key=tavily_api_key) - - class WebSearchInput(BaseModel): """Base parameters for Tavily search.""" @@ -86,22 +73,22 @@ class WebSearchTool: def __init__(self): """Initialize the web search tool.""" + # Load environment variables from .env file + self._logger = get_logger(__name__) + + tavily_api_key = settings.TAVILY_API_KEY + if tavily_api_key is None: + self._logger.warning("Tavily API key is not set") + tavily_client = None + else: + tavily_client = TavilyClient(api_key=tavily_api_key) self.tavily_client = tavily_client def web_search( self, query: str, max_results: int = 5, - include_domains: list[str] = [ - "stackoverflow.com", - "github.com", - "developer.mozilla.org", - "learn.microsoft.com", - "docs.python.org", - "pydantic.dev", - "pypi.org", - "readthedocs.org", - ], + include_domains=None, exclude_domains: list[str] = None, ) -> str: """Search the web for technical information to aid in bug analysis and resolution. @@ -116,16 +103,27 @@ def web_search( Formatted search results as a string. """ - if tavily_client is None: + if include_domains is None: + include_domains = [ + "stackoverflow.com", + "github.com", + "developer.mozilla.org", + "learn.microsoft.com", + "docs.python.org", + "pydantic.dev", + "pypi.org", + "readthedocs.org", + ] + if self.tavily_client is None: raise RuntimeError("Tavily API key is not set") try: - response = tavily_client.search( + response = self.tavily_client.search( query=query, max_results=max_results, search_depth="advanced", include_answer=True, - include_domains=include_domains or [], # Convert None to empty list - exclude_domains=exclude_domains or [], # Convert None to empty list + include_domains=include_domains or [], # Convert None to an empty list + exclude_domains=exclude_domains or [], # Convert None to an empty list ) format_response = format_results(response) self._logger.info(f"web_search format_response: {format_response}") @@ -136,29 +134,3 @@ def web_search( raise RuntimeError("Usage limit exceeded") except Exception as e: raise RuntimeError(f"An error occurred: {str(e)}") - - -async def mcp_web_search(): - client = MultiServerMCPClient( - { - "tavily_web_search": { - "transport": "streamable_http", - "url": f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}", - } - } - ) - # 异步获取工具 - tools = await client.get_tools() - return tools - - -if __name__ == "__main__": - load_dotenv() - tavily_api_key = os.getenv("PROMETHEUS_TAVILY_API_KEY") - if tavily_api_key is None: - logger.warning("Tavily API key is not set") - tavily_client = None - else: - tavily_client = TavilyClient(api_key=tavily_api_key) - - print(web_search("What is the capital of France?")) diff --git a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py index a446fa66..70951f3b 100644 --- a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py @@ -80,11 +80,9 @@ def test_web_search_tool_call_with_correct_parameters(fake_llm): @patch("prometheus.tools.web_search.tavily_client") -def test_web_search_tool_without_api_key(mock_tavily_client, fake_llm): +def test_web_search_tool_without_api_key(fake_llm): """Test web_search tool behavior when API key is not available.""" # Simulate no API key scenario - mock_tavily_client = None - node = IssueBugAnalyzerNode(fake_llm) web_search_tool = node.tools[0] diff --git a/tests/tools/test_mcp_client.py b/tests/tools/test_mcp_client.py deleted file mode 100644 index 0ddfc70e..00000000 --- a/tests/tools/test_mcp_client.py +++ /dev/null @@ -1,208 +0,0 @@ -import asyncio -from typing import ( - Union, - cast, -) - -from langchain_core.messages import ( - ToolCall, - ToolMessage, -) -from langchain_core.runnables import RunnableConfig -from langchain_mcp_adapters.client import MultiServerMCPClient -from langgraph.errors import GraphInterrupt -from langgraph.graph import START, MessagesState, StateGraph -from langgraph.prebuilt import ToolNode, tools_condition -from langgraph.prebuilt.tool_node import ( - _handle_tool_error, - _infer_handled_types, - msg_content_output, -) - -from prometheus.app.services.llm_service import get_model - -# 使用真实模型进行工具调用 -from prometheus.configuration.config import settings - -# 创建自定义 ToolNode -preset_params = { - "tavily-search": { - "include_domains": ["pypi.org", "docs.python.org"], - "exclude_domains": [ - "stackoverflow.com", - "*huggingface*", - "discourse.slicer.org", - "ask.csdn.net", - "codepudding.com", - "*geeksforgeeks*", - "*github*", - "forum.developer.parrot.com", - ], - } -} - - -class CustomToolNode(ToolNode): - """自定义 ToolNode,支持为特定工具添加预设参数""" - - def __init__(self, tools, preset_params=None, **kwargs): - super().__init__(tools, **kwargs) - self.preset_params = preset_params or {} - - async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: - if invalid_tool_message := self._validate_tool_call(call): - return invalid_tool_message - - try: - # 构建基础输入 - input = {**call, **{"type": "tool_call"}} - - # 如果这个工具有预设参数,则添加到输入中 - if call["name"] in self.preset_params: - preset_for_tool = self.preset_params[call["name"]] - # 预设参数优先级较低,不会覆盖用户传递的参数 - merged_args = {**preset_for_tool, **call.get("args", {})} - input["args"] = merged_args - print(f"🔧 为工具 {call['name']} 添加预设参数: {preset_for_tool}") - print(f"🔧 最终参数: {merged_args}") - - tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( - input, config - ) - tool_message.content = cast(Union[str, list], msg_content_output(tool_message.content)) - return tool_message - except GraphInterrupt as e: - raise e - except Exception as e: - # 使用父类的错误处理逻辑 - if isinstance(self.handle_tool_errors, tuple): - handled_types: tuple = self.handle_tool_errors - elif callable(self.handle_tool_errors): - handled_types = _infer_handled_types(self.handle_tool_errors) - else: - handled_types = (Exception,) - - if not self.handle_tool_errors or not isinstance(e, handled_types): - raise e - else: - content = _handle_tool_error(e, flag=self.handle_tool_errors) - return ToolMessage( - content=content, name=call["name"], tool_call_id=call["id"], status="error" - ) - - -async def main(): - # 获取 Tavily API key - tavily_api_key = settings.get("TAVILY_API_KEY", None) - if tavily_api_key is None: - print("错误: 未设置 TAVILY_API_KEY") - return - - model = get_model( - "gpt-4o-mini", - openai_format_api_key=settings.get("OPENAI_FORMAT_API_KEY", None), - openai_format_base_url=settings.get("OPENAI_FORMAT_BASE_URL", None), - anthropic_api_key=None, - gemini_api_key=None, - temperature=0.0, - max_output_tokens=15000, - ) - - async def init_tool(): - # 使用 HTTP 传输直接连接到 Tavily MCP 服务器 - client = MultiServerMCPClient( - { - "tavily_web_search": { - "transport": "streamable_http", - "url": f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}", - } - } - ) - - # 异步获取工具 - tools = await client.get_tools() - print(f"获取到的工具: {[tool.name for tool in tools]}") - # for tool in tools: - # print(f"\n工具名称: {tool.name}") - # if hasattr(tool, 'args_schema') and tool.args_schema: - # properties = tool.args_schema.get('properties', {}) - # # 简单的正则匹配设置默认值 - # for param_name in properties.keys(): - # param_lower = param_name.lower() - # if re.search(r'include.*domain', param_lower): - # properties[param_name]['default'] = ["pypi.org", "docs.python.org"] - # print(f" ✅ 设置 {param_name} 默认值: include domains") - # elif re.search(r'exclude.*domain', param_lower): - # properties[param_name]['default'] = ["stackoverflow.com", "*huggingface", "discourse.slicer.org","ask.csdn.net", - # "codepudding.com", "*geeksforgeeks*", "*github*", "forum.developer.parrot.com"] - # print(f" ✅ 设置 {param_name} 默认值: exclude domains") - - # elif re.search(r'search.*depth', param_lower): - # properties[param_name]['default'] = "advanced" - # print(f" ✅ 设置 {param_name} 默认值: advanced") - return tools - - tools = await init_tool() - - async def call_model(state: MessagesState): - messages = state["messages"] - print("\n=== call_model 被调用 ===") - print(f"输入消息数量: {len(messages)}") - - print(f"可用工具: {[tool.name for tool in tools]}") - - # 使用真实模型调用,绑定预设参数的工具 - model_with_tools = model.bind_tools(tools) - print("开始调用模型...") - - response = await model_with_tools.ainvoke(messages) - print(f"模型响应类型: {type(response)}") - - return {"messages": [response]} - - # 创建工具节点 - builder = StateGraph(MessagesState) - builder.add_node("call_model", call_model) - # builder.add_node("tools", CustomToolNode(tools, preset_params=preset_params)) - builder.add_node("tools", ToolNode(tools)) - - # 构建图 - builder.add_edge(START, "call_model") - builder.add_conditional_edges( - "call_model", - tools_condition, - ) - builder.add_edge("tools", "call_model") - - graph = builder.compile() - - # 执行测试 - 演示如何传递 include_domains 等参数 - # 注意:参数会在工具调用时由 LLM 自动传递,这里展示一个需要特定域名搜索的查询 - test_query = """ - ERROR: Could not find a version that satisfies the requirement opencv (from versions: none) - ERROR: No matching distribution found for opencv - 报错 - """ - - system_prompt = """\ - You are a web search assistant. When using the tavily_search tool, ALWAYS include these parameters: - - exclude_domains: ["stackoverflow.com", "*huggingface*", "discourse.slicer.org","ask.csdn.net", "codepudding.com", "*geeksforgeeks*", "*github*", "forum.developer.parrot.com"] - - include_domains: ['pypi.org', 'docs.python.org'] - - search_depth: "advanced" - - Make sure to explicitly pass these parameters in your tool call. - """ - # system_prompt = """\ - # You are a web search assistant. help the human to find the answer to the question. - # """ - - response = await graph.ainvoke({"messages": system_prompt + "\n" + test_query}) - # print("Response:", response) - - return response - - -# 运行异步主函数 -if __name__ == "__main__": - result = asyncio.run(main()) - print(result["messages"][-1].content) diff --git a/tests/tools/test_mcp_client_config.py b/tests/tools/test_mcp_client_config.py deleted file mode 100644 index bfa10e40..00000000 --- a/tests/tools/test_mcp_client_config.py +++ /dev/null @@ -1,93 +0,0 @@ -import asyncio -import json -import os - -# 使用项目中的自定义模拟模型,支持工具调用 -import sys -import tempfile - -from langchain_core.messages import AIMessage, ToolMessage -from langchain_mcp_adapters.client import MultiServerMCPClient -from langgraph.graph import START, MessagesState, StateGraph -from langgraph.prebuilt import ToolNode, tools_condition - -sys.path.append("/root/lix/Prometheus/") -from tests.test_utils.util import FakeListChatWithToolsModel - - -async def main(): - # 可以动态设置多个配置参数 - config = {"driver": "neo4j://enterprise-cluster:7687", "timeout": 120, "max_retries": 10} - - # 创建临时配置文件 - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - json.dump(config, f, indent=2) - config_file_path = f.name - - try: - client = MultiServerMCPClient( - { - "weather": { - "command": "python", - "args": ["/root/lix/Prometheus/tests/tools/config_based_mcp_tools.py"], - "transport": "stdio", - "env": { - "MCP_WEATHER_CONFIG": config_file_path # 通过环境变量传递配置文件路径 - }, - } - } - ) - - # 异步获取工具 - tools = await client.get_tools() - print(f"获取到的工具: {[tool.name for tool in tools]}") - - # 使用支持工具的模拟模型 - model = FakeListChatWithToolsModel(responses=["I need to check the weather for NYC"]) - - # 创建工具节点 - tool_node = ToolNode(tools) - - def call_model(state: MessagesState): - messages = state["messages"] - - # 检查是否已经有工具消息,如果有就结束 - if any(isinstance(msg, ToolMessage) for msg in messages): - return {"messages": [AIMessage(content="Weather check completed!")]} - - # 第一次调用时创建工具调用响应 - response = AIMessage( - content="Let me check the weather for you", - tool_calls=[{"name": "get_weather", "args": {"location": "nyc"}, "id": "call_1"}], - ) - return {"messages": [response]} - - # 构建图 - builder = StateGraph(MessagesState) - builder.add_node("call_model", call_model) - builder.add_node("tools", tool_node) - builder.add_edge(START, "call_model") - builder.add_conditional_edges( - "call_model", - tools_condition, - ) - builder.add_edge("tools", "call_model") - - graph = builder.compile() - - # 执行测试 - weather_response = await graph.ainvoke({"messages": "what is the weather in nyc?"}) - print("Response:", weather_response) - - return weather_response - - finally: - # 清理临时配置文件 - if os.path.exists(config_file_path): - os.unlink(config_file_path) - print(f"🗑️ 清理临时配置文件: {config_file_path}") - - -# 运行异步主函数 -if __name__ == "__main__": - result = asyncio.run(main()) diff --git a/tests/tools/test_mcp_server.py b/tests/tools/test_mcp_server.py deleted file mode 100644 index 5009cf5b..00000000 --- a/tests/tools/test_mcp_server.py +++ /dev/null @@ -1,155 +0,0 @@ -import asyncio -import json -import os -import sys -from pathlib import Path -from typing import Any, Dict, List, Optional, Set - -import yaml -from langchain_mcp_adapters.client import MultiServerMCPClient -from mcp.server.fastmcp import FastMCP - -# ========================================== -# MCP Server (existing behavior preserved) -# ========================================== -# Create unified MCP server instance -mcp = FastMCP("PrometheusTools") - - -# ========================================== -# Dynamic MCP client based on node_tools.yml -# ========================================== -_NODE_TOOLS_CACHE: Optional[Dict[str, List[str]]] = None -_CLIENT_CACHE: Optional[MultiServerMCPClient] = None - - -def _load_node_tools_map(config_path: Optional[Path] = None) -> Dict[str, List[str]]: - """Load node->tools mapping from prometheus/graph_config/node_tools.yml. - - Returns a dict: { node_name: [tool_name, ...] } - """ - global _NODE_TOOLS_CACHE - if _NODE_TOOLS_CACHE is not None: - return _NODE_TOOLS_CACHE - - if config_path is None: - # mcp_server.py is in prometheus/tools/, go up one to prometheus/ - project_root = Path(__file__).resolve().parents[1] - config_path = project_root / "graph_config" / "node_tools.yml" - - with config_path.open("r", encoding="utf-8") as f: - data = yaml.safe_load(f) or {} - - nodes = data.get("nodes", []) - node_to_tools: Dict[str, List[str]] = {} - for item in nodes: - name = item.get("name") - tools = item.get("tools", []) or [] - if name is None: - continue - if isinstance(tools, list): - node_to_tools[name] = tools - else: - # In case of malformed YAML (non-list), coerce to list - node_to_tools[name] = [tools] - - _NODE_TOOLS_CACHE = node_to_tools - return node_to_tools - - -def _load_server_configs() -> Dict[str, Dict[str, Any]]: - """Load MCP server configurations from environment. - - Expected env var PROMETHEUS_MCP_SERVERS as JSON, e.g.: - { - "math": { - "command": "python", - "args": ["/abs/path/to/examples/math_server.py"], - "transport": "stdio" - }, - "weather": { - "url": "http://localhost:8000/mcp/", - "transport": "streamable_http" - } - } - - If not provided, default to spawning this file as a stdio MCP server under id "PrometheusTools". - """ - raw = os.getenv("PROMETHEUS_MCP_SERVERS") - if raw: - try: - cfg = json.loads(raw) - if isinstance(cfg, dict): - return cfg - except json.JSONDecodeError: - pass - - # Fallback: local stdio server using this file - this_file = Path(__file__).resolve() - return { - "PrometheusTools": { - "command": "python", - "args": [str(this_file)], - "transport": "stdio", - } - } - - -def _build_client( - server_configs: Optional[Dict[str, Dict[str, Any]]] = None, -) -> MultiServerMCPClient: - global _CLIENT_CACHE - if _CLIENT_CACHE is not None: - return _CLIENT_CACHE - - if server_configs is None: - server_configs = _load_server_configs() - - client = MultiServerMCPClient(server_configs) - _CLIENT_CACHE = client - return client - - -async def get_all_tools() -> List[Any]: - """Fetch all tools from all configured MCP servers.""" - client = _build_client() - tools = await client.get_tools() - return tools - - -def get_required_tool_names_for_node(node_name: str) -> List[str]: - mapping = _load_node_tools_map() - return mapping.get(node_name, []) - - -async def get_tools_for_node(node_name: str) -> List[Any]: - """Return the list of MCP tools required by the given node name. - - This will connect to all configured MCP servers, fetch their tools, and filter - by the names listed for the node in node_tools.yml. - """ - required: Set[str] = set(get_required_tool_names_for_node(node_name)) - if not required: - return [] - all_tools = await get_all_tools() - return [t for t in all_tools if getattr(t, "name", None) in required] - - -def build_default_node_tool_client() -> MultiServerMCPClient: - """Expose a builder for external callers if needed.""" - return _build_client() - - -if __name__ == "__main__": - # 确保在启动前注册所有工具 - sys.path.append("~/lix/Prometheus") - import prometheus.tools # noqa: F401 - - mcp.run(transport="stdio") - - async def main(): - tools = await get_all_tools() - for t in sorted(tools, key=lambda x: getattr(x, "name", "")): - print(getattr(t, "name", str(t))) - - asyncio.run(main()) diff --git a/tests/tools/test_mcp_tools.py b/tests/tools/test_mcp_tools.py deleted file mode 100644 index 1c3f39a7..00000000 --- a/tests/tools/test_mcp_tools.py +++ /dev/null @@ -1,68 +0,0 @@ -# weather_server.py -import sys - -from mcp.server.fastmcp import FastMCP - -mcp = FastMCP("Weather") - - -class WeatherTools: - driver: str - timeout: int - max_retries: int - - @classmethod - def configure(cls, **kwargs): - """动态配置工具参数""" - for key, value in kwargs.items(): - if hasattr(cls, key): - setattr(cls, key, value) - print(f"设置 {key} = {value}") - else: - print(f"警告: 未知参数 {key} = {value}") - - @staticmethod - @mcp.tool() - async def get_weather(location: str) -> str: - """Get weather for a location.""" - return f"[Driver={WeatherTools.driver}, Timeout={WeatherTools.timeout}s] It's always sunny in {location}" - - @staticmethod - @mcp.tool() - async def get_temperature(location: str) -> str: - """Get temperature for a location.""" - return f"[Driver={WeatherTools.driver}, Retries={WeatherTools.max_retries}] Temperature in {location} is 25°C" - - -def parse_args(args): - """解析命令行参数为 kwargs""" - kwargs = {} - i = 1 - while i < len(args): - if args[i].startswith("--"): - key = args[i][2:] # 移除 "--" 前缀 - if i + 1 < len(args) and not args[i + 1].startswith("--"): - value = args[i + 1] - # 尝试转换数据类型 - if value.isdigit(): - value = int(value) - elif value.lower() in ["true", "false"]: - value = value.lower() == "true" - kwargs[key] = value - i += 2 - else: - kwargs[key] = True # 布尔标志 - i += 1 - else: - i += 1 - return kwargs - - -if __name__ == "__main__": - # 动态解析命令行参数 - # 支持格式:python test_mcp_tools.py --driver neo4j://server --timeout 60 --max_retries 5 - config = parse_args(sys.argv) - if config: - WeatherTools.configure(**config) - # 启动 MCP - mcp.run(transport="stdio") diff --git a/tests/tools/test_mcp_web_search.py b/tests/tools/test_mcp_web_search.py deleted file mode 100644 index f3b104c0..00000000 --- a/tests/tools/test_mcp_web_search.py +++ /dev/null @@ -1,215 +0,0 @@ -import json -from dataclasses import dataclass -from typing import Annotated, Optional - -import aiohttp -from mcp.server.fastmcp import FastMCP -from pydantic import BaseModel, Field - -from prometheus.app.services.llm_service import get_model -from prometheus.configuration.config import settings -from prometheus.utils.logger_manager import get_logger - -logger = get_logger(__name__) - - -@dataclass -class MCPToolSpec: - description: str - input_schema: type - - -model = get_model( - "gpt-4o-mini", - openai_format_api_key=settings.get("OPENAI_API_KEY", None), - openai_format_base_url=settings.get("OPENAI_BASE_URL", None), - anthropic_api_key=None, - gemini_api_key=None, - temperature=0.0, - max_output_tokens=15000, -) - - -# Initialize MCP server -mcp = FastMCP("WebSearchTool") - -# Get Tavily API key -tavily_api_key = settings.get("TAVILY_API_KEY", None) -if tavily_api_key is None: - logger.warning("Tavily API key is not set") - -# MCP server URL -TAVILY_SERVER_URL = f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}" - - -class WebSearchInput(BaseModel): - """Input parameters for web search.""" - - query: Annotated[str, Field(description="Search query string")] - max_results: Annotated[int, Field(description="Maximum number of results", default=5)] - include_domains: Annotated[ - Optional[list[str]], Field(description="List of domains to include", default=None) - ] - exclude_domains: Annotated[ - Optional[list[str]], Field(description="List of domains to exclude", default=None) - ] - - -def format_results(response: dict) -> str: - """Format Tavily search results into a readable string.""" - output = [] - - # Add domain filter information if present - if response.get("included_domains") or response.get("excluded_domains"): - filters = [] - if response.get("included_domains"): - filters.append(f"Including domains: {', '.join(response['included_domains'])}") - if response.get("excluded_domains"): - filters.append(f"Excluding domains: {', '.join(response['excluded_domains'])}") - output.append("Search Filters:") - output.extend(filters) - output.append("") # Empty line for separation - - # Add answer if present - if response.get("answer"): - output.append(f"Answer: {response['answer']}") - output.append("\nSources:") - # Add immediate source references for the answer - for result in response.get("results", []): - output.append(f"- {result.get('title', 'No title')}: {result.get('url', 'No URL')}") - output.append("") # Empty line for separation - - # Add detailed results - output.append("Detailed Results:") - for result in response.get("results", []): - output.append(f"\nTitle: {result.get('title', 'No title')}") - output.append(f"URL: {result.get('url', 'No URL')}") - output.append(f"Content: {result.get('content', 'No content')}") - if result.get("published_date"): - output.append(f"Published: {result['published_date']}") - - return "\n".join(output) - - -class MCPWebSearchTool: - """Web search tool class.""" - - web_search_spec = MCPToolSpec( - description="""\ - Searches the web for technical information to aid in bug analysis and resolution. - Use this when you need external context, such as: - 1. Looking up unfamiliar error messages, exceptions, or stack traces. - 2. Finding official documentation or usage examples for a specific library, framework, or API. - 3. Searching for known bugs, common pitfalls, or compatibility issues related to a software version. - 4. Learning the best practices or design patterns for fixing a class of vulnerability (e.g., SQL injection, XSS). - - Queries should be specific and include relevant keywords like library names, version numbers, and error codes. - """, - input_schema=WebSearchInput, - ) - - -@mcp.tool() -async def web_search( - query: str, - max_results: int = 5, - include_domains: Optional[list[str]] = None, - exclude_domains: Optional[list[str]] = None, -) -> str: - """\ - Searches the web for technical information to aid in bug analysis and resolution. - Use this when you need external context, such as: - 1. Looking up unfamiliar error messages, exceptions, or stack traces. - 2. Finding official documentation or usage examples for a specific library, framework, or API. - 3. Searching for known bugs, common pitfalls, or compatibility issues related to a software version. - 4. Learning the best practices or design patterns for fixing a class of vulnerability (e.g., SQL injection, XSS). - - Queries should be specific and include relevant keywords like library names, version numbers, and error codes. - - - Args: - query: Search query string - max_results: Maximum number of results (default: 5) - include_domains: List of domains to include (default: technical documentation sites) - exclude_domains: List of domains to exclude - """ - - # Check if API key is available - if tavily_api_key is None: - return "Error: Tavily API key is not set" - - # Default technical search domains - if include_domains is None: - include_domains = [ - "stackoverflow.com", - "github.com", - "developer.mozilla.org", - "learn.microsoft.com", - "docs.python.org", - "pydantic.dev", - "pypi.org", - "readthedocs.org", - "docs.djangoproject.com", - "flask.palletsprojects.com", - "fastapi.tiangolo.com", - ] - - # Build request payload - payload = { - "query": query, - "max_results": max_results, - "include_domains": include_domains or [], - "exclude_domains": exclude_domains or [], - } - - try: - logger.info(f"Executing web search, query: {query}") - - # Use aiohttp to send HTTP request to MCP server - async with aiohttp.ClientSession() as session: - async with session.post(TAVILY_SERVER_URL, json=payload) as resp: - if resp.status != 200: - error_msg = f"HTTP error {resp.status}: {await resp.text()}" - logger.error(error_msg) - return error_msg - - data = await resp.json() - - # Format response - formatted_response = format_results(data) - logger.info(f"Web search completed, returned {len(data.get('results', []))} results") - - return formatted_response - - except aiohttp.ClientError as e: - error_msg = f"Network request error: {str(e)}" - logger.error(error_msg) - return error_msg - except json.JSONDecodeError as e: - error_msg = f"JSON parsing error: {str(e)}" - logger.error(error_msg) - return error_msg - except Exception as e: - error_msg = f"Error occurred during search: {str(e)}" - logger.error(error_msg) - return error_msg - - -def run_mcp_server(): - """Run MCP server.""" - logger.info("Starting MCP Web search server...") - mcp.run() - - -if __name__ == "__main__": - # # Load environment variables - # load_dotenv() - - # # Get Tavily API key - # tavily_api_key = settings.get("TAVILY_API_KEY", None) - # if tavily_api_key is None: - # logger.warning("Tavily API key is not set") - # TAVILY_SERVER_URL = f"https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}" - - # Run server - run_mcp_server() From 4a840c0654cce0f84533659ef1eec743f66a5615 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Thu, 4 Sep 2025 22:17:13 +0800 Subject: [PATCH 15/30] refactor: Remove unused logger initialization in bug_fix_verification_subgraph_node.py --- .../lang_graph/nodes/bug_fix_verification_subgraph_node.py | 3 --- prometheus/tools/__init__.py | 1 - 2 files changed, 4 deletions(-) diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index 13a72ad2..507fd4d7 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -17,9 +17,6 @@ def __init__( container: BaseContainer, git_repo: GitRepository, ): - # self._logger = logging.getLogger( - # f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_fix_verification_subgraph_node" - # ) self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") self.git_repo = git_repo self.subgraph = BugFixVerificationSubgraph( diff --git a/prometheus/tools/__init__.py b/prometheus/tools/__init__.py index 8b137891..e69de29b 100644 --- a/prometheus/tools/__init__.py +++ b/prometheus/tools/__init__.py @@ -1 +0,0 @@ - From 3f8c4a363b5e4786cf36b50c8dfba4b63b42a13b Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Thu, 4 Sep 2025 23:55:45 +0800 Subject: [PATCH 16/30] refactor: Correct spelling errors in file operation messages and update test cases to use FileOperationTool --- prometheus/tools/file_operation.py | 6 +- .../nodes/test_bug_reproducing_write_node.py | 5 +- tests/lang_graph/nodes/test_edit_node.py | 9 ++- tests/tools/test_file_operation.py | 75 +++++++++---------- tests/tools/test_graph_traversal.py | 64 ++++++++-------- tests/utils/test_file_utils.py | 6 +- 6 files changed, 78 insertions(+), 87 deletions(-) diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index 179087e6..f9cee6df 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -167,7 +167,7 @@ def read_file_with_line_numbers( def create_file(self, relative_path: str, content: str) -> str: if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a abolsute path, not relative path." + return f"relative_path: {relative_path} is a absolute path, not relative path." file_path = Path(os.path.join(self.root_path, relative_path)) if file_path.exists(): @@ -179,7 +179,7 @@ def create_file(self, relative_path: str, content: str) -> str: def delete(self, relative_path: str) -> str: if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a abolsute path, not relative path." + return f"relative_path: {relative_path} is a absolute path, not relative path." file_path = Path(os.path.join(self.root_path, relative_path)) if not file_path.exists(): @@ -194,7 +194,7 @@ def delete(self, relative_path: str) -> str: def edit_file(self, relative_path: str, old_content: str, new_content: str) -> str: if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a abolsute path, not relative path." + return f"relative_path: {relative_path} is a absolute path, not relative path." file_path = Path(os.path.join(self.root_path, relative_path)) if not file_path.exists(): diff --git a/tests/lang_graph/nodes/test_bug_reproducing_write_node.py b/tests/lang_graph/nodes/test_bug_reproducing_write_node.py index 30cfb974..720a3b93 100644 --- a/tests/lang_graph/nodes/test_bug_reproducing_write_node.py +++ b/tests/lang_graph/nodes/test_bug_reproducing_write_node.py @@ -6,6 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.nodes.bug_reproducing_write_node import BugReproducingWriteNode from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState +from tests.test_utils.fixtures import temp_test_dir # noqa: F401 from tests.test_utils.util import FakeListChatWithToolsModel @@ -35,11 +36,11 @@ def test_state(): ) -def test_call_method(mock_kg, test_state): +def test_call_method(mock_kg, test_state, temp_test_dir): # noqa: F811 """Test the __call__ method execution.""" fake_response = "Created test file" fake_llm = FakeListChatWithToolsModel(responses=[fake_response]) - node = BugReproducingWriteNode(fake_llm, mock_kg) + node = BugReproducingWriteNode(fake_llm, temp_test_dir, mock_kg) result = node(test_state) diff --git a/tests/lang_graph/nodes/test_edit_node.py b/tests/lang_graph/nodes/test_edit_node.py index d7de2632..39a49b40 100644 --- a/tests/lang_graph/nodes/test_edit_node.py +++ b/tests/lang_graph/nodes/test_edit_node.py @@ -5,6 +5,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.nodes.edit_node import EditNode +from tests.test_utils.fixtures import temp_test_dir # noqa: F401 from tests.test_utils.util import FakeListChatWithToolsModel @@ -19,18 +20,18 @@ def fake_llm(): return FakeListChatWithToolsModel(responses=["File edit completed successfully"]) -def test_init_edit_node(mock_kg, fake_llm): +def test_init_edit_node(mock_kg, fake_llm, temp_test_dir): # noqa: F811 """Test EditNode initialization.""" - node = EditNode(fake_llm, mock_kg) + node = EditNode(fake_llm, temp_test_dir, mock_kg) assert isinstance(node.system_prompt, SystemMessage) assert len(node.tools) == 5 # Should have 5 file operation tools assert node.model_with_tools is not None -def test_call_method_basic(mock_kg, fake_llm): +def test_call_method_basic(mock_kg, fake_llm, temp_test_dir): # noqa: F811 """Test basic call functionality without tool execution.""" - node = EditNode(fake_llm, mock_kg) + node = EditNode(fake_llm, temp_test_dir, mock_kg) state = {"edit_messages": [HumanMessage(content="Make the following changes: ...")]} result = node(state) diff --git a/tests/tools/test_file_operation.py b/tests/tools/test_file_operation.py index 62fb8c81..c1a762d2 100644 --- a/tests/tools/test_file_operation.py +++ b/tests/tools/test_file_operation.py @@ -1,14 +1,7 @@ import pytest from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.tools.file_operation import ( - create_file, - delete, - edit_file, - read_file, - read_file_with_knowledge_graph_data, - read_file_with_line_numbers, -) +from prometheus.tools.file_operation import FileOperationTool from tests.test_utils import test_project_paths from tests.test_utils.fixtures import temp_test_dir # noqa: F401 @@ -20,13 +13,17 @@ async def knowledge_graph_fixture(): return kg -def test_read_file_with_knowledge_graph_data(temp_test_dir, knowledge_graph_fixture): # noqa: F811 +@pytest.fixture(scope="function") +def file_operation_tool(temp_test_dir, knowledge_graph_fixture): # noqa: F811 + file_operation = FileOperationTool(temp_test_dir, knowledge_graph_fixture) + return file_operation + + +def test_read_file_with_knowledge_graph_data(file_operation_tool): relative_path = str( test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) - result = read_file_with_knowledge_graph_data( - relative_path, test_project_paths.TEST_PROJECT_PATH, knowledge_graph_fixture - ) + result = file_operation_tool.read_file_with_knowledge_graph_data(relative_path) result_data = result[1] assert len(result_data) > 0 for result_row in result_data: @@ -36,87 +33,87 @@ def test_read_file_with_knowledge_graph_data(temp_test_dir, knowledge_graph_fixt assert result_row["FileNode"].get("relative_path", "") == relative_path -def test_create_and_read_file(temp_test_dir): # noqa: F811 +def test_create_and_read_file(temp_test_dir, file_operation_tool): # noqa: F811 """Test creating a file and reading its contents.""" test_file = temp_test_dir / "test.txt" content = "line 1\nline 2\nline 3" # Test create_file - result = create_file("test.txt", str(temp_test_dir), content) + result = file_operation_tool.create_file("test.txt", content) assert test_file.exists() assert test_file.read_text() == content assert result == "The file test.txt has been created." # Test read_file - result = read_file("test.txt", str(temp_test_dir)) + result = file_operation_tool.read_file("test.txt") expected = "1. line 1\n2. line 2\n3. line 3" assert result == expected -def test_read_file_nonexistent(temp_test_dir): # noqa: F811 +def test_read_file_nonexistent(file_operation_tool): """Test reading a nonexistent file.""" - result = read_file("nonexistent_file.txt", str(temp_test_dir)) + result = file_operation_tool.read_file("nonexistent_file.txt") assert result == "The file nonexistent_file.txt does not exist." -def test_read_file_with_line_numbers(temp_test_dir): # noqa: F811 +def test_read_file_with_line_numbers(file_operation_tool): """Test reading specific line ranges from a file.""" content = "line 1\nline 2\nline 3\nline 4\nline 5" - create_file("test_lines.txt", str(temp_test_dir), content) + file_operation_tool.create_file("test_lines.txt", content) # Test reading specific lines - result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 2, 4) + result = file_operation_tool.read_file_with_line_numbers("test_lines.txt", 2, 4) expected = "2. line 2\n3. line 3" assert result == expected # Test invalid range - result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 4, 2) + result = file_operation_tool.read_file_with_line_numbers("test_lines.txt", 4, 2) assert result == "The end line number 2 must be greater than the start line number 4." -def test_delete(temp_test_dir): # noqa: F811 +def test_delete(file_operation_tool, temp_test_dir): # noqa: F811 """Test file and directory deletion.""" # Test file deletion test_file = temp_test_dir / "to_delete.txt" - create_file("to_delete.txt", str(temp_test_dir), "content") + file_operation_tool.create_file("to_delete.txt", "content") assert test_file.exists() - result = delete("to_delete.txt", str(temp_test_dir)) + result = file_operation_tool.delete("to_delete.txt") assert result == "The file to_delete.txt has been deleted." assert not test_file.exists() # Test directory deletion test_subdir = temp_test_dir / "subdir" test_subdir.mkdir() - create_file("subdir/file.txt", str(temp_test_dir), "content") - result = delete("subdir", str(temp_test_dir)) + file_operation_tool.create_file("subdir/file.txt", "content") + result = file_operation_tool.delete("subdir") assert result == "The directory subdir has been deleted." assert not test_subdir.exists() -def test_delete_nonexistent(temp_test_dir): # noqa: F811 +def test_delete_nonexistent(file_operation_tool): """Test deleting a nonexistent path.""" - result = delete("nonexistent_path", str(temp_test_dir)) + result = file_operation_tool.delete("nonexistent_path") assert result == "The file nonexistent_path does not exist." -def test_edit_file(temp_test_dir): # noqa: F811 +def test_edit_file(file_operation_tool): """Test editing specific lines in a file.""" # Test case 1: Successfully edit a single occurrence initial_content = "line 1\nline 2\nline 3\nline 4\nline 5" - create_file("edit_test.txt", str(temp_test_dir), initial_content) - result = edit_file("edit_test.txt", str(temp_test_dir), "line 2", "new line 2") + file_operation_tool.create_file("edit_test.txt", initial_content) + result = file_operation_tool.edit_file("edit_test.txt", "line 2", "new line 2") assert result == "Successfully edited edit_test.txt." # Test case 2: Absolute path error - result = edit_file("/edit_test.txt", str(temp_test_dir), "line 2", "new line 2") + result = file_operation_tool.edit_file("/edit_test.txt", "line 2", "new line 2") assert result == "relative_path: /edit_test.txt is a absolute path, not relative path." # Test case 3: File doesn't exist - result = edit_file("nonexistent.txt", str(temp_test_dir), "line 2", "new line 2") + result = file_operation_tool.edit_file("nonexistent.txt", "line 2", "new line 2") assert result == "The file nonexistent.txt does not exist." # Test case 4: No matches found - result = edit_file("edit_test.txt", str(temp_test_dir), "nonexistent line", "new content") + result = file_operation_tool.edit_file("edit_test.txt", "nonexistent line", "new content") assert ( result == "No match found for the specified content in edit_test.txt. Please verify the content to replace." @@ -124,16 +121,16 @@ def test_edit_file(temp_test_dir): # noqa: F811 # Test case 5: Multiple occurrences duplicate_content = "line 1\nline 2\nline 2\nline 3" - create_file("duplicate_test.txt", str(temp_test_dir), duplicate_content) - result = edit_file("duplicate_test.txt", str(temp_test_dir), "line 2", "new line 2") + file_operation_tool.create_file("duplicate_test.txt", duplicate_content) + result = file_operation_tool.edit_file("duplicate_test.txt", "line 2", "new line 2") assert ( result == "Found 2 occurrences of the specified content in duplicate_test.txt. Please provide more context to ensure a unique match." ) -def test_create_file_already_exists(temp_test_dir): # noqa: F811 +def test_create_file_already_exists(file_operation_tool): """Test creating a file that already exists.""" - create_file("existing.txt", str(temp_test_dir), "content") - result = create_file("existing.txt", str(temp_test_dir), "new content") + file_operation_tool.create_file("existing.txt", "content") + result = file_operation_tool.create_file("existing.txt", "new content") assert result == "The file existing.txt already exists." diff --git a/tests/tools/test_graph_traversal.py b/tests/tools/test_graph_traversal.py index aa8f2a93..f3c30579 100644 --- a/tests/tools/test_graph_traversal.py +++ b/tests/tools/test_graph_traversal.py @@ -1,7 +1,7 @@ import pytest from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.tools import graph_traversal +from prometheus.tools.graph_traversal import GraphTraversalTool from tests.test_utils import test_project_paths @@ -12,11 +12,15 @@ async def knowledge_graph_fixture(): return kg +@pytest.fixture(scope="function") +def graph_traversal_tool(knowledge_graph_fixture): + graph_traversal_tool = GraphTraversalTool(knowledge_graph_fixture) + return graph_traversal_tool + + @pytest.mark.slow -async def test_find_file_node_with_basename(knowledge_graph_fixture): - result = graph_traversal.find_file_node_with_basename( - test_project_paths.PYTHON_FILE.name, knowledge_graph_fixture - ) +async def test_find_file_node_with_basename(graph_traversal_tool): + result = graph_traversal_tool.find_file_node_with_basename(test_project_paths.PYTHON_FILE.name) basename = test_project_paths.PYTHON_FILE.name relative_path = str( @@ -31,13 +35,11 @@ async def test_find_file_node_with_basename(knowledge_graph_fixture): @pytest.mark.slow -async def test_find_file_node_with_relative_path(knowledge_graph_fixture): +async def test_find_file_node_with_relative_path(graph_traversal_tool): relative_path = str( test_project_paths.MD_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) - result = graph_traversal.find_file_node_with_relative_path( - relative_path, knowledge_graph_fixture - ) + result = graph_traversal_tool.find_file_node_with_relative_path(relative_path) basename = test_project_paths.MD_FILE.name @@ -49,10 +51,10 @@ async def test_find_file_node_with_relative_path(knowledge_graph_fixture): @pytest.mark.slow -async def test_find_ast_node_with_text_in_file_with_basename(knowledge_graph_fixture): # noqa: F811 +async def test_find_ast_node_with_text_in_file_with_basename(graph_traversal_tool): basename = test_project_paths.PYTHON_FILE.name - result = graph_traversal.find_ast_node_with_text_in_file_with_basename( - "Hello world!", basename, knowledge_graph_fixture + result = graph_traversal_tool.find_ast_node_with_text_in_file_with_basename( + "Hello world!", basename ) result_data = result[1] @@ -65,12 +67,12 @@ async def test_find_ast_node_with_text_in_file_with_basename(knowledge_graph_fix @pytest.mark.slow -async def test_find_ast_node_with_text_in_file_with_relative_path(knowledge_graph_fixture): +async def test_find_ast_node_with_text_in_file_with_relative_path(graph_traversal_tool): relative_path = str( test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) - result = graph_traversal.find_ast_node_with_text_in_file_with_relative_path( - "Hello world!", relative_path, knowledge_graph_fixture + result = graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path( + "Hello world!", relative_path ) result_data = result[1] @@ -83,12 +85,10 @@ async def test_find_ast_node_with_text_in_file_with_relative_path(knowledge_grap @pytest.mark.slow -async def test_find_ast_node_with_type_in_file_with_basename(knowledge_graph_fixture): +async def test_find_ast_node_with_type_in_file_with_basename(graph_traversal_tool): basename = test_project_paths.C_FILE.name node_type = "function_definition" - result = graph_traversal.find_ast_node_with_type_in_file_with_basename( - node_type, basename, knowledge_graph_fixture - ) + result = graph_traversal_tool.find_ast_node_with_type_in_file_with_basename(node_type, basename) result_data = result[1] assert len(result_data) > 0 @@ -100,13 +100,13 @@ async def test_find_ast_node_with_type_in_file_with_basename(knowledge_graph_fix @pytest.mark.slow -async def test_find_ast_node_with_type_in_file_with_relative_path(knowledge_graph_fixture): # noqa: F811 +async def test_find_ast_node_with_type_in_file_with_relative_path(graph_traversal_tool): relative_path = str( test_project_paths.JAVA_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) node_type = "string_literal" - result = graph_traversal.find_ast_node_with_type_in_file_with_relative_path( - node_type, relative_path, knowledge_graph_fixture + result = graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path( + node_type, relative_path ) result_data = result[1] @@ -119,9 +119,9 @@ async def test_find_ast_node_with_type_in_file_with_relative_path(knowledge_grap @pytest.mark.slow -async def test_find_text_node_with_text(knowledge_graph_fixture): +async def test_find_text_node_with_text(graph_traversal_tool): text = "Text under header C" - result = graph_traversal.find_text_node_with_text(text, knowledge_graph_fixture) + result = graph_traversal_tool.find_text_node_with_text(text) result_data = result[1] assert len(result_data) > 0 @@ -137,12 +137,10 @@ async def test_find_text_node_with_text(knowledge_graph_fixture): @pytest.mark.slow -async def test_find_text_node_with_text_in_file(knowledge_graph_fixture): +async def test_find_text_node_with_text_in_file(graph_traversal_tool): basename = test_project_paths.MD_FILE.name text = "Text under header B" - result = graph_traversal.find_text_node_with_text_in_file( - text, basename, knowledge_graph_fixture - ) + result = graph_traversal_tool.find_text_node_with_text_in_file(text, basename) result_data = result[1] assert len(result_data) > 0 @@ -158,9 +156,9 @@ async def test_find_text_node_with_text_in_file(knowledge_graph_fixture): @pytest.mark.slow -async def test_get_next_text_node_with_node_id(knowledge_graph_fixture): +async def test_get_next_text_node_with_node_id(graph_traversal_tool): node_id = 34 - result = graph_traversal.get_next_text_node_with_node_id(node_id, knowledge_graph_fixture) + result = graph_traversal_tool.get_next_text_node_with_node_id(node_id) result_data = result[1] assert len(result_data) > 0 @@ -176,13 +174,11 @@ async def test_get_next_text_node_with_node_id(knowledge_graph_fixture): @pytest.mark.slow -async def test_read_code_with_relative_path(knowledge_graph_fixture): # noqa: F811 +async def test_read_code_with_relative_path(graph_traversal_tool): relative_path = str( test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) - result = graph_traversal.read_code_with_relative_path( - relative_path, 5, 6, knowledge_graph_fixture - ) + result = graph_traversal_tool.read_code_with_relative_path(relative_path, 5, 6) result_data = result[1] assert len(result_data) > 0 diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index 2c46a699..d5bd0925 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -1,7 +1,6 @@ import pytest from prometheus.exceptions.file_operation_exception import FileOperationException -from prometheus.tools.file_operation import create_file from prometheus.utils.file_utils import ( read_file_with_line_numbers, ) @@ -10,11 +9,8 @@ def test_read_file_with_line_numbers(temp_test_dir): # noqa: F811 """Test reading specific line ranges from a file.""" - content = "line 1\nline 2\nline 3\nline 4\nline 5" - create_file("test_lines.txt", str(temp_test_dir), content) - # Test reading specific lines - result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 2, 4) + result = read_file_with_line_numbers("foo/test.md", str(temp_test_dir), 1, 15) expected = "2. line 2\n3. line 3\n4. line 4" assert result == expected From f221867bb9db7eeb73a2d896c87253266395d568 Mon Sep 17 00:00:00 2001 From: cocoli Date: Fri, 5 Sep 2025 00:20:33 +0800 Subject: [PATCH 17/30] fix log --- prometheus/utils/logger_manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index 3565912b..8053162f 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -104,7 +104,7 @@ def _setup_root_logger(self): self.root_logger.handlers.clear() # Set log level - log_level = getattr(settings, "LOGGING_LEVEL", "INFO") + log_level = getattr(settings, "LOGGING_LEVEL", "DEBUG") self.root_logger.setLevel(getattr(logging, log_level)) # Create colored formatter for console output @@ -120,6 +120,7 @@ def _setup_root_logger(self): # Create console handler (using colored formatter) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(self.colored_formatter) + console_handler.setLevel(getattr(logging, log_level)) # Ensure console handler uses same level self.root_logger.addHandler(console_handler) # Prevent log propagation to parent logger @@ -192,6 +193,10 @@ def create_file_handler( # Create file handler (using plain formatter, without colors) file_handler = logging.FileHandler(log_file_path) file_handler.setFormatter(self.file_formatter) + + # Ensure file handler uses the same log level as console handler + log_level = getattr(settings, "LOGGING_LEVEL", "DEBUG") + file_handler.setLevel(getattr(logging, log_level)) # Get logger and add handler logger = self.get_logger(logger_name) From ac107e1136dab98f31605a508052a45568a7684e Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Fri, 5 Sep 2025 13:58:28 +0800 Subject: [PATCH 18/30] refactor: Replace standard logging with custom logger manager in patch_normalization_node.py --- prometheus/lang_graph/nodes/patch_normalization_node.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/prometheus/lang_graph/nodes/patch_normalization_node.py b/prometheus/lang_graph/nodes/patch_normalization_node.py index 6e8062ce..47a781a6 100644 --- a/prometheus/lang_graph/nodes/patch_normalization_node.py +++ b/prometheus/lang_graph/nodes/patch_normalization_node.py @@ -4,7 +4,6 @@ Provides standardized patch candidates with direct best patch selection. """ -import logging import re import threading from collections import defaultdict @@ -12,6 +11,7 @@ from typing import Dict, List, Sequence from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState +from prometheus.utils.logger_manager import get_logger @dataclass @@ -39,9 +39,7 @@ class PatchNormalizationNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.patch_normalization_node" - ) + self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") def normalize_patch(self, raw_patch: str) -> str: """Normalize patch content for deduplication From 6f034c947cd448a919154578eb4c609a5abbdd2a Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:51:12 +0800 Subject: [PATCH 19/30] refactor: Update KnowledgeGraph initialization parameters and adjust related tests --- prometheus/graph/knowledge_graph.py | 6 ++++-- prometheus/neo4j/knowledge_graph_handler.py | 6 ++++-- prometheus/utils/logger_manager.py | 6 ++++-- tests/graph/test_knowledge_graph.py | 18 ++++++++-------- tests/neo4j/test_knowledge_graph_handler.py | 12 +++++------ tests/test_utils/fixtures.py | 2 +- tests/tools/test_file_operation.py | 11 +++++++--- tests/utils/test_file_utils.py | 24 ++++++++++++++++----- 8 files changed, 55 insertions(+), 30 deletions(-) diff --git a/prometheus/graph/knowledge_graph.py b/prometheus/graph/knowledge_graph.py index eb2ef063..9c196a0e 100644 --- a/prometheus/graph/knowledge_graph.py +++ b/prometheus/graph/knowledge_graph.py @@ -19,6 +19,7 @@ import asyncio import itertools +import logging from collections import defaultdict, deque from pathlib import Path from typing import Mapping, Optional, Sequence @@ -42,7 +43,8 @@ Neo4jTextNode, TextNode, ) -from prometheus.utils.logger_manager import get_logger + +# from prometheus.utils.logger_manager import get_logger class KnowledgeGraph: @@ -79,7 +81,7 @@ def __init__( self._next_node_id = root_node_id + len(self._knowledge_graph_nodes) self._file_graph_builder = FileGraphBuilder(max_ast_depth, chunk_size, chunk_overlap) - self._logger = get_logger(__name__) + self._logger = logging.getLogger(__name__) async def build_graph(self, root_dir: Path): """Asynchronously builds knowledge graph for a codebase at a location. diff --git a/prometheus/neo4j/knowledge_graph_handler.py b/prometheus/neo4j/knowledge_graph_handler.py index 1d11a5dd..de627253 100644 --- a/prometheus/neo4j/knowledge_graph_handler.py +++ b/prometheus/neo4j/knowledge_graph_handler.py @@ -1,5 +1,6 @@ """The neo4j handler for writing the knowledge graph to neo4j.""" +import logging from typing import Mapping, Sequence from neo4j import AsyncGraphDatabase, AsyncManagedTransaction @@ -15,7 +16,8 @@ Neo4jTextNode, ) from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.utils.logger_manager import get_logger + +# from prometheus.utils.logger_manager import get_logger class KnowledgeGraphHandler: @@ -30,7 +32,7 @@ def __init__(self, driver: AsyncGraphDatabase.driver, batch_size: int): self.driver = driver self.batch_size = batch_size # initialize the database and logger - self._logger = get_logger(__name__) + self._logger = logging.getLogger(__name__) async def init_database(self): """Initialization of the neo4j database.""" diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index 8053162f..b4c98f6b 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -120,7 +120,9 @@ def _setup_root_logger(self): # Create console handler (using colored formatter) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(self.colored_formatter) - console_handler.setLevel(getattr(logging, log_level)) # Ensure console handler uses same level + console_handler.setLevel( + getattr(logging, log_level) + ) # Ensure console handler uses same level self.root_logger.addHandler(console_handler) # Prevent log propagation to parent logger @@ -193,7 +195,7 @@ def create_file_handler( # Create file handler (using plain formatter, without colors) file_handler = logging.FileHandler(log_file_path) file_handler.setFormatter(self.file_formatter) - + # Ensure file handler uses the same log level as console handler log_level = getattr(settings, "LOGGING_LEVEL", "DEBUG") file_handler.setLevel(getattr(logging, log_level)) diff --git a/tests/graph/test_knowledge_graph.py b/tests/graph/test_knowledge_graph.py index 9d598ade..551f8bd7 100644 --- a/tests/graph/test_knowledge_graph.py +++ b/tests/graph/test_knowledge_graph.py @@ -22,24 +22,24 @@ async def mock_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811 async def test_build_graph(): - knowledge_graph = KnowledgeGraph(1000, 100, 10, 0) + knowledge_graph = KnowledgeGraph(1, 1000, 100, 0) await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) - assert knowledge_graph._next_node_id == 93 + assert knowledge_graph._next_node_id == 15 # 7 FileNode # 84 ASTnode # 2 TextNode - assert len(knowledge_graph._knowledge_graph_nodes) == 93 - assert len(knowledge_graph._knowledge_graph_edges) == 93 + assert len(knowledge_graph._knowledge_graph_nodes) == 15 + assert len(knowledge_graph._knowledge_graph_edges) == 14 assert len(knowledge_graph.get_file_nodes()) == 7 - assert len(knowledge_graph.get_ast_nodes()) == 84 - assert len(knowledge_graph.get_text_nodes()) == 2 - assert len(knowledge_graph.get_parent_of_edges()) == 81 + assert len(knowledge_graph.get_ast_nodes()) == 7 + assert len(knowledge_graph.get_text_nodes()) == 1 + assert len(knowledge_graph.get_parent_of_edges()) == 4 assert len(knowledge_graph.get_has_file_edges()) == 6 assert len(knowledge_graph.get_has_ast_edges()) == 3 - assert len(knowledge_graph.get_has_text_edges()) == 2 - assert len(knowledge_graph.get_next_chunk_edges()) == 1 + assert len(knowledge_graph.get_has_text_edges()) == 1 + assert len(knowledge_graph.get_next_chunk_edges()) == 0 async def test_get_file_tree(): diff --git a/tests/neo4j/test_knowledge_graph_handler.py b/tests/neo4j/test_knowledge_graph_handler.py index c077533e..cdf9c873 100644 --- a/tests/neo4j/test_knowledge_graph_handler.py +++ b/tests/neo4j/test_knowledge_graph_handler.py @@ -38,7 +38,7 @@ async def test_num_ast_nodes(mock_neo4j_service): async with mock_neo4j_service.neo4j_driver.session() as session: read_ast_nodes = await session.execute_read(handler._read_ast_nodes, root_node_id=0) - assert len(read_ast_nodes) == 84 + assert len(read_ast_nodes) == 7 @pytest.mark.slow @@ -56,7 +56,7 @@ async def test_num_text_nodes(mock_neo4j_service): async with mock_neo4j_service.neo4j_driver.session() as session: read_text_nodes = await session.execute_read(handler._read_text_nodes, root_node_id=0) - assert len(read_text_nodes) == 2 + assert len(read_text_nodes) == 1 @pytest.mark.slow @@ -67,7 +67,7 @@ async def test_num_parent_of_edges(mock_neo4j_service): read_parent_of_edges = await session.execute_read( handler._read_parent_of_edges, root_node_id=0 ) - assert len(read_parent_of_edges) == 81 + assert len(read_parent_of_edges) == 4 @pytest.mark.slow @@ -98,7 +98,7 @@ async def test_num_has_text_edges(mock_neo4j_service): read_has_text_edges = await session.execute_read( handler._read_has_text_edges, root_node_id=0 ) - assert len(read_has_text_edges) == 2 + assert len(read_has_text_edges) == 1 @pytest.mark.slow @@ -109,12 +109,12 @@ async def test_num_next_chunk_edges(mock_neo4j_service): read_next_chunk_edges = await session.execute_read( handler._read_next_chunk_edges, root_node_id=0 ) - assert len(read_next_chunk_edges) == 1 + assert len(read_next_chunk_edges) == 0 @pytest.mark.slow async def test_clear_knowledge_graph(mock_empty_neo4j_service): - kg = KnowledgeGraph(1000, 1000, 100, 0) + kg = KnowledgeGraph(1, 1000, 100, 0) await kg.build_graph(test_project_paths.TEST_PROJECT_PATH) driver = mock_empty_neo4j_service.neo4j_driver diff --git a/tests/test_utils/fixtures.py b/tests/test_utils/fixtures.py index 7a30c6ca..4f0fb5b2 100644 --- a/tests/test_utils/fixtures.py +++ b/tests/test_utils/fixtures.py @@ -25,7 +25,7 @@ @pytest.fixture(scope="session") async def neo4j_container_with_kg_fixture(): - kg = KnowledgeGraph(1000, 100, 10, 0) + kg = KnowledgeGraph(1, 1000, 100, 0) await kg.build_graph(test_project_paths.TEST_PROJECT_PATH) container = ( Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD) diff --git a/tests/tools/test_file_operation.py b/tests/tools/test_file_operation.py index c1a762d2..7c4a6098 100644 --- a/tests/tools/test_file_operation.py +++ b/tests/tools/test_file_operation.py @@ -1,3 +1,5 @@ +import shutil + import pytest from prometheus.graph.knowledge_graph import KnowledgeGraph @@ -7,9 +9,12 @@ @pytest.fixture(scope="function") -async def knowledge_graph_fixture(): - kg = KnowledgeGraph(1000, 100, 10, 0) - await kg.build_graph(test_project_paths.TEST_PROJECT_PATH) +async def knowledge_graph_fixture(temp_test_dir): # noqa: F811 + if temp_test_dir.exists(): + shutil.rmtree(temp_test_dir) + shutil.copytree(test_project_paths.TEST_PROJECT_PATH, temp_test_dir) + kg = KnowledgeGraph(1, 1000, 100, 0) + await kg.build_graph(temp_test_dir) return kg diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index d5bd0925..c9a08107 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -4,18 +4,32 @@ from prometheus.utils.file_utils import ( read_file_with_line_numbers, ) -from tests.test_utils.fixtures import temp_test_dir # noqa: F401 +from tests.test_utils.test_project_paths import TEST_PROJECT_PATH -def test_read_file_with_line_numbers(temp_test_dir): # noqa: F811 +def test_read_file_with_line_numbers(): """Test reading specific line ranges from a file.""" # Test reading specific lines - result = read_file_with_line_numbers("foo/test.md", str(temp_test_dir), 1, 15) - expected = "2. line 2\n3. line 3\n4. line 4" + result = read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 1, 15) + expected = """1. # A +2. +3. Text under header A. +4. +5. ## B +6. +7. Text under header B. +8. +9. ## C +10. +11. Text under header C. +12. +13. ### D +14. +15. Text under header D.""" assert result == expected # Test invalid range should raise exception with pytest.raises(FileOperationException) as exc_info: - read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 4, 2) + read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 4, 2) assert str(exc_info.value) == "The end line number must be greater than the start line number." From fa2b6d3f868ac6089e3a5f94f72fcd713a1fa3b9 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:59:29 +0800 Subject: [PATCH 20/30] refactor: Simplify Tavily client initialization and introduce custom exception for web search errors --- .../exceptions/web_search_tool_exception.py | 4 +++ prometheus/tools/web_search.py | 30 ++++++++----------- .../nodes/test_issue_bug_analyzer_node.py | 11 ------- 3 files changed, 17 insertions(+), 28 deletions(-) create mode 100644 prometheus/exceptions/web_search_tool_exception.py diff --git a/prometheus/exceptions/web_search_tool_exception.py b/prometheus/exceptions/web_search_tool_exception.py new file mode 100644 index 00000000..873bff43 --- /dev/null +++ b/prometheus/exceptions/web_search_tool_exception.py @@ -0,0 +1,4 @@ +class WebSearchToolException(Exception): + """Custom exception for web search tool errors.""" + + pass diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 5e59f8c8..1b86d833 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -5,6 +5,7 @@ from tavily import InvalidAPIKeyError, TavilyClient, UsageLimitExceededError from prometheus.configuration.config import settings +from prometheus.exceptions.web_search_tool_exception import WebSearchToolException from prometheus.utils.logger_manager import get_logger @@ -75,14 +76,7 @@ def __init__(self): """Initialize the web search tool.""" # Load environment variables from .env file self._logger = get_logger(__name__) - - tavily_api_key = settings.TAVILY_API_KEY - if tavily_api_key is None: - self._logger.warning("Tavily API key is not set") - tavily_client = None - else: - tavily_client = TavilyClient(api_key=tavily_api_key) - self.tavily_client = tavily_client + self.tavily_client = TavilyClient(api_key=settings.TAVILY_API_KEY) def web_search( self, @@ -102,7 +96,7 @@ def web_search( Returns: Formatted search results as a string. """ - + # Set default include domains if not provided if include_domains is None: include_domains = [ "stackoverflow.com", @@ -114,8 +108,8 @@ def web_search( "pypi.org", "readthedocs.org", ] - if self.tavily_client is None: - raise RuntimeError("Tavily API key is not set") + + # Call the Tavily API try: response = self.tavily_client.search( query=query, @@ -125,12 +119,14 @@ def web_search( include_domains=include_domains or [], # Convert None to an empty list exclude_domains=exclude_domains or [], # Convert None to an empty list ) - format_response = format_results(response) - self._logger.info(f"web_search format_response: {format_response}") - return format_response except InvalidAPIKeyError: - raise ValueError("Invalid Tavily API key") + raise WebSearchToolException("Invalid Tavily API key") except UsageLimitExceededError: - raise RuntimeError("Usage limit exceeded") + raise WebSearchToolException("Usage limit exceeded") except Exception as e: - raise RuntimeError(f"An error occurred: {str(e)}") + raise WebSearchToolException(f"An error occurred: {str(e)}") + + # Format and return the results + format_response = format_results(response) + self._logger.info(f"web_search format_response: {format_response}") + return format_response diff --git a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py index 70951f3b..c7e5faa7 100644 --- a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py @@ -79,17 +79,6 @@ def test_web_search_tool_call_with_correct_parameters(fake_llm): assert web_search_tool.args_schema is not None -@patch("prometheus.tools.web_search.tavily_client") -def test_web_search_tool_without_api_key(fake_llm): - """Test web_search tool behavior when API key is not available.""" - # Simulate no API key scenario - node = IssueBugAnalyzerNode(fake_llm) - web_search_tool = node.tools[0] - - # The tool should still be created but may handle missing API key gracefully - assert web_search_tool.name == "web_search" - - def test_system_prompt_contains_web_search_info(fake_llm): """Test that the system prompt mentions web_search tool.""" node = IssueBugAnalyzerNode(fake_llm) From 3c6f03104c7cc93fba16b97232d5a8fe75cb8b82 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:59:47 +0800 Subject: [PATCH 21/30] refactor: Remove unused import from test_issue_bug_analyzer_node.py --- tests/lang_graph/nodes/test_issue_bug_analyzer_node.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py index c7e5faa7..efe0e2da 100644 --- a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py @@ -1,4 +1,3 @@ -from unittest.mock import patch import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage From bb5b07aaef673a663bce13254dbe70d9cb9fe2f5 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Fri, 5 Sep 2025 15:15:40 +0800 Subject: [PATCH 22/30] bump: Update project version to 1.2.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 52d59bc3..637cf474 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "Prometheus" -version = "0.0.1" +version = "1.2.0" dependencies = [ "langchain==0.3.3", "tree-sitter==0.21.3", From 0fb222808dde1cb65ab55946807a1ec3254171fc Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Fri, 5 Sep 2025 15:17:35 +0800 Subject: [PATCH 23/30] refactor: Remove unused import from test_issue_bug_analyzer_node.py --- tests/lang_graph/nodes/test_issue_bug_analyzer_node.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py index efe0e2da..ccc08399 100644 --- a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py @@ -1,4 +1,3 @@ - import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages.tool import ToolCall From 04edac4f5d07fe4a6e5041709f11448e587457e1 Mon Sep 17 00:00:00 2001 From: cocoli Date: Sat, 6 Sep 2025 00:28:26 +0800 Subject: [PATCH 24/30] fix log --- prometheus/app/main.py | 8 +- prometheus/app/services/database_service.py | 4 +- .../app/services/invitation_code_service.py | 3 +- prometheus/app/services/issue_service.py | 16 +- .../app/services/knowledge_graph_service.py | 3 +- prometheus/app/services/neo4j_service.py | 3 +- prometheus/app/services/user_service.py | 3 +- prometheus/docker/base_container.py | 6 +- prometheus/git/git_repository.py | 6 +- prometheus/graph/knowledge_graph.py | 5 +- .../bug_fix_verification_subgraph_node.py | 4 +- .../lang_graph/nodes/bug_fix_verify_node.py | 4 +- .../nodes/bug_fix_verify_structured_node.py | 4 +- ...bug_get_regression_context_message_node.py | 5 +- ...bug_get_regression_tests_selection_node.py | 5 +- .../bug_get_regression_tests_subgraph_node.py | 5 +- .../nodes/bug_reproducing_execute_node.py | 4 +- .../nodes/bug_reproducing_file_node.py | 4 +- .../nodes/bug_reproducing_structured_node.py | 4 +- .../bug_reproducing_write_message_node.py | 4 +- .../nodes/bug_reproducing_write_node.py | 4 +- .../nodes/bug_reproduction_subgraph_node.py | 4 +- .../nodes/build_and_test_subgraph_node.py | 4 +- .../nodes/context_extraction_node.py | 4 +- .../lang_graph/nodes/context_provider_node.py | 4 +- .../nodes/context_query_message_node.py | 4 +- .../lang_graph/nodes/context_refine_node.py | 4 +- .../nodes/context_retrieval_subgraph_node.py | 4 +- .../lang_graph/nodes/edit_message_node.py | 4 +- prometheus/lang_graph/nodes/edit_node.py | 4 +- .../nodes/final_patch_selection_node.py | 4 +- .../lang_graph/nodes/general_build_node.py | 4 +- .../nodes/general_build_structured_node.py | 4 +- .../lang_graph/nodes/general_test_node.py | 4 +- .../nodes/general_test_structured_node.py | 4 +- ...regression_test_patch_check_result_node.py | 6 +- ...ass_regression_test_patch_subgraph_node.py | 6 +- ..._pass_regression_test_patch_update_node.py | 5 +- .../lang_graph/nodes/git_apply_patch_node.py | 5 +- prometheus/lang_graph/nodes/git_diff_node.py | 4 +- prometheus/lang_graph/nodes/git_reset_node.py | 4 +- .../nodes/issue_bug_analyzer_message_node.py | 4 +- .../nodes/issue_bug_analyzer_node.py | 4 +- .../nodes/issue_bug_context_message_node.py | 4 +- ...e_bug_reproduction_context_message_node.py | 4 +- .../nodes/issue_bug_responder_node.py | 4 +- .../nodes/issue_bug_subgraph_node.py | 4 +- ...sue_classification_context_message_node.py | 4 +- .../issue_classification_subgraph_node.py | 4 +- .../lang_graph/nodes/issue_classifier_node.py | 4 +- .../issue_not_verified_bug_subgraph_node.py | 4 +- .../nodes/issue_question_analyzer_node.py | 5 +- .../issue_question_context_message_node.py | 5 +- .../nodes/issue_question_subgraph_node.py | 5 +- .../nodes/issue_verified_bug_subgraph_node.py | 4 +- prometheus/lang_graph/nodes/noop_node.py | 4 +- .../nodes/patch_normalization_node.py | 4 +- .../lang_graph/nodes/reset_messages_node.py | 4 +- .../nodes/run_existing_tests_node.py | 5 +- .../run_existing_tests_structure_node.py | 5 +- .../nodes/run_existing_tests_subgraph_node.py | 5 +- .../nodes/run_regression_tests_node.py | 4 +- .../run_regression_tests_structure_node.py | 5 +- .../run_regression_tests_subgraph_node.py | 5 +- .../lang_graph/nodes/update_container_node.py | 4 +- .../nodes/user_defined_build_node.py | 4 +- .../nodes/user_defined_test_node.py | 4 +- prometheus/neo4j/knowledge_graph_handler.py | 5 +- prometheus/tools/file_operation.py | 4 +- prometheus/tools/web_search.py | 14 +- prometheus/utils/logger_manager.py | 195 +++++++++++------- 71 files changed, 272 insertions(+), 248 deletions(-) diff --git a/prometheus/app/main.py b/prometheus/app/main.py index a25877dd..16cfa09c 100644 --- a/prometheus/app/main.py +++ b/prometheus/app/main.py @@ -1,6 +1,8 @@ import inspect from contextlib import asynccontextmanager from datetime import datetime, timezone +from pathlib import Path +import threading from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -15,9 +17,11 @@ register_login_required_routes, ) from prometheus.configuration.config import settings -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger, remove_multi_threads_log_file_handler -logger = get_logger(__name__) + +# Create main thread logger with file handler - ONE LINE! +logger, file_handler = get_thread_logger(__name__) @asynccontextmanager diff --git a/prometheus/app/services/database_service.py b/prometheus/app/services/database_service.py index 1280f665..b261cf75 100644 --- a/prometheus/app/services/database_service.py +++ b/prometheus/app/services/database_service.py @@ -2,13 +2,13 @@ from sqlmodel import SQLModel from prometheus.app.services.base_service import BaseService -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class DatabaseService(BaseService): def __init__(self, DATABASE_URL: str): self.engine = create_async_engine(DATABASE_URL, echo=True) - self._logger = get_logger(__name__) + self._logger, file_handler = get_thread_logger(__name__) # Create the database and tables async def create_db_and_tables(self): diff --git a/prometheus/app/services/invitation_code_service.py b/prometheus/app/services/invitation_code_service.py index 6cd0ab07..e0962fea 100644 --- a/prometheus/app/services/invitation_code_service.py +++ b/prometheus/app/services/invitation_code_service.py @@ -9,13 +9,14 @@ from prometheus.app.entity.invitation_code import InvitationCode from prometheus.app.services.base_service import BaseService from prometheus.app.services.database_service import DatabaseService +from prometheus.utils.logger_manager import get_thread_logger class InvitationCodeService(BaseService): def __init__(self, database_service: DatabaseService): self.database_service = database_service self.engine = database_service.engine - self._logger = logging.getLogger("prometheus.app.services.invitation_code_service") + self._logger, file_handler = get_thread_logger(__name__) async def create_invitation_code(self) -> InvitationCode: """ diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index 522a8a77..bc2f1286 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -13,6 +13,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_graph import IssueGraph from prometheus.lang_graph.graphs.issue_state import IssueType +from prometheus.utils.logger_manager import get_thread_logger, remove_multi_threads_log_file_handler class IssueService(BaseService): @@ -79,15 +80,8 @@ def answer_issue( - issue_type (IssueType): The type of the issue (BUG or QUESTION). """ - # Set up a dedicated logger for this thread - logger = logging.getLogger(f"thread-{threading.get_ident()}.prometheus") - logger.setLevel(getattr(logging, self.logging_level)) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = self.answer_issue_log_dir / f"{timestamp}_{threading.get_ident()}.log" - file_handler = logging.FileHandler(log_file) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + # Create thread-specific logger with file handler - ONE LINE! + logger, file_handler = get_thread_logger(__name__) # Construct the working directory if dockerfile_content or image_name: @@ -141,5 +135,5 @@ def answer_issue( logger.error(f"Error in answer_issue: {str(e)}\n{traceback.format_exc()}") return None, False, False, False, None, None finally: - logger.removeHandler(file_handler) - file_handler.close() + # Remove multi-thread file handler + remove_multi_threads_log_file_handler(file_handler, logger.name) diff --git a/prometheus/app/services/knowledge_graph_service.py b/prometheus/app/services/knowledge_graph_service.py index e893dbc2..fe251a75 100644 --- a/prometheus/app/services/knowledge_graph_service.py +++ b/prometheus/app/services/knowledge_graph_service.py @@ -8,6 +8,7 @@ from prometheus.app.services.neo4j_service import Neo4jService from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.neo4j import knowledge_graph_handler +from prometheus.utils.logger_manager import get_thread_logger class KnowledgeGraphService(BaseService): @@ -42,7 +43,7 @@ def __init__( self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.writing_lock = asyncio.Lock() - self._logger = logging.getLogger("prometheus.app.services.knowledge_graph_service") + self._logger, file_handler = get_thread_logger(__name__) async def start(self): # Initialize the Neo4j database for Knowledge Graph operations diff --git a/prometheus/app/services/neo4j_service.py b/prometheus/app/services/neo4j_service.py index 14b5312c..5a5a714c 100644 --- a/prometheus/app/services/neo4j_service.py +++ b/prometheus/app/services/neo4j_service.py @@ -5,11 +5,12 @@ from neo4j import AsyncGraphDatabase from prometheus.app.services.base_service import BaseService +from prometheus.utils.logger_manager import get_thread_logger class Neo4jService(BaseService): def __init__(self, neo4j_uri: str, neo4j_username: str, neo4j_password: str): - self._logger = logging.getLogger("prometheus.app.services.neo4j_service") + self._logger, file_handler = get_thread_logger(__name__) self.neo4j_driver = AsyncGraphDatabase.driver( neo4j_uri, auth=(neo4j_username, neo4j_password), diff --git a/prometheus/app/services/user_service.py b/prometheus/app/services/user_service.py index bfa7cc36..445a73ea 100644 --- a/prometheus/app/services/user_service.py +++ b/prometheus/app/services/user_service.py @@ -11,13 +11,14 @@ from prometheus.app.services.database_service import DatabaseService from prometheus.exceptions.server_exception import ServerException from prometheus.utils.jwt_utils import JWTUtils +from prometheus.utils.logger_manager import get_thread_logger class UserService(BaseService): def __init__(self, database_service: DatabaseService): self.database_service = database_service self.engine = database_service.engine - self._logger = logging.getLogger("prometheus.app.services.user_service") + self._logger, file_handler = get_thread_logger(__name__) self.ph = PasswordHasher() self.jwt_utils = JWTUtils() diff --git a/prometheus/docker/base_container.py b/prometheus/docker/base_container.py index 03c9809d..1b567881 100644 --- a/prometheus/docker/base_container.py +++ b/prometheus/docker/base_container.py @@ -9,7 +9,7 @@ import docker -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BaseContainer(ABC): @@ -43,9 +43,7 @@ def __init__( Args: project_path: Path to the project directory to be containerized. """ - self._logger = get_logger( - f"thread-{threading.get_ident()}.{self.__class__.__module__}.{self.__class__.__name__}" - ) + self._logger, file_handler = get_thread_logger(__name__) temp_dir = Path(tempfile.mkdtemp()) temp_project_path = temp_dir / project_path.name shutil.copytree(project_path, temp_project_path) diff --git a/prometheus/git/git_repository.py b/prometheus/git/git_repository.py index 5a96413e..f6f15ef5 100644 --- a/prometheus/git/git_repository.py +++ b/prometheus/git/git_repository.py @@ -8,7 +8,7 @@ from git import Git, GitCommandError, InvalidGitRepositoryError, Repo -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class GitRepository: @@ -24,12 +24,12 @@ def __init__(self): """ Initialize a GitRepository instance. """ - self._logger = get_logger(__name__) + self._logger, file_handler = get_thread_logger(__name__) # Configure git command to use our logger g = Git() type(g).GIT_PYTHON_TRACE = "full" - git_cmd_logger = get_logger("git.cmd") + git_cmd_logger, file_handler = get_thread_logger("git.cmd") # Ensure git command output goes to our logger for handler in git_cmd_logger.handlers: diff --git a/prometheus/graph/knowledge_graph.py b/prometheus/graph/knowledge_graph.py index 9c196a0e..56828446 100644 --- a/prometheus/graph/knowledge_graph.py +++ b/prometheus/graph/knowledge_graph.py @@ -43,8 +43,7 @@ Neo4jTextNode, TextNode, ) - -# from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class KnowledgeGraph: @@ -81,7 +80,7 @@ def __init__( self._next_node_id = root_node_id + len(self._knowledge_graph_nodes) self._file_graph_builder = FileGraphBuilder(max_ast_depth, chunk_size, chunk_overlap) - self._logger = logging.getLogger(__name__) + self._logger, file_handler = get_thread_logger(__name__) async def build_graph(self, root_dir: Path): """Asynchronously builds knowledge graph for a codebase at a location. diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index 507fd4d7..28d310d1 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -7,7 +7,7 @@ from prometheus.git.git_repository import GitRepository from prometheus.lang_graph.subgraphs.bug_fix_verification_subgraph import BugFixVerificationSubgraph from prometheus.lang_graph.subgraphs.issue_verified_bug_state import IssueVerifiedBugState -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BugFixVerificationSubgraphNode: @@ -17,7 +17,7 @@ def __init__( container: BaseContainer, git_repo: GitRepository, ): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) self.git_repo = git_repo self.subgraph = BugFixVerificationSubgraph( model=model, diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index fcff00b5..420a0e4b 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -8,7 +8,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerificationState from prometheus.tools.container_command import ContainerCommandTool -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BugFixVerifyNode: @@ -55,7 +55,7 @@ def __init__(self, model: BaseChatModel, container: BaseContainer): self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): tools = [] diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py index 80f8052f..1bc5c6cf 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py @@ -6,7 +6,7 @@ from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerificationState from prometheus.utils.lang_graph_util import get_last_message_content -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BugFixVerifyStructureOutput(BaseModel): @@ -91,7 +91,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BugFixVerifyStructureOutput) self.model = prompt | structured_llm - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BugFixVerificationState): bug_fix_verify_message = get_last_message_content(state["bug_fix_verify_messages"]) diff --git a/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py b/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py index 225ca3b6..cef875aa 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py @@ -5,6 +5,7 @@ BugGetRegressionTestsState, ) from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class BugGetRegressionContextMessageNode: @@ -86,9 +87,7 @@ def test_offset_preserved(self): """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_get_regression_context_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BugGetRegressionTestsState): select_regression_query = self.SELECT_REGRESSION_QUERY.format( diff --git a/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py b/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py index 8fa8a8a1..f557bccb 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py @@ -9,6 +9,7 @@ BugGetRegressionTestsState, ) from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class RegressionTestStructuredOutPut(BaseModel): @@ -91,9 +92,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(RegressionTestsStructuredOutPut) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_get_regression_tests_selection_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: BugGetRegressionTestsState): return self.HUMAN_PROMPT.format( diff --git a/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py b/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py index 24c1b638..9c0879b3 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py @@ -11,6 +11,7 @@ from prometheus.lang_graph.subgraphs.bug_get_regression_tests_subgraph import ( BugGetRegressionTestsSubgraph, ) +from prometheus.utils.logger_manager import get_thread_logger class BugGetRegressionTestsSubgraphNode: @@ -22,9 +23,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_get_regression_tests_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = BugGetRegressionTestsSubgraph( advanced_model=advanced_model, base_model=base_model, diff --git a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py index 38727ec5..afd00799 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py @@ -11,7 +11,7 @@ from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.issue_util import format_test_commands -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger from prometheus.utils.patch_util import get_updated_files @@ -57,7 +57,7 @@ def __init__( self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): tools = [] diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index 684f077a..9d6628ba 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -9,7 +9,7 @@ from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.tools.file_operation import FileOperationTool from prometheus.utils.lang_graph_util import get_last_message_content -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BugReproducingFileNode: @@ -42,7 +42,7 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph, local_path: str): self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): """Initializes file operation tools.""" diff --git a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py index cb2ca68f..3c1ecada 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py @@ -11,7 +11,7 @@ format_agent_tool_message_history, get_last_message_content, ) -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BugReproducingStructuredOutput(BaseModel): @@ -136,7 +136,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BugReproducingStructuredOutput) self.model = prompt | structured_llm - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BugReproductionState): bug_reproducing_log = format_agent_tool_message_history( diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py index f2dd20cd..c10cbc25 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py @@ -4,7 +4,7 @@ from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_info -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BugReproducingWriteMessageNode: @@ -25,7 +25,7 @@ class BugReproducingWriteMessageNode: """ def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: BugReproductionState): if "reproduced_bug_failure_log" in state and state["reproduced_bug_failure_log"]: diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index 9cd66b92..67c45f27 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -8,7 +8,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.tools.file_operation import FileOperationTool -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BugReproducingWriteNode: @@ -117,7 +117,7 @@ def __init__(self, model: BaseChatModel, local_path: str, kg: KnowledgeGraph): self.tools = self._init_tools() self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model_with_tools = model.bind_tools(self.tools) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): """Initializes file operation tools with the given root path. diff --git a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py index 20519e0f..7827e42a 100644 --- a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py @@ -9,7 +9,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_subgraph import BugReproductionSubgraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BugReproductionSubgraphNode: @@ -22,7 +22,7 @@ def __init__( git_repo: GitRepository, test_commands: Optional[Sequence[str]], ): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) self.git_repo = git_repo self.bug_reproduction_subgraph = BugReproductionSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py index f6f1f380..d7ee56ae 100644 --- a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py +++ b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py @@ -7,7 +7,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_subgraph import BuildAndTestSubgraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BuildAndTestSubgraphNode: @@ -26,7 +26,7 @@ def __init__( build_commands=build_commands, test_commands=test_commands, ) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: IssueBugState): exist_build = None diff --git a/prometheus/lang_graph/nodes/context_extraction_node.py b/prometheus/lang_graph/nodes/context_extraction_node.py index 0d8bd0d9..eb04ece4 100644 --- a/prometheus/lang_graph/nodes/context_extraction_node.py +++ b/prometheus/lang_graph/nodes/context_extraction_node.py @@ -14,7 +14,7 @@ extract_last_tool_messages, transform_tool_messages_to_str, ) -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger SYS_PROMPT = """\ You are a context summary agent that summarizes code contexts which is relevant to a given query. @@ -115,7 +115,7 @@ def __init__(self, model: BaseChatModel, root_path: str): structured_llm = model.with_structured_output(ContextExtractionStructuredOutput) self.model = prompt | structured_llm self.root_path = root_path - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: ContextRetrievalState): """ diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index b90cb798..44cef057 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -16,7 +16,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.tools.file_operation import FileOperationTool from prometheus.tools.graph_traversal import GraphTraversalTool -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class ContextProviderNode: @@ -119,7 +119,7 @@ def __init__( ) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): """ diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py index e6bc476a..d1b77a4c 100644 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ b/prometheus/lang_graph/nodes/context_query_message_node.py @@ -3,12 +3,12 @@ from langchain_core.messages import HumanMessage from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class ContextQueryMessageNode: def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: ContextRetrievalState): human_message = HumanMessage(state["query"]) diff --git a/prometheus/lang_graph/nodes/context_refine_node.py b/prometheus/lang_graph/nodes/context_refine_node.py index 923f8b34..8a62cdff 100644 --- a/prometheus/lang_graph/nodes/context_refine_node.py +++ b/prometheus/lang_graph/nodes/context_refine_node.py @@ -7,7 +7,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class ContextRefineStructuredOutput(BaseModel): @@ -90,7 +90,7 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): ) structured_llm = model.with_structured_output(ContextRefineStructuredOutput) self.model = prompt | structured_llm - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def format_refine_message(self, state: ContextRetrievalState): original_query = state["query"] diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index c9044f19..3105f798 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -7,7 +7,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.context_retrieval_subgraph import ContextRetrievalSubgraph from prometheus.models.context import Context -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class ContextRetrievalSubgraphNode: @@ -19,7 +19,7 @@ def __init__( query_key_name: str, context_key_name: str, ): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) self.context_retrieval_subgraph = ContextRetrievalSubgraph( model=model, kg=kg, diff --git a/prometheus/lang_graph/nodes/edit_message_node.py b/prometheus/lang_graph/nodes/edit_message_node.py index e7182501..97a58052 100644 --- a/prometheus/lang_graph/nodes/edit_message_node.py +++ b/prometheus/lang_graph/nodes/edit_message_node.py @@ -5,7 +5,7 @@ from prometheus.utils.issue_util import format_issue_info from prometheus.utils.lang_graph_util import get_last_message_content -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class EditMessageNode: @@ -43,7 +43,7 @@ class EditMessageNode: """ def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: Dict): edit_error = "" diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index 9d70ad77..b46593e5 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -16,7 +16,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.tools.file_operation import FileOperationTool -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class EditNode: @@ -124,7 +124,7 @@ def __init__(self, model: BaseChatModel, local_path: str, kg: KnowledgeGraph): self.file_operation_tool = FileOperationTool(local_path, kg) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): """Initializes file operation tools with the given root path. diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index 3d361e53..e7a0fa52 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -7,7 +7,7 @@ from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState from prometheus.utils.issue_util import format_issue_info -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class FinalPatchSelectionStructuredOutput(BaseModel): @@ -128,7 +128,7 @@ def __init__(self, model: BaseChatModel, max_retries: int = 2): ) structured_llm = model.with_structured_output(FinalPatchSelectionStructuredOutput) self.model = prompt | structured_llm - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) self.majority_voting_times = 10 def format_human_message(self, patches: Sequence[str], state: IssueNotVerifiedBugState): diff --git a/prometheus/lang_graph/nodes/general_build_node.py b/prometheus/lang_graph/nodes/general_build_node.py index 30ff5405..db6a774c 100644 --- a/prometheus/lang_graph/nodes/general_build_node.py +++ b/prometheus/lang_graph/nodes/general_build_node.py @@ -9,7 +9,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.tools.container_command import ContainerCommandTool -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class GeneralBuildNode: @@ -48,7 +48,7 @@ def __init__(self, model: BaseChatModel, container: BaseContainer, kg: Knowledge self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): tools = [] diff --git a/prometheus/lang_graph/nodes/general_build_structured_node.py b/prometheus/lang_graph/nodes/general_build_structured_node.py index e1b10f2d..c2c60f8e 100644 --- a/prometheus/lang_graph/nodes/general_build_structured_node.py +++ b/prometheus/lang_graph/nodes/general_build_structured_node.py @@ -14,7 +14,7 @@ from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.utils.lang_graph_util import format_agent_tool_message_history -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class BuildStructuredOutput(BaseModel): @@ -238,7 +238,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BuildStructuredOutput) self.model = prompt | structured_llm - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BuildAndTestState): """Processes build state to generate structured build analysis. diff --git a/prometheus/lang_graph/nodes/general_test_node.py b/prometheus/lang_graph/nodes/general_test_node.py index 9b207ab5..3f3f0adf 100644 --- a/prometheus/lang_graph/nodes/general_test_node.py +++ b/prometheus/lang_graph/nodes/general_test_node.py @@ -9,7 +9,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.tools.container_command import ContainerCommandTool -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class GeneralTestNode: @@ -65,7 +65,7 @@ def __init__(self, model: BaseChatModel, container: BaseContainer, kg: Knowledge self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): tools = [] diff --git a/prometheus/lang_graph/nodes/general_test_structured_node.py b/prometheus/lang_graph/nodes/general_test_structured_node.py index 6b4dc55e..6c2156c9 100644 --- a/prometheus/lang_graph/nodes/general_test_structured_node.py +++ b/prometheus/lang_graph/nodes/general_test_structured_node.py @@ -14,7 +14,7 @@ from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.utils.lang_graph_util import format_agent_tool_message_history -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class TestStructuredOutput(BaseModel): @@ -287,7 +287,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(TestStructuredOutput) self.model = prompt | structured_llm - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BuildAndTestState): """Processes test state to generate structured test analysis. diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py index 728b7918..f8edc06c 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py @@ -6,6 +6,7 @@ GetPassRegressionTestPatchState, ) from prometheus.models.test_patch_result import TestedPatchResult +from prometheus.utils.logger_manager import get_thread_logger class GetPassRegressionTestPatchCheckResultNode: @@ -14,10 +15,7 @@ class GetPassRegressionTestPatchCheckResultNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes." - f"get_pass_regression_test_patch_check_result_node" - ) + self._logger, file_handler = get_thread_logger(__name__ + f"get_pass_regression_test_patch_check_result_node") def __call__(self, state: GetPassRegressionTestPatchState): """ diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py index 63d56579..f72392a7 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py @@ -11,7 +11,7 @@ GetPassRegressionTestPatchSubgraph, ) from prometheus.models.test_patch_result import TestedPatchResult - +from prometheus.utils.logger_manager import get_thread_logger class GetPassRegressionTestPatchSubgraphNode: def __init__( @@ -22,9 +22,7 @@ def __init__( testing_patch_key: str, is_testing_patch_list: bool, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.get_pass_regression_test_patch_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = GetPassRegressionTestPatchSubgraph( base_model=model, container=container, diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py index caf04160..095b5add 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py @@ -5,6 +5,7 @@ from prometheus.lang_graph.subgraphs.get_pass_regression_test_patch_state import ( GetPassRegressionTestPatchState, ) +from prometheus.utils.logger_manager import get_thread_logger class GetPassRegressionTestPatchUpdateNode: @@ -17,9 +18,7 @@ def __init__( git_repo: GitRepository, ): self.git_repo = git_repo - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.get_pass_regression_test_patch_update_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: GetPassRegressionTestPatchState): """ diff --git a/prometheus/lang_graph/nodes/git_apply_patch_node.py b/prometheus/lang_graph/nodes/git_apply_patch_node.py index 242795d3..f73f03d3 100644 --- a/prometheus/lang_graph/nodes/git_apply_patch_node.py +++ b/prometheus/lang_graph/nodes/git_apply_patch_node.py @@ -3,6 +3,7 @@ from typing import Dict from prometheus.git.git_repository import GitRepository +from prometheus.utils.logger_manager import get_thread_logger class GitApplyPatchNode: @@ -13,9 +14,7 @@ class GitApplyPatchNode: def __init__(self, git_repo: GitRepository, state_patch_name: str): self.git_repo = git_repo self.state_patch_name = state_patch_name - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.git_apply_patch_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): """ diff --git a/prometheus/lang_graph/nodes/git_diff_node.py b/prometheus/lang_graph/nodes/git_diff_node.py index 6f9b0da0..87e42819 100644 --- a/prometheus/lang_graph/nodes/git_diff_node.py +++ b/prometheus/lang_graph/nodes/git_diff_node.py @@ -10,7 +10,7 @@ from typing import Dict, Optional from prometheus.git.git_repository import GitRepository -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class GitDiffNode: @@ -33,7 +33,7 @@ def __init__( self.state_patch_name = state_patch_name self.state_excluded_files_key = state_excluded_files_key self.return_list = return_list - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): """Generates a Git diff for the current project state. diff --git a/prometheus/lang_graph/nodes/git_reset_node.py b/prometheus/lang_graph/nodes/git_reset_node.py index 5bda5375..77d74220 100644 --- a/prometheus/lang_graph/nodes/git_reset_node.py +++ b/prometheus/lang_graph/nodes/git_reset_node.py @@ -1,7 +1,7 @@ import threading from prometheus.git.git_repository import GitRepository -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class GitResetNode: @@ -10,7 +10,7 @@ def __init__( git_repo: GitRepository, ): self.git_repo = git_repo - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, _): self._logger.debug("Resetting the git repository") diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py index d6ecae23..34b9f193 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py @@ -4,7 +4,7 @@ from langchain_core.messages import HumanMessage from prometheus.utils.issue_util import format_issue_info -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueBugAnalyzerMessageNode: @@ -65,7 +65,7 @@ class IssueBugAnalyzerMessageNode: """ def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: Dict): edit_error = "" diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index ae568bc0..f6fdd728 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -7,7 +7,7 @@ from langchain_core.messages import SystemMessage from prometheus.tools.web_search import WebSearchTool -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueBugAnalyzerNode: @@ -71,7 +71,7 @@ def __init__(self, model: BaseChatModel): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): """Initializes tools for the node.""" diff --git a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py index f7d09aee..21edbb9e 100644 --- a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py @@ -2,7 +2,7 @@ from typing import Dict from prometheus.utils.issue_util import format_issue_info -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueBugContextMessageNode: @@ -20,7 +20,7 @@ class IssueBugContextMessageNode: """ def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): bug_fix_query = self.BUG_FIX_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py index 3d53d250..86bdfb66 100644 --- a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py @@ -2,7 +2,7 @@ from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_info -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueBugReproductionContextMessageNode: @@ -109,7 +109,7 @@ def test_file_permission_denied(self, mock_open, mock_access): """ def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BugReproductionState): bug_reproducing_query = self.BUG_REPRODUCING_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_bug_responder_node.py b/prometheus/lang_graph/nodes/issue_bug_responder_node.py index d980772a..083fa69b 100644 --- a/prometheus/lang_graph/nodes/issue_bug_responder_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_responder_node.py @@ -5,7 +5,7 @@ from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState from prometheus.utils.issue_util import format_issue_info -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueBugResponderNode: @@ -49,7 +49,7 @@ def __init__(self, model: BaseChatModel): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model = model - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: IssueBugState) -> HumanMessage: verification_messages = [] diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index 4b550bdd..86446af6 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -9,7 +9,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_state import IssueState from prometheus.lang_graph.subgraphs.issue_bug_subgraph import IssueBugSubgraph -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueBugSubgraphNode: @@ -26,7 +26,7 @@ def __init__( git_repo: GitRepository, test_commands: Optional[Sequence[str]] = None, ): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) self.container = container self.issue_bug_subgraph = IssueBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py index 5a6aeee3..7f82823c 100644 --- a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py @@ -2,7 +2,7 @@ from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState from prometheus.utils.issue_util import format_issue_info -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueClassificationContextMessageNode: @@ -73,7 +73,7 @@ class IssueClassificationContextMessageNode: """ def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: IssueClassificationState): issue_classification_query = self.ISSUE_CLASSIFICATION_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py index 3f2cdc06..2bf941c6 100644 --- a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py @@ -7,7 +7,7 @@ from prometheus.lang_graph.subgraphs.issue_classification_subgraph import ( IssueClassificationSubgraph, ) -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueClassificationSubgraphNode: @@ -17,7 +17,7 @@ def __init__( kg: KnowledgeGraph, local_path: str, ): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) self.issue_classification_subgraph = IssueClassificationSubgraph( model=model, kg=kg, diff --git a/prometheus/lang_graph/nodes/issue_classifier_node.py b/prometheus/lang_graph/nodes/issue_classifier_node.py index 6127a034..e792e420 100644 --- a/prometheus/lang_graph/nodes/issue_classifier_node.py +++ b/prometheus/lang_graph/nodes/issue_classifier_node.py @@ -7,7 +7,7 @@ from prometheus.lang_graph.graphs.issue_state import IssueType from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState from prometheus.utils.issue_util import format_issue_info -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueClassifierOutput(BaseModel): @@ -125,7 +125,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(IssueClassifierOutput) self.model = prompt | structured_llm - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def format_context_info(self, state: IssueClassificationState) -> str: context_info = self.ISSUE_CLASSIFICATION_CONTEXT.format( diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index 4fa5806c..12792837 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -10,7 +10,7 @@ from prometheus.lang_graph.subgraphs.issue_not_verified_bug_subgraph import ( IssueNotVerifiedBugSubgraph, ) -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueNotVerifiedBugSubgraphNode: @@ -22,7 +22,7 @@ def __init__( git_repo: GitRepository, container: BaseContainer, ): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) self.issue_not_verified_bug_subgraph = IssueNotVerifiedBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/issue_question_analyzer_node.py b/prometheus/lang_graph/nodes/issue_question_analyzer_node.py index 21bf01dd..4dc2e1dc 100644 --- a/prometheus/lang_graph/nodes/issue_question_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_question_analyzer_node.py @@ -6,6 +6,7 @@ from prometheus.lang_graph.subgraphs.issue_question_state import IssueQuestionState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueQuestionAnalyzerNode: @@ -46,9 +47,7 @@ class IssueQuestionAnalyzerNode: def __init__(self, model: BaseChatModel): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model = model - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_analyzer_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: IssueQuestionState): human_prompt = HumanMessage( diff --git a/prometheus/lang_graph/nodes/issue_question_context_message_node.py b/prometheus/lang_graph/nodes/issue_question_context_message_node.py index b9546398..9bd02f12 100644 --- a/prometheus/lang_graph/nodes/issue_question_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_question_context_message_node.py @@ -3,6 +3,7 @@ from typing import Dict from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueQuestionContextMessageNode: @@ -21,9 +22,7 @@ class IssueQuestionContextMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_context_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): question_query = self.QUESTION_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py index 4bcaf0d2..4cb71c7b 100644 --- a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py @@ -8,6 +8,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_state import IssueState from prometheus.lang_graph.subgraphs.issue_question_subgraph import IssueQuestionSubgraph +from prometheus.utils.logger_manager import get_thread_logger class IssueQuestionSubgraphNode: @@ -23,9 +24,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.issue_question_subgraph = IssueQuestionSubgraph( advanced_model=advanced_model, base_model=base_model, diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index fccf8135..5a6f162a 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -8,7 +8,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState from prometheus.lang_graph.subgraphs.issue_verified_bug_subgraph import IssueVerifiedBugSubgraph -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class IssueVerifiedBugSubgraphNode: @@ -24,7 +24,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, ): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) self.git_repo = git_repo self.issue_reproduced_bug_subgraph = IssueVerifiedBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/noop_node.py b/prometheus/lang_graph/nodes/noop_node.py index c5878235..b04c9928 100644 --- a/prometheus/lang_graph/nodes/noop_node.py +++ b/prometheus/lang_graph/nodes/noop_node.py @@ -8,7 +8,7 @@ import threading from typing import Dict -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class NoopNode: @@ -21,7 +21,7 @@ class NoopNode: """ def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict) -> None: """Routes the workflow without performing any operations. diff --git a/prometheus/lang_graph/nodes/patch_normalization_node.py b/prometheus/lang_graph/nodes/patch_normalization_node.py index 47a781a6..799e2866 100644 --- a/prometheus/lang_graph/nodes/patch_normalization_node.py +++ b/prometheus/lang_graph/nodes/patch_normalization_node.py @@ -11,7 +11,7 @@ from typing import Dict, List, Sequence from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger @dataclass @@ -39,7 +39,7 @@ class PatchNormalizationNode: """ def __init__(self): - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def normalize_patch(self, raw_patch: str) -> str: """Normalize patch content for deduplication diff --git a/prometheus/lang_graph/nodes/reset_messages_node.py b/prometheus/lang_graph/nodes/reset_messages_node.py index a58d8422..ba8f9b08 100644 --- a/prometheus/lang_graph/nodes/reset_messages_node.py +++ b/prometheus/lang_graph/nodes/reset_messages_node.py @@ -13,7 +13,7 @@ import threading from typing import Dict -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class ResetMessagesNode: @@ -37,7 +37,7 @@ def __init__(self, message_state_key: str): be reset during node execution. """ self.message_state_key = message_state_key - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): """Resets the specified message state for the next iteration. diff --git a/prometheus/lang_graph/nodes/run_existing_tests_node.py b/prometheus/lang_graph/nodes/run_existing_tests_node.py index 1cddef6a..d78e8d41 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_node.py @@ -3,6 +3,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.run_existing_tests_state import RunExistingTestsState +from prometheus.utils.logger_manager import get_thread_logger class RunExistingTestsNode: @@ -11,9 +12,7 @@ class RunExistingTestsNode: """ def __init__(self, container: BaseContainer): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_regression_tests_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.container = container def __call__(self, state: RunExistingTestsState): diff --git a/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py b/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py index 5cbbe094..ffbcaa04 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from prometheus.lang_graph.subgraphs.run_existing_tests_state import RunExistingTestsState +from prometheus.utils.logger_manager import get_thread_logger class RunExistingTestsStructureOutput(BaseModel): @@ -55,9 +56,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(RunExistingTestsStructureOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_existing_tests_structure_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: RunExistingTestsState): # Get human message from the state diff --git a/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py b/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py index ac4b0bd0..4c8fb3bf 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py @@ -7,6 +7,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.git.git_repository import GitRepository from prometheus.lang_graph.subgraphs.run_existing_tests_subgraph import RunExistingTestsSubgraph +from prometheus.utils.logger_manager import get_thread_logger class RunExistingTestsSubgraphNode: @@ -18,9 +19,7 @@ def __init__( testing_patch_key: str, existing_test_fail_log_key: str, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_existing_tests_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = RunExistingTestsSubgraph( base_model=model, container=container, git_repo=git_repo ) diff --git a/prometheus/lang_graph/nodes/run_regression_tests_node.py b/prometheus/lang_graph/nodes/run_regression_tests_node.py index 18ddfcb4..6c461db7 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_node.py @@ -8,7 +8,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.run_regression_tests_state import RunRegressionTestsState from prometheus.tools.container_command import ContainerCommandTool -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class RunRegressionTestsNode: @@ -59,7 +59,7 @@ def __init__(self, model: BaseChatModel, container: BaseContainer): self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): tools = [] diff --git a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py index ad1da1b5..97d8afb7 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py @@ -8,6 +8,7 @@ from prometheus.lang_graph.subgraphs.run_regression_tests_state import RunRegressionTestsState from prometheus.utils.lang_graph_util import get_last_message_content +from prometheus.utils.logger_manager import get_thread_logger class RunRegressionTestsStructureOutput(BaseModel): @@ -104,9 +105,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(RunRegressionTestsStructureOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_regression_tests_structure_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def get_human_message(self, state: RunRegressionTestsState) -> str: # Format the human message using the state diff --git a/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py b/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py index d07f7b6d..cadddf6f 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py @@ -7,15 +7,14 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.run_regression_tests_subgraph import RunRegressionTestsSubgraph +from prometheus.utils.logger_manager import get_thread_logger class RunRegressionTestsSubgraphNode: def __init__( self, model: BaseChatModel, container: BaseContainer, passed_regression_tests_key: str ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_regression_tests_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = RunRegressionTestsSubgraph( base_model=model, container=container, diff --git a/prometheus/lang_graph/nodes/update_container_node.py b/prometheus/lang_graph/nodes/update_container_node.py index c1dacb6d..38625e46 100644 --- a/prometheus/lang_graph/nodes/update_container_node.py +++ b/prometheus/lang_graph/nodes/update_container_node.py @@ -11,7 +11,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.git.git_repository import GitRepository -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger from prometheus.utils.patch_util import get_updated_files @@ -34,7 +34,7 @@ def __init__(self, container: BaseContainer, git_repo: GitRepository): """ self.container = container self.git_repo = git_repo - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, _: Dict): """Synchronizes the current project state with the container.""" diff --git a/prometheus/lang_graph/nodes/user_defined_build_node.py b/prometheus/lang_graph/nodes/user_defined_build_node.py index 660b5373..d8283220 100644 --- a/prometheus/lang_graph/nodes/user_defined_build_node.py +++ b/prometheus/lang_graph/nodes/user_defined_build_node.py @@ -5,13 +5,13 @@ from langchain_core.messages import ToolMessage from prometheus.docker.base_container import BaseContainer -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class UserDefinedBuildNode: def __init__(self, container: BaseContainer): self.container = container - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, _: Any): build_output = self.container.run_build() diff --git a/prometheus/lang_graph/nodes/user_defined_test_node.py b/prometheus/lang_graph/nodes/user_defined_test_node.py index c3608864..309ae163 100644 --- a/prometheus/lang_graph/nodes/user_defined_test_node.py +++ b/prometheus/lang_graph/nodes/user_defined_test_node.py @@ -5,13 +5,13 @@ from langchain_core.messages import ToolMessage from prometheus.docker.base_container import BaseContainer -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class UserDefinedTestNode: def __init__(self, container: BaseContainer): self.container = container - self._logger = get_logger(f"thread-{threading.get_ident()}.{__name__}") + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, _: Any): test_output = self.container.run_test() diff --git a/prometheus/neo4j/knowledge_graph_handler.py b/prometheus/neo4j/knowledge_graph_handler.py index de627253..b572dd34 100644 --- a/prometheus/neo4j/knowledge_graph_handler.py +++ b/prometheus/neo4j/knowledge_graph_handler.py @@ -16,8 +16,7 @@ Neo4jTextNode, ) from prometheus.graph.knowledge_graph import KnowledgeGraph - -# from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger class KnowledgeGraphHandler: @@ -32,7 +31,7 @@ def __init__(self, driver: AsyncGraphDatabase.driver, batch_size: int): self.driver = driver self.batch_size = batch_size # initialize the database and logger - self._logger = logging.getLogger(__name__) + self._logger, file_handler = get_thread_logger(__name__) async def init_database(self): """Initialization of the neo4j database.""" diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index f9cee6df..c1057cea 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -8,10 +8,10 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.utils.knowledge_graph_utils import format_knowledge_graph_data -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger from prometheus.utils.str_util import pre_append_line_numbers -logger = get_logger(__name__) +logger, file_handler = get_thread_logger(__name__) @dataclass diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 1b86d833..0bb7fb40 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -5,8 +5,7 @@ from tavily import InvalidAPIKeyError, TavilyClient, UsageLimitExceededError from prometheus.configuration.config import settings -from prometheus.exceptions.web_search_tool_exception import WebSearchToolException -from prometheus.utils.logger_manager import get_logger +from prometheus.utils.logger_manager import get_thread_logger @dataclass @@ -75,8 +74,15 @@ class WebSearchTool: def __init__(self): """Initialize the web search tool.""" # Load environment variables from .env file - self._logger = get_logger(__name__) - self.tavily_client = TavilyClient(api_key=settings.TAVILY_API_KEY) + self._logger, file_handler = get_thread_logger(__name__) + + tavily_api_key = settings.TAVILY_API_KEY + if tavily_api_key is None: + self._logger.warning("Tavily API key is not set") + tavily_client = None + else: + tavily_client = TavilyClient(api_key=tavily_api_key) + self.tavily_client = tavily_client def web_search( self, diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index b4c98f6b..1b8064aa 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -91,10 +91,13 @@ def __new__(cls) -> "LoggerManager": def __init__(self): """Initialize logger manager""" + self.log_level = getattr(settings, "LOGGING_LEVEL") + self.issue_log_dir = Path(getattr(settings, "WORKING_DIRECTORY")) / "answer_issue_logs" if not self._initialized: self._setup_root_logger() self._initialized = True + def _setup_root_logger(self): """Setup root logger""" # Get root logger @@ -104,25 +107,19 @@ def _setup_root_logger(self): self.root_logger.handlers.clear() # Set log level - log_level = getattr(settings, "LOGGING_LEVEL", "DEBUG") - self.root_logger.setLevel(getattr(logging, log_level)) + self.root_logger.setLevel(getattr(logging, self.log_level)) # Create colored formatter for console output - self.colored_formatter = ColoredFormatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - + self.colored_formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") # Create plain formatter for file output - self.file_formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + self.file_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + + # Create console handler (using colored formatter) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(self.colored_formatter) - console_handler.setLevel( - getattr(logging, log_level) - ) # Ensure console handler uses same level + console_handler.setLevel(getattr(logging, self.log_level)) # Ensure console handler uses same level self.root_logger.addHandler(console_handler) # Prevent log propagation to parent logger @@ -131,6 +128,39 @@ def _setup_root_logger(self): # Log configuration information self._log_configuration() + def _set_multi_threads_log_file_handler(self, thread_id: int, logger_name: str): + """Set multi threads log file handler""" + # Find existing log file for this thread_id, or create new one if none exists + log_file_path = self._find_or_create_log_file(thread_id) + file_handler = self.create_file_handler(log_file_path, logger_name) + return file_handler + + def _find_or_create_log_file(self, thread_id: int) -> Path: + """ + Find existing log file for the thread_id, or create new one if none exists + + Args: + thread_id: Thread ID to find/create log file for + + Returns: + Path to the log file (existing earliest one or newly created) + """ + import glob + + # Pattern to match log files for this thread_id + pattern = str(self.issue_log_dir / f"*_{thread_id}.log") + existing_logs = glob.glob(pattern) + + if existing_logs: + # Sort by filename (which includes timestamp) to get chronological order + existing_logs.sort() + # Return the earliest (first) file + return Path(existing_logs[0]) + else: + # No existing log file found, create a new one + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return self.issue_log_dir / f"{timestamp}_{thread_id}.log" + def _log_configuration(self): """Log configuration information""" # 动态获取settings中所有可用的配置属性 @@ -163,9 +193,9 @@ def get_logger(self, name: str) -> logging.Logger: Returns: Configured logger instance """ - # Ensure logger name starts with prometheus - if not name.startswith("prometheus"): - name = f"prometheus.{name}" + # # Ensure logger name starts with prometheus + # if not name.startswith("prometheus"): + # name = f"prometheus.{name}" logger = logging.getLogger(name) @@ -176,9 +206,8 @@ def get_logger(self, name: str) -> logging.Logger: return logger - def create_file_handler( - self, log_file_path: Path, logger_name: str = "prometheus" - ) -> logging.FileHandler: + + def create_file_handler(self, log_file_path: Path, logger_name: str) -> logging.FileHandler: """ Create file handler for specified logger @@ -191,38 +220,31 @@ def create_file_handler( """ # Ensure log directory exists log_file_path.parent.mkdir(parents=True, exist_ok=True) - - # Create file handler (using plain formatter, without colors) - file_handler = logging.FileHandler(log_file_path) + + # Create file handler with append mode to preserve existing content + file_handler = logging.FileHandler(log_file_path, mode='a') + file_handler.setLevel(getattr(logging, self.log_level)) file_handler.setFormatter(self.file_formatter) - - # Ensure file handler uses the same log level as console handler - log_level = getattr(settings, "LOGGING_LEVEL", "DEBUG") - file_handler.setLevel(getattr(logging, log_level)) - - # Get logger and add handler - logger = self.get_logger(logger_name) - logger.addHandler(file_handler) - + + # # Get logger directly without going through get_logger to avoid recursion + # # Ensure logger name starts with prometheus + # if not logger_name.startswith("prometheus"): + # logger_name = f"prometheus.{logger_name}" + + logger = logging.getLogger(logger_name) + + # If it's a child logger, inherit root logger configuration + if logger_name != "prometheus": + logger.parent = self.root_logger + logger.propagate = True + + # Check if this logger already has a file handler to avoid duplicates + has_file_handler = any(isinstance(h, logging.FileHandler) for h in logger.handlers) + if not has_file_handler: + logger.addHandler(file_handler) + return file_handler - def create_timestamped_file_handler( - self, log_dir: Path, prefix: str = "prometheus", logger_name: str = "prometheus" - ) -> logging.FileHandler: - """ - Create file handler with timestamp - - Args: - log_dir: Log directory - prefix: Log file prefix - logger_name: Logger name - - Returns: - Configured file handler - """ - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = log_dir / f"{prefix}_{timestamp}.log" - return self.create_file_handler(log_file, logger_name) def remove_file_handler(self, handler: logging.FileHandler, logger_name: str = "prometheus"): """ @@ -236,17 +258,23 @@ def remove_file_handler(self, handler: logging.FileHandler, logger_name: str = " logger.removeHandler(handler) handler.close() - def enable_colors(self): - """Enable colored log output""" - self.colored_formatter.use_colors = True and self.colored_formatter._supports_color() + def remove_multi_thread_file_handler(self, handler: logging.FileHandler, logger_name: str = None): + """ + Remove multi-thread file handler from specific logger - def disable_colors(self): - """Disable colored log output""" - self.colored_formatter.use_colors = False + Args: + handler: File handler to remove + logger_name: Logger name to remove handler from + """ + if logger_name: + logger = self.get_logger(logger_name) + logger.removeHandler(handler) + else: + # Fallback: try to remove from root logger + self.root_logger.removeHandler(handler) + handler.close() - def is_colors_enabled(self) -> bool: - """Check if colored output is enabled""" - return self.colored_formatter.use_colors + # Create global logger manager instance @@ -264,40 +292,49 @@ def get_logger(name: str) -> logging.Logger: Configured logger instance Examples: - >>> logger = get_logger(__name__) - >>> logger = get_logger("prometheus.tools.web_search") + >>> logger, file_handler = get_thread_logger(__name__) + >>> logger, file_handler = get_thread_logger("prometheus.tools.web_search") """ return logger_manager.get_logger(name) -def create_file_handler( - log_file_path: Path, logger_name: str = "prometheus" -) -> logging.FileHandler: + + +def remove_multi_threads_log_file_handler(handler: logging.FileHandler, logger_name: str = None): """ - Convenience function to create file handler + Convenience function to remove multi-thread file handler Args: - log_file_path: Log file path - logger_name: Logger name - - Returns: - Configured file handler + handler: File handler to remove + logger_name: Logger name (optional) """ - return logger_manager.create_file_handler(log_file_path, logger_name) + logger_manager.remove_multi_thread_file_handler(handler, logger_name) -def create_timestamped_file_handler( - log_dir: Path, prefix: str = "prometheus", logger_name: str = "prometheus" -) -> logging.FileHandler: +def get_thread_logger(module_name: str) -> tuple[logging.Logger, logging.FileHandler]: """ - Convenience function to create timestamped file handler - + Convenience function to create a thread-specific logger with file handler in one call + Args: - log_dir: Log directory - prefix: Log file prefix - logger_name: Logger name - + module_name: Module name (usually __name__), if None, uses current module + Returns: - Configured file handler + Tuple of (logger, file_handler) for easy cleanup + + Examples: + >>> logger, file_handler = get_thread_logger(__name__) + >>> logger.info("This goes to both console and file") + >>> # In finally block: + >>> remove_multi_threads_log_file_handler(file_handler, logger.name) """ - return logger_manager.create_timestamped_file_handler(log_dir, prefix, logger_name) + import threading + + # Get thread ID + thread_id = threading.get_ident() + logger_name = f"thread-{thread_id}.{module_name}" + + # Create file handler and logger + file_handler = logger_manager._set_multi_threads_log_file_handler(thread_id, logger_name) + logger = get_logger(logger_name) + return logger, file_handler + From 591f18ab12d3b0a9516bdbad760afd1f0ea8d6db Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Sat, 6 Sep 2025 00:51:09 +0800 Subject: [PATCH 25/30] fix: Remove unused threading imports across multiple files --- prometheus/app/main.py | 5 +- .../app/services/invitation_code_service.py | 1 - prometheus/app/services/issue_service.py | 3 - .../app/services/knowledge_graph_service.py | 1 - prometheus/app/services/neo4j_service.py | 2 - prometheus/app/services/user_service.py | 1 - prometheus/docker/base_container.py | 1 - prometheus/graph/knowledge_graph.py | 1 - .../bug_fix_verification_subgraph_node.py | 2 - .../lang_graph/nodes/bug_fix_verify_node.py | 1 - .../nodes/bug_fix_verify_structured_node.py | 2 - ...bug_get_regression_context_message_node.py | 3 - ...bug_get_regression_tests_selection_node.py | 3 - .../bug_get_regression_tests_subgraph_node.py | 4 +- .../nodes/bug_reproducing_execute_node.py | 1 - .../nodes/bug_reproducing_file_node.py | 1 - .../nodes/bug_reproducing_structured_node.py | 1 - .../bug_reproducing_write_message_node.py | 2 - .../nodes/bug_reproducing_write_node.py | 1 - .../nodes/bug_reproduction_subgraph_node.py | 1 - .../nodes/build_and_test_subgraph_node.py | 1 - .../nodes/context_extraction_node.py | 1 - .../lang_graph/nodes/context_provider_node.py | 1 - .../nodes/context_query_message_node.py | 2 - .../lang_graph/nodes/context_refine_node.py | 2 - .../nodes/context_retrieval_subgraph_node.py | 1 - .../lang_graph/nodes/edit_message_node.py | 1 - prometheus/lang_graph/nodes/edit_node.py | 1 - .../nodes/final_patch_selection_node.py | 1 - .../lang_graph/nodes/general_build_node.py | 1 - .../nodes/general_build_structured_node.py | 2 - .../lang_graph/nodes/general_test_node.py | 1 - .../nodes/general_test_structured_node.py | 2 - ...regression_test_patch_check_result_node.py | 6 +- ...ass_regression_test_patch_subgraph_node.py | 3 +- ..._pass_regression_test_patch_update_node.py | 3 - .../lang_graph/nodes/git_apply_patch_node.py | 2 - prometheus/lang_graph/nodes/git_diff_node.py | 1 - prometheus/lang_graph/nodes/git_reset_node.py | 2 - .../nodes/issue_bug_analyzer_message_node.py | 1 - .../nodes/issue_bug_analyzer_node.py | 1 - .../nodes/issue_bug_context_message_node.py | 1 - ...e_bug_reproduction_context_message_node.py | 2 - .../nodes/issue_bug_responder_node.py | 2 - .../nodes/issue_bug_subgraph_node.py | 1 - ...sue_classification_context_message_node.py | 2 - .../issue_classification_subgraph_node.py | 2 - .../lang_graph/nodes/issue_classifier_node.py | 2 - .../issue_not_verified_bug_subgraph_node.py | 1 - .../nodes/issue_question_analyzer_node.py | 3 - .../issue_question_context_message_node.py | 2 - .../nodes/issue_question_subgraph_node.py | 3 - .../nodes/issue_verified_bug_subgraph_node.py | 2 - prometheus/lang_graph/nodes/noop_node.py | 1 - .../nodes/patch_normalization_node.py | 1 - .../lang_graph/nodes/reset_messages_node.py | 1 - .../nodes/run_existing_tests_node.py | 3 - .../run_existing_tests_structure_node.py | 3 - .../nodes/run_existing_tests_subgraph_node.py | 2 - .../nodes/run_regression_tests_node.py | 1 - .../run_regression_tests_structure_node.py | 4 +- .../run_regression_tests_subgraph_node.py | 2 - .../lang_graph/nodes/update_container_node.py | 1 - .../nodes/user_defined_build_node.py | 1 - .../nodes/user_defined_test_node.py | 1 - prometheus/neo4j/knowledge_graph_handler.py | 1 - prometheus/tools/web_search.py | 1 + prometheus/utils/logger_manager.py | 58 +++++++++---------- 68 files changed, 36 insertions(+), 141 deletions(-) diff --git a/prometheus/app/main.py b/prometheus/app/main.py index 16cfa09c..f59dfa2b 100644 --- a/prometheus/app/main.py +++ b/prometheus/app/main.py @@ -1,8 +1,6 @@ import inspect from contextlib import asynccontextmanager from datetime import datetime, timezone -from pathlib import Path -import threading from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -17,8 +15,7 @@ register_login_required_routes, ) from prometheus.configuration.config import settings -from prometheus.utils.logger_manager import get_thread_logger, remove_multi_threads_log_file_handler - +from prometheus.utils.logger_manager import get_thread_logger # Create main thread logger with file handler - ONE LINE! logger, file_handler = get_thread_logger(__name__) diff --git a/prometheus/app/services/invitation_code_service.py b/prometheus/app/services/invitation_code_service.py index e0962fea..cc3f239f 100644 --- a/prometheus/app/services/invitation_code_service.py +++ b/prometheus/app/services/invitation_code_service.py @@ -1,5 +1,4 @@ import datetime -import logging import uuid from typing import Sequence diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index bc2f1286..4287ecb9 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -1,7 +1,4 @@ -import logging -import threading import traceback -from datetime import datetime from pathlib import Path from typing import Mapping, Optional, Sequence diff --git a/prometheus/app/services/knowledge_graph_service.py b/prometheus/app/services/knowledge_graph_service.py index fe251a75..ed52a128 100644 --- a/prometheus/app/services/knowledge_graph_service.py +++ b/prometheus/app/services/knowledge_graph_service.py @@ -1,7 +1,6 @@ """Service for managing and interacting with Knowledge Graphs in Neo4j.""" import asyncio -import logging from pathlib import Path from prometheus.app.services.base_service import BaseService diff --git a/prometheus/app/services/neo4j_service.py b/prometheus/app/services/neo4j_service.py index 5a5a714c..81cc05f7 100644 --- a/prometheus/app/services/neo4j_service.py +++ b/prometheus/app/services/neo4j_service.py @@ -1,7 +1,5 @@ """Service for managing Neo4j database driver.""" -import logging - from neo4j import AsyncGraphDatabase from prometheus.app.services.base_service import BaseService diff --git a/prometheus/app/services/user_service.py b/prometheus/app/services/user_service.py index 445a73ea..914bf5df 100644 --- a/prometheus/app/services/user_service.py +++ b/prometheus/app/services/user_service.py @@ -1,4 +1,3 @@ -import logging from typing import Optional, Sequence from argon2 import PasswordHasher diff --git a/prometheus/docker/base_container.py b/prometheus/docker/base_container.py index 1b567881..8cc5c782 100644 --- a/prometheus/docker/base_container.py +++ b/prometheus/docker/base_container.py @@ -2,7 +2,6 @@ import shutil import tarfile import tempfile -import threading from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Sequence diff --git a/prometheus/graph/knowledge_graph.py b/prometheus/graph/knowledge_graph.py index 56828446..a3e14dc1 100644 --- a/prometheus/graph/knowledge_graph.py +++ b/prometheus/graph/knowledge_graph.py @@ -19,7 +19,6 @@ import asyncio import itertools -import logging from collections import defaultdict, deque from pathlib import Path from typing import Mapping, Optional, Sequence diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index 28d310d1..8bdee470 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langgraph.errors import GraphRecursionError diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index 420a0e4b..fdb7f07c 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -1,5 +1,4 @@ import functools -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py index 1bc5c6cf..cb344b90 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field diff --git a/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py b/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py index cef875aa..8a244195 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from prometheus.lang_graph.subgraphs.bug_get_regression_tests_state import ( BugGetRegressionTestsState, ) diff --git a/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py b/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py index f557bccb..87e9d311 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field diff --git a/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py b/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py index 9c0879b3..4555419c 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -23,7 +21,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, ): - self._logger, file_handler = get_thread_logger(__name__) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = BugGetRegressionTestsSubgraph( advanced_model=advanced_model, base_model=base_model, diff --git a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py index afd00799..39eb7449 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py @@ -1,5 +1,4 @@ import functools -import threading from pathlib import Path from typing import Optional, Sequence diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index 9d6628ba..ca515b18 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -1,5 +1,4 @@ import functools -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py index 3c1ecada..b0cae515 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py @@ -1,4 +1,3 @@ -import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py index c10cbc25..9aefe39e 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.messages import HumanMessage from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index 67c45f27..bf5e3120 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -1,5 +1,4 @@ import functools -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py index 7827e42a..efe3086a 100644 --- a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py @@ -1,4 +1,3 @@ -import threading from typing import Optional, Sequence from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py index d7ee56ae..be26ec47 100644 --- a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py +++ b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py @@ -1,4 +1,3 @@ -import threading from typing import Optional, Sequence from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/context_extraction_node.py b/prometheus/lang_graph/nodes/context_extraction_node.py index eb04ece4..3ef109d6 100644 --- a/prometheus/lang_graph/nodes/context_extraction_node.py +++ b/prometheus/lang_graph/nodes/context_extraction_node.py @@ -1,4 +1,3 @@ -import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index 44cef057..221e0293 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -6,7 +6,6 @@ """ import functools -import threading from typing import Dict from langchain.tools import StructuredTool diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py index d1b77a4c..a45d2511 100644 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ b/prometheus/lang_graph/nodes/context_query_message_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.messages import HumanMessage from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState diff --git a/prometheus/lang_graph/nodes/context_refine_node.py b/prometheus/lang_graph/nodes/context_refine_node.py index 8a62cdff..309ca73c 100644 --- a/prometheus/lang_graph/nodes/context_refine_node.py +++ b/prometheus/lang_graph/nodes/context_refine_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index 3105f798..7484c101 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -1,4 +1,3 @@ -import threading from typing import Dict, Sequence from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/edit_message_node.py b/prometheus/lang_graph/nodes/edit_message_node.py index 97a58052..e4b07aa8 100644 --- a/prometheus/lang_graph/nodes/edit_message_node.py +++ b/prometheus/lang_graph/nodes/edit_message_node.py @@ -1,4 +1,3 @@ -import threading from typing import Dict from langchain_core.messages import HumanMessage diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index b46593e5..d79d0897 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -7,7 +7,6 @@ """ import functools -import threading from typing import Dict from langchain.tools import StructuredTool diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index e7a0fa52..530fe90e 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -1,4 +1,3 @@ -import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/general_build_node.py b/prometheus/lang_graph/nodes/general_build_node.py index db6a774c..779c913d 100644 --- a/prometheus/lang_graph/nodes/general_build_node.py +++ b/prometheus/lang_graph/nodes/general_build_node.py @@ -1,5 +1,4 @@ import functools -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/general_build_structured_node.py b/prometheus/lang_graph/nodes/general_build_structured_node.py index c2c60f8e..c6c7f8fd 100644 --- a/prometheus/lang_graph/nodes/general_build_structured_node.py +++ b/prometheus/lang_graph/nodes/general_build_structured_node.py @@ -6,8 +6,6 @@ identify any failures. """ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field diff --git a/prometheus/lang_graph/nodes/general_test_node.py b/prometheus/lang_graph/nodes/general_test_node.py index 3f3f0adf..7aebf0ac 100644 --- a/prometheus/lang_graph/nodes/general_test_node.py +++ b/prometheus/lang_graph/nodes/general_test_node.py @@ -1,5 +1,4 @@ import functools -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/general_test_structured_node.py b/prometheus/lang_graph/nodes/general_test_structured_node.py index 6c2156c9..31efbf72 100644 --- a/prometheus/lang_graph/nodes/general_test_structured_node.py +++ b/prometheus/lang_graph/nodes/general_test_structured_node.py @@ -6,8 +6,6 @@ identify any failures. """ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py index f8edc06c..f598936b 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py @@ -1,5 +1,3 @@ -import logging -import threading from collections import Counter from prometheus.lang_graph.subgraphs.get_pass_regression_test_patch_state import ( @@ -15,7 +13,9 @@ class GetPassRegressionTestPatchCheckResultNode: """ def __init__(self): - self._logger, file_handler = get_thread_logger(__name__ + f"get_pass_regression_test_patch_check_result_node") + self._logger, file_handler = get_thread_logger( + __name__ + "get_pass_regression_test_patch_check_result_node" + ) def __call__(self, state: GetPassRegressionTestPatchState): """ diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py index f72392a7..8c6d038e 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -13,6 +11,7 @@ from prometheus.models.test_patch_result import TestedPatchResult from prometheus.utils.logger_manager import get_thread_logger + class GetPassRegressionTestPatchSubgraphNode: def __init__( self, diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py index 095b5add..fe05d82f 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from prometheus.git.git_repository import GitRepository from prometheus.lang_graph.subgraphs.get_pass_regression_test_patch_state import ( GetPassRegressionTestPatchState, diff --git a/prometheus/lang_graph/nodes/git_apply_patch_node.py b/prometheus/lang_graph/nodes/git_apply_patch_node.py index f73f03d3..75a397b7 100644 --- a/prometheus/lang_graph/nodes/git_apply_patch_node.py +++ b/prometheus/lang_graph/nodes/git_apply_patch_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from prometheus.git.git_repository import GitRepository diff --git a/prometheus/lang_graph/nodes/git_diff_node.py b/prometheus/lang_graph/nodes/git_diff_node.py index 87e42819..9a7ad2e6 100644 --- a/prometheus/lang_graph/nodes/git_diff_node.py +++ b/prometheus/lang_graph/nodes/git_diff_node.py @@ -6,7 +6,6 @@ output. """ -import threading from typing import Dict, Optional from prometheus.git.git_repository import GitRepository diff --git a/prometheus/lang_graph/nodes/git_reset_node.py b/prometheus/lang_graph/nodes/git_reset_node.py index 77d74220..c7b26d20 100644 --- a/prometheus/lang_graph/nodes/git_reset_node.py +++ b/prometheus/lang_graph/nodes/git_reset_node.py @@ -1,5 +1,3 @@ -import threading - from prometheus.git.git_repository import GitRepository from prometheus.utils.logger_manager import get_thread_logger diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py index 34b9f193..0217832e 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py @@ -1,4 +1,3 @@ -import threading from typing import Dict from langchain_core.messages import HumanMessage diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index f6fdd728..393cf44f 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -1,5 +1,4 @@ import functools -import threading from typing import Dict from langchain.tools import StructuredTool diff --git a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py index 21edbb9e..cdf3aac3 100644 --- a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py @@ -1,4 +1,3 @@ -import threading from typing import Dict from prometheus.utils.issue_util import format_issue_info diff --git a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py index 86bdfb66..5f646566 100644 --- a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py @@ -1,5 +1,3 @@ -import threading - from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_info from prometheus.utils.logger_manager import get_thread_logger diff --git a/prometheus/lang_graph/nodes/issue_bug_responder_node.py b/prometheus/lang_graph/nodes/issue_bug_responder_node.py index 083fa69b..519fb445 100644 --- a/prometheus/lang_graph/nodes/issue_bug_responder_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_responder_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index 86446af6..70fa53f6 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -1,4 +1,3 @@ -import threading from typing import Optional, Sequence from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py index 7f82823c..39cb9ec8 100644 --- a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py @@ -1,5 +1,3 @@ -import threading - from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState from prometheus.utils.issue_util import format_issue_info from prometheus.utils.logger_manager import get_thread_logger diff --git a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py index 2bf941c6..21bbd73b 100644 --- a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from prometheus.graph.knowledge_graph import KnowledgeGraph diff --git a/prometheus/lang_graph/nodes/issue_classifier_node.py b/prometheus/lang_graph/nodes/issue_classifier_node.py index e792e420..83dc5ed1 100644 --- a/prometheus/lang_graph/nodes/issue_classifier_node.py +++ b/prometheus/lang_graph/nodes/issue_classifier_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index 12792837..17d913b6 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -1,4 +1,3 @@ -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/issue_question_analyzer_node.py b/prometheus/lang_graph/nodes/issue_question_analyzer_node.py index 4dc2e1dc..98e69e38 100644 --- a/prometheus/lang_graph/nodes/issue_question_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_question_analyzer_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage diff --git a/prometheus/lang_graph/nodes/issue_question_context_message_node.py b/prometheus/lang_graph/nodes/issue_question_context_message_node.py index 9bd02f12..929ec8a0 100644 --- a/prometheus/lang_graph/nodes/issue_question_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_question_context_message_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from prometheus.utils.issue_util import format_issue_info diff --git a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py index 4cb71c7b..9b739d12 100644 --- a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langgraph.errors import GraphRecursionError diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index 5a6f162a..99e400cc 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -1,5 +1,3 @@ -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langgraph.errors import GraphRecursionError diff --git a/prometheus/lang_graph/nodes/noop_node.py b/prometheus/lang_graph/nodes/noop_node.py index b04c9928..33f0b94c 100644 --- a/prometheus/lang_graph/nodes/noop_node.py +++ b/prometheus/lang_graph/nodes/noop_node.py @@ -5,7 +5,6 @@ node graphs where a connection is needed but no processing is required. """ -import threading from typing import Dict from prometheus.utils.logger_manager import get_thread_logger diff --git a/prometheus/lang_graph/nodes/patch_normalization_node.py b/prometheus/lang_graph/nodes/patch_normalization_node.py index 799e2866..12d84711 100644 --- a/prometheus/lang_graph/nodes/patch_normalization_node.py +++ b/prometheus/lang_graph/nodes/patch_normalization_node.py @@ -5,7 +5,6 @@ """ import re -import threading from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Sequence diff --git a/prometheus/lang_graph/nodes/reset_messages_node.py b/prometheus/lang_graph/nodes/reset_messages_node.py index ba8f9b08..90e6ebf3 100644 --- a/prometheus/lang_graph/nodes/reset_messages_node.py +++ b/prometheus/lang_graph/nodes/reset_messages_node.py @@ -10,7 +10,6 @@ - The same state attribute name is reused """ -import threading from typing import Dict from prometheus.utils.logger_manager import get_thread_logger diff --git a/prometheus/lang_graph/nodes/run_existing_tests_node.py b/prometheus/lang_graph/nodes/run_existing_tests_node.py index d78e8d41..24cd41e0 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.run_existing_tests_state import RunExistingTestsState from prometheus.utils.logger_manager import get_thread_logger diff --git a/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py b/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py index ffbcaa04..5e80ff08 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel diff --git a/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py b/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py index 4c8fb3bf..c4c9ac54 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/run_regression_tests_node.py b/prometheus/lang_graph/nodes/run_regression_tests_node.py index 6c461db7..5a8e7dd2 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_node.py @@ -1,5 +1,4 @@ import functools -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py index 97d8afb7..6b0f5233 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -105,7 +103,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(RunRegressionTestsStructureOutput) self.model = prompt | structured_llm - self._logger, file_handler = get_thread_logger(__name__) + self._logger, file_handler = get_thread_logger(__name__) def get_human_message(self, state: RunRegressionTestsState) -> str: # Format the human message using the state diff --git a/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py b/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py index cadddf6f..678de625 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel diff --git a/prometheus/lang_graph/nodes/update_container_node.py b/prometheus/lang_graph/nodes/update_container_node.py index 38625e46..be15191f 100644 --- a/prometheus/lang_graph/nodes/update_container_node.py +++ b/prometheus/lang_graph/nodes/update_container_node.py @@ -6,7 +6,6 @@ between the agent's workspace and the container environment. """ -import threading from typing import Dict from prometheus.docker.base_container import BaseContainer diff --git a/prometheus/lang_graph/nodes/user_defined_build_node.py b/prometheus/lang_graph/nodes/user_defined_build_node.py index d8283220..7cfa5650 100644 --- a/prometheus/lang_graph/nodes/user_defined_build_node.py +++ b/prometheus/lang_graph/nodes/user_defined_build_node.py @@ -1,4 +1,3 @@ -import threading import uuid from typing import Any diff --git a/prometheus/lang_graph/nodes/user_defined_test_node.py b/prometheus/lang_graph/nodes/user_defined_test_node.py index 309ae163..948c3eec 100644 --- a/prometheus/lang_graph/nodes/user_defined_test_node.py +++ b/prometheus/lang_graph/nodes/user_defined_test_node.py @@ -1,4 +1,3 @@ -import threading import uuid from typing import Any diff --git a/prometheus/neo4j/knowledge_graph_handler.py b/prometheus/neo4j/knowledge_graph_handler.py index b572dd34..ee5c79bf 100644 --- a/prometheus/neo4j/knowledge_graph_handler.py +++ b/prometheus/neo4j/knowledge_graph_handler.py @@ -1,6 +1,5 @@ """The neo4j handler for writing the knowledge graph to neo4j.""" -import logging from typing import Mapping, Sequence from neo4j import AsyncGraphDatabase, AsyncManagedTransaction diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index 0bb7fb40..cfd38b2f 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -5,6 +5,7 @@ from tavily import InvalidAPIKeyError, TavilyClient, UsageLimitExceededError from prometheus.configuration.config import settings +from prometheus.exceptions.web_search_tool_exception import WebSearchToolException from prometheus.utils.logger_manager import get_thread_logger diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index 1b8064aa..f544fd63 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -97,7 +97,6 @@ def __init__(self): self._setup_root_logger() self._initialized = True - def _setup_root_logger(self): """Setup root logger""" # Get root logger @@ -110,16 +109,20 @@ def _setup_root_logger(self): self.root_logger.setLevel(getattr(logging, self.log_level)) # Create colored formatter for console output - self.colored_formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + self.colored_formatter = ColoredFormatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) # Create plain formatter for file output - self.file_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - + self.file_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) # Create console handler (using colored formatter) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(self.colored_formatter) - console_handler.setLevel(getattr(logging, self.log_level)) # Ensure console handler uses same level + console_handler.setLevel( + getattr(logging, self.log_level) + ) # Ensure console handler uses same level self.root_logger.addHandler(console_handler) # Prevent log propagation to parent logger @@ -138,19 +141,19 @@ def _set_multi_threads_log_file_handler(self, thread_id: int, logger_name: str): def _find_or_create_log_file(self, thread_id: int) -> Path: """ Find existing log file for the thread_id, or create new one if none exists - + Args: thread_id: Thread ID to find/create log file for - + Returns: Path to the log file (existing earliest one or newly created) """ import glob - + # Pattern to match log files for this thread_id pattern = str(self.issue_log_dir / f"*_{thread_id}.log") existing_logs = glob.glob(pattern) - + if existing_logs: # Sort by filename (which includes timestamp) to get chronological order existing_logs.sort() @@ -206,7 +209,6 @@ def get_logger(self, name: str) -> logging.Logger: return logger - def create_file_handler(self, log_file_path: Path, logger_name: str) -> logging.FileHandler: """ Create file handler for specified logger @@ -220,31 +222,30 @@ def create_file_handler(self, log_file_path: Path, logger_name: str) -> logging. """ # Ensure log directory exists log_file_path.parent.mkdir(parents=True, exist_ok=True) - + # Create file handler with append mode to preserve existing content - file_handler = logging.FileHandler(log_file_path, mode='a') + file_handler = logging.FileHandler(log_file_path, mode="a") file_handler.setLevel(getattr(logging, self.log_level)) file_handler.setFormatter(self.file_formatter) - + # # Get logger directly without going through get_logger to avoid recursion # # Ensure logger name starts with prometheus # if not logger_name.startswith("prometheus"): # logger_name = f"prometheus.{logger_name}" - + logger = logging.getLogger(logger_name) - + # If it's a child logger, inherit root logger configuration if logger_name != "prometheus": logger.parent = self.root_logger logger.propagate = True - + # Check if this logger already has a file handler to avoid duplicates has_file_handler = any(isinstance(h, logging.FileHandler) for h in logger.handlers) if not has_file_handler: logger.addHandler(file_handler) - - return file_handler + return file_handler def remove_file_handler(self, handler: logging.FileHandler, logger_name: str = "prometheus"): """ @@ -258,7 +259,9 @@ def remove_file_handler(self, handler: logging.FileHandler, logger_name: str = " logger.removeHandler(handler) handler.close() - def remove_multi_thread_file_handler(self, handler: logging.FileHandler, logger_name: str = None): + def remove_multi_thread_file_handler( + self, handler: logging.FileHandler, logger_name: str = None + ): """ Remove multi-thread file handler from specific logger @@ -274,8 +277,6 @@ def remove_multi_thread_file_handler(self, handler: logging.FileHandler, logger_ self.root_logger.removeHandler(handler) handler.close() - - # Create global logger manager instance logger_manager = LoggerManager() @@ -298,8 +299,6 @@ def get_logger(name: str) -> logging.Logger: return logger_manager.get_logger(name) - - def remove_multi_threads_log_file_handler(handler: logging.FileHandler, logger_name: str = None): """ Convenience function to remove multi-thread file handler @@ -314,13 +313,13 @@ def remove_multi_threads_log_file_handler(handler: logging.FileHandler, logger_n def get_thread_logger(module_name: str) -> tuple[logging.Logger, logging.FileHandler]: """ Convenience function to create a thread-specific logger with file handler in one call - + Args: module_name: Module name (usually __name__), if None, uses current module - + Returns: Tuple of (logger, file_handler) for easy cleanup - + Examples: >>> logger, file_handler = get_thread_logger(__name__) >>> logger.info("This goes to both console and file") @@ -328,13 +327,12 @@ def get_thread_logger(module_name: str) -> tuple[logging.Logger, logging.FileHan >>> remove_multi_threads_log_file_handler(file_handler, logger.name) """ import threading - + # Get thread ID thread_id = threading.get_ident() logger_name = f"thread-{thread_id}.{module_name}" - + # Create file handler and logger file_handler = logger_manager._set_multi_threads_log_file_handler(thread_id, logger_name) logger = get_logger(logger_name) return logger, file_handler - From f7e827bdd6a41f9ddabaade8056bf69bfe14e8b3 Mon Sep 17 00:00:00 2001 From: cocoli Date: Mon, 8 Sep 2025 00:14:28 +0800 Subject: [PATCH 26/30] fix log --- prometheus/app/services/issue_service.py | 2 +- prometheus/utils/logger_manager.py | 33 ++++++++++++++++++++---- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index 4287ecb9..1cbc3225 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -78,7 +78,7 @@ def answer_issue( """ # Create thread-specific logger with file handler - ONE LINE! - logger, file_handler = get_thread_logger(__name__) + logger, file_handler = get_thread_logger(__name__, force_new_file=True) # Construct the working directory if dockerfile_content or image_name: diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index f544fd63..f1e5d35b 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -131,25 +131,40 @@ def _setup_root_logger(self): # Log configuration information self._log_configuration() - def _set_multi_threads_log_file_handler(self, thread_id: int, logger_name: str): + def _set_multi_threads_log_file_handler(self, thread_id: int, logger_name: str, force_new_file: bool = False): """Set multi threads log file handler""" # Find existing log file for this thread_id, or create new one if none exists - log_file_path = self._find_or_create_log_file(thread_id) + log_file_path = self._find_or_create_log_file(thread_id, force_new_file) file_handler = self.create_file_handler(log_file_path, logger_name) return file_handler - def _find_or_create_log_file(self, thread_id: int) -> Path: + def _find_or_create_log_file(self, thread_id: int, force_new_file: bool = False) -> Path: """ Find existing log file for the thread_id, or create new one if none exists Args: thread_id: Thread ID to find/create log file for +<<<<<<< HEAD +======= + force_new_file: If True, always create a new file with timestamp, even if existing files exist + +>>>>>>> 25fb802 (fix log) Returns: Path to the log file (existing earliest one or newly created) """ import glob +<<<<<<< HEAD +======= + + if force_new_file: + # Always create a new log file with current timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Include milliseconds for uniqueness + return self.issue_log_dir / f"{timestamp}_{thread_id}.log" + + # Original logic: find existing file or create new one +>>>>>>> 25fb802 (fix log) # Pattern to match log files for this thread_id pattern = str(self.issue_log_dir / f"*_{thread_id}.log") existing_logs = glob.glob(pattern) @@ -310,13 +325,18 @@ def remove_multi_threads_log_file_handler(handler: logging.FileHandler, logger_n logger_manager.remove_multi_thread_file_handler(handler, logger_name) -def get_thread_logger(module_name: str) -> tuple[logging.Logger, logging.FileHandler]: +def get_thread_logger(module_name: str, force_new_file: bool = False) -> tuple[logging.Logger, logging.FileHandler]: """ Convenience function to create a thread-specific logger with file handler in one call Args: module_name: Module name (usually __name__), if None, uses current module +<<<<<<< HEAD +======= + force_new_file: If True, always create a new log file with timestamp, even if existing files exist + +>>>>>>> 25fb802 (fix log) Returns: Tuple of (logger, file_handler) for easy cleanup @@ -325,6 +345,9 @@ def get_thread_logger(module_name: str) -> tuple[logging.Logger, logging.FileHan >>> logger.info("This goes to both console and file") >>> # In finally block: >>> remove_multi_threads_log_file_handler(file_handler, logger.name) + + >>> # Force creating a new file each time + >>> logger, file_handler = get_thread_logger(__name__, force_new_file=True) """ import threading @@ -333,6 +356,6 @@ def get_thread_logger(module_name: str) -> tuple[logging.Logger, logging.FileHan logger_name = f"thread-{thread_id}.{module_name}" # Create file handler and logger - file_handler = logger_manager._set_multi_threads_log_file_handler(thread_id, logger_name) + file_handler = logger_manager._set_multi_threads_log_file_handler(thread_id, logger_name, force_new_file) logger = get_logger(logger_name) return logger, file_handler From 11326c361dfbb1fa6bf8bc8d966f7a48c01c340b Mon Sep 17 00:00:00 2001 From: cocoli Date: Mon, 8 Sep 2025 00:43:52 +0800 Subject: [PATCH 27/30] fix --- prometheus/utils/logger_manager.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index f1e5d35b..2823e95b 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -144,27 +144,18 @@ def _find_or_create_log_file(self, thread_id: int, force_new_file: bool = False) Args: thread_id: Thread ID to find/create log file for -<<<<<<< HEAD - -======= force_new_file: If True, always create a new file with timestamp, even if existing files exist ->>>>>>> 25fb802 (fix log) Returns: Path to the log file (existing earliest one or newly created) """ import glob -<<<<<<< HEAD - -======= - if force_new_file: # Always create a new log file with current timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Include milliseconds for uniqueness return self.issue_log_dir / f"{timestamp}_{thread_id}.log" # Original logic: find existing file or create new one ->>>>>>> 25fb802 (fix log) # Pattern to match log files for this thread_id pattern = str(self.issue_log_dir / f"*_{thread_id}.log") existing_logs = glob.glob(pattern) @@ -331,12 +322,9 @@ def get_thread_logger(module_name: str, force_new_file: bool = False) -> tuple[l Args: module_name: Module name (usually __name__), if None, uses current module -<<<<<<< HEAD -======= force_new_file: If True, always create a new log file with timestamp, even if existing files exist ->>>>>>> 25fb802 (fix log) Returns: Tuple of (logger, file_handler) for easy cleanup From 06084edd974fb7138927d5a1ca70372a7b6d6b23 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Mon, 8 Sep 2025 01:42:53 +0800 Subject: [PATCH 28/30] fix: Refactor logger_manager.py for improved readability and formatting --- prometheus/utils/logger_manager.py | 33 +++++++++++++++++++----------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py index 2823e95b..e3199770 100644 --- a/prometheus/utils/logger_manager.py +++ b/prometheus/utils/logger_manager.py @@ -131,7 +131,9 @@ def _setup_root_logger(self): # Log configuration information self._log_configuration() - def _set_multi_threads_log_file_handler(self, thread_id: int, logger_name: str, force_new_file: bool = False): + def _set_multi_threads_log_file_handler( + self, thread_id: int, logger_name: str, force_new_file: bool = False + ): """Set multi threads log file handler""" # Find existing log file for this thread_id, or create new one if none exists log_file_path = self._find_or_create_log_file(thread_id, force_new_file) @@ -145,16 +147,19 @@ def _find_or_create_log_file(self, thread_id: int, force_new_file: bool = False) Args: thread_id: Thread ID to find/create log file for force_new_file: If True, always create a new file with timestamp, even if existing files exist - + Returns: Path to the log file (existing earliest one or newly created) """ import glob + if force_new_file: # Always create a new log file with current timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Include milliseconds for uniqueness + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[ + :-3 + ] # Include milliseconds for uniqueness return self.issue_log_dir / f"{timestamp}_{thread_id}.log" - + # Original logic: find existing file or create new one # Pattern to match log files for this thread_id pattern = str(self.issue_log_dir / f"*_{thread_id}.log") @@ -172,7 +177,7 @@ def _find_or_create_log_file(self, thread_id: int, force_new_file: bool = False) def _log_configuration(self): """Log configuration information""" - # 动态获取settings中所有可用的配置属性 + # Dynamically get all attributes from settings config_attrs = [ attr for attr in dir(settings) if attr.isupper() and not attr.startswith("_") ] @@ -180,14 +185,14 @@ def _log_configuration(self): for attr in config_attrs: value = getattr(settings, attr, "Not Set") - # 使用通配符匹配敏感配置项(包含KEY、API、PASSWORD的) + # Check if the attribute name indicates a sensitive configuration is_sensitive = any( keyword in attr.upper() for keyword in ["KEY", "API", "PASSWORD", "SECRET"] ) - # 如果是敏感配置项,用星号代替 + # If sensitive, mask the value if is_sensitive and value and value != "Not Set": - masked_value = "*" * min(len(str(value)), 8) # 最多显示8个星号 + masked_value = "*" * min(len(str(value)), 8) self.root_logger.info(f"{attr}={masked_value}") else: self.root_logger.info(f"{attr}={value}") @@ -316,7 +321,9 @@ def remove_multi_threads_log_file_handler(handler: logging.FileHandler, logger_n logger_manager.remove_multi_thread_file_handler(handler, logger_name) -def get_thread_logger(module_name: str, force_new_file: bool = False) -> tuple[logging.Logger, logging.FileHandler]: +def get_thread_logger( + module_name: str, force_new_file: bool = False +) -> tuple[logging.Logger, logging.FileHandler]: """ Convenience function to create a thread-specific logger with file handler in one call @@ -324,7 +331,7 @@ def get_thread_logger(module_name: str, force_new_file: bool = False) -> tuple[l module_name: Module name (usually __name__), if None, uses current module force_new_file: If True, always create a new log file with timestamp, even if existing files exist - + Returns: Tuple of (logger, file_handler) for easy cleanup @@ -333,7 +340,7 @@ def get_thread_logger(module_name: str, force_new_file: bool = False) -> tuple[l >>> logger.info("This goes to both console and file") >>> # In finally block: >>> remove_multi_threads_log_file_handler(file_handler, logger.name) - + >>> # Force creating a new file each time >>> logger, file_handler = get_thread_logger(__name__, force_new_file=True) """ @@ -344,6 +351,8 @@ def get_thread_logger(module_name: str, force_new_file: bool = False) -> tuple[l logger_name = f"thread-{thread_id}.{module_name}" # Create file handler and logger - file_handler = logger_manager._set_multi_threads_log_file_handler(thread_id, logger_name, force_new_file) + file_handler = logger_manager._set_multi_threads_log_file_handler( + thread_id, logger_name, force_new_file + ) logger = get_logger(logger_name) return logger, file_handler From 5ccdb6665268270d293610ec8138c5ee7be47eb2 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Mon, 8 Sep 2025 01:59:10 +0800 Subject: [PATCH 29/30] fix: Clarify failure log descriptions in regression test structure --- .../lang_graph/nodes/run_regression_tests_structure_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py index 6b0f5233..1248c0e6 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py @@ -14,7 +14,7 @@ class RunRegressionTestsStructureOutput(BaseModel): description="List of test identifier of regression tests that passed (e.g., class name and method name)" ) regression_test_fail_log: str = Field( - description="If the test failed, contains the complete test FAILURE log. Otherwise empty string" + description="If any test failed, contains the exact and complete test FAILURE log. Otherwise empty string" ) total_tests_run: int = Field( description="Total number of tests run, including both passed and failed tests, or 0 if no tests were run", @@ -31,7 +31,7 @@ class RunRegressionTestsStructuredNode: - Test summary showing "passed" or "PASSED" - Warning is ok - No "FAILURES" section -2. If a test fails, capture the complete failure output. Otherwise empty string for failure log +2. If a test fails, capture the exact and complete failure output. Otherwise empty string for failure log 3. Return the exact test identifiers that passed 4. Count the total number of tests run. Only count tests that were actually executed! If tests were unable to run due to an error, do not count them! From f9daf85639b90c63a7a179e9b073edb3a2db73c8 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Mon, 8 Sep 2025 02:01:54 +0800 Subject: [PATCH 30/30] fix: Simplify logger initialization in get_pass_regression_test_patch_check_result_node --- .../nodes/get_pass_regression_test_patch_check_result_node.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py index f598936b..b3a818a8 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py @@ -13,9 +13,7 @@ class GetPassRegressionTestPatchCheckResultNode: """ def __init__(self): - self._logger, file_handler = get_thread_logger( - __name__ + "get_pass_regression_test_patch_check_result_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: GetPassRegressionTestPatchState): """