diff --git a/example.env b/example.env index b261c07..3e316f0 100644 --- a/example.env +++ b/example.env @@ -16,9 +16,9 @@ PROMETHEUS_NEO4J_BATCH_SIZE=1000 # Knowledge Graph settings PROMETHEUS_WORKING_DIRECTORY=working_dir/ PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH=3 -PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=10000 +PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=8000 PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP=1000 -PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=10000 +PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=8000 # LLM model settings PROMETHEUS_ADVANCED_MODEL=gpt-4o diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index 4e10f44..bc799ee 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -51,6 +51,7 @@ class EditNode: - Only one match of old_content should exist in the file - If multiple matches exist, more context is needed - If no matches exist, content must be verified +- Do not write any tests, your change will be tested by reproduction tests and regression tests later EXAMPLES: @@ -114,6 +115,7 @@ def other_method(): 4. When replacing multiple lines, include all lines in old_content 5. If multiple matches found, include more context 6. Verify uniqueness of matches before changes +7. NEVER write tests, your change will be tested by reproduction tests and regression tests later """ def __init__(self, model: BaseChatModel, local_path: str): diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index 45d28c4..a024e18 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -1,5 +1,6 @@ import logging import threading +from typing import Sequence from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate @@ -132,18 +133,13 @@ def __init__(self, model: BaseChatModel, max_retries: int = 2): ) self.majority_voting_times = 10 - def format_human_message(self, state: IssueNotVerifiedBugState): - if state["run_regression_test"]: - patches = [result.patch for result in state["tested_patch_result"] if result.passed] - else: - patches = state["edit_patches"] - + def format_human_message(self, patches: Sequence[str], state: IssueNotVerifiedBugState): patches_str = "" for index, patch in enumerate(patches): patches_str += f"Patch at index {index}:\n" patches_str += f"{patch}\n\n" patches_str += ( - f"You must select a patch with index from 0 to {len(state['edit_patches']) - 1}," + f"You must select a patch with index from 0 to {len(patches) - 1}," f" and provide your reasoning." ) @@ -156,28 +152,42 @@ def format_human_message(self, state: IssueNotVerifiedBugState): ) def __call__(self, state: IssueNotVerifiedBugState): - human_prompt = self.format_human_message(state) - result = [0 for _ in range(len(state["edit_patches"]))] + # Determine candidate patches + if state["run_regression_test"]: + patches = [result.patch for result in state["tested_patch_result"] if result.passed] + else: + patches = state["deduplicated_patches"] + + # Formalize Human Message + human_prompt = self.format_human_message(patches, state) + + # Majority voting + result = [0 for _ in range(len(patches))] for turn in range(self.majority_voting_times): + # Call the model response = self.model.invoke({"human_prompt": human_prompt}) self._logger.info( f"FinalPatchSelectionNode response at {turn + 1}/{self.majority_voting_times} try:" f"Selected patch index: {response.patch_index}, " ) - if 0 <= response.patch_index < len(state["edit_patches"]): + # Tally the vote if the index is valid + if 0 <= response.patch_index < len(patches): result[response.patch_index] += 1 + + # Early stopping if a patch has already secured majority if max(result) > self.majority_voting_times // 2: selected_patch_index = result.index(max(result)) self._logger.info( f"FinalPatchSelectionNode early stopping at turn {turn + 1} with result: {result}," f"selected patch index: {selected_patch_index}" ) - return {"final_patch": state["edit_patches"][selected_patch_index]} + return {"final_patch": patches[selected_patch_index]} + # Select the maximum voted patch index selected_patch_index = result.index(max(result)) self._logger.info( f"FinalPatchSelectionNode voting results: {result}, " f"selected patch index: {selected_patch_index}" ) - return {"final_patch": state["edit_patches"][selected_patch_index]} + return {"final_patch": patches[selected_patch_index]} diff --git a/prometheus/lang_graph/nodes/patch_normalization_node.py b/prometheus/lang_graph/nodes/patch_normalization_node.py index c2238ef..6e8062c 100644 --- a/prometheus/lang_graph/nodes/patch_normalization_node.py +++ b/prometheus/lang_graph/nodes/patch_normalization_node.py @@ -9,7 +9,9 @@ import threading from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Sequence + +from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState @dataclass @@ -93,7 +95,7 @@ def calculate_patch_metrics(self, normalized_patch: str) -> PatchMetrics: """Calculate basic metrics for a patch""" return PatchMetrics() - def deduplicate_patches(self, patches: List[str]) -> List[NormalizedPatch]: + def deduplicate_patches(self, patches: Sequence[str]) -> List[NormalizedPatch]: """Deduplicate patches using normalization Returns list of unique normalized patches with occurrence counts. @@ -136,20 +138,17 @@ def deduplicate_patches(self, patches: List[str]) -> List[NormalizedPatch]: return deduplicated - def __call__(self, state: Dict) -> Dict: + def __call__(self, state: IssueNotVerifiedBugState) -> Dict: """Node call interface - Process edit_patches in state, return normalized, deduplicated patches and selected best patch + Process edit_patches in state, return normalized, deduplicated patches """ patches = state.get("edit_patches", []) if not patches: self._logger.warning("No patches found to process") return { - "normalized_patches": [], - "final_patch": "", - "original_patch_count": 0, - "unique_patch_count": 0, + "deduplicated_patches": [], } self._logger.info(f"Starting to process {len(patches)} patches") @@ -161,12 +160,9 @@ def __call__(self, state: Dict) -> Dict: deduplicated_patches = [patch.original_content for patch in normalized_patches] self._logger.info( - f"Patch processing complete, deduplicated to {len(normalized_patches)} unique patches" + f"Patch processing complete, deduplicated to {len(deduplicated_patches)} unique patches" ) return { - "normalized_patches": normalized_patches, - "edit_patches": deduplicated_patches, # Return deduplicated patches for selection - "original_patch_count": len(patches), - "unique_patch_count": len(normalized_patches), + "deduplicated_patches": deduplicated_patches, } diff --git a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py index 4d62a32..ad1da1b 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py @@ -18,7 +18,8 @@ class RunRegressionTestsStructureOutput(BaseModel): description="If the test failed, contains the complete test FAILURE log. Otherwise empty string" ) total_tests_run: int = Field( - description="Total number of tests run, including both passed and failed tests" + description="Total number of tests run, including both passed and failed tests, or 0 if no tests were run", + default=0, ) @@ -31,13 +32,14 @@ class RunRegressionTestsStructuredNode: - Test summary showing "passed" or "PASSED" - Warning is ok - No "FAILURES" section -2. If a test fails, capture the complete failure output +2. If a test fails, capture the complete failure output. Otherwise empty string for failure log 3. Return the exact test identifiers that passed +4. Count the total number of tests run. Only count tests that were actually executed! If tests were unable to run due to an error, do not count them! Return: - passed_regression_tests: List of test identifier of regression tests that passed (e.g., class name and method name) - regression_test_fail_log: empty string if all tests pass, exact complete test output if a test fails -- total_tests_run: Total number of tests run, including both passed and failed tests +- total_tests_run: Total number of tests run, including both passed and failed tests. If you can't find any test run, return 0 Example 1: ``` @@ -67,7 +69,7 @@ class RunRegressionTestsStructuredNode: "test_file_operation.py::test_edit_file", "test_file_operation.py::test_create_file_already_exists" ], - "reproducing_test_fail_log": "" # ONLY output the log exact and complete test FAILURE log when test failure. Otherwise empty string, + "reproducing_test_fail_log": "", "total_tests_run": 7 }} @@ -76,6 +78,8 @@ class RunRegressionTestsStructuredNode: - A single failing test means the test is not passing - Include complete test output in failure log - Do Not output any log when where is no test executed. ONLY output the log exact and complete test FAILURE log when test failure! +- Do not forget to return the total number of tests run! If tests were unable to run due to an error, do not count them! +- If you can't find any test run, return 0 for total number of tests run! """ HUMAN_PROMPT = """ We have run the selected regression tests on the codebase. @@ -83,12 +87,15 @@ class RunRegressionTestsStructuredNode: --- BEGIN SELECTED REGRESSION TESTS --- {selected_regression_tests} --- END SELECTED REGRESSION TESTS --- + Run Regression Tests Logs: --- BEGIN LOG --- {run_regression_tests_messages} --- END LOG --- + Please analyze the logs and determine which regression tests passed!. You should return the exact test identifier that we give to you. +Don't forget to return the total number of tests run! """ def __init__(self, model: BaseChatModel): diff --git a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py index b46ac9e..8a186c6 100644 --- a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py +++ b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py @@ -130,7 +130,7 @@ def invoke(self, query: str, max_refined_query_loop: int) -> Dict[str, Sequence[ - "context" (Sequence[Context]): A list of selected context snippets relevant to the query. """ # Set the recursion limit based on the maximum number of refined query loops - config = {"recursion_limit": max_refined_query_loop * 50} + config = {"recursion_limit": (max_refined_query_loop + 1) * 40} input_state = { "query": query, diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py index db218e5..6cb5ace 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py @@ -24,6 +24,8 @@ class IssueNotVerifiedBugState(TypedDict): edit_patches: Annotated[Sequence[str], add] + deduplicated_patches: Sequence[str] + final_patch: str run_regression_test: bool diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index 7fc162e..123fbc6 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -21,7 +21,7 @@ from prometheus.lang_graph.nodes.issue_bug_analyzer_message_node import IssueBugAnalyzerMessageNode from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode from prometheus.lang_graph.nodes.issue_bug_context_message_node import IssueBugContextMessageNode -from prometheus.lang_graph.nodes.noop_node import NoopNode +from prometheus.lang_graph.nodes.patch_normalization_node import PatchNormalizationNode from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState @@ -37,8 +37,6 @@ def __init__( neo4j_driver: neo4j.Driver, max_token_per_neo4j_result: int, ): - noop_node = NoopNode() - issue_bug_context_message_node = IssueBugContextMessageNode() context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( model=base_model, @@ -66,12 +64,15 @@ def __init__( reset_issue_bug_analyzer_messages_node = ResetMessagesNode("issue_bug_analyzer_messages") reset_edit_messages_node = ResetMessagesNode("edit_messages") + # Patch Normalization Node + patch_normalization_node = PatchNormalizationNode() + # Get pass regression test patch subgraph node get_pass_regression_test_patch_subgraph_node = GetPassRegressionTestPatchSubgraphNode( model=base_model, container=container, git_repo=git_repo, - testing_patch_key="edit_patches", + testing_patch_key="deduplicated_patches", is_testing_patch_list=True, ) @@ -79,7 +80,6 @@ def __init__( final_patch_selection_node = FinalPatchSelectionNode(advanced_model) workflow = StateGraph(IssueNotVerifiedBugState) - workflow.add_node("noop_node", noop_node) workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node) workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) @@ -98,6 +98,8 @@ def __init__( ) workflow.add_node("reset_edit_messages_node", reset_edit_messages_node) + workflow.add_node("patch_normalization_node", patch_normalization_node) + workflow.add_node( "get_pass_regression_test_patch_subgraph_node", get_pass_regression_test_patch_subgraph_node, @@ -124,11 +126,11 @@ def __init__( lambda state: len(state["edit_patches"]) < state["number_of_candidate_patch"], { True: "git_reset_node", - False: "noop_node", + False: "patch_normalization_node", }, ) workflow.add_conditional_edges( - "noop_node", + "patch_normalization_node", lambda state: state["run_regression_test"], { True: "get_pass_regression_test_patch_subgraph_node",