diff --git a/mcpgateway/main.py b/mcpgateway/main.py index d00d07376..0155dde9f 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -805,7 +805,7 @@ async def call_next(_req: starletteRequest) -> starletteResponse: Returns: starletteResponse: A response generated from the streamable HTTP call. """ - return await self._call_streamable_http(scope, receive, send) + return await self._call_streamable_http(scope, receive, send) # type: ignore[return-value] response = await self.dispatch(request, call_next) @@ -870,7 +870,7 @@ async def _call_streamable_http(self, scope, receive, send): cors_origins = [] app.add_middleware( - CORSMiddleware, + CORSMiddleware, # type: ignore[arg-type] allow_origins=cors_origins, allow_credentials=settings.cors_allow_credentials, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], @@ -880,22 +880,22 @@ async def _call_streamable_http(self, scope, receive, send): # Add security headers middleware -app.add_middleware(SecurityHeadersMiddleware) +app.add_middleware(SecurityHeadersMiddleware) # type: ignore[arg-type] # Add token scoping middleware (only when email auth is enabled) if settings.email_auth_enabled: - app.add_middleware(BaseHTTPMiddleware, dispatch=token_scoping_middleware) + app.add_middleware(BaseHTTPMiddleware, dispatch=token_scoping_middleware) # type: ignore[arg-type] # Add streamable HTTP middleware for /mcp routes with token scoping - app.add_middleware(MCPPathRewriteMiddleware, dispatch=token_scoping_middleware) + app.add_middleware(MCPPathRewriteMiddleware, dispatch=token_scoping_middleware) # type: ignore[arg-type] else: # Add streamable HTTP middleware for /mcp routes - app.add_middleware(MCPPathRewriteMiddleware) + app.add_middleware(MCPPathRewriteMiddleware) # type: ignore[arg-type] # Add custom DocsAuthMiddleware -app.add_middleware(DocsAuthMiddleware) +app.add_middleware(DocsAuthMiddleware) # type: ignore[arg-type] # Trust all proxies (or lock down with a list of host patterns) -app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") +app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") # type: ignore[arg-type] # Set up Jinja2 templates and store in app state for later use @@ -2079,7 +2079,7 @@ async def list_tools( tools_dict_list = [tool.to_dict(use_alias=True) for tool in data] - return jsonpath_modifier(tools_dict_list, apijsonpath.jsonpath, apijsonpath.mapping) + return jsonpath_modifier(tools_dict_list, apijsonpath.jsonpath or "", apijsonpath.mapping) @tool_router.post("", response_model=ToolRead) @@ -2152,7 +2152,9 @@ async def create_tool( raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(ex)) if isinstance(ex, (ValidationError, ValueError)): logger.error(f"Validation error while creating tool: {ex}") - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ErrorFormatter.format_validation_error(ex)) + if isinstance(ex, ValidationError): + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ErrorFormatter.format_validation_error(ex)) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(ex)) if isinstance(ex, IntegrityError): logger.error(f"Integrity error while creating tool: {ex}") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=ErrorFormatter.format_database_error(ex)) @@ -2194,7 +2196,7 @@ async def get_tool( data_dict = data.to_dict(use_alias=True) - return jsonpath_modifier(data_dict, apijsonpath.jsonpath, apijsonpath.mapping) + return jsonpath_modifier(data_dict, apijsonpath.jsonpath or "", apijsonpath.mapping) # type: ignore[return-type] except Exception as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) @@ -2247,7 +2249,9 @@ async def update_tool( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(ex)) if isinstance(ex, ValidationError): logger.error(f"Validation error while creating tool: {ex}") - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ErrorFormatter.format_validation_error(ex)) + if isinstance(ex, ValidationError): + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ErrorFormatter.format_validation_error(ex)) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(ex)) if isinstance(ex, IntegrityError): logger.error(f"Integrity error while creating tool: {ex}") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=ErrorFormatter.format_database_error(ex)) @@ -2425,7 +2429,14 @@ async def list_resources( return cached data = await resource_service.list_resources(db, include_inactive=include_inactive, tags=tags_list) resource_cache.set("resource_list", data) - return data + # Convert ResourceRead objects to dictionaries for API response + result: List[Dict[str, Any]] = [] + for r in data: + if hasattr(r, "model_dump"): + result.append(r.model_dump(by_alias=True)) + else: + result.append(r) # type: ignore[arg-type] + return result @resource_router.post("", response_model=ResourceRead) @@ -2657,7 +2668,12 @@ async def subscribe_resource(uri: str, user=Depends(get_current_user_with_permis StreamingResponse: A streaming response with event updates. """ logger.debug(f"User {user} is subscribing to resource with URI {uri}") - return StreamingResponse(resource_service.subscribe_events(uri), media_type="text/event-stream") + + async def event_generator(): + async for event in resource_service.subscribe_events(uri): + yield f"data: {json.dumps(event)}\n\n".encode() + + return StreamingResponse(event_generator(), media_type="text/event-stream") ############### @@ -2742,7 +2758,14 @@ async def list_prompts( # Use existing method for backward compatibility when no team filtering logger.debug(f"User: {user_email} requested prompt list with include_inactive={include_inactive}, cursor={cursor}, tags={tags_list}") data = await prompt_service.list_prompts(db, cursor=cursor, include_inactive=include_inactive, tags=tags_list) - return data + # Convert PromptRead objects to dictionaries for API response + result: List[Dict[str, Any]] = [] + for p in data: + if hasattr(p, "model_dump"): + result.append(p.model_dump(by_alias=True)) + else: + result.append(p) # type: ignore[arg-type] + return result @prompt_router.post("", response_model=PromptRead) @@ -3384,6 +3407,7 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen PluginError: If encounters issue with plugin PluginViolationError: If plugin violated the request. Example - In case of OPA plugin, if the request is denied by policy. """ + req_id = None # Initialize req_id outside try block try: # Extract user identifier from either RBAC user object or JWT payload if hasattr(user, "email"): @@ -3836,7 +3860,7 @@ async def readiness_check(db: Session = Depends(get_db)): """ try: # Run the blocking DB check in a thread to avoid blocking the event loop - await asyncio.to_thread(db.execute, text("SELECT 1")) + await asyncio.to_thread(db.execute, text("SELECT 1")) # type: ignore[arg-type] return JSONResponse(content={"status": "ready"}, status_code=200) except Exception as e: error_message = f"Readiness check failed: {str(e)}" @@ -4155,7 +4179,8 @@ async def import_configuration( try: strategy = ConflictStrategy(conflict_strategy.lower()) except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid conflict strategy. Must be one of: {[s.value for s in list(ConflictStrategy)]}") + valid_strategies = [s.value for s in ConflictStrategy.__members__.values()] + raise HTTPException(status_code=400, detail=f"Invalid conflict strategy. Must be one of: {valid_strategies}") # Extract username from user (which is now an EmailUser object) if hasattr(user, "email"): diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 07887eb89..5652c7b18 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -286,28 +286,31 @@ def _convert_tool_to_read(self, tool: DbTool) -> ToolRead: tool_dict["request_type"] = tool.request_type tool_dict["annotations"] = tool.annotations or {} - decoded_auth_value = decode_auth(tool.auth_value) - if tool.auth_type == "basic": - decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1]) - username, password = decoded_bytes.decode("utf-8").split(":") - tool_dict["auth"] = { - "auth_type": "basic", - "username": username, - "password": "********" if password else None, - } - elif tool.auth_type == "bearer": - tool_dict["auth"] = { - "auth_type": "bearer", - "token": "********" if decoded_auth_value["Authorization"] else None, - } - elif tool.auth_type == "authheaders": - tool_dict["auth"] = { - "auth_type": "authheaders", - "auth_header_key": next(iter(decoded_auth_value)), - "auth_header_value": "********" if decoded_auth_value[next(iter(decoded_auth_value))] else None, - } - else: + if not tool.auth_value: tool_dict["auth"] = None + else: + decoded_auth_value = decode_auth(tool.auth_value) + if tool.auth_type == "basic" and "Authorization" in decoded_auth_value: + decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1]) + username, password = decoded_bytes.decode("utf-8").split(":") + tool_dict["auth"] = { + "auth_type": "basic", + "username": username, + "password": "********" if password else None, + } + elif tool.auth_type == "bearer" and "Authorization" in decoded_auth_value: + tool_dict["auth"] = { + "auth_type": "bearer", + "token": "********" if decoded_auth_value["Authorization"] else None, + } + elif tool.auth_type == "authheaders" and decoded_auth_value: + tool_dict["auth"] = { + "auth_type": "authheaders", + "auth_header_key": next(iter(decoded_auth_value)), + "auth_header_value": "********" if decoded_auth_value[next(iter(decoded_auth_value))] else None, + } + else: + tool_dict["auth"] = None tool_dict["name"] = tool.name # Handle displayName with fallback and None checks @@ -359,7 +362,7 @@ async def register_tool( federation_source: Optional[str] = None, team_id: Optional[str] = None, owner_email: Optional[str] = None, - visibility: str = None, + visibility: Optional[str] = None, ) -> ToolRead: """Register a new tool with team support. @@ -425,13 +428,13 @@ async def register_tool( # Check for existing tool with the same name and visibility if visibility.lower() == "public": # Check for existing public tool with the same name - existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "public")).scalar_one_or_none() + existing_tool = db.execute(select(DbTool).where(and_(DbTool.name == tool.name, DbTool.visibility == "public"))).scalar_one_or_none() if existing_tool: raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) elif visibility.lower() == "team" and team_id: # Check for existing team tool with the same name, team_id existing_tool = db.execute( - select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "team", DbTool.team_id == team_id) # pylint: disable=comparison-with-callable + select(DbTool).where(and_(DbTool.name == tool.name, DbTool.visibility == "team", DbTool.team_id == team_id)) # pylint: disable=comparison-with-callable ).scalar_one_or_none() if existing_tool: raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) @@ -789,7 +792,8 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r ToolNotFoundError: If tool not found. ToolInvocationError: If invocation fails. PluginViolationError: If plugin blocks tool invocation. - PluginError: If encounters issue with plugin + PluginError: If encounters issue with plugin. + AssertionError: Should never be raised (unreachable code guard) Examples: >>> from mcpgateway.services.tool_service import ToolService @@ -805,9 +809,9 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r True """ # pylint: disable=comparison-with-callable - tool = db.execute(select(DbTool).where(DbTool.name == name).where(DbTool.enabled)).scalar_one_or_none() + tool = db.execute(select(DbTool).where(and_(DbTool.name == name, DbTool.enabled))).scalar_one_or_none() if not tool: - inactive_tool = db.execute(select(DbTool).where(DbTool.name == name).where(not_(DbTool.enabled))).scalar_one_or_none() + inactive_tool = db.execute(select(DbTool).where(and_(DbTool.name == name, not_(DbTool.enabled)))).scalar_one_or_none() if inactive_tool: raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") raise ToolNotFoundError(f"Tool not found: {name}") @@ -881,7 +885,7 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r if pre_result.modified_payload: payload = pre_result.modified_payload name = payload.name - arguments = payload.args + arguments = payload.args or {} if payload.headers is not None: headers = payload.headers.model_dump() @@ -976,7 +980,7 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r logger.error(f"Failed to obtain OAuth access token for gateway {gateway.name}: {e}") raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}") else: - headers = decode_auth(gateway.auth_value if gateway else None) + headers = decode_auth(gateway.auth_value if gateway and gateway.auth_value else "") # Get combined headers including gateway auth and passthrough if request_headers: @@ -990,10 +994,15 @@ async def connect_to_sse_server(server_url: str): Returns: ToolResult: Result of tool call + + Raises: + ToolInvocationError: If tool has no original_name """ async with sse_client(url=server_url, headers=headers) as streams: async with ClientSession(*streams) as session: await session.initialize() + if not tool or not tool.original_name: + raise ToolInvocationError(f"Tool '{name}' has no original_name") tool_call_result = await session.call_tool(tool.original_name, arguments) return tool_call_result @@ -1005,10 +1014,15 @@ async def connect_to_streamablehttp_server(server_url: str): Returns: ToolResult: Result of tool call + + Raises: + ToolInvocationError: If tool has no original_name """ async with streamablehttp_client(url=server_url, headers=headers) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() + if not tool or not tool.original_name: + raise ToolInvocationError(f"Tool '{name}' has no original_name") tool_call_result = await session.call_tool(tool.original_name, arguments) return tool_call_result @@ -1030,15 +1044,20 @@ async def connect_to_streamablehttp_server(server_url: str): if pre_result.modified_payload: payload = pre_result.modified_payload name = payload.name - arguments = payload.args + arguments = payload.args or {} if payload.headers is not None: headers = payload.headers.model_dump() tool_call_result = ToolResult(content=[TextContent(text="", type="text")]) if transport == "sse": + if not tool_gateway or not tool_gateway.url: + raise ToolInvocationError("Tool gateway not found or has no URL") tool_call_result = await connect_to_sse_server(tool_gateway.url) elif transport == "streamablehttp": + if not tool_gateway or not tool_gateway.url: + raise ToolInvocationError("Tool gateway not found or has no URL") tool_call_result = await connect_to_streamablehttp_server(tool_gateway.url) + # If transport is neither sse nor streamablehttp, use the default empty result content = tool_call_result.model_dump(by_alias=True).get("content", []) filtered_response = extract_using_jq(content, tool.jsonpath_filter) @@ -1082,6 +1101,8 @@ async def connect_to_streamablehttp_server(server_url: str): span.set_attribute("success", success) span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) await self._record_tool_metric(db, tool, start_time, success, error_message) + # Unreachable - all paths above return or raise + raise AssertionError("Unreachable code") async def update_tool( self, @@ -1142,13 +1163,13 @@ async def update_tool( # Check for existing tool with the same name and visibility if tool_update.visibility.lower() == "public": # Check for existing public tool with the same name - existing_tool = db.execute(select(DbTool).where(DbTool.custom_name == tool_update.custom_name, DbTool.visibility == "public")).scalar_one_or_none() + existing_tool = db.execute(select(DbTool).where(and_(DbTool.custom_name == tool_update.custom_name, DbTool.visibility == "public"))).scalar_one_or_none() if existing_tool: raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) - elif tool_update.visibility.lower() == "team" and tool_update.team_id: + elif tool_update.visibility.lower() == "team" and hasattr(tool_update, "team_id") and tool_update.team_id: # Check for existing team tool with the same name existing_tool = db.execute( - select(DbTool).where(DbTool.custom_name == tool_update.custom_name, DbTool.visibility == "team", DbTool.team_id == tool_update.team_id) + select(DbTool).where(and_(DbTool.custom_name == tool_update.custom_name, DbTool.visibility == "team", DbTool.team_id == tool_update.team_id)) ).scalar_one_or_none() if existing_tool: raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) @@ -1535,6 +1556,8 @@ async def _invoke_a2a_tool(self, db: Session, tool: DbTool, arguments: Dict[str, ToolNotFoundError: If the A2A agent is not found. """ # Extract A2A agent ID from tool annotations + if not tool.annotations: + raise ToolNotFoundError(f"A2A tool '{tool.name}' has no annotations") agent_id = tool.annotations.get("a2a_agent_id") if not agent_id: raise ToolNotFoundError(f"A2A tool '{tool.name}' missing agent ID in annotations") diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index 43be222d7..9e487ba4b 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -765,7 +765,11 @@ def test_toggle_resource_status(self, mock_toggle, test_client, auth_headers): @patch("mcpgateway.main.resource_service.subscribe_events") def test_subscribe_resource_events(self, mock_subscribe, test_client, auth_headers): """Test subscribing to resource change events via SSE.""" - mock_subscribe.return_value = iter(["data: test\n\n"]) + # Create an async generator for the mock + async def async_event_generator(): + yield {"type": "resource_update", "data": "test"} + + mock_subscribe.return_value = async_event_generator() response = test_client.post("/resources/subscribe/test/resource", headers=auth_headers) assert response.status_code == 200 assert response.headers["content-type"] == "text/event-stream; charset=utf-8"