Skip to content
Merged
Empty file.
6 changes: 6 additions & 0 deletions prometheus/exceptions/file_operation_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class FileOperationException(Exception):
"""
Base class for file operation exceptions.
"""

pass
130 changes: 130 additions & 0 deletions prometheus/lang_graph/nodes/context_extraction_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import logging
from typing import Sequence

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import SystemMessage
from pydantic import BaseModel, Field

from prometheus.exceptions.file_operation_exceptions import FileOperationException
from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState
from prometheus.models.context import Context
from prometheus.utils.file_utils import read_file_with_line_numbers

SYS_PROMPT = """\
You are a context summary agent that summarizes code context that is relevant to a given query from history messages.
Your goal is to extract, evaluate and summary code context that directly answers the query requirements.

Your evaluation and summarization must consider two key aspects:
1. Query Match: Which set of history messages directly address specific requirements mentioned in the query?
2. Extended relevance: Which set of history messages provide essential information needed to understand the query topic?

Follow these strict evaluation steps:
1. First, identify specific requirements in the query
2. Check which set of history messages directly addresses these requirements
3. Check which parts of code context are relevant to the query
4. Consider if they provides essential context by examining:
- Function dependencies
- Type definitions
- Configuration requirements
- Implementation details needed for completeness

Query relevance guidelines - include only if:
- It directly implements functionality mentioned in the query
- It contains specific elements the query asks about
- It's necessary to understand or implement query requirements
- It provides critical information needed to answer the query

CRITICAL RULE:
- You don't have to select whole piece of code that you have seen, ONLY select the parts that are relevant to the query.
- Each context MUST be SHORT and CONCISE, focusing ONLY on the lines that are relevant to the query.
- Several context can be extracted from the same file, but each context must be concise and relevant to the query.
- Do NOT include any irrelevant lines or comments that do not contribute to answering the query.
- Do NOT include same context multiple times, even if it appears in different history messages.

Remember: Your primary goal is to summarize context that directly helps answer the query requirements.

Provide your analysis in a structured format matching the ContextExtractionStructuredOutput model.

Example output:
```json
{
"context": [{
"reasoning": "1. Query requirement analysis:\n - Query specifically asks about password validation\n - Context provides implementation details for password validation\n2. Extended relevance:\n - This function is essential for understanding how passwords are validated in the system",
"relative_path": "pychemia/code/fireball/fireball.py",
"start_line": 270,
"end_line": 293
} ......]
}
```

Your task is to summarize the context from the provided history messages and return it in the specified format.
"""


class ContextOutput(BaseModel):
reasoning: str = Field(
description="Your step-by-step reasoning why the context is relevant to the query"
)
relative_path: str = Field(description="Relative path to the context file in the codebase")
start_line: int = Field(
description="Start line number of the context in the file, minimum is 1", ge=1
)
end_line: int = Field(
description="End line number of the context in the file, minimum is 1. "
"The Content in the end line is including",
ge=1,
)


class ContextExtractionStructuredOutput(BaseModel):
context: Sequence[ContextOutput]


class ContextExtractionNode:
def __init__(self, model: BaseChatModel, root_path: str):
structured_llm = model.with_structured_output(ContextExtractionStructuredOutput)
self.model = structured_llm
self.system_prompt = SystemMessage(SYS_PROMPT)
self.root_path = root_path
self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_extraction_node")

def __call__(self, state: ContextRetrievalState):
self._logger.info("Starting context extraction process")
# Get Context List with existing context
final_context = state.get("context", [])
# Get chat history messages from the state
last_messages = state["context_provider_messages"]
# Summarize the context based on the last messages and system prompt
response = self.model.invoke([self.system_prompt] + last_messages)
self._logger.debug(f"Model response: {response}")
context_list = response.context
for context_ in context_list:
try:
content = read_file_with_line_numbers(
relative_path=context_.relative_path,
root_path=str(self.root_path),
start_line=context_.start_line,
end_line=context_.end_line,
)
except FileOperationException as e:
self._logger.error(e)
continue
if content:
final_context.append(
Context(
relative_path=context_.relative_path,
start_line_number=context_.start_line,
end_line_number=context_.end_line,
content=content,
)
)
# Filter out duplicate Context entries
seen = set()
unique_context = []
for ctx in final_context:
key = (ctx.relative_path, ctx.start_line_number, ctx.end_line_number)
if key not in seen:
seen.add(key)
unique_context.append(ctx)
self._logger.info(f"Context extraction complete, returning context {unique_context}")
return {"context": unique_context}
9 changes: 3 additions & 6 deletions prometheus/lang_graph/nodes/context_provider_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ContextProviderNode:

SYS_PROMPT = """\
You are a context gatherer that searches a Neo4j knowledge graph representation of a
codebase. Your role is to efficiently find relevant code and documentation
codebase. Your role is to understand the logic of the project and efficiently find relevant code and documentation
context based on user queries.

Knowledge Graph Structure:
Expand All @@ -55,7 +55,7 @@ class ContextProviderNode:
- Prioritize relative_path tools when exact file location is known
- Fall back to basename tools for filename-only searches
- Use AST node searches to find specific code structures
- Use preview_* or read_* tools with more than hundrend lines to get more context than class/function
- Use preview_* or read_* tools with more than hundred lines to get more context than class/function
- If a search returns no results, try alternative approaches with broader scope

2. Documentation/Text Search:
Expand All @@ -72,11 +72,8 @@ class ContextProviderNode:

4. Critical Rules:
- Do not repeat the same query!
- Do not select a whole file or directory as context, but rather specific code snippets.
- Each context should be a small, focused piece of code or documentation that directly addresses the query, which must be less than 100 lines!
- But several contexts snippets can be selected if they are relevant to the query.

In your response, just provide a short summary with a few setences (3-4 setences) on what you have done.
In your response, just provide a short summary with a few sentences (3-4 sentences) on what you have done.
As your searched are automatically visible to the user, you do not need to repeat them.

The file tree of the codebase:
Expand Down
4 changes: 3 additions & 1 deletion prometheus/lang_graph/nodes/context_refine_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ class ContextRefineStructuredOutput(BaseModel):
"""

def __init__(self, model: BaseChatModel, kg: KnowledgeGraph):
file_tree = kg.get_file_tree().replace("{", "{{").replace("}", "}}")
system_prompt = self.SYS_PROMPT.format(file_tree=file_tree)
prompt = ChatPromptTemplate.from_messages(
[
("system", self.SYS_PROMPT.format(file_tree=kg.get_file_tree())),
("system", system_prompt),
("human", "{human_prompt}"),
]
)
Expand Down
112 changes: 0 additions & 112 deletions prometheus/lang_graph/nodes/context_selection_node.py

This file was deleted.

4 changes: 3 additions & 1 deletion prometheus/lang_graph/nodes/final_patch_selection_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def __call__(self, state: Dict):
human_prompt = self.format_human_message(state)
for try_index in range(self.max_retries):
response = self.model.invoke({"human_prompt": human_prompt})
self._logger.info(f"FinalPatchSelectionNode response at {try_index} try:\n{response}")
self._logger.info(
f"FinalPatchSelectionNode response at {try_index + 1} try:\n{response}"
)

if 0 <= response.patch_index < len(state["edit_patches"]):
return {"final_patch": state["edit_patches"][response.patch_index]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def invoke(
"issue_title": issue_title,
"issue_body": issue_body,
"issue_comments": issue_comments,
"max_refined_query_loop": 1,
"max_refined_query_loop": 3,
}

try:
Expand Down
14 changes: 7 additions & 7 deletions prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from langgraph.prebuilt import ToolNode, tools_condition

from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.nodes.context_extraction_node import ContextExtractionNode
from prometheus.lang_graph.nodes.context_provider_node import ContextProviderNode
from prometheus.lang_graph.nodes.context_query_message_node import ContextQueryMessageNode
from prometheus.lang_graph.nodes.context_refine_node import ContextRefineNode
from prometheus.lang_graph.nodes.context_selection_node import ContextSelectionNode
from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode
from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState
from prometheus.models.context import Context
Expand Down Expand Up @@ -69,8 +69,8 @@ def __init__(
messages_key="context_provider_messages",
)

# Step 4: Select relevant context snippets from the candidates
context_selection_node = ContextSelectionNode(model)
# Step 4: Extract the Context
context_extraction_node = ContextExtractionNode(model, str(kg.get_local_path()))

# Step 5: Reset tool messages to prepare for the next iteration (if needed)
reset_context_provider_messages_node = ResetMessagesNode("context_provider_messages")
Expand All @@ -85,7 +85,7 @@ def __init__(
workflow.add_node("context_query_message_node", context_query_message_node)
workflow.add_node("context_provider_node", context_provider_node)
workflow.add_node("context_provider_tools", context_provider_tools)
workflow.add_node("context_selection_node", context_selection_node)
workflow.add_node("context_extraction_node", context_extraction_node)
workflow.add_node(
"reset_context_provider_messages_node", reset_context_provider_messages_node
)
Expand All @@ -100,10 +100,10 @@ def __init__(
workflow.add_conditional_edges(
"context_provider_node",
functools.partial(tools_condition, messages_key="context_provider_messages"),
{"tools": "context_provider_tools", END: "context_selection_node"},
{"tools": "context_provider_tools", END: "context_extraction_node"},
)
workflow.add_edge("context_provider_tools", "context_provider_node")
workflow.add_edge("context_selection_node", "reset_context_provider_messages_node")
workflow.add_edge("context_extraction_node", "reset_context_provider_messages_node")
workflow.add_edge("reset_context_provider_messages_node", "context_refine_node")

# If refined_query is non-empty, loop back to provider; else terminate
Expand All @@ -129,7 +129,7 @@ def invoke(

Returns:
Dict with a single key:
- "context" (Sequence[str]): A list of selected context snippets relevant to the query.
- "context" (Sequence[Context]): A list of selected context snippets relevant to the query.
"""
config = {"recursion_limit": recursion_limit}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def invoke(
"issue_title": issue_title,
"issue_body": issue_body,
"issue_comments": issue_comments,
"max_refined_query_loop": 1,
"max_refined_query_loop": 3,
}

output_state = self.subgraph.invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def invoke(
"issue_body": issue_body,
"issue_comments": issue_comments,
"number_of_candidate_patch": number_of_candidate_patch,
"max_refined_query_loop": 3,
"max_refined_query_loop": 5,
}

output_state = self.subgraph.invoke(input_state, config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def invoke(
"run_existing_test": run_existing_test,
"reproduced_bug_file": reproduced_bug_file,
"reproduced_bug_commands": reproduced_bug_commands,
"max_refined_query_loop": 3,
"max_refined_query_loop": 5,
}

output_state = self.subgraph.invoke(input_state, config)
Expand Down
Loading