Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.dist
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ RALPH_RUNSERVER_HOST=0.0.0.0
RALPH_RUNSERVER_MAX_SEARCH_HITS_COUNT=100
RALPH_RUNSERVER_POINT_IN_TIME_KEEP_ALIVE=1m
RALPH_RUNSERVER_PORT=8100
RALPH_RUNSERVER_CORS_ALLOW_ORIGINS=["http://my-allowed-host.com","http://my-other-allowed-host.com"]

# Authentication
# RALPH_AUTH_FILE=/app/.ralph/auth.json
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to

## [Unreleased]

### Added

- API: support for CORS request

### Removed

- Drop support for Python 3.8
Expand Down
10 changes: 10 additions & 0 deletions src/ralph/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import sentry_sdk
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware

from ralph.conf import settings

Expand Down Expand Up @@ -42,6 +43,15 @@ def filter_transactions(event: Dict, hint) -> Union[Dict, None]: # noqa: ARG001
)

app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=settings.RUNSERVER_CORS_ALLOW_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT"],
allow_headers=[
"Authorization,User-Agent,Keep-Alive,Content-Type,X-Experience-API-Version"
],
)
app.include_router(statements.router)
app.include_router(health.router)

Expand Down
62 changes: 60 additions & 2 deletions src/ralph/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import io
from enum import Enum
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
from urllib.parse import urlparse

from pydantic import (
AfterValidator,
Expand All @@ -14,6 +15,9 @@
Field,
StringConstraints,
TypeAdapter,
UrlConstraints,
ValidatorFunctionWrapHandler,
WrapValidator,
model_validator,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
Expand Down Expand Up @@ -132,7 +136,7 @@ class AuthBackend(str, Enum):


def validate_auth_backends(
value: Union[str, Tuple[str, ...], List[str]]
value: Union[str, Tuple[str, ...], List[str]],
) -> Tuple[AuthBackend]:
"""Check whether the value is a comma separated string or a list/tuple."""
if isinstance(value, (tuple, list)):
Expand All @@ -148,6 +152,52 @@ def validate_auth_backends(
Union[str, Tuple[str, ...], List[str]], AfterValidator(validate_auth_backends)
]

CorsAllowOriginUrlTypeAdapter = TypeAdapter(
Annotated[
AnyHttpUrl,
UrlConstraints(
allowed_schemes=["http", "https"],
host_required=True,
preserve_empty_path=True,
),
]
)


def validate_cors_allow_origin_url(
value: Any, _handler: ValidatorFunctionWrapHandler
) -> str:
"""Accepts url strings set in the 'Origin' header of a HTTP/HTTPS request.

Validates scheme (http or https), host, and an port.

Note: Pydantic's URL parser automatically adds the default port for
a given schema to URLs.
This is good in most cases, except preflight requests
expect the 'Origin' header (which do not specify port if not default)
to match our allowed URL EXACTLY.
So we remove the port from Pydantic's output if we did not provide it
in the first place.
"""
url: AnyHttpUrl = CorsAllowOriginUrlTypeAdapter.validate_python(value)
if url.username is not None or url.password is not None:
raise ValueError(
"CORS AllowOrigin URLs should not include username or password"
)
if url.path is not None:
raise ValueError("CORS AllowOrigin URLs should have an empty path")
parsed_url = urlparse(value)
explicit_port = parsed_url.port
origin = f"{url.scheme}://{url.host}{
':' + str(explicit_port) if explicit_port is not None
else ''}"
if origin != str(value):
raise ValueError("CORS AllowOrigin URL format incorrect")
return origin


CorsAllowOriginUrl = Annotated[str, WrapValidator(validate_cors_allow_origin_url)]


class Settings(BaseSettings):
"""Pydantic model for Ralph's global environment & configuration settings."""
Expand Down Expand Up @@ -208,6 +258,14 @@ class Settings(BaseSettings):
RUNSERVER_MAX_SEARCH_HITS_COUNT: int = 100
RUNSERVER_POINT_IN_TIME_KEEP_ALIVE: str = "1m"
RUNSERVER_PORT: int = 8100
RUNSERVER_CORS_ALLOW_ORIGINS: List[CorsAllowOriginUrl] = Field(
[],
title="CORS Allowed Origins",
description="List of allowed origins URL when using CORS",
examples=[
["https://my-allowed-origin.com", "https://my-other-allowed-origin.com"]
],
)
LRS_RESTRICT_BY_AUTHORITY: bool = False
LRS_RESTRICT_BY_SCOPES: bool = False
SENTRY_CLI_TRACES_SAMPLE_RATE: float = 1.0
Expand Down
57 changes: 57 additions & 0 deletions tests/api/test_cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Tests for the health check endpoints."""

import json
import logging

import pytest
from pydantic import ValidationError

from ralph.conf import Settings

ALLOW_ORIGINS = [
"https://my-allowed-origin.com",
"https://my-other-allowed-origin.com",
"https://yet.another.origin.com",
"http://my-local-origin:8080",
]
ALLOW_ORIGINS_INVALID = [
"htts://wrong-scheme.com",
"https-wrong-format.com",
"http:/another.wrong.format" "https://trailing-slash.com/",
]


def test_cors_allow_origin_valid_configuration(
monkeypatch,
):
"""Test the settings, given a valid CORS AllowOrigin valid configuration,
should not raise an exception.
"""
monkeypatch.delenv("RALPH_RUNSERVER_CORS_ALLOW_ORIGINS", raising=False)
settings = Settings()
assert settings.RUNSERVER_CORS_ALLOW_ORIGINS == []

monkeypatch.setenv("RALPH_RUNSERVER_CORS_ALLOW_ORIGINS", json.dumps(ALLOW_ORIGINS))
settings = Settings()
assert len(settings.RUNSERVER_CORS_ALLOW_ORIGINS) == len(ALLOW_ORIGINS)
for i in range(len(settings.RUNSERVER_CORS_ALLOW_ORIGINS)):

assert settings.RUNSERVER_CORS_ALLOW_ORIGINS[i] == ALLOW_ORIGINS[i]


def test_cors_allow_origin_invalid_configuration(
monkeypatch,
):
"""Test the settings, given an invalid CORS AllowOrigin valid configuration,
should raise an exception.
"""
for invalid_origin in ALLOW_ORIGINS_INVALID:
monkeypatch.delenv("RALPH_RUNSERVER_CORS_ALLOW_ORIGINS", raising=False)
settings = Settings()
assert settings.RUNSERVER_CORS_ALLOW_ORIGINS == []
monkeypatch.setenv(
"RALPH_RUNSERVER_CORS_ALLOW_ORIGINS", json.dumps([invalid_origin])
)
with pytest.raises(ValidationError):
settings = Settings()
logging.critical(settings.RUNSERVER_CORS_ALLOW_ORIGINS)