From d900f398d37509972acf9edac36b7bbc24f72a31 Mon Sep 17 00:00:00 2001 From: schwerli <2064580160@qq.com> Date: Sat, 23 Aug 2025 08:37:58 +0800 Subject: [PATCH 1/2] Add normalized patch processing: simplified patch normalization and deduplication --- .../nodes/patch_normalization_node.py | 165 +++++++++++++ .../normalized_not_verified_bug_subgraph.py | 223 ++++++++++++++++++ 2 files changed, 388 insertions(+) create mode 100644 prometheus/lang_graph/nodes/patch_normalization_node.py create mode 100644 prometheus/lang_graph/subgraphs/normalized_not_verified_bug_subgraph.py diff --git a/prometheus/lang_graph/nodes/patch_normalization_node.py b/prometheus/lang_graph/nodes/patch_normalization_node.py new file mode 100644 index 0000000..54e9365 --- /dev/null +++ b/prometheus/lang_graph/nodes/patch_normalization_node.py @@ -0,0 +1,165 @@ +"""Patch Normalization and Selection Node + +This module implements simplified patch normalization and direct selection functionality. +Provides standardized patch candidates with direct best patch selection. +""" + +import logging +import re +import threading +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass +class PatchMetrics: + """Patch basic metrics""" + occurrence_count: int = 1 + + +@dataclass +class NormalizedPatch: + """Normalized patch data structure""" + original_index: int + original_content: str + normalized_content: str + metrics: PatchMetrics + + +class PatchNormalizationNode: + """Patch Normalization and Direct Selection Node + + Implements patch normalization, deduplication and direct best patch selection. + Simplified approach without complex voting mechanisms. + """ + + def __init__(self): + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.patch_normalization_node" + ) + + def normalize_patch(self, raw_patch: str) -> str: + """Normalize patch content for deduplication + + Removes metadata lines and standardizes formatting to enable + accurate patch comparison and deduplication. + """ + if not raw_patch: + return "" + + lines = raw_patch.split('\n') + normalized_lines = [] + + for line in lines: + # Skip metadata lines + if self._is_metadata_line(line): + continue + + # Normalize file paths + if line.startswith('--- ') or line.startswith('+++ '): + line = self._normalize_file_path(line) + + normalized_lines.append(line) + + return '\n'.join(normalized_lines) + + def _is_metadata_line(self, line: str) -> bool: + """Check if line is metadata that should be ignored""" + metadata_patterns = [ + r'^diff --git', + r'^index [a-f0-9]+\.\.[a-f0-9]+', + r'^new file mode \d+', + r'^deleted file mode \d+', + r'^similarity index \d+%', + r'^rename from ', + r'^rename to ', + r'^Binary files ', + ] + + return any(re.match(pattern, line) for pattern in metadata_patterns) + + def _normalize_file_path(self, line: str) -> str: + """Normalize file path in diff header""" + # Remove timestamp and mode information + line = re.sub(r'\s+\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)? \+\d{4}', '', line) + line = re.sub(r'\s+\d{6}', '', line) + + return line + + def calculate_patch_metrics(self, normalized_patch: str) -> PatchMetrics: + """Calculate basic metrics for a patch""" + return PatchMetrics() + + + def deduplicate_patches(self, patches: List[str]) -> List[NormalizedPatch]: + """Deduplicate patches using normalization + + Returns list of unique normalized patches with occurrence counts. + """ + if not patches: + return [] + + # Normalize all patches + normalized_patches = [] + for i, patch in enumerate(patches): + normalized_content = self.normalize_patch(patch) + metrics = self.calculate_patch_metrics(normalized_content) + + normalized_patches.append(NormalizedPatch( + original_index=i, + original_content=patch, + normalized_content=normalized_content, + metrics=metrics + )) + + # Group by normalized content + patch_groups = defaultdict(list) + for patch in normalized_patches: + patch_groups[patch.normalized_content].append(patch) + + # Create deduplicated list with occurrence counts + deduplicated = [] + for normalized_content, group in patch_groups.items(): + # Use the first patch in the group as representative + representative = group[0] + # Update occurrence count + representative.metrics.occurrence_count = len(group) + deduplicated.append(representative) + + self._logger.info(f"Deduplication complete: {len(patches)} -> {len(deduplicated)} unique patches") + + return deduplicated + + def __call__(self, state: Dict) -> Dict: + """Node call interface + + Process edit_patches in state, return normalized, deduplicated patches and selected best patch + """ + patches = state.get("edit_patches", []) + + if not patches: + self._logger.warning("No patches found to process") + return { + "normalized_patches": [], + "final_patch": "", + "original_patch_count": 0, + "unique_patch_count": 0 + } + + self._logger.info(f"Starting to process {len(patches)} patches") + + # Execute deduplication and normalization + normalized_patches = self.deduplicate_patches(patches) + + # Return deduplicated patches (selection will be done by final_patch_selection_node) + deduplicated_patches = [patch.original_content for patch in normalized_patches] + + self._logger.info(f"Patch processing complete, deduplicated to {len(normalized_patches)} unique patches") + + return { + "normalized_patches": normalized_patches, + "edit_patches": deduplicated_patches, # Return deduplicated patches for selection + "original_patch_count": len(patches), + "unique_patch_count": len(normalized_patches) + } diff --git a/prometheus/lang_graph/subgraphs/normalized_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/normalized_not_verified_bug_subgraph.py new file mode 100644 index 0000000..5852085 --- /dev/null +++ b/prometheus/lang_graph/subgraphs/normalized_not_verified_bug_subgraph.py @@ -0,0 +1,223 @@ +"""Normalized Not Verified Bug Subgraph + +This module implements a simplified enhanced issue not verified bug subgraph +with patch normalization and deduplication, using standard final patch selection. +""" + +import logging +import threading +from typing import Optional, Sequence, Mapping + +import neo4j +from langchain_core.language_models import BaseChatModel +from langgraph.graph import StateGraph, END + +from prometheus.knowledge_graph.knowledge_graph import KnowledgeGraph +from prometheus.lang_graph.graphs.issue_state import IssueNotVerifiedBugState +from prometheus.lang_graph.nodes.context_retrieval_subgraph_node import ContextRetrievalSubgraphNode +from prometheus.lang_graph.nodes.edit_message_node import EditMessageNode +from prometheus.lang_graph.nodes.edit_node import EditNode +from prometheus.lang_graph.nodes.git_diff_node import GitDiffNode +from prometheus.lang_graph.nodes.git_reset_node import GitResetNode +from prometheus.lang_graph.nodes.issue_bug_analyzer_message_node import IssueBugAnalyzerMessageNode +from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode +from prometheus.lang_graph.nodes.issue_bug_context_message_node import IssueBugContextMessageNode +from prometheus.lang_graph.nodes.patch_normalization_node import PatchNormalizationNode +from prometheus.lang_graph.nodes.final_patch_selection_node import FinalPatchSelectionNode +from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode +from prometheus.repository.git_repository import GitRepository +from prometheus.container.base_container import BaseContainer + + +class NormalizedNotVerifiedBugSubgraph: + """Simplified Enhanced Issue Not Verified Bug Subgraph + + Simplified workflow with patch normalization and deduplication: + 1. Original context retrieval and bug analysis + 2. Patch generation and diff + 3. Patch normalization and deduplication + 4. Standard final patch selection + """ + + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + container: Optional[BaseContainer] = None, + ): + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.subgraphs.normalized_not_verified_bug_subgraph" + ) + + # === Initialize Nodes === + # Context retrieval subgraph node + context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( + advanced_model=advanced_model, + base_model=base_model, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + container=container, + ) + + # Issue bug context message node + issue_bug_context_message_node = IssueBugContextMessageNode( + advanced_model=advanced_model, + base_model=base_model, + ) + + # Issue bug analyzer message node + issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode( + advanced_model=advanced_model, + base_model=base_model, + ) + + # Issue bug analyzer node + issue_bug_analyzer_node = IssueBugAnalyzerNode( + advanced_model=advanced_model, + base_model=base_model, + ) + + # Edit message node + edit_message_node = EditMessageNode( + advanced_model=advanced_model, + base_model=base_model, + ) + + # Edit node + edit_node = EditNode( + advanced_model=advanced_model, + base_model=base_model, + ) + + # Git diff node + git_diff_node = GitDiffNode( + git_repo=git_repo, + ) + + # Git reset node + git_reset_node = GitResetNode( + git_repo=git_repo, + ) + + # Reset messages nodes + reset_issue_bug_analyzer_messages_node = ResetMessagesNode( + message_key="issue_bug_analyzer_messages" + ) + reset_edit_messages_node = ResetMessagesNode( + message_key="edit_messages" + ) + + # Patch normalization node (only deduplication) + patch_normalization_node = PatchNormalizationNode() + + # Final patch selection node (intelligent selection) + final_patch_selection_node = FinalPatchSelectionNode( + model=advanced_model, + max_retries=2 + ) + + # === Build Workflow Graph === + workflow = StateGraph(IssueNotVerifiedBugState) + + # Add nodes + workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) + workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node) + workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) + workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) + workflow.add_node("edit_message_node", edit_message_node) + workflow.add_node("edit_node", edit_node) + workflow.add_node("git_diff_node", git_diff_node) + workflow.add_node("git_reset_node", git_reset_node) + workflow.add_node("reset_issue_bug_analyzer_messages_node", reset_issue_bug_analyzer_messages_node) + workflow.add_node("reset_edit_messages_node", reset_edit_messages_node) + workflow.add_node("patch_normalization_node", patch_normalization_node) + workflow.add_node("final_patch_selection_node", final_patch_selection_node) + + # === Build Workflow Edges === + # Start with context retrieval + workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_context_message_node") + workflow.add_edge("issue_bug_context_message_node", "issue_bug_analyzer_message_node") + workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") + workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") + workflow.add_edge("edit_message_node", "edit_node") + workflow.add_edge("edit_node", "git_diff_node") + + # === Decision Point: Continue Generation or Process Patches === + workflow.add_conditional_edges( + "git_diff_node", + self._routing_logic, + { + "continue_generation": "git_reset_node", # Continue generating more patches + "process_patches": "patch_normalization_node", # Process patches with normalization + } + ) + + # Continue generating patches - original flow + workflow.add_edge("git_reset_node", "reset_issue_bug_analyzer_messages_node") + workflow.add_edge("reset_issue_bug_analyzer_messages_node", "reset_edit_messages_node") + workflow.add_edge("reset_edit_messages_node", "issue_bug_analyzer_message_node") + + # === Patch Processing Flow === + # Flow: normalization -> final selection -> END + workflow.add_edge("patch_normalization_node", "final_patch_selection_node") + workflow.add_edge("final_patch_selection_node", END) + + self.subgraph = workflow.compile() + + def _routing_logic(self, state: IssueNotVerifiedBugState) -> str: + """Routing logic to decide whether to continue generation or process patches""" + patches = state.get("edit_patches", []) + target_patch_count = state.get("number_of_candidate_patch", 1) + current_patch_count = len(patches) + + if current_patch_count < target_patch_count: + return "continue_generation" + + return "process_patches" + + def invoke( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + number_of_candidate_patch: int, + recursion_limit: int = 500, + ): + """Invoke the subgraph with issue information""" + # Prepare initial state + initial_state = { + "issue_title": issue_title, + "issue_body": issue_body, + "issue_comments": issue_comments, + "number_of_candidate_patch": number_of_candidate_patch, + "edit_patches": [], + "issue_bug_analyzer_messages": [], + "edit_messages": [], + } + + # Execute the workflow + output_state = self.subgraph.invoke( + initial_state, + config={"recursion_limit": recursion_limit} + ) + + # Extract results + result = { + "final_patch": output_state.get("final_patch", ""), + } + + # Add patch statistics if available + if "unique_patch_count" in output_state: + result["patch_statistics"] = { + "original_patch_count": output_state.get("original_patch_count", 0), + "unique_patch_count": output_state.get("unique_patch_count", 0), + "deduplication_ratio": output_state.get("unique_patch_count", 0) / max(output_state.get("original_patch_count", 1), 1) + } + + return result From 39d8c811e213b6c0b47e2d02a66fa2aecf0d73cb Mon Sep 17 00:00:00 2001 From: schwerli <2064580160@qq.com> Date: Sat, 23 Aug 2025 08:38:43 +0800 Subject: [PATCH 2/2] Fix ruff formatting and import sorting --- .../nodes/patch_normalization_node.py | 115 ++++++++++-------- .../normalized_not_verified_bug_subgraph.py | 85 +++++++------ 2 files changed, 102 insertions(+), 98 deletions(-) diff --git a/prometheus/lang_graph/nodes/patch_normalization_node.py b/prometheus/lang_graph/nodes/patch_normalization_node.py index 54e9365..c2238ef 100644 --- a/prometheus/lang_graph/nodes/patch_normalization_node.py +++ b/prometheus/lang_graph/nodes/patch_normalization_node.py @@ -15,12 +15,14 @@ @dataclass class PatchMetrics: """Patch basic metrics""" + occurrence_count: int = 1 @dataclass class NormalizedPatch: """Normalized patch data structure""" + original_index: int original_content: str normalized_content: str @@ -29,95 +31,96 @@ class NormalizedPatch: class PatchNormalizationNode: """Patch Normalization and Direct Selection Node - + Implements patch normalization, deduplication and direct best patch selection. Simplified approach without complex voting mechanisms. """ - + def __init__(self): self._logger = logging.getLogger( f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.patch_normalization_node" ) - + def normalize_patch(self, raw_patch: str) -> str: """Normalize patch content for deduplication - + Removes metadata lines and standardizes formatting to enable accurate patch comparison and deduplication. """ if not raw_patch: return "" - - lines = raw_patch.split('\n') + + lines = raw_patch.split("\n") normalized_lines = [] - + for line in lines: # Skip metadata lines if self._is_metadata_line(line): continue - + # Normalize file paths - if line.startswith('--- ') or line.startswith('+++ '): + if line.startswith("--- ") or line.startswith("+++ "): line = self._normalize_file_path(line) - + normalized_lines.append(line) - - return '\n'.join(normalized_lines) - + + return "\n".join(normalized_lines) + def _is_metadata_line(self, line: str) -> bool: """Check if line is metadata that should be ignored""" metadata_patterns = [ - r'^diff --git', - r'^index [a-f0-9]+\.\.[a-f0-9]+', - r'^new file mode \d+', - r'^deleted file mode \d+', - r'^similarity index \d+%', - r'^rename from ', - r'^rename to ', - r'^Binary files ', + r"^diff --git", + r"^index [a-f0-9]+\.\.[a-f0-9]+", + r"^new file mode \d+", + r"^deleted file mode \d+", + r"^similarity index \d+%", + r"^rename from ", + r"^rename to ", + r"^Binary files ", ] - + return any(re.match(pattern, line) for pattern in metadata_patterns) - + def _normalize_file_path(self, line: str) -> str: """Normalize file path in diff header""" # Remove timestamp and mode information - line = re.sub(r'\s+\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)? \+\d{4}', '', line) - line = re.sub(r'\s+\d{6}', '', line) - + line = re.sub(r"\s+\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)? \+\d{4}", "", line) + line = re.sub(r"\s+\d{6}", "", line) + return line - + def calculate_patch_metrics(self, normalized_patch: str) -> PatchMetrics: """Calculate basic metrics for a patch""" return PatchMetrics() - - + def deduplicate_patches(self, patches: List[str]) -> List[NormalizedPatch]: """Deduplicate patches using normalization - + Returns list of unique normalized patches with occurrence counts. """ if not patches: return [] - + # Normalize all patches normalized_patches = [] for i, patch in enumerate(patches): normalized_content = self.normalize_patch(patch) metrics = self.calculate_patch_metrics(normalized_content) - - normalized_patches.append(NormalizedPatch( - original_index=i, - original_content=patch, - normalized_content=normalized_content, - metrics=metrics - )) - + + normalized_patches.append( + NormalizedPatch( + original_index=i, + original_content=patch, + normalized_content=normalized_content, + metrics=metrics, + ) + ) + # Group by normalized content patch_groups = defaultdict(list) for patch in normalized_patches: patch_groups[patch.normalized_content].append(patch) - + # Create deduplicated list with occurrence counts deduplicated = [] for normalized_content, group in patch_groups.items(): @@ -126,40 +129,44 @@ def deduplicate_patches(self, patches: List[str]) -> List[NormalizedPatch]: # Update occurrence count representative.metrics.occurrence_count = len(group) deduplicated.append(representative) - - self._logger.info(f"Deduplication complete: {len(patches)} -> {len(deduplicated)} unique patches") - + + self._logger.info( + f"Deduplication complete: {len(patches)} -> {len(deduplicated)} unique patches" + ) + return deduplicated - + def __call__(self, state: Dict) -> Dict: """Node call interface - + Process edit_patches in state, return normalized, deduplicated patches and selected best patch """ patches = state.get("edit_patches", []) - + if not patches: self._logger.warning("No patches found to process") return { "normalized_patches": [], "final_patch": "", "original_patch_count": 0, - "unique_patch_count": 0 + "unique_patch_count": 0, } - + self._logger.info(f"Starting to process {len(patches)} patches") - + # Execute deduplication and normalization normalized_patches = self.deduplicate_patches(patches) - + # Return deduplicated patches (selection will be done by final_patch_selection_node) deduplicated_patches = [patch.original_content for patch in normalized_patches] - - self._logger.info(f"Patch processing complete, deduplicated to {len(normalized_patches)} unique patches") - + + self._logger.info( + f"Patch processing complete, deduplicated to {len(normalized_patches)} unique patches" + ) + return { "normalized_patches": normalized_patches, "edit_patches": deduplicated_patches, # Return deduplicated patches for selection "original_patch_count": len(patches), - "unique_patch_count": len(normalized_patches) + "unique_patch_count": len(normalized_patches), } diff --git a/prometheus/lang_graph/subgraphs/normalized_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/normalized_not_verified_bug_subgraph.py index 5852085..ebdc153 100644 --- a/prometheus/lang_graph/subgraphs/normalized_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/normalized_not_verified_bug_subgraph.py @@ -6,39 +6,39 @@ import logging import threading -from typing import Optional, Sequence, Mapping +from typing import Mapping, Optional, Sequence import neo4j from langchain_core.language_models import BaseChatModel -from langgraph.graph import StateGraph, END +from langgraph.graph import END, StateGraph +from prometheus.container.base_container import BaseContainer from prometheus.knowledge_graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_state import IssueNotVerifiedBugState from prometheus.lang_graph.nodes.context_retrieval_subgraph_node import ContextRetrievalSubgraphNode from prometheus.lang_graph.nodes.edit_message_node import EditMessageNode from prometheus.lang_graph.nodes.edit_node import EditNode +from prometheus.lang_graph.nodes.final_patch_selection_node import FinalPatchSelectionNode from prometheus.lang_graph.nodes.git_diff_node import GitDiffNode from prometheus.lang_graph.nodes.git_reset_node import GitResetNode from prometheus.lang_graph.nodes.issue_bug_analyzer_message_node import IssueBugAnalyzerMessageNode from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode from prometheus.lang_graph.nodes.issue_bug_context_message_node import IssueBugContextMessageNode from prometheus.lang_graph.nodes.patch_normalization_node import PatchNormalizationNode -from prometheus.lang_graph.nodes.final_patch_selection_node import FinalPatchSelectionNode from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode from prometheus.repository.git_repository import GitRepository -from prometheus.container.base_container import BaseContainer class NormalizedNotVerifiedBugSubgraph: """Simplified Enhanced Issue Not Verified Bug Subgraph - + Simplified workflow with patch normalization and deduplication: 1. Original context retrieval and bug analysis 2. Patch generation and diff 3. Patch normalization and deduplication 4. Standard final patch selection """ - + def __init__( self, advanced_model: BaseChatModel, @@ -52,7 +52,7 @@ def __init__( self._logger = logging.getLogger( f"thread-{threading.get_ident()}.prometheus.lang_graph.subgraphs.normalized_not_verified_bug_subgraph" ) - + # === Initialize Nodes === # Context retrieval subgraph node context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( @@ -64,67 +64,62 @@ def __init__( max_token_per_neo4j_result=max_token_per_neo4j_result, container=container, ) - + # Issue bug context message node issue_bug_context_message_node = IssueBugContextMessageNode( advanced_model=advanced_model, base_model=base_model, ) - + # Issue bug analyzer message node issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode( advanced_model=advanced_model, base_model=base_model, ) - + # Issue bug analyzer node issue_bug_analyzer_node = IssueBugAnalyzerNode( advanced_model=advanced_model, base_model=base_model, ) - + # Edit message node edit_message_node = EditMessageNode( advanced_model=advanced_model, base_model=base_model, ) - + # Edit node edit_node = EditNode( advanced_model=advanced_model, base_model=base_model, ) - + # Git diff node git_diff_node = GitDiffNode( git_repo=git_repo, ) - + # Git reset node git_reset_node = GitResetNode( git_repo=git_repo, ) - + # Reset messages nodes reset_issue_bug_analyzer_messages_node = ResetMessagesNode( message_key="issue_bug_analyzer_messages" ) - reset_edit_messages_node = ResetMessagesNode( - message_key="edit_messages" - ) - + reset_edit_messages_node = ResetMessagesNode(message_key="edit_messages") + # Patch normalization node (only deduplication) patch_normalization_node = PatchNormalizationNode() - + # Final patch selection node (intelligent selection) - final_patch_selection_node = FinalPatchSelectionNode( - model=advanced_model, - max_retries=2 - ) - + final_patch_selection_node = FinalPatchSelectionNode(model=advanced_model, max_retries=2) + # === Build Workflow Graph === workflow = StateGraph(IssueNotVerifiedBugState) - + # Add nodes workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node) @@ -134,11 +129,13 @@ def __init__( workflow.add_node("edit_node", edit_node) workflow.add_node("git_diff_node", git_diff_node) workflow.add_node("git_reset_node", git_reset_node) - workflow.add_node("reset_issue_bug_analyzer_messages_node", reset_issue_bug_analyzer_messages_node) + workflow.add_node( + "reset_issue_bug_analyzer_messages_node", reset_issue_bug_analyzer_messages_node + ) workflow.add_node("reset_edit_messages_node", reset_edit_messages_node) workflow.add_node("patch_normalization_node", patch_normalization_node) workflow.add_node("final_patch_selection_node", final_patch_selection_node) - + # === Build Workflow Edges === # Start with context retrieval workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_context_message_node") @@ -147,7 +144,7 @@ def __init__( workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") workflow.add_edge("edit_message_node", "edit_node") workflow.add_edge("edit_node", "git_diff_node") - + # === Decision Point: Continue Generation or Process Patches === workflow.add_conditional_edges( "git_diff_node", @@ -155,32 +152,32 @@ def __init__( { "continue_generation": "git_reset_node", # Continue generating more patches "process_patches": "patch_normalization_node", # Process patches with normalization - } + }, ) - + # Continue generating patches - original flow workflow.add_edge("git_reset_node", "reset_issue_bug_analyzer_messages_node") workflow.add_edge("reset_issue_bug_analyzer_messages_node", "reset_edit_messages_node") workflow.add_edge("reset_edit_messages_node", "issue_bug_analyzer_message_node") - + # === Patch Processing Flow === # Flow: normalization -> final selection -> END workflow.add_edge("patch_normalization_node", "final_patch_selection_node") workflow.add_edge("final_patch_selection_node", END) - + self.subgraph = workflow.compile() - + def _routing_logic(self, state: IssueNotVerifiedBugState) -> str: """Routing logic to decide whether to continue generation or process patches""" patches = state.get("edit_patches", []) target_patch_count = state.get("number_of_candidate_patch", 1) current_patch_count = len(patches) - + if current_patch_count < target_patch_count: return "continue_generation" - + return "process_patches" - + def invoke( self, issue_title: str, @@ -200,24 +197,24 @@ def invoke( "issue_bug_analyzer_messages": [], "edit_messages": [], } - + # Execute the workflow output_state = self.subgraph.invoke( - initial_state, - config={"recursion_limit": recursion_limit} + initial_state, config={"recursion_limit": recursion_limit} ) - + # Extract results result = { "final_patch": output_state.get("final_patch", ""), } - + # Add patch statistics if available if "unique_patch_count" in output_state: result["patch_statistics"] = { "original_patch_count": output_state.get("original_patch_count", 0), "unique_patch_count": output_state.get("unique_patch_count", 0), - "deduplication_ratio": output_state.get("unique_patch_count", 0) / max(output_state.get("original_patch_count", 1), 1) + "deduplication_ratio": output_state.get("unique_patch_count", 0) + / max(output_state.get("original_patch_count", 1), 1), } - + return result