diff --git a/pyproject.toml b/pyproject.toml index 0a1021f..32b198a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,11 @@ omit = [ "src/oci_genai_auth/__about__.py", ] +[tool.pytest.ini_options] +markers = [ + "integration: live OCI endpoint tests (require OCI_GENAI_* env vars)", +] + [tool.coverage.paths] oci_genai_auth = ["src/oci_genai_auth", "*/oci-genai-auth/src/oci_genai_auth"] tests = ["tests", "*/oci-genai-auth/tests"] diff --git a/src/oci_genai_auth/__init__.py b/src/oci_genai_auth/__init__.py index 1d1716e..2c2382a 100644 --- a/src/oci_genai_auth/__init__.py +++ b/src/oci_genai_auth/__init__.py @@ -13,8 +13,8 @@ __all__ = [ "HttpxOciAuth", - "OciSessionAuth", - "OciResourcePrincipalAuth", "OciInstancePrincipalAuth", + "OciResourcePrincipalAuth", + "OciSessionAuth", "OciUserPrincipalAuth", ] diff --git a/src/oci_genai_auth/auth.py b/src/oci_genai_auth/auth.py index b19475b..f6c0b5a 100644 --- a/src/oci_genai_auth/auth.py +++ b/src/oci_genai_auth/auth.py @@ -33,6 +33,7 @@ class HttpxOciAuth(httpx.Auth, ABC): refresh_interval: Seconds between token refreshes (default: 3600 - 1 hour) _lock: Threading lock for thread-safe token refresh _last_refresh: Last refresh timestamp + _last_refresh_error: The last refresh exception, if any (None on success) """ def __init__(self, signer: OciAuthSigner, refresh_interval: int = 3600): @@ -46,7 +47,8 @@ def __init__(self, signer: OciAuthSigner, refresh_interval: int = 3600): self.refresh_interval = refresh_interval self._lock = threading.Lock() self._last_refresh: Optional[float] = time.time() - logger.info( + self._last_refresh_error: Optional[Exception] = None + logger.debug( "Initialized %s with refresh interval: %d seconds", self.__class__.__name__, refresh_interval, @@ -76,13 +78,20 @@ def _refresh_if_needed(self) -> OciAuthSigner: """ with self._lock: if self._should_refresh_token(): - logger.info("Time interval reached, refreshing %s ...", self.__class__.__name__) + logger.debug("Time interval reached, refreshing %s ...", self.__class__.__name__) try: self._refresh_signer() self._last_refresh = time.time() + self._last_refresh_error = None logger.info("%s token refresh completed successfully", self.__class__.__name__) - except Exception: - logger.exception("Token refresh failed") + except Exception as exc: + self._last_refresh_error = exc + logger.warning( + "Scheduled token refresh failed for %s, " + "continuing with existing signer: %s", + self.__class__.__name__, + exc, + ) return self.signer def _sign_request(self, request: httpx.Request, content: bytes, signer: OciAuthSigner) -> None: @@ -112,6 +121,8 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re 2. Signs the request using OCI signer 3. Yields the signed request 4. If 401 error is received, attempts token refresh and retries once + 5. If retry refresh also fails, the generator ends and the caller + receives the original 401 rather than a silently dropped response Args: request: The HTTPX request to be authenticated Yields: @@ -138,11 +149,18 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re try: self._refresh_signer() self._last_refresh = time.time() + self._last_refresh_error = None signer = self.signer self._sign_request(request, content, signer) yield request - except Exception: - logger.exception("Token refresh on 401 failed") + except Exception as exc: + self._last_refresh_error = exc + logger.error( + "Token refresh on 401 failed for %s: %s. " + "The original 401 response will be returned to the caller.", + self.__class__.__name__, + exc, + ) class OciSessionAuth(HttpxOciAuth): @@ -231,13 +249,13 @@ def _load_token(self, config: Mapping[str, Any]) -> str: with open(token_file, "r") as f: return f.read().strip() - def _load_private_key(self, config: Any) -> str: + def _load_private_key(self, config: Any) -> Any: """ Load private key from file specified in configuration. Args: config: OCI configuration dictionary Returns: - Private key object + Private key object (RSA/EC key from cryptography library) """ return oci.signer.load_private_key_from_file(config["key_file"]) diff --git a/tests/.env.example b/tests/.env.example new file mode 100644 index 0000000..f467b34 --- /dev/null +++ b/tests/.env.example @@ -0,0 +1,52 @@ +# OCI GenAI Auth - Integration Test Configuration +# +# Copy this file to tests/.env and fill in your values. +# These variables are used by the integration test suite. +# Without this file (or without these env vars), integration tests +# are automatically skipped -- unit tests still run. +# +# IMPORTANT: Never commit tests/.env -- it is gitignored. +# +# ── Prerequisites ──────────────────────────────────────────────────────── +# +# 1. An OCI tenancy with Generative AI enabled. +# +# 2. An OCI config profile in ~/.oci/config with either: +# - API key auth (auth_type=user_principal) or +# - Session token auth (auth_type=session, run `oci session authenticate` first) +# +# 3. A GenAI Project. Create one in the OCI Console: +# Console > Analytics & AI > Generative AI > Projects > Create Project +# Or via CLI: +# oci generative-ai generative-ai-project create \ +# --compartment-id \ +# --display-name "test-project" \ +# --profile --region us-chicago-1 +# +# 4. A model available on the /openai/v1 endpoint. +# Known working models: xai.grok-3-mini-fast, xai.grok-4-1-fast-reasoning, +# google.gemini-2.5-flash, openai.gpt-5.2 (if available on your tenancy). +# Known NOT working on /openai/v1: meta.llama-*, cohere.command-* (return 404). + +# ── OCI Authentication ─────────────────────────────────────────────────── +# Which OCI config profile to use (must exist in ~/.oci/config). +OCI_GENAI_PROFILE=DEFAULT + +# Auth type: "session" (SecurityTokenSigner) or "user_principal" (API key signer). +OCI_GENAI_AUTH_TYPE=session + +# ── OCI GenAI Project ──────────────────────────────────────────────────── +# Required. The GenAI Project OCID for the /openai/v1 endpoint. +OCI_GENAI_PROJECT_ID=ocid1.generativeaiproject.oc1.us-chicago-1.aaaaaaaaexample + +# Optional. Compartment OCID (required for /chat/completions endpoint). +OCI_GENAI_COMPARTMENT_ID=ocid1.compartment.oc1..aaaaaaaaexample + +# ── Region & Model ─────────────────────────────────────────────────────── +# OCI region where GenAI is enabled. +OCI_GENAI_REGION=us-chicago-1 + +# Model to use in integration tests. Must be available on your tenancy's +# /openai/v1 endpoint. xai.grok-3-mini-fast is recommended (fast, cheap, +# available on most tenancies). +OCI_GENAI_MODEL=xai.grok-3-mini-fast diff --git a/tests/conftest.py b/tests/conftest.py index 06fcc1e..a50085b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,22 @@ # Copyright (c) 2026 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +"""Shared fixtures for both unit and integration tests.""" + +from __future__ import annotations + import os +from pathlib import Path import pytest +# --------------------------------------------------------------------------- +# Global: disable noisy tracing from OpenAI Agents SDK +# --------------------------------------------------------------------------- + @pytest.fixture(autouse=True, scope="session") def _disable_openai_agents_tracing(): - # Prevent OpenAI Agents tracing from emitting external HTTP requests during tests. os.environ.setdefault("OPENAI_AGENTS_DISABLE_TRACING", "true") try: from agents.tracing import set_tracing_disabled @@ -17,3 +25,82 @@ def _disable_openai_agents_tracing(): return set_tracing_disabled(True) yield + + +# --------------------------------------------------------------------------- +# Integration test environment +# --------------------------------------------------------------------------- + + +def _load_env(): + """Load tests/.env if present (plain KEY=VALUE, no shell expansion).""" + env_file = Path(__file__).parent / ".env" + if not env_file.exists(): + return + for line in env_file.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + key, _, value = line.partition("=") + key, value = key.strip(), value.strip() + if key: + os.environ.setdefault(key, value) + + +_load_env() + + +# Required env vars for integration tests +_REQUIRED_VARS = ( + "OCI_GENAI_PROJECT_ID", + "OCI_GENAI_REGION", + "OCI_GENAI_MODEL", + "OCI_GENAI_PROFILE", + "OCI_GENAI_AUTH_TYPE", +) + + +def _env(var: str) -> str: + return os.environ.get(var, "") + + +def _integration_configured() -> bool: + """Return True if all required env vars are set to non-placeholder values.""" + return all(_env(v) and "example" not in _env(v).lower() for v in _REQUIRED_VARS) + + +# Marker: skip integration tests when env is not configured +requires_oci = pytest.mark.skipif( + not _integration_configured(), + reason="Integration tests require OCI_GENAI_* env vars (see tests/.env.example)", +) + + +@pytest.fixture(scope="session") +def oci_project_id(): + return _env("OCI_GENAI_PROJECT_ID") + + +@pytest.fixture(scope="session") +def oci_compartment_id(): + return _env("OCI_GENAI_COMPARTMENT_ID") + + +@pytest.fixture(scope="session") +def oci_region(): + return _env("OCI_GENAI_REGION") + + +@pytest.fixture(scope="session") +def oci_model(): + return _env("OCI_GENAI_MODEL") + + +@pytest.fixture(scope="session") +def oci_profile(): + return _env("OCI_GENAI_PROFILE") + + +@pytest.fixture(scope="session") +def oci_auth_type(): + return _env("OCI_GENAI_AUTH_TYPE") diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..b38e643 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2026 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/tests/integration/test_auth_live.py b/tests/integration/test_auth_live.py new file mode 100644 index 0000000..0cbab45 --- /dev/null +++ b/tests/integration/test_auth_live.py @@ -0,0 +1,151 @@ +# Copyright (c) 2026 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Live integration tests for OCI auth classes. + +These tests make real API calls to the OCI Enterprise AI Agents endpoint +to verify that request signing works end-to-end. They are skipped +automatically when ``OCI_GENAI_*`` environment variables are not set +(see ``tests/.env.example``). + +The assertions verify that: + - The request was accepted (not 401/403/404) + - A non-empty response was returned + - OCI signing headers replaced SDK-injected headers + +They intentionally do NOT assert on response content (model output is +non-deterministic). The goal is to verify auth, not model behavior. +""" + +from __future__ import annotations + +import httpx +from openai import AsyncOpenAI, OpenAI + +from oci_genai_auth import OciSessionAuth, OciUserPrincipalAuth +from tests.conftest import requires_oci + +_ENDPOINT = "https://inference.generativeai.{region}.oci.oraclecloud.com/openai/v1" + + +def _build_auth(profile: str, auth_type: str): + """Create the appropriate auth instance based on env config.""" + if auth_type == "user_principal": + return OciUserPrincipalAuth(profile_name=profile) + return OciSessionAuth(profile_name=profile) + + +@requires_oci +class TestOciAuthSigningLive: + """Verify that OCI request signing produces valid, accepted requests.""" + + def test_sync_responses_api( + self, + oci_project_id, + oci_region, + oci_profile, + oci_auth_type, + oci_model, + ): + """Sign a sync request and verify the Responses API accepts it.""" + auth = _build_auth(oci_profile, oci_auth_type) + client = OpenAI( + base_url=_ENDPOINT.format(region=oci_region), + api_key="not-used", + project=oci_project_id, + http_client=httpx.Client(auth=auth), + ) + resp = client.responses.create( + model=oci_model, + input="What is 2+2?", + store=False, + ) + assert resp.output_text, "Expected non-empty response from Responses API" + + def test_async_responses_api( + self, + oci_project_id, + oci_region, + oci_profile, + oci_auth_type, + oci_model, + ): + """Sign an async request and verify the Responses API accepts it.""" + import asyncio + + auth = _build_auth(oci_profile, oci_auth_type) + + async def _run(): + client = AsyncOpenAI( + base_url=_ENDPOINT.format(region=oci_region), + api_key="not-used", + project=oci_project_id, + http_client=httpx.AsyncClient(auth=auth), + ) + resp = await client.responses.create( + model=oci_model, + input="What is 2+2?", + store=False, + ) + return resp.output_text + + output = asyncio.run(_run()) + assert output, "Expected non-empty response from async Responses API" + + def test_raw_httpx_request( + self, + oci_project_id, + oci_region, + oci_profile, + oci_auth_type, + oci_model, + ): + """Verify signing works at the raw httpx level (no OpenAI SDK).""" + auth = _build_auth(oci_profile, oci_auth_type) + client = httpx.Client(auth=auth) + + url = _ENDPOINT.format(region=oci_region) + "/responses" + body = { + "model": oci_model, + "input": "What is 2+2?", + "store": False, + } + headers = { + "Content-Type": "application/json", + "OpenAI-Project": oci_project_id, + } + resp = client.post(url, json=body, headers=headers, timeout=60) + + assert resp.status_code == 200, f"Expected 200, got {resp.status_code}: {resp.text[:200]}" + data = resp.json() + # Response must contain output (either output_text or output array) + assert data.get("output_text") or data.get( + "output" + ), "Expected non-empty output in response" + + def test_auth_headers_stripped( + self, + oci_project_id, + oci_region, + oci_profile, + oci_auth_type, + ): + """Verify that sdk-injected Authorization/X-Api-Key headers are + replaced by OCI signing headers, not sent alongside them.""" + auth = _build_auth(oci_profile, oci_auth_type) + request = httpx.Request( + "GET", + _ENDPOINT.format(region=oci_region) + "/responses", + headers={ + "Authorization": "Bearer should-be-stripped", + "X-Api-Key": "should-be-stripped", + }, + ) + flow = auth.auth_flow(request) + signed = next(flow) + + # OCI signing should have replaced Authorization + assert signed.headers["authorization"] != "Bearer should-be-stripped" + assert signed.headers["authorization"].startswith("Signature ") + # X-Api-Key should be gone + assert "x-api-key" not in signed.headers diff --git a/tests/test_auth.py b/tests/test_auth.py index f9e8784..10980c8 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -3,9 +3,11 @@ from __future__ import annotations +import contextlib from unittest.mock import patch import httpx +import pytest from oci_genai_auth.auth import ( HttpxOciAuth, @@ -39,6 +41,11 @@ def _refresh_signer(self) -> None: raise ConnectionError("metadata service unreachable") +# --------------------------------------------------------------------------- +# Core signing +# --------------------------------------------------------------------------- + + def test_auth_flow_signs_request(): auth = _DummyAuth(_DummySigner("signed-0")) request = httpx.Request( @@ -56,6 +63,43 @@ def test_auth_flow_signs_request(): assert signed_request.url.params.get("foo") == "bar" +def test_sign_request_strips_conflicting_headers(): + """Verify Authorization and X-Api-Key are removed before OCI signing.""" + auth = _DummyAuth(_DummySigner("oci-sig")) + request = httpx.Request( + "POST", + "https://example.com/api", + headers={ + "Authorization": "Bearer sdk-token", + "X-Api-Key": "sdk-key", + "Content-Type": "application/json", + }, + content=b'{"hello": "world"}', + ) + auth._sign_request(request, request.content, auth.signer) + assert request.headers["authorization"] == "oci-sig" + assert "x-api-key" not in request.headers + assert request.headers["content-type"] == "application/json" + + +def test_sign_request_with_body(): + """Verify signing works with POST bodies.""" + auth = _DummyAuth(_DummySigner("oci-sig")) + body = b'{"model": "grok", "input": "hello"}' + request = httpx.Request( + "POST", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com/openai/v1/responses", + content=body, + ) + auth._sign_request(request, body, auth.signer) + assert request.headers["authorization"] == "oci-sig" + + +# --------------------------------------------------------------------------- +# 401 retry +# --------------------------------------------------------------------------- + + def test_auth_flow_refreshes_on_401(): auth = _DummyAuth(_DummySigner("signed-0")) request = httpx.Request("GET", "https://example.com") @@ -67,22 +111,83 @@ def test_auth_flow_refreshes_on_401(): assert retry_request.headers["authorization"] == "signed-1" +def test_auth_flow_no_retry_on_200(): + """Non-401 responses should not trigger a refresh.""" + auth = _DummyAuth(_DummySigner("signed-0")) + request = httpx.Request("GET", "https://example.com") + flow = auth.auth_flow(request) + signed_request = next(flow) + response = httpx.Response(200, request=signed_request) + with contextlib.suppress(StopIteration): + flow.send(response) + assert auth.refresh_calls == 0 + + +def test_auth_flow_401_refresh_failure_does_not_crash(caplog): + """When 401 retry refresh fails, the generator should end gracefully + and the caller receives the original 401 (not a crash).""" + auth = _BrokenRefreshAuth(_DummySigner("signed-0"), refresh_interval=99999) + request = httpx.Request("GET", "https://example.com") + + with caplog.at_level("ERROR"): + flow = auth.auth_flow(request) + signed_request = next(flow) + response = httpx.Response(401, request=signed_request) + with pytest.raises(StopIteration): + flow.send(response) + + assert auth._last_refresh_error is not None + assert any("Token refresh on 401 failed" in r.message for r in caplog.records) + + +# --------------------------------------------------------------------------- +# Scheduled refresh +# --------------------------------------------------------------------------- + + def test_refresh_if_needed_calls_refresh_signer(): auth = _DummyAuth(_DummySigner("signed-0"), refresh_interval=0) auth._refresh_if_needed() assert auth.refresh_calls == 1 +def test_refresh_if_needed_skips_when_interval_not_reached(): + auth = _DummyAuth(_DummySigner("signed-0"), refresh_interval=99999) + auth._refresh_if_needed() + assert auth.refresh_calls == 0 + + def test_refresh_failure_does_not_break_auth_flow(caplog): auth = _BrokenRefreshAuth(_DummySigner("signed-0"), refresh_interval=0) request = httpx.Request("GET", "https://example.com") - with caplog.at_level("ERROR"): + with caplog.at_level("WARNING"): flow = auth.auth_flow(request) signed_request = next(flow) assert signed_request.headers["authorization"] == "signed-0" - assert any("Token refresh failed" in record.message for record in caplog.records) + assert any("Scheduled token refresh failed" in record.message for record in caplog.records) + + +def test_refresh_failure_tracks_last_error(): + """_last_refresh_error should be set on failure and cleared on success.""" + auth = _BrokenRefreshAuth(_DummySigner("signed-0"), refresh_interval=0) + auth._refresh_if_needed() + assert auth._last_refresh_error is not None + assert isinstance(auth._last_refresh_error, ConnectionError) + + +def test_refresh_success_clears_error(): + auth = _DummyAuth(_DummySigner("signed-0"), refresh_interval=0) + auth._last_refresh_error = ConnectionError("old error") + auth._refresh_if_needed() + assert auth._last_refresh_error is None + assert auth.refresh_calls == 1 + + +# --------------------------------------------------------------------------- +# OciSessionAuth +# --------------------------------------------------------------------------- def test_session_auth_initializes_signer_from_config(): @@ -115,6 +220,47 @@ def test_session_auth_initializes_signer_from_config(): assert auth.signer == mock_signer.return_value +def test_session_auth_missing_key_file(): + config = { + "security_token_file": "dummy.token", + "tenancy": "dummy_tenancy", + } + with ( + patch("oci.config.from_file", return_value=config), + patch("builtins.open", create=True) as mock_open, + ): + mock_open.return_value.__enter__.return_value.read.return_value = "dummy_token" + with pytest.raises(KeyError, match="key_file"): + OciSessionAuth(profile_name="DEFAULT") + + +def test_session_auth_refresh_reloads_config(): + """Verify refresh re-reads config and token files.""" + config = { + "key_file": "dummy.key", + "security_token_file": "dummy.token", + } + with ( + patch("oci.config.from_file", return_value=config) as mock_config, + patch("oci.signer.load_private_key_from_file", return_value="dummy_key"), + patch("oci.auth.signers.SecurityTokenSigner") as mock_signer, + patch("builtins.open", create=True) as mock_open, + ): + mock_open.return_value.__enter__.return_value.read.return_value = "dummy_token" + auth = OciSessionAuth(profile_name="TEST") + initial_calls = mock_config.call_count + + auth._refresh_signer() + + assert mock_config.call_count == initial_calls + 1 + assert mock_signer.call_count == 2 + + +# --------------------------------------------------------------------------- +# OciUserPrincipalAuth +# --------------------------------------------------------------------------- + + def test_user_principal_auth_uses_signer_from_config(): config = { "key_file": "dummy.key", @@ -133,6 +279,30 @@ def test_user_principal_auth_uses_signer_from_config(): assert auth.signer == mock_signer.return_value +def test_user_principal_auth_refresh_reloads_config(): + config = { + "key_file": "dummy.key", + "tenancy": "dummy_tenancy", + "user": "dummy_user", + "fingerprint": "dummy_fingerprint", + } + with ( + patch("oci.config.from_file", return_value=config) as mock_config, + patch("oci.config.validate_config", return_value=True), + patch("oci.signer.Signer") as mock_signer, + ): + auth = OciUserPrincipalAuth(profile_name="TEST") + initial_calls = mock_config.call_count + auth._refresh_signer() + assert mock_config.call_count == initial_calls + 1 + assert mock_signer.call_count == 2 + + +# --------------------------------------------------------------------------- +# OciResourcePrincipalAuth / OciInstancePrincipalAuth +# --------------------------------------------------------------------------- + + def test_resource_principal_refreshes_signer(): with patch( "oci.auth.signers.get_resource_principals_signer", return_value="signer-1"