diff --git a/.env.dist b/.env.dist index 16f47bc5c..c23486e1d 100644 --- a/.env.dist +++ b/.env.dist @@ -166,6 +166,7 @@ RALPH_RUNSERVER_HOST=0.0.0.0 RALPH_RUNSERVER_MAX_SEARCH_HITS_COUNT=100 RALPH_RUNSERVER_POINT_IN_TIME_KEEP_ALIVE=1m RALPH_RUNSERVER_PORT=8100 +RALPH_RUNSERVER_CORS_ALLOW_ORIGINS=["http://my-allowed-host.com","http://my-other-allowed-host.com"] # Authentication # RALPH_AUTH_FILE=/app/.ralph/auth.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f3b6ce9f..d56cf3cce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to ## [Unreleased] +### Added + +- API: support for CORS request + ### Removed - Drop support for Python 3.8 diff --git a/src/ralph/api/__init__.py b/src/ralph/api/__init__.py index e960a0701..98d38efbf 100644 --- a/src/ralph/api/__init__.py +++ b/src/ralph/api/__init__.py @@ -6,6 +6,7 @@ import sentry_sdk from fastapi import Depends, FastAPI +from fastapi.middleware.cors import CORSMiddleware from ralph.conf import settings @@ -42,6 +43,15 @@ def filter_transactions(event: Dict, hint) -> Union[Dict, None]: # noqa: ARG001 ) app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=settings.RUNSERVER_CORS_ALLOW_ORIGINS, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT"], + allow_headers=[ + "Authorization,User-Agent,Keep-Alive,Content-Type,X-Experience-API-Version" + ], +) app.include_router(statements.router) app.include_router(health.router) diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 6dfba7ed6..1e7d36e7e 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -3,7 +3,8 @@ import io from enum import Enum from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union +from urllib.parse import urlparse from pydantic import ( AfterValidator, @@ -14,6 +15,9 @@ Field, StringConstraints, TypeAdapter, + UrlConstraints, + ValidatorFunctionWrapHandler, + WrapValidator, model_validator, ) from pydantic_settings import BaseSettings, SettingsConfigDict @@ -132,7 +136,7 @@ class AuthBackend(str, Enum): def validate_auth_backends( - value: Union[str, Tuple[str, ...], List[str]] + value: Union[str, Tuple[str, ...], List[str]], ) -> Tuple[AuthBackend]: """Check whether the value is a comma separated string or a list/tuple.""" if isinstance(value, (tuple, list)): @@ -148,6 +152,52 @@ def validate_auth_backends( Union[str, Tuple[str, ...], List[str]], AfterValidator(validate_auth_backends) ] +CorsAllowOriginUrlTypeAdapter = TypeAdapter( + Annotated[ + AnyHttpUrl, + UrlConstraints( + allowed_schemes=["http", "https"], + host_required=True, + preserve_empty_path=True, + ), + ] +) + + +def validate_cors_allow_origin_url( + value: Any, _handler: ValidatorFunctionWrapHandler +) -> str: + """Accepts url strings set in the 'Origin' header of a HTTP/HTTPS request. + + Validates scheme (http or https), host, and an port. + + Note: Pydantic's URL parser automatically adds the default port for + a given schema to URLs. + This is good in most cases, except preflight requests + expect the 'Origin' header (which do not specify port if not default) + to match our allowed URL EXACTLY. + So we remove the port from Pydantic's output if we did not provide it + in the first place. + """ + url: AnyHttpUrl = CorsAllowOriginUrlTypeAdapter.validate_python(value) + if url.username is not None or url.password is not None: + raise ValueError( + "CORS AllowOrigin URLs should not include username or password" + ) + if url.path is not None: + raise ValueError("CORS AllowOrigin URLs should have an empty path") + parsed_url = urlparse(value) + explicit_port = parsed_url.port + origin = f"{url.scheme}://{url.host}{ + ':' + str(explicit_port) if explicit_port is not None + else ''}" + if origin != str(value): + raise ValueError("CORS AllowOrigin URL format incorrect") + return origin + + +CorsAllowOriginUrl = Annotated[str, WrapValidator(validate_cors_allow_origin_url)] + class Settings(BaseSettings): """Pydantic model for Ralph's global environment & configuration settings.""" @@ -208,6 +258,14 @@ class Settings(BaseSettings): RUNSERVER_MAX_SEARCH_HITS_COUNT: int = 100 RUNSERVER_POINT_IN_TIME_KEEP_ALIVE: str = "1m" RUNSERVER_PORT: int = 8100 + RUNSERVER_CORS_ALLOW_ORIGINS: List[CorsAllowOriginUrl] = Field( + [], + title="CORS Allowed Origins", + description="List of allowed origins URL when using CORS", + examples=[ + ["https://my-allowed-origin.com", "https://my-other-allowed-origin.com"] + ], + ) LRS_RESTRICT_BY_AUTHORITY: bool = False LRS_RESTRICT_BY_SCOPES: bool = False SENTRY_CLI_TRACES_SAMPLE_RATE: float = 1.0 diff --git a/tests/api/test_cors.py b/tests/api/test_cors.py new file mode 100644 index 000000000..7190aeb01 --- /dev/null +++ b/tests/api/test_cors.py @@ -0,0 +1,57 @@ +"""Tests for the health check endpoints.""" + +import json +import logging + +import pytest +from pydantic import ValidationError + +from ralph.conf import Settings + +ALLOW_ORIGINS = [ + "https://my-allowed-origin.com", + "https://my-other-allowed-origin.com", + "https://yet.another.origin.com", + "http://my-local-origin:8080", +] +ALLOW_ORIGINS_INVALID = [ + "htts://wrong-scheme.com", + "https-wrong-format.com", + "http:/another.wrong.format" "https://trailing-slash.com/", +] + + +def test_cors_allow_origin_valid_configuration( + monkeypatch, +): + """Test the settings, given a valid CORS AllowOrigin valid configuration, + should not raise an exception. + """ + monkeypatch.delenv("RALPH_RUNSERVER_CORS_ALLOW_ORIGINS", raising=False) + settings = Settings() + assert settings.RUNSERVER_CORS_ALLOW_ORIGINS == [] + + monkeypatch.setenv("RALPH_RUNSERVER_CORS_ALLOW_ORIGINS", json.dumps(ALLOW_ORIGINS)) + settings = Settings() + assert len(settings.RUNSERVER_CORS_ALLOW_ORIGINS) == len(ALLOW_ORIGINS) + for i in range(len(settings.RUNSERVER_CORS_ALLOW_ORIGINS)): + + assert settings.RUNSERVER_CORS_ALLOW_ORIGINS[i] == ALLOW_ORIGINS[i] + + +def test_cors_allow_origin_invalid_configuration( + monkeypatch, +): + """Test the settings, given an invalid CORS AllowOrigin valid configuration, + should raise an exception. + """ + for invalid_origin in ALLOW_ORIGINS_INVALID: + monkeypatch.delenv("RALPH_RUNSERVER_CORS_ALLOW_ORIGINS", raising=False) + settings = Settings() + assert settings.RUNSERVER_CORS_ALLOW_ORIGINS == [] + monkeypatch.setenv( + "RALPH_RUNSERVER_CORS_ALLOW_ORIGINS", json.dumps([invalid_origin]) + ) + with pytest.raises(ValidationError): + settings = Settings() + logging.critical(settings.RUNSERVER_CORS_ALLOW_ORIGINS)