diff --git a/prometheus/app/api/main.py b/prometheus/app/api/main.py index d1184089..f0e3f815 100644 --- a/prometheus/app/api/main.py +++ b/prometheus/app/api/main.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from prometheus.app.api.routes import auth, issue, repository +from prometheus.app.api.routes import auth, invitation_code, issue, repository, user from prometheus.configuration.config import settings api_router = APIRouter() @@ -9,3 +9,7 @@ if settings.ENABLE_AUTHENTICATION: api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) + api_router.include_router( + invitation_code.router, prefix="/invitation-code", tags=["invitation_code"] + ) + api_router.include_router(user.router, prefix="/user", tags=["user"]) diff --git a/prometheus/app/api/routes/auth.py b/prometheus/app/api/routes/auth.py index cb55a352..1dd84dbe 100644 --- a/prometheus/app/api/routes/auth.py +++ b/prometheus/app/api/routes/auth.py @@ -1,9 +1,12 @@ from fastapi import APIRouter, Request -from prometheus.app.models.requests.auth import LoginRequest +from prometheus.app.models.requests.auth import CreateUserRequest, LoginRequest from prometheus.app.models.response.auth import LoginResponse from prometheus.app.models.response.response import Response +from prometheus.app.services.invitation_code_service import InvitationCodeService from prometheus.app.services.user_service import UserService +from prometheus.configuration.config import settings +from prometheus.exceptions.server_exception import ServerException router = APIRouter() @@ -28,3 +31,38 @@ def login(login_request: LoginRequest, request: Request) -> Response[LoginRespon password=login_request.password, ) return Response(data=LoginResponse(access_token=access_token)) + + +@router.post( + "/register/", + summary="Register a new user", + description="Register a new user with username, email, password and invitation code.", + response_description="Returns a success message upon successful registration", + response_model=Response, +) +def register(request: Request, create_user_request: CreateUserRequest) -> Response: + """ + Register a new user with username, email, password and invitation code. + Returns a success message upon successful registration. + """ + invitation_code_service: InvitationCodeService = request.app.state.service[ + "invitation_code_service" + ] + user_service: UserService = request.app.state.service["user_service"] + + # Check if the invitation code is valid + if not invitation_code_service.check_invitation_code(create_user_request.invitation_code): + raise ServerException(code=400, message="Invalid or expired invitation code") + + # Create the user + user_service.create_user( + username=create_user_request.username, + email=create_user_request.email, + password=create_user_request.password, + issue_credit=settings.DEFAULT_USER_ISSUE_CREDIT, + ) + + # Mark the invitation code as used + invitation_code_service.mark_code_as_used(create_user_request.invitation_code) + + return Response(message="User registered successfully") diff --git a/prometheus/app/api/routes/invitation_code.py b/prometheus/app/api/routes/invitation_code.py new file mode 100644 index 00000000..5f05ad69 --- /dev/null +++ b/prometheus/app/api/routes/invitation_code.py @@ -0,0 +1,57 @@ +from typing import Sequence + +from fastapi import APIRouter, Request + +from prometheus.app.decorators.require_login import requireLogin +from prometheus.app.entity.invitation_code import InvitationCode +from prometheus.app.models.response.response import Response +from prometheus.app.services.user_service import UserService +from prometheus.exceptions.server_exception import ServerException + +router = APIRouter() + + +@router.post( + "/create/", + summary="Create a new invitation code", + description="Generates a new invitation code for user registration.", + response_description="Returns the newly created invitation code", + response_model=Response[InvitationCode], +) +@requireLogin +def create_invitation_code(request: Request) -> Response[InvitationCode]: + """ + Create a new invitation code. + """ + # Check if the user is an admin + user_service: UserService = request.app.state.service["user_service"] + if not user_service.is_admin(request.state.user_id): + raise ServerException(code=403, message="Only admins can create invitation codes") + + # Create a new invitation code + invitation_code_service = request.app.state.service["invitation_code_service"] + invitation_code = invitation_code_service.create_invitation_code() + return Response(data=invitation_code) + + +@router.get( + "/list/", + summary="List all invitation codes", + description="Retrieves a list of all invitation codes.", + response_description="Returns a list of invitation codes", + response_model=Response[Sequence[InvitationCode]], +) +@requireLogin +def list_invitation_codes(request: Request) -> Response[Sequence[InvitationCode]]: + """ + List all invitation codes. + """ + # Check if the user is an admin + user_service: UserService = request.app.state.service["user_service"] + if not user_service.is_admin(request.state.user_id): + raise ServerException(code=403, message="Only admins can list invitation codes") + + # List all invitation codes + invitation_code_service = request.app.state.service["invitation_code_service"] + invitation_codes = invitation_code_service.list_invitation_codes() + return Response(data=invitation_codes) diff --git a/prometheus/app/api/routes/issue.py b/prometheus/app/api/routes/issue.py index 91115871..e0f26d32 100644 --- a/prometheus/app/api/routes/issue.py +++ b/prometheus/app/api/routes/issue.py @@ -26,7 +26,14 @@ ) @requireLogin async def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueResponse]: + # Retrieve necessary services from the application state repository_service: RepositoryService = request.app.state.service["repository_service"] + user_service: UserService = request.app.state.service["user_service"] + issue_service: IssueService = request.app.state.service["issue_service"] + knowledge_graph_service: KnowledgeGraphService = request.app.state.service[ + "knowledge_graph_service" + ] + # Fetch the repository by ID repository = repository_service.get_repository_by_id(issue.repository_id) # Ensure the repository exists @@ -36,17 +43,15 @@ async def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueR if settings.ENABLE_AUTHENTICATION and repository.user_id != request.state.user_id: raise ServerException(code=403, message="You do not have access to this repository") - # Deduct issue credit if authentication is enabled - user_service: UserService = request.app.state.service["user_service"] + # Check issue credit + user_issue_credit = None if settings.ENABLE_AUTHENTICATION: - # Check and deduct issue credit user_issue_credit = user_service.get_issue_credit(request.state.user_id) if user_issue_credit <= 0: raise ServerException( code=403, message="Insufficient issue credits. Please purchase more to continue.", ) - user_service.update_issue_credit(request.state.user_id, user_issue_credit - 1) # Validate Dockerfile and workdir inputs if issue.dockerfile_content or issue.image_name: @@ -62,10 +67,7 @@ async def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueR message="The repository is currently being used. Please try again later.", ) - knowledge_graph_service: KnowledgeGraphService = request.app.state.service[ - "knowledge_graph_service" - ] - + # Load the git repository and knowledge graph git_repository = repository_service.get_repository(repository.playground_path) knowledge_graph = knowledge_graph_service.get_knowledge_graph( repository.kg_root_node_id, @@ -74,8 +76,7 @@ async def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueR repository.kg_chunk_overlap, ) - issue_service: IssueService = request.app.state.service["issue_service"] - + # Process the issue in a separate thread to avoid blocking the event loop ( patch, passed_reproducing_test, @@ -104,6 +105,8 @@ async def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueR build_commands=issue.build_commands, test_commands=issue.test_commands, ) + + # Check if all outputs are in their initial state, indicating a failure if ( patch, passed_reproducing_test, @@ -117,6 +120,12 @@ async def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueR code=500, message="Failed to process the issue. Please try again later.", ) + + # Deduct issue credit after successful processing + if settings.ENABLE_AUTHENTICATION: + user_service.update_issue_credit(request.state.user_id, user_issue_credit - 1) + + # Return the response return Response( data=IssueResponse( patch=patch, diff --git a/prometheus/app/api/routes/repository.py b/prometheus/app/api/routes/repository.py index bbf5956d..898f7153 100644 --- a/prometheus/app/api/routes/repository.py +++ b/prometheus/app/api/routes/repository.py @@ -57,28 +57,40 @@ async def upload_github_repository( ): # Get the repository and knowledge graph services repository_service: RepositoryService = request.app.state.service["repository_service"] - repository = repository_service.get_repository_by_url_and_commit_id( - upload_repository_request.https_url, commit_id=upload_repository_request.commit_id - ) + knowledge_graph_service: KnowledgeGraphService = request.app.state.service[ + "knowledge_graph_service" + ] + + # Check if the repository already exists if settings.ENABLE_AUTHENTICATION: - if repository and request.state.user_id == repository.user_id: - return Response( - message="Repository already exists", data={"repository_id": repository.id} - ) + repository = repository_service.get_repository_by_url_commit_id_and_user_id( + upload_repository_request.https_url, + upload_repository_request.commit_id, + request.state.user_id, + ) else: - if repository: - # If the repository already exists, return its ID - return Response( - message="Repository already exists", data={"repository_id": repository.id} + repository = repository_service.get_repository_by_url_and_commit_id( + upload_repository_request.https_url, commit_id=upload_repository_request.commit_id + ) + + # If the repository already exists, return its ID + if repository: + return Response(message="Repository already exists", data={"repository_id": repository.id}) + + # Check if the number of repositories exceeds the limit + if settings.ENABLE_AUTHENTICATION: + user_repositories = repository_service.get_repositories_by_user_id(request.state.user_id) + if len(user_repositories) >= settings.DEFAULT_USER_REPOSITORY_LIMIT: + raise ServerException( + code=400, + message=f"You have reached the maximum number of repositories ({settings.DEFAULT_USER_REPOSITORY_LIMIT}). Please delete some repositories before uploading new ones.", ) - knowledge_graph_service: KnowledgeGraphService = request.app.state.service[ - "knowledge_graph_service" - ] + # Get the GitHub token github_token = get_github_token(request, upload_repository_request.github_token) + # Clone the repository try: - # Clone the repository saved_path = await repository_service.clone_github_repo( github_token, upload_repository_request.https_url, upload_repository_request.commit_id ) @@ -86,6 +98,7 @@ async def upload_github_repository( raise ServerException( code=400, message=f"Unable to clone {upload_repository_request.https_url}." ) + # Build and save the knowledge graph from the cloned repository root_node_id = await knowledge_graph_service.build_and_save_knowledge_graph(saved_path) repository_id = repository_service.create_new_repository( diff --git a/prometheus/app/api/routes/user.py b/prometheus/app/api/routes/user.py new file mode 100644 index 00000000..5e7c6382 --- /dev/null +++ b/prometheus/app/api/routes/user.py @@ -0,0 +1,54 @@ +from typing import Sequence + +from fastapi import APIRouter, Request + +from prometheus.app.decorators.require_login import requireLogin +from prometheus.app.entity.user import User +from prometheus.app.models.requests.user import SetGithubTokenRequest +from prometheus.app.models.response.response import Response +from prometheus.app.models.response.user import UserResponse +from prometheus.app.services.user_service import UserService +from prometheus.exceptions.server_exception import ServerException + +router = APIRouter() + + +@router.get( + "/list/", + summary="List all users in the database", + description="Retrieves a list of all users.", + response_description="Returns a list of users", + response_model=Response[Sequence[UserResponse]], +) +@requireLogin +def list_users(request: Request) -> Response[Sequence[User]]: + """ + List all users in the database. + """ + # Check if the user is an admin + user_service: UserService = request.app.state.service["user_service"] + if not user_service.is_admin(request.state.user_id): + raise ServerException(code=403, message="Only admins can list users") + + # List all users + users = user_service.list_users() + return Response(data=[UserResponse.model_validate(user) for user in users]) + + +@router.put( + "/set-github-token/", + summary="Set GitHub token for the user", + description="Sets the GitHub token for the authenticated user.", + response_description="Returns the updated user information", + response_model=Response, +) +@requireLogin +def set_github_token(request: Request, set_github_token_request: SetGithubTokenRequest) -> Response: + """ + Set GitHub token for the user. + """ + user_service: UserService = request.app.state.service["user_service"] + + # Update the user's GitHub token + user_service.set_github_token(request.state.user_id, set_github_token_request.github_token) + return Response() diff --git a/prometheus/app/dependencies.py b/prometheus/app/dependencies.py index 22042330..972226c3 100644 --- a/prometheus/app/dependencies.py +++ b/prometheus/app/dependencies.py @@ -2,6 +2,7 @@ from prometheus.app.services.base_service import BaseService from prometheus.app.services.database_service import DatabaseService +from prometheus.app.services.invitation_code_service import InvitationCodeService from prometheus.app.services.issue_service import IssueService from prometheus.app.services.knowledge_graph_service import KnowledgeGraphService from prometheus.app.services.llm_service import LLMService @@ -69,6 +70,7 @@ def initialize_services() -> dict[str, BaseService]: ) user_service = UserService(database_service) + invitation_code_service = InvitationCodeService(database_service) return { "neo4j_service": neo4j_service, @@ -78,4 +80,5 @@ def initialize_services() -> dict[str, BaseService]: "issue_service": issue_service, "database_service": database_service, "user_service": user_service, + "invitation_code_service": invitation_code_service, } diff --git a/prometheus/app/entity/invitation_code.py b/prometheus/app/entity/invitation_code.py new file mode 100644 index 00000000..a0790cb5 --- /dev/null +++ b/prometheus/app/entity/invitation_code.py @@ -0,0 +1,19 @@ +from datetime import datetime, timedelta, timezone + +from sqlmodel import Field, SQLModel + +from prometheus.configuration.config import settings + + +class InvitationCode(SQLModel, table=True): + """ + InvitationCode model for managing invitation codes. + """ + + id: int = Field(primary_key=True, description="ID") + code: str = Field(index=True, unique=True, max_length=36, description="Invitation code") + is_used: bool = Field(default=False, description="Whether the invitation code has been used") + expiration_time: datetime = Field( + default=datetime.now(timezone.utc) + timedelta(days=settings.INVITATION_CODE_EXPIRE_TIME), + description="Expiration time of the invitation code", + ) diff --git a/prometheus/app/models/requests/auth.py b/prometheus/app/models/requests/auth.py index 29c0fc9d..28d7c2cd 100644 --- a/prometheus/app/models/requests/auth.py +++ b/prometheus/app/models/requests/auth.py @@ -19,6 +19,11 @@ class LoginRequest(BaseModel): @field_validator("email", mode="after") def validate_email_format(cls, v: str) -> str: + # Allow empty email + if not v: + return v + + # Simple regex for email validation pattern = r"^[^@\s]+@[^@\s]+\.[^@\s]+$" if not re.match(pattern, v): raise ValueError("Invalid email format") @@ -29,3 +34,31 @@ def check_username_or_email(self) -> "LoginRequest": if not self.username and not self.email: raise ValueError("At least one of 'username' or 'email' must be provided.") return self + + +class CreateUserRequest(BaseModel): + username: str = Field(description="username of the user", max_length=20) + email: str = Field( + description="email of the user", + examples=["your_email@gmail.com"], + max_length=30, + ) + password: str = Field( + description="password of the user", + examples=["P@ssw0rd!"], + min_length=8, + max_length=30, + ) + invitation_code: str = Field( + description="invitation code for registration", + examples=["abcd-efgh-ijkl-mnop"], + max_length=36, + min_length=36, + ) + + @field_validator("email", mode="after") + def validate_email_format(cls, v: str) -> str: + pattern = r"^[^@\s]+@[^@\s]+\.[^@\s]+$" + if not re.match(pattern, v): + raise ValueError("Invalid email format") + return v diff --git a/prometheus/app/models/requests/user.py b/prometheus/app/models/requests/user.py index 6a0e3144..171f14a7 100644 --- a/prometheus/app/models/requests/user.py +++ b/prometheus/app/models/requests/user.py @@ -1,26 +1,5 @@ -import re +from pydantic import BaseModel, Field -from pydantic import BaseModel, Field, field_validator - -class CreateUserRequest(BaseModel): - username: str = Field(description="username of the user", max_length=20) - email: str = Field( - description="email of the user", - examples=["your_email@gmail.com"], - max_length=30, - ) - password: str = Field( - description="password of the user", - examples=["P@ssw0rd!"], - min_length=12, - max_length=30, - ) - github_token: str = Field(description="github token of the user", max_length=100) - - @field_validator("email", mode="after") - def validate_email_format(self, v: str) -> str: - pattern = r"^[^@\s]+@[^@\s]+\.[^@\s]+$" - if not re.match(pattern, v): - raise ValueError("Invalid email format") - return v +class SetGithubTokenRequest(BaseModel): + github_token: str = Field(description="GitHub token of the user", max_length=100) diff --git a/prometheus/app/models/response/user.py b/prometheus/app/models/response/user.py new file mode 100644 index 00000000..0b596c1c --- /dev/null +++ b/prometheus/app/models/response/user.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + + +class UserResponse(BaseModel): + """ + Response model for a user. + """ + + model_config = { + "from_attributes": True, + } + + id: int + username: str + email: str + issue_credit: int + is_superuser: bool diff --git a/prometheus/app/services/invitation_code_service.py b/prometheus/app/services/invitation_code_service.py new file mode 100644 index 00000000..45ee2fd9 --- /dev/null +++ b/prometheus/app/services/invitation_code_service.py @@ -0,0 +1,76 @@ +import datetime +import logging +import uuid +from typing import Sequence + +from sqlmodel import Session, select + +from prometheus.app.entity.invitation_code import InvitationCode +from prometheus.app.services.base_service import BaseService +from prometheus.app.services.database_service import DatabaseService + + +class InvitationCodeService(BaseService): + def __init__(self, database_service: DatabaseService): + self.database_service = database_service + self.engine = database_service.engine + self._logger = logging.getLogger("prometheus.app.services.invitation_code_service") + + def create_invitation_code(self) -> InvitationCode: + """ + Create a new invitation code and commit it to the database. + + Returns: + InvitationCode: The created invitation code instance. + """ + + with Session(self.engine) as session: + code = str(uuid.uuid4()) + invitation_code = InvitationCode(code=code) + session.add(invitation_code) + session.commit() + session.refresh(invitation_code) + return invitation_code + + def list_invitation_codes(self) -> Sequence[InvitationCode]: + """ + List all invitation codes from the database. + + Returns: + Sequence[InvitationCode]: A list of all invitation code instances. + """ + with Session(self.engine) as session: + statement = select(InvitationCode) + return session.exec(statement).all() + + def check_invitation_code(self, code: str) -> bool: + """ + Check if an invitation code is valid (exists, not used and not expired). + """ + with Session(self.engine) as session: + statement = select(InvitationCode).where(InvitationCode.code == code) + invitation_code = session.exec(statement).first() + if not invitation_code: + return False + if invitation_code.is_used: + return False + + exp = invitation_code.expiration_time + # If our database returned a naive datetime, assume it's UTC + if exp.tzinfo is None: + exp = exp.replace(tzinfo=datetime.timezone.utc) + if exp < datetime.datetime.now(datetime.timezone.utc): + return False + return True + + def mark_code_as_used(self, code: str) -> None: + """ + Mark an invitation code as used. + """ + with Session(self.engine) as session: + statement = select(InvitationCode).where(InvitationCode.code == code) + invitation_code = session.exec(statement).first() + if invitation_code: + invitation_code.is_used = True + session.add(invitation_code) + session.commit() diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index 53ad5a27..d1a2f177 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -153,6 +153,5 @@ def answer_issue( return None, False, False, False, False, None, None finally: self.repository_service.update_repository_status(repository_id, is_working=False) - repository.reset_repository() logger.removeHandler(file_handler) file_handler.close() diff --git a/prometheus/app/services/repository_service.py b/prometheus/app/services/repository_service.py index df5e538e..22caa12d 100644 --- a/prometheus/app/services/repository_service.py +++ b/prometheus/app/services/repository_service.py @@ -150,6 +150,28 @@ def get_repository_by_url_and_commit_id(self, url: str, commit_id: str) -> Optio ) return session.exec(statement).first() + def get_repository_by_url_commit_id_and_user_id( + self, url: str, commit_id: str, user_id: int + ) -> Optional[Repository]: + """ + Retrieves a repository by its URL commit ID and User ID. + + Args: + url: The URL of the repository. + commit_id: The commit ID of the repository. + user_id: The user ID of the repository. + + Returns: + The Repository instance if found, otherwise None. + """ + with Session(self.engine) as session: + statement = select(Repository).where( + Repository.url == url, + Repository.commit_id == commit_id, + Repository.user_id == user_id, + ) + return session.exec(statement).first() + def update_repository_status(self, repository_id: int, is_working: bool): """ Updates the working status of a repository. diff --git a/prometheus/app/services/user_service.py b/prometheus/app/services/user_service.py index 92d5c628..bdd4091a 100644 --- a/prometheus/app/services/user_service.py +++ b/prometheus/app/services/user_service.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Optional, Sequence from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError @@ -45,10 +45,10 @@ def create_user( with Session(self.engine) as session: statement = select(User).where(User.username == username) if session.exec(statement).first(): - raise ValueError(f"Username '{username}' already exists") + raise ServerException(400, f"Username '{username}' already exists") statement = select(User).where(User.email == email) if session.exec(statement).first(): - raise ValueError(f"Email '{email}' already exists") + raise ServerException(400, f"Email '{email}' already exists") hashed_password = self.ph.hash(password) @@ -150,3 +150,34 @@ def update_issue_credit(self, user_id: int, new_issue_credit) -> None: user.issue_credit = new_issue_credit session.add(user) session.commit() + + def is_admin(self, user_id): + """ + Check if a user is an admin (superuser) by their ID. + """ + with Session(self.engine) as session: + statement = select(User).where(User.id == user_id) + user = session.exec(statement).first() + return user.is_superuser if user else False + + def list_users(self) -> Sequence[User]: + """ + List all users in the database. + """ + with Session(self.engine) as session: + statement = select(User) + users = session.exec(statement).all() + return users + + def set_github_token(self, user_id: int, github_token: str): + """ + Set GitHub token for a user by their ID. + """ + with Session(self.engine) as session: + statement = select(User).where(User.id == user_id) + user = session.exec(statement).first() + if user: + user.github_token = github_token + session.add(user) + session.commit() + session.refresh(user) diff --git a/prometheus/configuration/config.py b/prometheus/configuration/config.py index ad8ca34e..c754f47c 100644 --- a/prometheus/configuration/config.py +++ b/prometheus/configuration/config.py @@ -56,7 +56,16 @@ class Settings(BaseSettings): # JWT Configuration JWT_SECRET_KEY: str - ACCESS_TOKEN_EXPIRE_TIME: int = 7 # days + ACCESS_TOKEN_EXPIRE_TIME: int = 30 # days + + # Invitation Code Expire Time + INVITATION_CODE_EXPIRE_TIME: int = 14 # days + + # Default normal user issue credit + DEFAULT_USER_ISSUE_CREDIT: int = 20 + + # Default normal user repository number + DEFAULT_USER_REPOSITORY_LIMIT: int = 5 settings = Settings() diff --git a/prometheus/lang_graph/graphs/issue_graph.py b/prometheus/lang_graph/graphs/issue_graph.py index 47143941..1f4a8e23 100644 --- a/prometheus/lang_graph/graphs/issue_graph.py +++ b/prometheus/lang_graph/graphs/issue_graph.py @@ -12,6 +12,7 @@ from prometheus.lang_graph.nodes.issue_classification_subgraph_node import ( IssueClassificationSubgraphNode, ) +from prometheus.lang_graph.nodes.issue_question_subgraph_node import IssueQuestionSubgraphNode from prometheus.lang_graph.nodes.noop_node import NoopNode @@ -39,7 +40,8 @@ def __init__( # Entrance point for the issue handling workflow issue_type_branch_node = NoopNode() - # Subgraph nodes for issue classification and bug handling + + # Subgraph nodes for issue classification issue_classification_subgraph_node = IssueClassificationSubgraphNode( model=base_model, kg=kg, @@ -47,6 +49,8 @@ def __init__( neo4j_driver=neo4j_driver, max_token_per_neo4j_result=max_token_per_neo4j_result, ) + + # Subgraph node for handling bug issues issue_bug_subgraph_node = IssueBugSubgraphNode( advanced_model=advanced_model, base_model=base_model, @@ -58,12 +62,24 @@ def __init__( build_commands=build_commands, test_commands=test_commands, ) + + # Subgraph node for handling question issues + issue_question_subgraph_node = IssueQuestionSubgraphNode( + advanced_model=advanced_model, + base_model=base_model, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + ) + # Create the state graph for the issue handling workflow workflow = StateGraph(IssueState) # Add nodes to the workflow workflow.add_node("issue_type_branch_node", issue_type_branch_node) workflow.add_node("issue_classification_subgraph_node", issue_classification_subgraph_node) workflow.add_node("issue_bug_subgraph_node", issue_bug_subgraph_node) + workflow.add_node("issue_question_subgraph_node", issue_question_subgraph_node) # Set the entry point for the workflow workflow.set_entry_point("issue_type_branch_node") # Define the edges and conditions for the workflow @@ -76,7 +92,7 @@ def __init__( IssueType.BUG: "issue_bug_subgraph_node", IssueType.FEATURE: END, IssueType.DOCUMENTATION: END, - IssueType.QUESTION: END, + IssueType.QUESTION: "issue_question_subgraph_node", }, ) # Add edges for the issue classification subgraph @@ -87,11 +103,12 @@ def __init__( IssueType.BUG: "issue_bug_subgraph_node", IssueType.FEATURE: END, IssueType.DOCUMENTATION: END, - IssueType.QUESTION: END, + IssueType.QUESTION: "issue_question_subgraph_node", }, ) # Add edges for ending the workflow workflow.add_edge("issue_bug_subgraph_node", END) + workflow.add_edge("issue_question_subgraph_node", END) self.graph = workflow.compile() diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index 35f83eda..8ec73943 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -162,6 +162,10 @@ def __call__(self, state: IssueNotVerifiedBugState): if not patches: self._logger.warning("No candidate patches available for selection.") return {"final_patch": ""} + # Handle the case with only one candidate patch + elif len(patches) == 1: + self._logger.info("Only one candidate patch available, selecting it by default.") + return {"final_patch": patches[0]} # Formalize Human Message human_prompt = self.format_human_message(patches, state) diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index 3256c6b7..de4f6e7f 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -55,13 +55,14 @@ def __call__(self, state: Dict): ) except GraphRecursionError: self._logger.debug("GraphRecursionError encountered, returning empty patch") - self.git_repo.reset_repository() return { "edit_patch": None, "passed_reproducing_test": False, "passed_build": False, "passed_existing_test": False, } + finally: + self.git_repo.reset_repository() self._logger.info(f"final_patch:\n{output_state['final_patch']}") diff --git a/prometheus/lang_graph/nodes/issue_question_analyzer_node.py b/prometheus/lang_graph/nodes/issue_question_analyzer_node.py new file mode 100644 index 00000000..21bf01dd --- /dev/null +++ b/prometheus/lang_graph/nodes/issue_question_analyzer_node.py @@ -0,0 +1,68 @@ +import logging +import threading + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import HumanMessage, SystemMessage + +from prometheus.lang_graph.subgraphs.issue_question_state import IssueQuestionState +from prometheus.utils.issue_util import format_issue_info + + +class IssueQuestionAnalyzerNode: + SYS_PROMPT = """ +You are an expert software engineer specializing in analysis and answering issue. Your role is to: + +1. Carefully analyze reported software issues and question by: + - Understanding issue descriptions and symptoms + - Identifying related code components + +2. Answer the question through systematic investigation: + - Identify which specific code elements are related to the question + - Understand the context and interactions related to the question or issue + +3. Provide high-level answer suggestions step by step + +Important: +- You may provide actual code snippets or diffs if necessary +- Keep descriptions precise and actionable + +Communicate in a clear, technical manner focused on accurate analysis and practical suggestions +rather than implementation details. +""" + HUMAN_PROMPT = """ + Here is a Github issue description: + -- BEGIN ISSUE -- + {issue_info} + -- END ISSUE -- + + Here is the relevant code context and documentation needed to understand and answer this issue: + --- BEGIN CONTEXT -- + {question_context} + --- END CONTEXT -- + + Based on the above information, please provide a detailed answer to the question. + """ + + def __init__(self, model: BaseChatModel): + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self.model = model + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_analyzer_node" + ) + + def __call__(self, state: IssueQuestionState): + human_prompt = HumanMessage( + self.HUMAN_PROMPT.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + question_context="\n\n".join( + [str(context) for context in state["question_context"]] + ), + ) + ) + message_history = [self.system_prompt, human_prompt] + response = self.model.invoke(message_history) + + self._logger.debug(response) + return {"question_response": response.content} diff --git a/prometheus/lang_graph/nodes/issue_question_context_message_node.py b/prometheus/lang_graph/nodes/issue_question_context_message_node.py new file mode 100644 index 00000000..b9546398 --- /dev/null +++ b/prometheus/lang_graph/nodes/issue_question_context_message_node.py @@ -0,0 +1,35 @@ +import logging +import threading +from typing import Dict + +from prometheus.utils.issue_util import format_issue_info + + +class IssueQuestionContextMessageNode: + QUESTION_QUERY = """\ +{issue_info} + +Find all relevant source code context and documentation needed to understand and answer this issue. +Focus on both production code (ignore test files) and documentations (e.g. README.md) and follow these steps: +1. Identify key components mentioned in the issue (functions, classes, types, etc.) +2. Find their complete implementations and class definitions +3. Include related code from the same module that affects the behavior +4. Follow imports to find dependent code that directly impacts the issue +5. Include relevant documentation that helps understand the issue + +Skip any test files +""" + + def __init__(self): + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_context_message_node" + ) + + def __call__(self, state: Dict): + question_query = self.QUESTION_QUERY.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + ) + self._logger.debug(f"Sending query to context provider:\n{question_query}") + return {"question_query": question_query} diff --git a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py new file mode 100644 index 00000000..451dd185 --- /dev/null +++ b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py @@ -0,0 +1,71 @@ +import logging +import threading + +import neo4j +from langchain_core.language_models.chat_models import BaseChatModel +from langgraph.errors import GraphRecursionError + +from prometheus.git.git_repository import GitRepository +from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.lang_graph.graphs.issue_state import IssueState +from prometheus.lang_graph.subgraphs.issue_question_subgraph import IssueQuestionSubgraph + + +class IssueQuestionSubgraphNode: + """ + A LangGraph node that handles the issue question subgraph, which is responsible for answering question in a GitHub issue. + """ + + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + ): + self._logger = logging.getLogger( + f"thread-{threading.get_ident()}.prometheus.lang_graph.nodes.issue_question_subgraph_node" + ) + self.issue_question_subgraph = IssueQuestionSubgraph( + advanced_model=advanced_model, + base_model=base_model, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + ) + + def __call__(self, state: IssueState): + # Logging entry into the node + self._logger.info("Enter IssueQuestionSubgraphNode") + + try: + output_state = self.issue_question_subgraph.invoke( + issue_title=state["issue_title"], + issue_body=state["issue_body"], + issue_comments=state["issue_comments"], + ) + except GraphRecursionError: + # Handle recursion error gracefully + self._logger.critical("Please increase the recursion limit of IssueQuestionSubgraph") + return { + "edit_patch": None, + "passed_reproducing_test": False, + "passed_build": False, + "passed_regression_test": False, + "passed_existing_test": False, + "issue_response": None, + } + + # Logging the issue response for debugging + self._logger.info(f"issue_response:\n{output_state['issue_response']}") + return { + "edit_patch": output_state["edit_patch"], + "passed_reproducing_test": output_state["passed_reproducing_test"], + "passed_build": output_state["passed_build"], + "passed_regression_test": output_state["passed_regression_test"], + "passed_existing_test": output_state["passed_existing_test"], + "issue_response": output_state["issue_response"], + } diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index bed44acf..cf214559 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -65,13 +65,14 @@ def __call__(self, state: IssueBugState): ) except GraphRecursionError: self._logger.info("Recursion limit reached") - self.git_repo.reset_repository() return { "edit_patch": None, "passed_reproducing_test": False, "passed_build": False, "passed_existing_test": False, } + finally: + self.git_repo.reset_repository() # if all the tests passed passed_reproducing_test = not bool(output_state["reproducing_test_fail_log"]) # if the build passed diff --git a/prometheus/lang_graph/subgraphs/issue_question_state.py b/prometheus/lang_graph/subgraphs/issue_question_state.py new file mode 100644 index 00000000..a434599f --- /dev/null +++ b/prometheus/lang_graph/subgraphs/issue_question_state.py @@ -0,0 +1,16 @@ +from typing import Mapping, Sequence, TypedDict + +from prometheus.models.context import Context + + +class IssueQuestionState(TypedDict): + issue_title: str + issue_body: str + issue_comments: Sequence[Mapping[str, str]] + + max_refined_query_loop: int + + question_query: str + question_context: Sequence[Context] + + question_response: str diff --git a/prometheus/lang_graph/subgraphs/issue_question_subgraph.py b/prometheus/lang_graph/subgraphs/issue_question_subgraph.py new file mode 100644 index 00000000..feeb85ce --- /dev/null +++ b/prometheus/lang_graph/subgraphs/issue_question_subgraph.py @@ -0,0 +1,91 @@ +from typing import Mapping, Sequence + +import neo4j +from langchain_core.language_models.chat_models import BaseChatModel +from langgraph.constants import END +from langgraph.graph import StateGraph + +from prometheus.git.git_repository import GitRepository +from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.lang_graph.nodes.context_retrieval_subgraph_node import ContextRetrievalSubgraphNode +from prometheus.lang_graph.nodes.issue_question_analyzer_node import IssueQuestionAnalyzerNode +from prometheus.lang_graph.nodes.issue_question_context_message_node import ( + IssueQuestionContextMessageNode, +) +from prometheus.lang_graph.subgraphs.issue_question_state import IssueQuestionState + + +class IssueQuestionSubgraph: + """ + A LangGraph-based subgraph to analyze and answer questions related to GitHub issues. + This subgraph processes issue details, retrieves relevant context, and generates a comprehensive response. + """ + + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + ): + # Step 1: Retrieve relevant context based on the issue details + issue_question_context_message_node = IssueQuestionContextMessageNode() + context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( + model=base_model, + kg=kg, + local_path=git_repo.playground_path, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + query_key_name="question_query", + context_key_name="question_context", + ) + + # Step 2: Analyze the issue and retrieved context to generate a response + issue_question_analyzer_node = IssueQuestionAnalyzerNode(model=advanced_model) + + # Define the subgraph structure + workflow = StateGraph(IssueQuestionState) + workflow.add_node( + "issue_question_context_message_node", issue_question_context_message_node + ) + workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) + workflow.add_node("issue_question_analyzer_node", issue_question_analyzer_node) + + # Define the entry point + workflow.set_entry_point("issue_question_context_message_node") + + # Define the workflow transitions + workflow.add_edge("issue_question_context_message_node", "context_retrieval_subgraph_node") + workflow.add_edge("context_retrieval_subgraph_node", "issue_question_analyzer_node") + workflow.add_edge("issue_question_analyzer_node", END) + + # Compile the workflow into an executable subgraph + self.subgraph = workflow.compile() + + def invoke( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + recursion_limit: int = 30, + ): + config = {"recursion_limit": recursion_limit} + + input_state = { + "issue_title": issue_title, + "issue_body": issue_body, + "issue_comments": issue_comments, + "max_refined_query_loop": 3, + } + + output_state = self.subgraph.invoke(input_state, config) + return { + "edit_patch": None, + "passed_reproducing_test": False, + "passed_build": False, + "passed_existing_test": False, + "passed_regression_test": False, + "issue_response": output_state["question_response"], + } diff --git a/tests/app/api/test_auth.py b/tests/app/api/test_auth.py index ec710ce7..a6aefc9b 100644 --- a/tests/app/api/test_auth.py +++ b/tests/app/api/test_auth.py @@ -36,3 +36,18 @@ def test_login(mock_service): "message": "success", "data": {"access_token": "your_access_token"}, } + + +def test_register(mock_service): + mock_service["invitation_code_service"].check_invitation_code.return_value = True + response = client.post( + "/auth/register", + json={ + "username": "testuser", + "email": "test@gmail.com", + "password": "passwordpassword", + "invitation_code": "f23ee204-ff33-401d-8291-1f128d0db08a", + }, + ) + assert response.status_code == 200 + assert response.json() == {"code": 200, "message": "User registered successfully", "data": None} diff --git a/tests/app/api/test_invitation_code.py b/tests/app/api/test_invitation_code.py new file mode 100644 index 00000000..448abe1f --- /dev/null +++ b/tests/app/api/test_invitation_code.py @@ -0,0 +1,87 @@ +import datetime +from unittest import mock + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from prometheus.app.api.routes import invitation_code +from prometheus.app.entity.invitation_code import InvitationCode +from prometheus.app.exception_handler import register_exception_handlers + +app = FastAPI() +register_exception_handlers(app) +app.include_router(invitation_code.router, prefix="/invitation-code", tags=["invitation_code"]) + + +@app.middleware("mock_jwt_middleware") +async def add_user_id(request: Request, call_next): + request.state.user_id = 1 # Set user_id to 1 for testing purposes + response = await call_next(request) + return response + + +client = TestClient(app) + + +@pytest.fixture +def mock_service(): + service = mock.MagicMock() + app.state.service = service + yield service + + +def test_create_invitation_code(mock_service): + # Mock the return value of create_invitation_code + mock_service["invitation_code_service"].create_invitation_code.return_value = InvitationCode( + id=1, + code="testcode", + is_used=False, + expiration_time=datetime.datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0), + ) + mock_service["user_service"].is_admin.return_value = True + + # Test the creation endpoint + response = client.post("invitation-code/create/") + assert response.status_code == 200 + assert response.json() == { + "code": 200, + "message": "success", + "data": { + "id": 1, + "code": "testcode", + "is_used": False, + "expiration_time": "2025-01-01T00:00:00", + }, + } + + +def test_list(mock_service): + # Mock user as admin and return a list of invitation codes + mock_service["invitation_code_service"].list_invitation_codes.return_value = [ + InvitationCode( + id=1, + code="testcode", + is_used=False, + expiration_time=datetime.datetime( + year=2025, month=1, day=1, hour=0, minute=0, second=0 + ), + ) + ] + mock_service["user_service"].is_admin.return_value = True + + # Test the list endpoint + response = client.get("invitation-code/list/") + assert response.status_code == 200 + assert response.json() == { + "code": 200, + "message": "success", + "data": [ + { + "id": 1, + "code": "testcode", + "is_used": False, + "expiration_time": "2025-01-01T00:00:00", + } + ], + } diff --git a/tests/app/api/test_user.py b/tests/app/api/test_user.py new file mode 100644 index 00000000..4c59bb5c --- /dev/null +++ b/tests/app/api/test_user.py @@ -0,0 +1,79 @@ +from unittest import mock + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from prometheus.app.api.routes import user +from prometheus.app.entity.user import User +from prometheus.app.exception_handler import register_exception_handlers + +app = FastAPI() +register_exception_handlers(app) +app.include_router(user.router, prefix="/user", tags=["user"]) + + +@app.middleware("mock_jwt_middleware") +async def add_user_id(request: Request, call_next): + request.state.user_id = 1 # Set user_id to 1 for testing purposes + response = await call_next(request) + return response + + +client = TestClient(app) + + +@pytest.fixture +def mock_service(): + service = mock.MagicMock() + app.state.service = service + yield service + + +def test_list(mock_service): + # Mock user as admin and return a list of users + mock_service["user_service"].list_users.return_value = [ + User( + id=1, + username="testuser", + email="test@gmail.com", + password_hash="hashedpassword", + github_token="ghp_1234567890abcdef1234567890abcdef1234", + issue_credit=10, + is_superuser=False, + ) + ] + mock_service["user_service"].is_admin.return_value = True + + # Test the list endpoint + response = client.get("user/list/") + assert response.status_code == 200 + assert response.json() == { + "code": 200, + "message": "success", + "data": [ + { + "id": 1, + "username": "testuser", + "email": "test@gmail.com", + "issue_credit": 10, + "is_superuser": False, + } + ], + } + + +def test_set_github_token(mock_service): + # Mock user as admin and return a list of users + mock_service["user_service"].set_github_token.return_value = None + + # Test the list endpoint + response = client.put( + "user/set-github-token/", json={"github_token": "ghp_1234567890abcdef1234567890abcdef1234"} + ) + assert response.status_code == 200 + assert response.json() == { + "code": 200, + "message": "success", + "data": None, + } diff --git a/tests/app/services/test_invitation_code.py b/tests/app/services/test_invitation_code.py new file mode 100644 index 00000000..97ca8d23 --- /dev/null +++ b/tests/app/services/test_invitation_code.py @@ -0,0 +1,124 @@ +from datetime import datetime, timedelta, timezone + +import pytest +from sqlmodel import Session, SQLModel + +from prometheus.app.entity.invitation_code import InvitationCode +from prometheus.app.services.database_service import DatabaseService +from prometheus.app.services.invitation_code_service import InvitationCodeService +from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401 + + +@pytest.fixture +def mock_database_service(postgres_container_fixture): # noqa: F811 + """Fixture: provide a clean DatabaseService using the Postgres test container.""" + service = DatabaseService(postgres_container_fixture.get_connection_url()) + service.start() + # Initialize schema + SQLModel.metadata.create_all(service.engine) + yield service + service.close() + + +@pytest.fixture +def service(mock_database_service): + """Fixture: construct an InvitationCodeService with the database service.""" + return InvitationCodeService(database_service=mock_database_service) + + +def _insert_code( + session: Session, code: str, is_used: bool = False, expires_in_seconds: int = 3600 +) -> InvitationCode: + """Helper: insert a single InvitationCode with given state and expiration.""" + obj = InvitationCode( + code=code, + is_used=is_used, + expiration_time=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds), + ) + session.add(obj) + session.commit() + session.refresh(obj) + return obj + + +def test_create_invitation_code(service): + """Test that create_invitation_code correctly generates and returns an InvitationCode.""" + invitation_code = service.create_invitation_code() + + # Verify the returned object is an InvitationCode instance + assert isinstance(invitation_code, InvitationCode) + assert isinstance(invitation_code.code, str) + assert len(invitation_code.code) == 36 # uuid4 string length + assert invitation_code.id is not None + + # Verify the object is persisted in the database + with Session(service.engine) as session: + db_obj = session.get(InvitationCode, invitation_code.id) + assert db_obj is not None + assert db_obj.code == invitation_code.code + + +def test_list_invitation_codes(service): + """Test that list_invitation_codes returns all stored invitation codes.""" + # Insert two invitation codes first + code1 = service.create_invitation_code() + code2 = service.create_invitation_code() + + codes = service.list_invitation_codes() + + # Verify length + assert len(codes) >= 2 + # Verify both created codes are included + all_codes = [c.code for c in codes] + assert code1.code in all_codes + assert code2.code in all_codes + + +def test_check_invitation_code_returns_false_when_not_exists(service): + """check_invitation_code should return False if the code does not exist.""" + ok = service.check_invitation_code("non-existent-code") + assert ok is False + + +def test_check_invitation_code_returns_false_when_used(service): + """check_invitation_code should return False if the code is already used.""" + with Session(service.engine) as session: + _insert_code(session, "used-code", is_used=True, expires_in_seconds=3600) + + ok = service.check_invitation_code("used-code") + assert ok is False + + +def test_check_invitation_code_returns_false_when_expired(service): + """check_invitation_code should return False if the code is expired.""" + with Session(service.engine) as session: + # Negative expires_in_seconds makes it expire in the past + _insert_code(session, "expired-code", is_used=False, expires_in_seconds=-60) + + ok = service.check_invitation_code("expired-code") + assert ok is False + + +def test_check_invitation_code_returns_true_when_valid(service): + """check_invitation_code should return True if the code exists, not used, and not expired.""" + with Session(service.engine) as session: + _insert_code(session, "valid-code", is_used=False, expires_in_seconds=3600) + + ok = service.check_invitation_code("valid-code") + assert ok is True + + +def test_mark_code_as_used_persists_state(service): + """mark_code_as_used should set 'used' to True and persist to DB.""" + with Session(service.engine) as session: + created = _insert_code(session, "to-use", is_used=False, expires_in_seconds=3600) + created_id = created.id + + # Act + service.mark_code_as_used("to-use") + + # Assert persisted state + with Session(service.engine) as session: + refreshed = session.get(InvitationCode, created_id) + assert refreshed is not None + assert refreshed.is_used is True diff --git a/tests/app/services/test_repository_service.py b/tests/app/services/test_repository_service.py index 23b796f9..f52292e1 100644 --- a/tests/app/services/test_repository_service.py +++ b/tests/app/services/test_repository_service.py @@ -168,7 +168,7 @@ def test_get_repository_returns_git_repo_instance(service): assert result == mock_git_repo_instance -def test_create_superuser(service): +def test_create_new_repository(service): # Exercise service.create_new_repository( url="https://github.com/test/repo", @@ -177,3 +177,10 @@ def test_create_superuser(service): user_id=None, kg_root_node_id=0, ) + + +def test_get_all_repositories(service): + # Exercise + repos = service.get_all_repositories() + # Verify + assert len(repos) == 1 diff --git a/tests/app/services/test_user_service.py b/tests/app/services/test_user_service.py index 3136972f..8194a695 100644 --- a/tests/app/services/test_user_service.py +++ b/tests/app/services/test_user_service.py @@ -25,3 +25,12 @@ def test_login(mock_database_service): access_token = service.login("testuser", "test@gmail.com", "password123") # Verify assert access_token is not None + + +def test_set_github_token(mock_database_service): + # Exercise + service = UserService(mock_database_service) + service.set_github_token(1, "new_gh_token") + # Verify + user = service.get_user_by_id(1) + assert user.github_token == "new_gh_token" diff --git a/tests/lang_graph/subgraphs/test_issue_question_subgraph.py b/tests/lang_graph/subgraphs/test_issue_question_subgraph.py new file mode 100644 index 00000000..03d4812a --- /dev/null +++ b/tests/lang_graph/subgraphs/test_issue_question_subgraph.py @@ -0,0 +1,58 @@ +from unittest.mock import Mock + +import neo4j +import pytest + +from prometheus.docker.base_container import BaseContainer +from prometheus.git.git_repository import GitRepository +from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.lang_graph.subgraphs.issue_question_subgraph import IssueQuestionSubgraph +from tests.test_utils.util import FakeListChatWithToolsModel + + +@pytest.fixture +def mock_container(): + return Mock(spec=BaseContainer) + + +@pytest.fixture +def mock_kg(): + kg = Mock(spec=KnowledgeGraph) + # Configure the mock to return a list of AST node types + kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] + kg.root_node_id = 0 + return kg + + +@pytest.fixture +def mock_git_repo(): + git_repo = Mock(spec=GitRepository) + git_repo.playground_path = "mock/playground/path" + return git_repo + + +@pytest.fixture +def mock_neo4j_driver(): + return Mock(spec=neo4j.Driver) + + +def test_issue_question_subgraph_basic_initialization( + mock_container, mock_kg, mock_git_repo, mock_neo4j_driver +): + """Test that IssueQuestionSubgraph initializes correctly with basic components.""" + # Initialize fake model with empty responses + fake_advanced_model = FakeListChatWithToolsModel(responses=[]) + fake_base_model = FakeListChatWithToolsModel(responses=[]) + + # Initialize the subgraph with required parameters + subgraph = IssueQuestionSubgraph( + advanced_model=fake_advanced_model, + base_model=fake_base_model, + kg=mock_kg, + git_repo=mock_git_repo, + neo4j_driver=mock_neo4j_driver, + max_token_per_neo4j_result=1000, + ) + + # Verify the subgraph was created + assert subgraph.subgraph is not None