diff --git a/README.md b/README.md index 780984ac..23ad9015 100644 --- a/README.md +++ b/README.md @@ -177,39 +177,6 @@ Verify Neo4J at: [http://localhost:7474](http://localhost:7474) --- -## โš™๏ธ Configuration - -Set the following variables in your `.env` file: - -### ๐Ÿ”น Neo4j - -* `PROMETHEUS_NEO4J_URI` -* `PROMETHEUS_NEO4J_USERNAME` -* `PROMETHEUS_NEO4J_PASSWORD` - -### ๐Ÿ”น LLM Models - -* `PROMETHEUS_ADVANCED_MODEL` -* `PROMETHEUS_BASE_MODEL` -* API Keys: - - * `PROMETHEUS_OPENAI_FORMAT_API_KEY` - * `PROMETHEUS_ANTHROPIC_API_KEY` - * `PROMETHEUS_GEMINI_API_KEY` -* Base URL for LLMs: - - * `PROMETHEUS_OPENAI_FORMAT_BASE_URL` - -### ๐Ÿ”น Other Settings - -* `PROMETHEUS_WORKING_DIRECTORY` -* `PROMETHEUS_GITHUB_ACCESS_TOKEN` -* `PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH` -* `PROMETHEUS_NEO4J_BATCH_SIZE` -* `PROMETHEUS_POSTGRES_URL` - ---- - ## ๐Ÿงช Development ### Requirements diff --git a/prometheus/app/api/routes/issue.py b/prometheus/app/api/routes/issue.py index 6b19a723..e58582c6 100644 --- a/prometheus/app/api/routes/issue.py +++ b/prometheus/app/api/routes/issue.py @@ -4,6 +4,7 @@ from prometheus.app.models.requests.issue import IssueRequest from prometheus.app.models.response.issue import IssueResponse from prometheus.app.models.response.response import Response +from prometheus.app.services.issue_service import IssueService from prometheus.app.services.knowledge_graph_service import KnowledgeGraphService from prometheus.app.services.repository_service import RepositoryService from prometheus.configuration.config import settings @@ -21,7 +22,7 @@ response_model=Response[IssueResponse], ) @requireLogin -def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueResponse]: +async def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueResponse]: repository_service: RepositoryService = request.app.state.service["repository_service"] repository = repository_service.get_repository_by_id(issue.repository_id) if not repository: @@ -36,6 +37,12 @@ def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueRespons message="workdir must be provided for user defined environment", ) + if repository.is_working: + raise ServerException( + code=400, + message="The repository is currently being used. Please try again later.", + ) + knowledge_graph_service: KnowledgeGraphService = request.app.state.service[ "knowledge_graph_service" ] @@ -47,6 +54,9 @@ def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueRespons repository.kg_chunk_size, repository.kg_chunk_overlap, ) + + issue_service: IssueService = request.app.state.service["issue_service"] + ( remote_branch_name, patch, @@ -54,7 +64,8 @@ def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueRespons passed_build, passed_existing_test, issue_response, - ) = request.app.state.service["issue_service"].answer_issue( + ) = await issue_service.answer_issue( + repository_id=repository.id, repository=git_repository, knowledge_graph=knowledge_graph, issue_number=issue.issue_number, @@ -73,6 +84,7 @@ def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueRespons test_commands=issue.test_commands, push_to_remote=issue.push_to_remote, ) + repository_service.update_repository_status(repository.id, is_working=False) return Response( data=IssueResponse( patch=patch, diff --git a/prometheus/app/dependencies.py b/prometheus/app/dependencies.py index f67ba593..787052b2 100644 --- a/prometheus/app/dependencies.py +++ b/prometheus/app/dependencies.py @@ -57,9 +57,11 @@ def initialize_services() -> dict[str, BaseService]: ) issue_service = IssueService( neo4j_service, + repository_service, llm_service, settings.MAX_TOKEN_PER_NEO4J_RESULT, settings.WORKING_DIRECTORY, + settings.LOGGING_LEVEL, ) user_service = UserService(database_service) diff --git a/prometheus/app/entity/repository.py b/prometheus/app/entity/repository.py index 1fb3052f..c97c965c 100644 --- a/prometheus/app/entity/repository.py +++ b/prometheus/app/entity/repository.py @@ -20,6 +20,10 @@ class Repository(SQLModel, table=True): max_length=300, description="The playground path of the repository where the repository was cloned.", ) + is_working: bool = Field( + default=False, + description="Indicates whether the repository is currently being used for processing or not.", + ) user_id: int = Field( index=True, nullable=True, description="The ID of the user who upload this repository." ) diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index 9731a766..79df5f39 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -1,4 +1,6 @@ +import asyncio import logging +import threading import traceback import uuid from datetime import datetime @@ -10,6 +12,7 @@ from prometheus.app.services.neo4j_service import Neo4jService from prometheus.docker.general_container import GeneralContainer from prometheus.docker.user_defined_container import UserDefinedContainer +from prometheus.exceptions.server_exception import ServerException from prometheus.git.git_repository import GitRepository from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_graph import IssueGraph @@ -20,19 +23,24 @@ class IssueService(BaseService): def __init__( self, neo4j_service: Neo4jService, + repository_service, llm_service: LLMService, max_token_per_neo4j_result: int, working_directory: str, + logging_level: str, ): self.neo4j_service = neo4j_service + self.repository_service = repository_service self.llm_service = llm_service self.max_token_per_neo4j_result = max_token_per_neo4j_result self.working_directory = working_directory self.answer_issue_log_dir = Path(self.working_directory) / "answer_issue_logs" self.answer_issue_log_dir.mkdir(parents=True, exist_ok=True) + self.logging_level = logging_level - def answer_issue( + async def answer_issue( self, + repository_id: int, repository: GitRepository, knowledge_graph: KnowledgeGraph, issue_number: int, @@ -55,6 +63,7 @@ def answer_issue( Processes an issue, generates patches if needed, runs optional builds and tests, and returning the results. Args: + repository_id: The ID of the repository to update. repository (GitRepository): The Git repository instance. knowledge_graph (KnowledgeGraph): The knowledge graph instance. issue_number (int): The number of the issue. @@ -80,39 +89,130 @@ def answer_issue( - passed_existing_test (bool): Whether the existing tests passed. - issue_response (str): Response generated for the issue. """ - logger = logging.getLogger("prometheus") + + # Initialize the issue graph with the necessary services and parameters + ( + edit_patch, + passed_reproducing_test, + passed_build, + passed_existing_test, + issue_response, + issue_type, + ) = await asyncio.to_thread( + self.__answer, + repository_id=repository_id, + issue_title=issue_title, + issue_body=issue_body, + issue_comments=issue_comments, + issue_type=issue_type, + run_build=run_build, + run_existing_test=run_existing_test, + run_reproduce_test=run_reproduce_test, + number_of_candidate_patch=number_of_candidate_patch, + knowledge_graph=knowledge_graph, + repository=repository, + build_commands=build_commands, + test_commands=test_commands, + dockerfile_content=dockerfile_content, + image_name=image_name, + workdir=workdir, + ) + if ( + edit_patch, + passed_reproducing_test, + passed_build, + passed_existing_test, + issue_response, + issue_type, + ) == (None, False, False, False, None, None): + raise ServerException(500, "Failed to process the issue due to an internal error.") + if issue_type == IssueType.BUG: + # push to remote if requested + remote_branch_name = None + if edit_patch and push_to_remote: + remote_branch_name = f"prometheus_fix_{uuid.uuid4().hex[:10]}" + await repository.create_and_push_branch( + remote_branch_name, f"Fixes #{issue_number}", edit_patch + ) + + return ( + remote_branch_name, + edit_patch, + passed_reproducing_test, + passed_build, + passed_existing_test, + issue_response, + ) + elif issue_type == IssueType.QUESTION: + return ( + None, + None, + False, + False, + False, + issue_response, + ) + else: + raise ValueError(f"Unknown issue type: {issue_type}. Expected BUG or QUESTION.") + + def __answer( + self, + knowledge_graph: KnowledgeGraph, + repository: GitRepository, + repository_id: int, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + issue_type: IssueType, + run_build: bool, + run_existing_test: bool, + run_reproduce_test: bool, + number_of_candidate_patch: int, + build_commands: Optional[Sequence[str]], + test_commands: Optional[Sequence[str]], + dockerfile_content: Optional[str] = None, + image_name: Optional[str] = None, + workdir: Optional[str] = None, + ) -> tuple[None, bool, bool, bool, None, None] | tuple[str, bool, bool, bool, str, IssueType]: + # Set up a dedicated logger for this thread + logger = logging.getLogger(f"thread-{threading.get_ident()}.prometheus") + logger.setLevel(getattr(logging, self.logging_level)) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = self.answer_issue_log_dir / f"{timestamp}.log" + log_file = self.answer_issue_log_dir / f"{timestamp}_{threading.get_ident()}.log" file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) logger.addHandler(file_handler) - try: - # Construct the working directory - if dockerfile_content or image_name: - container = UserDefinedContainer( - repository.get_working_directory(), - workdir, - build_commands, - test_commands, - dockerfile_content, - image_name, - ) - else: - container = GeneralContainer(repository.get_working_directory()) - # Initialize the issue graph with the necessary services and parameters - issue_graph = IssueGraph( - advanced_model=self.llm_service.advanced_model, - base_model=self.llm_service.base_model, - kg=knowledge_graph, - git_repo=repository, - neo4j_driver=self.neo4j_service.neo4j_driver, - max_token_per_neo4j_result=self.max_token_per_neo4j_result, - container=container, - build_commands=build_commands, - test_commands=test_commands, + # Construct the working directory + if dockerfile_content or image_name: + container = UserDefinedContainer( + repository.get_working_directory(), + workdir, + build_commands, + test_commands, + dockerfile_content, + image_name, ) + else: + container = GeneralContainer(repository.get_working_directory()) + + # Initialize the IssueGraph with the provided services and parameters + issue_graph = IssueGraph( + advanced_model=self.llm_service.advanced_model, + base_model=self.llm_service.base_model, + kg=knowledge_graph, + git_repo=repository, + neo4j_driver=self.neo4j_service.neo4j_driver, + max_token_per_neo4j_result=self.max_token_per_neo4j_result, + container=container, + build_commands=build_commands, + test_commands=test_commands, + ) + + # Update the repository status to working + self.repository_service.update_repository_status(repository_id, is_working=True) + try: # Invoke the issue graph with the provided parameters output_state = issue_graph.invoke( issue_title, @@ -124,40 +224,18 @@ def answer_issue( run_reproduce_test, number_of_candidate_patch, ) - - if output_state["issue_type"] == IssueType.BUG: - # push to remote if requested - remote_branch_name = None - if output_state["edit_patch"] and push_to_remote: - remote_branch_name = f"prometheus_fix_{uuid.uuid4().hex[:10]}" - repository.create_and_push_branch( - remote_branch_name, f"Fixes #{issue_number}", output_state["edit_patch"] - ) - - return ( - remote_branch_name, - output_state["edit_patch"], - output_state["passed_reproducing_test"], - output_state["passed_build"], - output_state["passed_existing_test"], - output_state["issue_response"], - ) - elif output_state["issue_type"] == IssueType.QUESTION: - return ( - None, - None, - False, - False, - False, - output_state["issue_response"], - ) - - raise ValueError( - f"Unknown issue type: {output_state['issue_type']}. Expected BUG or QUESTION." + return ( + output_state["edit_patch"], + output_state["passed_reproducing_test"], + output_state["passed_build"], + output_state["passed_existing_test"], + output_state["issue_response"], + output_state["issue_type"], ) except Exception as e: logger.error(f"Error in answer_issue: {str(e)}\n{traceback.format_exc()}") - return None, None, False, False, False, None + return None, False, False, False, None, None finally: + self.repository_service.update_repository_status(repository_id, is_working=False) logger.removeHandler(file_handler) file_handler.close() diff --git a/prometheus/app/services/knowledge_graph_service.py b/prometheus/app/services/knowledge_graph_service.py index e536127a..4a1b7c54 100644 --- a/prometheus/app/services/knowledge_graph_service.py +++ b/prometheus/app/services/knowledge_graph_service.py @@ -1,5 +1,6 @@ """Service for managing and interacting with Knowledge Graphs in Neo4j.""" +import asyncio from pathlib import Path from prometheus.app.services.base_service import BaseService @@ -39,6 +40,7 @@ def __init__( self.max_ast_depth = max_ast_depth self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap + self.writing_lock = asyncio.Lock() async def build_and_save_knowledge_graph(self, path: Path) -> int: """Builds a new Knowledge Graph from source code and saves it to Neo4j. @@ -52,11 +54,14 @@ async def build_and_save_knowledge_graph(self, path: Path) -> int: Returns: The root node ID of the newly created Knowledge Graph. """ - root_node_id = self.kg_handler.get_new_knowledge_graph_root_node_id() - kg = KnowledgeGraph(self.max_ast_depth, self.chunk_size, self.chunk_overlap, root_node_id) - await kg.build_graph(path) - self.kg_handler.write_knowledge_graph(kg) - return kg.root_node_id + async with self.writing_lock: # Ensure only one build operation at a time + root_node_id = self.kg_handler.get_new_knowledge_graph_root_node_id() + kg = KnowledgeGraph( + self.max_ast_depth, self.chunk_size, self.chunk_overlap, root_node_id + ) + await kg.build_graph(path) + self.kg_handler.write_knowledge_graph(kg) + return kg.root_node_id def clear_kg(self, root_node_id: int): self.kg_handler.clear_knowledge_graph(root_node_id) diff --git a/prometheus/app/services/repository_service.py b/prometheus/app/services/repository_service.py index 8f6ed9c9..5721b6c6 100644 --- a/prometheus/app/services/repository_service.py +++ b/prometheus/app/services/repository_service.py @@ -150,6 +150,21 @@ def get_repository_by_url_and_commit_id(self, url: str, commit_id: str) -> Optio ) return session.exec(statement).first() + def update_repository_status(self, repository_id: int, is_working: bool): + """ + Updates the working status of a repository. + + Args: + repository_id: The ID of the repository to update. + is_working: The new working status to set for the repository. + """ + with Session(self.engine) as session: + repository = session.get(Repository, repository_id) + if repository: + repository.is_working = is_working + session.add(repository) + session.commit() + def clean_repository(self, repository: Repository): path = Path(repository.playground_path) if path.exists(): diff --git a/prometheus/docker/base_container.py b/prometheus/docker/base_container.py index c09f5187..bf486e17 100644 --- a/prometheus/docker/base_container.py +++ b/prometheus/docker/base_container.py @@ -2,6 +2,7 @@ import shutil import tarfile import tempfile +import threading from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Sequence @@ -34,7 +35,9 @@ def __init__(self, project_path: Path, workdir: Optional[str] = None): Args: project_path: Path to the project directory to be containerized. """ - self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.{self.__class__.__module__}.{self.__class__.__name__}" + ) temp_dir = Path(tempfile.mkdtemp()) temp_project_path = temp_dir / project_path.name shutil.copytree(project_path, temp_project_path) diff --git a/prometheus/git/git_repository.py b/prometheus/git/git_repository.py index 95b23498..6f65ae08 100644 --- a/prometheus/git/git_repository.py +++ b/prometheus/git/git_repository.py @@ -131,7 +131,7 @@ def remove_repository(self): shutil.rmtree(self.repo.working_dir) self.repo = None - def create_and_push_branch(self, branch_name: str, commit_message: str, patch: str): + async def create_and_push_branch(self, branch_name: str, commit_message: str, patch: str): """Create a new branch, commit changes, and push to remote. This method creates a new branch, switches to it, stages all changes, @@ -155,6 +155,6 @@ def create_and_push_branch(self, branch_name: str, commit_message: str, patch: s self.repo.git.apply(tmp_file.name) self.repo.git.add(A=True) self.repo.index.commit(commit_message) - self.repo.git.push("--set-upstream", "origin", branch_name) + await asyncio.to_thread(self.repo.git.push, "--set-upstream", "origin", branch_name) self.reset_repository() self.switch_branch(self.default_branch) diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index 25620677..4486e8b5 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -1,4 +1,5 @@ import logging +import threading from langchain_core.language_models.chat_models import BaseChatModel @@ -14,7 +15,7 @@ def __init__( container: BaseContainer, ): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_fix_verification_subgraph_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_fix_verification_subgraph_node" ) self.subgraph = BugFixVerificationSubgraph( model=model, diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index 27094be6..21d2ca6f 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -1,5 +1,6 @@ import functools import logging +import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -53,7 +54,9 @@ def __init__(self, model: BaseChatModel, container: BaseContainer): self.tools = self._init_tools(container) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_verify_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_verify_node" + ) def _init_tools(self, container: BaseContainer): tools = [] diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py index ac4fb533..1c86faf2 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py @@ -1,4 +1,5 @@ import logging +import threading from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -91,7 +92,7 @@ def __init__(self, model: BaseChatModel): structured_llm = model.with_structured_output(BugFixVerifyStructureOutput) self.model = prompt | structured_llm self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_fix_verify_structured_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_fix_verify_structured_node" ) def __call__(self, state: BugFixVerficationState): diff --git a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py index e7121965..f99f3a4f 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py @@ -1,5 +1,6 @@ import functools import logging +import threading from pathlib import Path from typing import Optional, Sequence @@ -55,7 +56,9 @@ def __init__( self.tools = self._init_tools(container) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_execute_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_execute_node" + ) def _init_tools(self, container: BaseContainer): tools = [] diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index f38cd088..2c63f8d7 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -1,5 +1,6 @@ import functools import logging +import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -40,7 +41,9 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph, local_path: str): self.tools = self._init_tools(local_path) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_file_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_file_node" + ) def _init_tools(self, root_path: str): """Initializes file operation tools with the given root path. diff --git a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py index 9fa204aa..6628d903 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -135,7 +136,7 @@ def __init__(self, model: BaseChatModel): structured_llm = model.with_structured_output(BugReproducingStructuredOutput) self.model = prompt | structured_llm self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_reproducing_structured_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_structured_node" ) def __call__(self, state: BugReproductionState): diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py index f1c85d97..f2808f0a 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py @@ -1,4 +1,5 @@ import logging +import threading from langchain_core.messages import HumanMessage @@ -25,7 +26,7 @@ class BugReproducingWriteMessageNode: def __init__(self): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_reproducing_write_message_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_write_message_node" ) def format_human_message(self, state: BugReproductionState): diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index a70d91e4..827bbcf6 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -1,5 +1,6 @@ import functools import logging +import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -114,7 +115,9 @@ def __init__(self, model: BaseChatModel, local_path: str): self.tools = self._init_tools(local_path) self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_write_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproducing_write_node" + ) def _init_tools(self, root_path: str): """Initializes file operation tools with the given root path. diff --git a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py index 39b0d0a5..8caa4853 100644 --- a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Optional, Sequence import neo4j @@ -25,7 +26,7 @@ def __init__( test_commands: Optional[Sequence[str]], ): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_reproduction_subgraph_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.bug_reproduction_subgraph_node" ) self.git_repo = git_repo self.bug_reproduction_subgraph = BugReproductionSubgraph( diff --git a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py index 7f84c90b..946d1bb2 100644 --- a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py +++ b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Optional, Sequence from langchain_core.language_models.chat_models import BaseChatModel @@ -25,7 +26,9 @@ def __init__( build_commands=build_commands, test_commands=test_commands, ) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.build_and_test_subgraph_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.build_and_test_subgraph_node" + ) def __call__(self, state: IssueBugState): exist_build = None diff --git a/prometheus/lang_graph/nodes/context_extraction_node.py b/prometheus/lang_graph/nodes/context_extraction_node.py index cf331b34..c01876e6 100644 --- a/prometheus/lang_graph/nodes/context_extraction_node.py +++ b/prometheus/lang_graph/nodes/context_extraction_node.py @@ -1,27 +1,32 @@ import logging +import threading from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import SystemMessage +from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field from prometheus.exceptions.file_operation_exception import FileOperationException from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState from prometheus.models.context import Context from prometheus.utils.file_utils import read_file_with_line_numbers +from prometheus.utils.lang_graph_util import ( + extract_last_tool_messages, + transform_tool_messages_to_str, +) SYS_PROMPT = """\ -You are a context summary agent that summarizes code context that is relevant to a given query from history messages. - Your goal is to extract, evaluate and summary code context that directly answers the query requirements. +You are a context summary agent that summarizes code contexts which is relevant to a given query. + Your goal is to extract, evaluate and summary code contexts that directly answers the query requirements. Your evaluation and summarization must consider two key aspects: -1. Query Match: Which set of history messages directly address specific requirements mentioned in the query? -2. Extended relevance: Which set of history messages provide essential information needed to understand the query topic? +1. Query Match: Which set of contexts directly address specific requirements mentioned in the query? +2. Extended relevance: Which set of contexts provide essential information needed to understand the query topic? Follow these strict evaluation steps: 1. First, identify specific requirements in the query -2. Check which set of history messages directly addresses these requirements -3. Check which parts of code context are relevant to the query +2. Check which set of contexts directly addresses these requirements +3. Check which parts of code contexts are relevant to the query 4. Consider if they provides essential context by examining: - Function dependencies - Type definitions @@ -37,27 +42,37 @@ CRITICAL RULE: - You don't have to select whole piece of code that you have seen, ONLY select the parts that are relevant to the query. - Each context MUST be SHORT and CONCISE, focusing ONLY on the lines that are relevant to the query. -- Several context can be extracted from the same file, but each context must be concise and relevant to the query. +- Several contexts can be extracted from the same file, but each context must be concise and relevant to the query. - Do NOT include any irrelevant lines or comments that do not contribute to answering the query. -- Do NOT include same context multiple times, even if it appears in different history messages. +- Do NOT include same context multiple times. -Remember: Your primary goal is to summarize context that directly helps answer the query requirements. +Remember: Your primary goal is to summarize contexts that directly helps answer the query requirements. Provide your analysis in a structured format matching the ContextExtractionStructuredOutput model. Example output: ```json -{ - "context": [{ +{{ + "context": [{{ "reasoning": "1. Query requirement analysis:\n - Query specifically asks about password validation\n - Context provides implementation details for password validation\n2. Extended relevance:\n - This function is essential for understanding how passwords are validated in the system", "relative_path": "pychemia/code/fireball/fireball.py", "start_line": 270, # Must be greater than or equal to 1 "end_line": 293 # Must be greater than or equal to start_line - } ......] -} + }} ......] +}} ``` -Your task is to summarize the context from the provided history messages and return it in the specified format. +Your task is to summarize the relevant contexts to a given query and return it in the specified format. +""" + +HUMAN_MESSAGE = """\ +This is the original user query: +{original_query} + +The context or file content that you have seen so far (Some of the context may be IRRELEVANT to the query!!!): +{context} + +REMEMBER: Your task is to summarize the relevant contexts to a given query and return it in the specified format! """ @@ -84,20 +99,38 @@ class ContextExtractionStructuredOutput(BaseModel): class ContextExtractionNode: def __init__(self, model: BaseChatModel, root_path: str): + prompt = ChatPromptTemplate.from_messages( + [ + ("system", SYS_PROMPT), + ("human", "{human_prompt}"), + ] + ) structured_llm = model.with_structured_output(ContextExtractionStructuredOutput) - self.model = structured_llm - self.system_prompt = SystemMessage(SYS_PROMPT) + self.model = prompt | structured_llm self.root_path = root_path - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_extraction_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_extraction_node" + ) + + def get_human_message(self, state: ContextRetrievalState) -> str: + full_context_str = transform_tool_messages_to_str( + extract_last_tool_messages(state["context_provider_messages"]) + ) + original_query = state["query"] + return HUMAN_MESSAGE.format( + original_query=original_query, + context=full_context_str, + ) def __call__(self, state: ContextRetrievalState): self._logger.info("Starting context extraction process") # Get Context List with existing context final_context = state.get("context", []) - # Get chat history messages from the state - last_messages = state["context_provider_messages"] + # Get a human message + human_message = self.get_human_message(state) + self._logger.debug(human_message) # Summarize the context based on the last messages and system prompt - response = self.model.invoke([self.system_prompt] + last_messages) + response = self.model.invoke({"human_prompt": human_message}) self._logger.debug(f"Model response: {response}") context_list = response.context for context_ in context_list: @@ -116,22 +149,20 @@ def __call__(self, state: ContextRetrievalState): except FileOperationException as e: self._logger.error(e) continue - if content: - final_context.append( - Context( - relative_path=context_.relative_path, - start_line_number=context_.start_line, - end_line_number=context_.end_line, - content=content, - ) + if not content: + self._logger.warning( + f"Skipping context with empty content for {context_.relative_path} " + f"from line {context_.start_line} to {context_.end_line}" ) - # Filter out duplicate Context entries - seen = set() - unique_context = [] - for ctx in final_context: - key = (ctx.relative_path, ctx.start_line_number, ctx.end_line_number) - if key not in seen: - seen.add(key) - unique_context.append(ctx) - self._logger.info(f"Context extraction complete, returning context {unique_context}") - return {"context": unique_context} + continue + context = Context( + relative_path=context_.relative_path, + start_line_number=context_.start_line, + end_line_number=context_.end_line, + content=content, + ) + if context not in final_context: + final_context = final_context + [context] + + self._logger.info(f"Context extraction complete, returning context {final_context}") + return {"context": final_context} diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index 30fd8345..cbf70a7c 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -7,6 +7,7 @@ import functools import logging +import threading from typing import Dict import neo4j @@ -115,8 +116,9 @@ def __init__( ) self.tools = self._init_tools() self.model_with_tools = model.bind_tools(self.tools) - - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_provider_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_provider_node" + ) def _init_tools(self): """ diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py index ba43a5b1..e5372b59 100644 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ b/prometheus/lang_graph/nodes/context_query_message_node.py @@ -1,4 +1,5 @@ import logging +import threading from langchain_core.messages import HumanMessage @@ -7,7 +8,9 @@ class ContextQueryMessageNode: def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_query_message_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_query_message_node" + ) def __call__(self, state: ContextRetrievalState): human_message = HumanMessage(state["query"]) diff --git a/prometheus/lang_graph/nodes/context_refine_node.py b/prometheus/lang_graph/nodes/context_refine_node.py index 05861bcf..1cfc4ee1 100644 --- a/prometheus/lang_graph/nodes/context_refine_node.py +++ b/prometheus/lang_graph/nodes/context_refine_node.py @@ -1,4 +1,5 @@ import logging +import threading from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage @@ -79,7 +80,9 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): ) structured_llm = model.with_structured_output(ContextRefineStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_refine_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_refine_node" + ) def format_refine_message(self, state: ContextRetrievalState): original_query = state["query"] diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index 7e9877d5..5e817354 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict, Sequence import neo4j @@ -21,7 +22,7 @@ def __init__( context_key_name: str, ): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.context_retrieval_subgraph_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.context_retrieval_subgraph_node" ) self.context_retrieval_subgraph = ContextRetrievalSubgraph( model=model, diff --git a/prometheus/lang_graph/nodes/edit_message_node.py b/prometheus/lang_graph/nodes/edit_message_node.py index bbbb93a8..11eacc80 100644 --- a/prometheus/lang_graph/nodes/edit_message_node.py +++ b/prometheus/lang_graph/nodes/edit_message_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict from langchain_core.messages import HumanMessage @@ -32,7 +33,9 @@ class EditMessageNode: """ def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.edit_message_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.edit_message_node" + ) def format_human_message(self, state: Dict): edit_error = "" diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index 37544d18..4e10f441 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -8,6 +8,7 @@ import functools import logging +import threading from typing import Dict from langchain.tools import StructuredTool @@ -119,7 +120,9 @@ def __init__(self, model: BaseChatModel, local_path: str): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.tools = self._init_tools(local_path) self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.edit_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.edit_node" + ) def _init_tools(self, root_path: str): """Initializes file operation tools with the given root path. diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index 5637a3c5..e8bef233 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -126,7 +127,9 @@ def __init__(self, model: BaseChatModel, max_retries: int = 2): ) structured_llm = model.with_structured_output(FinalPatchSelectionStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.final_patch_selection_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.final_patch_selection_node" + ) def format_human_message(self, state: Dict): patches = "" diff --git a/prometheus/lang_graph/nodes/general_build_node.py b/prometheus/lang_graph/nodes/general_build_node.py index 1d3454f5..486763a8 100644 --- a/prometheus/lang_graph/nodes/general_build_node.py +++ b/prometheus/lang_graph/nodes/general_build_node.py @@ -1,5 +1,6 @@ import functools import logging +import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -46,7 +47,9 @@ def __init__(self, model: BaseChatModel, container: BaseContainer, kg: Knowledge self.tools = self._init_tools(container) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_build_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.general_build_node" + ) def _init_tools(self, container: BaseContainer): tools = [] diff --git a/prometheus/lang_graph/nodes/general_build_structured_node.py b/prometheus/lang_graph/nodes/general_build_structured_node.py index af9058bc..e41463d7 100644 --- a/prometheus/lang_graph/nodes/general_build_structured_node.py +++ b/prometheus/lang_graph/nodes/general_build_structured_node.py @@ -7,6 +7,7 @@ """ import logging +import threading from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -238,7 +239,7 @@ def __init__(self, model: BaseChatModel): structured_llm = model.with_structured_output(BuildStructuredOutput) self.model = prompt | structured_llm self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.general_build_structured_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.general_build_structured_node" ) def __call__(self, state: BuildAndTestState): diff --git a/prometheus/lang_graph/nodes/general_test_node.py b/prometheus/lang_graph/nodes/general_test_node.py index d35c1ee8..d875c539 100644 --- a/prometheus/lang_graph/nodes/general_test_node.py +++ b/prometheus/lang_graph/nodes/general_test_node.py @@ -1,5 +1,6 @@ import functools import logging +import threading from langchain.tools import StructuredTool from langchain_core.language_models.chat_models import BaseChatModel @@ -63,7 +64,9 @@ def __init__(self, model: BaseChatModel, container: BaseContainer, kg: Knowledge self.tools = self._init_tools(container) self.model_with_tools = model.bind_tools(self.tools) self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_test_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.general_test_node" + ) def _init_tools(self, container: BaseContainer): tools = [] diff --git a/prometheus/lang_graph/nodes/general_test_structured_node.py b/prometheus/lang_graph/nodes/general_test_structured_node.py index 78758070..8da0b868 100644 --- a/prometheus/lang_graph/nodes/general_test_structured_node.py +++ b/prometheus/lang_graph/nodes/general_test_structured_node.py @@ -7,6 +7,7 @@ """ import logging +import threading from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -286,7 +287,9 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(TestStructuredOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_test_structured_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.general_test_structured_node" + ) def __call__(self, state: BuildAndTestState): """Processes test state to generate structured test analysis. diff --git a/prometheus/lang_graph/nodes/git_diff_node.py b/prometheus/lang_graph/nodes/git_diff_node.py index 4a8be50c..4cce12bd 100644 --- a/prometheus/lang_graph/nodes/git_diff_node.py +++ b/prometheus/lang_graph/nodes/git_diff_node.py @@ -7,6 +7,7 @@ """ import logging +import threading from typing import Dict, Optional from prometheus.git.git_repository import GitRepository @@ -32,7 +33,9 @@ def __init__( self.state_patch_name = state_patch_name self.state_excluded_files_key = state_excluded_files_key self.return_list = return_list - self._logger = logging.getLogger("prometheus.lang_graph.nodes.git_diff_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.git_diff_node" + ) def __call__(self, state: Dict): """Generates a Git diff for the current project state. diff --git a/prometheus/lang_graph/nodes/git_reset_node.py b/prometheus/lang_graph/nodes/git_reset_node.py index b9d63b30..1bf03f6b 100644 --- a/prometheus/lang_graph/nodes/git_reset_node.py +++ b/prometheus/lang_graph/nodes/git_reset_node.py @@ -1,4 +1,5 @@ import logging +import threading from prometheus.git.git_repository import GitRepository @@ -9,7 +10,9 @@ def __init__( git_repo: GitRepository, ): self.git_repo = git_repo - self._logger = logging.getLogger("prometheus.lang_graph.nodes.git_reset_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.git_reset_node" + ) def __call__(self, _): self._logger.debug("Resetting the git repository") diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py index 89845ced..ceff2bcc 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict from langchain_core.messages import HumanMessage @@ -65,7 +66,7 @@ class IssueBugAnalyzerMessageNode: def __init__(self): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_bug_analyzer_message_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_analyzer_message_node" ) def format_human_message(self, state: Dict): diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index 74586fb9..d3c501aa 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -43,7 +44,9 @@ class IssueBugAnalyzerNode: def __init__(self, model: BaseChatModel): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model = model - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_analyzer_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_analyzer_node" + ) def __call__(self, state: Dict): message_history = [self.system_prompt] + state["issue_bug_analyzer_messages"] diff --git a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py index acf6ccd8..25e1e5c3 100644 --- a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict from prometheus.utils.issue_util import format_issue_info @@ -20,7 +21,7 @@ class IssueBugContextMessageNode: def __init__(self): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_bug_context_message_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_context_message_node" ) def __call__(self, state: Dict): diff --git a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py index e951e2d4..3c4946a3 100644 --- a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py @@ -1,4 +1,5 @@ import logging +import threading from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_info @@ -109,7 +110,7 @@ def test_file_permission_denied(self, mock_open, mock_access): def __init__(self): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_bug_reproduction_context_message_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_reproduction_context_message_node" ) def __call__(self, state: BugReproductionState): diff --git a/prometheus/lang_graph/nodes/issue_bug_responder_node.py b/prometheus/lang_graph/nodes/issue_bug_responder_node.py index 1be5f22e..50578820 100644 --- a/prometheus/lang_graph/nodes/issue_bug_responder_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_responder_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict from langchain_core.language_models.chat_models import BaseChatModel @@ -48,7 +49,9 @@ def __init__(self, model: BaseChatModel): self.system_prompt = SystemMessage(self.SYS_PROMPT) self.model = model - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_responder_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_responder_node" + ) def format_human_message(self, state: Dict) -> HumanMessage: verification_messages = [] diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index ac34dc92..f954ca1a 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Optional, Sequence import neo4j @@ -29,7 +30,9 @@ def __init__( build_commands: Optional[Sequence[str]] = None, test_commands: Optional[Sequence[str]] = None, ): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_subgraph_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_bug_subgraph_node" + ) self.container = container self.issue_bug_subgraph = IssueBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py index a90e484a..552c50ec 100644 --- a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py @@ -1,4 +1,5 @@ import logging +import threading from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState from prometheus.utils.issue_util import format_issue_info @@ -73,7 +74,7 @@ class IssueClassificationContextMessageNode: def __init__(self): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_classification_context_message_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_classification_context_message_node" ) def __call__(self, state: IssueClassificationState): diff --git a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py index bc71d849..6fcf059e 100644 --- a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py @@ -1,4 +1,5 @@ import logging +import threading import neo4j from langchain_core.language_models.chat_models import BaseChatModel @@ -20,7 +21,7 @@ def __init__( max_token_per_neo4j_result: int, ): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_classification_subgraph_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_classification_subgraph_node" ) self.issue_classification_subgraph = IssueClassificationSubgraph( model=model, diff --git a/prometheus/lang_graph/nodes/issue_classifier_node.py b/prometheus/lang_graph/nodes/issue_classifier_node.py index eb8a60df..2c259a8f 100644 --- a/prometheus/lang_graph/nodes/issue_classifier_node.py +++ b/prometheus/lang_graph/nodes/issue_classifier_node.py @@ -1,4 +1,5 @@ import logging +import threading from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -124,7 +125,9 @@ def __init__(self, model: BaseChatModel): ) structured_llm = model.with_structured_output(IssueClassifierOutput) self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_classifier_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_classifier_node" + ) def format_context_info(self, state: IssueClassificationState) -> str: context_info = self.ISSUE_CLASSIFICATION_CONTEXT.format( diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index c0ce236c..c34d5ae4 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict import neo4j @@ -22,7 +23,7 @@ def __init__( max_token_per_neo4j_result: int, ): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_not_verified_bug_subgraph_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_not_verified_bug_subgraph_node" ) self.issue_not_verified_bug_subgraph = IssueNotVerifiedBugSubgraph( advanced_model=advanced_model, diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index 43ace6f2..9a0f6146 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Dict, Optional, Sequence import neo4j @@ -29,7 +30,7 @@ def __init__( test_commands: Optional[Sequence[str]] = None, ): self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_verified_bug_subgraph_node" + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_verified_bug_subgraph_node" ) self.git_repo = git_repo self.issue_reproduced_bug_subgraph = IssueVerifiedBugSubgraph( diff --git a/prometheus/lang_graph/nodes/noop_node.py b/prometheus/lang_graph/nodes/noop_node.py index e21b3f70..de06eae3 100644 --- a/prometheus/lang_graph/nodes/noop_node.py +++ b/prometheus/lang_graph/nodes/noop_node.py @@ -6,6 +6,7 @@ """ import logging +import threading from typing import Dict @@ -19,7 +20,9 @@ class NoopNode: """ def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.noop_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.noop_node" + ) def __call__(self, state: Dict) -> None: """Routes the workflow without performing any operations. diff --git a/prometheus/lang_graph/nodes/reset_messages_node.py b/prometheus/lang_graph/nodes/reset_messages_node.py index bf87c134..1c37bdde 100644 --- a/prometheus/lang_graph/nodes/reset_messages_node.py +++ b/prometheus/lang_graph/nodes/reset_messages_node.py @@ -11,6 +11,7 @@ """ import logging +import threading from typing import Dict @@ -35,7 +36,9 @@ def __init__(self, message_state_key: str): be reset during node execution. """ self.message_state_key = message_state_key - self._logger = logging.getLogger("prometheus.lang_graph.nodes.reset_messages_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.reset_messages_node" + ) def __call__(self, state: Dict): """Resets the specified message state for the next iteration. diff --git a/prometheus/lang_graph/nodes/update_container_node.py b/prometheus/lang_graph/nodes/update_container_node.py index 5c473dc6..76f73a90 100644 --- a/prometheus/lang_graph/nodes/update_container_node.py +++ b/prometheus/lang_graph/nodes/update_container_node.py @@ -7,6 +7,7 @@ """ import logging +import threading from typing import Dict from prometheus.docker.base_container import BaseContainer @@ -33,7 +34,9 @@ def __init__(self, container: BaseContainer, git_repo: GitRepository): """ self.container = container self.git_repo = git_repo - self._logger = logging.getLogger("prometheus.lang_graph.nodes.update_container_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.update_container_node" + ) def __call__(self, _: Dict): """Synchronizes the current project state with the container.""" diff --git a/prometheus/lang_graph/nodes/user_defined_build_node.py b/prometheus/lang_graph/nodes/user_defined_build_node.py index aa8dc40c..f5e47d1a 100644 --- a/prometheus/lang_graph/nodes/user_defined_build_node.py +++ b/prometheus/lang_graph/nodes/user_defined_build_node.py @@ -1,4 +1,5 @@ import logging +import threading import uuid from typing import Any @@ -10,7 +11,9 @@ class UserDefinedBuildNode: def __init__(self, container: BaseContainer): self.container = container - self._logger = logging.getLogger("prometheus.lang_graph.nodes.user_defined_build_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.user_defined_build_node" + ) def __call__(self, _: Any): build_output = self.container.run_build() diff --git a/prometheus/lang_graph/nodes/user_defined_test_node.py b/prometheus/lang_graph/nodes/user_defined_test_node.py index 7f3e6e61..47a43859 100644 --- a/prometheus/lang_graph/nodes/user_defined_test_node.py +++ b/prometheus/lang_graph/nodes/user_defined_test_node.py @@ -1,4 +1,5 @@ import logging +import threading import uuid from typing import Any @@ -10,7 +11,9 @@ class UserDefinedTestNode: def __init__(self, container: BaseContainer): self.container = container - self._logger = logging.getLogger("prometheus.lang_graph.nodes.user_defined_test_node") + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.user_defined_test_node" + ) def __call__(self, _: Any): test_output = self.container.run_test() diff --git a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py index ea1cf304..c27c6de2 100644 --- a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py +++ b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py @@ -118,7 +118,7 @@ def __init__( self.subgraph = workflow.compile() def invoke( - self, query: str, max_refined_query_loop: int, recursion_limit: int = 300 + self, query: str, max_refined_query_loop: int, recursion_limit: int = 120 ) -> Dict[str, Sequence[Context]]: """ Executes the context retrieval subgraph given an initial query. diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index 59c7c4f5..8211c021 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -115,7 +115,7 @@ def invoke( issue_body: str, issue_comments: Sequence[Mapping[str, str]], number_of_candidate_patch: int, - recursion_limit: int = 999, + recursion_limit: int = 300, ): config = {"recursion_limit": recursion_limit} diff --git a/prometheus/utils/file_utils.py b/prometheus/utils/file_utils.py index c63f79dd..f1691e90 100644 --- a/prometheus/utils/file_utils.py +++ b/prometheus/utils/file_utils.py @@ -2,6 +2,7 @@ from pathlib import Path from prometheus.exceptions.file_operation_exception import FileOperationException +from prometheus.utils.str_util import pre_append_line_numbers def read_file_with_line_numbers( @@ -32,4 +33,6 @@ def read_file_with_line_numbers( with file_path.open() as f: lines = f.readlines() - return "".join(lines[zero_based_start_line:zero_based_end_line]) + return pre_append_line_numbers( + "".join(lines[zero_based_start_line:zero_based_end_line]), zero_based_start_line + ) diff --git a/prometheus/utils/lang_graph_util.py b/prometheus/utils/lang_graph_util.py index f3eaa6fd..9f9a999a 100644 --- a/prometheus/utils/lang_graph_util.py +++ b/prometheus/utils/lang_graph_util.py @@ -8,6 +8,8 @@ ) from langchain_core.output_parsers import StrOutputParser +from prometheus.utils.neo4j_util import neo4j_data_for_context_generator + def check_remaining_steps( state: Dict, @@ -64,6 +66,15 @@ def extract_last_tool_messages(messages: Sequence[BaseMessage]) -> Sequence[Tool return tool_messages +def transform_tool_messages_to_str(messages: Sequence[ToolMessage]) -> str: + result = "" + for message in messages: + for context in neo4j_data_for_context_generator(message.artifact): + result += str(context) + result += "\n" + return result + + def get_last_message_content(messages: Sequence[BaseMessage]) -> str: output_parser = StrOutputParser() return output_parser.invoke(messages[-1]) diff --git a/prometheus/utils/neo4j_util.py b/prometheus/utils/neo4j_util.py index b594c5a9..faf24758 100644 --- a/prometheus/utils/neo4j_util.py +++ b/prometheus/utils/neo4j_util.py @@ -1,7 +1,8 @@ -from typing import Any, Mapping, Sequence, Tuple +from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple import neo4j +from prometheus.models.context import Context from prometheus.utils.str_util import truncate_text EMPTY_DATA_MESSAGE = "Your query returned empty result, please try a different query!" @@ -29,6 +30,39 @@ def format_neo4j_data(data: Sequence[Mapping[str, Any]], max_token_per_result: i return truncate_text(output.strip(), max_token_per_result) +def neo4j_data_for_context_generator( + data: Optional[Sequence[Mapping[str, Any]]], +) -> Iterator[Context]: + if data is None: + return + + for search_result in data: + search_result_keys = search_result.keys() + # Skip if the result has no keys or only contains the "FileNode" key + if len(search_result_keys) == 1: + continue + + context = Context( + relative_path=search_result["FileNode"]["relative_path"], + content=( + search_result.get("ASTNode", {}).get("text") + or search_result.get("TextNode", {}).get("text") + or search_result.get("preview", {}).get("text") + or search_result.get("SelectedLines", {}).get("text") + ), + start_line_number=( + search_result.get("ASTNode", {}).get("start_line") + or search_result.get("SelectedLines", {}).get("start_line") + or search_result.get("preview", {}).get("start_line") + ), + end_line_number=search_result.get("ASTNode", {}).get("end_line") + or search_result.get("SelectedLines", {}).get("end_line") + or search_result.get("preview", {}).get("end_line"), + ) + + yield context + + def run_neo4j_query( query: str, driver: neo4j.GraphDatabase.driver, max_token_per_result: int ) -> Tuple[str, Sequence[Mapping[str, Any]]]: diff --git a/tests/app/api/test_issue.py b/tests/app/api/test_issue.py index e7217606..c93ffe97 100644 --- a/tests/app/api/test_issue.py +++ b/tests/app/api/test_issue.py @@ -36,17 +36,19 @@ def test_answer_issue(mock_service): kg_chunk_size=1000, kg_chunk_overlap=100, ) - mock_service["issue_service"].answer_issue.return_value = ( - "feature/fix-42", # remote_branch_name - "test patch", # patch - True, # passed_reproducing_test - True, # passed_build - True, # passed_existing_test - "Issue fixed", # issue_response + mock_service["issue_service"].answer_issue = mock.AsyncMock( + return_value=( + "feature/fix-42", # remote_branch_name + "test patch", # patch + True, # passed_reproducing_test + True, # passed_build + True, # passed_existing_test + "Issue fixed", # issue_response + ) ) response = client.post( - "/issue/answer", + "/issue/answer/", json={ "repository_id": 1, "issue_number": 42, @@ -74,7 +76,7 @@ def test_answer_issue_no_repository(mock_service): mock_service["repository_service"].get_repository_by_id.return_value = None response = client.post( - "/issue/answer", + "/issue/answer/", json={ "repository_id": 1, "issue_number": 42, @@ -100,7 +102,7 @@ def test_answer_issue_invalid_container_config(mock_service): ) response = client.post( - "/issue/answer", + "/issue/answer/", json={ "repository_id": 1, "issue_number": 42, @@ -135,13 +137,15 @@ def test_answer_issue_with_container(mock_service): mock_service["repository_service"].get_repository.return_value = git_repo - mock_service["issue_service"].answer_issue.return_value = ( - "feature/fix-42", - "test patch", - True, - True, - True, - "Issue fixed", + mock_service["issue_service"].answer_issue = mock.AsyncMock( + return_value=( + "feature/fix-42", + "test patch", + True, + True, + True, + "Issue fixed", + ) ) test_payload = { @@ -156,10 +160,11 @@ def test_answer_issue_with_container(mock_service): "test_commands": ["pytest ."], } - response = client.post("/issue/answer", json=test_payload) + response = client.post("/issue/answer/", json=test_payload) assert response.status_code == 200 mock_service["issue_service"].answer_issue.assert_called_once_with( + repository_id=1, repository=git_repo, knowledge_graph=knowledge_graph, issue_number=42, diff --git a/tests/app/services/test_issue_service.py b/tests/app/services/test_issue_service.py index 06c43a0e..79d9a0d6 100644 --- a/tests/app/services/test_issue_service.py +++ b/tests/app/services/test_issue_service.py @@ -5,6 +5,7 @@ from prometheus.app.services.issue_service import IssueService from prometheus.app.services.llm_service import LLMService from prometheus.app.services.neo4j_service import Neo4jService +from prometheus.app.services.repository_service import RepositoryService from prometheus.git.git_repository import GitRepository from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_state import IssueType @@ -26,16 +27,25 @@ def mock_llm_service(): @pytest.fixture -def issue_service(mock_neo4j_service, mock_llm_service): +def mock_repository_service(): + service = create_autospec(RepositoryService, instance=True) + return service + + +@pytest.fixture +def issue_service(mock_neo4j_service, mock_llm_service, mock_repository_service): return IssueService( - mock_neo4j_service, - mock_llm_service, + neo4j_service=mock_neo4j_service, + llm_service=mock_llm_service, + repository_service=mock_repository_service, max_token_per_neo4j_result=1000, working_directory="/tmp/working_dir/", + logging_level="DEBUG", ) -def test_answer_issue_with_general_container(issue_service, monkeypatch): +@pytest.mark.asyncio +async def test_answer_issue_with_general_container(issue_service, monkeypatch): # Setup mock_issue_graph = Mock() mock_issue_graph_class = Mock(return_value=mock_issue_graph) @@ -64,7 +74,8 @@ def test_answer_issue_with_general_container(issue_service, monkeypatch): mock_issue_graph.invoke.return_value = mock_output_state # Exercise - result = issue_service.answer_issue( + result = await issue_service.answer_issue( + repository_id=1, repository=repository, knowledge_graph=knowledge_graph, issue_number=-1, @@ -94,7 +105,8 @@ def test_answer_issue_with_general_container(issue_service, monkeypatch): assert result == (None, "test_patch", True, True, True, "test_response") -def test_answer_issue_with_user_defined_container(issue_service, monkeypatch): +@pytest.mark.asyncio +async def test_answer_issue_with_user_defined_container(issue_service, monkeypatch): # Setup mock_issue_graph = Mock() mock_issue_graph_class = Mock(return_value=mock_issue_graph) @@ -112,11 +124,19 @@ def test_answer_issue_with_user_defined_container(issue_service, monkeypatch): knowledge_graph = Mock(spec=KnowledgeGraph) # Mock output state for a question type - mock_output_state = {"issue_type": IssueType.QUESTION, "issue_response": "test_response"} + mock_output_state = { + "issue_type": IssueType.QUESTION, + "edit_patch": None, + "passed_reproducing_test": False, + "passed_build": False, + "passed_existing_test": False, + "issue_response": "test_response", + } mock_issue_graph.invoke.return_value = mock_output_state # Exercise - result = issue_service.answer_issue( + result = await issue_service.answer_issue( + repository_id=1, repository=repository, knowledge_graph=knowledge_graph, issue_number=-1,