diff --git a/.github/workflows/pytest_and_coverage.yml b/.github/workflows/pytest_and_coverage.yml index 7cb16920..eb7adf12 100644 --- a/.github/workflows/pytest_and_coverage.yml +++ b/.github/workflows/pytest_and_coverage.yml @@ -53,6 +53,9 @@ jobs: # GitHub settings PROMETHEUS_GITHUB_ACCESS_TOKEN: github_access_token + # Tavily API key + PROMETHEUS_TAVILY_API_KEY: tavily_api_key + # DATABASE settings PROMETHEUS_DATABASE_URL: postgresql://postgres:password@localhost:5432/postgres?sslmode=disable diff --git a/docker-compose.win_mac.yml b/docker-compose.win_mac.yml index 73fa8b0d..ae62fedf 100644 --- a/docker-compose.win_mac.yml +++ b/docker-compose.win_mac.yml @@ -69,6 +69,9 @@ services: - PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS=${PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS} - PROMETHEUS_BASE_MODEL_TEMPERATURE=${PROMETHEUS_BASE_MODEL_TEMPERATURE} + # Tavily API key + - PROMETHEUS_TAVILY_API_KEY=${PROMETHEUS_TAVILY_API_KEY} + # Database settings - PROMETHEUS_DATABASE_URL=${PROMETHEUS_DATABASE_URL} diff --git a/docker-compose.yml b/docker-compose.yml index 51af2e8b..9db92412 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -88,6 +88,9 @@ services: - PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS=${PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS} - PROMETHEUS_BASE_MODEL_TEMPERATURE=${PROMETHEUS_BASE_MODEL_TEMPERATURE} + # Tavily API key + - PROMETHEUS_TAVILY_API_KEY=${PROMETHEUS_TAVILY_API_KEY} + # Database settings - PROMETHEUS_DATABASE_URL=${PROMETHEUS_DATABASE_URL} diff --git a/example.env b/example.env index 84186ced..44df0990 100644 --- a/example.env +++ b/example.env @@ -38,6 +38,9 @@ PROMETHEUS_BASE_MODEL_MAX_INPUT_TOKENS=64000 PROMETHEUS_BASE_MODEL_TEMPERATURE=0.3 PROMETHEUS_BASE_MODEL_MAX_OUTPUT_TOKENS=15000 +# Tavily API settings +PROMETHEUS_TAVILY_API_KEY=your_tavily_api_key + # Database settings PROMETHEUS_DATABASE_URL=postgresql+asyncpg://postgres:password@postgres:5432/postgres diff --git a/prometheus/app/api/routes/repository.py b/prometheus/app/api/routes/repository.py index 232b782c..dfd48a23 100644 --- a/prometheus/app/api/routes/repository.py +++ b/prometheus/app/api/routes/repository.py @@ -177,7 +177,11 @@ async def list_repositories(request: Request): response_model=Response, ) @requireLogin -async def delete(repository_id: int, request: Request): +async def delete( + repository_id: int, + request: Request, + force: bool = False, +): knowledge_graph_service: KnowledgeGraphService = request.app.state.service[ "knowledge_graph_service" ] @@ -189,7 +193,7 @@ async def delete(repository_id: int, request: Request): if not repository: raise ServerException(code=404, message="Repository not found") # Check if the repository is being processed - if repository.is_working: + if repository.is_working and not force: raise ServerException( code=400, message="Repository is currently being processed, please try again later" ) diff --git a/prometheus/app/main.py b/prometheus/app/main.py index 9109a07c..f59dfa2b 100644 --- a/prometheus/app/main.py +++ b/prometheus/app/main.py @@ -1,5 +1,4 @@ import inspect -import logging from contextlib import asynccontextmanager from datetime import datetime, timezone @@ -16,27 +15,10 @@ register_login_required_routes, ) from prometheus.configuration.config import settings +from prometheus.utils.logger_manager import get_thread_logger -# Create a logger for the application's namespace -logger = logging.getLogger("prometheus") -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) -logger.addHandler(console_handler) -logger.setLevel(getattr(logging, settings.LOGGING_LEVEL)) -logger.propagate = False - -# Log the configuration settings -logger.info(f"LOGGING_LEVEL={settings.LOGGING_LEVEL}") -logger.info(f"ENVIRONMENT={settings.ENVIRONMENT}") -logger.info(f"BACKEND_CORS_ORIGINS={settings.BACKEND_CORS_ORIGINS}") -logger.info(f"ADVANCED_MODEL={settings.ADVANCED_MODEL}") -logger.info(f"BASE_MODEL={settings.BASE_MODEL}") -logger.info(f"NEO4J_BATCH_SIZE={settings.NEO4J_BATCH_SIZE}") -logger.info(f"WORKING_DIRECTORY={settings.WORKING_DIRECTORY}") -logger.info(f"KNOWLEDGE_GRAPH_MAX_AST_DEPTH={settings.KNOWLEDGE_GRAPH_MAX_AST_DEPTH}") -logger.info(f"KNOWLEDGE_GRAPH_CHUNK_SIZE={settings.KNOWLEDGE_GRAPH_CHUNK_SIZE}") -logger.info(f"KNOWLEDGE_GRAPH_CHUNK_OVERLAP={settings.KNOWLEDGE_GRAPH_CHUNK_OVERLAP}") +# Create main thread logger with file handler - ONE LINE! +logger, file_handler = get_thread_logger(__name__) @asynccontextmanager diff --git a/prometheus/app/services/database_service.py b/prometheus/app/services/database_service.py index 1442ea1e..b261cf75 100644 --- a/prometheus/app/services/database_service.py +++ b/prometheus/app/services/database_service.py @@ -1,15 +1,14 @@ -import logging - from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from prometheus.app.services.base_service import BaseService +from prometheus.utils.logger_manager import get_thread_logger class DatabaseService(BaseService): def __init__(self, DATABASE_URL: str): self.engine = create_async_engine(DATABASE_URL, echo=True) - self._logger = logging.getLogger("prometheus.app.services.database_service") + self._logger, file_handler = get_thread_logger(__name__) # Create the database and tables async def create_db_and_tables(self): diff --git a/prometheus/app/services/invitation_code_service.py b/prometheus/app/services/invitation_code_service.py index 6cd0ab07..cc3f239f 100644 --- a/prometheus/app/services/invitation_code_service.py +++ b/prometheus/app/services/invitation_code_service.py @@ -1,5 +1,4 @@ import datetime -import logging import uuid from typing import Sequence @@ -9,13 +8,14 @@ from prometheus.app.entity.invitation_code import InvitationCode from prometheus.app.services.base_service import BaseService from prometheus.app.services.database_service import DatabaseService +from prometheus.utils.logger_manager import get_thread_logger class InvitationCodeService(BaseService): def __init__(self, database_service: DatabaseService): self.database_service = database_service self.engine = database_service.engine - self._logger = logging.getLogger("prometheus.app.services.invitation_code_service") + self._logger, file_handler = get_thread_logger(__name__) async def create_invitation_code(self) -> InvitationCode: """ diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index 522a8a77..1cbc3225 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -1,7 +1,4 @@ -import logging -import threading import traceback -from datetime import datetime from pathlib import Path from typing import Mapping, Optional, Sequence @@ -13,6 +10,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_graph import IssueGraph from prometheus.lang_graph.graphs.issue_state import IssueType +from prometheus.utils.logger_manager import get_thread_logger, remove_multi_threads_log_file_handler class IssueService(BaseService): @@ -79,15 +77,8 @@ def answer_issue( - issue_type (IssueType): The type of the issue (BUG or QUESTION). """ - # Set up a dedicated logger for this thread - logger = logging.getLogger(f"thread-{threading.get_ident()}.prometheus") - logger.setLevel(getattr(logging, self.logging_level)) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = self.answer_issue_log_dir / f"{timestamp}_{threading.get_ident()}.log" - file_handler = logging.FileHandler(log_file) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + # Create thread-specific logger with file handler - ONE LINE! + logger, file_handler = get_thread_logger(__name__, force_new_file=True) # Construct the working directory if dockerfile_content or image_name: @@ -141,5 +132,5 @@ def answer_issue( logger.error(f"Error in answer_issue: {str(e)}\n{traceback.format_exc()}") return None, False, False, False, None, None finally: - logger.removeHandler(file_handler) - file_handler.close() + # Remove multi-thread file handler + remove_multi_threads_log_file_handler(file_handler, logger.name) diff --git a/prometheus/app/services/knowledge_graph_service.py b/prometheus/app/services/knowledge_graph_service.py index e893dbc2..ed52a128 100644 --- a/prometheus/app/services/knowledge_graph_service.py +++ b/prometheus/app/services/knowledge_graph_service.py @@ -1,13 +1,13 @@ """Service for managing and interacting with Knowledge Graphs in Neo4j.""" import asyncio -import logging from pathlib import Path from prometheus.app.services.base_service import BaseService from prometheus.app.services.neo4j_service import Neo4jService from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.neo4j import knowledge_graph_handler +from prometheus.utils.logger_manager import get_thread_logger class KnowledgeGraphService(BaseService): @@ -42,7 +42,7 @@ def __init__( self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.writing_lock = asyncio.Lock() - self._logger = logging.getLogger("prometheus.app.services.knowledge_graph_service") + self._logger, file_handler = get_thread_logger(__name__) async def start(self): # Initialize the Neo4j database for Knowledge Graph operations diff --git a/prometheus/app/services/neo4j_service.py b/prometheus/app/services/neo4j_service.py index 14b5312c..81cc05f7 100644 --- a/prometheus/app/services/neo4j_service.py +++ b/prometheus/app/services/neo4j_service.py @@ -1,15 +1,14 @@ """Service for managing Neo4j database driver.""" -import logging - from neo4j import AsyncGraphDatabase from prometheus.app.services.base_service import BaseService +from prometheus.utils.logger_manager import get_thread_logger class Neo4jService(BaseService): def __init__(self, neo4j_uri: str, neo4j_username: str, neo4j_password: str): - self._logger = logging.getLogger("prometheus.app.services.neo4j_service") + self._logger, file_handler = get_thread_logger(__name__) self.neo4j_driver = AsyncGraphDatabase.driver( neo4j_uri, auth=(neo4j_username, neo4j_password), diff --git a/prometheus/app/services/user_service.py b/prometheus/app/services/user_service.py index bfa7cc36..914bf5df 100644 --- a/prometheus/app/services/user_service.py +++ b/prometheus/app/services/user_service.py @@ -1,4 +1,3 @@ -import logging from typing import Optional, Sequence from argon2 import PasswordHasher @@ -11,13 +10,14 @@ from prometheus.app.services.database_service import DatabaseService from prometheus.exceptions.server_exception import ServerException from prometheus.utils.jwt_utils import JWTUtils +from prometheus.utils.logger_manager import get_thread_logger class UserService(BaseService): def __init__(self, database_service: DatabaseService): self.database_service = database_service self.engine = database_service.engine - self._logger = logging.getLogger("prometheus.app.services.user_service") + self._logger, file_handler = get_thread_logger(__name__) self.ph = PasswordHasher() self.jwt_utils = JWTUtils() diff --git a/prometheus/configuration/config.py b/prometheus/configuration/config.py index e6327eb0..00483912 100644 --- a/prometheus/configuration/config.py +++ b/prometheus/configuration/config.py @@ -66,5 +66,8 @@ class Settings(BaseSettings): # Default normal user repository number DEFAULT_USER_REPOSITORY_LIMIT: int = 5 + # tool for Websearch + TAVILY_API_KEY: str + settings = Settings() diff --git a/prometheus/docker/base_container.py b/prometheus/docker/base_container.py index 0b80e605..8cc5c782 100644 --- a/prometheus/docker/base_container.py +++ b/prometheus/docker/base_container.py @@ -2,13 +2,14 @@ import shutil import tarfile import tempfile -import threading from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Sequence import docker +from prometheus.utils.logger_manager import get_thread_logger + class BaseContainer(ABC): """An abstract base class for managing Docker containers with file synchronization capabilities. @@ -41,9 +42,7 @@ def __init__( Args: project_path: Path to the project directory to be containerized. """ - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.{self.__class__.__module__}.{self.__class__.__name__}" - ) + self._logger, file_handler = get_thread_logger(__name__) temp_dir = Path(tempfile.mkdtemp()) temp_project_path = temp_dir / project_path.name shutil.copytree(project_path, temp_project_path) diff --git a/prometheus/exceptions/web_search_tool_exception.py b/prometheus/exceptions/web_search_tool_exception.py new file mode 100644 index 00000000..873bff43 --- /dev/null +++ b/prometheus/exceptions/web_search_tool_exception.py @@ -0,0 +1,4 @@ +class WebSearchToolException(Exception): + """Custom exception for web search tool errors.""" + + pass diff --git a/prometheus/git/git_repository.py b/prometheus/git/git_repository.py index a8ec9d5b..f6f15ef5 100644 --- a/prometheus/git/git_repository.py +++ b/prometheus/git/git_repository.py @@ -1,7 +1,6 @@ """Git repository management module.""" import asyncio -import logging import shutil import tempfile from pathlib import Path @@ -9,6 +8,8 @@ from git import Git, GitCommandError, InvalidGitRepositoryError, Repo +from prometheus.utils.logger_manager import get_thread_logger + class GitRepository: """A class for managing Git repositories with support for both local and remote operations. @@ -23,12 +24,12 @@ def __init__(self): """ Initialize a GitRepository instance. """ - self._logger = logging.getLogger("prometheus.git.git_repository") + self._logger, file_handler = get_thread_logger(__name__) # Configure git command to use our logger g = Git() type(g).GIT_PYTHON_TRACE = "full" - git_cmd_logger = logging.getLogger("git.cmd") + git_cmd_logger, file_handler = get_thread_logger("git.cmd") # Ensure git command output goes to our logger for handler in git_cmd_logger.handlers: diff --git a/prometheus/graph/knowledge_graph.py b/prometheus/graph/knowledge_graph.py index cbeda724..a3e14dc1 100644 --- a/prometheus/graph/knowledge_graph.py +++ b/prometheus/graph/knowledge_graph.py @@ -19,7 +19,6 @@ import asyncio import itertools -import logging from collections import defaultdict, deque from pathlib import Path from typing import Mapping, Optional, Sequence @@ -43,6 +42,7 @@ Neo4jTextNode, TextNode, ) +from prometheus.utils.logger_manager import get_thread_logger class KnowledgeGraph: @@ -79,7 +79,7 @@ def __init__( self._next_node_id = root_node_id + len(self._knowledge_graph_nodes) self._file_graph_builder = FileGraphBuilder(max_ast_depth, chunk_size, chunk_overlap) - self._logger = logging.getLogger("prometheus.graph.knowledge_graph") + self._logger, file_handler = get_thread_logger(__name__) async def build_graph(self, root_dir: Path): """Asynchronously builds knowledge graph for a codebase at a location. diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index 15fa6fa2..8bdee470 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langgraph.errors import GraphRecursionError @@ -8,6 +5,7 @@ from prometheus.git.git_repository import GitRepository from prometheus.lang_graph.subgraphs.bug_fix_verification_subgraph import BugFixVerificationSubgraph from prometheus.lang_graph.subgraphs.issue_verified_bug_state import IssueVerifiedBugState +from prometheus.utils.logger_manager import get_thread_logger class BugFixVerificationSubgraphNode: @@ -17,9 +15,7 @@ def __init__( container: BaseContainer, git_repo: GitRepository, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_fix_verification_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.git_repo = git_repo self.subgraph = BugFixVerificationSubgraph( model=model, diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index 176202e3..fdb7f07c 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -1,6 +1,4 @@ import functools -import logging -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -8,7 +6,8 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerificationState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool +from prometheus.utils.logger_manager import get_thread_logger class BugFixVerifyNode: @@ -51,22 +50,21 @@ class BugFixVerifyNode: """ def __init__(self, model: BaseChatModel, container: BaseContainer): - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_verify_node" - ) + self._logger, file_handler = get_thread_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py index 7399c58d..cb344b90 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py @@ -1,12 +1,10 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerificationState from prometheus.utils.lang_graph_util import get_last_message_content +from prometheus.utils.logger_manager import get_thread_logger class BugFixVerifyStructureOutput(BaseModel): @@ -91,9 +89,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BugFixVerifyStructureOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_fix_verify_structured_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BugFixVerificationState): bug_fix_verify_message = get_last_message_content(state["bug_fix_verify_messages"]) diff --git a/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py b/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py index 225ca3b6..8a244195 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_context_message_node.py @@ -1,10 +1,8 @@ -import logging -import threading - from prometheus.lang_graph.subgraphs.bug_get_regression_tests_state import ( BugGetRegressionTestsState, ) from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class BugGetRegressionContextMessageNode: @@ -86,9 +84,7 @@ def test_offset_preserved(self): """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_get_regression_context_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BugGetRegressionTestsState): select_regression_query = self.SELECT_REGRESSION_QUERY.format( diff --git a/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py b/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py index 8fa8a8a1..87e9d311 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_tests_selection_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field @@ -9,6 +6,7 @@ BugGetRegressionTestsState, ) from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class RegressionTestStructuredOutPut(BaseModel): @@ -91,9 +89,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(RegressionTestsStructuredOutPut) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_get_regression_tests_selection_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: BugGetRegressionTestsState): return self.HUMAN_PROMPT.format( diff --git a/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py b/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py index 24c1b638..4555419c 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -11,6 +9,7 @@ from prometheus.lang_graph.subgraphs.bug_get_regression_tests_subgraph import ( BugGetRegressionTestsSubgraph, ) +from prometheus.utils.logger_manager import get_thread_logger class BugGetRegressionTestsSubgraphNode: @@ -22,9 +21,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_get_regression_tests_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = BugGetRegressionTestsSubgraph( advanced_model=advanced_model, base_model=base_model, diff --git a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py index b6ea0751..39eb7449 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py @@ -1,6 +1,4 @@ import functools -import logging -import threading from pathlib import Path from typing import Optional, Sequence @@ -10,8 +8,9 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool from prometheus.utils.issue_util import format_test_commands +from prometheus.utils.logger_manager import get_thread_logger from prometheus.utils.patch_util import get_updated_files @@ -53,22 +52,21 @@ def __init__( test_commands: Optional[Sequence[str]] = None, ): self.test_commands = test_commands - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_execute_node" - ) + self._logger, file_handler = get_thread_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index 2c63f8d7..ca515b18 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -1,6 +1,4 @@ import functools -import logging -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -8,8 +6,9 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState -from prometheus.tools import file_operation +from prometheus.tools.file_operation import FileOperationTool from prometheus.utils.lang_graph_util import get_last_message_content +from prometheus.utils.logger_manager import get_thread_logger class BugReproducingFileNode: @@ -38,39 +37,31 @@ class BugReproducingFileNode: def __init__(self, model: BaseChatModel, kg: KnowledgeGraph, local_path: str): self.kg = kg - self.tools = self._init_tools(local_path) + self.file_operation_tool = FileOperationTool(local_path, kg) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_file_node" - ) - - def _init_tools(self, root_path: str): - """Initializes file operation tools with the given root path. - - Args: - root_path: Base directory path for all file operations. + self._logger, file_handler = get_thread_logger(__name__) - Returns: - List of StructuredTool instances configured for file operations. - """ + def _init_tools(self): + """Initializes file operation tools.""" tools = [] - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_fn = functools.partial(self.file_operation_tool.read_file) read_file_tool = StructuredTool.from_function( func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, + name=FileOperationTool.read_file.__name__, + description=FileOperationTool.read_file_spec.description, + args_schema=FileOperationTool.read_file_spec.input_schema, ) tools.append(read_file_tool) - create_file_fn = functools.partial(file_operation.create_file, root_path=root_path) + create_file_fn = functools.partial(self.file_operation_tool.create_file) create_file_tool = StructuredTool.from_function( func=create_file_fn, - name=file_operation.create_file.__name__, - description=file_operation.CREATE_FILE_DESCRIPTION, - args_schema=file_operation.CreateFileInput, + name=FileOperationTool.create_file.__name__, + description=FileOperationTool.create_file_spec.description, + args_schema=FileOperationTool.create_file_spec.input_schema, ) tools.append(create_file_tool) diff --git a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py index 42b3be3e..b0cae515 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -12,6 +10,7 @@ format_agent_tool_message_history, get_last_message_content, ) +from prometheus.utils.logger_manager import get_thread_logger class BugReproducingStructuredOutput(BaseModel): @@ -136,9 +135,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BugReproducingStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_structured_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BugReproductionState): bug_reproducing_log = format_agent_tool_message_history( diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py index f2808f0a..9aefe39e 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py @@ -1,10 +1,8 @@ -import logging -import threading - from langchain_core.messages import HumanMessage from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class BugReproducingWriteMessageNode: @@ -25,9 +23,7 @@ class BugReproducingWriteMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_write_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: BugReproductionState): if "reproduced_bug_failure_log" in state and state["reproduced_bug_failure_log"]: diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index 827bbcf6..bf5e3120 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -1,13 +1,13 @@ import functools -import logging -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import SystemMessage +from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState -from prometheus.tools import file_operation +from prometheus.tools.file_operation import FileOperationTool +from prometheus.utils.logger_manager import get_thread_logger class BugReproducingWriteNode: @@ -111,15 +111,14 @@ def test_empty_array_parsing(parser): ''' - def __init__(self, model: BaseChatModel, local_path: str): - self.tools = self._init_tools(local_path) + def __init__(self, model: BaseChatModel, local_path: str, kg: KnowledgeGraph): + self.file_operation_tool = FileOperationTool(local_path, kg) + self.tools = self._init_tools() self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_write_node" - ) + self._logger, file_handler = get_thread_logger(__name__) - def _init_tools(self, root_path: str): + def _init_tools(self): """Initializes file operation tools with the given root path. Args: @@ -130,12 +129,12 @@ def _init_tools(self, root_path: str): """ tools = [] - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_fn = functools.partial(self.file_operation_tool.read_file) read_file_tool = StructuredTool.from_function( func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, + name=self.file_operation_tool.read_file.__name__, + description=self.file_operation_tool.read_file_spec.description, + args_schema=self.file_operation_tool.read_file_spec.input_schema, ) tools.append(read_file_tool) diff --git a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py index 032ce0ef..efe3086a 100644 --- a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Optional, Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -10,6 +8,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.bug_reproduction_subgraph import BugReproductionSubgraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState +from prometheus.utils.logger_manager import get_thread_logger class BugReproductionSubgraphNode: @@ -22,9 +21,7 @@ def __init__( git_repo: GitRepository, test_commands: Optional[Sequence[str]], ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproduction_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.git_repo = git_repo self.bug_reproduction_subgraph = BugReproductionSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py index 946d1bb2..be26ec47 100644 --- a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py +++ b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Optional, Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -8,6 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_subgraph import BuildAndTestSubgraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState +from prometheus.utils.logger_manager import get_thread_logger class BuildAndTestSubgraphNode: @@ -26,9 +25,7 @@ def __init__( build_commands=build_commands, test_commands=test_commands, ) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.build_and_test_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: IssueBugState): exist_build = None diff --git a/prometheus/lang_graph/nodes/context_extraction_node.py b/prometheus/lang_graph/nodes/context_extraction_node.py index c04ca1f4..9debb6d6 100644 --- a/prometheus/lang_graph/nodes/context_extraction_node.py +++ b/prometheus/lang_graph/nodes/context_extraction_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -16,6 +14,7 @@ extract_last_tool_messages, transform_tool_messages_to_str, ) +from prometheus.utils.logger_manager import get_thread_logger SYS_PROMPT = """\ You are a context summary agent that summarizes code contexts which is relevant to a given query. @@ -138,9 +137,7 @@ def __init__(self, model: BaseChatModel, root_path: str): structured_llm = model.with_structured_output(ContextExtractionStructuredOutput) self.model = prompt | structured_llm self.root_path = root_path - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_extraction_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: ContextRetrievalState): """ diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index f1168972..221e0293 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -6,8 +6,6 @@ """ import functools -import logging -import threading from typing import Dict from langchain.tools import StructuredTool @@ -15,7 +13,9 @@ from langchain_core.messages import SystemMessage from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.tools import file_operation, graph_traversal +from prometheus.tools.file_operation import FileOperationTool +from prometheus.tools.graph_traversal import GraphTraversalTool +from prometheus.utils.logger_manager import get_thread_logger class ContextProviderNode: @@ -103,17 +103,22 @@ def __init__( kg: Knowledge graph instance containing the processed codebase structure. Used to obtain the file tree for system prompts. """ - ast_node_types_str = ", ".join(kg.get_all_ast_node_types()) + # self.neo4j_driver = neo4j_driver + self.root_node_id = kg.root_node_id self.kg = kg self.root_path = local_path + # self.max_token_per_result = max_token_per_result + # Initialize GraphTraversalTool with the driver and token limit + self.graph_traversal_tool = GraphTraversalTool(kg) + self.file_operation_tool = FileOperationTool(local_path, kg) + + ast_node_types_str = ", ".join(kg.get_all_ast_node_types()) self.system_prompt = SystemMessage( self.SYS_PROMPT.format(file_tree=kg.get_file_tree(), ast_node_types=ast_node_types_str) ) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_provider_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def _init_tools(self): """ @@ -129,14 +134,13 @@ def _init_tools(self): # Tool: Find file node by filename (basename) # Used when only the filename (not full path) is known find_file_node_with_basename_fn = functools.partial( - graph_traversal.find_file_node_with_basename, - kg=self.kg, + self.graph_traversal_tool.find_file_node_with_basename ) find_file_node_with_basename_tool = StructuredTool.from_function( func=find_file_node_with_basename_fn, - name=graph_traversal.find_file_node_with_basename.__name__, - description=graph_traversal.FIND_FILE_NODE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindFileNodeWithBasenameInput, + name=self.graph_traversal_tool.find_file_node_with_basename.__name__, + description=self.graph_traversal_tool.find_file_node_with_basename_spec.description, + args_schema=self.graph_traversal_tool.find_file_node_with_basename_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_file_node_with_basename_tool) @@ -144,14 +148,13 @@ def _init_tools(self): # Tool: Find file node by relative path # Preferred method when the exact file path is known find_file_node_with_relative_path_fn = functools.partial( - graph_traversal.find_file_node_with_relative_path, - kg=self.kg, + self.graph_traversal_tool.find_file_node_with_relative_path ) find_file_node_with_relative_path_tool = StructuredTool.from_function( func=find_file_node_with_relative_path_fn, - name=graph_traversal.find_file_node_with_relative_path.__name__, - description=graph_traversal.FIND_FILE_NODE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindFileNodeWithRelativePathInput, + name=self.graph_traversal_tool.find_file_node_with_relative_path.__name__, + description=self.graph_traversal_tool.find_file_node_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.find_file_node_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_file_node_with_relative_path_tool) @@ -161,28 +164,26 @@ def _init_tools(self): # Tool: Find AST node by text match in file (by basename) # Useful for searching specific snippets or patterns in unknown locations find_ast_node_with_text_in_file_with_basename_fn = functools.partial( - graph_traversal.find_ast_node_with_text_in_file_with_basename, - kg=self.kg, + self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename ) find_ast_node_with_text_in_file_with_basename_tool = StructuredTool.from_function( func=find_ast_node_with_text_in_file_with_basename_fn, - name=graph_traversal.find_ast_node_with_text_in_file_with_basename.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTextInFileWithBasenameInput, + name=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename.__name__, + description=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename_spec.description, + args_schema=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_basename_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_ast_node_with_text_in_file_with_basename_tool) # Tool: Find AST node by text match in file (by relative path) find_ast_node_with_text_in_file_with_relative_path_fn = functools.partial( - graph_traversal.find_ast_node_with_text_in_file_with_relative_path, - kg=self.kg, + self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path ) find_ast_node_with_text_in_file_with_relative_path_tool = StructuredTool.from_function( func=find_ast_node_with_text_in_file_with_relative_path_fn, - name=graph_traversal.find_ast_node_with_text_in_file_with_relative_path.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTextInFileWithRelativePathInput, + name=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path.__name__, + description=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_ast_node_with_text_in_file_with_relative_path_tool) @@ -190,28 +191,26 @@ def _init_tools(self): # Tool: Find AST node by type in file (by basename) # Example types: FunctionDef, ClassDef, Assign, etc. find_ast_node_with_type_in_file_with_basename_fn = functools.partial( - graph_traversal.find_ast_node_with_type_in_file_with_basename, - kg=self.kg, + self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename ) find_ast_node_with_type_in_file_with_basename_tool = StructuredTool.from_function( func=find_ast_node_with_type_in_file_with_basename_fn, - name=graph_traversal.find_ast_node_with_type_in_file_with_basename.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTypeInFileWithBasenameInput, + name=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename.__name__, + description=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename_spec.description, + args_schema=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_basename_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_ast_node_with_type_in_file_with_basename_tool) # Tool: Find AST node by type in file (by relative path) find_ast_node_with_type_in_file_with_relative_path_fn = functools.partial( - graph_traversal.find_ast_node_with_type_in_file_with_relative_path, - kg=self.kg, + self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path ) find_ast_node_with_type_in_file_with_relative_path_tool = StructuredTool.from_function( func=find_ast_node_with_type_in_file_with_relative_path_fn, - name=graph_traversal.find_ast_node_with_type_in_file_with_relative_path.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTypeInFileWithRelativePathInput, + name=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path.__name__, + description=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_ast_node_with_type_in_file_with_relative_path_tool) @@ -220,42 +219,39 @@ def _init_tools(self): # Tool: Find text node globally by keyword find_text_node_with_text_fn = functools.partial( - graph_traversal.find_text_node_with_text, - kg=self.kg, + self.graph_traversal_tool.find_text_node_with_text ) find_text_node_with_text_tool = StructuredTool.from_function( func=find_text_node_with_text_fn, - name=graph_traversal.find_text_node_with_text.__name__, - description=graph_traversal.FIND_TEXT_NODE_WITH_TEXT_DESCRIPTION, - args_schema=graph_traversal.FindTextNodeWithTextInput, + name=self.graph_traversal_tool.find_text_node_with_text.__name__, + description=self.graph_traversal_tool.find_text_node_with_text_spec.description, + args_schema=self.graph_traversal_tool.find_text_node_with_text_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_text_node_with_text_tool) # Tool: Find text node by keyword in specific file find_text_node_with_text_in_file_fn = functools.partial( - graph_traversal.find_text_node_with_text_in_file, - kg=self.kg, + self.graph_traversal_tool.find_text_node_with_text_in_file ) find_text_node_with_text_in_file_tool = StructuredTool.from_function( func=find_text_node_with_text_in_file_fn, - name=graph_traversal.find_text_node_with_text_in_file.__name__, - description=graph_traversal.FIND_TEXT_NODE_WITH_TEXT_IN_FILE_DESCRIPTION, - args_schema=graph_traversal.FindTextNodeWithTextInFileInput, + name=self.graph_traversal_tool.find_text_node_with_text_in_file.__name__, + description=self.graph_traversal_tool.find_text_node_with_text_in_file_spec.description, + args_schema=self.graph_traversal_tool.find_text_node_with_text_in_file_spec.input_schema, response_format="content_and_artifact", ) tools.append(find_text_node_with_text_in_file_tool) # Tool: Fetch the next text node chunk in a chain (used for long docs/comments) get_next_text_node_with_node_id_fn = functools.partial( - graph_traversal.get_next_text_node_with_node_id, - kg=self.kg, + self.graph_traversal_tool.get_next_text_node_with_node_id ) get_next_text_node_with_node_id_tool = StructuredTool.from_function( func=get_next_text_node_with_node_id_fn, - name=graph_traversal.get_next_text_node_with_node_id.__name__, - description=graph_traversal.GET_NEXT_TEXT_NODE_WITH_NODE_ID_DESCRIPTION, - args_schema=graph_traversal.GetNextTextNodeWithNodeIdInput, + name=self.graph_traversal_tool.get_next_text_node_with_node_id.__name__, + description=self.graph_traversal_tool.get_next_text_node_with_node_id_spec.description, + args_schema=self.graph_traversal_tool.get_next_text_node_with_node_id_spec.input_schema, response_format="content_and_artifact", ) tools.append(get_next_text_node_with_node_id_tool) @@ -264,29 +260,26 @@ def _init_tools(self): # Tool: Preview contents of file by relative path read_file_fn = functools.partial( - file_operation.read_file_with_knowledge_graph_data, - root_path=self.root_path, - kg=self.kg, + self.file_operation_tool.read_file_with_knowledge_graph_data ) read_file_tool = StructuredTool.from_function( func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, + name=self.file_operation_tool.read_file_with_knowledge_graph_data.__name__, + description=self.file_operation_tool.read_file_spec.description, + args_schema=self.file_operation_tool.read_file_spec.input_schema, response_format="content_and_artifact", ) tools.append(read_file_tool) # Tool: Read entire code file by relative path read_code_with_relative_path_fn = functools.partial( - graph_traversal.read_code_with_relative_path, - kg=self.kg, + self.graph_traversal_tool.read_code_with_relative_path ) read_code_with_relative_path_tool = StructuredTool.from_function( func=read_code_with_relative_path_fn, - name=graph_traversal.read_code_with_relative_path.__name__, - description=graph_traversal.READ_CODE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.ReadCodeWithRelativePathInput, + name=self.graph_traversal_tool.read_code_with_relative_path.__name__, + description=self.graph_traversal_tool.read_code_with_relative_path_spec.description, + args_schema=self.graph_traversal_tool.read_code_with_relative_path_spec.input_schema, response_format="content_and_artifact", ) tools.append(read_code_with_relative_path_tool) diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py index e5372b59..a45d2511 100644 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ b/prometheus/lang_graph/nodes/context_query_message_node.py @@ -1,16 +1,12 @@ -import logging -import threading - from langchain_core.messages import HumanMessage from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.utils.logger_manager import get_thread_logger class ContextQueryMessageNode: def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_query_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: ContextRetrievalState): human_message = HumanMessage(state["query"]) diff --git a/prometheus/lang_graph/nodes/context_refine_node.py b/prometheus/lang_graph/nodes/context_refine_node.py index 25cd9bc3..309ca73c 100644 --- a/prometheus/lang_graph/nodes/context_refine_node.py +++ b/prometheus/lang_graph/nodes/context_refine_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate @@ -8,6 +5,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.utils.logger_manager import get_thread_logger class ContextRefineStructuredOutput(BaseModel): @@ -90,9 +88,7 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): ) structured_llm = model.with_structured_output(ContextRefineStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_refine_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def format_refine_message(self, state: ContextRetrievalState): original_query = state["query"] diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index 47bc7170..7484c101 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict, Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -8,6 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.context_retrieval_subgraph import ContextRetrievalSubgraph from prometheus.models.context import Context +from prometheus.utils.logger_manager import get_thread_logger class ContextRetrievalSubgraphNode: @@ -19,9 +18,7 @@ def __init__( query_key_name: str, context_key_name: str, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_retrieval_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.context_retrieval_subgraph = ContextRetrievalSubgraph( model=model, kg=kg, diff --git a/prometheus/lang_graph/nodes/edit_message_node.py b/prometheus/lang_graph/nodes/edit_message_node.py index e238895e..e4b07aa8 100644 --- a/prometheus/lang_graph/nodes/edit_message_node.py +++ b/prometheus/lang_graph/nodes/edit_message_node.py @@ -1,11 +1,10 @@ -import logging -import threading from typing import Dict from langchain_core.messages import HumanMessage from prometheus.utils.issue_util import format_issue_info from prometheus.utils.lang_graph_util import get_last_message_content +from prometheus.utils.logger_manager import get_thread_logger class EditMessageNode: @@ -43,9 +42,7 @@ class EditMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.edit_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: Dict): edit_error = "" diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index 157194cd..d79d0897 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -7,15 +7,15 @@ """ import functools -import logging -import threading from typing import Dict from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import SystemMessage -from prometheus.tools import file_operation +from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.tools.file_operation import FileOperationTool +from prometheus.utils.logger_manager import get_thread_logger class EditNode: @@ -118,15 +118,14 @@ def other_method(): 7. NEVER write or run any tests, your change will be tested by reproduction tests and regression tests later """ - def __init__(self, model: BaseChatModel, local_path: str): + def __init__(self, model: BaseChatModel, local_path: str, kg: KnowledgeGraph): self.system_prompt = SystemMessage(self.SYS_PROMPT) - self.tools = self._init_tools(local_path) + self.file_operation_tool = FileOperationTool(local_path, kg) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.edit_node" - ) + self._logger, file_handler = get_thread_logger(__name__) - def _init_tools(self, root_path: str): + def _init_tools(self): """Initializes file operation tools with the given root path. Args: @@ -137,50 +136,50 @@ def _init_tools(self, root_path: str): """ tools = [] - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_fn = functools.partial(self.file_operation_tool.read_file) read_file_tool = StructuredTool.from_function( func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, + name=self.file_operation_tool.read_file.__name__, + description=self.file_operation_tool.read_file_spec.description, + args_schema=self.file_operation_tool.read_file_spec.input_schema, ) tools.append(read_file_tool) read_file_with_line_numbers_fn = functools.partial( - file_operation.read_file_with_line_numbers, root_path=root_path + self.file_operation_tool.read_file_with_line_numbers ) read_file_with_line_numbers_tool = StructuredTool.from_function( func=read_file_with_line_numbers_fn, - name=file_operation.read_file_with_line_numbers.__name__, - description=file_operation.READ_FILE_WITH_LINE_NUMBERS_DESCRIPTION, - args_schema=file_operation.ReadFileWithLineNumbersInput, + name=self.file_operation_tool.read_file_with_line_numbers.__name__, + description=self.file_operation_tool.read_file_with_line_numbers_spec.description, + args_schema=self.file_operation_tool.read_file_with_line_numbers_spec.input_schema, ) tools.append(read_file_with_line_numbers_tool) - create_file_fn = functools.partial(file_operation.create_file, root_path=root_path) + create_file_fn = functools.partial(self.file_operation_tool.create_file) create_file_tool = StructuredTool.from_function( func=create_file_fn, - name=file_operation.create_file.__name__, - description=file_operation.CREATE_FILE_DESCRIPTION, - args_schema=file_operation.CreateFileInput, + name=self.file_operation_tool.create_file.__name__, + description=self.file_operation_tool.create_file_spec.description, + args_schema=self.file_operation_tool.create_file_spec.input_schema, ) tools.append(create_file_tool) - delete_fn = functools.partial(file_operation.delete, root_path=root_path) + delete_fn = functools.partial(self.file_operation_tool.delete) delete_tool = StructuredTool.from_function( func=delete_fn, - name=file_operation.delete.__name__, - description=file_operation.DELETE_DESCRIPTION, - args_schema=file_operation.DeleteInput, + name=self.file_operation_tool.delete.__name__, + description=self.file_operation_tool.delete_spec.description, + args_schema=self.file_operation_tool.delete_spec.input_schema, ) tools.append(delete_tool) - edit_file_fn = functools.partial(file_operation.edit_file, root_path=root_path) + edit_file_fn = functools.partial(self.file_operation_tool.edit_file) edit_file_tool = StructuredTool.from_function( func=edit_file_fn, - name=file_operation.edit_file.__name__, - description=file_operation.EDIT_FILE_DESCRIPTION, - args_schema=file_operation.EditFileInput, + name=self.file_operation_tool.edit_file.__name__, + description=self.file_operation_tool.edit_file_spec.description, + args_schema=self.file_operation_tool.edit_file_spec.input_schema, ) tools.append(edit_file_tool) diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index fbca8c2b..13554498 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -8,6 +6,7 @@ from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class FinalPatchSelectionStructuredOutput(BaseModel): @@ -128,9 +127,7 @@ def __init__(self, model: BaseChatModel, max_retries: int = 2): ) structured_llm = model.with_structured_output(FinalPatchSelectionStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.final_patch_selection_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.majority_voting_times = 10 def format_human_message(self, patches: Sequence[str], state: IssueNotVerifiedBugState): diff --git a/prometheus/lang_graph/nodes/general_build_node.py b/prometheus/lang_graph/nodes/general_build_node.py index 486763a8..779c913d 100644 --- a/prometheus/lang_graph/nodes/general_build_node.py +++ b/prometheus/lang_graph/nodes/general_build_node.py @@ -1,6 +1,4 @@ import functools -import logging -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -9,7 +7,8 @@ from prometheus.docker.base_container import BaseContainer from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool +from prometheus.utils.logger_manager import get_thread_logger class GeneralBuildNode: @@ -44,22 +43,21 @@ class GeneralBuildNode: def __init__(self, model: BaseChatModel, container: BaseContainer, kg: KnowledgeGraph): self.kg = kg - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.general_build_node" - ) + self._logger, file_handler = get_thread_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/general_build_structured_node.py b/prometheus/lang_graph/nodes/general_build_structured_node.py index e41463d7..c6c7f8fd 100644 --- a/prometheus/lang_graph/nodes/general_build_structured_node.py +++ b/prometheus/lang_graph/nodes/general_build_structured_node.py @@ -6,15 +6,13 @@ identify any failures. """ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.utils.lang_graph_util import format_agent_tool_message_history +from prometheus.utils.logger_manager import get_thread_logger class BuildStructuredOutput(BaseModel): @@ -238,9 +236,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(BuildStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.general_build_structured_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BuildAndTestState): """Processes build state to generate structured build analysis. diff --git a/prometheus/lang_graph/nodes/general_test_node.py b/prometheus/lang_graph/nodes/general_test_node.py index d875c539..7aebf0ac 100644 --- a/prometheus/lang_graph/nodes/general_test_node.py +++ b/prometheus/lang_graph/nodes/general_test_node.py @@ -1,6 +1,4 @@ import functools -import logging -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -9,7 +7,8 @@ from prometheus.docker.base_container import BaseContainer from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool +from prometheus.utils.logger_manager import get_thread_logger class GeneralTestNode: @@ -61,22 +60,21 @@ class GeneralTestNode: def __init__(self, model: BaseChatModel, container: BaseContainer, kg: KnowledgeGraph): self.kg = kg - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.general_test_node" - ) + self._logger, file_handler = get_thread_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/general_test_structured_node.py b/prometheus/lang_graph/nodes/general_test_structured_node.py index 8da0b868..31efbf72 100644 --- a/prometheus/lang_graph/nodes/general_test_structured_node.py +++ b/prometheus/lang_graph/nodes/general_test_structured_node.py @@ -6,15 +6,13 @@ identify any failures. """ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState from prometheus.utils.lang_graph_util import format_agent_tool_message_history +from prometheus.utils.logger_manager import get_thread_logger class TestStructuredOutput(BaseModel): @@ -287,9 +285,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(TestStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.general_test_structured_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BuildAndTestState): """Processes test state to generate structured test analysis. diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py index 728b7918..b3a818a8 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_check_result_node.py @@ -1,11 +1,10 @@ -import logging -import threading from collections import Counter from prometheus.lang_graph.subgraphs.get_pass_regression_test_patch_state import ( GetPassRegressionTestPatchState, ) from prometheus.models.test_patch_result import TestedPatchResult +from prometheus.utils.logger_manager import get_thread_logger class GetPassRegressionTestPatchCheckResultNode: @@ -14,10 +13,7 @@ class GetPassRegressionTestPatchCheckResultNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes." - f"get_pass_regression_test_patch_check_result_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: GetPassRegressionTestPatchState): """ diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py index 63d56579..8c6d038e 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -11,6 +9,7 @@ GetPassRegressionTestPatchSubgraph, ) from prometheus.models.test_patch_result import TestedPatchResult +from prometheus.utils.logger_manager import get_thread_logger class GetPassRegressionTestPatchSubgraphNode: @@ -22,9 +21,7 @@ def __init__( testing_patch_key: str, is_testing_patch_list: bool, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.get_pass_regression_test_patch_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = GetPassRegressionTestPatchSubgraph( base_model=model, container=container, diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py index caf04160..fe05d82f 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_update_node.py @@ -1,10 +1,8 @@ -import logging -import threading - from prometheus.git.git_repository import GitRepository from prometheus.lang_graph.subgraphs.get_pass_regression_test_patch_state import ( GetPassRegressionTestPatchState, ) +from prometheus.utils.logger_manager import get_thread_logger class GetPassRegressionTestPatchUpdateNode: @@ -17,9 +15,7 @@ def __init__( git_repo: GitRepository, ): self.git_repo = git_repo - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.get_pass_regression_test_patch_update_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: GetPassRegressionTestPatchState): """ diff --git a/prometheus/lang_graph/nodes/git_apply_patch_node.py b/prometheus/lang_graph/nodes/git_apply_patch_node.py index 242795d3..75a397b7 100644 --- a/prometheus/lang_graph/nodes/git_apply_patch_node.py +++ b/prometheus/lang_graph/nodes/git_apply_patch_node.py @@ -1,8 +1,7 @@ -import logging -import threading from typing import Dict from prometheus.git.git_repository import GitRepository +from prometheus.utils.logger_manager import get_thread_logger class GitApplyPatchNode: @@ -13,9 +12,7 @@ class GitApplyPatchNode: def __init__(self, git_repo: GitRepository, state_patch_name: str): self.git_repo = git_repo self.state_patch_name = state_patch_name - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.git_apply_patch_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): """ diff --git a/prometheus/lang_graph/nodes/git_diff_node.py b/prometheus/lang_graph/nodes/git_diff_node.py index f6923cd4..9a7ad2e6 100644 --- a/prometheus/lang_graph/nodes/git_diff_node.py +++ b/prometheus/lang_graph/nodes/git_diff_node.py @@ -6,11 +6,10 @@ output. """ -import logging -import threading from typing import Dict, Optional from prometheus.git.git_repository import GitRepository +from prometheus.utils.logger_manager import get_thread_logger class GitDiffNode: @@ -33,9 +32,7 @@ def __init__( self.state_patch_name = state_patch_name self.state_excluded_files_key = state_excluded_files_key self.return_list = return_list - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.git_diff_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): """Generates a Git diff for the current project state. diff --git a/prometheus/lang_graph/nodes/git_reset_node.py b/prometheus/lang_graph/nodes/git_reset_node.py index 1bf03f6b..c7b26d20 100644 --- a/prometheus/lang_graph/nodes/git_reset_node.py +++ b/prometheus/lang_graph/nodes/git_reset_node.py @@ -1,7 +1,5 @@ -import logging -import threading - from prometheus.git.git_repository import GitRepository +from prometheus.utils.logger_manager import get_thread_logger class GitResetNode: @@ -10,9 +8,7 @@ def __init__( git_repo: GitRepository, ): self.git_repo = git_repo - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.git_reset_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, _): self._logger.debug("Resetting the git repository") diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py index 2a8a14c7..0217832e 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py @@ -1,10 +1,9 @@ -import logging -import threading from typing import Dict from langchain_core.messages import HumanMessage from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueBugAnalyzerMessageNode: @@ -65,9 +64,7 @@ class IssueBugAnalyzerMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_analyzer_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: Dict): edit_error = "" diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index d3c501aa..393cf44f 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -1,10 +1,13 @@ -import logging -import threading +import functools from typing import Dict +from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import SystemMessage +from prometheus.tools.web_search import WebSearchTool +from prometheus.utils.logger_manager import get_thread_logger + class IssueBugAnalyzerNode: SYS_PROMPT = """\ @@ -31,26 +34,62 @@ class IssueBugAnalyzerNode: - Identifying what went wrong with the previous attempt - Suggesting revised high-level changes that avoid the previous issues +MANDATORY TOOL USAGE: +- You MUST use the web_search tool for EVERY bug analysis +- Before providing any analysis, search for: + * Similar error messages or exceptions + * Known issues with the specific libraries/frameworks involved + * Best practices for the type of bug you're analyzing + * Official documentation for error resolution +- Only proceed with analysis after gathering relevant web information + +Tools available: +- web_search: Searches the web for technical information to aid in bug analysis and resolution. +When using the web_search tool, ALWAYS include these parameters: + - exclude_domains: ["*swe-bench*"] + - include_domains: ['stackoverflow.com', 'github.com', 'developer.mozilla.org', 'learn.microsoft.com', 'fastapi.tiangolo.com' + 'docs.python.org', 'pydantic.dev', 'pypi.org', 'readthedocs.org', 'docs.djangoproject.com','flask.palletsprojects.com'] + - search_depth: "advanced" + + Make sure to explicitly pass these parameters in your tool call. + Important: - Do NOT provide actual code snippets or diffs - DO provide clear file paths and function names where changes are needed - Focus on describing WHAT needs to change and WHY, not HOW to change it - Keep descriptions precise and actionable, as they will be used by another agent to implement the changes +- ALWAYS start your analysis with web search results Communicate in a clear, technical manner focused on accurate analysis and practical suggestions rather than implementation details. """ def __init__(self, model: BaseChatModel): - self.system_prompt = SystemMessage(self.SYS_PROMPT) + self.web_search_tool = WebSearchTool() self.model = model - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_analyzer_node" + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self.tools = self._init_tools() + self.model_with_tools = model.bind_tools(self.tools) + self._logger, file_handler = get_thread_logger(__name__) + + def _init_tools(self): + """Initializes tools for the node.""" + tools = [] + + web_search_fn = functools.partial(self.web_search_tool.web_search) + web_search_tool = StructuredTool.from_function( + func=web_search_fn, + name=self.web_search_tool.web_search.__name__, + description=self.web_search_tool.web_search_spec.description, + args_schema=self.web_search_tool.web_search_spec.input_schema, ) + tools.append(web_search_tool) + + return tools def __call__(self, state: Dict): message_history = [self.system_prompt] + state["issue_bug_analyzer_messages"] - response = self.model.invoke(message_history) + response = self.model_with_tools.invoke(message_history) self._logger.debug(response) return {"issue_bug_analyzer_messages": [response]} diff --git a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py index 25e1e5c3..cdf3aac3 100644 --- a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py @@ -1,8 +1,7 @@ -import logging -import threading from typing import Dict from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueBugContextMessageNode: @@ -20,9 +19,7 @@ class IssueBugContextMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_context_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): bug_fix_query = self.BUG_FIX_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py index f66f2cb0..5f646566 100644 --- a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py @@ -1,8 +1,6 @@ -import logging -import threading - from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueBugReproductionContextMessageNode: @@ -109,9 +107,7 @@ def test_file_permission_denied(self, mock_open, mock_access): """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_reproduction_context_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: BugReproductionState): bug_reproducing_query = self.BUG_REPRODUCING_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_bug_responder_node.py b/prometheus/lang_graph/nodes/issue_bug_responder_node.py index 2e079d33..519fb445 100644 --- a/prometheus/lang_graph/nodes/issue_bug_responder_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_responder_node.py @@ -1,11 +1,9 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueBugResponderNode: @@ -49,9 +47,7 @@ def __init__(self, model: BaseChatModel): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model = model - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_responder_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def format_human_message(self, state: IssueBugState) -> HumanMessage: verification_messages = [] diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index e1fe479f..70fa53f6 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Optional, Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -10,6 +8,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_state import IssueState from prometheus.lang_graph.subgraphs.issue_bug_subgraph import IssueBugSubgraph +from prometheus.utils.logger_manager import get_thread_logger class IssueBugSubgraphNode: @@ -26,9 +25,7 @@ def __init__( git_repo: GitRepository, test_commands: Optional[Sequence[str]] = None, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.container = container self.issue_bug_subgraph = IssueBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py index 552c50ec..39cb9ec8 100644 --- a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py @@ -1,8 +1,6 @@ -import logging -import threading - from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueClassificationContextMessageNode: @@ -73,9 +71,7 @@ class IssueClassificationContextMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_classification_context_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: IssueClassificationState): issue_classification_query = self.ISSUE_CLASSIFICATION_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py index 508f77a0..21bbd73b 100644 --- a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from prometheus.graph.knowledge_graph import KnowledgeGraph @@ -8,6 +5,7 @@ from prometheus.lang_graph.subgraphs.issue_classification_subgraph import ( IssueClassificationSubgraph, ) +from prometheus.utils.logger_manager import get_thread_logger class IssueClassificationSubgraphNode: @@ -17,9 +15,7 @@ def __init__( kg: KnowledgeGraph, local_path: str, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_classification_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.issue_classification_subgraph = IssueClassificationSubgraph( model=model, kg=kg, diff --git a/prometheus/lang_graph/nodes/issue_classifier_node.py b/prometheus/lang_graph/nodes/issue_classifier_node.py index 2c259a8f..83dc5ed1 100644 --- a/prometheus/lang_graph/nodes/issue_classifier_node.py +++ b/prometheus/lang_graph/nodes/issue_classifier_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field @@ -8,6 +5,7 @@ from prometheus.lang_graph.graphs.issue_state import IssueType from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueClassifierOutput(BaseModel): @@ -125,9 +123,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(IssueClassifierOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_classifier_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def format_context_info(self, state: IssueClassificationState) -> str: context_info = self.ISSUE_CLASSIFICATION_CONTEXT.format( diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index 45fc621d..17d913b6 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -11,6 +9,7 @@ from prometheus.lang_graph.subgraphs.issue_not_verified_bug_subgraph import ( IssueNotVerifiedBugSubgraph, ) +from prometheus.utils.logger_manager import get_thread_logger class IssueNotVerifiedBugSubgraphNode: @@ -22,9 +21,8 @@ def __init__( git_repo: GitRepository, container: BaseContainer, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_not_verified_bug_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) + self.issue_not_verified_bug_subgraph = IssueNotVerifiedBugSubgraph( advanced_model=advanced_model, base_model=base_model, diff --git a/prometheus/lang_graph/nodes/issue_question_analyzer_node.py b/prometheus/lang_graph/nodes/issue_question_analyzer_node.py index 21bf01dd..98e69e38 100644 --- a/prometheus/lang_graph/nodes/issue_question_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_question_analyzer_node.py @@ -1,11 +1,9 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage from prometheus.lang_graph.subgraphs.issue_question_state import IssueQuestionState from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueQuestionAnalyzerNode: @@ -46,9 +44,7 @@ class IssueQuestionAnalyzerNode: def __init__(self, model: BaseChatModel): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model = model - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_analyzer_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: IssueQuestionState): human_prompt = HumanMessage( diff --git a/prometheus/lang_graph/nodes/issue_question_context_message_node.py b/prometheus/lang_graph/nodes/issue_question_context_message_node.py index b9546398..929ec8a0 100644 --- a/prometheus/lang_graph/nodes/issue_question_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_question_context_message_node.py @@ -1,8 +1,7 @@ -import logging -import threading from typing import Dict from prometheus.utils.issue_util import format_issue_info +from prometheus.utils.logger_manager import get_thread_logger class IssueQuestionContextMessageNode: @@ -21,9 +20,7 @@ class IssueQuestionContextMessageNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_context_message_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): question_query = self.QUESTION_QUERY.format( diff --git a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py index 4bcaf0d2..9b739d12 100644 --- a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langgraph.errors import GraphRecursionError @@ -8,6 +5,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_state import IssueState from prometheus.lang_graph.subgraphs.issue_question_subgraph import IssueQuestionSubgraph +from prometheus.utils.logger_manager import get_thread_logger class IssueQuestionSubgraphNode: @@ -23,9 +21,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.issue_question_subgraph = IssueQuestionSubgraph( advanced_model=advanced_model, base_model=base_model, diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index 0288e270..99e400cc 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -1,6 +1,3 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langgraph.errors import GraphRecursionError @@ -9,6 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState from prometheus.lang_graph.subgraphs.issue_verified_bug_subgraph import IssueVerifiedBugSubgraph +from prometheus.utils.logger_manager import get_thread_logger class IssueVerifiedBugSubgraphNode: @@ -24,9 +22,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_verified_bug_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.git_repo = git_repo self.issue_reproduced_bug_subgraph = IssueVerifiedBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/noop_node.py b/prometheus/lang_graph/nodes/noop_node.py index de06eae3..33f0b94c 100644 --- a/prometheus/lang_graph/nodes/noop_node.py +++ b/prometheus/lang_graph/nodes/noop_node.py @@ -5,10 +5,10 @@ node graphs where a connection is needed but no processing is required. """ -import logging -import threading from typing import Dict +from prometheus.utils.logger_manager import get_thread_logger + class NoopNode: """No-operation node that routes workflow without processing. @@ -20,9 +20,7 @@ class NoopNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.noop_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict) -> None: """Routes the workflow without performing any operations. diff --git a/prometheus/lang_graph/nodes/patch_normalization_node.py b/prometheus/lang_graph/nodes/patch_normalization_node.py index 6e8062ce..12d84711 100644 --- a/prometheus/lang_graph/nodes/patch_normalization_node.py +++ b/prometheus/lang_graph/nodes/patch_normalization_node.py @@ -4,14 +4,13 @@ Provides standardized patch candidates with direct best patch selection. """ -import logging import re -import threading from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Sequence from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState +from prometheus.utils.logger_manager import get_thread_logger @dataclass @@ -39,9 +38,7 @@ class PatchNormalizationNode: """ def __init__(self): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.patch_normalization_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def normalize_patch(self, raw_patch: str) -> str: """Normalize patch content for deduplication diff --git a/prometheus/lang_graph/nodes/reset_messages_node.py b/prometheus/lang_graph/nodes/reset_messages_node.py index 1c37bdde..90e6ebf3 100644 --- a/prometheus/lang_graph/nodes/reset_messages_node.py +++ b/prometheus/lang_graph/nodes/reset_messages_node.py @@ -10,10 +10,10 @@ - The same state attribute name is reused """ -import logging -import threading from typing import Dict +from prometheus.utils.logger_manager import get_thread_logger + class ResetMessagesNode: """Resets message states for workflow loop iterations. @@ -36,9 +36,7 @@ def __init__(self, message_state_key: str): be reset during node execution. """ self.message_state_key = message_state_key - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.reset_messages_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: Dict): """Resets the specified message state for the next iteration. diff --git a/prometheus/lang_graph/nodes/run_existing_tests_node.py b/prometheus/lang_graph/nodes/run_existing_tests_node.py index 1cddef6a..24cd41e0 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_node.py @@ -1,8 +1,6 @@ -import logging -import threading - from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.run_existing_tests_state import RunExistingTestsState +from prometheus.utils.logger_manager import get_thread_logger class RunExistingTestsNode: @@ -11,9 +9,7 @@ class RunExistingTestsNode: """ def __init__(self, container: BaseContainer): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_regression_tests_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.container = container def __call__(self, state: RunExistingTestsState): diff --git a/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py b/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py index 5cbbe094..5e80ff08 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_structure_node.py @@ -1,11 +1,9 @@ -import logging -import threading - from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel from prometheus.lang_graph.subgraphs.run_existing_tests_state import RunExistingTestsState +from prometheus.utils.logger_manager import get_thread_logger class RunExistingTestsStructureOutput(BaseModel): @@ -55,9 +53,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(RunExistingTestsStructureOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_existing_tests_structure_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, state: RunExistingTestsState): # Get human message from the state diff --git a/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py b/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py index ac4b0bd0..c4c9ac54 100644 --- a/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/run_existing_tests_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -7,6 +5,7 @@ from prometheus.docker.base_container import BaseContainer from prometheus.git.git_repository import GitRepository from prometheus.lang_graph.subgraphs.run_existing_tests_subgraph import RunExistingTestsSubgraph +from prometheus.utils.logger_manager import get_thread_logger class RunExistingTestsSubgraphNode: @@ -18,9 +17,7 @@ def __init__( testing_patch_key: str, existing_test_fail_log_key: str, ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_existing_tests_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = RunExistingTestsSubgraph( base_model=model, container=container, git_repo=git_repo ) diff --git a/prometheus/lang_graph/nodes/run_regression_tests_node.py b/prometheus/lang_graph/nodes/run_regression_tests_node.py index ace9a769..5a8e7dd2 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_node.py @@ -1,6 +1,4 @@ import functools -import logging -import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -8,7 +6,8 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.run_regression_tests_state import RunRegressionTestsState -from prometheus.tools import container_command +from prometheus.tools.container_command import ContainerCommandTool +from prometheus.utils.logger_manager import get_thread_logger class RunRegressionTestsNode: @@ -55,22 +54,21 @@ class RunRegressionTestsNode: """ def __init__(self, model: BaseChatModel, container: BaseContainer): - self.tools = self._init_tools(container) + self.container_command_tool = ContainerCommandTool(container) + self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_regression_tests_node" - ) + self._logger, file_handler = get_thread_logger(__name__) - def _init_tools(self, container: BaseContainer): + def _init_tools(self): tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_fn = functools.partial(self.container_command_tool.run_command) run_command_tool = StructuredTool.from_function( func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, + name=self.container_command_tool.run_command.__name__, + description=self.container_command_tool.run_command_spec.description, + args_schema=self.container_command_tool.run_command_spec.input_schema, ) tools.append(run_command_tool) diff --git a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py index ad1da1b5..1248c0e6 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -8,6 +6,7 @@ from prometheus.lang_graph.subgraphs.run_regression_tests_state import RunRegressionTestsState from prometheus.utils.lang_graph_util import get_last_message_content +from prometheus.utils.logger_manager import get_thread_logger class RunRegressionTestsStructureOutput(BaseModel): @@ -15,7 +14,7 @@ class RunRegressionTestsStructureOutput(BaseModel): description="List of test identifier of regression tests that passed (e.g., class name and method name)" ) regression_test_fail_log: str = Field( - description="If the test failed, contains the complete test FAILURE log. Otherwise empty string" + description="If any test failed, contains the exact and complete test FAILURE log. Otherwise empty string" ) total_tests_run: int = Field( description="Total number of tests run, including both passed and failed tests, or 0 if no tests were run", @@ -32,7 +31,7 @@ class RunRegressionTestsStructuredNode: - Test summary showing "passed" or "PASSED" - Warning is ok - No "FAILURES" section -2. If a test fails, capture the complete failure output. Otherwise empty string for failure log +2. If a test fails, capture the exact and complete failure output. Otherwise empty string for failure log 3. Return the exact test identifiers that passed 4. Count the total number of tests run. Only count tests that were actually executed! If tests were unable to run due to an error, do not count them! @@ -104,9 +103,7 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(RunRegressionTestsStructureOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_regression_tests_structure_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def get_human_message(self, state: RunRegressionTestsState) -> str: # Format the human message using the state diff --git a/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py b/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py index d07f7b6d..678de625 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_subgraph_node.py @@ -1,5 +1,3 @@ -import logging -import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -7,15 +5,14 @@ from prometheus.docker.base_container import BaseContainer from prometheus.lang_graph.subgraphs.run_regression_tests_subgraph import RunRegressionTestsSubgraph +from prometheus.utils.logger_manager import get_thread_logger class RunRegressionTestsSubgraphNode: def __init__( self, model: BaseChatModel, container: BaseContainer, passed_regression_tests_key: str ): - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.run_regression_tests_subgraph_node" - ) + self._logger, file_handler = get_thread_logger(__name__) self.subgraph = RunRegressionTestsSubgraph( base_model=model, container=container, diff --git a/prometheus/lang_graph/nodes/update_container_node.py b/prometheus/lang_graph/nodes/update_container_node.py index 41a42cf5..be15191f 100644 --- a/prometheus/lang_graph/nodes/update_container_node.py +++ b/prometheus/lang_graph/nodes/update_container_node.py @@ -6,12 +6,11 @@ between the agent's workspace and the container environment. """ -import logging -import threading from typing import Dict from prometheus.docker.base_container import BaseContainer from prometheus.git.git_repository import GitRepository +from prometheus.utils.logger_manager import get_thread_logger from prometheus.utils.patch_util import get_updated_files @@ -34,9 +33,7 @@ def __init__(self, container: BaseContainer, git_repo: GitRepository): """ self.container = container self.git_repo = git_repo - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.update_container_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, _: Dict): """Synchronizes the current project state with the container.""" diff --git a/prometheus/lang_graph/nodes/user_defined_build_node.py b/prometheus/lang_graph/nodes/user_defined_build_node.py index f5e47d1a..7cfa5650 100644 --- a/prometheus/lang_graph/nodes/user_defined_build_node.py +++ b/prometheus/lang_graph/nodes/user_defined_build_node.py @@ -1,19 +1,16 @@ -import logging -import threading import uuid from typing import Any from langchain_core.messages import ToolMessage from prometheus.docker.base_container import BaseContainer +from prometheus.utils.logger_manager import get_thread_logger class UserDefinedBuildNode: def __init__(self, container: BaseContainer): self.container = container - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.user_defined_build_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, _: Any): build_output = self.container.run_build() diff --git a/prometheus/lang_graph/nodes/user_defined_test_node.py b/prometheus/lang_graph/nodes/user_defined_test_node.py index 47a43859..948c3eec 100644 --- a/prometheus/lang_graph/nodes/user_defined_test_node.py +++ b/prometheus/lang_graph/nodes/user_defined_test_node.py @@ -1,19 +1,16 @@ -import logging -import threading import uuid from typing import Any from langchain_core.messages import ToolMessage from prometheus.docker.base_container import BaseContainer +from prometheus.utils.logger_manager import get_thread_logger class UserDefinedTestNode: def __init__(self, container: BaseContainer): self.container = container - self._logger = logging.getLogger( - f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.user_defined_test_node" - ) + self._logger, file_handler = get_thread_logger(__name__) def __call__(self, _: Any): test_output = self.container.run_test() diff --git a/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py b/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py index dff4c0e7..ea678434 100644 --- a/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py +++ b/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py @@ -70,7 +70,7 @@ def __init__( # Step 3: Write a patch to reproduce the bug bug_reproducing_write_message_node = BugReproducingWriteMessageNode() bug_reproducing_write_node = BugReproducingWriteNode( - advanced_model, git_repo.playground_path + advanced_model, git_repo.playground_path, kg ) bug_reproducing_write_tools = ToolNode( tools=bug_reproducing_write_node.tools, diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index 9ffb5cf6..43a1753c 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -45,9 +45,14 @@ def __init__( issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) + issue_bug_analyzer_tools = ToolNode( + tools=issue_bug_analyzer_node.tools, + name="issue_bug_analyzer_tools", + messages_key="issue_bug_analyzer_messages", + ) edit_message_node = EditMessageNode() - edit_node = EditNode(advanced_model, git_repo.playground_path) + edit_node = EditNode(advanced_model, git_repo.playground_path, kg) edit_tools = ToolNode( tools=edit_node.tools, name="edit_tools", @@ -81,6 +86,7 @@ def __init__( workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) + workflow.add_node("issue_bug_analyzer_tools", issue_bug_analyzer_tools) workflow.add_node("edit_message_node", edit_message_node) workflow.add_node("edit_node", edit_node) @@ -106,7 +112,15 @@ def __init__( workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") - workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") + + # Conditionally invoke tools or continue to edit message + workflow.add_conditional_edges( + "issue_bug_analyzer_node", + functools.partial(tools_condition, messages_key="issue_bug_analyzer_messages"), + {"tools": "issue_bug_analyzer_tools", END: "edit_message_node"}, + ) + + workflow.add_edge("issue_bug_analyzer_tools", "issue_bug_analyzer_node") workflow.add_edge("edit_message_node", "edit_node") workflow.add_conditional_edges( diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py index 3d23f66e..de1a1c0d 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py @@ -79,10 +79,15 @@ def __init__( # Phase 2: Analyze the bug and generate hypotheses issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) + issue_bug_analyzer_tools = ToolNode( + tools=issue_bug_analyzer_node.tools, + name="issue_bug_analyzer_tools", + messages_key="issue_bug_analyzer_messages", + ) # Phase 3: Generate code edits and optionally apply toolchains edit_message_node = EditMessageNode() - edit_node = EditNode(advanced_model, git_repo.playground_path) + edit_node = EditNode(advanced_model, git_repo.playground_path, kg) edit_tools = ToolNode( tools=edit_node.tools, name="edit_tools", @@ -128,6 +133,7 @@ def __init__( workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) + workflow.add_node("issue_bug_analyzer_tools", issue_bug_analyzer_tools) workflow.add_node("edit_message_node", edit_message_node) workflow.add_node("edit_node", edit_node) @@ -150,7 +156,15 @@ def __init__( workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") - workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") + + # Conditionally invoke tools or continue to edit message + workflow.add_conditional_edges( + "issue_bug_analyzer_node", + functools.partial(tools_condition, messages_key="issue_bug_analyzer_messages"), + {"tools": "issue_bug_analyzer_tools", END: "edit_message_node"}, + ) + + workflow.add_edge("issue_bug_analyzer_tools", "issue_bug_analyzer_node") workflow.add_edge("edit_message_node", "edit_node") # Conditionally invoke tools or continue to diffing diff --git a/prometheus/neo4j/knowledge_graph_handler.py b/prometheus/neo4j/knowledge_graph_handler.py index c953cb49..ee5c79bf 100644 --- a/prometheus/neo4j/knowledge_graph_handler.py +++ b/prometheus/neo4j/knowledge_graph_handler.py @@ -1,6 +1,5 @@ """The neo4j handler for writing the knowledge graph to neo4j.""" -import logging from typing import Mapping, Sequence from neo4j import AsyncGraphDatabase, AsyncManagedTransaction @@ -16,6 +15,7 @@ Neo4jTextNode, ) from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.utils.logger_manager import get_thread_logger class KnowledgeGraphHandler: @@ -30,7 +30,7 @@ def __init__(self, driver: AsyncGraphDatabase.driver, batch_size: int): self.driver = driver self.batch_size = batch_size # initialize the database and logger - self._logger = logging.getLogger("prometheus.neo4j.knowledge_graph_handler") + self._logger, file_handler = get_thread_logger(__name__) async def init_database(self): """Initialization of the neo4j database.""" diff --git a/prometheus/tools/container_command.py b/prometheus/tools/container_command.py index 0d764613..e906528f 100644 --- a/prometheus/tools/container_command.py +++ b/prometheus/tools/container_command.py @@ -1,17 +1,43 @@ +from dataclasses import dataclass + from pydantic import BaseModel, Field from prometheus.docker.base_container import BaseContainer +@dataclass +class ToolSpec: + description: str + input_schema: type + + class RunCommandInput(BaseModel): command: str = Field("The shell command to be run in the container") -RUN_COMMAND_DESCRIPTION = """\ -Run a shell command in the container and return the result of the command. You are always at the root -of the codebase. -""" +class ContainerCommandTool: + """Tool class for executing shell commands in containers.""" + + run_command_spec = ToolSpec( + description="""\ + Run a shell command in the container and return the result of the command. You are always at the root + of the codebase. + """, + input_schema=RunCommandInput, + ) + def __init__(self, container: BaseContainer): + """Initialize the container command tool. + Args: + container: The GeneralContainer instance to execute commands in. + """ + self.container = container -def run_command(command: str, container: BaseContainer) -> str: - return container.execute_command(command) + def run_command(self, command: str) -> str: + """Run a shell command in the container and return the result. + Args: + command: The shell command to be run in the container. + Returns: + The output of the command execution. + """ + return self.container.execute_command(command) diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index eb7bc7df..c1057cea 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -1,6 +1,6 @@ -import logging import os import shutil +from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Tuple, Union @@ -8,80 +8,20 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.utils.knowledge_graph_utils import format_knowledge_graph_data +from prometheus.utils.logger_manager import get_thread_logger from prometheus.utils.str_util import pre_append_line_numbers -logger = logging.getLogger("prometheus.tools.file_operation") +logger, file_handler = get_thread_logger(__name__) -class ReadFileInput(BaseModel): - relative_path: str = Field("The relative path of the file to read") - - -READ_FILE_DESCRIPTION = """\ -Read the content of a file with line numbers prepended from the codebase with a safety limit on the number of lines. -Returns up to the first 1000 lines by default to prevent context issues with large files. -Returns an error message if the file doesn't exist. -""" - - -def read_file(relative_path: str, root_path: str, n_lines: int = 1000) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." - - with file_path.open() as f: - lines = f.readlines() - - return pre_append_line_numbers("".join(lines[:n_lines]), 1) - - -def read_file_with_knowledge_graph_data( - relative_path: str, root_path: str, kg: KnowledgeGraph -) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: - """ - Read the content of a file and return it along with structured knowledge graph data. - Used for context provider node - """ - - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path.", None - - file_node = None - for file_node_ in kg.get_file_nodes(): - if file_node_.node.relative_path == relative_path: - file_node = file_node_ - break - - # Check if file node exists in the knowledge graph - if not file_node: - return f"The file {relative_path} does not exist.", None +@dataclass +class ToolSpec: + description: str + input_schema: type - file_path = Path(os.path.join(root_path, file_node.node.relative_path)) - # Read the file content - with file_path.open() as f: - lines = f.readlines() - # Limit to first 1000 lines to avoid context issues - selected_text_with_line_numbers = pre_append_line_numbers("".join(lines[:1000]), 1) - - result_data = [ - { - "FileNode": { - "node_id": file_node.node_id, - "basename": file_node.node.basename, - "relative_path": file_node.node.relative_path, - }, - "preview": { - "text": selected_text_with_line_numbers, - "start_line": 1, - "end_line": len(selected_text_with_line_numbers.splitlines()), - }, - } - ] - return format_knowledge_graph_data(result_data), result_data +class ReadFileInput(BaseModel): + relative_path: str = Field("The relative path of the file to read") class ReadFileWithLineNumbersInput(BaseModel): @@ -92,38 +32,6 @@ class ReadFileWithLineNumbersInput(BaseModel): end_line: int = Field(description="The ending line number to read, 1-indexed and exclusive") -READ_FILE_WITH_LINE_NUMBERS_DESCRIPTION = """\ -Read a specific range of lines from a file and return the content with line numbers prepended. -The line numbers are 1-indexed where start_line is inclusive and end_line is exclusive. -For best results when analyzing code or text files, consider reading chunks of 500-1000 lines at a time. -""" - - -def read_file_with_line_numbers( - relative_path: str, root_path: str, start_line: int, end_line: int -) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." - - if end_line < start_line: - return f"The end line number {end_line} must be greater than the start line number {start_line}." - - zero_based_start_line = start_line - 1 - zero_based_end_line = end_line - 1 - - with file_path.open() as f: - lines = f.readlines() - final_content = "".join(lines[zero_based_start_line:zero_based_end_line]) - if not final_content: - return f"No content found between lines {start_line} and {end_line} in {relative_path}!" - - return pre_append_line_numbers(final_content, start_line) - - class CreateFileInput(BaseModel): relative_path: str = Field( description="The relative path of the file to create, eg. foo/bar/test.py, not absolute path" @@ -131,55 +39,12 @@ class CreateFileInput(BaseModel): content: str = Field(description="The content of the file to create") -CREATE_FILE_DESCRIPTION = """\ -Create a new file at the specified path with the given content. -If the parent directories don't exist, they will be created automatically. -Returns an error message if the file already exists. -""" - - -def create_file(relative_path: str, root_path: str, content: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if file_path.exists(): - return f"The file {relative_path} already exists." - - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(content) - return f"The file {relative_path} has been created." - - class DeleteInput(BaseModel): relative_path: str = Field( description="The relative path of the file/dir to delete, eg. foo/bar/test.py, not absolute path" ) -DELETE_DESCRIPTION = """\ -Delete a file or directory at the specified path. -For directories, it will recursively delete all contents. -Returns an error message if the path doesn't exist. -""" - - -def delete(relative_path: str, root_path: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." - - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." - - if file_path.is_dir(): - shutil.rmtree(file_path) - return f"The directory {relative_path} has been deleted." - - file_path.unlink() - return f"The file {relative_path} has been deleted." - - class EditFileInput(BaseModel): relative_path: str = Field( description="The relative path of the file to edit, eg. foo/bar/test.py, not absolute path" @@ -192,46 +57,208 @@ class EditFileInput(BaseModel): ) -EDIT_FILE_DESCRIPTION = """\ -Edit a file by replacing specific content with new content. -Performs an exact string replacement of old_content with new_content. -Returns an error message if: -- The file doesn't exist -- The old_content is not found in the file -- The old_content matches multiple locations (in which case more context is needed) -- The provided path is absolute instead of relative +class FileOperationTool: + """Tool class for file operations including reading, creating, editing, and deleting files.""" -Example usage: -edit_file( - relative_path="src/calculator.py", - old_content="return a * b", - new_content="return a / b" -) -""" + # File operation tools + read_file_spec = ToolSpec( + description="""\ + Read the content of a file with line numbers prepended from the codebase with a safety limit on the number of lines. + Returns up to the first 1000 lines by default to prevent context issues with large files. + Returns an error message if the file doesn't exist. + """, + input_schema=ReadFileInput, + ) + read_file_with_line_numbers_spec = ToolSpec( + description="""\ + Read a specific range of lines from a file and return the content with line numbers prepended. + The line numbers are 1-indexed where start_line is inclusive and end_line is exclusive. + For best results when analyzing code or text files, consider reading chunks of 500-1000 lines at a time. + """, + input_schema=ReadFileWithLineNumbersInput, + ) -def edit_file(relative_path: str, root_path: str, old_content: str, new_content: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a absolute path, not relative path." + create_file_spec = ToolSpec( + description="""\ + Create a new file at the specified path with the given content. + If the parent directories don't exist, they will be created automatically. + Returns an error message if the file already exists. + """, + input_schema=CreateFileInput, + ) - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." + delete_spec = ToolSpec( + description="""\ + Delete a file or directory at the specified path. + For directories, it will recursively delete all contents. + Returns an error message if the path doesn't exist. + """, + input_schema=DeleteInput, + ) - content = file_path.read_text() + edit_file_spec = ToolSpec( + description="""\ + Edit a file by replacing specific content with new content. + Performs an exact string replacement of old_content with new_content. + Returns an error message if: + - The file doesn't exist + - The old_content is not found in the file + - The old_content matches multiple locations (in which case more context is needed) + - The provided path is absolute instead of relative + + Example usage: + edit_file( + relative_path="src/calculator.py", + old_content="return a * b", + new_content="return a / b" + ) + """, + input_schema=EditFileInput, + ) - occurrences = content.count(old_content) + def __init__(self, root_path: str, kg: KnowledgeGraph): + """Initialize the file operation tool. + Args: + root_path: The root path of the codebase for relative path operations. + kg: The knowledge graph for context provider node. + Args: + root_path: The root path of the codebase for relative path operations. + kg: The knowledge graph for context provider node. + """ + self.root_path = root_path + self.kg = kg - if occurrences == 0: - return f"No match found for the specified content in {relative_path}. Please verify the content to replace." + def read_file(self, relative_path: str, n_lines: int = 1000) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a absolute path, not relative path." - if occurrences > 1: - return ( - f"Found {occurrences} occurrences of the specified content in {relative_path}. " - "Please provide more context to ensure a unique match." - ) + file_path = Path(os.path.join(self.root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." + + with file_path.open() as f: + lines = f.readlines() + + return pre_append_line_numbers("".join(lines[:n_lines]), 1) - new_content_full = content.replace(old_content, new_content) - file_path.write_text(new_content_full) + def read_file_with_line_numbers( + self, relative_path: str, start_line: int, end_line: int + ) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a absolute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." + + if end_line < start_line: + return f"The end line number {end_line} must be greater than the start line number {start_line}." + + zero_based_start_line = start_line - 1 + zero_based_end_line = end_line - 1 + + with file_path.open() as f: + lines = f.readlines() + + return pre_append_line_numbers( + "".join(lines[zero_based_start_line:zero_based_end_line]), start_line + ) - return f"Successfully edited {relative_path}." + def create_file(self, relative_path: str, content: str) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a absolute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if file_path.exists(): + return f"The file {relative_path} already exists." + + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content) + return f"The file {relative_path} has been created." + + def delete(self, relative_path: str) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a absolute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." + + if file_path.is_dir(): + shutil.rmtree(file_path) + return f"The directory {relative_path} has been deleted." + + file_path.unlink() + return f"The file {relative_path} has been deleted." + + def edit_file(self, relative_path: str, old_content: str, new_content: str) -> str: + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a absolute path, not relative path." + + file_path = Path(os.path.join(self.root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." + + content = file_path.read_text() + + occurrences = content.count(old_content) + + if occurrences == 0: + return f"No match found for the specified content in {relative_path}. Please verify the content to replace." + + if occurrences > 1: + return ( + f"Found {occurrences} occurrences of the specified content in {relative_path}. " + "Please provide more context to ensure a unique match." + ) + + new_content_full = content.replace(old_content, new_content) + file_path.write_text(new_content_full) + + return f"Successfully edited {relative_path}." + + def read_file_with_knowledge_graph_data( + self, relative_path: str + ) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: + """ + Read the content of a file and return it along with structured knowledge graph data. + Used for context provider node + """ + + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a absolute path, not relative path.", None + + file_node = None + for file_node_ in self.kg.get_file_nodes(): + if file_node_.node.relative_path == relative_path: + file_node = file_node_ + break + + # Check if file node exists in the knowledge graph + if not file_node: + return f"The file {relative_path} does not exist.", None + + file_path = Path(os.path.join(self.root_path, file_node.node.relative_path)) + + # Read the file content + with file_path.open() as f: + lines = f.readlines() + # Limit to first 1000 lines to avoid context issues + selected_text_with_line_numbers = pre_append_line_numbers("".join(lines[:1000]), 1) + + result_data = [ + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "preview": { + "text": selected_text_with_line_numbers, + "start_line": 1, + "end_line": len(selected_text_with_line_numbers.splitlines()), + }, + } + ] + return format_knowledge_graph_data(result_data), result_data diff --git a/prometheus/tools/graph_traversal.py b/prometheus/tools/graph_traversal.py index 8a1c0c0d..2cec41ad 100644 --- a/prometheus/tools/graph_traversal.py +++ b/prometheus/tools/graph_traversal.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Tuple, Union @@ -20,395 +21,405 @@ """ -############################################################################### -# FileNode retrieval # -############################################################################### +@dataclass +class ToolSpec: + description: str + input_schema: type class FindFileNodeWithBasenameInput(BaseModel): basename: str = Field("The basename of FileNode to search for") -FIND_FILE_NODE_WITH_BASENAME_DESCRIPTION = """\ -Find all FileNode in the graph with this basename of a file/dir. The basename must -include the extension, like 'bar.py', 'baz.java' or 'foo' -(in this case foo is a directory or a file without extension). - -You can use this tool to check if a file/dir with this basename exists or get all -attributes related to the file/dir.""" - - -def find_file_node_with_basename( - basename: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all FileNodes with the given basename.""" - results = [] - for kg_node in kg.get_file_nodes(): - if kg_node.node.basename == basename: - results.append( - { - "FileNode": { - "node_id": kg_node.node_id, - "basename": kg_node.node.basename, - "relative_path": kg_node.node.relative_path, - } - } - ) - results.sort(key=lambda x: x["FileNode"]["node_id"]) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - - class FindFileNodeWithRelativePathInput(BaseModel): relative_path: str = Field("The relative_path of FileNode to search for") -FIND_FILE_NODE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Search FileNode in the graph with this relative_path of a file/dir. The relative_path is -the relative path from the root path of codebase. The relative_path must include the extension, -like 'foo/bar/baz.java'. - -You can use this tool to check if a file/dir with this relative_path exists or get all -attributes related to the file/dir.""" - - -def find_file_node_with_relative_path( - relative_path: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all FileNodes with the given relative path.""" - results = [] - for kg_node in kg.get_file_nodes(): - if kg_node.node.relative_path == relative_path: - results.append( - { - "FileNode": { - "node_id": kg_node.node_id, - "basename": kg_node.node.basename, - "relative_path": kg_node.node.relative_path, - } - } - ) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - - -############################################################################### -# ASTNode retrieval # -############################################################################### - - -def find_ast_node_with_text_in_file( - text: str, target_files_nodes: List[KnowledgeGraphNode], kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes containing the given text in files with the given basename.""" - results = [] - - # Get HAS_AST edges to find which AST nodes belong to these files - has_ast_edges = kg.get_has_ast_edges() - file_to_ast_map = { - edge.source.node_id: edge.target - for edge in has_ast_edges - if edge.source.node_id in [n.node_id for n in target_files_nodes] - } - - # Construct parent to children map for AST traversal - parent_to_children = kg.get_parent_to_children_map() - - # Get root AstNode id list - root_ast_node_ids = set([node.node_id for node in file_to_ast_map.values()]) - - for file_node in target_files_nodes: - # Start with root AST node for this file - root_ast = file_to_ast_map[file_node.node_id] - - # Add all descendant AST nodes - stack = [root_ast] - while stack: - current_node = stack.pop() - - # Check if the current node contains the text - # Don't include the root AST node itself - if text in current_node.node.text and current_node.node_id not in root_ast_node_ids: - results.append( - { - "FileNode": { - "node_id": file_node.node_id, - "basename": file_node.node.basename, - "relative_path": file_node.node.relative_path, - }, - "ASTNode": { - "node_id": current_node.node_id, - "type": current_node.node.type, - "start_line": current_node.node.start_line, - "end_line": current_node.node.end_line, - "text": current_node.node.text, - }, - } - ) - - # Add children to stack - stack += parent_to_children.get(current_node.node_id, []) - - # Sort by text length (smaller first) - results.sort(key=lambda x: len(x["ASTNode"]["text"])) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - - class FindASTNodeWithTextInFileWithBasenameInput(BaseModel): text: str = Field("Search ASTNode that exactly contains this text.") basename: str = Field("The basename of file/directory to search under for ASTNodes.") -FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_BASENAME_DESCRIPTION = """\ -Find all ASTNode in the graph that exactly contains this text in any source file with this basename. -For reliable results, search for longer, distinct text sequences rather than short common words or fragments. -The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is looking for exact matches. -For best results, use unique text segments of at least several words. The basename can be either a file (like -'bar.py', 'baz.java').""" - - -def find_ast_node_with_text_in_file_with_basename( - text: str, basename: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes containing the given text in files with the given basename.""" - # Get file nodes with the given basename - target_files_nodes: List[KnowledgeGraphNode] = [ - node for node in kg.get_file_nodes() if node.node.basename == basename - ] - return find_ast_node_with_text_in_file(text, target_files_nodes, kg) - - class FindASTNodeWithTextInFileWithRelativePathInput(BaseModel): text: str = Field("Search ASTNode that exactly contains this text.") relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") -FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Find all ASTNode in the graph that exactly contains this text in any source file with this relative path. -For reliable results, search for longer, distinct text sequences rather than short common words or fragments. - The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is looking for exact matches. -Therefore the search text should be exact as well. The relative path should be the path from the root of codebase -(like 'src/core/parser.py').""" - - -def find_ast_node_with_text_in_file_with_relative_path( - text: str, relative_path: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes containing the given text in files with the given relative path.""" - # Get file nodes with the given basename - target_files_nodes: List[KnowledgeGraphNode] = [ - node for node in kg.get_file_nodes() if node.node.relative_path == relative_path - ] - return find_ast_node_with_text_in_file(text, target_files_nodes, kg) - - -def find_ast_node_with_type_in_file( - type: str, target_files_nodes: List[KnowledgeGraphNode], kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes containing the given text in files with the given basename.""" - results = [] - - # Get HAS_AST edges to find which AST nodes belong to these files - has_ast_edges = kg.get_has_ast_edges() - file_to_ast_map = { - edge.source.node_id: edge.target - for edge in has_ast_edges - if edge.source.node_id in [n.node_id for n in target_files_nodes] - } - - # Construct parent to children map for AST traversal - parent_to_children = kg.get_parent_to_children_map() - - # Get root AstNode id list - root_ast_node_ids = set([node.node_id for node in file_to_ast_map.values()]) - - for file_node in target_files_nodes: - # Start with root AST node for this file - root_ast = file_to_ast_map[file_node.node_id] - - # Add all descendant AST nodes - stack = [root_ast] - while stack: - current_node = stack.pop() - - # Check if current node contains the text - # Don't include the root AST node itself - if current_node.node.type == type and current_node.node_id not in root_ast_node_ids: - results.append( - { - "FileNode": { - "node_id": file_node.node_id, - "basename": file_node.node.basename, - "relative_path": file_node.node.relative_path, - }, - "ASTNode": { - "node_id": current_node.node_id, - "type": current_node.node.type, - "start_line": current_node.node.start_line, - "end_line": current_node.node.end_line, - "text": current_node.node.text, - }, - } - ) +class FindASTNodeWithTypeInFileWithBasenameInput(BaseModel): + type: str = Field("Search ASTNode with this tree-sitter node type.") + basename: str = Field("The basename of file/directory to search under for ASTNodes.") - # Add children to stack - stack += parent_to_children.get(current_node.node_id, []) - # Sort by text length (smaller first) - results.sort(key=lambda x: len(x["ASTNode"]["text"])) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] +class FindASTNodeWithTypeInFileWithRelativePathInput(BaseModel): + type: str = Field("Search ASTNode with this tree-sitter node type.") + relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") -class FindASTNodeWithTypeInFileWithBasenameInput(BaseModel): - type: str = Field("Search ASTNode with this tree-sitter node type.") - basename: str = Field("The basename of file/directory to search under for ASTNodes.") +class FindTextNodeWithTextInput(BaseModel): + text: str = Field("Search TextNode that exactly contains this text.") -FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_BASENAME_DESCRIPTION = """\ -Find all ASTNode in the graph that has this tree-sitter node type in any source file with this basename. -This tool is useful for searching class/function/method under files.""" +class FindTextNodeWithTextInFileInput(BaseModel): + text: str = Field("Search TextNode that exactly contains this text.") + basename: str = Field("The basename of FileNode to search TextNode.") -def find_ast_node_with_type_in_file_with_basename( - type: str, basename: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes with the given type in files with the given basename.""" - # Get file nodes with the given basename - target_files_nodes: List[KnowledgeGraphNode] = [ - node for node in kg.get_file_nodes() if node.node.basename == basename - ] - return find_ast_node_with_type_in_file(type, target_files_nodes, kg) +class GetNextTextNodeWithNodeIdInput(BaseModel): + node_id: int = Field("Get the next TextNode of this given node_id.") -class FindASTNodeWithTypeInFileWithRelativePathInput(BaseModel): - type: str = Field("Search ASTNode with this tree-sitter node type.") - relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") +class PreviewFileContentWithBasenameInput(BaseModel): + basename: str = Field("The basename of FileNode to preview.") -FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Find all ASTNode in the graph that has this tree-sitter node type in any source file with this relative path. -This tool is useful for searching class/function/method under a file.""" +class PreviewFileContentWithRelativePathInput(BaseModel): + relative_path: str = Field("The relative path of FileNode to preview.") -def find_ast_node_with_type_in_file_with_relative_path( - type: str, relative_path: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all ASTNodes with the given type in files with the given relative path.""" - # Get file nodes with the given basename - target_files_nodes: List[KnowledgeGraphNode] = [ - node for node in kg.get_file_nodes() if node.node.relative_path == relative_path - ] - return find_ast_node_with_type_in_file(type, target_files_nodes, kg) +class ReadCodeWithBasenameInput(BaseModel): + basename: str = Field("The basename of FileNode to read.") + start_line: int = Field("The starting line number, 1-indexed and inclusive.") + end_line: int = Field("The ending line number, 1-indexed and exclusive.") -############################################################################### -# TextNode retrieval # -############################################################################### +class ReadCodeWithRelativePathInput(BaseModel): + relative_path: str = Field("The relative path of FileNode to read from root of codebase.") + start_line: int = Field("The starting line number, 1-indexed and inclusive.") + end_line: int = Field("The ending line number, 1-indexed and exclusive.") -def find_file_node_of_a_text_node( - text_node: KnowledgeGraphNode, kg: KnowledgeGraph -) -> KnowledgeGraphNode: - """ - Find a file node that contains the given text node. - """ - next_chunk_reverse_map = { - edge.target.node_id: edge.source for edge in kg.get_next_chunk_edges() - } - has_file_node_map = {edge.target.node_id: edge.source for edge in kg.get_has_text_edges()} +class GraphTraversalTool: + # FileNode retrieval tools + find_file_node_with_basename_spec = ToolSpec( + description="""Find all FileNode in the graph with this basename of a file/dir. The basename must + include the extension, like 'bar.py', 'baz.java' or 'foo' + (in this case foo is a directory or a file without extension). - # Find the root text node - current_text_node = text_node - while next_chunk_reverse_map.get(current_text_node.node_id, None) is not None: - current_text_node = next_chunk_reverse_map[current_text_node.node_id] + You can use this tool to check if a file/dir with this basename exists or get all + attributes related to the file/dir.""", + input_schema=FindFileNodeWithBasenameInput, + ) - # Now current_text_node is the root text node - file_node = has_file_node_map[current_text_node.node_id] - return file_node + find_file_node_with_relative_path_spec = ToolSpec( + description="""Search FileNode in the graph with this relative_path of a file/dir. The relative_path is + the relative path from the root path of codebase. The relative_path must include the extension, + like 'foo/bar/baz.java'. + You can use this tool to check if a file/dir with this relative_path exists or get all + attributes related to the file/dir.""", + input_schema=FindFileNodeWithRelativePathInput, + ) -class FindTextNodeWithTextInput(BaseModel): - text: str = Field("Search TextNode that exactly contains this text.") + # ASTNode retrieval tools + find_ast_node_with_text_in_file_with_basename_spec = ToolSpec( + description="""Find all ASTNode in the graph that exactly contains this text in any source file with this basename. + For reliable results, search for longer, distinct text sequences rather than short common words or fragments. + The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is looking for exact matches. + For best results, use unique text segments of at least several words. The basename can be either a file (like + 'bar.py', 'baz.java').""", + input_schema=FindASTNodeWithTextInFileWithBasenameInput, + ) + find_ast_node_with_text_in_file_with_relative_path_spec = ToolSpec( + description="""Find all ASTNode in the graph that exactly contains this text in any source file with this relative path. + For reliable results, search for longer, distinct text sequences rather than short common words or fragments. + The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is looking for exact matches. + Therefore the search text should be exact as well. The relative path should be the path from the root of codebase + (like 'src/core/parser.py').""", + input_schema=FindASTNodeWithTextInFileWithRelativePathInput, + ) -FIND_TEXT_NODE_WITH_TEXT_DESCRIPTION = """\ -Find all TextNode in the graph that exactly contains this text. The contains is -same as python's check `'foo' in text`, ie. it is case sensitive and is -looking for exact matches. Therefore the search text should be exact as well. + find_ast_node_with_type_in_file_with_basename_spec = ToolSpec( + description="""Find all ASTNode in the graph that has this tree-sitter node type in any source file with this basename. + This tool is useful for searching class/function/method under files.""", + input_schema=FindASTNodeWithTypeInFileWithBasenameInput, + ) -Text Node is a chunk of text extracted from a text file, such as comments or documentation. -Source code files are not split into TextNodes! + find_ast_node_with_type_in_file_with_relative_path_spec = ToolSpec( + description="""Find all ASTNode in the graph that has this tree-sitter node type in any source file with this relative path. + This tool is useful for searching class/function/method under a file.""", + input_schema=FindASTNodeWithTypeInFileWithRelativePathInput, + ) -You can use this tool to find all text/documentation in codebase that contains this text.""" + # TextNode retrieval tools + find_text_node_with_text_spec = ToolSpec( + description="""Find all TextNode in the graph that exactly contains this text. The contains is + same as python's check `'foo' in text`, ie. it is case sensitive and is + looking for exact matches. Therefore the search text should be exact as well. + You can use this tool to find all text/documentation in codebase that contains this text.""", + input_schema=FindTextNodeWithTextInput, + ) -def find_text_node_with_text(text: str, kg: KnowledgeGraph) -> Tuple[str, List[Dict[str, Any]]]: - """Find all TextNodes containing the given text.""" - results = [] - # Find text nodes that contain the given text - text_nodes_with_text = [node for node in kg.get_text_nodes() if text in node.node.text] + find_text_node_with_text_in_file_spec = ToolSpec( + description="""Find all TextNode in the graph that exactly contains this text in a file with this basename. + The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is + looking for exact matches. Therefore the search text should be exact as well. + The basename must include the extension, like 'bar.py', 'baz.java' or 'foo' + (in this case foo is a file without extension). - # If no text nodes found, return early - if not text_nodes_with_text: - return format_knowledge_graph_data([]), [] - for text_node in text_nodes_with_text: - # Find the file node that contains this text node - file_node = find_file_node_of_a_text_node(text_node, kg) - results.append( - { - "FileNode": { - "node_id": file_node.node_id, - "basename": file_node.node.basename, - "relative_path": file_node.node.relative_path, - }, - "TextNode": { - "node_id": text_node.node_id, - "text": text_node.node.text, - "start_line": text_node.node.start_line, - "end_line": text_node.node.end_line, - }, - } - ) + You can use this tool to find text/documentation in a specific file that contains this text.""", + input_schema=FindTextNodeWithTextInFileInput, + ) - # Sort by node_id - results.sort(key=lambda x: x["TextNode"]["node_id"]) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] + get_next_text_node_with_node_id_spec = ToolSpec( + description="""Get the next TextNode of this given node_id. + You can use this tool to read the next section of text that you are interested in.""", + input_schema=GetNextTextNodeWithNodeIdInput, + ) -class FindTextNodeWithTextInFileInput(BaseModel): - text: str = Field("Search TextNode that exactly contains this text.") - basename: str = Field("The basename of FileNode to search TextNode.") + read_code_with_relative_path_spec = ToolSpec( + description="""Read a specific section of a source code file's content by specifying its relative path and line range. + The relative path must be the full path from the root of codebase, like 'src/core/parser.py' or + 'test/unit/test_parser.java'. + + This tool ONLY works with source code files (not text files or documentation). It is designed + to read large sections of code at once - you should request substantial chunks (hundreds of lines) + rather than making multiple small requests of 10-20 lines each, which would be inefficient. + Line numbers are 1-indexed, where start_line is inclusive and end_line is exclusive. -FIND_TEXT_NODE_WITH_TEXT_IN_FILE_DESCRIPTION = """\ -Find all TextNode in the graph that exactly contains this text in a file with this basename. -The contains is same as python's check `'foo' in text`, ie. it is case sensitive and is -looking for exact matches. Therefore the search text should be exact as well. -The basename must include the extension, like 'README.md' or 'foo' -(in this case foo is a file without extension). + This tool is useful for examining specific sections of source code files when you know + the exact line range you want to analyze. The function will return an error message if + end_line is less than start_line.""", + input_schema=ReadCodeWithRelativePathInput, + ) + + def __init__(self, kg: KnowledgeGraph): + self.kg = kg -Text Node is a chunk of text extracted from a text file, such as comments or documentation. -Source code files are not split into TextNodes! + ############################################################################### + # FileNode retrieval # + ############################################################################### -You can use this tool to find text/documentation in a specific file that contains this text.""" + def find_file_node_with_basename(self, basename: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all FileNodes with the given basename.""" + results = [] + for kg_node in self.kg.get_file_nodes(): + if kg_node.node.basename == basename: + results.append( + { + "FileNode": { + "node_id": kg_node.node_id, + "basename": kg_node.node.basename, + "relative_path": kg_node.node.relative_path, + } + } + ) + results.sort(key=lambda x: x["FileNode"]["node_id"]) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] + + def find_file_node_with_relative_path( + self, relative_path: str + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all FileNodes with the given relative path.""" + results = [] + for kg_node in self.kg.get_file_nodes(): + if kg_node.node.relative_path == relative_path: + results.append( + { + "FileNode": { + "node_id": kg_node.node_id, + "basename": kg_node.node.basename, + "relative_path": kg_node.node.relative_path, + } + } + ) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] + + ############################################################################### + # ASTNode retrieval # + ############################################################################### + + def find_ast_node_with_text_in_file( + self, text: str, target_files_nodes: List[KnowledgeGraphNode] + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes containing the given text in files with the given basename.""" + results = [] + + # Get HAS_AST edges to find which AST nodes belong to these files + has_ast_edges = self.kg.get_has_ast_edges() + file_to_ast_map = { + edge.source.node_id: edge.target + for edge in has_ast_edges + if edge.source.node_id in [n.node_id for n in target_files_nodes] + } + # Construct parent to children map for AST traversal + parent_to_children = self.kg.get_parent_to_children_map() + + # Get root AstNode id list + root_ast_node_ids = set([node.node_id for node in file_to_ast_map.values()]) + + for file_node in target_files_nodes: + # Start with root AST node for this file + root_ast = file_to_ast_map[file_node.node_id] + + # Add all descendant AST nodes + stack = [root_ast] + while stack: + current_node = stack.pop() + + # Check if the current node contains the text + # Don't include the root AST node itself + if text in current_node.node.text and current_node.node_id not in root_ast_node_ids: + results.append( + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "ASTNode": { + "node_id": current_node.node_id, + "type": current_node.node.type, + "start_line": current_node.node.start_line, + "end_line": current_node.node.end_line, + "text": current_node.node.text, + }, + } + ) + + # Add children to stack + stack += parent_to_children.get(current_node.node_id, []) + + # Sort by text length (smaller first) + results.sort(key=lambda x: len(x["ASTNode"]["text"])) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] + + def find_ast_node_with_text_in_file_with_basename( + self, text: str, basename: str + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes containing the given text in files with the given basename.""" + # Get file nodes with the given basename + target_files_nodes: List[KnowledgeGraphNode] = [ + node for node in self.kg.get_file_nodes() if node.node.basename == basename + ] + return self.find_ast_node_with_text_in_file(text, target_files_nodes) + + def find_ast_node_with_text_in_file_with_relative_path( + self, text: str, relative_path: str + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes containing the given text in files with the given relative path.""" + # Get file nodes with the given basename + target_files_nodes: List[KnowledgeGraphNode] = [ + node for node in self.kg.get_file_nodes() if node.node.relative_path == relative_path + ] + return self.find_ast_node_with_text_in_file(text, target_files_nodes) + + def find_ast_node_with_type_in_file( + self, type: str, target_files_nodes: List[KnowledgeGraphNode] + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes containing the given text in files with the given basename.""" + results = [] + + # Get HAS_AST edges to find which AST nodes belong to these files + has_ast_edges = self.kg.get_has_ast_edges() + file_to_ast_map = { + edge.source.node_id: edge.target + for edge in has_ast_edges + if edge.source.node_id in [n.node_id for n in target_files_nodes] + } -def find_text_node_with_text_in_file( - text: str, basename: str, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Find all TextNodes containing the given text in files with the given basename.""" - results = [] - # Find text nodes that contain the given text - text_nodes_with_text = [node for node in kg.get_text_nodes() if text in node.node.text] + # Construct parent to children map for AST traversal + parent_to_children = self.kg.get_parent_to_children_map() + + # Get root AstNode id list + root_ast_node_ids = set([node.node_id for node in file_to_ast_map.values()]) + + for file_node in target_files_nodes: + # Start with root AST node for this file + root_ast = file_to_ast_map[file_node.node_id] + + # Add all descendant AST nodes + stack = [root_ast] + while stack: + current_node = stack.pop() + + # Check if current node contains the text + # Don't include the root AST node itself + if current_node.node.type == type and current_node.node_id not in root_ast_node_ids: + results.append( + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "ASTNode": { + "node_id": current_node.node_id, + "type": current_node.node.type, + "start_line": current_node.node.start_line, + "end_line": current_node.node.end_line, + "text": current_node.node.text, + }, + } + ) + + # Add children to stack + stack += parent_to_children.get(current_node.node_id, []) + + # Sort by text length (smaller first) + results.sort(key=lambda x: len(x["ASTNode"]["text"])) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] + + def find_ast_node_with_type_in_file_with_basename( + self, type: str, basename: str + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes with the given type in files with the given basename.""" + # Get file nodes with the given basename + target_files_nodes: List[KnowledgeGraphNode] = [ + node for node in self.kg.get_file_nodes() if node.node.basename == basename + ] + return self.find_ast_node_with_type_in_file(type, target_files_nodes) + + def find_ast_node_with_type_in_file_with_relative_path( + self, type: str, relative_path: str + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all ASTNodes with the given type in files with the given relative path.""" + # Get file nodes with the given basename + target_files_nodes: List[KnowledgeGraphNode] = [ + node for node in self.kg.get_file_nodes() if node.node.relative_path == relative_path + ] + return self.find_ast_node_with_type_in_file(type, target_files_nodes) + + ############################################################################### + # TextNode retrieval # + ############################################################################### + + def find_file_node_of_a_text_node(self, text_node: KnowledgeGraphNode) -> KnowledgeGraphNode: + """ + Find a file node that contains the given text node. + """ + next_chunk_reverse_map = { + edge.target.node_id: edge.source for edge in self.kg.get_next_chunk_edges() + } + has_file_node_map = { + edge.target.node_id: edge.source for edge in self.kg.get_has_text_edges() + } - # If no text nodes found, return early - if not text_nodes_with_text: - return format_knowledge_graph_data([]), [] + # Find the root text node + current_text_node = text_node + while next_chunk_reverse_map.get(current_text_node.node_id, None) is not None: + current_text_node = next_chunk_reverse_map[current_text_node.node_id] - for text_node in text_nodes_with_text: # Now current_text_node is the root text node - file_node = find_file_node_of_a_text_node(text_node, kg) - - # If the file node matches the given basename, add to results - if file_node.node.basename == basename: + file_node = has_file_node_map[current_text_node.node_id] + return file_node + + def find_text_node_with_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]: + """Find all TextNodes containing the given text.""" + results = [] + # Find text nodes that contain the given text + text_nodes_with_text = [node for node in self.kg.get_text_nodes() if text in node.node.text] + + # If no text nodes found, return early + if not text_nodes_with_text: + return format_knowledge_graph_data([]), [] + for text_node in text_nodes_with_text: + # Find the file node that contains this text node + file_node = self.find_file_node_of_a_text_node(text_node) results.append( { "FileNode": { @@ -425,144 +436,145 @@ def find_text_node_with_text_in_file( } ) - # Sort by node_id - results.sort(key=lambda x: x["TextNode"]["node_id"]) - return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - - -class GetNextTextNodeWithNodeIdInput(BaseModel): - node_id: int = Field("Get the next TextNode of this given node_id.") - + # Sort by node_id + results.sort(key=lambda x: x["TextNode"]["node_id"]) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] -GET_NEXT_TEXT_NODE_WITH_NODE_ID_DESCRIPTION = """\ -Get the next TextNode of this given node_id. + def find_text_node_with_text_in_file( + self, text: str, basename: str + ) -> Tuple[str, List[Dict[str, Any]]]: + """Find all TextNodes containing the given text in files with the given basename.""" + results = [] + # Find text nodes that contain the given text + text_nodes_with_text = [node for node in self.kg.get_text_nodes() if text in node.node.text] -Text Node is a chunk of text extracted from a text file, such as comments or documentation. -Source code files are not split into TextNodes! + # If no text nodes found, return early + if not text_nodes_with_text: + return format_knowledge_graph_data([]), [] -You can use this tool to read the next section of text that you are interested in.""" + for text_node in text_nodes_with_text: + # Now current_text_node is the root text node + file_node = self.find_file_node_of_a_text_node(text_node) + # If the file node matches the given basename, add to results + if file_node.node.basename == basename: + results.append( + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "TextNode": { + "node_id": text_node.node_id, + "text": text_node.node.text, + "start_line": text_node.node.start_line, + "end_line": text_node.node.end_line, + }, + } + ) -def get_next_text_node_with_node_id( - node_id: int, kg: KnowledgeGraph -) -> Tuple[str, List[Dict[str, Any]]]: - """Get the next TextNode for the given node_id.""" - - results = [] - - # Find the current text node - current_text_node = None - for node in kg.get_text_nodes(): - if node.node_id == node_id: - current_text_node = node - break + # Sort by node_id + results.sort(key=lambda x: x["TextNode"]["node_id"]) + return format_knowledge_graph_data(results[:MAX_RESULT]), results[:MAX_RESULT] - # If the current text node does not exist, return empty result - if not current_text_node: - return format_knowledge_graph_data([]), [] + def get_next_text_node_with_node_id(self, node_id: int) -> Tuple[str, List[Dict[str, Any]]]: + """Get the next TextNode for the given node_id.""" - # Get next chunk map - next_chunk_map = {edge.source.node_id: edge.target for edge in kg.get_next_chunk_edges()} + results = [] - # Get the next text node - next_text_node = next_chunk_map.get(current_text_node.node_id, None) + # Find the current text node + current_text_node = None + for node in self.kg.get_text_nodes(): + if node.node_id == node_id: + current_text_node = node + break - # if the next text node does not exist, return empty result - if not next_text_node: - return format_knowledge_graph_data([]), [] + # If the current text node does not exist, return empty result + if not current_text_node: + return format_knowledge_graph_data([]), [] - # Find the file node that contains this text node - file_node = find_file_node_of_a_text_node(next_text_node, kg) - results.append( - { - "FileNode": { - "node_id": file_node.node_id, - "basename": file_node.node.basename, - "relative_path": file_node.node.relative_path, - }, - "TextNode": { - "node_id": next_text_node.node_id, - "text": next_text_node.node.text, - "start_line": next_text_node.node.start_line, - "end_line": next_text_node.node.end_line, - }, + # Get next chunk map + next_chunk_map = { + edge.source.node_id: edge.target for edge in self.kg.get_next_chunk_edges() } - ) - return format_knowledge_graph_data(results), results - - -############################################################################### -# Other # -############################################################################### - - -class ReadCodeWithRelativePathInput(BaseModel): - relative_path: str = Field("The relative path of FileNode to read from root of codebase.") - start_line: int = Field("The starting line number, 1-indexed and inclusive.") - end_line: int = Field("The ending line number, 1-indexed and exclusive.") - -READ_CODE_WITH_RELATIVE_PATH_DESCRIPTION = """\ -Read a specific section of a source code file's content by specifying its relative path and line range. -The relative path must be the full path from the root of codebase, like 'src/core/parser.py' or -'test/unit/test_parser.java'. + # Get the next text node + next_text_node = next_chunk_map.get(current_text_node.node_id, None) -This tool ONLY works with source code files (not text files or documentation). It is designed -to read large sections of code at once - you should request substantial chunks (hundreds of lines) -rather than making multiple small requests of 10-20 lines each, which would be inefficient. + # if the next text node does not exist, return empty result + if not next_text_node: + return format_knowledge_graph_data([]), [] -Line numbers are 1-indexed, where start_line is inclusive and end_line is exclusive. - -This tool is useful for examining specific sections of source code files when you know -the exact line range you want to analyze. The function will return an error message if -end_line is less than start_line. -""" - - -def read_code_with_relative_path( - relative_path: str, start_line: int, end_line: int, kg: KnowledgeGraph -) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: - """Read a specific section of a source code file by relative path and line range.""" - if end_line < start_line: - return f"end_line {end_line} must be greater than start_line {start_line}!", None - - # Find file nodes with the given relative path - target_file = None - for node in kg.get_file_nodes(): - if node.node.relative_path == relative_path: - target_file = node - break - - # Check if the file exists - if not target_file: - return format_knowledge_graph_data([]), [] - - # Check if it is a source code file - if not tree_sitter_parser.supports_file(Path(target_file.node.relative_path)): - return f"The file {relative_path} is not a source code file!", None - - # Get the first ast node for this file - first_ast_node = [ - edge.target for edge in kg.get_has_ast_edges() if edge.source.node_id == target_file.node_id - ][0] - text = first_ast_node.node.text - lines = text.split("\n") - selected_lines = lines[start_line - 1 : end_line] # Convert to 0-indexed - selected_text = "\n".join(selected_lines) - selected_text_with_line_numbers = pre_append_line_numbers(selected_text, start_line) - - result_data = [ - { - "FileNode": { - "node_id": target_file.node_id, - "basename": target_file.node.basename, - "relative_path": target_file.node.relative_path, - }, - "SelectedLines": { - "text": selected_text_with_line_numbers, - "start_line": start_line, - "end_line": end_line, - }, - } - ] - return format_knowledge_graph_data(result_data), result_data + # Find the file node that contains this text node + file_node = self.find_file_node_of_a_text_node(next_text_node) + results.append( + { + "FileNode": { + "node_id": file_node.node_id, + "basename": file_node.node.basename, + "relative_path": file_node.node.relative_path, + }, + "TextNode": { + "node_id": next_text_node.node_id, + "text": next_text_node.node.text, + "start_line": next_text_node.node.start_line, + "end_line": next_text_node.node.end_line, + }, + } + ) + return format_knowledge_graph_data(results), results + + ############################################################################### + # Other # + ############################################################################### + + def read_code_with_relative_path( + self, relative_path: str, start_line: int, end_line: int + ) -> Union[Tuple[str, List[Dict[str, Any]]], Tuple[str, None]]: + """Read a specific section of a source code file by relative path and line range.""" + if end_line < start_line: + return f"end_line {end_line} must be greater than start_line {start_line}!", None + + # Find file nodes with the given relative path + target_file = None + for node in self.kg.get_file_nodes(): + if node.node.relative_path == relative_path: + target_file = node + break + + # Check if the file exists + if not target_file: + return format_knowledge_graph_data([]), [] + + # Check if it is a source code file + if not tree_sitter_parser.supports_file(Path(target_file.node.relative_path)): + return f"The file {relative_path} is not a source code file!", None + + # Get the first ast node for this file + first_ast_node = [ + edge.target + for edge in self.kg.get_has_ast_edges() + if edge.source.node_id == target_file.node_id + ][0] + text = first_ast_node.node.text + lines = text.split("\n") + selected_lines = lines[start_line - 1 : end_line] # Convert to 0-indexed + selected_text = "\n".join(selected_lines) + selected_text_with_line_numbers = pre_append_line_numbers(selected_text, start_line) + + result_data = [ + { + "FileNode": { + "node_id": target_file.node_id, + "basename": target_file.node.basename, + "relative_path": target_file.node.relative_path, + }, + "SelectedLines": { + "text": selected_text_with_line_numbers, + "start_line": start_line, + "end_line": end_line, + }, + } + ] + return format_knowledge_graph_data(result_data), result_data diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py new file mode 100644 index 00000000..cfd38b2f --- /dev/null +++ b/prometheus/tools/web_search.py @@ -0,0 +1,139 @@ +from dataclasses import dataclass +from typing import Annotated + +from pydantic import BaseModel, Field +from tavily import InvalidAPIKeyError, TavilyClient, UsageLimitExceededError + +from prometheus.configuration.config import settings +from prometheus.exceptions.web_search_tool_exception import WebSearchToolException +from prometheus.utils.logger_manager import get_thread_logger + + +@dataclass +class ToolSpec: + description: str + input_schema: type + + +class WebSearchInput(BaseModel): + """Base parameters for Tavily search.""" + + query: Annotated[str, Field(description="Search query")] + + +def format_results(response: dict) -> str: + """Format Tavily search results into a readable string.""" + output = [] + + # Add domain filter information if present + if response.get("included_domains") or response.get("excluded_domains"): + filters = [] + if response.get("included_domains"): + filters.append(f"Including domains: {', '.join(response['included_domains'])}") + if response.get("excluded_domains"): + filters.append(f"Excluding domains: {', '.join(response['excluded_domains'])}") + output.append("Search Filters:") + output.extend(filters) + output.append("") # Empty line for separation + + if response.get("answer"): + output.append(f"Answer: {response['answer']}") + output.append("\nSources:") + # Add immediate source references for the answer + for result in response["results"]: + output.append(f"- {result['title']}: {result['url']}") + output.append("") # Empty line for separation + + output.append("Detailed Results:") + for result in response["results"]: + output.append(f"\nTitle: {result['title']}") + output.append(f"URL: {result['url']}") + output.append(f"Content: {result['content']}") + if result.get("published_date"): + output.append(f"Published: {result['published_date']}") + + return "\n".join(output) + + +class WebSearchTool: + """Tool class for web search functionality.""" + + web_search_spec = ToolSpec( + description="""\ + Searches the web for technical information to aid in bug analysis and resolution. + Use this when you need external context, such as: + 1. Looking up unfamiliar error messages, exceptions, or stack traces. + 2. Finding official documentation or usage examples for a specific library, framework, or API. + 3. Searching for known bugs, common pitfalls, or compatibility issues related to a software version. + 4. Learning the best practices or design patterns for fixing a class of vulnerability (e.g., SQL injection, XSS). + + Queries should be specific and include relevant keywords like library names, version numbers, and error codes. + """, + input_schema=WebSearchInput, + ) + + def __init__(self): + """Initialize the web search tool.""" + # Load environment variables from .env file + self._logger, file_handler = get_thread_logger(__name__) + + tavily_api_key = settings.TAVILY_API_KEY + if tavily_api_key is None: + self._logger.warning("Tavily API key is not set") + tavily_client = None + else: + tavily_client = TavilyClient(api_key=tavily_api_key) + self.tavily_client = tavily_client + + def web_search( + self, + query: str, + max_results: int = 5, + include_domains=None, + exclude_domains: list[str] = None, + ) -> str: + """Search the web for technical information to aid in bug analysis and resolution. + + Args: + query: Search query string. + max_results: Maximum number of results to return (default: 5). + include_domains: List of domains to include in search. + exclude_domains: List of domains to exclude from search. + + Returns: + Formatted search results as a string. + """ + # Set default include domains if not provided + if include_domains is None: + include_domains = [ + "stackoverflow.com", + "github.com", + "developer.mozilla.org", + "learn.microsoft.com", + "docs.python.org", + "pydantic.dev", + "pypi.org", + "readthedocs.org", + ] + + # Call the Tavily API + try: + response = self.tavily_client.search( + query=query, + max_results=max_results, + search_depth="advanced", + include_answer=True, + include_domains=include_domains or [], # Convert None to an empty list + exclude_domains=exclude_domains or [], # Convert None to an empty list + ) + except InvalidAPIKeyError: + raise WebSearchToolException("Invalid Tavily API key") + except UsageLimitExceededError: + raise WebSearchToolException("Usage limit exceeded") + except Exception as e: + raise WebSearchToolException(f"An error occurred: {str(e)}") + + # Format and return the results + format_response = format_results(response) + self._logger.info(f"web_search format_response: {format_response}") + return format_response diff --git a/prometheus/utils/logger_manager.py b/prometheus/utils/logger_manager.py new file mode 100644 index 00000000..e3199770 --- /dev/null +++ b/prometheus/utils/logger_manager.py @@ -0,0 +1,358 @@ +""" +Unified Log Manager + +This module provides a centralized logging management solution for the entire Prometheus project. +All logger configuration and retrieval should be done through this module. +""" + +import logging +import os +import sys +from datetime import datetime +from pathlib import Path +from typing import Optional + +from prometheus.configuration.config import settings + + +class ColoredFormatter(logging.Formatter): + """Colored log formatter""" + + # ANSI color codes + COLORS = { + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[35m", # Purple + "RESET": "\033[0m", # Reset color + } + + # Colored level names + COLORED_LEVELNAMES = { + "DEBUG": f"{COLORS['DEBUG']}DEBUG{COLORS['RESET']}", + "INFO": f"{COLORS['INFO']}INFO{COLORS['RESET']}", + "WARNING": f"{COLORS['WARNING']}WARNING{COLORS['RESET']}", + "ERROR": f"{COLORS['ERROR']}ERROR{COLORS['RESET']}", + "CRITICAL": f"{COLORS['CRITICAL']}CRITICAL{COLORS['RESET']}", + } + + def __init__(self, fmt=None, datefmt=None, use_colors=True): + """ + Initialize colored formatter + + Args: + fmt: Log format string + datefmt: Date format string + use_colors: Whether to use colors + """ + super().__init__(fmt, datefmt) + self.use_colors = use_colors and self._supports_color() + + def _supports_color(self) -> bool: + """Check if terminal supports colors""" + # Check if running in a color-supporting terminal + return ( + hasattr(sys.stdout, "isatty") + and sys.stdout.isatty() + and sys.platform != "win32" # Windows may need special handling + ) or "FORCE_COLOR" in os.environ + + def format(self, record): + """Format log record""" + if self.use_colors and record.levelname in self.COLORED_LEVELNAMES: + # Save original level name + original_levelname = record.levelname + # Use colored level name + record.levelname = self.COLORED_LEVELNAMES[record.levelname] + + # Format message + formatted = super().format(record) + + # Restore original level name + record.levelname = original_levelname + + return formatted + else: + return super().format(record) + + +class LoggerManager: + """Logger manager class, responsible for creating and configuring all loggers""" + + _instance: Optional["LoggerManager"] = None + _initialized: bool = False + + def __new__(cls) -> "LoggerManager": + """Singleton pattern implementation""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + """Initialize logger manager""" + self.log_level = getattr(settings, "LOGGING_LEVEL") + self.issue_log_dir = Path(getattr(settings, "WORKING_DIRECTORY")) / "answer_issue_logs" + if not self._initialized: + self._setup_root_logger() + self._initialized = True + + def _setup_root_logger(self): + """Setup root logger""" + # Get root logger + self.root_logger = logging.getLogger("prometheus") + + # Clear existing handlers to avoid duplication + self.root_logger.handlers.clear() + + # Set log level + self.root_logger.setLevel(getattr(logging, self.log_level)) + + # Create colored formatter for console output + self.colored_formatter = ColoredFormatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + # Create plain formatter for file output + self.file_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Create console handler (using colored formatter) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(self.colored_formatter) + console_handler.setLevel( + getattr(logging, self.log_level) + ) # Ensure console handler uses same level + self.root_logger.addHandler(console_handler) + + # Prevent log propagation to parent logger + self.root_logger.propagate = False + + # Log configuration information + self._log_configuration() + + def _set_multi_threads_log_file_handler( + self, thread_id: int, logger_name: str, force_new_file: bool = False + ): + """Set multi threads log file handler""" + # Find existing log file for this thread_id, or create new one if none exists + log_file_path = self._find_or_create_log_file(thread_id, force_new_file) + file_handler = self.create_file_handler(log_file_path, logger_name) + return file_handler + + def _find_or_create_log_file(self, thread_id: int, force_new_file: bool = False) -> Path: + """ + Find existing log file for the thread_id, or create new one if none exists + + Args: + thread_id: Thread ID to find/create log file for + force_new_file: If True, always create a new file with timestamp, even if existing files exist + + Returns: + Path to the log file (existing earliest one or newly created) + """ + import glob + + if force_new_file: + # Always create a new log file with current timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[ + :-3 + ] # Include milliseconds for uniqueness + return self.issue_log_dir / f"{timestamp}_{thread_id}.log" + + # Original logic: find existing file or create new one + # Pattern to match log files for this thread_id + pattern = str(self.issue_log_dir / f"*_{thread_id}.log") + existing_logs = glob.glob(pattern) + + if existing_logs: + # Sort by filename (which includes timestamp) to get chronological order + existing_logs.sort() + # Return the earliest (first) file + return Path(existing_logs[0]) + else: + # No existing log file found, create a new one + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return self.issue_log_dir / f"{timestamp}_{thread_id}.log" + + def _log_configuration(self): + """Log configuration information""" + # Dynamically get all attributes from settings + config_attrs = [ + attr for attr in dir(settings) if attr.isupper() and not attr.startswith("_") + ] + + for attr in config_attrs: + value = getattr(settings, attr, "Not Set") + + # Check if the attribute name indicates a sensitive configuration + is_sensitive = any( + keyword in attr.upper() for keyword in ["KEY", "API", "PASSWORD", "SECRET"] + ) + + # If sensitive, mask the value + if is_sensitive and value and value != "Not Set": + masked_value = "*" * min(len(str(value)), 8) + self.root_logger.info(f"{attr}={masked_value}") + else: + self.root_logger.info(f"{attr}={value}") + + def get_logger(self, name: str) -> logging.Logger: + """ + Get logger with specified name + + Args: + name: Logger name, recommended to use full module path + + Returns: + Configured logger instance + """ + # # Ensure logger name starts with prometheus + # if not name.startswith("prometheus"): + # name = f"prometheus.{name}" + + logger = logging.getLogger(name) + + # If it's a child logger, inherit root logger configuration + if name != "prometheus": + logger.parent = self.root_logger + logger.propagate = True + + return logger + + def create_file_handler(self, log_file_path: Path, logger_name: str) -> logging.FileHandler: + """ + Create file handler for specified logger + + Args: + log_file_path: Log file path + logger_name: Logger name + + Returns: + Configured file handler + """ + # Ensure log directory exists + log_file_path.parent.mkdir(parents=True, exist_ok=True) + + # Create file handler with append mode to preserve existing content + file_handler = logging.FileHandler(log_file_path, mode="a") + file_handler.setLevel(getattr(logging, self.log_level)) + file_handler.setFormatter(self.file_formatter) + + # # Get logger directly without going through get_logger to avoid recursion + # # Ensure logger name starts with prometheus + # if not logger_name.startswith("prometheus"): + # logger_name = f"prometheus.{logger_name}" + + logger = logging.getLogger(logger_name) + + # If it's a child logger, inherit root logger configuration + if logger_name != "prometheus": + logger.parent = self.root_logger + logger.propagate = True + + # Check if this logger already has a file handler to avoid duplicates + has_file_handler = any(isinstance(h, logging.FileHandler) for h in logger.handlers) + if not has_file_handler: + logger.addHandler(file_handler) + + return file_handler + + def remove_file_handler(self, handler: logging.FileHandler, logger_name: str = "prometheus"): + """ + Remove file handler + + Args: + handler: Handler to remove + logger_name: Logger name + """ + logger = self.get_logger(logger_name) + logger.removeHandler(handler) + handler.close() + + def remove_multi_thread_file_handler( + self, handler: logging.FileHandler, logger_name: str = None + ): + """ + Remove multi-thread file handler from specific logger + + Args: + handler: File handler to remove + logger_name: Logger name to remove handler from + """ + if logger_name: + logger = self.get_logger(logger_name) + logger.removeHandler(handler) + else: + # Fallback: try to remove from root logger + self.root_logger.removeHandler(handler) + handler.close() + + +# Create global logger manager instance +logger_manager = LoggerManager() + + +def get_logger(name: str) -> logging.Logger: + """ + Convenience function to get logger + + Args: + name: Logger name, recommended to use __name__ or module path + + Returns: + Configured logger instance + + Examples: + >>> logger, file_handler = get_thread_logger(__name__) + >>> logger, file_handler = get_thread_logger("prometheus.tools.web_search") + """ + return logger_manager.get_logger(name) + + +def remove_multi_threads_log_file_handler(handler: logging.FileHandler, logger_name: str = None): + """ + Convenience function to remove multi-thread file handler + + Args: + handler: File handler to remove + logger_name: Logger name (optional) + """ + logger_manager.remove_multi_thread_file_handler(handler, logger_name) + + +def get_thread_logger( + module_name: str, force_new_file: bool = False +) -> tuple[logging.Logger, logging.FileHandler]: + """ + Convenience function to create a thread-specific logger with file handler in one call + + Args: + module_name: Module name (usually __name__), if None, uses current module + + force_new_file: If True, always create a new log file with timestamp, even if existing files exist + + Returns: + Tuple of (logger, file_handler) for easy cleanup + + Examples: + >>> logger, file_handler = get_thread_logger(__name__) + >>> logger.info("This goes to both console and file") + >>> # In finally block: + >>> remove_multi_threads_log_file_handler(file_handler, logger.name) + + >>> # Force creating a new file each time + >>> logger, file_handler = get_thread_logger(__name__, force_new_file=True) + """ + import threading + + # Get thread ID + thread_id = threading.get_ident() + logger_name = f"thread-{thread_id}.{module_name}" + + # Create file handler and logger + file_handler = logger_manager._set_multi_threads_log_file_handler( + thread_id, logger_name, force_new_file + ) + logger = get_logger(logger_name) + return logger, file_handler diff --git a/pyproject.toml b/pyproject.toml index c9f9da4f..522821bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "Prometheus" -version = "0.0.1" +version = "1.2.0" dependencies = [ "langchain==0.3.3", "tree-sitter==0.21.3", @@ -27,6 +27,9 @@ dependencies = [ "sqlmodel==0.0.24", "asyncpg", "pyjwt==2.6.0", + "mcp>=1.4.1", + "tavily-python>=0.5.1", + "langchain-mcp-adapters>=0.1.9", "httpx==0.28.1", ] requires-python = ">= 3.11" diff --git a/tests/graph/test_knowledge_graph.py b/tests/graph/test_knowledge_graph.py index 9d598ade..551f8bd7 100644 --- a/tests/graph/test_knowledge_graph.py +++ b/tests/graph/test_knowledge_graph.py @@ -22,24 +22,24 @@ async def mock_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811 async def test_build_graph(): - knowledge_graph = KnowledgeGraph(1000, 100, 10, 0) + knowledge_graph = KnowledgeGraph(1, 1000, 100, 0) await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) - assert knowledge_graph._next_node_id == 93 + assert knowledge_graph._next_node_id == 15 # 7 FileNode # 84 ASTnode # 2 TextNode - assert len(knowledge_graph._knowledge_graph_nodes) == 93 - assert len(knowledge_graph._knowledge_graph_edges) == 93 + assert len(knowledge_graph._knowledge_graph_nodes) == 15 + assert len(knowledge_graph._knowledge_graph_edges) == 14 assert len(knowledge_graph.get_file_nodes()) == 7 - assert len(knowledge_graph.get_ast_nodes()) == 84 - assert len(knowledge_graph.get_text_nodes()) == 2 - assert len(knowledge_graph.get_parent_of_edges()) == 81 + assert len(knowledge_graph.get_ast_nodes()) == 7 + assert len(knowledge_graph.get_text_nodes()) == 1 + assert len(knowledge_graph.get_parent_of_edges()) == 4 assert len(knowledge_graph.get_has_file_edges()) == 6 assert len(knowledge_graph.get_has_ast_edges()) == 3 - assert len(knowledge_graph.get_has_text_edges()) == 2 - assert len(knowledge_graph.get_next_chunk_edges()) == 1 + assert len(knowledge_graph.get_has_text_edges()) == 1 + assert len(knowledge_graph.get_next_chunk_edges()) == 0 async def test_get_file_tree(): diff --git a/tests/lang_graph/nodes/test_bug_reproducing_write_node.py b/tests/lang_graph/nodes/test_bug_reproducing_write_node.py index 30cfb974..720a3b93 100644 --- a/tests/lang_graph/nodes/test_bug_reproducing_write_node.py +++ b/tests/lang_graph/nodes/test_bug_reproducing_write_node.py @@ -6,6 +6,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.nodes.bug_reproducing_write_node import BugReproducingWriteNode from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState +from tests.test_utils.fixtures import temp_test_dir # noqa: F401 from tests.test_utils.util import FakeListChatWithToolsModel @@ -35,11 +36,11 @@ def test_state(): ) -def test_call_method(mock_kg, test_state): +def test_call_method(mock_kg, test_state, temp_test_dir): # noqa: F811 """Test the __call__ method execution.""" fake_response = "Created test file" fake_llm = FakeListChatWithToolsModel(responses=[fake_response]) - node = BugReproducingWriteNode(fake_llm, mock_kg) + node = BugReproducingWriteNode(fake_llm, temp_test_dir, mock_kg) result = node(test_state) diff --git a/tests/lang_graph/nodes/test_edit_node.py b/tests/lang_graph/nodes/test_edit_node.py index d7de2632..39a49b40 100644 --- a/tests/lang_graph/nodes/test_edit_node.py +++ b/tests/lang_graph/nodes/test_edit_node.py @@ -5,6 +5,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.nodes.edit_node import EditNode +from tests.test_utils.fixtures import temp_test_dir # noqa: F401 from tests.test_utils.util import FakeListChatWithToolsModel @@ -19,18 +20,18 @@ def fake_llm(): return FakeListChatWithToolsModel(responses=["File edit completed successfully"]) -def test_init_edit_node(mock_kg, fake_llm): +def test_init_edit_node(mock_kg, fake_llm, temp_test_dir): # noqa: F811 """Test EditNode initialization.""" - node = EditNode(fake_llm, mock_kg) + node = EditNode(fake_llm, temp_test_dir, mock_kg) assert isinstance(node.system_prompt, SystemMessage) assert len(node.tools) == 5 # Should have 5 file operation tools assert node.model_with_tools is not None -def test_call_method_basic(mock_kg, fake_llm): +def test_call_method_basic(mock_kg, fake_llm, temp_test_dir): # noqa: F811 """Test basic call functionality without tool execution.""" - node = EditNode(fake_llm, mock_kg) + node = EditNode(fake_llm, temp_test_dir, mock_kg) state = {"edit_messages": [HumanMessage(content="Make the following changes: ...")]} result = node(state) diff --git a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py index 38842e7e..ccc08399 100644 --- a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py @@ -1,5 +1,6 @@ import pytest -from langchain_core.messages import HumanMessage +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.messages.tool import ToolCall from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode from tests.test_utils.util import FakeListChatWithToolsModel @@ -10,6 +11,24 @@ def fake_llm(): return FakeListChatWithToolsModel(responses=["Bug analysis completed successfully"]) +@pytest.fixture +def fake_llm_with_tool_call(): + """LLM that simulates making a web_search tool call.""" + return FakeListChatWithToolsModel( + responses=["I need to search for information about this error."] + ) + + +def test_init_issue_bug_analyzer_node(fake_llm): + """Test IssueBugAnalyzerNode initialization.""" + node = IssueBugAnalyzerNode(fake_llm) + + assert node.system_prompt is not None + assert len(node.tools) == 1 # Should have web_search tool + assert node.tools[0].name == "web_search" + assert node.model_with_tools is not None + + def test_call_method_basic(fake_llm): """Test basic call functionality.""" node = IssueBugAnalyzerNode(fake_llm) @@ -20,3 +39,97 @@ def test_call_method_basic(fake_llm): assert "issue_bug_analyzer_messages" in result assert len(result["issue_bug_analyzer_messages"]) == 1 assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully" + + +def test_web_search_tool_integration(fake_llm_with_tool_call): + """Test that the web_search tool is properly integrated and can be called.""" + node = IssueBugAnalyzerNode(fake_llm_with_tool_call) + state = { + "issue_bug_analyzer_messages": [ + HumanMessage( + content="I'm getting a ValueError in my Python code. Can you help analyze it?" + ) + ] + } + + result = node(state) + + # Verify the result contains the response message + assert "issue_bug_analyzer_messages" in result + assert len(result["issue_bug_analyzer_messages"]) == 1 + assert ( + result["issue_bug_analyzer_messages"][0].content + == "I need to search for information about this error." + ) + + +def test_web_search_tool_call_with_correct_parameters(fake_llm): + """Test that web_search tool has correct configuration and can be called.""" + node = IssueBugAnalyzerNode(fake_llm) + + # Test that the tool exists and has correct configuration + web_search_tool = node.tools[0] + assert web_search_tool.name == "web_search" + assert "technical information" in web_search_tool.description.lower() + + # Test that the tool has the correct args schema + assert hasattr(web_search_tool, "args_schema") + assert web_search_tool.args_schema is not None + + +def test_system_prompt_contains_web_search_info(fake_llm): + """Test that the system prompt mentions web_search tool.""" + node = IssueBugAnalyzerNode(fake_llm) + + system_prompt_content = node.system_prompt.content.lower() + assert "web_search" in system_prompt_content + assert "technical information" in system_prompt_content + + +def test_web_search_tool_schema_validation(fake_llm): + """Test that the web_search tool has proper input validation.""" + node = IssueBugAnalyzerNode(fake_llm) + web_search_tool = node.tools[0] + + # Check that the tool has an args_schema + assert hasattr(web_search_tool, "args_schema") + assert web_search_tool.args_schema is not None + + # Test with valid input + valid_input = {"query": "Python debugging techniques"} + # This should not raise an exception + validated_input = web_search_tool.args_schema(**valid_input) + assert validated_input.query == "Python debugging techniques" + + +def test_multiple_tool_calls_in_conversation(fake_llm): + """Test handling multiple web_search calls in a conversation.""" + node = IssueBugAnalyzerNode(fake_llm) + + # Simulate a conversation with tool calls + state = { + "issue_bug_analyzer_messages": [ + HumanMessage(content="Analyze this bug: ImportError in my application"), + AIMessage( + content="Let me search for information about this error.", + tool_calls=[ + ToolCall( + name="web_search", + args={"query": "Python ImportError debugging"}, + id="call_1", + ) + ], + ), + ToolMessage( + content="Search results: ImportError occurs when...", tool_call_id="call_1" + ), + HumanMessage(content="The error still persists after trying the suggested fixes"), + ] + } + + result = node(state) + + assert "issue_bug_analyzer_messages" in result + assert len(result["issue_bug_analyzer_messages"]) == 1 + # The new response should be added to the conversation + assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully" diff --git a/tests/neo4j/test_knowledge_graph_handler.py b/tests/neo4j/test_knowledge_graph_handler.py index c077533e..cdf9c873 100644 --- a/tests/neo4j/test_knowledge_graph_handler.py +++ b/tests/neo4j/test_knowledge_graph_handler.py @@ -38,7 +38,7 @@ async def test_num_ast_nodes(mock_neo4j_service): async with mock_neo4j_service.neo4j_driver.session() as session: read_ast_nodes = await session.execute_read(handler._read_ast_nodes, root_node_id=0) - assert len(read_ast_nodes) == 84 + assert len(read_ast_nodes) == 7 @pytest.mark.slow @@ -56,7 +56,7 @@ async def test_num_text_nodes(mock_neo4j_service): async with mock_neo4j_service.neo4j_driver.session() as session: read_text_nodes = await session.execute_read(handler._read_text_nodes, root_node_id=0) - assert len(read_text_nodes) == 2 + assert len(read_text_nodes) == 1 @pytest.mark.slow @@ -67,7 +67,7 @@ async def test_num_parent_of_edges(mock_neo4j_service): read_parent_of_edges = await session.execute_read( handler._read_parent_of_edges, root_node_id=0 ) - assert len(read_parent_of_edges) == 81 + assert len(read_parent_of_edges) == 4 @pytest.mark.slow @@ -98,7 +98,7 @@ async def test_num_has_text_edges(mock_neo4j_service): read_has_text_edges = await session.execute_read( handler._read_has_text_edges, root_node_id=0 ) - assert len(read_has_text_edges) == 2 + assert len(read_has_text_edges) == 1 @pytest.mark.slow @@ -109,12 +109,12 @@ async def test_num_next_chunk_edges(mock_neo4j_service): read_next_chunk_edges = await session.execute_read( handler._read_next_chunk_edges, root_node_id=0 ) - assert len(read_next_chunk_edges) == 1 + assert len(read_next_chunk_edges) == 0 @pytest.mark.slow async def test_clear_knowledge_graph(mock_empty_neo4j_service): - kg = KnowledgeGraph(1000, 1000, 100, 0) + kg = KnowledgeGraph(1, 1000, 100, 0) await kg.build_graph(test_project_paths.TEST_PROJECT_PATH) driver = mock_empty_neo4j_service.neo4j_driver diff --git a/tests/test_utils/fixtures.py b/tests/test_utils/fixtures.py index 7a30c6ca..4f0fb5b2 100644 --- a/tests/test_utils/fixtures.py +++ b/tests/test_utils/fixtures.py @@ -25,7 +25,7 @@ @pytest.fixture(scope="session") async def neo4j_container_with_kg_fixture(): - kg = KnowledgeGraph(1000, 100, 10, 0) + kg = KnowledgeGraph(1, 1000, 100, 0) await kg.build_graph(test_project_paths.TEST_PROJECT_PATH) container = ( Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD) diff --git a/tests/tools/test_file_operation.py b/tests/tools/test_file_operation.py index 62fb8c81..7c4a6098 100644 --- a/tests/tools/test_file_operation.py +++ b/tests/tools/test_file_operation.py @@ -1,32 +1,34 @@ +import shutil + import pytest from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.tools.file_operation import ( - create_file, - delete, - edit_file, - read_file, - read_file_with_knowledge_graph_data, - read_file_with_line_numbers, -) +from prometheus.tools.file_operation import FileOperationTool from tests.test_utils import test_project_paths from tests.test_utils.fixtures import temp_test_dir # noqa: F401 @pytest.fixture(scope="function") -async def knowledge_graph_fixture(): - kg = KnowledgeGraph(1000, 100, 10, 0) - await kg.build_graph(test_project_paths.TEST_PROJECT_PATH) +async def knowledge_graph_fixture(temp_test_dir): # noqa: F811 + if temp_test_dir.exists(): + shutil.rmtree(temp_test_dir) + shutil.copytree(test_project_paths.TEST_PROJECT_PATH, temp_test_dir) + kg = KnowledgeGraph(1, 1000, 100, 0) + await kg.build_graph(temp_test_dir) return kg -def test_read_file_with_knowledge_graph_data(temp_test_dir, knowledge_graph_fixture): # noqa: F811 +@pytest.fixture(scope="function") +def file_operation_tool(temp_test_dir, knowledge_graph_fixture): # noqa: F811 + file_operation = FileOperationTool(temp_test_dir, knowledge_graph_fixture) + return file_operation + + +def test_read_file_with_knowledge_graph_data(file_operation_tool): relative_path = str( test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) - result = read_file_with_knowledge_graph_data( - relative_path, test_project_paths.TEST_PROJECT_PATH, knowledge_graph_fixture - ) + result = file_operation_tool.read_file_with_knowledge_graph_data(relative_path) result_data = result[1] assert len(result_data) > 0 for result_row in result_data: @@ -36,87 +38,87 @@ def test_read_file_with_knowledge_graph_data(temp_test_dir, knowledge_graph_fixt assert result_row["FileNode"].get("relative_path", "") == relative_path -def test_create_and_read_file(temp_test_dir): # noqa: F811 +def test_create_and_read_file(temp_test_dir, file_operation_tool): # noqa: F811 """Test creating a file and reading its contents.""" test_file = temp_test_dir / "test.txt" content = "line 1\nline 2\nline 3" # Test create_file - result = create_file("test.txt", str(temp_test_dir), content) + result = file_operation_tool.create_file("test.txt", content) assert test_file.exists() assert test_file.read_text() == content assert result == "The file test.txt has been created." # Test read_file - result = read_file("test.txt", str(temp_test_dir)) + result = file_operation_tool.read_file("test.txt") expected = "1. line 1\n2. line 2\n3. line 3" assert result == expected -def test_read_file_nonexistent(temp_test_dir): # noqa: F811 +def test_read_file_nonexistent(file_operation_tool): """Test reading a nonexistent file.""" - result = read_file("nonexistent_file.txt", str(temp_test_dir)) + result = file_operation_tool.read_file("nonexistent_file.txt") assert result == "The file nonexistent_file.txt does not exist." -def test_read_file_with_line_numbers(temp_test_dir): # noqa: F811 +def test_read_file_with_line_numbers(file_operation_tool): """Test reading specific line ranges from a file.""" content = "line 1\nline 2\nline 3\nline 4\nline 5" - create_file("test_lines.txt", str(temp_test_dir), content) + file_operation_tool.create_file("test_lines.txt", content) # Test reading specific lines - result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 2, 4) + result = file_operation_tool.read_file_with_line_numbers("test_lines.txt", 2, 4) expected = "2. line 2\n3. line 3" assert result == expected # Test invalid range - result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 4, 2) + result = file_operation_tool.read_file_with_line_numbers("test_lines.txt", 4, 2) assert result == "The end line number 2 must be greater than the start line number 4." -def test_delete(temp_test_dir): # noqa: F811 +def test_delete(file_operation_tool, temp_test_dir): # noqa: F811 """Test file and directory deletion.""" # Test file deletion test_file = temp_test_dir / "to_delete.txt" - create_file("to_delete.txt", str(temp_test_dir), "content") + file_operation_tool.create_file("to_delete.txt", "content") assert test_file.exists() - result = delete("to_delete.txt", str(temp_test_dir)) + result = file_operation_tool.delete("to_delete.txt") assert result == "The file to_delete.txt has been deleted." assert not test_file.exists() # Test directory deletion test_subdir = temp_test_dir / "subdir" test_subdir.mkdir() - create_file("subdir/file.txt", str(temp_test_dir), "content") - result = delete("subdir", str(temp_test_dir)) + file_operation_tool.create_file("subdir/file.txt", "content") + result = file_operation_tool.delete("subdir") assert result == "The directory subdir has been deleted." assert not test_subdir.exists() -def test_delete_nonexistent(temp_test_dir): # noqa: F811 +def test_delete_nonexistent(file_operation_tool): """Test deleting a nonexistent path.""" - result = delete("nonexistent_path", str(temp_test_dir)) + result = file_operation_tool.delete("nonexistent_path") assert result == "The file nonexistent_path does not exist." -def test_edit_file(temp_test_dir): # noqa: F811 +def test_edit_file(file_operation_tool): """Test editing specific lines in a file.""" # Test case 1: Successfully edit a single occurrence initial_content = "line 1\nline 2\nline 3\nline 4\nline 5" - create_file("edit_test.txt", str(temp_test_dir), initial_content) - result = edit_file("edit_test.txt", str(temp_test_dir), "line 2", "new line 2") + file_operation_tool.create_file("edit_test.txt", initial_content) + result = file_operation_tool.edit_file("edit_test.txt", "line 2", "new line 2") assert result == "Successfully edited edit_test.txt." # Test case 2: Absolute path error - result = edit_file("/edit_test.txt", str(temp_test_dir), "line 2", "new line 2") + result = file_operation_tool.edit_file("/edit_test.txt", "line 2", "new line 2") assert result == "relative_path: /edit_test.txt is a absolute path, not relative path." # Test case 3: File doesn't exist - result = edit_file("nonexistent.txt", str(temp_test_dir), "line 2", "new line 2") + result = file_operation_tool.edit_file("nonexistent.txt", "line 2", "new line 2") assert result == "The file nonexistent.txt does not exist." # Test case 4: No matches found - result = edit_file("edit_test.txt", str(temp_test_dir), "nonexistent line", "new content") + result = file_operation_tool.edit_file("edit_test.txt", "nonexistent line", "new content") assert ( result == "No match found for the specified content in edit_test.txt. Please verify the content to replace." @@ -124,16 +126,16 @@ def test_edit_file(temp_test_dir): # noqa: F811 # Test case 5: Multiple occurrences duplicate_content = "line 1\nline 2\nline 2\nline 3" - create_file("duplicate_test.txt", str(temp_test_dir), duplicate_content) - result = edit_file("duplicate_test.txt", str(temp_test_dir), "line 2", "new line 2") + file_operation_tool.create_file("duplicate_test.txt", duplicate_content) + result = file_operation_tool.edit_file("duplicate_test.txt", "line 2", "new line 2") assert ( result == "Found 2 occurrences of the specified content in duplicate_test.txt. Please provide more context to ensure a unique match." ) -def test_create_file_already_exists(temp_test_dir): # noqa: F811 +def test_create_file_already_exists(file_operation_tool): """Test creating a file that already exists.""" - create_file("existing.txt", str(temp_test_dir), "content") - result = create_file("existing.txt", str(temp_test_dir), "new content") + file_operation_tool.create_file("existing.txt", "content") + result = file_operation_tool.create_file("existing.txt", "new content") assert result == "The file existing.txt already exists." diff --git a/tests/tools/test_graph_traversal.py b/tests/tools/test_graph_traversal.py index aa8f2a93..f3c30579 100644 --- a/tests/tools/test_graph_traversal.py +++ b/tests/tools/test_graph_traversal.py @@ -1,7 +1,7 @@ import pytest from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.tools import graph_traversal +from prometheus.tools.graph_traversal import GraphTraversalTool from tests.test_utils import test_project_paths @@ -12,11 +12,15 @@ async def knowledge_graph_fixture(): return kg +@pytest.fixture(scope="function") +def graph_traversal_tool(knowledge_graph_fixture): + graph_traversal_tool = GraphTraversalTool(knowledge_graph_fixture) + return graph_traversal_tool + + @pytest.mark.slow -async def test_find_file_node_with_basename(knowledge_graph_fixture): - result = graph_traversal.find_file_node_with_basename( - test_project_paths.PYTHON_FILE.name, knowledge_graph_fixture - ) +async def test_find_file_node_with_basename(graph_traversal_tool): + result = graph_traversal_tool.find_file_node_with_basename(test_project_paths.PYTHON_FILE.name) basename = test_project_paths.PYTHON_FILE.name relative_path = str( @@ -31,13 +35,11 @@ async def test_find_file_node_with_basename(knowledge_graph_fixture): @pytest.mark.slow -async def test_find_file_node_with_relative_path(knowledge_graph_fixture): +async def test_find_file_node_with_relative_path(graph_traversal_tool): relative_path = str( test_project_paths.MD_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) - result = graph_traversal.find_file_node_with_relative_path( - relative_path, knowledge_graph_fixture - ) + result = graph_traversal_tool.find_file_node_with_relative_path(relative_path) basename = test_project_paths.MD_FILE.name @@ -49,10 +51,10 @@ async def test_find_file_node_with_relative_path(knowledge_graph_fixture): @pytest.mark.slow -async def test_find_ast_node_with_text_in_file_with_basename(knowledge_graph_fixture): # noqa: F811 +async def test_find_ast_node_with_text_in_file_with_basename(graph_traversal_tool): basename = test_project_paths.PYTHON_FILE.name - result = graph_traversal.find_ast_node_with_text_in_file_with_basename( - "Hello world!", basename, knowledge_graph_fixture + result = graph_traversal_tool.find_ast_node_with_text_in_file_with_basename( + "Hello world!", basename ) result_data = result[1] @@ -65,12 +67,12 @@ async def test_find_ast_node_with_text_in_file_with_basename(knowledge_graph_fix @pytest.mark.slow -async def test_find_ast_node_with_text_in_file_with_relative_path(knowledge_graph_fixture): +async def test_find_ast_node_with_text_in_file_with_relative_path(graph_traversal_tool): relative_path = str( test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) - result = graph_traversal.find_ast_node_with_text_in_file_with_relative_path( - "Hello world!", relative_path, knowledge_graph_fixture + result = graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path( + "Hello world!", relative_path ) result_data = result[1] @@ -83,12 +85,10 @@ async def test_find_ast_node_with_text_in_file_with_relative_path(knowledge_grap @pytest.mark.slow -async def test_find_ast_node_with_type_in_file_with_basename(knowledge_graph_fixture): +async def test_find_ast_node_with_type_in_file_with_basename(graph_traversal_tool): basename = test_project_paths.C_FILE.name node_type = "function_definition" - result = graph_traversal.find_ast_node_with_type_in_file_with_basename( - node_type, basename, knowledge_graph_fixture - ) + result = graph_traversal_tool.find_ast_node_with_type_in_file_with_basename(node_type, basename) result_data = result[1] assert len(result_data) > 0 @@ -100,13 +100,13 @@ async def test_find_ast_node_with_type_in_file_with_basename(knowledge_graph_fix @pytest.mark.slow -async def test_find_ast_node_with_type_in_file_with_relative_path(knowledge_graph_fixture): # noqa: F811 +async def test_find_ast_node_with_type_in_file_with_relative_path(graph_traversal_tool): relative_path = str( test_project_paths.JAVA_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) node_type = "string_literal" - result = graph_traversal.find_ast_node_with_type_in_file_with_relative_path( - node_type, relative_path, knowledge_graph_fixture + result = graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path( + node_type, relative_path ) result_data = result[1] @@ -119,9 +119,9 @@ async def test_find_ast_node_with_type_in_file_with_relative_path(knowledge_grap @pytest.mark.slow -async def test_find_text_node_with_text(knowledge_graph_fixture): +async def test_find_text_node_with_text(graph_traversal_tool): text = "Text under header C" - result = graph_traversal.find_text_node_with_text(text, knowledge_graph_fixture) + result = graph_traversal_tool.find_text_node_with_text(text) result_data = result[1] assert len(result_data) > 0 @@ -137,12 +137,10 @@ async def test_find_text_node_with_text(knowledge_graph_fixture): @pytest.mark.slow -async def test_find_text_node_with_text_in_file(knowledge_graph_fixture): +async def test_find_text_node_with_text_in_file(graph_traversal_tool): basename = test_project_paths.MD_FILE.name text = "Text under header B" - result = graph_traversal.find_text_node_with_text_in_file( - text, basename, knowledge_graph_fixture - ) + result = graph_traversal_tool.find_text_node_with_text_in_file(text, basename) result_data = result[1] assert len(result_data) > 0 @@ -158,9 +156,9 @@ async def test_find_text_node_with_text_in_file(knowledge_graph_fixture): @pytest.mark.slow -async def test_get_next_text_node_with_node_id(knowledge_graph_fixture): +async def test_get_next_text_node_with_node_id(graph_traversal_tool): node_id = 34 - result = graph_traversal.get_next_text_node_with_node_id(node_id, knowledge_graph_fixture) + result = graph_traversal_tool.get_next_text_node_with_node_id(node_id) result_data = result[1] assert len(result_data) > 0 @@ -176,13 +174,11 @@ async def test_get_next_text_node_with_node_id(knowledge_graph_fixture): @pytest.mark.slow -async def test_read_code_with_relative_path(knowledge_graph_fixture): # noqa: F811 +async def test_read_code_with_relative_path(graph_traversal_tool): relative_path = str( test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) - result = graph_traversal.read_code_with_relative_path( - relative_path, 5, 6, knowledge_graph_fixture - ) + result = graph_traversal_tool.read_code_with_relative_path(relative_path, 5, 6) result_data = result[1] assert len(result_data) > 0 diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index 2c46a699..c9a08107 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -1,25 +1,35 @@ import pytest from prometheus.exceptions.file_operation_exception import FileOperationException -from prometheus.tools.file_operation import create_file from prometheus.utils.file_utils import ( read_file_with_line_numbers, ) -from tests.test_utils.fixtures import temp_test_dir # noqa: F401 +from tests.test_utils.test_project_paths import TEST_PROJECT_PATH -def test_read_file_with_line_numbers(temp_test_dir): # noqa: F811 +def test_read_file_with_line_numbers(): """Test reading specific line ranges from a file.""" - content = "line 1\nline 2\nline 3\nline 4\nline 5" - create_file("test_lines.txt", str(temp_test_dir), content) - # Test reading specific lines - result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 2, 4) - expected = "2. line 2\n3. line 3\n4. line 4" + result = read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 1, 15) + expected = """1. # A +2. +3. Text under header A. +4. +5. ## B +6. +7. Text under header B. +8. +9. ## C +10. +11. Text under header C. +12. +13. ### D +14. +15. Text under header D.""" assert result == expected # Test invalid range should raise exception with pytest.raises(FileOperationException) as exc_info: - read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 4, 2) + read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 4, 2) assert str(exc_info.value) == "The end line number must be greater than the start line number."