Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions src/otdf_python/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,33 +112,14 @@ def load_client_credentials(creds_file_path: str) -> tuple[str, str]:
) from e


def build_sdk(args) -> SDK:
"""Build SDK instance from CLI arguments."""
builder = SDKBuilder()

if args.platform_url:
builder.set_platform_endpoint(args.platform_url)

# Auto-detect HTTP URLs and enable plaintext mode
if args.platform_url.startswith("http://") and (
not hasattr(args, "plaintext") or not args.plaintext
):
logger.debug(
f"Auto-detected HTTP URL {args.platform_url}, enabling plaintext mode"
)
builder.use_insecure_plaintext_connection(True)

if args.oidc_endpoint:
builder.set_issuer_endpoint(args.oidc_endpoint)

def _configure_auth(builder: SDKBuilder, args) -> None:
"""Configure authentication on the SDK builder."""
if args.client_id and args.client_secret:
builder.client_secret(args.client_id, args.client_secret)
elif hasattr(args, "with_client_creds_file") and args.with_client_creds_file:
# Load credentials from file
client_id, client_secret = load_client_credentials(args.with_client_creds_file)
builder.client_secret(client_id, client_secret)
elif hasattr(args, "auth") and args.auth:
# Parse combined auth string (clientId:clientSecret) - legacy support
auth_parts = args.auth.split(":")
if len(auth_parts) != 2:
raise CLIError(
Expand All @@ -152,12 +133,49 @@ def build_sdk(args) -> SDK:
"Authentication required: provide --with-client-creds-file OR --client-id and --client-secret",
)


def _configure_kas_allowlist(builder: SDKBuilder, args) -> None:
"""Configure KAS allowlist on the SDK builder."""
if hasattr(args, "ignore_kas_allowlist") and args.ignore_kas_allowlist:
logger.warning(
"KAS allowlist validation is disabled. This may leak credentials "
"to malicious servers if decrypting untrusted TDF files."
)
builder.with_ignore_kas_allowlist(True)
elif hasattr(args, "kas_allowlist") and args.kas_allowlist:
kas_urls = [url.strip() for url in args.kas_allowlist.split(",") if url.strip()]
logger.debug(f"Using KAS allowlist: {kas_urls}")
builder.with_kas_allowlist(kas_urls)


def build_sdk(args) -> SDK:
"""Build SDK instance from CLI arguments."""
builder = SDKBuilder()

if args.platform_url:
builder.set_platform_endpoint(args.platform_url)
# Auto-detect HTTP URLs and enable plaintext mode
if args.platform_url.startswith("http://") and (
not hasattr(args, "plaintext") or not args.plaintext
):
logger.debug(
f"Auto-detected HTTP URL {args.platform_url}, enabling plaintext mode"
)
builder.use_insecure_plaintext_connection(True)

if args.oidc_endpoint:
builder.set_issuer_endpoint(args.oidc_endpoint)

_configure_auth(builder, args)

if hasattr(args, "plaintext") and args.plaintext:
builder.use_insecure_plaintext_connection(True)

if args.insecure:
builder.use_insecure_skip_verify(True)

_configure_kas_allowlist(builder, args)

return builder.build()


Expand Down Expand Up @@ -476,6 +494,17 @@ def create_parser() -> argparse.ArgumentParser:
security_group.add_argument(
"--insecure", action="store_true", help="Skip TLS verification"
)
security_group.add_argument(
"--kas-allowlist",
help="Comma-separated list of trusted KAS URLs. "
"By default, only the platform URL's KAS endpoint is trusted.",
)
security_group.add_argument(
"--ignore-kas-allowlist",
action="store_true",
help="WARNING: Disable KAS allowlist validation. This is insecure and "
"should only be used for testing. May leak credentials to malicious servers.",
)

# Subcommands
subparsers = parser.add_subparsers(dest="command", help="Available commands")
Expand Down
182 changes: 182 additions & 0 deletions src/otdf_python/kas_allowlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""KAS Allowlist: Validates KAS URLs against a list of trusted hosts.

This module provides protection against SSRF attacks where malicious TDF files
could contain attacker-controlled KAS URLs to steal OIDC credentials.
"""

import logging
from urllib.parse import urlparse


class KASAllowlist:
"""Validates KAS URLs against an allowlist of trusted hosts.

This class prevents credential theft by ensuring the SDK only sends
authentication tokens to trusted KAS endpoints.

Example:
allowlist = KASAllowlist(["https://kas.example.com"])
allowlist.is_allowed("https://kas.example.com/kas") # True
allowlist.is_allowed("https://evil.com/kas") # False

"""

def __init__(self, allowed_urls: list[str] | None = None, allow_all: bool = False):
"""Initialize the KAS allowlist.

Args:
allowed_urls: List of trusted KAS URLs. Each URL is normalized to
its origin (scheme://host:port) for comparison.
allow_all: If True, all URLs are allowed. Use only for testing.
A warning is logged when this is enabled.

"""
self._allowed_origins: set[str] = set()
self._allow_all = allow_all

if allow_all:
logging.warning(
"KAS allowlist is disabled (allow_all=True). "
"This is insecure and should only be used for testing."
)

if allowed_urls:
for url in allowed_urls:
self.add(url)

def add(self, url: str) -> None:
"""Add a URL to the allowlist.

The URL is normalized to its origin (scheme://host:port) before storage.
Paths and query strings are stripped.

Args:
url: The KAS URL to allow. Can include path components which
will be stripped for origin comparison.

"""
origin = self._get_origin(url)
self._allowed_origins.add(origin)
logging.debug(f"Added KAS origin to allowlist: {origin}")

def is_allowed(self, url: str) -> bool:
"""Check if a URL is allowed by the allowlist.

Args:
url: The KAS URL to check.

Returns:
True if the URL's origin is in the allowlist or allow_all is True.
False otherwise.

"""
if self._allow_all:
logging.debug(f"KAS URL allowed (allow_all=True): {url}")
return True

if not self._allowed_origins:
logging.debug(f"KAS URL rejected (empty allowlist): {url}")
return False

origin = self._get_origin(url)
allowed = origin in self._allowed_origins
if allowed:
logging.debug(f"KAS URL allowed: {url} (origin: {origin})")
else:
logging.debug(
f"KAS URL rejected: {url} (origin: {origin}, "
f"allowed: {self._allowed_origins})"
)
return allowed

def validate(self, url: str) -> None:
"""Validate a URL against the allowlist, raising an exception if not allowed.

Args:
url: The KAS URL to validate.

Raises:
SDK.KasAllowlistException: If the URL is not in the allowlist.

"""
if not self.is_allowed(url):
# Import here to avoid circular imports
from .sdk import SDK

raise SDK.KasAllowlistException(url, self._allowed_origins)

@property
def allowed_origins(self) -> set[str]:
"""Return the set of allowed origins (read-only copy)."""
return self._allowed_origins.copy()

@property
def allow_all(self) -> bool:
"""Return whether all URLs are allowed."""
return self._allow_all

@staticmethod
def _get_origin(url: str) -> str:
"""Extract the origin (scheme://host:port) from a URL.

This normalizes URLs for comparison by stripping paths and query strings.
Default ports (80 for http, 443 for https) are included explicitly.

Args:
url: The URL to extract the origin from.

Returns:
Normalized origin string in format scheme://host:port

"""
# Add scheme if missing
if "://" not in url:
url = "https://" + url

try:
parsed = urlparse(url)
except Exception as e:
logging.warning(f"Failed to parse URL {url}: {e}")
# Return the URL as-is if parsing fails
return url.lower()

scheme = (parsed.scheme or "https").lower()
hostname = (parsed.hostname or "").lower()

if not hostname:
# URL might be malformed, return as-is
logging.warning(f"Could not extract hostname from URL: {url}")
return url.lower()

# Determine port (use explicit port or default for scheme)
if parsed.port:
port = parsed.port
elif scheme == "http":
port = 80
else:
port = 443

return f"{scheme}://{hostname}:{port}"

@classmethod
def from_platform_url(cls, platform_url: str) -> "KASAllowlist":
"""Create an allowlist from a platform URL.

This is the default behavior: auto-allow the platform's KAS endpoint.

Args:
platform_url: The OpenTDF platform URL. The KAS endpoint is
assumed to be at {platform_url}/kas.

Returns:
KASAllowlist configured to allow the platform's KAS endpoint.

"""
allowlist = cls()
# Add the platform URL itself (KAS might be at root or /kas)
allowlist.add(platform_url)
# Also construct the /kas endpoint explicitly
kas_url = platform_url.rstrip("/") + "/kas"
allowlist.add(kas_url)
logging.info(f"Created KAS allowlist from platform URL: {platform_url}")
return allowlist
26 changes: 25 additions & 1 deletion src/otdf_python/kas_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,26 @@ def __init__(
cache=None,
use_plaintext=False,
verify_ssl=True,
kas_allowlist=None,
):
"""Initialize KAS client."""
"""Initialize KAS client.

Args:
kas_url: Default KAS URL
token_source: Function that returns an authentication token
cache: Optional KASKeyCache for caching public keys
use_plaintext: Whether to use HTTP instead of HTTPS
verify_ssl: Whether to verify SSL certificates
kas_allowlist: Optional KASAllowlist for URL validation. If provided,
only URLs in the allowlist will be contacted.

"""
self.kas_url = kas_url
self.token_source = token_source
self.cache = cache or KASKeyCache()
self.use_plaintext = use_plaintext
self.verify_ssl = verify_ssl
self.kas_allowlist = kas_allowlist
self.decryptor = None
self.client_public_key = None

Expand Down Expand Up @@ -86,15 +99,26 @@ def close(self):
def _normalize_kas_url(self, url: str) -> str:
"""Normalize KAS URLs based on client security settings.

This method also validates the URL against the KAS allowlist if one
is configured. This prevents SSRF attacks where malicious TDF files
could contain attacker-controlled KAS URLs to steal OIDC credentials.

Args:
url: The KAS URL to normalize

Returns:
Normalized URL with appropriate protocol and port

Raises:
KASAllowlistException: If the URL is not in the allowlist

"""
from urllib.parse import urlparse

# Validate against allowlist BEFORE making any requests
if self.kas_allowlist is not None:
self.kas_allowlist.validate(url)

try:
# Parse the URL
parsed = urlparse(url)
Expand Down
30 changes: 30 additions & 0 deletions src/otdf_python/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
token_source=None,
sdk_ssl_verify=True,
use_plaintext=False,
kas_allowlist=None,
):
"""Initialize the KAS client.

Expand All @@ -45,6 +46,7 @@ def __init__(
token_source: Function that returns an authentication token
sdk_ssl_verify: Whether to verify SSL certificates
use_plaintext: Whether to use plaintext HTTP connections instead of HTTPS
kas_allowlist: Optional KASAllowlist for URL validation

"""
from .kas_client import KASClient
Expand All @@ -54,6 +56,7 @@ def __init__(
token_source=token_source,
verify_ssl=sdk_ssl_verify,
use_plaintext=use_plaintext,
kas_allowlist=kas_allowlist,
)
# Store the parameters for potential use
self._sdk_ssl_verify = sdk_ssl_verify
Expand Down Expand Up @@ -405,6 +408,33 @@ class KasBadRequestException(SDKException):
class KasAllowlistException(SDKException):
"""Throw when KAS allowlist check fails."""

def __init__(
self,
url: str,
allowed_origins: set[str] | None = None,
message: str | None = None,
):
"""Initialize exception.

Args:
url: The KAS URL that was rejected
allowed_origins: Set of allowed origin URLs
message: Optional custom message (auto-generated if not provided)

"""
self.url = url
self.allowed_origins = allowed_origins or set()
if message is None:
origins_str = (
", ".join(sorted(self.allowed_origins))
if self.allowed_origins
else "none"
)
message = (
f"KAS URL not in allowlist: {url}. Allowed origins: {origins_str}"
)
super().__init__(message)

class AssertionException(SDKException):
"""Throw when an assertion validation fails."""

Expand Down
Loading
Loading