From a594b961739d57063b1f899c29f2be7d641d792e Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Mon, 22 Jun 2026 19:21:18 +0200 Subject: [PATCH 01/26] Extract stitch-service package --- packages/stitch-service/README.md | 33 ++++++++++ packages/stitch-service/pyproject.toml | 30 +++++++++ .../src/stitch/service/__init__.py | 25 ++++++++ .../stitch-service/src/stitch/service/app.py | 64 +++++++++++++++++++ .../src/stitch/service/health.py | 45 +++++++++++++ .../src/stitch/service/middleware.py | 42 ++++++++++++ packages/stitch-service/tests/conftest.py | 6 ++ packages/stitch-service/tests/test_app.py | 52 +++++++++++++++ 8 files changed, 297 insertions(+) create mode 100644 packages/stitch-service/README.md create mode 100644 packages/stitch-service/pyproject.toml create mode 100644 packages/stitch-service/src/stitch/service/__init__.py create mode 100644 packages/stitch-service/src/stitch/service/app.py create mode 100644 packages/stitch-service/src/stitch/service/health.py create mode 100644 packages/stitch-service/src/stitch/service/middleware.py create mode 100644 packages/stitch-service/tests/conftest.py create mode 100644 packages/stitch-service/tests/test_app.py diff --git a/packages/stitch-service/README.md b/packages/stitch-service/README.md new file mode 100644 index 00000000..b876f930 --- /dev/null +++ b/packages/stitch-service/README.md @@ -0,0 +1,33 @@ +# stitch-service + +Shared FastAPI scaffolding for Stitch non-core services — the boilerplate that +`entity-linkage`, the ETL services, and `stitch-llm` otherwise each copy. + +- `create_app(...)` — app factory: sets `app.state.started_at`, registers CORS, + mounts routers under `/api/v1`, and runs service-provided startup/shutdown + hooks inside the lifespan. +- `register_cors(app, origins=...)` — the standard CORS policy. +- health helpers — `make_basic_health_router(service)` for liveness, plus + `runtime_block`/`format_started_at`/`uptime_seconds` for assembling a + service-specific `/health/details`. + +```python +from stitch.service import create_app + +def _startup(app): + validate_auth_config_at_startup() + validate_downstream_auth_config_at_startup() + +app = create_app( + routers=[health_router, start_router], + cors_origins=[str(settings.frontend_origin_url)], + on_startup=_startup, +) +``` + +## Out of scope (for now) + +- **Observability/logging** — in flight on a separate branch; will hook into the + app factory's lifespan later. +- **Auth** — each service still owns its auth wiring (settings-coupled); a future + pass may extract a configurable auth provider here. diff --git a/packages/stitch-service/pyproject.toml b/packages/stitch-service/pyproject.toml new file mode 100644 index 00000000..f48e67a3 --- /dev/null +++ b/packages/stitch-service/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = "stitch-service" +version = "0.1.0" +description = "Shared FastAPI scaffolding for Stitch non-core services (app factory, health, CORS)" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "fastapi[standard-no-fastapi-cloud-cli]>=0.135.1", +] + +[build-system] +requires = ["uv_build>=0.9.30,<0.10.0"] +build-backend = "uv_build" + +[tool.uv.build-backend] +module-name = "stitch.service" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = ["-v", "--strict-markers", "--tb=short"] + +[dependency-groups] +dev = [ + "pytest>=9.0.2", + "pytest-anyio>=0.0.0", + "httpx>=0.28.0", +] diff --git a/packages/stitch-service/src/stitch/service/__init__.py b/packages/stitch-service/src/stitch/service/__init__.py new file mode 100644 index 00000000..eccda2c8 --- /dev/null +++ b/packages/stitch-service/src/stitch/service/__init__.py @@ -0,0 +1,25 @@ +"""Shared FastAPI scaffolding for Stitch non-core services. + +Provides the app factory, CORS wiring, and health helpers that every service +otherwise copies. Observability and auth extraction are intentionally out of +scope for now (observability is in flight on a separate branch); the app +factory leaves lifecycle hooks open so they can be added later. +""" + +from .app import create_app +from .health import ( + format_started_at, + make_basic_health_router, + runtime_block, + uptime_seconds, +) +from .middleware import register_cors + +__all__ = [ + "create_app", + "format_started_at", + "make_basic_health_router", + "register_cors", + "runtime_block", + "uptime_seconds", +] diff --git a/packages/stitch-service/src/stitch/service/app.py b/packages/stitch-service/src/stitch/service/app.py new file mode 100644 index 00000000..1dec3d03 --- /dev/null +++ b/packages/stitch-service/src/stitch/service/app.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable, Callable, Sequence +from contextlib import asynccontextmanager +from datetime import UTC, datetime + +from fastapi import APIRouter, FastAPI + +from .middleware import register_cors + +#: A startup/shutdown hook: receives the app, may be sync or async. +LifecycleHook = Callable[[FastAPI], Awaitable[None] | None] + + +async def _maybe_await(value: Awaitable[None] | None) -> None: + if inspect.isawaitable(value): + await value + + +def create_app( + *, + title: str | None = None, + routers: Sequence[APIRouter] = (), + api_prefix: str = "/api/v1", + cors_origins: Sequence[str] = (), + on_startup: LifecycleHook | None = None, + on_shutdown: LifecycleHook | None = None, + **fastapi_kwargs: object, +) -> FastAPI: + """Build a FastAPI app with the scaffolding every non-core service repeats. + + Sets ``app.state.started_at`` for health/uptime, registers CORS, mounts the + given routers under ``api_prefix``, and runs the optional ``on_startup`` / + ``on_shutdown`` hooks inside the lifespan. + + ``on_startup`` is where a service does its own startup validation (auth / + downstream config). Keeping it a service-provided callback — rather than + baking specific validators in here — lets each service own and test that + logic. Observability wiring (deferred to a later pass) will hook in here + too, without reshaping this signature. + """ + + @asynccontextmanager + async def lifespan(app: FastAPI): + app.state.started_at = datetime.now(UTC) + if on_startup is not None: + await _maybe_await(on_startup(app)) + yield + if on_shutdown is not None: + await _maybe_await(on_shutdown(app)) + + if title is not None: + fastapi_kwargs["title"] = title + app = FastAPI(lifespan=lifespan, **fastapi_kwargs) + + register_cors(app, origins=cors_origins) + + base_router = APIRouter(prefix=api_prefix) + for router in routers: + base_router.include_router(router) + app.include_router(base_router) + + return app diff --git a/packages/stitch-service/src/stitch/service/health.py b/packages/stitch-service/src/stitch/service/health.py new file mode 100644 index 00000000..a93bcf7b --- /dev/null +++ b/packages/stitch-service/src/stitch/service/health.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from fastapi import APIRouter +from fastapi.responses import JSONResponse +from starlette.status import HTTP_200_OK + + +def format_started_at(value: object) -> str | None: + """Render an ``app.state.started_at`` value as an ISO-8601 UTC string.""" + if isinstance(value, datetime): + return value.astimezone(UTC).isoformat() + return None + + +def uptime_seconds(value: object) -> float | None: + if isinstance(value, datetime): + return round((datetime.now(UTC) - value).total_seconds(), 3) + return None + + +def runtime_block(started_at: object) -> dict[str, object]: + """The ``runtime`` sub-object shared by every service's /health/details.""" + return { + "started_at": format_started_at(started_at), + "uptime_seconds": uptime_seconds(started_at), + } + + +def make_basic_health_router(service: str) -> APIRouter: + """A liveness ``GET /health`` returning ``{"service", "status": "ok"}``. + + Readiness/dependency probes belong in a service-specific ``/health/details`` + (they differ per service); compose this for the trivial liveness check. + """ + router = APIRouter() + + @router.get("/health") + async def check_health() -> JSONResponse: + return JSONResponse( + {"service": service, "status": "ok"}, status_code=HTTP_200_OK + ) + + return router diff --git a/packages/stitch-service/src/stitch/service/middleware.py b/packages/stitch-service/src/stitch/service/middleware.py new file mode 100644 index 00000000..bf76685a --- /dev/null +++ b/packages/stitch-service/src/stitch/service/middleware.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Final + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +ALLOWED_METHODS: Final[tuple[str, ...]] = ( + "GET", + "POST", + "PUT", + "DELETE", + "OPTIONS", +) + +ALLOWED_HEADERS: Final[tuple[str, ...]] = ( + "Authorization", + "Content-Type", + "Accept", + "Origin", +) + + +def register_cors( + app: FastAPI, + *, + origins: Sequence[str], + allow_credentials: bool = True, +) -> None: + """Register the standard CORS policy shared across Stitch services. + + Origins are normalised (trailing slash stripped) to match how browsers send + the ``Origin`` header. + """ + app.add_middleware( + CORSMiddleware, + allow_origins=[origin.rstrip("/") for origin in origins], + allow_credentials=allow_credentials, + allow_methods=list(ALLOWED_METHODS), + allow_headers=list(ALLOWED_HEADERS), + ) diff --git a/packages/stitch-service/tests/conftest.py b/packages/stitch-service/tests/conftest.py new file mode 100644 index 00000000..5c53fe0a --- /dev/null +++ b/packages/stitch-service/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" diff --git a/packages/stitch-service/tests/test_app.py b/packages/stitch-service/tests/test_app.py new file mode 100644 index 00000000..831dd382 --- /dev/null +++ b/packages/stitch-service/tests/test_app.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from fastapi import APIRouter +from fastapi.testclient import TestClient + +from stitch.service import create_app, make_basic_health_router, runtime_block + + +def test_create_app_mounts_routers_under_prefix_and_runs_startup() -> None: + events: list[str] = [] + + router = APIRouter() + + @router.get("/ping") + async def ping() -> dict[str, str]: + return {"pong": "ok"} + + def on_startup(app) -> None: + events.append("startup") + app.state.ready = True + + app = create_app( + routers=[router, make_basic_health_router("svc")], + cors_origins=["http://localhost:3000/"], + on_startup=on_startup, + ) + + with TestClient(app) as client: + assert client.get("/api/v1/ping").json() == {"pong": "ok"} + health = client.get("/api/v1/health").json() + assert health == {"service": "svc", "status": "ok"} + + assert events == ["startup"] + assert app.state.ready is True + assert app.state.started_at is not None + + +def test_async_startup_hook_is_awaited() -> None: + events: list[str] = [] + + async def on_startup(app) -> None: + events.append("async-startup") + + app = create_app(on_startup=on_startup) + with TestClient(app): + pass + assert events == ["async-startup"] + + +def test_runtime_block_shape() -> None: + block = runtime_block(None) + assert block == {"started_at": None, "uptime_seconds": None} From bc93ec58b497af100c5e1b29d19feac845b107df Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Mon, 22 Jun 2026 19:31:35 +0200 Subject: [PATCH 02/26] extract jobs manager into package --- packages/stitch-jobs/README.md | 48 +++++ packages/stitch-jobs/pyproject.toml | 31 +++ .../stitch-jobs/src/stitch/jobs/__init__.py | 33 +++ .../stitch-jobs/src/stitch/jobs/manager.py | 115 +++++++++++ .../stitch-jobs/src/stitch/jobs/models.py | 47 +++++ .../stitch-jobs/src/stitch/jobs/routers.py | 88 ++++++++ packages/stitch-jobs/src/stitch/jobs/store.py | 113 +++++++++++ .../stitch-jobs/src/stitch/jobs/uniqueness.py | 82 ++++++++ packages/stitch-jobs/tests/conftest.py | 6 + packages/stitch-jobs/tests/test_manager.py | 191 ++++++++++++++++++ packages/stitch-jobs/tests/test_router.py | 135 +++++++++++++ 11 files changed, 889 insertions(+) create mode 100644 packages/stitch-jobs/README.md create mode 100644 packages/stitch-jobs/pyproject.toml create mode 100644 packages/stitch-jobs/src/stitch/jobs/__init__.py create mode 100644 packages/stitch-jobs/src/stitch/jobs/manager.py create mode 100644 packages/stitch-jobs/src/stitch/jobs/models.py create mode 100644 packages/stitch-jobs/src/stitch/jobs/routers.py create mode 100644 packages/stitch-jobs/src/stitch/jobs/store.py create mode 100644 packages/stitch-jobs/src/stitch/jobs/uniqueness.py create mode 100644 packages/stitch-jobs/tests/conftest.py create mode 100644 packages/stitch-jobs/tests/test_manager.py create mode 100644 packages/stitch-jobs/tests/test_router.py diff --git a/packages/stitch-jobs/README.md b/packages/stitch-jobs/README.md new file mode 100644 index 00000000..da67bd6c --- /dev/null +++ b/packages/stitch-jobs/README.md @@ -0,0 +1,48 @@ +# stitch-jobs + +Shared **"FastAPI wrapper around a terminating process"** framework for Stitch +non-core services. + +A service supplies a `run_fn(params) -> result` coroutine and gets: + +- `POST /start` — launch the work in the background; returns immediately with a + `job_id` (`202`), or joins an existing matching run (`200`). +- `GET /status/{job_id}` — poll the job's state and, once finished, its result. +- `GET /jobs` — list recent runs, newest first. + +## Deduplication ("the same request across users") + +Whether two requests are "the same" is a **per-service policy**: + +- `SingletonPolicy` — one job at a time, regardless of params. +- `FingerprintPolicy(exclude={"payload_limit"})` — same job unless meaningful + params differ (here a run capped at 500 and one at 501 collapse into one). +- `CallablePolicy(fn)` / `NoDedupPolicy` — custom, or never dedupe. + +`JobManager(recent_within=...)` controls how long after a run finishes a new +identical request reuses it (so callers see results instead of re-running). + +## Usage + +```python +from stitch.jobs import JobManager, FingerprintPolicy, make_job_router + +manager = JobManager( + run_etl, # async (params) -> result + policy=FingerprintPolicy(exclude={"payload_limit"}), + recent_within=timedelta(minutes=5), +) +router = make_job_router( + manager, + start_request_model=StartRequest, + result_model=EtlResult, + dependencies=[Depends(require_permissions(SOURCE_WRITE))], + initiated_by=current_user_label, +) +``` + +## Scope + +The default `InMemoryJobStore` is single-replica and loses state on restart. +The `JobStore` protocol is the seam for a future DB-backed store; the manager +and routers are unaffected by that swap. diff --git a/packages/stitch-jobs/pyproject.toml b/packages/stitch-jobs/pyproject.toml new file mode 100644 index 00000000..5434f71b --- /dev/null +++ b/packages/stitch-jobs/pyproject.toml @@ -0,0 +1,31 @@ +[project] +name = "stitch-jobs" +version = "0.1.0" +description = "Shared FastAPI job framework: start/status/results around a terminating process" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "fastapi[standard-no-fastapi-cloud-cli]>=0.135.1", + "pydantic>=2.12.5", +] + +[build-system] +requires = ["uv_build>=0.9.30,<0.10.0"] +build-backend = "uv_build" + +[tool.uv.build-backend] +module-name = "stitch.jobs" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = ["-v", "--strict-markers", "--tb=short"] + +[dependency-groups] +dev = [ + "pytest>=9.0.2", + "pytest-anyio>=0.0.0", + "httpx>=0.28.0", +] diff --git a/packages/stitch-jobs/src/stitch/jobs/__init__.py b/packages/stitch-jobs/src/stitch/jobs/__init__.py new file mode 100644 index 00000000..b51f811a --- /dev/null +++ b/packages/stitch-jobs/src/stitch/jobs/__init__.py @@ -0,0 +1,33 @@ +"""Shared FastAPI job framework for Stitch non-core services. + +Wraps a terminating process (``run_fn(params) -> result``) with a ``/start`` +endpoint, a ``/status`` poll, and a ``/jobs`` listing, plus per-service +deduplication so a request can be observed/reused across users. +""" + +from .manager import JobManager +from .models import TERMINAL_STATES, JobRecord, JobState +from .routers import make_job_router +from .store import InMemoryJobStore, JobStore +from .uniqueness import ( + CallablePolicy, + FingerprintPolicy, + NoDedupPolicy, + SingletonPolicy, + UniquenessPolicy, +) + +__all__ = [ + "TERMINAL_STATES", + "CallablePolicy", + "FingerprintPolicy", + "InMemoryJobStore", + "JobManager", + "JobRecord", + "JobState", + "JobStore", + "NoDedupPolicy", + "SingletonPolicy", + "UniquenessPolicy", + "make_job_router", +] diff --git a/packages/stitch-jobs/src/stitch/jobs/manager.py b/packages/stitch-jobs/src/stitch/jobs/manager.py new file mode 100644 index 00000000..b0d2ed20 --- /dev/null +++ b/packages/stitch-jobs/src/stitch/jobs/manager.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Awaitable, Callable +from datetime import UTC, datetime, timedelta +from typing import Generic +from uuid import uuid4 + +from .models import P, R, JobRecord, JobState +from .store import InMemoryJobStore, JobStore +from .uniqueness import SingletonPolicy, UniquenessPolicy + +logger = logging.getLogger("stitch.jobs") + + +def _utcnow() -> datetime: + return datetime.now(UTC) + + +class JobManager(Generic[P, R]): + """Runs a terminating process as a background job and tracks its state. + + Wraps a ``run_fn(params) -> result`` coroutine. ``start()`` launches it as + an ``asyncio.Task`` and returns immediately; the record's state transitions + ``running -> succeeded|failed`` as the task completes. Callers observe + progress via :meth:`get` / :meth:`list` (exposed over HTTP by + :func:`stitch.jobs.routers.make_job_router`). + + Deduplication is governed by the injected :class:`UniquenessPolicy`: before + starting, the manager looks for an existing run with the same key that is + still active — or finished within ``recent_within`` — and returns it instead + of starting a duplicate. That is what lets a second user observe (and reuse + the results of) a run another user already kicked off. + """ + + def __init__( + self, + run_fn: Callable[[P], Awaitable[R]], + *, + store: JobStore | None = None, + policy: UniquenessPolicy | None = None, + recent_within: timedelta = timedelta(0), + clock: Callable[[], datetime] = _utcnow, + ) -> None: + self._run_fn = run_fn + self._store: JobStore = store or InMemoryJobStore(clock=clock) + self._policy = policy or SingletonPolicy() + self._recent_within = recent_within + self._clock = clock + self._lock = asyncio.Lock() + # Hold strong refs so tasks aren't garbage-collected mid-flight. + self._tasks: set[asyncio.Task[None]] = set() + + async def start( + self, params: P, *, initiated_by: str | None = None + ) -> tuple[JobRecord[P, R], bool]: + """Start a run, or join an existing matching one. + + Returns ``(record, created)`` where ``created`` is ``False`` when an + existing active/recent run with the same dedup key was returned instead + of launching a new task. + """ + async with self._lock: + key = self._policy.key(params) + if key is not None: + existing = await self._store.find_active_or_recent( + key, recent_within=self._recent_within + ) + if existing is not None: + return existing, False + + record: JobRecord[P, R] = JobRecord( + job_id=str(uuid4()), + state=JobState.running, + dedup_key=key, + initiated_by=initiated_by, + params=params, + started_at=self._clock(), + ) + await self._store.create(record) + task = asyncio.create_task(self._run(record, params)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + return record, True + + async def _run(self, record: JobRecord[P, R], params: P) -> None: + try: + record.result = await self._run_fn(params) + record.state = JobState.succeeded + except Exception as exc: # noqa: BLE001 - captured into the record + logger.exception("job %s failed", record.job_id) + record.error = str(exc) + record.state = JobState.failed + finally: + record.finished_at = self._clock() + + def reset(self) -> None: + """Cancel in-flight tasks and drop all run state. + + For tests that share a module-level manager; not part of the request + flow. + """ + for task in self._tasks: + task.cancel() + self._tasks.clear() + clear = getattr(self._store, "clear", None) + if callable(clear): + clear() + + async def get(self, job_id: str) -> JobRecord[P, R] | None: + return await self._store.get(job_id) + + async def list(self, *, limit: int | None = None) -> list[JobRecord[P, R]]: + return await self._store.list(limit=limit) diff --git a/packages/stitch-jobs/src/stitch/jobs/models.py b/packages/stitch-jobs/src/stitch/jobs/models.py new file mode 100644 index 00000000..7ea7e83c --- /dev/null +++ b/packages/stitch-jobs/src/stitch/jobs/models.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Generic, TypeVar + +from pydantic import BaseModel + +P = TypeVar("P", bound=BaseModel) +R = TypeVar("R", bound=BaseModel) + + +class JobState(str, Enum): + running = "running" + succeeded = "succeeded" + failed = "failed" + + +#: States a job can no longer leave. +TERMINAL_STATES: frozenset[JobState] = frozenset({JobState.succeeded, JobState.failed}) + + +class JobRecord(BaseModel, Generic[P, R]): + """The full, observable state of a single job run. + + Generic over the per-service ``params`` and ``result`` Pydantic models so a + service gets typed request params and typed results in its OpenAPI schema. + + Records are mutated in place by :class:`~stitch.jobs.manager.JobManager` as + the run progresses (``state``/``result``/``error``/``finished_at``). + """ + + job_id: str + state: JobState + #: Per-service uniqueness key; ``None`` when the job is not deduplicated. + dedup_key: str | None = None + #: Human label of the user who first started the run (best-effort). + initiated_by: str | None = None + params: P + started_at: datetime + finished_at: datetime | None = None + result: R | None = None + error: str | None = None + + @property + def is_terminal(self) -> bool: + return self.state in TERMINAL_STATES diff --git a/packages/stitch-jobs/src/stitch/jobs/routers.py b/packages/stitch-jobs/src/stitch/jobs/routers.py new file mode 100644 index 00000000..2a6d0971 --- /dev/null +++ b/packages/stitch-jobs/src/stitch/jobs/routers.py @@ -0,0 +1,88 @@ +# NOTE: deliberately no `from __future__ import annotations` here. The /start +# endpoint is generated with the caller-supplied request model as a real +# annotation object; stringized annotations would break FastAPI's body parsing. + +import logging +from collections.abc import Awaitable, Callable, Sequence +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Query, Response +from pydantic import BaseModel +from starlette.status import HTTP_200_OK, HTTP_202_ACCEPTED, HTTP_404_NOT_FOUND + +from .manager import JobManager +from .models import JobRecord + +logger = logging.getLogger("stitch.jobs") + + +def make_job_router( + manager: JobManager, + *, + start_request_model: type[BaseModel], + result_model: type[BaseModel], + params_model: type[BaseModel] | None = None, + to_params: Callable[[Any], BaseModel] | None = None, + dependencies: Sequence[Any] = (), + initiated_by: Callable[..., Awaitable[str | None] | str | None] | None = None, + tags: Sequence[str] | None = None, + default_list_limit: int = 20, +) -> APIRouter: + """Build a reusable ``/start`` + ``/status`` + ``/jobs`` router for a job. + + ``start_request_model`` is the POST body; ``result_model`` is what + ``run_fn`` returns. By default the request body *is* the params; pass + ``params_model`` + ``to_params`` when the stored params differ from the + wire request. ``dependencies`` is where the service plugs in its permission + gate (e.g. ``[Depends(require_permissions(...))]``); ``initiated_by`` is an + optional dependency returning the caller's display label. + """ + params_model = params_model or start_request_model + to_params = to_params or (lambda request: request) + resolve_initiated_by = initiated_by or (lambda: None) + + record_model = JobRecord[params_model, result_model] # type: ignore[valid-type] + + router = APIRouter(tags=list(tags) if tags else None) + + @router.post( + "/start", + status_code=HTTP_202_ACCEPTED, + response_model=record_model, + dependencies=list(dependencies), + ) + async def start( + request: start_request_model, # type: ignore[valid-type] + response: Response, + initiated_by_label: Any = Depends(resolve_initiated_by), + ): + """Start the job, or join an existing matching run. + + Returns ``202`` with a fresh record, or ``200`` with the existing record + when a recent/active run with the same dedup key is found (so a second + caller observes that run rather than starting a duplicate). + """ + params = to_params(request) + record, created = await manager.start(params, initiated_by=initiated_by_label) + if not created: + response.status_code = HTTP_200_OK + return record + + @router.get("/status/{job_id}", response_model=record_model) + async def status(job_id: str): + record = await manager.get(job_id) + if record is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"No job found with id {job_id}.", + ) + return record + + @router.get("/jobs", response_model=list[record_model]) + async def jobs( + limit: int = Query(default=default_list_limit, ge=1, le=200), + ): + """List recent jobs, newest first — for discovering an in-flight run.""" + return await manager.list(limit=limit) + + return router diff --git a/packages/stitch-jobs/src/stitch/jobs/store.py b/packages/stitch-jobs/src/stitch/jobs/store.py new file mode 100644 index 00000000..082e989f --- /dev/null +++ b/packages/stitch-jobs/src/stitch/jobs/store.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from collections.abc import Callable +from datetime import UTC, datetime, timedelta +from typing import Protocol + +from .models import JobRecord, JobState + + +def _utcnow() -> datetime: + return datetime.now(UTC) + + +class JobStore(Protocol): + """Persistence seam for job records. + + The in-memory implementation below is sufficient for a single replica. A + DB-backed store (surviving restarts and shared across replicas) can be + dropped in later behind this same interface without touching the manager or + routers. + """ + + async def create(self, record: JobRecord) -> None: ... + + async def get(self, job_id: str) -> JobRecord | None: ... + + async def find_active_or_recent( + self, dedup_key: str, *, recent_within: timedelta + ) -> JobRecord | None: ... + + async def list(self, *, limit: int | None = None) -> list[JobRecord]: ... + + def clear(self) -> None: ... + + +class InMemoryJobStore: + """Process-local job store backed by a dict. + + Completed records are retained so a just-finished run is still discoverable + (for cross-user result reuse and ``GET /status``), then evicted once older + than ``retention``. State is lost on restart and is not shared across + replicas — acceptable for the current single-replica deployments. + """ + + def __init__( + self, + *, + retention: timedelta | None = timedelta(hours=1), + clock: Callable[[], datetime] = _utcnow, + ) -> None: + self._records: dict[str, JobRecord] = {} + self._retention = retention + self._clock = clock + + def _evict_expired(self) -> None: + if self._retention is None: + return + cutoff = self._clock() - self._retention + stale = [ + job_id + for job_id, record in self._records.items() + if record.finished_at is not None and record.finished_at < cutoff + ] + for job_id in stale: + del self._records[job_id] + + async def create(self, record: JobRecord) -> None: + self._evict_expired() + self._records[record.job_id] = record + + async def get(self, job_id: str) -> JobRecord | None: + self._evict_expired() + return self._records.get(job_id) + + async def find_active_or_recent( + self, dedup_key: str, *, recent_within: timedelta + ) -> JobRecord | None: + """Return the newest matching job that is still running, or that + finished within ``recent_within``. Newest-first so callers join/observe + the most relevant run. + """ + self._evict_expired() + now = self._clock() + candidates = [ + record + for record in self._records.values() + if record.dedup_key == dedup_key + and ( + record.state == JobState.running + or ( + record.finished_at is not None + and now - record.finished_at <= recent_within + ) + ) + ] + if not candidates: + return None + return max(candidates, key=lambda record: record.started_at) + + def clear(self) -> None: + """Drop all records. For tests; not part of the request flow.""" + self._records.clear() + + async def list(self, *, limit: int | None = None) -> list[JobRecord]: + self._evict_expired() + records = sorted( + self._records.values(), + key=lambda record: record.started_at, + reverse=True, + ) + if limit is not None: + records = records[:limit] + return records diff --git a/packages/stitch-jobs/src/stitch/jobs/uniqueness.py b/packages/stitch-jobs/src/stitch/jobs/uniqueness.py new file mode 100644 index 00000000..976f2476 --- /dev/null +++ b/packages/stitch-jobs/src/stitch/jobs/uniqueness.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import hashlib +import json +from collections.abc import Callable, Iterable +from typing import Protocol, runtime_checkable + +from pydantic import BaseModel + + +@runtime_checkable +class UniquenessPolicy(Protocol): + """Decides whether two requests are "the same" job. + + ``key(params)`` returns a stable string for params that should collapse to a + single shared run, or ``None`` to opt that request out of deduplication + entirely (always start a fresh job). + """ + + def key(self, params: BaseModel) -> str | None: ... + + +class SingletonPolicy: + """One job at a time, regardless of params. + + Every request maps to the same key, so while a run is active (or recently + completed, within the manager's window) a second caller joins it instead of + starting another. Use for services that must never run two jobs at once. + """ + + def __init__(self, key: str = "singleton") -> None: + self._key = key + + def key(self, params: BaseModel) -> str | None: # noqa: ARG002 - params ignored by design + return self._key + + +class FingerprintPolicy: + """Deduplicate by a hash of (a subset of) the request params. + + By default every field participates, so only byte-identical requests + collapse. Narrow the key with ``include`` (allowlist) or widen what counts + as "the same" with ``exclude`` (drop noisy/irrelevant fields). For example a + GEM ETL can ``exclude={"payload_limit"}`` so a run capped at 500 and one + capped at 501 are treated as the same job. + """ + + def __init__( + self, + *, + include: Iterable[str] | None = None, + exclude: Iterable[str] = (), + ) -> None: + self._include = set(include) if include is not None else None + self._exclude = set(exclude) + + def key(self, params: BaseModel) -> str | None: + data = params.model_dump(mode="json") + if self._include is not None: + data = {k: v for k, v in data.items() if k in self._include} + if self._exclude: + data = {k: v for k, v in data.items() if k not in self._exclude} + blob = json.dumps(data, sort_keys=True, separators=(",", ":")) + digest = hashlib.sha256(blob.encode("utf-8")).hexdigest() + return f"{type(params).__name__}:{digest}" + + +class CallablePolicy: + """Adapt an arbitrary ``params -> key`` function into a policy.""" + + def __init__(self, fn: Callable[[BaseModel], str | None]) -> None: + self._fn = fn + + def key(self, params: BaseModel) -> str | None: + return self._fn(params) + + +class NoDedupPolicy: + """Never deduplicate: every request starts a new job.""" + + def key(self, params: BaseModel) -> str | None: # noqa: ARG002 + return None diff --git a/packages/stitch-jobs/tests/conftest.py b/packages/stitch-jobs/tests/conftest.py new file mode 100644 index 00000000..5c53fe0a --- /dev/null +++ b/packages/stitch-jobs/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" diff --git a/packages/stitch-jobs/tests/test_manager.py b/packages/stitch-jobs/tests/test_manager.py new file mode 100644 index 00000000..b0a33203 --- /dev/null +++ b/packages/stitch-jobs/tests/test_manager.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import asyncio +from datetime import UTC, datetime, timedelta + +import pytest +from pydantic import BaseModel + +from stitch.jobs import ( + FingerprintPolicy, + InMemoryJobStore, + JobManager, + JobState, + SingletonPolicy, +) + + +class Params(BaseModel): + name: str + payload_limit: int | None = None + + +class Result(BaseModel): + value: int + + +async def _wait_until_terminal(manager: JobManager, job_id: str, *, timeout=2.0): + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + record = await manager.get(job_id) + if record is not None and record.is_terminal: + return record + await asyncio.sleep(0.005) + raise AssertionError("job did not reach a terminal state in time") + + +@pytest.mark.anyio +async def test_start_runs_and_succeeds() -> None: + async def run(params: Params) -> Result: + return Result(value=len(params.name)) + + manager: JobManager[Params, Result] = JobManager(run, policy=SingletonPolicy()) + record, created = await manager.start(Params(name="alpha"), initiated_by="Tester") + + assert created is True + assert record.state == JobState.running + assert record.initiated_by == "Tester" + + final = await _wait_until_terminal(manager, record.job_id) + assert final.state == JobState.succeeded + assert final.result == Result(value=5) + assert final.error is None + assert final.finished_at is not None + + +@pytest.mark.anyio +async def test_failure_is_captured_in_record() -> None: + async def run(params: Params) -> Result: + raise RuntimeError("boom") + + manager: JobManager[Params, Result] = JobManager(run) + record, _ = await manager.start(Params(name="x")) + + final = await _wait_until_terminal(manager, record.job_id) + assert final.state == JobState.failed + assert final.error == "boom" + assert final.result is None + + +@pytest.mark.anyio +async def test_singleton_joins_active_run() -> None: + release = asyncio.Event() + + async def run(params: Params) -> Result: + await release.wait() + return Result(value=1) + + manager: JobManager[Params, Result] = JobManager(run, policy=SingletonPolicy()) + first, first_created = await manager.start(Params(name="a")) + second, second_created = await manager.start(Params(name="b")) + + assert first_created is True + # Different params, but singleton policy → same active job is returned. + assert second_created is False + assert second.job_id == first.job_id + + release.set() + await _wait_until_terminal(manager, first.job_id) + + +@pytest.mark.anyio +async def test_fingerprint_splits_by_params() -> None: + release = asyncio.Event() + + async def run(params: Params) -> Result: + await release.wait() + return Result(value=1) + + manager: JobManager[Params, Result] = JobManager(run, policy=FingerprintPolicy()) + a, a_created = await manager.start(Params(name="a")) + b, b_created = await manager.start(Params(name="b")) + a_again, a_again_created = await manager.start(Params(name="a")) + + assert a_created and b_created + assert a.job_id != b.job_id # different params → independent jobs + assert a_again_created is False # identical params → joins the active 'a' run + assert a_again.job_id == a.job_id + + release.set() + await _wait_until_terminal(manager, a.job_id) + await _wait_until_terminal(manager, b.job_id) + + +@pytest.mark.anyio +async def test_fingerprint_exclude_collapses_ignored_fields() -> None: + release = asyncio.Event() + + async def run(params: Params) -> Result: + await release.wait() + return Result(value=1) + + manager: JobManager[Params, Result] = JobManager( + run, policy=FingerprintPolicy(exclude={"payload_limit"}) + ) + first, first_created = await manager.start(Params(name="gem", payload_limit=500)) + second, second_created = await manager.start(Params(name="gem", payload_limit=501)) + + # payload_limit excluded from the key → 500 and 501 are "the same" job. + assert first_created is True + assert second_created is False + assert second.job_id == first.job_id + + release.set() + await _wait_until_terminal(manager, first.job_id) + + +@pytest.mark.anyio +async def test_recent_completed_run_is_reused_within_window() -> None: + now = {"t": datetime(2026, 1, 1, tzinfo=UTC)} + + def clock() -> datetime: + return now["t"] + + async def run(params: Params) -> Result: + return Result(value=1) + + store = InMemoryJobStore(clock=clock, retention=timedelta(hours=1)) + manager: JobManager[Params, Result] = JobManager( + run, + store=store, + policy=FingerprintPolicy(), + recent_within=timedelta(minutes=5), + clock=clock, + ) + + first, _ = await manager.start(Params(name="a")) + final = await _wait_until_terminal(manager, first.job_id) + assert final.state == JobState.succeeded + + # Two minutes later: identical request reuses the just-finished run. + now["t"] = now["t"] + timedelta(minutes=2) + reused, created = await manager.start(Params(name="a")) + assert created is False + assert reused.job_id == first.job_id + + # Ten minutes after that: outside the window → a fresh run starts. + now["t"] = now["t"] + timedelta(minutes=10) + fresh, created = await manager.start(Params(name="a")) + assert created is True + assert fresh.job_id != first.job_id + await _wait_until_terminal(manager, fresh.job_id) + + +@pytest.mark.anyio +async def test_terminal_records_evicted_after_retention() -> None: + now = {"t": datetime(2026, 1, 1, tzinfo=UTC)} + + def clock() -> datetime: + return now["t"] + + async def run(params: Params) -> Result: + return Result(value=1) + + store = InMemoryJobStore(clock=clock, retention=timedelta(minutes=30)) + manager: JobManager[Params, Result] = JobManager(run, store=store, clock=clock) + + record, _ = await manager.start(Params(name="a")) + await _wait_until_terminal(manager, record.job_id) + + now["t"] = now["t"] + timedelta(hours=1) + assert await manager.get(record.job_id) is None diff --git a/packages/stitch-jobs/tests/test_router.py b/packages/stitch-jobs/tests/test_router.py new file mode 100644 index 00000000..b86459cf --- /dev/null +++ b/packages/stitch-jobs/tests/test_router.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import asyncio +import time + +from fastapi import Depends, FastAPI, HTTPException +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.status import HTTP_403_FORBIDDEN + +from stitch.jobs import FingerprintPolicy, JobManager, SingletonPolicy, make_job_router + + +class StartRequest(BaseModel): + name: str + + +class Result(BaseModel): + value: int + + +def _poll(client: TestClient, job_id: str, *, timeout: float = 5.0) -> dict: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + body = client.get(f"/api/v1/status/{job_id}").json() + if body["state"] != "running": + return body + time.sleep(0.02) + raise AssertionError("job did not finish in time") + + +def build_app(manager: JobManager, **kwargs) -> FastAPI: + app = FastAPI() + router = make_job_router( + manager, + start_request_model=StartRequest, + result_model=Result, + **kwargs, + ) + app.include_router(router, prefix="/api/v1") + return app + + +def test_start_returns_202_and_status_succeeds() -> None: + async def run(params: StartRequest) -> Result: + return Result(value=len(params.name)) + + app = build_app( + JobManager(run, policy=SingletonPolicy()), initiated_by=lambda: "Tester" + ) + + with TestClient(app) as client: + response = client.post("/api/v1/start", json={"name": "alpha"}) + assert response.status_code == 202 + body = response.json() + assert body["state"] == "running" + assert body["initiated_by"] == "Tester" + + final = _poll(client, body["job_id"]) + assert final["state"] == "succeeded" + assert final["result"] == {"value": 5} + assert final["params"] == {"name": "alpha"} + + +def test_second_caller_joins_existing_run_with_200() -> None: + async def slow_run(params: StartRequest) -> Result: + await asyncio.sleep(0.3) + return Result(value=1) + + app = build_app(JobManager(slow_run, policy=SingletonPolicy())) + + with TestClient(app) as client: + first = client.post("/api/v1/start", json={"name": "a"}) + assert first.status_code == 202 + + # Different user/params, singleton policy → joins the active run (200). + second = client.post("/api/v1/start", json={"name": "b"}) + assert second.status_code == 200 + assert second.json()["job_id"] == first.json()["job_id"] + + _poll(client, first.json()["job_id"]) + + +def test_fingerprint_policy_allows_distinct_jobs() -> None: + async def slow_run(params: StartRequest) -> Result: + await asyncio.sleep(0.3) + return Result(value=1) + + app = build_app(JobManager(slow_run, policy=FingerprintPolicy())) + + with TestClient(app) as client: + a = client.post("/api/v1/start", json={"name": "a"}) + b = client.post("/api/v1/start", json={"name": "b"}) + assert a.status_code == 202 and b.status_code == 202 + assert a.json()["job_id"] != b.json()["job_id"] + + _poll(client, a.json()["job_id"]) + _poll(client, b.json()["job_id"]) + + +def test_status_404_for_unknown_job() -> None: + async def run(params: StartRequest) -> Result: + return Result(value=1) + + app = build_app(JobManager(run)) + with TestClient(app) as client: + assert client.get("/api/v1/status/does-not-exist").status_code == 404 + + +def test_jobs_listing_returns_recent_runs() -> None: + async def run(params: StartRequest) -> Result: + return Result(value=1) + + app = build_app(JobManager(run, policy=FingerprintPolicy())) + + with TestClient(app) as client: + first = client.post("/api/v1/start", json={"name": "a"}) + _poll(client, first.json()["job_id"]) + second = client.post("/api/v1/start", json={"name": "b"}) + _poll(client, second.json()["job_id"]) + + listed = client.get("/api/v1/jobs").json() + assert {job["params"]["name"] for job in listed} == {"a", "b"} + + +def test_dependencies_gate_start() -> None: + async def run(params: StartRequest) -> Result: + return Result(value=1) + + def forbid() -> None: + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="nope") + + app = build_app(JobManager(run), dependencies=[Depends(forbid)]) + with TestClient(app) as client: + assert client.post("/api/v1/start", json={"name": "a"}).status_code == 403 From add454eb5ac426194b5cf34755b1f86ab2535549 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Mon, 22 Jun 2026 19:32:47 +0200 Subject: [PATCH 03/26] update repo infrastructure --- Makefile | 22 +++++++++++++++--- pyproject.toml | 2 ++ uv.lock | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 7daaf3a8..c460bd96 100644 --- a/Makefile +++ b/Makefile @@ -105,9 +105,23 @@ pkg-test-ogsi: pkg-test-exact-ogsi: $(MAKE) uv-test-target-exact PKG=stitch-ogsi TEST_PATH=packages/stitch-ogsi -pkg-build: pkg-build-auth pkg-build-client pkg-build-models pkg-build-ogsi -pkg-test: pkg-test-auth pkg-test-client pkg-test-models pkg-test-ogsi -pkg-test-exact: pkg-test-exact-auth pkg-test-exact-client pkg-test-exact-models pkg-test-exact-ogsi +pkg-build-service: + $(UV) build --package stitch-service +pkg-test-service: + $(MAKE) uv-test-target PKG=stitch-service TEST_PATH=packages/stitch-service +pkg-test-exact-service: + $(MAKE) uv-test-target-exact PKG=stitch-service TEST_PATH=packages/stitch-service + +pkg-build-jobs: + $(UV) build --package stitch-jobs +pkg-test-jobs: + $(MAKE) uv-test-target PKG=stitch-jobs TEST_PATH=packages/stitch-jobs +pkg-test-exact-jobs: + $(MAKE) uv-test-target-exact PKG=stitch-jobs TEST_PATH=packages/stitch-jobs + +pkg-build: pkg-build-auth pkg-build-client pkg-build-models pkg-build-ogsi pkg-build-service pkg-build-jobs +pkg-test: pkg-test-auth pkg-test-client pkg-test-models pkg-test-ogsi pkg-test-service pkg-test-jobs +pkg-test-exact: pkg-test-exact-auth pkg-test-exact-client pkg-test-exact-models pkg-test-exact-ogsi pkg-test-exact-service pkg-test-exact-jobs # --------------------------------------------------------------------- # Deployments @@ -282,6 +296,8 @@ follow-stack-logs: pkg-build-client pkg-test-client pkg-test-exact-client \ pkg-build-models pkg-test-models pkg-test-exact-models \ pkg-build-ogsi pkg-test-ogsi pkg-test-exact-ogsi \ + pkg-build-service pkg-test-service pkg-test-exact-service \ + pkg-build-jobs pkg-test-jobs pkg-test-exact-jobs \ \ # API api-build api-test api-test-exact api-dev stack-api-dev \ diff --git a/pyproject.toml b/pyproject.toml index 4b0a94d7..1da6fff7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,10 @@ members = [ "deployments/stitch-llm", "packages/stitch-auth", "packages/stitch-client", + "packages/stitch-jobs", "packages/stitch-models", "packages/stitch-ogsi", + "packages/stitch-service", ] [tool.uv.sources] diff --git a/uv.lock b/uv.lock index 1c362a4c..049eaed0 100644 --- a/uv.lock +++ b/uv.lock @@ -9,10 +9,12 @@ members = [ "stitch-auth", "stitch-client", "stitch-entity-linkage", + "stitch-jobs", "stitch-llm", "stitch-models", "stitch-ogsi", "stitch-seed", + "stitch-service", ] [[package]] @@ -1112,8 +1114,10 @@ dependencies = [ { name = "pydantic-settings" }, { name = "stitch-auth" }, { name = "stitch-client" }, + { name = "stitch-jobs" }, { name = "stitch-models" }, { name = "stitch-ogsi" }, + { name = "stitch-service" }, ] [package.dev-dependencies] @@ -1132,8 +1136,10 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "stitch-auth", editable = "packages/stitch-auth" }, { name = "stitch-client", editable = "packages/stitch-client" }, + { name = "stitch-jobs", editable = "packages/stitch-jobs" }, { name = "stitch-models", editable = "packages/stitch-models" }, { name = "stitch-ogsi", editable = "packages/stitch-ogsi" }, + { name = "stitch-service", editable = "packages/stitch-service" }, ] [package.metadata.requires-dev] @@ -1145,6 +1151,35 @@ dev = [ { name = "pytest-anyio", specifier = ">=0.0.0" }, ] +[[package]] +name = "stitch-jobs" +version = "0.1.0" +source = { editable = "packages/stitch-jobs" } +dependencies = [ + { name = "fastapi", extra = ["standard-no-fastapi-cloud-cli"] }, + { name = "pydantic" }, +] + +[package.dev-dependencies] +dev = [ + { name = "httpx" }, + { name = "pytest" }, + { name = "pytest-anyio" }, +] + +[package.metadata] +requires-dist = [ + { name = "fastapi", extras = ["standard-no-fastapi-cloud-cli"], specifier = ">=0.135.1" }, + { name = "pydantic", specifier = ">=2.12.5" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "httpx", specifier = ">=0.28.0" }, + { name = "pytest", specifier = ">=9.0.2" }, + { name = "pytest-anyio", specifier = ">=0.0.0" }, +] + [[package]] name = "stitch-llm" version = "0.1.0" @@ -1257,6 +1292,31 @@ dev = [ { name = "pytest-anyio", specifier = ">=0.0.0" }, ] +[[package]] +name = "stitch-service" +version = "0.1.0" +source = { editable = "packages/stitch-service" } +dependencies = [ + { name = "fastapi", extra = ["standard-no-fastapi-cloud-cli"] }, +] + +[package.dev-dependencies] +dev = [ + { name = "httpx" }, + { name = "pytest" }, + { name = "pytest-anyio" }, +] + +[package.metadata] +requires-dist = [{ name = "fastapi", extras = ["standard-no-fastapi-cloud-cli"], specifier = ">=0.135.1" }] + +[package.metadata.requires-dev] +dev = [ + { name = "httpx", specifier = ">=0.28.0" }, + { name = "pytest", specifier = ">=9.0.2" }, + { name = "pytest-anyio", specifier = ">=0.0.0" }, +] + [[package]] name = "typer" version = "0.25.1" From 395e2ba2e80da20b7b9fe2600acfdb9f695237f5 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Mon, 22 Jun 2026 20:06:52 +0200 Subject: [PATCH 04/26] Update entity linkage to use new packages --- deployments/entity-linkage/pyproject.toml | 4 + .../src/stitch/entity_linkage/linkage.py | 140 ++++++++++++++ .../src/stitch/entity_linkage/main.py | 26 +-- .../stitch/entity_linkage/routers/start.py | 175 +++--------------- deployments/entity-linkage/tests/conftest.py | 6 + .../entity-linkage/tests/test_start.py | 111 ++++------- .../entity-linkage/tests/test_start_api.py | 165 ++++++++++------- 7 files changed, 315 insertions(+), 312 deletions(-) create mode 100644 deployments/entity-linkage/src/stitch/entity_linkage/linkage.py create mode 100644 deployments/entity-linkage/tests/conftest.py diff --git a/deployments/entity-linkage/pyproject.toml b/deployments/entity-linkage/pyproject.toml index ca5936a2..a043f2e3 100644 --- a/deployments/entity-linkage/pyproject.toml +++ b/deployments/entity-linkage/pyproject.toml @@ -11,8 +11,10 @@ dependencies = [ "pydantic-settings>=2.12.0", "stitch-auth", "stitch-client", + "stitch-jobs", "stitch-models", "stitch-ogsi", + "stitch-service", ] [build-system] @@ -41,5 +43,7 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } stitch-client = { workspace = true } +stitch-jobs = { workspace = true } stitch-models = { workspace = true } stitch-ogsi = { workspace = true } +stitch-service = { workspace = true } diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/linkage.py b/deployments/entity-linkage/src/stitch/entity_linkage/linkage.py new file mode 100644 index 00000000..16a52bd6 --- /dev/null +++ b/deployments/entity-linkage/src/stitch/entity_linkage/linkage.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from collections import defaultdict + +from pydantic import BaseModel, Field + +from stitch.entity_linkage.client import StitchApiClient +from stitch.entity_linkage.entities import FieldCandidate, MatchGroup + + +class LinkageParams(BaseModel): + """Tunable inputs for one entity-linkage pass. + + Doubles as the ``POST /start`` request body and the params stored on the + job record. + """ + + apply_merges: bool = Field( + default=False, + description=( + "When true, submit confirmed match groups to the Stitch API as " + "merge candidates." + ), + ) + page: int = Field(default=1, ge=1) + page_size: int = Field(default=50, ge=1, le=200) + max_pages: int | None = Field( + default=None, + ge=1, + le=1000, + description="Optional cap on pages fetched. Null means fetch all pages.", + ) + + +class LinkageResult(BaseModel): + """Summary of a completed entity-linkage pass.""" + + pages_fetched: int + total_records_fetched: int + duplicate_name_candidate_count: int + detail_records_fetched: int + match_groups: list[list[int]] + merge_results: list[dict] + + +def _group_duplicate_names( + items: list[FieldCandidate], +) -> dict[str, list[FieldCandidate]]: + grouped: dict[str, list[FieldCandidate]] = defaultdict(list) + for item in items: + if item.normalized_name is None: + continue + grouped[item.normalized_name].append(item) + return { + normalized_name: grouped_items + for normalized_name, grouped_items in grouped.items() + if len(grouped_items) > 1 + } + + +def _normalize_country(country: str | None) -> str | None: + if country is None: + return None + normalized = country.strip().upper() + return normalized or None + + +async def _resolve_match_groups( + client: StitchApiClient, + duplicate_groups: dict[str, list[FieldCandidate]], +) -> tuple[list[MatchGroup], int]: + match_groups: list[MatchGroup] = [] + detail_records_fetched = 0 + + for normalized_name, candidates in duplicate_groups.items(): + by_country: dict[str, list[int]] = defaultdict(list) + + for candidate in candidates: + detail = await client.get_oil_gas_field_detail(candidate.id) + detail_records_fetched += 1 + normalized_country = _normalize_country(detail.country) + if normalized_country is None: + continue + by_country[normalized_country].append(detail.id) + + for country, ids in by_country.items(): + if len(ids) > 1: + match_groups.append( + MatchGroup( + ids=sorted(ids), + normalized_name=normalized_name, + country=country, + ) + ) + + return match_groups, detail_records_fetched + + +async def run_linkage(params: LinkageParams) -> LinkageResult: + """Run one entity-linkage pass and return a summary. + + - fetch paginated oil-gas-fields list + - group exact case-insensitive duplicate names + - fetch detail records for candidate duplicates + - confirm same-country matches + - optionally submit merge candidates + + Invoked as the background job body by the ``JobManager`` in + :mod:`stitch.entity_linkage.routers.start`. Downstream failures + (``StitchAPIError``) propagate and are captured as a failed job, observable + via ``GET /status/{job_id}``. + """ + async with StitchApiClient() as client: + items, pages_fetched = await client.collect_oil_gas_fields( + start_page=params.page, + page_size=params.page_size, + max_pages=params.max_pages, + ) + duplicate_groups = _group_duplicate_names(items) + match_groups, detail_records_fetched = await _resolve_match_groups( + client=client, + duplicate_groups=duplicate_groups, + ) + + merge_results: list[dict] = [] + if params.apply_merges: + for group in match_groups: + response = await client.create_merge_candidate(resource_ids=group.ids) + merge_results.append({"ids": group.ids, "response": response}) + + return LinkageResult( + pages_fetched=pages_fetched, + total_records_fetched=len(items), + duplicate_name_candidate_count=sum( + len(group) for group in duplicate_groups.values() + ), + detail_records_fetched=detail_records_fetched, + match_groups=[group.ids for group in match_groups], + merge_results=merge_results, + ) diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/main.py b/deployments/entity-linkage/src/stitch/entity_linkage/main.py index ab95c02f..a4f6c9dd 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/main.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/main.py @@ -1,7 +1,6 @@ -from contextlib import asynccontextmanager -from datetime import UTC, datetime -from fastapi import APIRouter, FastAPI -from .middleware import register_middlewares +from fastapi import FastAPI +from stitch.service import create_app + from .auth import validate_auth_config_at_startup from .client import validate_downstream_auth_config_at_startup from .settings import get_settings @@ -9,27 +8,20 @@ from .routers.health import router as health_router from .routers.start import router as start_router -base_router = APIRouter(prefix="/api/v1") -base_router.include_router(health_router) -base_router.include_router(start_router) - -@asynccontextmanager -async def lifespan(app: FastAPI): - app.state.started_at = datetime.now(UTC) +def _run_startup(app: FastAPI) -> None: app.state.auth_config_validated = False app.state.downstream_auth_config_validated = False validate_auth_config_at_startup() app.state.auth_config_validated = True validate_downstream_auth_config_at_startup() app.state.downstream_auth_config_validated = True - yield -app = FastAPI(lifespan=lifespan) - settings = get_settings() -register_middlewares(application=app, settings=settings) - -app.include_router(base_router) +app = create_app( + routers=[health_router, start_router], + cors_origins=[str(settings.frontend_origin_url)], + on_startup=_run_startup, +) diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py b/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py index 7f284bfb..f287b345 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py @@ -1,168 +1,43 @@ from __future__ import annotations -from collections import defaultdict +from datetime import timedelta -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel, Field -from starlette.status import HTTP_502_BAD_GATEWAY +from fastapi import Depends from stitch.auth.permissions import SERVICE_ENTITY_LINKAGE_RUN +from stitch.jobs import FingerprintPolicy, JobManager, make_job_router from stitch.entity_linkage.auth import AuthContext, require_permissions -from stitch.entity_linkage.client import StitchApiClient -from stitch.entity_linkage.entities import FieldCandidate, MatchGroup, User -from stitch.entity_linkage.errors import StitchAPIError - -router = APIRouter(tags=["entity-linkage"]) - - -class StartRequest(BaseModel): - apply_merges: bool = Field( - default=False, - description=( - "When true, submit confirmed match groups to the Stitch API as " - "merge candidates." - ), - ) - page: int = Field(default=1, ge=1) - page_size: int = Field(default=50, ge=1, le=200) - max_pages: int | None = Field( - default=None, - ge=1, - le=1000, - description="Optional cap on pages fetched. Null means fetch all pages.", - ) +from stitch.entity_linkage.entities import User +from stitch.entity_linkage.linkage import LinkageParams, LinkageResult, run_linkage + +# Two requests are "the same" run when all tunable params match. Identical +# requests (same paging + apply_merges) collapse onto one job — so a second +# user sees the in-flight run, and reuses its result for `recent_within` after +# it finishes — while different params run independently. +_manager: JobManager[LinkageParams, LinkageResult] = JobManager( + run_linkage, + policy=FingerprintPolicy(), + recent_within=timedelta(minutes=5), +) -class StartResponse(BaseModel): - initiated_by: str - apply_merges: bool - pages_fetched: int - total_records_fetched: int - duplicate_name_candidate_count: int - detail_records_fetched: int - match_groups: list[list[int]] - merge_results: list[dict] +def get_job_manager() -> JobManager[LinkageParams, LinkageResult]: + return _manager def _extract_user_label(user: User) -> str: return user.name or user.email or user.sub -def _group_duplicate_names( - items: list[FieldCandidate], -) -> dict[str, list[FieldCandidate]]: - grouped: dict[str, list[FieldCandidate]] = defaultdict(list) - for item in items: - if item.normalized_name is None: - continue - grouped[item.normalized_name].append(item) - return { - normalized_name: grouped_items - for normalized_name, grouped_items in grouped.items() - if len(grouped_items) > 1 - } - - -def _normalize_country(country: str | None) -> str | None: - if country is None: - return None - normalized = country.strip().upper() - return normalized or None - - -async def _resolve_match_groups( - client: StitchApiClient, - duplicate_groups: dict[str, list[FieldCandidate]], -) -> tuple[list[MatchGroup], int]: - match_groups: list[MatchGroup] = [] - detail_records_fetched = 0 +async def initiated_by(auth_context: AuthContext) -> str: + return _extract_user_label(auth_context.user) - for normalized_name, candidates in duplicate_groups.items(): - by_country: dict[str, list[int]] = defaultdict(list) - for candidate in candidates: - detail = await client.get_oil_gas_field_detail(candidate.id) - detail_records_fetched += 1 - normalized_country = _normalize_country(detail.country) - if normalized_country is None: - continue - by_country[normalized_country].append(detail.id) - - for country, ids in by_country.items(): - if len(ids) > 1: - match_groups.append( - MatchGroup( - ids=sorted(ids), - normalized_name=normalized_name, - country=country, - ) - ) - - return match_groups, detail_records_fetched - - -@router.post( - "/start", - response_model=StartResponse, +router = make_job_router( + _manager, + start_request_model=LinkageParams, + result_model=LinkageResult, dependencies=[Depends(require_permissions(SERVICE_ENTITY_LINKAGE_RUN))], + initiated_by=initiated_by, + tags=["entity-linkage"], ) -async def start( - request: StartRequest, - auth_context: AuthContext, -) -> StartResponse: - """ - In-memory entity-linkage pass: - - fetch paginated oil-gas-fields list - - group exact case-insensitive duplicate names - - fetch detail records for candidate duplicates - - confirm same-country matches - - optionally submit merge candidates - - Not implemented: - - add concurrency controls for detail fetches - - add stronger second-phase inspection beyond country equality - - add alternate downstream auth modes beyond env-token auth - """ - try: - async with StitchApiClient() as client: - items, pages_fetched = await client.collect_oil_gas_fields( - start_page=request.page, - page_size=request.page_size, - max_pages=request.max_pages, - ) - duplicate_groups = _group_duplicate_names(items) - match_groups, detail_records_fetched = await _resolve_match_groups( - client=client, - duplicate_groups=duplicate_groups, - ) - - merge_results: list[dict] = [] - if request.apply_merges: - for group in match_groups: - response = await client.create_merge_candidate( - resource_ids=group.ids - ) - merge_results.append( - { - "ids": group.ids, - "response": response, - } - ) - except StitchAPIError as exc: - raise HTTPException( - status_code=HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc - - return StartResponse( - initiated_by=_extract_user_label(auth_context.user), - apply_merges=request.apply_merges, - pages_fetched=pages_fetched, - total_records_fetched=len(items), - duplicate_name_candidate_count=sum( - len(group) for group in duplicate_groups.values() - ), - detail_records_fetched=detail_records_fetched, - match_groups=[group.ids for group in match_groups], - merge_results=merge_results, - ) diff --git a/deployments/entity-linkage/tests/conftest.py b/deployments/entity-linkage/tests/conftest.py new file mode 100644 index 00000000..5c53fe0a --- /dev/null +++ b/deployments/entity-linkage/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" diff --git a/deployments/entity-linkage/tests/test_start.py b/deployments/entity-linkage/tests/test_start.py index 828e05cf..8dde571b 100644 --- a/deployments/entity-linkage/tests/test_start.py +++ b/deployments/entity-linkage/tests/test_start.py @@ -3,43 +3,23 @@ from contextlib import AbstractAsyncContextManager import pytest -from fastapi import HTTPException +from stitch.entity_linkage import linkage as linkage_module from stitch.entity_linkage.entities import ( FieldCandidate, FieldDetailCandidate, MatchGroup, - RequestAuthContext, User, ) from stitch.entity_linkage.errors import StitchAPIError -from stitch.entity_linkage.routers import start as start_module -from stitch.entity_linkage.routers.start import ( - StartRequest, - _extract_user_label, +from stitch.entity_linkage.linkage import ( + LinkageParams, _group_duplicate_names, _normalize_country, _resolve_match_groups, - start, + run_linkage, ) - - -def make_auth_context( - *, - name: str = "Test User", - email: str = "test@example.com", - sub: str = "auth0|user-123", - bearer_token: str | None = "token-123", -) -> RequestAuthContext: - return RequestAuthContext( - user=User( - id=1, - sub=sub, - email=email, - name=name, - ), - bearer_token=bearer_token, - ) +from stitch.entity_linkage.routers.start import _extract_user_label class FakeStitchApiClient(AbstractAsyncContextManager["FakeStitchApiClient"]): @@ -80,11 +60,7 @@ async def collect_oil_gas_fields( max_pages: int | None = None, ) -> tuple[list[FieldCandidate], int]: self.collect_calls.append( - { - "start_page": start_page, - "page_size": page_size, - "max_pages": max_pages, - } + {"start_page": start_page, "page_size": page_size, "max_pages": max_pages} ) if self.collect_error is not None: raise self.collect_error @@ -128,10 +104,6 @@ def test_extract_user_label_prefers_name_then_email_then_sub() -> None: _extract_user_label(User(id=1, sub="sub-2", email="b@example.com", name="")) == "b@example.com" ) - assert ( - _extract_user_label(User(id=1, sub="sub-3", email="c@example.com", name="")) - != "sub-3" - ) def test_group_duplicate_names_uses_casefold_and_strips_whitespace() -> None: @@ -189,10 +161,9 @@ async def test_resolve_match_groups_groups_only_same_country_duplicates() -> Non @pytest.mark.anyio -async def test_start_returns_summary_without_merges( +async def test_run_linkage_returns_summary_without_merges( monkeypatch: pytest.MonkeyPatch, ) -> None: - auth_context = make_auth_context(name="Alex Reviewer") fake_client = FakeStitchApiClient( items=[ FieldCandidate(id=1, name="Alpha", country="US"), @@ -208,21 +179,18 @@ async def test_start_returns_summary_without_merges( }, ) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: fake_client) + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: fake_client) - response = await start( - StartRequest(apply_merges=False, page=2, page_size=25, max_pages=4), - auth_context=auth_context, + result = await run_linkage( + LinkageParams(apply_merges=False, page=2, page_size=25, max_pages=4) ) - assert response.initiated_by == "Alex Reviewer" - assert response.apply_merges is False - assert response.pages_fetched == 3 - assert response.total_records_fetched == 4 - assert response.duplicate_name_candidate_count == 3 - assert response.detail_records_fetched == 3 - assert response.match_groups == [[1, 3]] - assert response.merge_results == [] + assert result.pages_fetched == 3 + assert result.total_records_fetched == 4 + assert result.duplicate_name_candidate_count == 3 + assert result.detail_records_fetched == 3 + assert result.match_groups == [[1, 3]] + assert result.merge_results == [] assert fake_client.collect_calls == [ {"start_page": 2, "page_size": 25, "max_pages": 4} @@ -231,10 +199,9 @@ async def test_start_returns_summary_without_merges( @pytest.mark.anyio -async def test_start_applies_merges_for_each_match_group( +async def test_run_linkage_applies_merges_for_each_match_group( monkeypatch: pytest.MonkeyPatch, ) -> None: - auth_context = make_auth_context() fake_client = FakeStitchApiClient( items=[ FieldCandidate(id=1, name="Alpha", country="ignored"), @@ -254,26 +221,22 @@ async def test_start_applies_merges_for_each_match_group( }, ) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: fake_client) + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: fake_client) - response = await start( - StartRequest(apply_merges=True), - auth_context=auth_context, - ) + result = await run_linkage(LinkageParams(apply_merges=True)) - assert response.match_groups == [[1, 2], [3, 4]] + assert result.match_groups == [[1, 2], [3, 4]] assert fake_client.merge_calls == [[1, 2], [3, 4]] - assert response.merge_results == [ + assert result.merge_results == [ {"ids": [1, 2], "response": {"merged_ids": [1, 2], "winner": 1}}, {"ids": [3, 4], "response": {"merged_ids": [3, 4], "winner": 3}}, ] @pytest.mark.anyio -async def test_start_returns_no_matches_when_duplicate_names_do_not_confirm( +async def test_run_linkage_returns_no_matches_when_duplicate_names_do_not_confirm( monkeypatch: pytest.MonkeyPatch, ) -> None: - auth_context = make_auth_context() fake_client = FakeStitchApiClient( items=[ FieldCandidate(id=1, name="Alpha", country="ignored"), @@ -286,36 +249,30 @@ async def test_start_returns_no_matches_when_duplicate_names_do_not_confirm( }, ) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: fake_client) + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: fake_client) - response = await start( - StartRequest(apply_merges=True), - auth_context=auth_context, - ) + result = await run_linkage(LinkageParams(apply_merges=True)) - assert response.duplicate_name_candidate_count == 2 - assert response.detail_records_fetched == 2 - assert response.match_groups == [] - assert response.merge_results == [] + assert result.duplicate_name_candidate_count == 2 + assert result.detail_records_fetched == 2 + assert result.match_groups == [] + assert result.merge_results == [] assert fake_client.merge_calls == [] @pytest.mark.anyio -async def test_start_translates_stitch_api_error_to_502( +async def test_run_linkage_propagates_stitch_api_error( monkeypatch: pytest.MonkeyPatch, ) -> None: - auth_context = make_auth_context() + # In the job model, downstream errors propagate out of run_linkage and are + # captured by the JobManager as a failed job (no synchronous 502). fake_client = FakeStitchApiClient( collect_error=StitchAPIError( "GET /oil-gas-fields/ failed with status 500: boom" ), ) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: fake_client) - - with pytest.raises(HTTPException) as exc_info: - await start(StartRequest(), auth_context=auth_context) + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: fake_client) - exc = exc_info.value - assert exc.status_code == 502 - assert exc.detail == "GET /oil-gas-fields/ failed with status 500: boom" + with pytest.raises(StitchAPIError): + await run_linkage(LinkageParams()) diff --git a/deployments/entity-linkage/tests/test_start_api.py b/deployments/entity-linkage/tests/test_start_api.py index 0f1587e5..79b72746 100644 --- a/deployments/entity-linkage/tests/test_start_api.py +++ b/deployments/entity-linkage/tests/test_start_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from contextlib import AbstractAsyncContextManager import pytest @@ -7,6 +8,9 @@ from stitch.auth import TokenClaims from stitch.auth.permissions import SERVICE_ENTITY_LINKAGE_RUN +import stitch.entity_linkage.main as main_module +from stitch.entity_linkage import linkage as linkage_module +from stitch.entity_linkage.auth import get_request_auth_context, get_token_claims from stitch.entity_linkage.entities import ( FieldCandidate, FieldDetailCandidate, @@ -16,9 +20,7 @@ from stitch.entity_linkage.errors import StitchAPIError from stitch.entity_linkage.main import app from stitch.entity_linkage.routers import health as health_module -from stitch.entity_linkage.routers import start as start_module -from stitch.entity_linkage.auth import get_request_auth_context, get_token_claims -from stitch.entity_linkage import main as main_module +from stitch.entity_linkage.routers.start import get_job_manager def make_auth_context( @@ -29,12 +31,7 @@ def make_auth_context( bearer_token: str | None = "integration-token", ) -> RequestAuthContext: return RequestAuthContext( - user=User( - id=1, - sub=sub, - email=email, - name=name, - ), + user=User(id=1, sub=sub, email=email, name=name), bearer_token=bearer_token, ) @@ -84,11 +81,7 @@ async def collect_oil_gas_fields( max_pages: int | None = None, ) -> tuple[list[FieldCandidate], int]: self.collect_calls.append( - { - "start_page": start_page, - "page_size": page_size, - "max_pages": max_pages, - } + {"start_page": start_page, "page_size": page_size, "max_pages": max_pages} ) if self.collect_error is not None: raise self.collect_error @@ -113,6 +106,14 @@ async def get_auth_me(self) -> dict: return self.auth_me_response +@pytest.fixture(autouse=True) +def reset_job_manager(): + """Each test starts with a clean, isolated job store.""" + get_job_manager().reset() + yield + get_job_manager().reset() + + @pytest.fixture def auth_context() -> RequestAuthContext: return make_auth_context() @@ -144,7 +145,9 @@ def install( merge_error=merge_error, ) created_clients.append(client) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: client) + # The job runs run_linkage in the background, which constructs the + # client from the linkage module's namespace. + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: client) return client return install, created_clients @@ -177,6 +180,16 @@ def override_token_claims() -> TokenClaims: app.dependency_overrides.clear() +def _poll(client: TestClient, job_id: str, *, timeout: float = 5.0) -> dict: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + body = client.get(f"/api/v1/status/{job_id}").json() + if body["state"] != "running": + return body + time.sleep(0.02) + raise AssertionError("job did not finish within timeout") + + def test_post_start_requires_service_permission( auth_context: RequestAuthContext, monkeypatch: pytest.MonkeyPatch, @@ -203,7 +216,7 @@ def override_token_claims() -> TokenClaims: assert SERVICE_ENTITY_LINKAGE_RUN in response.json()["detail"] -def test_post_start_returns_serialized_response_model( +def test_post_start_accepts_job_and_status_reports_result( test_client: TestClient, api_client_factory, ) -> None: @@ -221,20 +234,27 @@ def test_post_start_returns_serialized_response_model( }, ) - response = test_client.post( + started = test_client.post( "/api/v1/start", - json={ - "apply_merges": False, - "page": 3, - "page_size": 25, - "max_pages": 7, - }, + json={"apply_merges": False, "page": 3, "page_size": 25, "max_pages": 7}, ) - assert response.status_code == 200 - assert response.json() == { - "initiated_by": "Integration Tester", + assert started.status_code == 202 + body = started.json() + assert body["state"] == "running" + assert body["initiated_by"] == "Integration Tester" + assert body["params"] == { "apply_merges": False, + "page": 3, + "page_size": 25, + "max_pages": 7, + } + + final = _poll(test_client, body["job_id"]) + assert final["state"] == "succeeded" + assert final["error"] is None + assert final["finished_at"] is not None + assert final["result"] == { "pages_fetched": 2, "total_records_fetched": 3, "duplicate_name_candidate_count": 2, @@ -245,17 +265,13 @@ def test_post_start_returns_serialized_response_model( assert len(created_clients) == 1 assert fake_client.collect_calls == [ - { - "start_page": 3, - "page_size": 25, - "max_pages": 7, - } + {"start_page": 3, "page_size": 25, "max_pages": 7} ] assert fake_client.detail_calls == [1, 2] assert fake_client.merge_calls == [] -def test_post_start_applies_merges_and_returns_merge_results( +def test_post_start_applies_merges_and_reports_merge_results( test_client: TestClient, api_client_factory, ) -> None: @@ -279,21 +295,20 @@ def test_post_start_applies_merges_and_returns_merge_results( }, ) - response = test_client.post( - "/api/v1/start", - json={"apply_merges": True}, - ) + started = test_client.post("/api/v1/start", json={"apply_merges": True}) + assert started.status_code == 202 - assert response.status_code == 200 - assert response.json()["match_groups"] == [[10, 11], [20, 21]] - assert response.json()["merge_results"] == [ + final = _poll(test_client, started.json()["job_id"]) + assert final["state"] == "succeeded" + assert final["result"]["match_groups"] == [[10, 11], [20, 21]] + assert final["result"]["merge_results"] == [ {"ids": [10, 11], "response": {"merged_ids": [10, 11], "winner": 10}}, {"ids": [20, 21], "response": {"merged_ids": [20, 21], "winner": 20}}, ] assert fake_client.merge_calls == [[10, 11], [20, 21]] -def test_post_start_returns_empty_matches_when_country_check_does_not_confirm( +def test_post_start_reports_empty_matches_when_country_check_does_not_confirm( test_client: TestClient, api_client_factory, ) -> None: @@ -310,19 +325,17 @@ def test_post_start_returns_empty_matches_when_country_check_does_not_confirm( }, ) - response = test_client.post( - "/api/v1/start", - json={"apply_merges": True}, - ) + started = test_client.post("/api/v1/start", json={"apply_merges": True}) + final = _poll(test_client, started.json()["job_id"]) - assert response.status_code == 200 - assert response.json()["duplicate_name_candidate_count"] == 2 - assert response.json()["detail_records_fetched"] == 2 - assert response.json()["match_groups"] == [] - assert response.json()["merge_results"] == [] + assert final["state"] == "succeeded" + assert final["result"]["duplicate_name_candidate_count"] == 2 + assert final["result"]["detail_records_fetched"] == 2 + assert final["result"]["match_groups"] == [] + assert final["result"]["merge_results"] == [] -def test_post_start_translates_stitch_api_error_to_502( +def test_job_records_failure_when_downstream_errors( test_client: TestClient, api_client_factory, ) -> None: @@ -333,46 +346,62 @@ def test_post_start_translates_stitch_api_error_to_502( ), ) - response = test_client.post( - "/api/v1/start", - json={"apply_merges": False}, - ) + started = test_client.post("/api/v1/start", json={"apply_merges": False}) + assert started.status_code == 202 - assert response.status_code == 502 - assert response.json() == { - "detail": "GET /oil-gas-fields/ failed with status 500: boom", - } + final = _poll(test_client, started.json()["job_id"]) + assert final["state"] == "failed" + assert final["result"] is None + assert "GET /oil-gas-fields/ failed with status 500: boom" in final["error"] -def test_post_start_validates_request_body_constraints( +def test_second_caller_observes_existing_run( test_client: TestClient, api_client_factory, ) -> None: install, _ = api_client_factory install( - items=[], - details_by_id={}, + items=[ + FieldCandidate(id=1, name="Alpha", country="ignored"), + FieldCandidate(id=2, name="alpha", country="ignored"), + ], + details_by_id={ + 1: FieldDetailCandidate(id=1, name="Alpha", country="US"), + 2: FieldDetailCandidate(id=2, name="Alpha", country="US"), + }, ) + first = test_client.post("/api/v1/start", json={"apply_merges": False}) + job_id = first.json()["job_id"] + _poll(test_client, job_id) + + # Same params within the reuse window → returns the existing run (200), not + # a fresh job. This is the cross-user "request already made" behavior. + second = test_client.post("/api/v1/start", json={"apply_merges": False}) + assert second.status_code == 200 + assert second.json()["job_id"] == job_id + + +def test_post_start_validates_request_body_constraints( + test_client: TestClient, +) -> None: response = test_client.post( "/api/v1/start", - json={ - "apply_merges": False, - "page": 0, - "page_size": 500, - "max_pages": 0, - }, + json={"apply_merges": False, "page": 0, "page_size": 500, "max_pages": 0}, ) assert response.status_code == 422 detail = response.json()["detail"] - fields = {tuple(item["loc"]) for item in detail} assert ("body", "page") in fields assert ("body", "page_size") in fields assert ("body", "max_pages") in fields +def test_status_404_for_unknown_job(test_client: TestClient) -> None: + assert test_client.get("/api/v1/status/nope").status_code == 404 + + def test_health_details_reports_ready_when_downstream_auth_probe_succeeds( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, From b08f6436ab66f462f27c1d29ed76adbc31b4f428 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Mon, 22 Jun 2026 20:25:11 +0200 Subject: [PATCH 05/26] Update EL frontend --- .../src/pages/EntityLinkagePage.jsx | 138 ++++++++++++++---- .../src/pages/EntityLinkagePage.test.jsx | 132 +++++++++++++---- 2 files changed, 218 insertions(+), 52 deletions(-) diff --git a/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx b/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx index ff7ee1d2..cd2d9360 100644 --- a/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx +++ b/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx @@ -4,6 +4,36 @@ import { useConfig } from "../config/useConfig"; import StructuredDataView from "../components/StructuredDataView"; import Button from "../components/Button"; +const STATE_STYLES = { + running: "border-warning/30 bg-warning-soft text-warning", + succeeded: "border-success/25 bg-success-soft text-success-strong", + failed: "border-danger/25 bg-danger-soft text-danger", +}; + +function StateBadge({ state }) { + if (!state) return null; + + const classes = STATE_STYLES[state] ?? "border-line bg-surface text-ink"; + + return ( + + {state} + + ); +} + +async function parseJsonResponse(response) { + const text = await response.text(); + + try { + return text ? JSON.parse(text) : null; + } catch { + return { raw: text }; + } +} + function formatCount(count, singular, plural = `${singular}s`) { return `${count} ${count === 1 ? singular : plural}`; } @@ -66,10 +96,6 @@ function RunResult({ result }) { const matchGroups = getMatchGroups(result); const details = getResultDetails(result); - if (!result) { - return

No run has completed yet.

; - } - return (
@@ -98,14 +124,14 @@ export default function EntityLinkagePage() { const { getAccessTokenSilently } = useAuth0(); const [applyMerges, setApplyMerges] = useState(false); - const [loading, setLoading] = useState(false); - const [result, setResult] = useState(null); + const [starting, setStarting] = useState(false); + const [refreshing, setRefreshing] = useState(false); + const [record, setRecord] = useState(null); const [error, setError] = useState(null); async function handleStart() { - setLoading(true); + setStarting(true); setError(null); - setResult(null); try { const token = await getAccessTokenSilently({ @@ -123,34 +149,59 @@ export default function EntityLinkagePage() { }), }); - const text = await response.text(); + const parsed = await parseJsonResponse(response); - let parsed; - try { - parsed = text ? JSON.parse(text) : null; - } catch { - parsed = { raw: text }; + if (!response.ok) { + setError({ status: response.status, body: parsed }); + return; } + // 202 starts a new run; 200 means an identical run is already active or + // recently finished — either way `parsed` is the job record to track. + setRecord(parsed); + } catch (err) { + setError({ + status: null, + body: err instanceof Error ? err.message : String(err), + }); + } finally { + setStarting(false); + } + } + + async function handleRefresh() { + const jobId = record?.job_id; + if (!jobId) return; + + setRefreshing(true); + setError(null); + + try { + // GET /status/{job_id} is unauthenticated, like the other job services. + const response = await fetch( + `${config.entityLinkageBaseUrl}/status/${jobId}`, + ); + const parsed = await parseJsonResponse(response); + if (!response.ok) { - setError({ - status: response.status, - body: parsed, - }); + setError({ status: response.status, body: parsed }); return; } - setResult(parsed); + setRecord(parsed); } catch (err) { setError({ status: null, body: err instanceof Error ? err.message : String(err), }); } finally { - setLoading(false); + setRefreshing(false); } } + const state = record?.state; + const isRunning = state === "running"; + return (
@@ -159,7 +210,9 @@ export default function EntityLinkagePage() {

Entity Linkage

- Start an entity-linkage run and review the result. + Start an entity-linkage run, then refresh to check its status and + review the result. An identical run already in progress is shared + rather than started again.

@@ -169,14 +222,26 @@ export default function EntityLinkagePage() { type="checkbox" checked={applyMerges} onChange={(e) => setApplyMerges(e.target.checked)} + disabled={isRunning} className="accent-primary" /> Initiate merges -
- +
@@ -191,9 +256,30 @@ export default function EntityLinkagePage() { ) : null}
-

Run result

+
+

Run status

+ +
- + {!record ? ( +

+ No run started yet. Start a run to begin. +

+ ) : isRunning ? ( +

+ Run in progress — refresh to check for the result. +

+ ) : state === "failed" ? ( +
+

Run failed.

+ +
+ ) : ( + + )}
diff --git a/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx b/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx index cc92aa93..6059faa1 100644 --- a/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx +++ b/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx @@ -5,6 +5,43 @@ import { useAuth0 } from "@auth0/auth0-react"; import EntityLinkagePage from "./EntityLinkagePage"; import { auth0TestDefaults, renderWithQueryClient } from "../test/utils"; +const RUNNING_RECORD = { + job_id: "job-123", + state: "running", + dedup_key: "LinkageParams:abc", + initiated_by: "Test User", + params: { apply_merges: false, page: 1, page_size: 50, max_pages: null }, + started_at: "2026-01-01T00:00:00Z", + finished_at: null, + result: null, + error: null, +}; + +const SUCCEEDED_RECORD = { + ...RUNNING_RECORD, + state: "succeeded", + finished_at: "2026-01-01T00:00:05Z", + result: { + pages_fetched: 1, + total_records_fetched: 4, + duplicate_name_candidate_count: 4, + detail_records_fetched: 4, + match_groups: [ + [101, 102], + [203, 204, 205], + ], + merge_results: [], + }, +}; + +function mockResponse(status, body) { + return { + ok: status >= 200 && status < 300, + status, + text: async () => JSON.stringify(body), + }; +} + describe("EntityLinkagePage", () => { let getAccessTokenSilently; @@ -16,29 +53,50 @@ describe("EntityLinkagePage", () => { }); }); - it("renders match groups as visually separated groups", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue({ - ok: true, - status: 200, - text: async () => - JSON.stringify({ - initiated_by: "Test User", - apply_merges: false, - pages_fetched: 1, - total_records_fetched: 4, - duplicate_name_candidate_count: 4, - detail_records_fetched: 4, - match_groups: [ - [101, 102], - [203, 204, 205], - ], - merge_results: [], - }), + it("starts a job (202) and authenticates the start request", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(mockResponse(202, RUNNING_RECORD)); + + renderWithQueryClient(); + + await userEvent.click(screen.getByRole("button", { name: "Start run" })); + + await waitFor(() => { + expect(screen.getByText("running")).toBeInTheDocument(); }); + expect( + screen.getByText("Run in progress — refresh to check for the result."), + ).toBeInTheDocument(); + + const [startUrl, startOptions] = fetchSpy.mock.calls[0]; + expect(startUrl).toMatch(/\/start$/); + expect(startOptions.method).toBe("POST"); + expect(startOptions.headers.Authorization).toBe("Bearer test-access-token"); + expect(getAccessTokenSilently).toHaveBeenCalledWith({ + authorizationParams: { audience: "https://stitch-api.local" }, + }); + expect(getAccessTokenSilently.mock.invocationCallOrder[0]).toBeLessThan( + fetchSpy.mock.invocationCallOrder[0], + ); + }); + + it("polls /status/{job_id} on refresh and renders the completed result", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValueOnce(mockResponse(202, RUNNING_RECORD)) + .mockResolvedValueOnce(mockResponse(200, SUCCEEDED_RECORD)); renderWithQueryClient(); await userEvent.click(screen.getByRole("button", { name: "Start run" })); + await waitFor(() => { + expect(screen.getByText("running")).toBeInTheDocument(); + }); + + await userEvent.click( + screen.getByRole("button", { name: "Refresh status" }), + ); await waitFor(() => { expect( @@ -46,20 +104,42 @@ describe("EntityLinkagePage", () => { ).toBeInTheDocument(); }); + expect(screen.getByText("succeeded")).toBeInTheDocument(); expect(screen.getByText("2 groups")).toBeInTheDocument(); expect( screen.getByRole("heading", { name: "Match group 1" }), ).toBeInTheDocument(); - expect( - screen.getByRole("heading", { name: "Match group 2" }), - ).toBeInTheDocument(); expect(screen.getByText("Resource 101")).toBeInTheDocument(); expect(screen.getByText("Resource 205")).toBeInTheDocument(); - expect(getAccessTokenSilently).toHaveBeenCalledWith({ - authorizationParams: { audience: "https://stitch-api.local" }, - }); - expect(getAccessTokenSilently.mock.invocationCallOrder[0]).toBeLessThan( - fetch.mock.invocationCallOrder[0], + + // The status poll hits /status/{job_id} (unauthenticated GET). + const [statusUrl] = fetchSpy.mock.calls[1]; + expect(statusUrl).toMatch(/\/status\/job-123$/); + }); + + it("surfaces a failed run", async () => { + vi.spyOn(globalThis, "fetch") + .mockResolvedValueOnce(mockResponse(202, RUNNING_RECORD)) + .mockResolvedValueOnce( + mockResponse(200, { + ...RUNNING_RECORD, + state: "failed", + finished_at: "2026-01-01T00:00:05Z", + error: "GET /oil-gas-fields/ failed with status 500: boom", + }), + ); + + renderWithQueryClient(); + + await userEvent.click(screen.getByRole("button", { name: "Start run" })); + await waitFor(() => screen.getByText("running")); + await userEvent.click( + screen.getByRole("button", { name: "Refresh status" }), ); + + await waitFor(() => { + expect(screen.getByText("Run failed.")).toBeInTheDocument(); + }); + expect(screen.getByText("failed")).toBeInTheDocument(); }); }); From 225d9fefb7750aa2eeb60c9af43bd55839a0467e Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 10:58:58 +0200 Subject: [PATCH 06/26] style: forbidden patterns --- packages/stitch-jobs/src/stitch/jobs/manager.py | 4 +++- packages/stitch-jobs/src/stitch/jobs/routers.py | 4 ++-- packages/stitch-jobs/src/stitch/jobs/uniqueness.py | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/packages/stitch-jobs/src/stitch/jobs/manager.py b/packages/stitch-jobs/src/stitch/jobs/manager.py index b0d2ed20..f9db2f9a 100644 --- a/packages/stitch-jobs/src/stitch/jobs/manager.py +++ b/packages/stitch-jobs/src/stitch/jobs/manager.py @@ -88,7 +88,9 @@ async def _run(self, record: JobRecord[P, R], params: P) -> None: try: record.result = await self._run_fn(params) record.state = JobState.succeeded - except Exception as exc: # noqa: BLE001 - captured into the record + except Exception as exc: + # Broad on purpose: any run_fn failure is captured onto the record + # (state=failed, error set) rather than crashing the background task. logger.exception("job %s failed", record.job_id) record.error = str(exc) record.state = JobState.failed diff --git a/packages/stitch-jobs/src/stitch/jobs/routers.py b/packages/stitch-jobs/src/stitch/jobs/routers.py index 2a6d0971..ad583e6c 100644 --- a/packages/stitch-jobs/src/stitch/jobs/routers.py +++ b/packages/stitch-jobs/src/stitch/jobs/routers.py @@ -41,7 +41,7 @@ def make_job_router( to_params = to_params or (lambda request: request) resolve_initiated_by = initiated_by or (lambda: None) - record_model = JobRecord[params_model, result_model] # type: ignore[valid-type] + record_model = JobRecord[params_model, result_model] router = APIRouter(tags=list(tags) if tags else None) @@ -52,7 +52,7 @@ def make_job_router( dependencies=list(dependencies), ) async def start( - request: start_request_model, # type: ignore[valid-type] + request: start_request_model, response: Response, initiated_by_label: Any = Depends(resolve_initiated_by), ): diff --git a/packages/stitch-jobs/src/stitch/jobs/uniqueness.py b/packages/stitch-jobs/src/stitch/jobs/uniqueness.py index 976f2476..49113362 100644 --- a/packages/stitch-jobs/src/stitch/jobs/uniqueness.py +++ b/packages/stitch-jobs/src/stitch/jobs/uniqueness.py @@ -31,7 +31,8 @@ class SingletonPolicy: def __init__(self, key: str = "singleton") -> None: self._key = key - def key(self, params: BaseModel) -> str | None: # noqa: ARG002 - params ignored by design + def key(self, params: BaseModel) -> str | None: + # params intentionally unused: every request maps to the same key. return self._key @@ -78,5 +79,6 @@ def key(self, params: BaseModel) -> str | None: class NoDedupPolicy: """Never deduplicate: every request starts a new job.""" - def key(self, params: BaseModel) -> str | None: # noqa: ARG002 + def key(self, params: BaseModel) -> str | None: + # params intentionally unused: opt every request out of dedup. return None From 288d8565f11b394bc98a9578fe57e721f8b98fde Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 11:05:50 +0200 Subject: [PATCH 07/26] CodeQL: sub ellipsis for one-line docstrings --- packages/stitch-jobs/src/stitch/jobs/store.py | 15 ++++++++++----- .../stitch-jobs/src/stitch/jobs/uniqueness.py | 3 ++- packages/stitch-service/src/stitch/service/app.py | 3 ++- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/packages/stitch-jobs/src/stitch/jobs/store.py b/packages/stitch-jobs/src/stitch/jobs/store.py index 082e989f..4af9434b 100644 --- a/packages/stitch-jobs/src/stitch/jobs/store.py +++ b/packages/stitch-jobs/src/stitch/jobs/store.py @@ -20,17 +20,22 @@ class JobStore(Protocol): routers. """ - async def create(self, record: JobRecord) -> None: ... + async def create(self, record: JobRecord) -> None: + """Persist a newly started job record.""" - async def get(self, job_id: str) -> JobRecord | None: ... + async def get(self, job_id: str) -> JobRecord | None: + """Return the record for ``job_id``, or ``None`` if unknown.""" async def find_active_or_recent( self, dedup_key: str, *, recent_within: timedelta - ) -> JobRecord | None: ... + ) -> JobRecord | None: + """Return a matching job that is running or finished recently.""" - async def list(self, *, limit: int | None = None) -> list[JobRecord]: ... + async def list(self, *, limit: int | None = None) -> list[JobRecord]: + """Return recent records, newest first.""" - def clear(self) -> None: ... + def clear(self) -> None: + """Drop all records (test affordance).""" class InMemoryJobStore: diff --git a/packages/stitch-jobs/src/stitch/jobs/uniqueness.py b/packages/stitch-jobs/src/stitch/jobs/uniqueness.py index 49113362..5a3c94a6 100644 --- a/packages/stitch-jobs/src/stitch/jobs/uniqueness.py +++ b/packages/stitch-jobs/src/stitch/jobs/uniqueness.py @@ -17,7 +17,8 @@ class UniquenessPolicy(Protocol): entirely (always start a fresh job). """ - def key(self, params: BaseModel) -> str | None: ... + def key(self, params: BaseModel) -> str | None: + """Return the dedup key for ``params``, or ``None`` to skip dedup.""" class SingletonPolicy: diff --git a/packages/stitch-service/src/stitch/service/app.py b/packages/stitch-service/src/stitch/service/app.py index 1dec3d03..b7625d23 100644 --- a/packages/stitch-service/src/stitch/service/app.py +++ b/packages/stitch-service/src/stitch/service/app.py @@ -15,7 +15,8 @@ async def _maybe_await(value: Awaitable[None] | None) -> None: if inspect.isawaitable(value): - await value + return await value + return None def create_app( From ebea67c2ba39be4d1e88b98c1f23b277a72a2695 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 11:57:01 +0200 Subject: [PATCH 08/26] Extract Auth to `stitch-service` package --- .../src/stitch/entity_linkage/auth.py | 216 +------------ .../src/stitch/entity_linkage/client.py | 15 +- .../src/stitch/entity_linkage/entities.py | 44 ++- .../tests/test_downstream_auth.py | 31 ++ packages/stitch-service/pyproject.toml | 6 + .../src/stitch/service/__init__.py | 25 +- .../stitch-service/src/stitch/service/auth.py | 294 ++++++++++++++++++ packages/stitch-service/tests/test_auth.py | 121 +++++++ .../tests/test_downstream_client.py | 48 +++ uv.lock | 8 +- 10 files changed, 573 insertions(+), 235 deletions(-) create mode 100644 deployments/entity-linkage/tests/test_downstream_auth.py create mode 100644 packages/stitch-service/src/stitch/service/auth.py create mode 100644 packages/stitch-service/tests/test_auth.py create mode 100644 packages/stitch-service/tests/test_downstream_client.py diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/auth.py b/deployments/entity-linkage/src/stitch/entity_linkage/auth.py index f04cf08d..45ac3bb9 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/auth.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/auth.py @@ -1,208 +1,22 @@ -import asyncio -import logging -from functools import lru_cache -from typing import Annotated, Literal, NoReturn +"""Entity-linkage auth wiring. -from fastapi import Depends, HTTPException, Request -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +All the mechanics live in :mod:`stitch.service.auth`; here we just bind a +:class:`~stitch.service.auth.ServiceAuth` to this service's settings and +re-export the dependencies the routers and tests import by name. +""" -from stitch.auth import ( - ALL_PERMISSIONS, - AuthError, - InsufficientPermissionsError, - JWKSFetchError, - JWTValidator, - OIDCSettings, - TokenClaims, - check_permissions, -) +from stitch.service.auth import ServiceAuth -from stitch.entity_linkage.entities import RequestAuthContext, User from stitch.entity_linkage.settings import get_settings -logger = logging.getLogger(__name__) +_auth = ServiceAuth(is_auth_disabled=lambda: get_settings().auth_disabled) +validate_auth_config_at_startup = _auth.validate_auth_config_at_startup +get_token_claims = _auth.get_token_claims +require_permissions = _auth.require_permissions +get_current_user = _auth.get_current_user +get_request_auth_context = _auth.get_request_auth_context -@lru_cache -def get_oidc_settings() -> OIDCSettings: - return OIDCSettings() - - -@lru_cache -def get_jwt_validator() -> JWTValidator: - return JWTValidator(get_oidc_settings()) - - -_DEV_CLAIMS = TokenClaims( - sub="dev|local-placeholder", - email="dev@example.com", - name="Dev User", - permissions=ALL_PERMISSIONS, - raw={}, -) - -# auto_error=False so that when AUTH_DISABLED=true the missing header -# doesn't trigger a 403 before our custom handler runs. -_bearer_scheme = HTTPBearer(auto_error=False) - - -def validate_auth_config_at_startup() -> None: - settings = get_settings() - - if settings.auth_disabled: - logger.warning("Auth is disabled — all requests use dev credentials") - return - - # fail fast if OIDC config is invalid - get_oidc_settings() - - -def _extract_bearer_token_from_request(request: Request) -> str | None: - """ - Return the raw bearer token from the Authorization header. - - This exists separately from JWT validation so that downstream callers can - opt into explicit bearer-token relay in the future. - """ - auth_header = request.headers.get("Authorization") - if not auth_header: - return None - - scheme, _, token = auth_header.partition(" ") - if scheme.lower() != "bearer" or not token: - return None - - return token - - -def _dev_bearer_token() -> str: - """ - Placeholder token used only when auth is disabled in local development. - """ - return "dev-placeholder-token" - - -def _claims_to_user(claims: TokenClaims) -> User: - return User( - id=1, - sub=claims.sub, - email=claims.email or "unknown@example.com", - name=claims.name or claims.email or claims.sub, - ) - - -async def get_token_claims( - request: Request, - _credential: HTTPAuthorizationCredentials | None = Depends(_bearer_scheme), -) -> TokenClaims: - """Extract and validate JWT from Authorization header. - - The ``_credential`` parameter exists solely so FastAPI registers the - HTTPBearer security scheme in the OpenAPI spec (Swagger "Authorize" - button). Actual token parsing still uses the raw header so we can - return precise 401 messages for missing/malformed values. - """ - if get_settings().auth_disabled: - return _DEV_CLAIMS - - auth_header = request.headers.get("Authorization") - if not auth_header: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Missing Authorization header", - headers={"WWW-Authenticate": "Bearer"}, - ) - - scheme, _, token = auth_header.partition(" ") - if scheme.lower() != "bearer" or not token: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid Authorization header format", - headers={"WWW-Authenticate": "Bearer"}, - ) - - validator = get_jwt_validator() - try: - return await asyncio.to_thread(validator.validate, token) - except JWKSFetchError: - logger.error( - "JWKS endpoint unreachable or returned invalid data", exc_info=True - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - except AuthError as e: - logger.warning("JWT validation failed: %s", e, exc_info=True) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - - -Claims = Annotated[TokenClaims, Depends(get_token_claims)] - - -def _permission_exception_handler(exc: InsufficientPermissionsError) -> NoReturn: - raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=exc.detail) - - -def require_permissions( - *required_permissions: str, check: Literal["all", "any"] = "all" -): - async def dependency(claims: Claims) -> None: - check_permissions( - granted=claims.permissions, - required=required_permissions, - check=check, - exc_handler=_permission_exception_handler, - ) - - return dependency - - -async def get_current_user(claims: Claims) -> User: - """ - Resolve validated token claims to a lightweight request user. - """ - if get_settings().auth_disabled: - return User( - id=1, - sub=_DEV_CLAIMS.sub, - email=_DEV_CLAIMS.email or "dev@example.com", - name=_DEV_CLAIMS.name or "Dev User", - ) - return _claims_to_user(claims) - - -CurrentUser = Annotated[User, Depends(get_current_user)] - - -async def get_request_auth_context( - request: Request, - user: CurrentUser, -) -> RequestAuthContext: - """ - Build the request-scoped auth context used by downstream API clients. - - The current deployment wiring uses env-based downstream auth, but we keep - the raw bearer token available for future explicit relay or OBO modes. - """ - if get_settings().auth_disabled: - bearer_token = _dev_bearer_token() - else: - bearer_token = _extract_bearer_token_from_request(request) - - return RequestAuthContext( - user=user, - bearer_token=bearer_token, - ) - - -AuthContext = Annotated[ - RequestAuthContext, - Depends(get_request_auth_context), -] +Claims = _auth.Claims +CurrentUser = _auth.CurrentUser +AuthContext = _auth.AuthContext diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/client.py b/deployments/entity-linkage/src/stitch/entity_linkage/client.py index 907020db..23971b73 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/client.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/client.py @@ -2,11 +2,17 @@ from typing import Any -from stitch.client import AsyncStitchClient, env_bearer_token_headers_provider +from stitch.client import AsyncStitchClient +from stitch.service.auth import AuthMode, build_headers_provider from stitch.entity_linkage.entities import FieldCandidate, FieldDetailCandidate from stitch.entity_linkage.settings import get_settings +# Entity-linkage does its work in a detached background job, so the caller's +# token is gone by the time the run executes — it authenticates downstream with +# its own machine identity (STITCH_CLIENT_BEARER_TOKEN), not on-behalf-of. +_DOWNSTREAM_AUTH_MODE = AuthMode.machine + def _get_api_base_url() -> str: """ @@ -16,8 +22,8 @@ def _get_api_base_url() -> str: def validate_downstream_auth_config_at_startup() -> None: - headers_provider = env_bearer_token_headers_provider() - headers_provider() + # Fail fast at startup if the machine token isn't configured. + build_headers_provider(_DOWNSTREAM_AUTH_MODE)() class StitchApiClient: @@ -29,11 +35,10 @@ def __init__( self._client = client return - headers_provider = env_bearer_token_headers_provider() self._client = AsyncStitchClient( base_url=_get_api_base_url(), timeout=30.0, - headers_provider=headers_provider, + headers_provider=build_headers_provider(_DOWNSTREAM_AUTH_MODE), ) async def __aenter__(self) -> "StitchApiClient": diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/entities.py b/deployments/entity-linkage/src/stitch/entity_linkage/entities.py index b4c2e550..a8864849 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/entities.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/entities.py @@ -1,10 +1,13 @@ -from dataclasses import dataclass from datetime import datetime from math import ceil from typing import Literal -from pydantic import BaseModel, EmailStr, Field, computed_field +from pydantic import BaseModel, Field, computed_field +# Identity is shared scaffolding now; re-exported here so existing imports +# (`from stitch.entity_linkage.entities import User, RequestAuthContext`) keep +# working. +from stitch.service.auth import RequestAuthContext, ServiceUser as User from stitch.ogsi.model.types import ( FieldStatus, LocationType, @@ -13,34 +16,27 @@ ProductionConventionality, ) +__all__ = [ + "FieldCandidate", + "FieldDetailCandidate", + "MatchGroup", + "OGFieldFilterParams", + "OGFieldQueryParams", + "OGFieldSortParams", + "PaginatedResponse", + "PaginationParams", + "RequestAuthContext", + "SortableField", + "Timestamped", + "User", +] + class Timestamped(BaseModel): created: datetime = Field(default_factory=datetime.now) updated: datetime = Field(default_factory=datetime.now) -class User(BaseModel): - id: int = Field(...) - sub: str = Field(...) - role: str | None = None - email: EmailStr - name: str - - -@dataclass(frozen=True, slots=True) -class RequestAuthContext: - """ - Request-scoped auth context for inbound request identity. - - not implemented: - - re-enable downstream relay or OBO auth as an explicit client mode - - keep user attribution/provenance as separate metadata - """ - - user: User - bearer_token: str | None - - class FieldCandidate(BaseModel): id: int name: str | None = None diff --git a/deployments/entity-linkage/tests/test_downstream_auth.py b/deployments/entity-linkage/tests/test_downstream_auth.py new file mode 100644 index 00000000..206c11a1 --- /dev/null +++ b/deployments/entity-linkage/tests/test_downstream_auth.py @@ -0,0 +1,31 @@ +"""Entity-linkage authenticates downstream with its own machine identity. + +It runs its work in a detached background job, so the caller's token is gone by +the time the run executes — passthrough is not an option here. +""" + +import pytest +from stitch.client.auth import STITCH_CLIENT_BEARER_TOKEN_ENV_VAR +from stitch.service.auth import AuthMode + +from stitch.entity_linkage import client as client_module + + +def test_downstream_uses_machine_identity() -> None: + assert client_module._DOWNSTREAM_AUTH_MODE is AuthMode.machine + + +def test_validate_downstream_requires_machine_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, raising=False) + with pytest.raises(ValueError): + client_module.validate_downstream_auth_config_at_startup() + + +def test_validate_downstream_passes_with_machine_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, "machine-tok") + # Should not raise. + client_module.validate_downstream_auth_config_at_startup() diff --git a/packages/stitch-service/pyproject.toml b/packages/stitch-service/pyproject.toml index f48e67a3..120f9b95 100644 --- a/packages/stitch-service/pyproject.toml +++ b/packages/stitch-service/pyproject.toml @@ -6,6 +6,8 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "fastapi[standard-no-fastapi-cloud-cli]>=0.135.1", + "stitch-auth", + "stitch-client", ] [build-system] @@ -15,6 +17,10 @@ build-backend = "uv_build" [tool.uv.build-backend] module-name = "stitch.service" +[tool.uv.sources] +stitch-auth = { workspace = true } +stitch-client = { workspace = true } + [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] diff --git a/packages/stitch-service/src/stitch/service/__init__.py b/packages/stitch-service/src/stitch/service/__init__.py index eccda2c8..03c7778d 100644 --- a/packages/stitch-service/src/stitch/service/__init__.py +++ b/packages/stitch-service/src/stitch/service/__init__.py @@ -1,12 +1,22 @@ """Shared FastAPI scaffolding for Stitch non-core services. -Provides the app factory, CORS wiring, and health helpers that every service -otherwise copies. Observability and auth extraction are intentionally out of -scope for now (observability is in flight on a separate branch); the app -factory leaves lifecycle hooks open so they can be added later. +Provides the app factory, CORS wiring, health helpers, and the auth seam (both +inbound request validation and the downstream machine / on-behalf-of modes) that +every service otherwise copies. Observability is intentionally out of scope for +now (in flight on a separate branch); the app factory leaves lifecycle hooks +open so it can be added later. """ from .app import create_app +from .auth import ( + AuthMode, + RequestAuthContext, + ServiceAuth, + ServiceUser, + build_headers_provider, + machine_token_headers_provider, + relay_token_headers_provider, +) from .health import ( format_started_at, make_basic_health_router, @@ -16,10 +26,17 @@ from .middleware import register_cors __all__ = [ + "AuthMode", + "RequestAuthContext", + "ServiceAuth", + "ServiceUser", + "build_headers_provider", "create_app", "format_started_at", + "machine_token_headers_provider", "make_basic_health_router", "register_cors", + "relay_token_headers_provider", "runtime_block", "uptime_seconds", ] diff --git a/packages/stitch-service/src/stitch/service/auth.py b/packages/stitch-service/src/stitch/service/auth.py new file mode 100644 index 00000000..612aa14f --- /dev/null +++ b/packages/stitch-service/src/stitch/service/auth.py @@ -0,0 +1,294 @@ +# NOTE: no `from __future__ import annotations` here. The dependency callables +# built in ServiceAuth.__init__ carry real Annotated objects (Claims/CurrentUser) +# as parameter annotations; stringized annotations would not resolve from the +# closure scope when FastAPI inspects the signature. + +import asyncio +import logging +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from enum import Enum +from typing import Annotated, Literal, NoReturn + +from fastapi import Depends, HTTPException, Request +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pydantic import BaseModel +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +from stitch.auth import ( + ALL_PERMISSIONS, + AuthError, + InsufficientPermissionsError, + JWKSFetchError, + JWTValidator, + OIDCSettings, + TokenClaims, + check_permissions, +) +from stitch.client import env_bearer_token_headers_provider + +logger = logging.getLogger("stitch.service.auth") + + +# --------------------------------------------------------------------------- # +# Identity models +# --------------------------------------------------------------------------- # + + +class ServiceUser(BaseModel): + """Lightweight request identity resolved from validated token claims. + + ``id`` defaults to a placeholder; services that need a persisted user row + supply their own ``user_factory`` to :class:`ServiceAuth`. + """ + + id: int = 1 + sub: str + email: str + name: str + role: str | None = None + + +@dataclass(frozen=True, slots=True) +class RequestAuthContext: + """Request-scoped identity plus the raw caller bearer token. + + The token is retained so a request-scoped (synchronous) service can relay it + downstream in on-behalf-of mode. Background jobs cannot use it (the request + is gone by the time the job runs) and should use machine identity instead. + """ + + user: ServiceUser + bearer_token: str | None + + +# --------------------------------------------------------------------------- # +# Downstream auth seam — how a service authenticates when calling other services +# --------------------------------------------------------------------------- # + + +class AuthMode(str, Enum): + #: Call downstream with the service's own machine identity (env token). + machine = "machine" + #: Forward the caller's token downstream unchanged. This is token + #: *passthrough*, NOT RFC 8693 on-behalf-of: no new token is minted and + #: nothing records the intermediate hop, so it relies on the downstream + #: accepting the same token (shared audience). True OBO (token exchange with + #: an ``act`` actor claim) would be added as a separate mode if needed. + passthrough = "passthrough" + + +def machine_token_headers_provider() -> Callable[[], Mapping[str, str]]: + """Machine identity: bearer token read from the env (STITCH_CLIENT_BEARER_TOKEN).""" + return env_bearer_token_headers_provider() + + +def relay_token_headers_provider(token: str) -> Callable[[], Mapping[str, str]]: + """Passthrough: relay a specific caller token on each downstream request.""" + header = {"Authorization": f"Bearer {token}"} + + def provider() -> Mapping[str, str]: + return dict(header) + + return provider + + +def build_headers_provider( + mode: AuthMode, *, token: str | None = None +) -> Callable[[], Mapping[str, str]]: + """Build the downstream ``headers_provider`` for the chosen auth mode. + + ``machine`` reads the env token; ``passthrough`` requires ``token`` (the + caller's bearer token, e.g. ``RequestAuthContext.bearer_token``) and + forwards it unchanged. + """ + if mode is AuthMode.machine: + return machine_token_headers_provider() + if mode is AuthMode.passthrough: + if not token: + raise ValueError("passthrough mode requires a caller token") + return relay_token_headers_provider(token) + raise ValueError(f"unknown auth mode: {mode!r}") + + +# --------------------------------------------------------------------------- # +# Inbound auth — validating incoming requests +# --------------------------------------------------------------------------- # + + +DEFAULT_DEV_CLAIMS = TokenClaims( + sub="dev|local-placeholder", + email="dev@example.com", + name="Dev User", + permissions=ALL_PERMISSIONS, + raw={}, +) + + +def _dev_bearer_token() -> str: + """Placeholder token used only when auth is disabled in local development.""" + return "dev-placeholder-token" + + +def _extract_bearer_token_from_request(request: Request) -> str | None: + auth_header = request.headers.get("Authorization") + if not auth_header: + return None + scheme, _, token = auth_header.partition(" ") + if scheme.lower() != "bearer" or not token: + return None + return token + + +def _default_user_from_claims(claims: TokenClaims) -> ServiceUser: + return ServiceUser( + id=1, + sub=claims.sub, + email=claims.email or "unknown@example.com", + name=claims.name or claims.email or claims.sub, + ) + + +def _permission_exception_handler(exc: InsufficientPermissionsError) -> NoReturn: + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=exc.detail) + + +class ServiceAuth: + """Inbound auth wiring shared by Stitch services. + + Produces the FastAPI dependencies a service needs (``get_token_claims``, + ``require_permissions``, ``get_current_user``, ``get_request_auth_context`` + and their ``Annotated`` aliases ``Claims``/``CurrentUser``/``AuthContext``). + A service constructs one instance and re-exports the attributes it uses. + + Config seams: + - ``is_auth_disabled``: callable read per request; when true, all requests + resolve to ``dev_claims`` (local-dev bypass). + - ``user_factory``: maps validated claims to a user (override to hit a DB). + - ``oidc_settings_factory`` / ``dev_claims``: rarely overridden. + """ + + def __init__( + self, + *, + is_auth_disabled: Callable[[], bool], + oidc_settings_factory: Callable[[], OIDCSettings] = OIDCSettings, + dev_claims: TokenClaims | None = None, + user_factory: Callable[[TokenClaims], ServiceUser] = _default_user_from_claims, + ) -> None: + self._is_auth_disabled = is_auth_disabled + self._oidc_settings_factory = oidc_settings_factory + self._dev_claims = dev_claims if dev_claims is not None else DEFAULT_DEV_CLAIMS + self._user_factory = user_factory + self._oidc_settings: OIDCSettings | None = None + self._validator: JWTValidator | None = None + + # auto_error=False so a missing header doesn't 403 before our handler + # runs (and so AUTH_DISABLED can short-circuit). + bearer_scheme = HTTPBearer(auto_error=False) + + async def get_token_claims( + request: Request, + _credential: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), + ) -> TokenClaims: + """Extract and validate the JWT from the Authorization header. + + ``_credential`` exists only so FastAPI registers the HTTPBearer + scheme in OpenAPI (the Swagger "Authorize" button); token parsing + uses the raw header for precise 401 messages. + """ + if self._is_auth_disabled(): + return self._dev_claims + + auth_header = request.headers.get("Authorization") + if not auth_header: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Missing Authorization header", + headers={"WWW-Authenticate": "Bearer"}, + ) + + scheme, _, token = auth_header.partition(" ") + if scheme.lower() != "bearer" or not token: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid Authorization header format", + headers={"WWW-Authenticate": "Bearer"}, + ) + + validator = self._jwt_validator() + try: + return await asyncio.to_thread(validator.validate, token) + except JWKSFetchError: + logger.error( + "JWKS endpoint unreachable or returned invalid data", + exc_info=True, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + except AuthError as exc: + logger.warning("JWT validation failed: %s", exc, exc_info=True) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + Claims = Annotated[TokenClaims, Depends(get_token_claims)] + + def require_permissions( + *required_permissions: str, check: Literal["all", "any"] = "all" + ): + async def dependency(claims: Claims) -> None: + check_permissions( + granted=claims.permissions, + required=required_permissions, + check=check, + exc_handler=_permission_exception_handler, + ) + + return dependency + + async def get_current_user(claims: Claims) -> ServiceUser: + # When auth is disabled, `claims` is already the dev claims. + return self._user_factory(claims) + + CurrentUser = Annotated[ServiceUser, Depends(get_current_user)] + + async def get_request_auth_context( + request: Request, user: CurrentUser + ) -> RequestAuthContext: + if self._is_auth_disabled(): + bearer_token = _dev_bearer_token() + else: + bearer_token = _extract_bearer_token_from_request(request) + return RequestAuthContext(user=user, bearer_token=bearer_token) + + AuthContext = Annotated[RequestAuthContext, Depends(get_request_auth_context)] + + self.get_token_claims = get_token_claims + self.require_permissions = require_permissions + self.get_current_user = get_current_user + self.get_request_auth_context = get_request_auth_context + self.Claims = Claims + self.CurrentUser = CurrentUser + self.AuthContext = AuthContext + + def oidc_settings(self) -> OIDCSettings: + if self._oidc_settings is None: + self._oidc_settings = self._oidc_settings_factory() + return self._oidc_settings + + def _jwt_validator(self) -> JWTValidator: + if self._validator is None: + self._validator = JWTValidator(self.oidc_settings()) + return self._validator + + def validate_auth_config_at_startup(self) -> None: + if self._is_auth_disabled(): + logger.warning("Auth is disabled — all requests use dev credentials") + return + # Fail fast if OIDC config is invalid. + self.oidc_settings() diff --git a/packages/stitch-service/tests/test_auth.py b/packages/stitch-service/tests/test_auth.py new file mode 100644 index 00000000..8c5315bb --- /dev/null +++ b/packages/stitch-service/tests/test_auth.py @@ -0,0 +1,121 @@ +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient +from stitch.auth import SOURCE_WRITE, TokenClaims + +from stitch.service.auth import ( + AuthMode, + ServiceAuth, + build_headers_provider, + machine_token_headers_provider, + relay_token_headers_provider, +) +from stitch.client.auth import STITCH_CLIENT_BEARER_TOKEN_ENV_VAR + + +# --------------------------------------------------------------------------- # +# Downstream auth seam +# --------------------------------------------------------------------------- # + + +def test_machine_provider_reads_env_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, "machine-tok") + provider = build_headers_provider(AuthMode.machine) + assert provider() == {"Authorization": "Bearer machine-tok"} + # Sanity: the helper and the dispatcher agree. + assert machine_token_headers_provider()() == {"Authorization": "Bearer machine-tok"} + + +def test_machine_provider_requires_env_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, raising=False) + provider = build_headers_provider(AuthMode.machine) + with pytest.raises(ValueError): + provider() + + +def test_passthrough_provider_relays_caller_token() -> None: + provider = build_headers_provider(AuthMode.passthrough, token="caller-jwt") + assert provider() == {"Authorization": "Bearer caller-jwt"} + assert relay_token_headers_provider("x")() == {"Authorization": "Bearer x"} + + +def test_passthrough_requires_token() -> None: + with pytest.raises(ValueError): + build_headers_provider(AuthMode.passthrough) + + +# --------------------------------------------------------------------------- # +# Inbound auth +# --------------------------------------------------------------------------- # + + +def build_app(auth: ServiceAuth) -> FastAPI: + app = FastAPI() + + @app.get("/me") + async def me(user: auth.CurrentUser): + return {"sub": user.sub, "name": user.name} + + @app.get("/context") + async def context(ctx: auth.AuthContext): + return {"sub": ctx.user.sub, "bearer_token": ctx.bearer_token} + + @app.post( + "/guarded", + dependencies=[Depends(auth.require_permissions(SOURCE_WRITE))], + ) + async def guarded(): + return {"ok": True} + + return app + + +def test_auth_disabled_resolves_dev_user_without_a_token() -> None: + auth = ServiceAuth(is_auth_disabled=lambda: True) + app = build_app(auth) + + with TestClient(app) as client: + me = client.get("/me") + assert me.status_code == 200 + assert me.json()["sub"] == "dev|local-placeholder" + + # Dev claims carry all permissions, so the guarded route is allowed. + assert client.post("/guarded").status_code == 200 + + # In disabled mode the relayed token is the dev placeholder. + ctx = client.get("/context") + assert ctx.json()["bearer_token"] == "dev-placeholder-token" + + +def test_require_permissions_rejects_missing_permission() -> None: + auth = ServiceAuth(is_auth_disabled=lambda: False) + app = build_app(auth) + + def claims_without_permission() -> TokenClaims: + return TokenClaims(sub="user|1", permissions=frozenset()) + + app.dependency_overrides[auth.get_token_claims] = claims_without_permission + + with TestClient(app) as client: + response = client.post("/guarded") + + assert response.status_code == 403 + assert SOURCE_WRITE in response.json()["detail"] + + +def test_request_context_relays_caller_bearer_token() -> None: + auth = ServiceAuth(is_auth_disabled=lambda: False) + app = build_app(auth) + + def claims() -> TokenClaims: + return TokenClaims(sub="user|1", permissions=frozenset({SOURCE_WRITE})) + + app.dependency_overrides[auth.get_token_claims] = claims + + with TestClient(app) as client: + response = client.get( + "/context", headers={"Authorization": "Bearer caller-jwt"} + ) + + assert response.status_code == 200 + assert response.json()["bearer_token"] == "caller-jwt" diff --git a/packages/stitch-service/tests/test_downstream_client.py b/packages/stitch-service/tests/test_downstream_client.py new file mode 100644 index 00000000..d39ca717 --- /dev/null +++ b/packages/stitch-service/tests/test_downstream_client.py @@ -0,0 +1,48 @@ +"""Integration-level checks that the downstream auth modes actually attach the +expected Authorization header on outgoing requests via AsyncStitchClient.""" + +from collections.abc import Callable, Mapping + +import httpx +import pytest +from stitch.client import AsyncStitchClient +from stitch.client.auth import STITCH_CLIENT_BEARER_TOKEN_ENV_VAR + +from stitch.service.auth import AuthMode, build_headers_provider + + +def _capturing_client( + seen: dict, headers_provider: Callable[[], Mapping[str, str]] +) -> AsyncStitchClient: + def handler(request: httpx.Request) -> httpx.Response: + seen["authorization"] = request.headers.get("Authorization") + return httpx.Response(200, json={}) + + raw = httpx.AsyncClient( + transport=httpx.MockTransport(handler), + base_url="http://downstream.test/api/v1", + ) + return AsyncStitchClient(client=raw, headers_provider=headers_provider) + + +@pytest.mark.anyio +async def test_passthrough_mode_forwards_caller_token() -> None: + seen: dict = {} + provider = build_headers_provider(AuthMode.passthrough, token="caller-jwt") + + async with _capturing_client(seen, provider) as client: + await client.get_auth_me() + + assert seen["authorization"] == "Bearer caller-jwt" + + +@pytest.mark.anyio +async def test_machine_mode_sends_env_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, "machine-tok") + seen: dict = {} + provider = build_headers_provider(AuthMode.machine) + + async with _capturing_client(seen, provider) as client: + await client.get_auth_me() + + assert seen["authorization"] == "Bearer machine-tok" diff --git a/uv.lock b/uv.lock index 049eaed0..9ef0427e 100644 --- a/uv.lock +++ b/uv.lock @@ -1298,6 +1298,8 @@ version = "0.1.0" source = { editable = "packages/stitch-service" } dependencies = [ { name = "fastapi", extra = ["standard-no-fastapi-cloud-cli"] }, + { name = "stitch-auth" }, + { name = "stitch-client" }, ] [package.dev-dependencies] @@ -1308,7 +1310,11 @@ dev = [ ] [package.metadata] -requires-dist = [{ name = "fastapi", extras = ["standard-no-fastapi-cloud-cli"], specifier = ">=0.135.1" }] +requires-dist = [ + { name = "fastapi", extras = ["standard-no-fastapi-cloud-cli"], specifier = ">=0.135.1" }, + { name = "stitch-auth", editable = "packages/stitch-auth" }, + { name = "stitch-client", editable = "packages/stitch-client" }, +] [package.metadata.requires-dev] dev = [ From f5cc3d3561b52ce3e36d95d7b2d17c5a5ed6707e Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 12:53:45 +0200 Subject: [PATCH 09/26] better handling of terminal states and job restarts --- .../stitch-jobs/src/stitch/jobs/manager.py | 28 +++++- .../stitch-jobs/src/stitch/jobs/routers.py | 10 +- packages/stitch-jobs/src/stitch/jobs/store.py | 23 ++++- packages/stitch-jobs/tests/test_manager.py | 97 +++++++++++++++++++ 4 files changed, 147 insertions(+), 11 deletions(-) diff --git a/packages/stitch-jobs/src/stitch/jobs/manager.py b/packages/stitch-jobs/src/stitch/jobs/manager.py index f9db2f9a..85dec8b6 100644 --- a/packages/stitch-jobs/src/stitch/jobs/manager.py +++ b/packages/stitch-jobs/src/stitch/jobs/manager.py @@ -13,6 +13,9 @@ logger = logging.getLogger("stitch.jobs") +#: Terminal states that, by default, an existing run may be reused from. +_DEFAULT_REUSABLE_TERMINAL = frozenset({JobState.succeeded, JobState.failed}) + def _utcnow() -> datetime: return datetime.now(UTC) @@ -32,6 +35,13 @@ class JobManager(Generic[P, R]): still active — or finished within ``recent_within`` — and returns it instead of starting a duplicate. That is what lets a second user observe (and reuse the results of) a run another user already kicked off. + + Reuse is tunable: + - ``recent_within`` — how long after finishing a terminal run stays + reusable. ``None`` means forever (no expiry). + - ``reuse_failed`` — when ``False``, failed runs are kept/visible but are + not reused, so the next request retries (transient failures self-heal). + - ``start(force=True)`` — bypass reuse entirely and always launch a new run. """ def __init__( @@ -40,32 +50,40 @@ def __init__( *, store: JobStore | None = None, policy: UniquenessPolicy | None = None, - recent_within: timedelta = timedelta(0), + recent_within: timedelta | None = timedelta(0), + reuse_failed: bool = True, clock: Callable[[], datetime] = _utcnow, ) -> None: self._run_fn = run_fn self._store: JobStore = store or InMemoryJobStore(clock=clock) self._policy = policy or SingletonPolicy() self._recent_within = recent_within + self._reusable_states = frozenset({JobState.running}) | ( + _DEFAULT_REUSABLE_TERMINAL + if reuse_failed + else frozenset({JobState.succeeded}) + ) self._clock = clock self._lock = asyncio.Lock() # Hold strong refs so tasks aren't garbage-collected mid-flight. self._tasks: set[asyncio.Task[None]] = set() async def start( - self, params: P, *, initiated_by: str | None = None + self, params: P, *, initiated_by: str | None = None, force: bool = False ) -> tuple[JobRecord[P, R], bool]: """Start a run, or join an existing matching one. Returns ``(record, created)`` where ``created`` is ``False`` when an existing active/recent run with the same dedup key was returned instead - of launching a new task. + of launching a new task. ``force=True`` always launches a new run. """ async with self._lock: key = self._policy.key(params) - if key is not None: + if not force and key is not None: existing = await self._store.find_active_or_recent( - key, recent_within=self._recent_within + key, + recent_within=self._recent_within, + reusable_states=self._reusable_states, ) if existing is not None: return existing, False diff --git a/packages/stitch-jobs/src/stitch/jobs/routers.py b/packages/stitch-jobs/src/stitch/jobs/routers.py index ad583e6c..06752a10 100644 --- a/packages/stitch-jobs/src/stitch/jobs/routers.py +++ b/packages/stitch-jobs/src/stitch/jobs/routers.py @@ -25,6 +25,7 @@ def make_job_router( to_params: Callable[[Any], BaseModel] | None = None, dependencies: Sequence[Any] = (), initiated_by: Callable[..., Awaitable[str | None] | str | None] | None = None, + force_attr: str | None = None, tags: Sequence[str] | None = None, default_list_limit: int = 20, ) -> APIRouter: @@ -36,6 +37,10 @@ def make_job_router( wire request. ``dependencies`` is where the service plugs in its permission gate (e.g. ``[Depends(require_permissions(...))]``); ``initiated_by`` is an optional dependency returning the caller's display label. + + ``force_attr`` names a boolean field on the request body that, when true, + bypasses dedup and forces a fresh run. Keep that field out of ``params`` (via + ``to_params``) so it never participates in the dedup key. """ params_model = params_model or start_request_model to_params = to_params or (lambda request: request) @@ -63,7 +68,10 @@ async def start( caller observes that run rather than starting a duplicate). """ params = to_params(request) - record, created = await manager.start(params, initiated_by=initiated_by_label) + force = bool(getattr(request, force_attr)) if force_attr else False + record, created = await manager.start( + params, initiated_by=initiated_by_label, force=force + ) if not created: response.status_code = HTTP_200_OK return record diff --git a/packages/stitch-jobs/src/stitch/jobs/store.py b/packages/stitch-jobs/src/stitch/jobs/store.py index 4af9434b..c7c3561f 100644 --- a/packages/stitch-jobs/src/stitch/jobs/store.py +++ b/packages/stitch-jobs/src/stitch/jobs/store.py @@ -27,7 +27,11 @@ async def get(self, job_id: str) -> JobRecord | None: """Return the record for ``job_id``, or ``None`` if unknown.""" async def find_active_or_recent( - self, dedup_key: str, *, recent_within: timedelta + self, + dedup_key: str, + *, + recent_within: timedelta | None, + reusable_states: frozenset[JobState], ) -> JobRecord | None: """Return a matching job that is running or finished recently.""" @@ -78,11 +82,18 @@ async def get(self, job_id: str) -> JobRecord | None: return self._records.get(job_id) async def find_active_or_recent( - self, dedup_key: str, *, recent_within: timedelta + self, + dedup_key: str, + *, + recent_within: timedelta | None, + reusable_states: frozenset[JobState], ) -> JobRecord | None: - """Return the newest matching job that is still running, or that - finished within ``recent_within``. Newest-first so callers join/observe - the most relevant run. + """Return the newest matching, reusable job. + + A record matches when its key equals ``dedup_key``, its state is in + ``reusable_states``, and it is either still running or finished within + ``recent_within`` (``None`` means no age limit — reuse forever). + Newest-first so callers join/observe the most relevant run. """ self._evict_expired() now = self._clock() @@ -90,8 +101,10 @@ async def find_active_or_recent( record for record in self._records.values() if record.dedup_key == dedup_key + and record.state in reusable_states and ( record.state == JobState.running + or recent_within is None or ( record.finished_at is not None and now - record.finished_at <= recent_within diff --git a/packages/stitch-jobs/tests/test_manager.py b/packages/stitch-jobs/tests/test_manager.py index b0a33203..2d9b62dd 100644 --- a/packages/stitch-jobs/tests/test_manager.py +++ b/packages/stitch-jobs/tests/test_manager.py @@ -171,6 +171,103 @@ async def run(params: Params) -> Result: await _wait_until_terminal(manager, fresh.job_id) +@pytest.mark.anyio +async def test_force_bypasses_an_active_run() -> None: + release = asyncio.Event() + + async def run(params: Params) -> Result: + await release.wait() + return Result(value=1) + + manager: JobManager[Params, Result] = JobManager(run, policy=SingletonPolicy()) + first, first_created = await manager.start(Params(name="a")) + forced, forced_created = await manager.start(Params(name="a"), force=True) + + assert first_created is True + assert forced_created is True # force ignores the active run + assert forced.job_id != first.job_id + + release.set() + await _wait_until_terminal(manager, first.job_id) + await _wait_until_terminal(manager, forced.job_id) + + +@pytest.mark.anyio +async def test_recent_within_none_reuses_indefinitely() -> None: + now = {"t": datetime(2026, 1, 1, tzinfo=UTC)} + + def clock() -> datetime: + return now["t"] + + async def run(params: Params) -> Result: + return Result(value=1) + + store = InMemoryJobStore(clock=clock, retention=None) + manager: JobManager[Params, Result] = JobManager( + run, + store=store, + policy=FingerprintPolicy(), + recent_within=None, + clock=clock, + ) + + first, _ = await manager.start(Params(name="a")) + await _wait_until_terminal(manager, first.job_id) + + # A year later, the same params still reuse the original run. + now["t"] = now["t"] + timedelta(days=365) + reused, created = await manager.start(Params(name="a")) + assert created is False + assert reused.job_id == first.job_id + + +@pytest.mark.anyio +async def test_failed_runs_are_not_reused_when_reuse_failed_false() -> None: + calls = {"n": 0} + + async def run(params: Params) -> Result: + calls["n"] += 1 + raise RuntimeError("boom") + + manager: JobManager[Params, Result] = JobManager( + run, + policy=FingerprintPolicy(), + recent_within=None, + reuse_failed=False, + ) + + first, first_created = await manager.start(Params(name="a")) + await _wait_until_terminal(manager, first.job_id) + assert first_created is True + + # The failed run is not reused — the next request retries with a new job. + second, second_created = await manager.start(Params(name="a")) + assert second_created is True + assert second.job_id != first.job_id + await _wait_until_terminal(manager, second.job_id) + assert calls["n"] == 2 + + +@pytest.mark.anyio +async def test_succeeded_runs_reused_even_when_reuse_failed_false() -> None: + async def run(params: Params) -> Result: + return Result(value=1) + + manager: JobManager[Params, Result] = JobManager( + run, + policy=FingerprintPolicy(), + recent_within=None, + reuse_failed=False, + ) + + first, _ = await manager.start(Params(name="a")) + await _wait_until_terminal(manager, first.job_id) + + reused, created = await manager.start(Params(name="a")) + assert created is False + assert reused.job_id == first.job_id + + @pytest.mark.anyio async def test_terminal_records_evicted_after_retention() -> None: now = {"t": datetime(2026, 1, 1, tzinfo=UTC)} From baf555b9faee17f29fc6da057c72713b9b969c9d Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 13:01:54 +0200 Subject: [PATCH 10/26] Convert `stitch-llm` to use new packages --- deployments/stitch-llm/pyproject.toml | 4 + deployments/stitch-llm/src/stitch/llm/auth.py | 164 +------- .../stitch-llm/src/stitch/llm/client.py | 14 +- .../stitch-llm/src/stitch/llm/entities.py | 11 +- deployments/stitch-llm/src/stitch/llm/jobs.py | 106 +++++ deployments/stitch-llm/src/stitch/llm/main.py | 26 +- .../stitch-llm/src/stitch/llm/middleware.py | 26 -- .../src/stitch/llm/routers/oil_gas_fields.py | 202 +++------ .../stitch-llm/tests/test_downstream_auth.py | 30 ++ .../tests/test_oil_gas_fields_api.py | 384 ++++++++++-------- uv.lock | 4 + 11 files changed, 441 insertions(+), 530 deletions(-) create mode 100644 deployments/stitch-llm/src/stitch/llm/jobs.py delete mode 100644 deployments/stitch-llm/src/stitch/llm/middleware.py create mode 100644 deployments/stitch-llm/tests/test_downstream_auth.py diff --git a/deployments/stitch-llm/pyproject.toml b/deployments/stitch-llm/pyproject.toml index 4d1f3f7c..bf76d1d5 100644 --- a/deployments/stitch-llm/pyproject.toml +++ b/deployments/stitch-llm/pyproject.toml @@ -11,8 +11,10 @@ dependencies = [ "pydantic-settings>=2.12.0", "stitch-auth", "stitch-client", + "stitch-jobs", "stitch-models", "stitch-ogsi", + "stitch-service", ] [project.scripts] @@ -42,5 +44,7 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } stitch-client = { workspace = true } +stitch-jobs = { workspace = true } stitch-models = { workspace = true } stitch-ogsi = { workspace = true } +stitch-service = { workspace = true } diff --git a/deployments/stitch-llm/src/stitch/llm/auth.py b/deployments/stitch-llm/src/stitch/llm/auth.py index fc8acff9..c059193e 100644 --- a/deployments/stitch-llm/src/stitch/llm/auth.py +++ b/deployments/stitch-llm/src/stitch/llm/auth.py @@ -1,156 +1,22 @@ -import asyncio -import logging -from functools import lru_cache -from typing import Annotated, Literal, NoReturn +"""stitch-llm auth wiring. -from fastapi import Depends, HTTPException, Request -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +Mechanics live in :mod:`stitch.service.auth`; here we bind a +:class:`~stitch.service.auth.ServiceAuth` to this service's settings and +re-export the dependencies the router and tests import by name. +""" -from stitch.auth import ( - ALL_PERMISSIONS, - AuthError, - InsufficientPermissionsError, - JWKSFetchError, - JWTValidator, - OIDCSettings, - TokenClaims, - check_permissions, -) +from stitch.service.auth import ServiceAuth -from stitch.llm.entities import User from stitch.llm.settings import get_settings -logger = logging.getLogger(__name__) +_auth = ServiceAuth(is_auth_disabled=lambda: get_settings().auth_disabled) +validate_auth_config_at_startup = _auth.validate_auth_config_at_startup +get_token_claims = _auth.get_token_claims +require_permissions = _auth.require_permissions +get_current_user = _auth.get_current_user +get_request_auth_context = _auth.get_request_auth_context -@lru_cache -def get_oidc_settings() -> OIDCSettings: - return OIDCSettings() - - -@lru_cache -def get_jwt_validator() -> JWTValidator: - return JWTValidator(get_oidc_settings()) - - -_DEV_CLAIMS = TokenClaims( - sub="dev|local-placeholder", - email="dev@example.com", - name="Dev User", - permissions=ALL_PERMISSIONS, - raw={}, -) - -# auto_error=False so that when AUTH_DISABLED=true the missing header -# doesn't trigger a 403 before our custom handler runs. -_bearer_scheme = HTTPBearer(auto_error=False) - - -def validate_auth_config_at_startup() -> None: - settings = get_settings() - - if settings.auth_disabled: - logger.warning("Auth is disabled — all requests use dev credentials") - return - - # fail fast if OIDC config is invalid - get_oidc_settings() - - -def _claims_to_user(claims: TokenClaims) -> User: - return User( - id=1, - sub=claims.sub, - email=claims.email or "unknown@example.com", - name=claims.name or claims.email or claims.sub, - ) - - -async def get_token_claims( - request: Request, - _credential: HTTPAuthorizationCredentials | None = Depends(_bearer_scheme), -) -> TokenClaims: - """Extract and validate JWT from Authorization header. - - The ``_credential`` parameter exists solely so FastAPI registers the - HTTPBearer security scheme in the OpenAPI spec (Swagger "Authorize" - button). Actual token parsing still uses the raw header so we can - return precise 401 messages for missing/malformed values. - """ - if get_settings().auth_disabled: - return _DEV_CLAIMS - - auth_header = request.headers.get("Authorization") - if not auth_header: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Missing Authorization header", - headers={"WWW-Authenticate": "Bearer"}, - ) - - scheme, _, token = auth_header.partition(" ") - if scheme.lower() != "bearer" or not token: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid Authorization header format", - headers={"WWW-Authenticate": "Bearer"}, - ) - - validator = get_jwt_validator() - try: - return await asyncio.to_thread(validator.validate, token) - except JWKSFetchError: - logger.error( - "JWKS endpoint unreachable or returned invalid data", exc_info=True - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - except AuthError as e: - logger.warning("JWT validation failed: %s", e, exc_info=True) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - - -Claims = Annotated[TokenClaims, Depends(get_token_claims)] - - -def _permission_exception_handler(exc: InsufficientPermissionsError) -> NoReturn: - raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=exc.detail) - - -def require_permissions( - *required_permissions: str, check: Literal["all", "any"] = "all" -): - async def dependency(claims: Claims) -> None: - check_permissions( - granted=claims.permissions, - required=required_permissions, - check=check, - exc_handler=_permission_exception_handler, - ) - - return dependency - - -async def get_current_user(claims: Claims) -> User: - """ - Resolve validated token claims to a lightweight request user. - """ - if get_settings().auth_disabled: - return User( - id=1, - sub=_DEV_CLAIMS.sub, - email=_DEV_CLAIMS.email or "dev@example.com", - name=_DEV_CLAIMS.name or "Dev User", - ) - return _claims_to_user(claims) - - -CurrentUser = Annotated[User, Depends(get_current_user)] +Claims = _auth.Claims +CurrentUser = _auth.CurrentUser +AuthContext = _auth.AuthContext diff --git a/deployments/stitch-llm/src/stitch/llm/client.py b/deployments/stitch-llm/src/stitch/llm/client.py index 562a147b..5e57315a 100644 --- a/deployments/stitch-llm/src/stitch/llm/client.py +++ b/deployments/stitch-llm/src/stitch/llm/client.py @@ -1,16 +1,21 @@ from __future__ import annotations -from stitch.client import AsyncStitchClient, env_bearer_token_headers_provider +from stitch.client import AsyncStitchClient from stitch.ogsi.model import OGFieldDetailView from pydantic import ValidationError +from stitch.service.auth import AuthMode, build_headers_provider from stitch.llm.errors import ModelOutputError from stitch.llm.settings import Settings, get_settings +# Suggestions run as detached background jobs, so the caller's token is gone by +# the time they execute — authenticate downstream with machine identity. +_DOWNSTREAM_AUTH_MODE = AuthMode.machine + def validate_downstream_auth_config_at_startup() -> None: - headers_provider = env_bearer_token_headers_provider() - headers_provider() + # Fail fast at startup if the machine token isn't configured. + build_headers_provider(_DOWNSTREAM_AUTH_MODE)() class StitchApiClient: @@ -24,11 +29,10 @@ def __init__( self._client = client return - headers_provider = env_bearer_token_headers_provider() self._client = AsyncStitchClient( base_url=str(self._settings.api_base_url), timeout=30.0, - headers_provider=headers_provider, + headers_provider=build_headers_provider(_DOWNSTREAM_AUTH_MODE), ) async def __aenter__(self) -> "StitchApiClient": diff --git a/deployments/stitch-llm/src/stitch/llm/entities.py b/deployments/stitch-llm/src/stitch/llm/entities.py index 5b92cc37..8c53a0e8 100644 --- a/deployments/stitch-llm/src/stitch/llm/entities.py +++ b/deployments/stitch-llm/src/stitch/llm/entities.py @@ -1,15 +1,10 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, EmailStr, Field +from pydantic import BaseModel +from stitch.service.auth import ServiceUser as User - -class User(BaseModel): - id: int = Field(...) - sub: str = Field(...) - role: str | None = None - email: EmailStr - name: str +__all__ = ["Citation", "FieldSuggestionResponse", "User"] class Citation(BaseModel): diff --git a/deployments/stitch-llm/src/stitch/llm/jobs.py b/deployments/stitch-llm/src/stitch/llm/jobs.py new file mode 100644 index 00000000..4187a20b --- /dev/null +++ b/deployments/stitch-llm/src/stitch/llm/jobs.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from pydantic import BaseModel + +from stitch.llm.azure_responses import AzureResponsesClient, extract_public_citations +from stitch.llm.client import StitchApiClient +from stitch.llm.entities import FieldSuggestionResponse +from stitch.llm.settings import get_settings +from stitch.llm.suggestions import ( + AllowedSuggestionField, + build_field_suggestion_input, + ensure_field_is_missing, + is_string_suggestion_field, + parse_field_suggestion_response, + sanitize_and_validate_suggested_value, +) + +PLACEHOLDER_LLM_VALUE = ":warning: placeholder LLM value" +PLACEHOLDER_LLM_MODEL = "placeholder-llm" + + +class FieldSuggestionParams(BaseModel): + """Identifies the suggestion to run; also the dedup key (resource_id, field).""" + + resource_id: int + field: AllowedSuggestionField + + +async def run_suggestion(params: FieldSuggestionParams) -> FieldSuggestionResponse: + """Produce an LLM field suggestion as a background job. + + Domain failures (resource missing, field already populated, LLM config/output + errors) propagate out and are captured by the JobManager as a failed record + (observable via ``GET /status/{job_id}``) — there is no synchronous HTTP + status mapping anymore. + """ + resource_id = params.resource_id + field = params.field + observed_at = datetime.now(UTC) + + async with StitchApiClient() as stitch_client: + detail_view = await stitch_client.get_oil_gas_field_detail(resource_id) + + ensure_field_is_missing(detail_view, field) + + input_messages = build_field_suggestion_input( + resource_id=resource_id, + field=field, + detail_view=detail_view, + ) + settings = get_settings() + + if settings.auth_disabled and not settings.azure_openai_configured: + fallback_value = ( + PLACEHOLDER_LLM_VALUE if is_string_suggestion_field(field) else None + ) + return FieldSuggestionResponse( + resource_id=resource_id, + field=field, + value=fallback_value, + citations=[], + query_succeeded=True, + model=PLACEHOLDER_LLM_MODEL, + rationale=( + "Foundry is not configured in auth-disabled mode; returned a local " + "placeholder value." + if fallback_value is not None + else "Foundry is not configured in auth-disabled mode; no safe " + "placeholder exists for this field type." + ), + observed_at=observed_at, + foundry_request={}, + foundry_response={}, + ) + + async with AzureResponsesClient() as llm_client: + llm_result = await llm_client.generate_field_suggestion( + field=field, + input_messages=input_messages, + ) + parsed = parse_field_suggestion_response(llm_result.output_text) + citations = extract_public_citations(llm_result.response_payload) + if parsed.value is None or not citations: + value = None + citations = [] + else: + value = sanitize_and_validate_suggested_value( + detail_data=detail_view.data, + field=field, + value=parsed.value, + ) + + return FieldSuggestionResponse( + resource_id=resource_id, + field=field, + value=value, + citations=citations, + query_succeeded=True, + model=llm_result.model, + rationale=parsed.rationale, + observed_at=observed_at, + foundry_request=llm_result.request_payload, + foundry_response=llm_result.response_payload, + ) diff --git a/deployments/stitch-llm/src/stitch/llm/main.py b/deployments/stitch-llm/src/stitch/llm/main.py index a6b74046..1c325f7e 100644 --- a/deployments/stitch-llm/src/stitch/llm/main.py +++ b/deployments/stitch-llm/src/stitch/llm/main.py @@ -1,36 +1,26 @@ -from contextlib import asynccontextmanager -from datetime import UTC, datetime - -from fastapi import APIRouter, FastAPI +from fastapi import FastAPI +from stitch.service import create_app from stitch.llm.auth import validate_auth_config_at_startup from stitch.llm.client import validate_downstream_auth_config_at_startup -from stitch.llm.middleware import register_middlewares from stitch.llm.routers.health import router as health_router from stitch.llm.routers.oil_gas_fields import router as oil_gas_fields_router from stitch.llm.settings import get_settings -base_router = APIRouter(prefix="/api/v1") -base_router.include_router(health_router) -base_router.include_router(oil_gas_fields_router) - -@asynccontextmanager -async def lifespan(app: FastAPI): - app.state.started_at = datetime.now(UTC) +def _run_startup(app: FastAPI) -> None: app.state.auth_config_validated = False app.state.downstream_auth_config_validated = False validate_auth_config_at_startup() app.state.auth_config_validated = True validate_downstream_auth_config_at_startup() app.state.downstream_auth_config_validated = True - yield -app = FastAPI(lifespan=lifespan) - settings = get_settings() -register_middlewares(application=app, settings=settings) - -app.include_router(base_router) +app = create_app( + routers=[health_router, oil_gas_fields_router], + cors_origins=[str(settings.frontend_origin_url)], + on_startup=_run_startup, +) diff --git a/deployments/stitch-llm/src/stitch/llm/middleware.py b/deployments/stitch-llm/src/stitch/llm/middleware.py deleted file mode 100644 index 278d4160..00000000 --- a/deployments/stitch-llm/src/stitch/llm/middleware.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Final -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from stitch.llm.settings import Settings - -ALLOWED_METHODS: Final[tuple[str, ...]] = ( - "GET", - "OPTIONS", -) - -ALLOWED_HEADERS: Final[tuple[str, ...]] = ( - "Authorization", - "Content-Type", - "Accept", - "Origin", -) - - -def register_middlewares(application: FastAPI, settings: Settings) -> None: - application.add_middleware( - CORSMiddleware, - allow_origins=[str(settings.frontend_origin_url).rstrip("/")], - allow_credentials=True, - allow_methods=ALLOWED_METHODS, - allow_headers=ALLOWED_HEADERS, - ) diff --git a/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py b/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py index c0e2c8b3..bf4cd3bc 100644 --- a/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py +++ b/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py @@ -1,159 +1,69 @@ from __future__ import annotations -import logging -from datetime import UTC, datetime -from typing import Annotated - -from fastapi import APIRouter, Depends, HTTPException, Query -from starlette.status import ( - HTTP_404_NOT_FOUND, - HTTP_409_CONFLICT, - HTTP_502_BAD_GATEWAY, - HTTP_503_SERVICE_UNAVAILABLE, -) -from stitch.client import StitchAPIError +from fastapi import APIRouter, Depends +from pydantic import BaseModel, Field from stitch.auth.permissions import SERVICE_LLM_SUGGEST +from stitch.jobs import FingerprintPolicy, InMemoryJobStore, JobManager, make_job_router -from stitch.llm.auth import CurrentUser, require_permissions -from stitch.llm.azure_responses import AzureResponsesClient, extract_public_citations -from stitch.llm.client import StitchApiClient -from stitch.llm.entities import FieldSuggestionResponse -from stitch.llm.errors import ( - AzureResponsesError, - FieldAlreadyPopulatedError, - LLMConfigurationError, - ModelOutputError, -) -from stitch.llm.suggestions import ( +from stitch.llm.auth import AuthContext, require_permissions +from stitch.llm.entities import FieldSuggestionResponse, User +from stitch.llm.jobs import ( AllowedSuggestionField, - build_field_suggestion_input, - ensure_field_is_missing, - is_string_suggestion_field, - parse_field_suggestion_response, - sanitize_and_validate_suggested_value, + FieldSuggestionParams, + run_suggestion, ) -from stitch.llm.settings import get_settings -logger = logging.getLogger(__name__) +# Suggestions are tracked per (resource_id, field) with no expiry: once a pair +# has a result it is reused indefinitely (decoupled from the original caller, so +# a later user sees that a backfill was attempted). Failed runs are kept/visible +# but not reused, so the next request retries; `force` bypasses reuse entirely. +_manager: JobManager[FieldSuggestionParams, FieldSuggestionResponse] = JobManager( + run_suggestion, + policy=FingerprintPolicy(), + recent_within=None, + reuse_failed=False, + store=InMemoryJobStore(retention=None), +) -PLACEHOLDER_LLM_VALUE = ":warning: placeholder LLM value" -PLACEHOLDER_LLM_MODEL = "placeholder-llm" -router = APIRouter( - prefix="/oil-gas-fields", - tags=["oil_gas_fields"], - responses={404: {"description": "Not found"}}, -) +def get_job_manager() -> JobManager[FieldSuggestionParams, FieldSuggestionResponse]: + return _manager + + +class StartSuggestionRequest(BaseModel): + resource_id: int + field: AllowedSuggestionField + force: bool = Field( + default=False, + description="Re-run even if a suggestion for this (resource, field) exists.", + ) + +def _to_params(request: StartSuggestionRequest) -> FieldSuggestionParams: + # `force` is intentionally dropped so it never participates in the dedup key. + return FieldSuggestionParams(resource_id=request.resource_id, field=request.field) -@router.get( - "/{id}", - response_model=FieldSuggestionResponse, + +def _extract_user_label(user: User) -> str: + return user.name or user.email or user.sub + + +async def initiated_by(auth_context: AuthContext) -> str: + return _extract_user_label(auth_context.user) + + +_job_router = make_job_router( + _manager, + start_request_model=StartSuggestionRequest, + params_model=FieldSuggestionParams, + to_params=_to_params, + force_attr="force", + result_model=FieldSuggestionResponse, dependencies=[Depends(require_permissions(SERVICE_LLM_SUGGEST))], + initiated_by=initiated_by, + tags=["oil_gas_fields"], ) -async def suggest_oil_gas_field_value( - *, - _user: CurrentUser, - id: int, - field: Annotated[AllowedSuggestionField, Query()], -) -> FieldSuggestionResponse: - observed_at = datetime.now(UTC) - try: - async with StitchApiClient() as stitch_client: - detail_view = await stitch_client.get_oil_gas_field_detail(id) - except StitchAPIError as exc: - if exc.status_code == HTTP_404_NOT_FOUND: - raise HTTPException( - status_code=HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - logger.exception("Stitch API request failed for resource %s", id) - raise HTTPException( - status_code=HTTP_502_BAD_GATEWAY, - detail="Failed to fetch resource detail from Stitch API.", - ) from exc - except LLMConfigurationError as exc: - raise HTTPException( - status_code=HTTP_503_SERVICE_UNAVAILABLE, - detail=str(exc), - ) from exc - except ModelOutputError as exc: - raise HTTPException( - status_code=HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc - - try: - ensure_field_is_missing(detail_view, field) - except FieldAlreadyPopulatedError as exc: - raise HTTPException(status_code=HTTP_409_CONFLICT, detail=str(exc)) from exc - - input_messages = build_field_suggestion_input( - resource_id=id, - field=field, - detail_view=detail_view, - ) - settings = get_settings() - - if settings.auth_disabled and not settings.azure_openai_configured: - fallback_value = ( - PLACEHOLDER_LLM_VALUE if is_string_suggestion_field(field) else None - ) - return FieldSuggestionResponse( - resource_id=id, - field=field, - value=fallback_value, - citations=[], - query_succeeded=True, - model=PLACEHOLDER_LLM_MODEL, - rationale=( - "Foundry is not configured in auth-disabled mode; returned a local " - "placeholder value." - if fallback_value is not None - else "Foundry is not configured in auth-disabled mode; no safe " - "placeholder exists for this field type." - ), - observed_at=observed_at, - foundry_request={}, - foundry_response={}, - ) - - try: - async with AzureResponsesClient() as llm_client: - llm_result = await llm_client.generate_field_suggestion( - field=field, - input_messages=input_messages, - ) - parsed = parse_field_suggestion_response(llm_result.output_text) - citations = extract_public_citations(llm_result.response_payload) - if parsed.value is None or not citations: - value = None - citations = [] - else: - value = sanitize_and_validate_suggested_value( - detail_data=detail_view.data, - field=field, - value=parsed.value, - ) - except LLMConfigurationError as exc: - raise HTTPException( - status_code=HTTP_503_SERVICE_UNAVAILABLE, - detail=str(exc), - ) from exc - except (AzureResponsesError, ModelOutputError) as exc: - raise HTTPException( - status_code=HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc - - return FieldSuggestionResponse( - resource_id=id, - field=field, - value=value, - citations=citations, - query_succeeded=True, - model=llm_result.model, - rationale=parsed.rationale, - observed_at=observed_at, - foundry_request=llm_result.request_payload, - foundry_response=llm_result.response_payload, - ) + +# Namespace the job endpoints under /oil-gas-fields (→ /api/v1/oil-gas-fields/start, …). +router = APIRouter(prefix="/oil-gas-fields") +router.include_router(_job_router) diff --git a/deployments/stitch-llm/tests/test_downstream_auth.py b/deployments/stitch-llm/tests/test_downstream_auth.py new file mode 100644 index 00000000..93cface0 --- /dev/null +++ b/deployments/stitch-llm/tests/test_downstream_auth.py @@ -0,0 +1,30 @@ +"""stitch-llm authenticates downstream with its own machine identity. + +Suggestions run as detached background jobs, so the caller's token is gone when +the job executes — passthrough is not an option here. +""" + +import pytest +from stitch.client.auth import STITCH_CLIENT_BEARER_TOKEN_ENV_VAR +from stitch.service.auth import AuthMode + +from stitch.llm import client as client_module + + +def test_downstream_uses_machine_identity() -> None: + assert client_module._DOWNSTREAM_AUTH_MODE is AuthMode.machine + + +def test_validate_downstream_requires_machine_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, raising=False) + with pytest.raises(ValueError): + client_module.validate_downstream_auth_config_at_startup() + + +def test_validate_downstream_passes_with_machine_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, "machine-tok") + client_module.validate_downstream_auth_config_at_startup() diff --git a/deployments/stitch-llm/tests/test_oil_gas_fields_api.py b/deployments/stitch-llm/tests/test_oil_gas_fields_api.py index 19627a4d..c281daed 100644 --- a/deployments/stitch-llm/tests/test_oil_gas_fields_api.py +++ b/deployments/stitch-llm/tests/test_oil_gas_fields_api.py @@ -1,26 +1,28 @@ from __future__ import annotations -from contextlib import AbstractAsyncContextManager import json +import time +from contextlib import AbstractAsyncContextManager +from datetime import UTC, datetime import pytest from fastapi.testclient import TestClient -from stitch.client import StitchAPIError - from stitch.auth import TokenClaims from stitch.auth.permissions import SERVICE_LLM_SUGGEST -from stitch.llm import auth as auth_module -from stitch.llm.auth import get_current_user, get_token_claims +from stitch.client import StitchAPIError +from stitch.ogsi.model import GemSource, OGFieldDetailView, SourceRecord +from stitch.ogsi.model.og_field import OilGasFieldBase +from stitch.service.auth import RequestAuthContext + +from stitch.llm import jobs as jobs_module +from stitch.llm import main as main_module +from stitch.llm.auth import get_request_auth_context, get_token_claims from stitch.llm.azure_responses import AzureResponsesResult from stitch.llm.entities import User from stitch.llm.errors import LLMConfigurationError -from stitch.llm import main as main_module from stitch.llm.main import app -from stitch.llm.routers import oil_gas_fields as route_module +from stitch.llm.routers.oil_gas_fields import get_job_manager from stitch.llm.settings import Settings -from stitch.ogsi.model import GemSource, OGFieldDetailView, SourceRecord -from stitch.ogsi.model.og_field import OilGasFieldBase -from datetime import UTC, datetime def make_detail_view(**data) -> OGFieldDetailView: @@ -112,78 +114,52 @@ async def generate_field_suggestion(self, *, field, input_messages): ) -@pytest.fixture -def test_client(monkeypatch: pytest.MonkeyPatch): - async def override_current_user() -> User: - return User( - id=1, - sub="test|user", - email="test@example.com", - name="Test User", - ) - - def override_token_claims() -> TokenClaims: - return TokenClaims( - sub="test|user", - permissions=frozenset({SERVICE_LLM_SUGGEST}), - ) - - test_settings = Settings( - auth_disabled=True, +def _settings(*, auth_disabled: bool) -> Settings: + return Settings( + auth_disabled=auth_disabled, azure_openai_base_url=None, azure_openai_api_key=None, azure_openai_model=None, ) - monkeypatch.setattr(auth_module, "get_settings", lambda: test_settings) - monkeypatch.setattr(route_module, "get_settings", lambda: test_settings) - monkeypatch.setattr( - main_module, "validate_downstream_auth_config_at_startup", lambda: None - ) - app.dependency_overrides[get_current_user] = override_current_user - app.dependency_overrides[get_token_claims] = override_token_claims - with TestClient(app) as client: - yield client - app.dependency_overrides.clear() +@pytest.fixture(autouse=True) +def reset_job_manager(): + get_job_manager().reset() + yield + get_job_manager().reset() -def test_get_suggestion_requires_service_permission( - monkeypatch: pytest.MonkeyPatch, -) -> None: - async def override_current_user() -> User: - return User( - id=1, - sub="test|user", - email="test@example.com", - name="Test User", - ) - - def override_token_claims() -> TokenClaims: - return TokenClaims(sub="test|user", permissions=frozenset()) - - test_settings = Settings( - auth_disabled=True, - azure_openai_base_url=None, - azure_openai_api_key=None, - azure_openai_model=None, - ) - monkeypatch.setattr(auth_module, "get_settings", lambda: test_settings) - monkeypatch.setattr(route_module, "get_settings", lambda: test_settings) +@pytest.fixture +def test_client(monkeypatch: pytest.MonkeyPatch): + # Default: auth-disabled, Azure unconfigured (placeholder mode for the job). + test_settings = _settings(auth_disabled=True) + monkeypatch.setattr(jobs_module, "get_settings", lambda: test_settings) monkeypatch.setattr( main_module, "validate_downstream_auth_config_at_startup", lambda: None ) - app.dependency_overrides[get_current_user] = override_current_user + + def override_token_claims() -> TokenClaims: + return TokenClaims( + sub="test|user", permissions=frozenset({SERVICE_LLM_SUGGEST}) + ) + + async def override_request_auth_context() -> RequestAuthContext: + return RequestAuthContext( + user=User( + id=1, sub="test|user", email="test@example.com", name="Test User" + ), + bearer_token="test-token", + ) + app.dependency_overrides[get_token_claims] = override_token_claims + app.dependency_overrides[get_request_auth_context] = override_request_auth_context with TestClient(app) as client: - response = client.get("/api/v1/oil-gas-fields/42?field=basin") + yield client app.dependency_overrides.clear() - assert response.status_code == 403 - assert SERVICE_LLM_SUGGEST in response.json()["detail"] - def install_fakes( monkeypatch: pytest.MonkeyPatch, @@ -192,22 +168,74 @@ def install_fakes( azure_client: FakeAzureResponsesClient | None = None, ) -> FakeAzureResponsesClient: azure_client = azure_client or FakeAzureResponsesClient() - monkeypatch.setattr(route_module, "StitchApiClient", lambda: stitch_client) - monkeypatch.setattr(route_module, "AzureResponsesClient", lambda: azure_client) + monkeypatch.setattr(jobs_module, "StitchApiClient", lambda: stitch_client) + monkeypatch.setattr(jobs_module, "AzureResponsesClient", lambda: azure_client) return azure_client def enable_foundry_mode(monkeypatch: pytest.MonkeyPatch) -> None: - settings = Settings( - auth_disabled=False, - azure_openai_base_url=None, - azure_openai_api_key=None, - azure_openai_model=None, + monkeypatch.setattr( + jobs_module, "get_settings", lambda: _settings(auth_disabled=False) ) - monkeypatch.setattr(route_module, "get_settings", lambda: settings) -def test_get_suggestion_returns_validated_value( +def _start( + client: TestClient, *, resource_id: int = 42, field: str = "basin", force=False +): + return client.post( + "/api/v1/oil-gas-fields/start", + json={"resource_id": resource_id, "field": field, "force": force}, + ) + + +def _poll(client: TestClient, job_id: str, *, timeout: float = 5.0) -> dict: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + body = client.get(f"/api/v1/oil-gas-fields/status/{job_id}").json() + if body["state"] != "running": + return body + time.sleep(0.02) + raise AssertionError("job did not finish within timeout") + + +def _run(client: TestClient, **kwargs) -> dict: + started = _start(client, **kwargs) + assert started.status_code == 202 + return _poll(client, started.json()["job_id"]) + + +def test_start_requires_service_permission( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + jobs_module, "get_settings", lambda: _settings(auth_disabled=True) + ) + monkeypatch.setattr( + main_module, "validate_downstream_auth_config_at_startup", lambda: None + ) + + def override_token_claims() -> TokenClaims: + return TokenClaims(sub="test|user", permissions=frozenset()) + + async def override_request_auth_context() -> RequestAuthContext: + return RequestAuthContext( + user=User(id=1, sub="test|user", email="t@example.com", name="T"), + bearer_token="x", + ) + + app.dependency_overrides[get_token_claims] = override_token_claims + app.dependency_overrides[get_request_auth_context] = override_request_auth_context + + with TestClient(app) as client: + response = _start(client) + + app.dependency_overrides.clear() + + assert response.status_code == 403 + assert SERVICE_LLM_SUGGEST in response.json()["detail"] + + +def test_job_returns_validated_value( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -241,14 +269,13 @@ def test_get_suggestion_returns_validated_value( ), ) - response = test_client.get("/api/v1/oil-gas-fields/42?field=basin") - - assert response.status_code == 200 - body = response.json() - assert body["observed_at"].endswith("Z") + final = _run(test_client) + assert final["state"] == "succeeded" + result = final["result"] + assert result["observed_at"].endswith("Z") prompt_payload = json.loads(azure_client.calls[0]["input_messages"][1]["content"]) assert "source_record" not in prompt_payload["source_records"][0] - assert body == { + assert result == { "resource_id": 42, "field": "basin", "value": "Permian Basin", @@ -258,7 +285,7 @@ def test_get_suggestion_returns_validated_value( "query_succeeded": True, "model": "test-model", "rationale": "Public sources identify the basin.", - "observed_at": body["observed_at"], + "observed_at": result["observed_at"], "foundry_request": { "model": "test-model", "input": azure_client.calls[0]["input_messages"], @@ -290,7 +317,7 @@ def test_get_suggestion_returns_validated_value( assert azure_client.calls[0]["field"] == "basin" -def test_get_suggestion_returns_409_when_field_populated( +def test_job_fails_when_field_populated( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -299,13 +326,13 @@ def test_get_suggestion_returns_409_when_field_populated( ) azure_client = install_fakes(monkeypatch, stitch_client=stitch_client) - response = test_client.get("/api/v1/oil-gas-fields/42?field=basin") - - assert response.status_code == 409 + final = _run(test_client) + assert final["state"] == "failed" + assert final["error"] assert azure_client.calls == [] -def test_get_suggestion_maps_stitch_404( +def test_job_fails_on_stitch_404( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -314,35 +341,16 @@ def test_get_suggestion_maps_stitch_404( ) install_fakes(monkeypatch, stitch_client=stitch_client) - response = test_client.get("/api/v1/oil-gas-fields/42?field=basin") + final = _run(test_client) + assert final["state"] == "failed" + assert "missing" in final["error"] - assert response.status_code == 404 - -def test_get_suggestion_maps_missing_azure_config( +def test_job_fails_on_missing_azure_config( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr( - auth_module, - "get_settings", - lambda: Settings( - auth_disabled=False, - azure_openai_base_url=None, - azure_openai_api_key=None, - azure_openai_model=None, - ), - ) - monkeypatch.setattr( - route_module, - "get_settings", - lambda: Settings( - auth_disabled=False, - azure_openai_base_url=None, - azure_openai_api_key=None, - azure_openai_model=None, - ), - ) + enable_foundry_mode(monkeypatch) stitch_client = FakeStitchApiClient(detail_view=make_detail_view(basin=None)) install_fakes( monkeypatch, @@ -352,29 +360,29 @@ def test_get_suggestion_maps_missing_azure_config( ), ) - response = test_client.get("/api/v1/oil-gas-fields/42?field=basin") - - assert response.status_code == 503 + final = _run(test_client) + assert final["state"] == "failed" + assert "Azure OpenAI" in final["error"] -def test_get_suggestion_returns_placeholder_when_auth_disabled_and_azure_missing( +def test_job_placeholder_when_auth_disabled_and_azure_missing( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: stitch_client = FakeStitchApiClient(detail_view=make_detail_view(basin=None)) azure_client = install_fakes(monkeypatch, stitch_client=stitch_client) - response = test_client.get("/api/v1/oil-gas-fields/42?field=basin") - - assert response.status_code == 200 - assert response.json()["value"] == ":warning: placeholder LLM value" - assert response.json()["citations"] == [] - assert response.json()["model"] == "placeholder-llm" - assert response.json()["observed_at"].endswith("Z") + final = _run(test_client) + assert final["state"] == "succeeded" + result = final["result"] + assert result["value"] == ":warning: placeholder LLM value" + assert result["citations"] == [] + assert result["model"] == "placeholder-llm" + assert result["observed_at"].endswith("Z") assert azure_client.calls == [] -def test_get_suggestion_returns_null_for_non_string_placeholder_fallback( +def test_job_null_for_non_string_placeholder_fallback( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -383,17 +391,15 @@ def test_get_suggestion_returns_null_for_non_string_placeholder_fallback( ) azure_client = install_fakes(monkeypatch, stitch_client=stitch_client) - response = test_client.get("/api/v1/oil-gas-fields/42?field=discovery_year") - - assert response.status_code == 200 - assert response.json()["value"] is None - assert response.json()["citations"] == [] - assert response.json()["model"] == "placeholder-llm" - assert response.json()["observed_at"].endswith("Z") + final = _run(test_client, field="discovery_year") + assert final["state"] == "succeeded" + result = final["result"] + assert result["value"] is None + assert result["model"] == "placeholder-llm" assert azure_client.calls == [] -def test_get_suggestion_maps_invalid_model_output( +def test_job_fails_on_invalid_model_output( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -409,7 +415,7 @@ def test_get_suggestion_maps_invalid_model_output( response_payload={ "id": "resp_test", "model": "test-model", - "output_text": "VALUE: Subsea\nRATIONALE: Public sources identify the location type.", + "output_text": "VALUE: Subsea\nRATIONALE: ...", "output": [ { "content": [ @@ -429,12 +435,11 @@ def test_get_suggestion_maps_invalid_model_output( ), ) - response = test_client.get("/api/v1/oil-gas-fields/42?field=location_type") + final = _run(test_client, field="location_type") + assert final["state"] == "failed" - assert response.status_code == 502 - -def test_get_suggestion_returns_null_when_no_public_citation_found( +def test_job_null_when_no_public_citation_found( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -448,55 +453,78 @@ def test_get_suggestion_returns_null_when_no_public_citation_found( ), ) - response = test_client.get("/api/v1/oil-gas-fields/42?field=basin") + final = _run(test_client) + assert final["state"] == "succeeded" + result = final["result"] + assert result["value"] is None + assert result["citations"] == [] + assert result["query_succeeded"] is True + - assert response.status_code == 200 - assert response.json()["value"] is None - assert response.json()["citations"] == [] - assert response.json()["query_succeeded"] is True +# --------------------------------------------------------------------------- # +# Job-specific behavior: dedup per (resource_id, field), force, failed-retry +# --------------------------------------------------------------------------- # -def test_get_suggestion_returns_null_when_annotations_absent( +def test_same_resource_field_reuses_existing_job( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, ) -> None: - enable_foundry_mode(monkeypatch) - output_text = ( - "VALUE: Songliao Basin\n" - "RATIONALE: Public sources describing Daqing Oil Field place it in the Songliao Basin." - ) stitch_client = FakeStitchApiClient(detail_view=make_detail_view(basin=None)) - install_fakes( - monkeypatch, - stitch_client=stitch_client, - azure_client=FakeAzureResponsesClient( - output_text=output_text, - response_payload={ - "id": "resp_test", - "model": "test-model", - "output": [ - { - "type": "message", - "content": [ - { - "type": "output_text", - "annotations": [], - "text": output_text, - } - ], - } - ], - }, - ), - ) + install_fakes(monkeypatch, stitch_client=stitch_client) - response = test_client.get("/api/v1/oil-gas-fields/42?field=basin") + first = _start(test_client) + job_id = first.json()["job_id"] + _poll(test_client, job_id) - assert response.status_code == 200 - assert response.json()["value"] is None - assert response.json()["citations"] == [] - assert ( - response.json()["rationale"] - == "Public sources describing Daqing Oil Field place it in the Songliao Basin." - ) - assert response.json()["query_succeeded"] is True + # Same (resource_id, field) → reused (200, same job), even for a new caller. + second = _start(test_client) + assert second.status_code == 200 + assert second.json()["job_id"] == job_id + + +def test_distinct_pairs_get_distinct_jobs( + test_client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + stitch_client = FakeStitchApiClient(detail_view=make_detail_view(basin=None)) + install_fakes(monkeypatch, stitch_client=stitch_client) + + a = _start(test_client, field="basin") + b = _start(test_client, field="state_province") + assert a.status_code == 202 and b.status_code == 202 + assert a.json()["job_id"] != b.json()["job_id"] + + +def test_force_starts_a_new_run( + test_client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + stitch_client = FakeStitchApiClient(detail_view=make_detail_view(basin=None)) + install_fakes(monkeypatch, stitch_client=stitch_client) + + first = _start(test_client) + _poll(test_client, first.json()["job_id"]) + + forced = _start(test_client, force=True) + assert forced.status_code == 202 + assert forced.json()["job_id"] != first.json()["job_id"] + _poll(test_client, forced.json()["job_id"]) + + +def test_failed_pair_auto_retries( + test_client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + stitch_client = FakeStitchApiClient(error=StitchAPIError("boom", status_code=500)) + install_fakes(monkeypatch, stitch_client=stitch_client) + + first = _start(test_client) + first_final = _poll(test_client, first.json()["job_id"]) + assert first_final["state"] == "failed" + + # Failed runs are not reused → the next request retries with a new job. + second = _start(test_client) + assert second.status_code == 202 + assert second.json()["job_id"] != first.json()["job_id"] + _poll(test_client, second.json()["job_id"]) diff --git a/uv.lock b/uv.lock index 9ef0427e..d3176f71 100644 --- a/uv.lock +++ b/uv.lock @@ -1190,8 +1190,10 @@ dependencies = [ { name = "pydantic-settings" }, { name = "stitch-auth" }, { name = "stitch-client" }, + { name = "stitch-jobs" }, { name = "stitch-models" }, { name = "stitch-ogsi" }, + { name = "stitch-service" }, ] [package.dev-dependencies] @@ -1208,8 +1210,10 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "stitch-auth", editable = "packages/stitch-auth" }, { name = "stitch-client", editable = "packages/stitch-client" }, + { name = "stitch-jobs", editable = "packages/stitch-jobs" }, { name = "stitch-models", editable = "packages/stitch-models" }, { name = "stitch-ogsi", editable = "packages/stitch-ogsi" }, + { name = "stitch-service", editable = "packages/stitch-service" }, ] [package.metadata.requires-dev] From fa4c791c9fdac508f3a399ec9248df440aebb364 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 13:05:00 +0200 Subject: [PATCH 11/26] Update LLM frontend to reflect changes in deployment --- .../src/pages/ResourceDetailPage.jsx | 44 ++- .../src/pages/ResourceDetailPage.test.jsx | 324 +++++++++++++----- .../stitch-frontend/src/queries/api.js | 35 +- .../stitch-frontend/src/queries/api.test.js | 90 ++--- 4 files changed, 348 insertions(+), 145 deletions(-) diff --git a/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx b/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx index 6354c680..abba657c 100644 --- a/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx +++ b/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx @@ -5,9 +5,10 @@ import { useResourceDetail, useSourceDetail } from "../hooks/useResources"; import { createAuthenticatedFetcher } from "../auth/api"; import { useConfig } from "../config/useConfig"; import { - createLLMSuggestion, createMergeCandidate, createResource, + getLLMSuggestionStatus, + startLLMSuggestion, } from "../queries/api"; import SourceMixBar from "../components/SourceMixBar"; import SectionHeader from "../components/SectionHeader"; @@ -23,6 +24,10 @@ import { const LLM_AUDIT_PRODUCER = "stitch-frontend"; +// LLM suggestions run as async jobs; poll their status until terminal. +const SUGGESTION_POLL_INTERVAL_MS = 1000; +const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); + const OBSERVED_AT_FORMATTER = new Intl.DateTimeFormat(undefined, { year: "numeric", month: "short", @@ -182,6 +187,7 @@ function AISuggestionPanel({ endpoint, resourceId }) { const { getAccessTokenSilently } = useAuth0(); const fetcher = createAuthenticatedFetcher(config, getAccessTokenSilently); const [selectedField, setSelectedField] = useState(AI_SUGGESTION_FIELDS[0]); + const [forceRerun, setForceRerun] = useState(false); const [result, setResult] = useState(null); const [error, setError] = useState(""); const [isLoading, setIsLoading] = useState(false); @@ -201,14 +207,31 @@ function AISuggestionPanel({ endpoint, resourceId }) { setPersistState(null); try { - const suggestion = await createLLMSuggestion( + // Start (or join an existing) suggestion job, then poll until it + // finishes. A repeat for the same (resource, field) returns the prior + // run's result unless "Re-run" is checked. + let record = await startLLMSuggestion( config, - resourceId, - selectedField, + { resourceId, field: selectedField, force: forceRerun }, fetcher, endpoint, ); - setResult(suggestion); + + while (record.state === "running") { + await sleep(SUGGESTION_POLL_INTERVAL_MS); + record = await getLLMSuggestionStatus( + config, + record.job_id, + fetcher, + endpoint, + ); + } + + if (record.state === "failed") { + setError(record.error || "Suggestion job failed."); + } else { + setResult(record.result); + } } catch (err) { setError(err.message || "Failed to generate suggestion."); } finally { @@ -301,6 +324,17 @@ function AISuggestionPanel({ endpoint, resourceId }) { + + {error && (
{error} diff --git a/deployments/stitch-frontend/src/pages/ResourceDetailPage.test.jsx b/deployments/stitch-frontend/src/pages/ResourceDetailPage.test.jsx index 5a0a841f..c6495de1 100644 --- a/deployments/stitch-frontend/src/pages/ResourceDetailPage.test.jsx +++ b/deployments/stitch-frontend/src/pages/ResourceDetailPage.test.jsx @@ -347,22 +347,26 @@ describe("ResourceDetailPage", () => { ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(apiModule, "createLLMSuggestion").mockResolvedValue({ - resource_id: 1, - field: "basin", - value: "Songliao", - citations: [ - { - url: "https://example.com/daqing", - title: "Daqing citation", - }, - ], - query_succeeded: true, - model: "test-model", - rationale: "Public sources place Daqing in the Songliao Basin.", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: {}, - foundry_response: {}, + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [ + { + url: "https://example.com/daqing", + title: "Daqing citation", + }, + ], + query_succeeded: true, + model: "test-model", + rationale: "Public sources place Daqing in the Songliao Basin.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, }); const user = userEvent.setup(); @@ -380,22 +384,128 @@ describe("ResourceDetailPage", () => { ).toHaveAttribute("href", "https://example.com/daqing"); }); + it("polls the status endpoint until the job finishes, then renders the result", async () => { + vi.mocked(useResourceDetail).mockReturnValue({ + ...defaultHookReturn, + data: mockDetailView, + }); + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "running", + result: null, + }); + vi.spyOn(apiModule, "getLLMSuggestionStatus").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [], + query_succeeded: true, + model: "test-model", + rationale: "Supported.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, + }); + const user = userEvent.setup(); + + renderWithQueryClient(); + await user.click( + screen.getByRole("button", { name: /generate suggestion/i }), + ); + + expect( + await screen.findByText("Songliao", {}, { timeout: 3000 }), + ).toBeInTheDocument(); + expect(apiModule.getLLMSuggestionStatus).toHaveBeenCalled(); + }); + + it("renders the failure when the job fails", async () => { + vi.mocked(useResourceDetail).mockReturnValue({ + ...defaultHookReturn, + data: mockDetailView, + }); + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "failed", + result: null, + error: "field already populated", + }); + const user = userEvent.setup(); + + renderWithQueryClient(); + await user.click( + screen.getByRole("button", { name: /generate suggestion/i }), + ); + + expect( + await screen.findByText("field already populated"), + ).toBeInTheDocument(); + }); + + it("passes force=true when Re-run is checked", async () => { + vi.mocked(useResourceDetail).mockReturnValue({ + ...defaultHookReturn, + data: mockDetailView, + }); + const startSpy = vi + .spyOn(apiModule, "startLLMSuggestion") + .mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [], + query_succeeded: true, + model: "test-model", + rationale: "Supported.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, + }); + const user = userEvent.setup(); + + renderWithQueryClient(); + await user.click(screen.getByRole("checkbox", { name: /re-run/i })); + await user.click( + screen.getByRole("button", { name: /generate suggestion/i }), + ); + + await screen.findByText("Songliao"); + expect(startSpy).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ field: "basin", force: true }), + expect.anything(), + expect.anything(), + ); + }); + it("renders a no-answer suggestion state without treating it as an error", async () => { vi.mocked(useResourceDetail).mockReturnValue({ ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(apiModule, "createLLMSuggestion").mockResolvedValue({ - resource_id: 1, - field: "basin", - value: null, - citations: [], - query_succeeded: true, - model: "test-model", - rationale: "I could not find a grounded public source for this field.", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: {}, - foundry_response: {}, + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: null, + citations: [], + query_succeeded: true, + model: "test-model", + rationale: "I could not find a grounded public source for this field.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, }); const user = userEvent.setup(); @@ -419,17 +529,21 @@ describe("ResourceDetailPage", () => { ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(apiModule, "createLLMSuggestion").mockResolvedValue({ - resource_id: 1, - field: "basin", - value: "Songliao", - citations: [], - query_succeeded: true, - model: "test-model", - rationale: "Supported.", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: {}, - foundry_response: {}, + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [], + query_succeeded: true, + model: "test-model", + rationale: "Supported.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, }); const user = userEvent.setup(); @@ -448,19 +562,23 @@ describe("ResourceDetailPage", () => { ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(apiModule, "createLLMSuggestion").mockResolvedValue({ - resource_id: 1, - field: "basin", - value: "Songliao", - citations: [ - { url: "https://example.com/source", title: "Example Source" }, - ], - query_succeeded: true, - model: "test-model", - rationale: "Supported.", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: { request: true }, - foundry_response: { response: true }, + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [ + { url: "https://example.com/source", title: "Example Source" }, + ], + query_succeeded: true, + model: "test-model", + rationale: "Supported.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: { request: true }, + foundry_response: { response: true }, + }, }); const createResourceSpy = vi .spyOn(apiModule, "createResource") @@ -545,17 +663,21 @@ describe("ResourceDetailPage", () => { ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(apiModule, "createLLMSuggestion").mockResolvedValue({ - resource_id: 1, - field: "basin", - value: "Songliao", - citations: [], - query_succeeded: true, - model: "test-model", - rationale: "Supported.", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: {}, - foundry_response: {}, + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [], + query_succeeded: true, + model: "test-model", + rationale: "Supported.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, }); vi.spyOn(apiModule, "createResource").mockRejectedValue( new Error( @@ -596,17 +718,21 @@ describe("ResourceDetailPage", () => { ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(apiModule, "createLLMSuggestion").mockResolvedValue({ - resource_id: 1, - field: "basin", - value: "Songliao", - citations: [], - query_succeeded: true, - model: "test-model", - rationale: "Supported.", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: {}, - foundry_response: {}, + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [], + query_succeeded: true, + model: "test-model", + rationale: "Supported.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, }); vi.spyOn(apiModule, "createResource").mockResolvedValue({ id: 123 }); const createMergeCandidateSpy = vi @@ -636,17 +762,21 @@ describe("ResourceDetailPage", () => { ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(apiModule, "createLLMSuggestion").mockResolvedValue({ - resource_id: 1, - field: "basin", - value: "Songliao", - citations: [], - query_succeeded: true, - model: "test-model", - rationale: "Supported.", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: {}, - foundry_response: {}, + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [], + query_succeeded: true, + model: "test-model", + rationale: "Supported.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, }); vi.spyOn(apiModule, "createResource").mockRejectedValue( new Error("create failed"), @@ -671,17 +801,21 @@ describe("ResourceDetailPage", () => { ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(apiModule, "createLLMSuggestion").mockResolvedValue({ - resource_id: 1, - field: "basin", - value: "Songliao", - citations: [], - query_succeeded: true, - model: "test-model", - rationale: "Supported.", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: {}, - foundry_response: {}, + vi.spyOn(apiModule, "startLLMSuggestion").mockResolvedValue({ + job_id: "job-1", + state: "succeeded", + result: { + resource_id: 1, + field: "basin", + value: "Songliao", + citations: [], + query_succeeded: true, + model: "test-model", + rationale: "Supported.", + observed_at: "2026-05-13T12:00:00Z", + foundry_request: {}, + foundry_response: {}, + }, }); vi.spyOn(apiModule, "createResource").mockResolvedValue({ id: 123 }); vi.spyOn(apiModule, "createMergeCandidate").mockResolvedValue({ diff --git a/deployments/stitch-frontend/src/queries/api.js b/deployments/stitch-frontend/src/queries/api.js index 63bc3d4d..1cd431d9 100644 --- a/deployments/stitch-frontend/src/queries/api.js +++ b/deployments/stitch-frontend/src/queries/api.js @@ -65,17 +65,21 @@ export async function getResourceDetail( return data; } -export async function createLLMSuggestion( +// LLM suggestions run as async jobs (decoupled from the caller): start one, +// then poll its status until it leaves the "running" state. The job is tracked +// per (resource_id, field) with no expiry — a repeat start returns the existing +// job (200) unless `force` is set. +export async function startLLMSuggestion( config, - id, - field, + { resourceId, field, force = false }, fetcher, endpoint = "resources", ) { - const url = new URL(`${config.stitchLlmBaseUrl}/${endpoint}/${id}`); - url.searchParams.set("field", field); + const url = `${config.stitchLlmBaseUrl}/${endpoint}/start`; const response = await fetcher(url, { - method: "GET", + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ resource_id: resourceId, field, force }), }); if (!response.ok) { @@ -88,6 +92,25 @@ export async function createLLMSuggestion( return await response.json(); } +export async function getLLMSuggestionStatus( + config, + jobId, + fetcher, + endpoint = "resources", +) { + const url = `${config.stitchLlmBaseUrl}/${endpoint}/status/${jobId}`; + const response = await fetcher(url, { method: "GET" }); + + if (!response.ok) { + const detail = await getErrorDetail(response); + const error = new Error(detail); + error.status = response.status; + throw error; + } + + return await response.json(); +} + function formatApiErrorDetail(detail, fallbackStatus) { if (typeof detail === "string" && detail) return detail; if (Array.isArray(detail) || (detail && typeof detail === "object")) { diff --git a/deployments/stitch-frontend/src/queries/api.test.js b/deployments/stitch-frontend/src/queries/api.test.js index 6a9acab0..3e155ec5 100644 --- a/deployments/stitch-frontend/src/queries/api.test.js +++ b/deployments/stitch-frontend/src/queries/api.test.js @@ -1,12 +1,13 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; import { - createLLMSuggestion, createMergeCandidate, createResource, + getLLMSuggestionStatus, getResourceFilterOptions, getResources, getResource, reviewMergeCandidate, + startLLMSuggestion, } from "./api"; describe("API Functions", () => { @@ -236,70 +237,81 @@ describe("API Functions", () => { }); }); - describe("createLLMSuggestion", () => { - it("calls the stitch-llm GET endpoint with the requested field", async () => { + describe("startLLMSuggestion", () => { + it("POSTs resource_id/field/force to the stitch-llm start endpoint", async () => { mockFetcher.mockResolvedValueOnce({ ok: true, - status: 200, - json: async () => ({ - resource_id: 42, - field: "basin", - value: "Songliao Basin", - citations: [], - query_succeeded: true, - model: "test-model", - observed_at: "2026-05-13T12:00:00Z", - foundry_request: {}, - foundry_response: {}, - }), + status: 202, + json: async () => ({ job_id: "job-1", state: "running", result: null }), }); - const result = await createLLMSuggestion( + const record = await startLLMSuggestion( config, - 42, - "basin", + { resourceId: 42, field: "basin", force: true }, mockFetcher, "oil-gas-fields", ); expect(mockFetcher).toHaveBeenCalledWith( - new URL("http://localhost:8002/api/v1/oil-gas-fields/42?field=basin"), - { method: "GET" }, + "http://localhost:8002/api/v1/oil-gas-fields/start", + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + resource_id: 42, + field: "basin", + force: true, + }), + }, ); - expect(result.value).toBe("Songliao Basin"); + expect(record.job_id).toBe("job-1"); }); it("surfaces structured JSON detail and status on failure", async () => { mockFetcher.mockResolvedValueOnce({ ok: false, - status: 502, - text: async () => - JSON.stringify({ - detail: "LLM upstream returned an invalid response", - }), + status: 403, + text: async () => JSON.stringify({ detail: "missing permission" }), }); await expect( - createLLMSuggestion(config, 42, "basin", mockFetcher, "oil-gas-fields"), + startLLMSuggestion( + config, + { resourceId: 42, field: "basin" }, + mockFetcher, + "oil-gas-fields", + ), ).rejects.toMatchObject({ - message: "LLM upstream returned an invalid response", - status: 502, + message: "missing permission", + status: 403, }); }); + }); - it("falls back to plain-text error bodies and preserves status", async () => { + describe("getLLMSuggestionStatus", () => { + it("GETs the status endpoint for a job id", async () => { mockFetcher.mockResolvedValueOnce({ - ok: false, - status: 503, - text: async () => "Service temporarily unavailable", + ok: true, + status: 200, + json: async () => ({ + job_id: "job-1", + state: "succeeded", + result: { field: "basin", value: "Songliao Basin" }, + }), }); - await expect( - createLLMSuggestion(config, 42, "basin", mockFetcher, "oil-gas-fields"), - ).rejects.toMatchObject({ - message: "Service temporarily unavailable", - status: 503, - }); + const record = await getLLMSuggestionStatus( + config, + "job-1", + mockFetcher, + "oil-gas-fields", + ); + + expect(mockFetcher).toHaveBeenCalledWith( + "http://localhost:8002/api/v1/oil-gas-fields/status/job-1", + { method: "GET" }, + ); + expect(record.result.value).toBe("Songliao Basin"); }); }); From eff24dcb1b15ccda2b9b7bf39b3657f91a5ca4f3 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 14:19:27 +0200 Subject: [PATCH 12/26] Better interface for jobs results in UI --- .../src/components/JobResultList.jsx | 84 +++++++++ .../src/components/JobTriggerButton.jsx | 43 +++++ .../src/components/LastUpdated.jsx | 28 +++ .../stitch-frontend/src/hooks/useJobRunner.js | 132 ++++++++++++++ .../src/pages/ResourceDetailPage.jsx | 161 ++++++++++-------- .../src/pages/ResourceDetailPage.test.jsx | 98 +++++++---- .../stitch-frontend/src/queries/api.js | 47 +---- .../stitch-frontend/src/queries/api.test.js | 80 --------- .../stitch-frontend/src/queries/jobs.js | 60 +++++++ .../stitch-frontend/src/queries/jobs.test.js | 75 ++++++++ 10 files changed, 581 insertions(+), 227 deletions(-) create mode 100644 deployments/stitch-frontend/src/components/JobResultList.jsx create mode 100644 deployments/stitch-frontend/src/components/JobTriggerButton.jsx create mode 100644 deployments/stitch-frontend/src/components/LastUpdated.jsx create mode 100644 deployments/stitch-frontend/src/hooks/useJobRunner.js create mode 100644 deployments/stitch-frontend/src/queries/jobs.js create mode 100644 deployments/stitch-frontend/src/queries/jobs.test.js diff --git a/deployments/stitch-frontend/src/components/JobResultList.jsx b/deployments/stitch-frontend/src/components/JobResultList.jsx new file mode 100644 index 00000000..58e78437 --- /dev/null +++ b/deployments/stitch-frontend/src/components/JobResultList.jsx @@ -0,0 +1,84 @@ +import { useState } from "react"; + +const STATE_STYLES = { + running: "border-warning/30 bg-warning-soft text-warning", + succeeded: "border-success/25 bg-success-soft text-success-strong", + failed: "border-danger/25 bg-danger-soft text-danger", +}; + +const DATE_FORMATTER = new Intl.DateTimeFormat(undefined, { + month: "short", + day: "numeric", + hour: "2-digit", + minute: "2-digit", +}); + +function StateBadge({ state }) { + if (!state) return null; + const classes = STATE_STYLES[state] ?? "border-line bg-surface text-ink"; + return ( + + {state} + + ); +} + +function formatStartedAt(value) { + if (!value) return "—"; + const date = new Date(value); + return Number.isNaN(date.getTime()) ? "—" : DATE_FORMATTER.format(date); +} + +// Collapsible list of job records, newest first, with the most recent expanded +// by default. Each service supplies `renderResult(record)` for the body. +export default function JobResultList({ records, renderResult }) { + // Per-item user overrides; absent → default (newest open, others collapsed). + // Derived rather than effect-driven so the newest run stays expanded as new + // runs arrive, without a set-state-in-effect. + const [overrides, setOverrides] = useState({}); + const newestId = records[0]?.job_id; + + if (!records.length) return null; + + const isOpen = (id) => overrides[id] ?? id === newestId; + + function toggle(id) { + setOverrides((prev) => ({ ...prev, [id]: !isOpen(id) })); + } + + return ( +
    + {records.map((record) => { + const open = isOpen(record.job_id); + return ( +
  1. + + {open && ( +
    + {renderResult(record)} +
    + )} +
  2. + ); + })} +
+ ); +} diff --git a/deployments/stitch-frontend/src/components/JobTriggerButton.jsx b/deployments/stitch-frontend/src/components/JobTriggerButton.jsx new file mode 100644 index 00000000..3b7ea5a7 --- /dev/null +++ b/deployments/stitch-frontend/src/components/JobTriggerButton.jsx @@ -0,0 +1,43 @@ +import Button from "./Button"; + +// Smart trigger for a job: its label reflects whether a result already exists +// and whether a forced re-run is requested, and it shows a spinner while the +// job is running/polling. +// +// labels: { running, show, create, recreate } +// - running → shown with a spinner while a run is in flight +// - show → a prior result exists and hasn't been revealed yet +// - recreate → force is toggled on (re-run) +// - create → no prior result; first run +export default function JobTriggerButton({ + running, + force, + hasExisting, + revealed, + labels, + onClick, + disabled = false, + variant = "secondary", +}) { + let label; + if (running) label = labels.running; + else if (force) label = labels.recreate; + else if (hasExisting && !revealed) label = labels.show; + else label = labels.create; + + return ( + + ); +} diff --git a/deployments/stitch-frontend/src/components/LastUpdated.jsx b/deployments/stitch-frontend/src/components/LastUpdated.jsx new file mode 100644 index 00000000..7a819131 --- /dev/null +++ b/deployments/stitch-frontend/src/components/LastUpdated.jsx @@ -0,0 +1,28 @@ +import { useEffect, useState } from "react"; + +function relativeLabel(at) { + const seconds = Math.max(0, Math.round((Date.now() - at) / 1000)); + if (seconds < 5) return "just now"; + if (seconds < 60) return `${seconds}s ago`; + const minutes = Math.round(seconds / 60); + if (minutes < 60) return `${minutes}m ago`; + return `${Math.round(minutes / 60)}h ago`; +} + +// Live "Updated N ago" indicator; re-renders on its own so the relative time +// keeps counting up after the last poll. +export default function LastUpdated({ at }) { + const [, tick] = useState(0); + + useEffect(() => { + if (!at) return undefined; + const id = setInterval(() => tick((n) => n + 1), 5000); + return () => clearInterval(id); + }, [at]); + + if (!at) return null; + + return ( + Updated {relativeLabel(at)} + ); +} diff --git a/deployments/stitch-frontend/src/hooks/useJobRunner.js b/deployments/stitch-frontend/src/hooks/useJobRunner.js new file mode 100644 index 00000000..18bd9a07 --- /dev/null +++ b/deployments/stitch-frontend/src/hooks/useJobRunner.js @@ -0,0 +1,132 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { getJobStatus, listJobs, startJob } from "../queries/jobs"; + +const POLL_INTERVAL_MS = 1000; + +function sortNewestFirst(records) { + return [...records].sort( + (a, b) => + new Date(b.started_at).getTime() - new Date(a.started_at).getTime(), + ); +} + +// Drives a Stitch job from the UI: loads prior jobs for the current params on +// mount, starts/auto-polls runs, and tracks the records (newest first). Shared +// by every job-shaped service (LLM, entity-linkage, ETL). +// +// - baseUrl: where the job routes live (POST /start, GET /status, GET /jobs). +// - fetcher: authenticated fetch wrapper (may change each render — captured by ref). +// - paramsKey: stable string for the current params; reloading keys off it. +// - matchesParams: predicate used to filter /jobs down to the current params. +export function useJobRunner({ baseUrl, fetcher, paramsKey, matchesParams }) { + const [records, setRecords] = useState([]); + const [isStarting, setIsStarting] = useState(false); + const [isPolling, setIsPolling] = useState(false); + const [error, setError] = useState(""); + const [lastUpdatedAt, setLastUpdatedAt] = useState(null); + + // Stable refs so the load effect doesn't churn on every parent re-render. + const fetcherRef = useRef(fetcher); + fetcherRef.current = fetcher; + const matchesRef = useRef(matchesParams); + matchesRef.current = matchesParams; + // Bumped whenever params change / on unmount, to cancel stale polls. + const generationRef = useRef(0); + + const upsert = useCallback((record) => { + setRecords((prev) => + sortNewestFirst([ + ...prev.filter((r) => r.job_id !== record.job_id), + record, + ]), + ); + setLastUpdatedAt(Date.now()); + }, []); + + const poll = useCallback( + async (jobId, generation) => { + setIsPolling(true); + try { + while (generationRef.current === generation) { + const record = await getJobStatus(baseUrl, jobId, fetcherRef.current); + if (generationRef.current !== generation) return; + upsert(record); + if (record.state !== "running") return; + await new Promise((resolve) => setTimeout(resolve, POLL_INTERVAL_MS)); + } + } catch (err) { + if (generationRef.current === generation) { + setError(err.message || "Failed to check job status."); + } + } finally { + if (generationRef.current === generation) setIsPolling(false); + } + }, + [baseUrl, upsert], + ); + + // Load existing jobs for the current params on mount / when params change. + useEffect(() => { + generationRef.current += 1; + const generation = generationRef.current; + setRecords([]); + setError(""); + setIsPolling(false); + + (async () => { + try { + const all = await listJobs(baseUrl, fetcherRef.current); + if (generationRef.current !== generation) return; + const mine = sortNewestFirst(all.filter((r) => matchesRef.current(r))); + setRecords(mine); + setLastUpdatedAt(Date.now()); + const running = mine.find((r) => r.state === "running"); + if (running) poll(running.job_id, generation); + } catch { + // No prior jobs (or list unavailable) — start from a clean slate. + if (generationRef.current === generation) setRecords([]); + } + })(); + + return () => { + generationRef.current += 1; // cancel any in-flight poll for this generation + }; + }, [baseUrl, paramsKey, poll]); + + const start = useCallback( + async (body) => { + setIsStarting(true); + setError(""); + const generation = generationRef.current; + try { + const record = await startJob(baseUrl, body, fetcherRef.current); + if (generationRef.current !== generation) return record; + upsert(record); + if (record.state === "running") poll(record.job_id, generation); + return record; + } catch (err) { + setError(err.message || "Failed to start job."); + return null; + } finally { + setIsStarting(false); + } + }, + [baseUrl, poll, upsert], + ); + + const current = records[0] ?? null; + const latestSucceeded = records.find((r) => r.state === "succeeded") ?? null; + + return { + records, + current, + latestSucceeded, + hasExisting: records.length > 0, + isRunning: isStarting || isPolling || current?.state === "running", + isStarting, + isPolling, + error, + lastUpdatedAt, + start, + }; +} diff --git a/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx b/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx index abba657c..00c6cc11 100644 --- a/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx +++ b/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx @@ -4,12 +4,11 @@ import { useParams, useNavigate } from "react-router-dom"; import { useResourceDetail, useSourceDetail } from "../hooks/useResources"; import { createAuthenticatedFetcher } from "../auth/api"; import { useConfig } from "../config/useConfig"; -import { - createMergeCandidate, - createResource, - getLLMSuggestionStatus, - startLLMSuggestion, -} from "../queries/api"; +import { createMergeCandidate, createResource } from "../queries/api"; +import { useJobRunner } from "../hooks/useJobRunner"; +import JobTriggerButton from "../components/JobTriggerButton"; +import JobResultList from "../components/JobResultList"; +import LastUpdated from "../components/LastUpdated"; import SourceMixBar from "../components/SourceMixBar"; import SectionHeader from "../components/SectionHeader"; import { FieldCard, FieldGrid } from "../components/FieldCard"; @@ -24,10 +23,6 @@ import { const LLM_AUDIT_PRODUCER = "stitch-frontend"; -// LLM suggestions run as async jobs; poll their status until terminal. -const SUGGESTION_POLL_INTERVAL_MS = 1000; -const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); - const OBSERVED_AT_FORMATTER = new Intl.DateTimeFormat(undefined, { year: "numeric", month: "short", @@ -188,62 +183,62 @@ function AISuggestionPanel({ endpoint, resourceId }) { const fetcher = createAuthenticatedFetcher(config, getAccessTokenSilently); const [selectedField, setSelectedField] = useState(AI_SUGGESTION_FIELDS[0]); const [forceRerun, setForceRerun] = useState(false); - const [result, setResult] = useState(null); - const [error, setError] = useState(""); - const [isLoading, setIsLoading] = useState(false); + const [revealed, setRevealed] = useState(false); const [isPersisting, setIsPersisting] = useState(false); const [persistState, setPersistState] = useState(null); + const [persistError, setPersistError] = useState(""); + + const job = useJobRunner({ + baseUrl: `${config.stitchLlmBaseUrl}/${endpoint}`, + fetcher, + paramsKey: `${resourceId}:${selectedField}`, + matchesParams: (record) => + record.params?.resource_id === resourceId && + record.params?.field === selectedField, + }); + // Persist (and the value/citation rendering) act on the latest succeeded run. + const result = job.latestSucceeded?.result ?? null; const canPersist = result?.value != null; const isPersistedCurrentSuggestion = result && persistState?.status === "success" && persistState.suggestionKey === getSuggestionSubmissionKey(result); + const error = job.error || persistError; - async function handleGenerateSuggestion() { - setIsLoading(true); - setError(""); - setResult(null); + function handleFieldChange(event) { + setSelectedField(event.target.value); + setForceRerun(false); + setRevealed(false); setPersistState(null); + setPersistError(""); + } - try { - // Start (or join an existing) suggestion job, then poll until it - // finishes. A repeat for the same (resource, field) returns the prior - // run's result unless "Re-run" is checked. - let record = await startLLMSuggestion( - config, - { resourceId, field: selectedField, force: forceRerun }, - fetcher, - endpoint, - ); - - while (record.state === "running") { - await sleep(SUGGESTION_POLL_INTERVAL_MS); - record = await getLLMSuggestionStatus( - config, - record.job_id, - fetcher, - endpoint, - ); - } + async function handleTrigger() { + setPersistState(null); + setPersistError(""); - if (record.state === "failed") { - setError(record.error || "Suggestion job failed."); - } else { - setResult(record.result); - } - } catch (err) { - setError(err.message || "Failed to generate suggestion."); - } finally { - setIsLoading(false); + // A prior suggestion exists and we're not forcing a new one → just reveal + // it; no LLM call. + if (job.hasExisting && !forceRerun && !revealed) { + setRevealed(true); + return; } + + setRevealed(true); + await job.start({ + resource_id: resourceId, + field: selectedField, + force: forceRerun, + }); + setForceRerun(false); } async function handlePersistSuggestion() { if (!result || result.value == null) return; setIsPersisting(true); - setError(""); + setPersistError(""); const persistIntentId = createPersistIntentId(); const resourcePayload = buildLLMResourcePayload({ @@ -280,13 +275,13 @@ function AISuggestionPanel({ endpoint, resourceId }) { resourceId: createdResource.id, suggestionKey, }); - setError( + setPersistError( `Suggestion saved as resource ${createdResource.id}, but the merge draft was not created.`, ); } } catch (err) { setPersistState(null); - setError(err.message || "Failed to persist suggestion."); + setPersistError(err.message || "Failed to persist suggestion."); } finally { setIsPersisting(false); } @@ -301,11 +296,7 @@ function AISuggestionPanel({ endpoint, resourceId }) { Field - +
- +
+ + +
{error && (
@@ -341,9 +341,24 @@ function AISuggestionPanel({ endpoint, resourceId }) {
)} - {result && } + {revealed && ( + + record.state === "succeeded" ? ( + + ) : record.state === "failed" ? ( +

+ {record.error || "Suggestion job failed."} +

+ ) : ( +

Generating…

+ ) + } + /> + )} - {canPersist && ( + {revealed && canPersist && (
- +
+
+ + +
+
- {error ? ( -
-

Run error

-
- -
-
- ) : null} - -
-
-

Run status

- + {job.error && ( +
+ {job.error}
-
- {!record ? ( -

- No run started yet. Start a run to begin. -

- ) : isRunning ? ( -

- Run in progress — refresh to check for the result. -

- ) : state === "failed" ? ( -
-

Run failed.

- -
- ) : ( - - )} -
-
+ )} + + {revealed && ( +
+

Runs

+ + record.state === "succeeded" ? ( + + ) : record.state === "failed" ? ( +

+ {record.error || "Run failed."} +

+ ) : ( +

Running…

+ ) + } + /> +
+ )} ); } diff --git a/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx b/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx index 6059faa1..6ff932df 100644 --- a/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx +++ b/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx @@ -3,12 +3,12 @@ import { screen, waitFor } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import { useAuth0 } from "@auth0/auth0-react"; import EntityLinkagePage from "./EntityLinkagePage"; +import * as jobsModule from "../queries/jobs"; import { auth0TestDefaults, renderWithQueryClient } from "../test/utils"; const RUNNING_RECORD = { job_id: "job-123", state: "running", - dedup_key: "LinkageParams:abc", initiated_by: "Test User", params: { apply_merges: false, page: 1, page_size: 50, max_pages: null }, started_at: "2026-01-01T00:00:00Z", @@ -34,112 +34,100 @@ const SUCCEEDED_RECORD = { }, }; -function mockResponse(status, body) { - return { - ok: status >= 200 && status < 300, - status, - text: async () => JSON.stringify(body), - }; -} - describe("EntityLinkagePage", () => { - let getAccessTokenSilently; - beforeEach(() => { - getAccessTokenSilently = vi.fn().mockResolvedValue("test-access-token"); - vi.mocked(useAuth0).mockReturnValue({ - ...auth0TestDefaults, - getAccessTokenSilently, - }); + vi.clearAllMocks(); + vi.mocked(useAuth0).mockReturnValue(auth0TestDefaults); + // Default: no prior runs (loaded on mount). + vi.spyOn(jobsModule, "listJobs").mockResolvedValue([]); }); - it("starts a job (202) and authenticates the start request", async () => { - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValue(mockResponse(202, RUNNING_RECORD)); + it("starts a run, auto-polls, and renders the completed result", async () => { + const startSpy = vi + .spyOn(jobsModule, "startJob") + .mockResolvedValue(RUNNING_RECORD); + vi.spyOn(jobsModule, "getJobStatus").mockResolvedValue(SUCCEEDED_RECORD); renderWithQueryClient(); await userEvent.click(screen.getByRole("button", { name: "Start run" })); await waitFor(() => { - expect(screen.getByText("running")).toBeInTheDocument(); + expect( + screen.getByRole("heading", { name: "Match groups" }), + ).toBeInTheDocument(); }); - expect( - screen.getByText("Run in progress — refresh to check for the result."), - ).toBeInTheDocument(); + expect(screen.getByText("2 groups")).toBeInTheDocument(); + expect(screen.getByText("Resource 101")).toBeInTheDocument(); + expect(screen.getByText("Resource 205")).toBeInTheDocument(); - const [startUrl, startOptions] = fetchSpy.mock.calls[0]; - expect(startUrl).toMatch(/\/start$/); - expect(startOptions.method).toBe("POST"); - expect(startOptions.headers.Authorization).toBe("Bearer test-access-token"); - expect(getAccessTokenSilently).toHaveBeenCalledWith({ - authorizationParams: { audience: "https://stitch-api.local" }, - }); - expect(getAccessTokenSilently.mock.invocationCallOrder[0]).toBeLessThan( - fetchSpy.mock.invocationCallOrder[0], + // start body carries apply_merges + force; auto-poll happened (no manual refresh). + expect(startSpy).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ apply_merges: false, force: false }), + expect.anything(), ); + expect(jobsModule.getJobStatus).toHaveBeenCalled(); + expect( + screen.queryByRole("button", { name: /refresh status/i }), + ).not.toBeInTheDocument(); }); - it("polls /status/{job_id} on refresh and renders the completed result", async () => { - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValueOnce(mockResponse(202, RUNNING_RECORD)) - .mockResolvedValueOnce(mockResponse(200, SUCCEEDED_RECORD)); + it("offers 'Show result' for a recent run and reveals it without re-running", async () => { + vi.spyOn(jobsModule, "listJobs").mockResolvedValue([SUCCEEDED_RECORD]); + const startSpy = vi.spyOn(jobsModule, "startJob"); renderWithQueryClient(); - await userEvent.click(screen.getByRole("button", { name: "Start run" })); - await waitFor(() => { - expect(screen.getByText("running")).toBeInTheDocument(); + const showButton = await screen.findByRole("button", { + name: /show result/i, }); + await userEvent.click(showButton); - await userEvent.click( - screen.getByRole("button", { name: "Refresh status" }), - ); + expect( + await screen.findByRole("heading", { name: "Match groups" }), + ).toBeInTheDocument(); + expect(startSpy).not.toHaveBeenCalled(); + }); + + it("forces a re-run when Re-run is checked", async () => { + const startSpy = vi + .spyOn(jobsModule, "startJob") + .mockResolvedValue(SUCCEEDED_RECORD); + + renderWithQueryClient(); + + await userEvent.click(screen.getByRole("checkbox", { name: /re-run/i })); + await userEvent.click(screen.getByRole("button", { name: "Re-run" })); await waitFor(() => { expect( screen.getByRole("heading", { name: "Match groups" }), ).toBeInTheDocument(); }); - - expect(screen.getByText("succeeded")).toBeInTheDocument(); - expect(screen.getByText("2 groups")).toBeInTheDocument(); - expect( - screen.getByRole("heading", { name: "Match group 1" }), - ).toBeInTheDocument(); - expect(screen.getByText("Resource 101")).toBeInTheDocument(); - expect(screen.getByText("Resource 205")).toBeInTheDocument(); - - // The status poll hits /status/{job_id} (unauthenticated GET). - const [statusUrl] = fetchSpy.mock.calls[1]; - expect(statusUrl).toMatch(/\/status\/job-123$/); + expect(startSpy).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ apply_merges: false, force: true }), + expect.anything(), + ); }); it("surfaces a failed run", async () => { - vi.spyOn(globalThis, "fetch") - .mockResolvedValueOnce(mockResponse(202, RUNNING_RECORD)) - .mockResolvedValueOnce( - mockResponse(200, { - ...RUNNING_RECORD, - state: "failed", - finished_at: "2026-01-01T00:00:05Z", - error: "GET /oil-gas-fields/ failed with status 500: boom", - }), - ); + vi.spyOn(jobsModule, "startJob").mockResolvedValue({ + ...RUNNING_RECORD, + state: "failed", + finished_at: "2026-01-01T00:00:05Z", + error: "GET /oil-gas-fields/ failed with status 500: boom", + }); renderWithQueryClient(); await userEvent.click(screen.getByRole("button", { name: "Start run" })); - await waitFor(() => screen.getByText("running")); - await userEvent.click( - screen.getByRole("button", { name: "Refresh status" }), - ); await waitFor(() => { - expect(screen.getByText("Run failed.")).toBeInTheDocument(); + expect( + screen.getByText("GET /oil-gas-fields/ failed with status 500: boom"), + ).toBeInTheDocument(); }); - expect(screen.getByText("failed")).toBeInTheDocument(); }); }); From 250e243b730760f6bbbe0bd585af784e2ab6ad72 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 15:23:33 +0200 Subject: [PATCH 14/26] rework ETL ui with job components forward looking to ETL integration of stitch-jobs and -service --- .../stitch-frontend/src/pages/EtlPage.jsx | 224 +++++++----------- .../src/pages/EtlPage.test.jsx | 129 +++++----- 2 files changed, 140 insertions(+), 213 deletions(-) diff --git a/deployments/stitch-frontend/src/pages/EtlPage.jsx b/deployments/stitch-frontend/src/pages/EtlPage.jsx index cbafee6f..647d6034 100644 --- a/deployments/stitch-frontend/src/pages/EtlPage.jsx +++ b/deployments/stitch-frontend/src/pages/EtlPage.jsx @@ -1,10 +1,18 @@ import { useState } from "react"; import { useAuth0 } from "@auth0/auth0-react"; import { useConfig } from "../config/useConfig"; +import { createAuthenticatedFetcher } from "../auth/api"; +import { useJobRunner } from "../hooks/useJobRunner"; +import JobTriggerButton from "../components/JobTriggerButton"; +import JobResultList from "../components/JobResultList"; +import LastUpdated from "../components/LastUpdated"; import StructuredDataView from "../components/StructuredDataView"; -import Button from "../components/Button"; import Input from "../components/Input"; +// NOTE: the ETL services aren't on the shared `stitch-jobs` framework yet. This +// UI targets that contract (POST /start, GET /status/{job_id}, GET /jobs) so it +// lights up once the backend adopts it. + // Per-ETL run parameters. Empty number/text fields are omitted from the // request body so the service falls back to its env-derived defaults. const GEM_FIELDS = [ @@ -43,37 +51,7 @@ const WOODMAC_FIELDS = [ }, ]; -const STATE_STYLES = { - running: "border-warning/30 bg-warning-soft text-warning", - succeeded: "border-success/25 bg-success-soft text-success-strong", - failed: "border-danger/25 bg-danger-soft text-danger", -}; - -function StateBadge({ state }) { - if (!state) return null; - - const classes = STATE_STYLES[state] ?? "border-line bg-surface text-ink"; - - return ( - - {state} - - ); -} - -async function parseJsonResponse(response) { - const text = await response.text(); - - try { - return text ? JSON.parse(text) : null; - } catch { - return { raw: text }; - } -} - -function EtlPanel({ title, description, baseUrl, fields, getToken }) { +function EtlPanel({ title, description, baseUrl, fields, fetcher }) { const [values, setValues] = useState(() => Object.fromEntries( fields.map((field) => [ @@ -82,17 +60,23 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { ]), ), ); - const [starting, setStarting] = useState(false); - const [refreshing, setRefreshing] = useState(false); - const [record, setRecord] = useState(null); - const [error, setError] = useState(null); + const [forceRerun, setForceRerun] = useState(false); + const [revealed, setRevealed] = useState(false); + + // All runs at this service's base URL belong to this pipeline. + const job = useJobRunner({ + baseUrl, + fetcher, + paramsKey: baseUrl, + matchesParams: () => true, + }); function setField(key, value) { setValues((prev) => ({ ...prev, [key]: value })); } function buildRequestBody() { - const body = {}; + const body = { force: forceRerun }; for (const field of fields) { const value = values[field.key]; @@ -107,81 +91,22 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { return body; } - async function handleStart() { - setStarting(true); - setError(null); - - try { - const token = await getToken(); - - const response = await fetch(`${baseUrl}/start`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify(buildRequestBody()), - }); - - const parsed = await parseJsonResponse(response); - - if (response.status === 409) { - setError({ - status: 409, - message: "A run is already in progress — refresh status to check.", - body: parsed, - }); - return; - } - - if (!response.ok) { - setError({ status: response.status, body: parsed }); - return; - } - - setRecord(parsed); - } catch (err) { - setError({ - status: null, - body: err instanceof Error ? err.message : String(err), - }); - } finally { - setStarting(false); + async function handleTrigger() { + // A recent run exists and we're not forcing → just reveal it. + if (job.hasExisting && !forceRerun && !revealed) { + setRevealed(true); + return; } + setRevealed(true); + await job.start(buildRequestBody()); + setForceRerun(false); } - async function handleRefresh() { - setRefreshing(true); - setError(null); - - try { - // GET /status is unauthenticated per the ETL OpenAPI spec. - const response = await fetch(`${baseUrl}/status`); - const parsed = await parseJsonResponse(response); - - if (!response.ok) { - setError({ status: response.status, body: parsed }); - return; - } - - setRecord(parsed); - } catch (err) { - setError({ - status: null, - body: err instanceof Error ? err.message : String(err), - }); - } finally { - setRefreshing(false); - } - } - - const isRunning = record?.state === "running"; - return (

{title}

- +

{description}

@@ -194,6 +119,7 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { type="checkbox" checked={values[field.key]} onChange={(e) => setField(field.key, e.target.checked)} + disabled={job.isRunning} className="accent-primary" /> {field.label} @@ -209,6 +135,7 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { onChange={(e) => setField(field.key, e.target.value)} placeholder={field.placeholder} min={field.type === "number" ? 1 : undefined} + disabled={job.isRunning} className="w-full" /> @@ -220,43 +147,62 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { )} -
- - + /> +
- {error ? ( + {job.error ? (
- {error.message ? ( -

{error.message}

- ) : null} - + {job.error}
) : null}
-

Run status

- {record ? ( - +

Runs

+ {revealed && job.records.length ? ( + + record.state === "succeeded" ? ( + + ) : record.state === "failed" ? ( +

+ {record.error || "Run failed."} +

+ ) : ( +

Running…

+ ) + } + /> ) : (

- No run started yet. Start a run or refresh to fetch the latest - status. + No run started yet. Start a run to begin.

)}
@@ -267,11 +213,7 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { export default function EtlPage() { const config = useConfig(); const { getAccessTokenSilently } = useAuth0(); - - const getToken = () => - getAccessTokenSilently({ - authorizationParams: { audience: config.auth0.audience }, - }); + const fetcher = createAuthenticatedFetcher(config, getAccessTokenSilently); return (
@@ -281,8 +223,8 @@ export default function EtlPage() {

ETL Pipelines

- Start an ETL run and check its status. Only one run per pipeline may - be active at a time. + Start an ETL run and watch its status. A recent run for a pipeline is + shown rather than started again; use “Re-run” to force a fresh run.

@@ -292,14 +234,14 @@ export default function EtlPage() { description="Load GEM oil & gas data from the configured spreadsheet and post it to Stitch." baseUrl={config.etlGemBaseUrl} fields={GEM_FIELDS} - getToken={getToken} + fetcher={fetcher} /> diff --git a/deployments/stitch-frontend/src/pages/EtlPage.test.jsx b/deployments/stitch-frontend/src/pages/EtlPage.test.jsx index 148a3e9c..097e8844 100644 --- a/deployments/stitch-frontend/src/pages/EtlPage.test.jsx +++ b/deployments/stitch-frontend/src/pages/EtlPage.test.jsx @@ -3,24 +3,36 @@ import { screen, waitFor, within } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import { useAuth0 } from "@auth0/auth0-react"; import EtlPage from "./EtlPage"; +import * as jobsModule from "../queries/jobs"; import { auth0TestDefaults, renderWithQueryClient } from "../test/utils"; +const GEM_BASE = "http://localhost:8101/api/v1"; + function getPanel(title) { return screen.getByRole("heading", { name: title }).closest("section"); } -describe("EtlPage", () => { - let getAccessTokenSilently; +function succeededRecord(overrides = {}) { + return { + job_id: "job-123", + state: "succeeded", + started_at: "2026-06-11T10:00:00Z", + finished_at: "2026-06-11T10:05:00Z", + params: {}, + result: { payloads_posted: 42 }, + error: null, + ...overrides, + }; +} +describe("EtlPage", () => { beforeEach(() => { - getAccessTokenSilently = vi.fn().mockResolvedValue("test-access-token"); - vi.mocked(useAuth0).mockReturnValue({ - ...auth0TestDefaults, - getAccessTokenSilently, - }); + vi.clearAllMocks(); + vi.mocked(useAuth0).mockReturnValue(auth0TestDefaults); + vi.spyOn(jobsModule, "listJobs").mockResolvedValue([]); }); - it("renders a panel for each ETL pipeline", () => { + it("renders a panel for each ETL pipeline with no manual refresh", () => { renderWithQueryClient(); expect(screen.getByRole("heading", { name: "GEM" })).toBeInTheDocument(); @@ -31,22 +43,14 @@ describe("EtlPage", () => { 2, ); expect( - screen.getAllByRole("button", { name: "Refresh status" }), - ).toHaveLength(2); + screen.queryByRole("button", { name: /refresh status/i }), + ).not.toBeInTheDocument(); }); - it("starts a GEM run with an authenticated token and shows the returned state", async () => { - const fetchMock = vi.spyOn(globalThis, "fetch").mockResolvedValue({ - ok: true, - status: 202, - text: async () => - JSON.stringify({ - job_id: "job-123", - state: "running", - started_at: "2026-06-11T10:00:00Z", - initiated_by: "Test User", - }), - }); + it("starts a GEM run and renders the completed result", async () => { + const startSpy = vi + .spyOn(jobsModule, "startJob") + .mockResolvedValue(succeededRecord()); renderWithQueryClient(); @@ -56,77 +60,58 @@ describe("EtlPage", () => { ); await waitFor(() => { - expect(within(gemPanel).getAllByText("running").length).toBeGreaterThan( - 0, - ); + expect(within(gemPanel).getByText("succeeded")).toBeInTheDocument(); }); - expect(getAccessTokenSilently).toHaveBeenCalledWith({ - authorizationParams: { audience: "https://stitch-api.local" }, - }); - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:8101/api/v1/start", - expect.objectContaining({ - method: "POST", - headers: expect.objectContaining({ - Authorization: "Bearer test-access-token", - }), - }), + expect(startSpy).toHaveBeenCalledWith( + GEM_BASE, + expect.objectContaining({ force: false }), + expect.anything(), ); }); - it("surfaces a friendly message when a run is already in progress (409)", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue({ - ok: false, - status: 409, - text: async () => JSON.stringify({ detail: "A run is already active" }), - }); + it("offers 'Show result' for a recent run and reveals it without re-running", async () => { + vi.spyOn(jobsModule, "listJobs").mockImplementation(async (baseUrl) => + baseUrl === GEM_BASE ? [succeededRecord()] : [], + ); + const startSpy = vi.spyOn(jobsModule, "startJob"); renderWithQueryClient(); - const woodmacPanel = getPanel("WoodMac"); - await userEvent.click( - within(woodmacPanel).getByRole("button", { name: "Start run" }), - ); + const gemPanel = getPanel("GEM"); + const showButton = await within(gemPanel).findByRole("button", { + name: /show result/i, + }); + await userEvent.click(showButton); await waitFor(() => { - expect( - within(woodmacPanel).getByText( - "A run is already in progress — refresh status to check.", - ), - ).toBeInTheDocument(); + expect(within(gemPanel).getByText("succeeded")).toBeInTheDocument(); }); + expect(startSpy).not.toHaveBeenCalled(); }); - it("refreshes status via an unauthenticated GET", async () => { - const fetchMock = vi.spyOn(globalThis, "fetch").mockResolvedValue({ - ok: true, - status: 200, - text: async () => - JSON.stringify({ - job_id: "job-789", - state: "succeeded", - started_at: "2026-06-11T10:00:00Z", - finished_at: "2026-06-11T10:05:00Z", - result: { payloads_posted: 42 }, - }), - }); + it("forces a re-run when Re-run is checked", async () => { + const startSpy = vi + .spyOn(jobsModule, "startJob") + .mockResolvedValue(succeededRecord()); renderWithQueryClient(); - const woodmacPanel = getPanel("WoodMac"); + const gemPanel = getPanel("GEM"); + await userEvent.click( + within(gemPanel).getByRole("checkbox", { name: /re-run/i }), + ); await userEvent.click( - within(woodmacPanel).getByRole("button", { name: "Refresh status" }), + within(gemPanel).getByRole("button", { name: "Re-run" }), ); await waitFor(() => { - expect( - within(woodmacPanel).getAllByText("succeeded").length, - ).toBeGreaterThan(0); + expect(within(gemPanel).getByText("succeeded")).toBeInTheDocument(); }); - - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:8102/api/v1/status", + expect(startSpy).toHaveBeenCalledWith( + GEM_BASE, + expect.objectContaining({ force: true }), + expect.anything(), ); }); }); From c9d24c82ff0faf3b65495a5859b2de7e4e811452 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 15:40:35 +0200 Subject: [PATCH 15/26] fix test failing on `--exact` --- .../stitch-llm/tests/test_oil_gas_fields_api.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/deployments/stitch-llm/tests/test_oil_gas_fields_api.py b/deployments/stitch-llm/tests/test_oil_gas_fields_api.py index c281daed..0df54ee5 100644 --- a/deployments/stitch-llm/tests/test_oil_gas_fields_api.py +++ b/deployments/stitch-llm/tests/test_oil_gas_fields_api.py @@ -14,6 +14,7 @@ from stitch.ogsi.model.og_field import OilGasFieldBase from stitch.service.auth import RequestAuthContext +from stitch.llm import auth as auth_module from stitch.llm import jobs as jobs_module from stitch.llm import main as main_module from stitch.llm.auth import get_request_auth_context, get_token_claims @@ -133,8 +134,11 @@ def reset_job_manager(): @pytest.fixture def test_client(monkeypatch: pytest.MonkeyPatch): # Default: auth-disabled, Azure unconfigured (placeholder mode for the job). + # Patch auth's settings too, so startup auth validation short-circuits + # instead of building OIDCSettings (which has no env config in CI). test_settings = _settings(auth_disabled=True) monkeypatch.setattr(jobs_module, "get_settings", lambda: test_settings) + monkeypatch.setattr(auth_module, "get_settings", lambda: test_settings) monkeypatch.setattr( main_module, "validate_downstream_auth_config_at_startup", lambda: None ) @@ -207,9 +211,9 @@ def _run(client: TestClient, **kwargs) -> dict: def test_start_requires_service_permission( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr( - jobs_module, "get_settings", lambda: _settings(auth_disabled=True) - ) + test_settings = _settings(auth_disabled=True) + monkeypatch.setattr(jobs_module, "get_settings", lambda: test_settings) + monkeypatch.setattr(auth_module, "get_settings", lambda: test_settings) monkeypatch.setattr( main_module, "validate_downstream_auth_config_at_startup", lambda: None ) From 2803a54a22f186fa9e8e410c96023084457959d5 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 17:07:59 +0200 Subject: [PATCH 16/26] safer default --- packages/stitch-jobs/src/stitch/jobs/routers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/stitch-jobs/src/stitch/jobs/routers.py b/packages/stitch-jobs/src/stitch/jobs/routers.py index 06752a10..ef15b37e 100644 --- a/packages/stitch-jobs/src/stitch/jobs/routers.py +++ b/packages/stitch-jobs/src/stitch/jobs/routers.py @@ -68,7 +68,9 @@ async def start( caller observes that run rather than starting a duplicate). """ params = to_params(request) - force = bool(getattr(request, force_attr)) if force_attr else False + # Default to False so a mis-set force_attr degrades to "no force" + # rather than raising AttributeError (500). + force = bool(getattr(request, force_attr, False)) if force_attr else False record, created = await manager.start( params, initiated_by=initiated_by_label, force=force ) From ba0d8cb247e4c0bc0c74dc9f33e6ae09b3211925 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 18:16:51 +0200 Subject: [PATCH 17/26] document expected behavior in LLM --- deployments/stitch-llm/src/stitch/llm/jobs.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deployments/stitch-llm/src/stitch/llm/jobs.py b/deployments/stitch-llm/src/stitch/llm/jobs.py index 4187a20b..3b8f8541 100644 --- a/deployments/stitch-llm/src/stitch/llm/jobs.py +++ b/deployments/stitch-llm/src/stitch/llm/jobs.py @@ -43,6 +43,10 @@ async def run_suggestion(params: FieldSuggestionParams) -> FieldSuggestionRespon async with StitchApiClient() as stitch_client: detail_view = await stitch_client.get_oil_gas_field_detail(resource_id) + # Expected behavior: if the field is already populated this raises and the + # run is recorded as a failed job (surfaced in the UI as a failed run), + # rather than the old synchronous 409. That's intentional — requesting a + # suggestion for an already-filled field is a no-op the user can see. ensure_field_is_missing(detail_view, field) input_messages = build_field_suggestion_input( From 4f58e8dbc1709447a12c3994d72ed80fbee9a0e8 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 18:35:05 +0200 Subject: [PATCH 18/26] remove limit on `/jobs` endpoint client-side filter lists all recent jobs (across cache keys), which means frequently ysed services would silently ignore dedupe. `find` endpoint returns a params payload (same as POST start) and returns jobs with those params --- .../stitch-frontend/src/hooks/useJobRunner.js | 44 ++++++++++++------- .../src/pages/EntityLinkagePage.jsx | 3 +- .../src/pages/EntityLinkagePage.test.jsx | 4 +- .../stitch-frontend/src/pages/EtlPage.jsx | 11 ++--- .../src/pages/EtlPage.test.jsx | 4 +- .../src/pages/ResourceDetailPage.jsx | 5 +-- .../src/pages/ResourceDetailPage.test.jsx | 4 +- .../stitch-frontend/src/queries/jobs.js | 13 ++++++ .../stitch-frontend/src/queries/jobs.test.js | 23 +++++++++- .../stitch-jobs/src/stitch/jobs/manager.py | 16 +++++++ .../stitch-jobs/src/stitch/jobs/routers.py | 9 ++++ packages/stitch-jobs/src/stitch/jobs/store.py | 18 ++++++++ packages/stitch-jobs/tests/test_manager.py | 32 ++++++++++++++ packages/stitch-jobs/tests/test_router.py | 16 +++++++ 14 files changed, 165 insertions(+), 37 deletions(-) diff --git a/deployments/stitch-frontend/src/hooks/useJobRunner.js b/deployments/stitch-frontend/src/hooks/useJobRunner.js index 18bd9a07..87410e7b 100644 --- a/deployments/stitch-frontend/src/hooks/useJobRunner.js +++ b/deployments/stitch-frontend/src/hooks/useJobRunner.js @@ -1,5 +1,5 @@ import { useCallback, useEffect, useRef, useState } from "react"; -import { getJobStatus, listJobs, startJob } from "../queries/jobs"; +import { findJobs, getJobStatus, startJob } from "../queries/jobs"; const POLL_INTERVAL_MS = 1000; @@ -10,15 +10,16 @@ function sortNewestFirst(records) { ); } -// Drives a Stitch job from the UI: loads prior jobs for the current params on -// mount, starts/auto-polls runs, and tracks the records (newest first). Shared -// by every job-shaped service (LLM, entity-linkage, ETL). +// Drives a Stitch job from the UI: loads the prior runs for the current params +// on mount, starts/auto-polls runs, and tracks the records (newest first). +// Shared by every job-shaped service (LLM, entity-linkage, ETL). // -// - baseUrl: where the job routes live (POST /start, GET /status, GET /jobs). +// - baseUrl: where the job routes live (POST /start, POST /find, GET /status). // - fetcher: authenticated fetch wrapper (may change each render — captured by ref). -// - paramsKey: stable string for the current params; reloading keys off it. -// - matchesParams: predicate used to filter /jobs down to the current params. -export function useJobRunner({ baseUrl, fetcher, paramsKey, matchesParams }) { +// - lookupBody: the request params (without `force`) used to look up existing +// runs via /find; the server filters by the same dedup policy as /start, so +// there's no fetch-everything-then-filter and no client/server filter drift. +export function useJobRunner({ baseUrl, fetcher, lookupBody }) { const [records, setRecords] = useState([]); const [isStarting, setIsStarting] = useState(false); const [isPolling, setIsPolling] = useState(false); @@ -28,8 +29,10 @@ export function useJobRunner({ baseUrl, fetcher, paramsKey, matchesParams }) { // Stable refs so the load effect doesn't churn on every parent re-render. const fetcherRef = useRef(fetcher); fetcherRef.current = fetcher; - const matchesRef = useRef(matchesParams); - matchesRef.current = matchesParams; + const lookupRef = useRef(lookupBody); + lookupRef.current = lookupBody; + // Serialized lookup params double as the effect's reload key. + const lookupKey = JSON.stringify(lookupBody ?? null); // Bumped whenever params change / on unmount, to cancel stale polls. const generationRef = useRef(0); @@ -65,7 +68,7 @@ export function useJobRunner({ baseUrl, fetcher, paramsKey, matchesParams }) { [baseUrl, upsert], ); - // Load existing jobs for the current params on mount / when params change. + // Load the runs for the current params on mount / when params change. useEffect(() => { generationRef.current += 1; const generation = generationRef.current; @@ -75,15 +78,19 @@ export function useJobRunner({ baseUrl, fetcher, paramsKey, matchesParams }) { (async () => { try { - const all = await listJobs(baseUrl, fetcherRef.current); + const mine = await findJobs( + baseUrl, + lookupRef.current ?? {}, + fetcherRef.current, + ); if (generationRef.current !== generation) return; - const mine = sortNewestFirst(all.filter((r) => matchesRef.current(r))); - setRecords(mine); + const sorted = sortNewestFirst(mine); + setRecords(sorted); setLastUpdatedAt(Date.now()); - const running = mine.find((r) => r.state === "running"); + const running = sorted.find((r) => r.state === "running"); if (running) poll(running.job_id, generation); } catch { - // No prior jobs (or list unavailable) — start from a clean slate. + // No prior runs (or lookup unavailable) — start from a clean slate. if (generationRef.current === generation) setRecords([]); } })(); @@ -91,7 +98,7 @@ export function useJobRunner({ baseUrl, fetcher, paramsKey, matchesParams }) { return () => { generationRef.current += 1; // cancel any in-flight poll for this generation }; - }, [baseUrl, paramsKey, poll]); + }, [baseUrl, lookupKey, poll]); const start = useCallback( async (body) => { @@ -114,6 +121,9 @@ export function useJobRunner({ baseUrl, fetcher, paramsKey, matchesParams }) { [baseUrl, poll, upsert], ); + // Known behavior: `current` is the newest run by start time (which drives the + // running/spinner state), while `latestSucceeded` (used for results/persist) + // is the newest succeeded run. With force re-runs these can differ briefly. const current = records[0] ?? null; const latestSucceeded = records.find((r) => r.state === "succeeded") ?? null; diff --git a/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx b/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx index e13bef02..ef4258f9 100644 --- a/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx +++ b/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx @@ -105,8 +105,7 @@ export default function EntityLinkagePage() { const job = useJobRunner({ baseUrl: config.entityLinkageBaseUrl, fetcher, - paramsKey: `${applyMerges}`, - matchesParams: (record) => record.params?.apply_merges === applyMerges, + lookupBody: { apply_merges: applyMerges }, }); function handleToggleApplyMerges(event) { diff --git a/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx b/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx index 6ff932df..2c96fbae 100644 --- a/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx +++ b/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx @@ -39,7 +39,7 @@ describe("EntityLinkagePage", () => { vi.clearAllMocks(); vi.mocked(useAuth0).mockReturnValue(auth0TestDefaults); // Default: no prior runs (loaded on mount). - vi.spyOn(jobsModule, "listJobs").mockResolvedValue([]); + vi.spyOn(jobsModule, "findJobs").mockResolvedValue([]); }); it("starts a run, auto-polls, and renders the completed result", async () => { @@ -74,7 +74,7 @@ describe("EntityLinkagePage", () => { }); it("offers 'Show result' for a recent run and reveals it without re-running", async () => { - vi.spyOn(jobsModule, "listJobs").mockResolvedValue([SUCCEEDED_RECORD]); + vi.spyOn(jobsModule, "findJobs").mockResolvedValue([SUCCEEDED_RECORD]); const startSpy = vi.spyOn(jobsModule, "startJob"); renderWithQueryClient(); diff --git a/deployments/stitch-frontend/src/pages/EtlPage.jsx b/deployments/stitch-frontend/src/pages/EtlPage.jsx index 647d6034..75af0123 100644 --- a/deployments/stitch-frontend/src/pages/EtlPage.jsx +++ b/deployments/stitch-frontend/src/pages/EtlPage.jsx @@ -63,13 +63,10 @@ function EtlPanel({ title, description, baseUrl, fields, fetcher }) { const [forceRerun, setForceRerun] = useState(false); const [revealed, setRevealed] = useState(false); - // All runs at this service's base URL belong to this pipeline. - const job = useJobRunner({ - baseUrl, - fetcher, - paramsKey: baseUrl, - matchesParams: () => true, - }); + // Look up this pipeline's runs with default params (a stable key, so editing + // the tunable fields doesn't refetch on every keystroke). The pipeline is its + // own service, so /find returns its runs per the backend's dedup policy. + const job = useJobRunner({ baseUrl, fetcher, lookupBody: {} }); function setField(key, value) { setValues((prev) => ({ ...prev, [key]: value })); diff --git a/deployments/stitch-frontend/src/pages/EtlPage.test.jsx b/deployments/stitch-frontend/src/pages/EtlPage.test.jsx index 097e8844..94f02754 100644 --- a/deployments/stitch-frontend/src/pages/EtlPage.test.jsx +++ b/deployments/stitch-frontend/src/pages/EtlPage.test.jsx @@ -29,7 +29,7 @@ describe("EtlPage", () => { beforeEach(() => { vi.clearAllMocks(); vi.mocked(useAuth0).mockReturnValue(auth0TestDefaults); - vi.spyOn(jobsModule, "listJobs").mockResolvedValue([]); + vi.spyOn(jobsModule, "findJobs").mockResolvedValue([]); }); it("renders a panel for each ETL pipeline with no manual refresh", () => { @@ -71,7 +71,7 @@ describe("EtlPage", () => { }); it("offers 'Show result' for a recent run and reveals it without re-running", async () => { - vi.spyOn(jobsModule, "listJobs").mockImplementation(async (baseUrl) => + vi.spyOn(jobsModule, "findJobs").mockImplementation(async (baseUrl) => baseUrl === GEM_BASE ? [succeededRecord()] : [], ); const startSpy = vi.spyOn(jobsModule, "startJob"); diff --git a/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx b/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx index 00c6cc11..cb4f6895 100644 --- a/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx +++ b/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx @@ -191,10 +191,7 @@ function AISuggestionPanel({ endpoint, resourceId }) { const job = useJobRunner({ baseUrl: `${config.stitchLlmBaseUrl}/${endpoint}`, fetcher, - paramsKey: `${resourceId}:${selectedField}`, - matchesParams: (record) => - record.params?.resource_id === resourceId && - record.params?.field === selectedField, + lookupBody: { resource_id: resourceId, field: selectedField }, }); // Persist (and the value/citation rendering) act on the latest succeeded run. diff --git a/deployments/stitch-frontend/src/pages/ResourceDetailPage.test.jsx b/deployments/stitch-frontend/src/pages/ResourceDetailPage.test.jsx index f0c2020c..9fdba777 100644 --- a/deployments/stitch-frontend/src/pages/ResourceDetailPage.test.jsx +++ b/deployments/stitch-frontend/src/pages/ResourceDetailPage.test.jsx @@ -83,7 +83,7 @@ beforeEach(() => { vi.mocked(useSourceDetail).mockReturnValue(defaultSourceDetailHookReturn); // Default: no prior jobs for the current resource/field (the panel loads // these on mount). Individual tests override to exercise "Show suggestion". - vi.spyOn(jobsModule, "listJobs").mockResolvedValue([]); + vi.spyOn(jobsModule, "findJobs").mockResolvedValue([]); vi.stubGlobal("crypto", { randomUUID: () => "persist-uuid-123", }); @@ -492,7 +492,7 @@ describe("ResourceDetailPage", () => { ...defaultHookReturn, data: mockDetailView, }); - vi.spyOn(jobsModule, "listJobs").mockResolvedValue([ + vi.spyOn(jobsModule, "findJobs").mockResolvedValue([ { job_id: "prior-1", state: "succeeded", diff --git a/deployments/stitch-frontend/src/queries/jobs.js b/deployments/stitch-frontend/src/queries/jobs.js index ade16235..c9fedd98 100644 --- a/deployments/stitch-frontend/src/queries/jobs.js +++ b/deployments/stitch-frontend/src/queries/jobs.js @@ -58,3 +58,16 @@ export async function listJobs(baseUrl, fetcher, { limit = 50 } = {}) { if (!response.ok) throw await errorFromResponse(response); return await response.json(); } + +// Return the runs matching a request's params (server applies the same dedup +// policy as /start), newest first. Lets the UI discover/reuse the existing run +// for exactly these params without fetching-then-filtering the whole job list. +export async function findJobs(baseUrl, body, fetcher) { + const response = await fetcher(`${baseUrl}/find`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); + if (!response.ok) throw await errorFromResponse(response); + return await response.json(); +} diff --git a/deployments/stitch-frontend/src/queries/jobs.test.js b/deployments/stitch-frontend/src/queries/jobs.test.js index 6c8b0468..125528ba 100644 --- a/deployments/stitch-frontend/src/queries/jobs.test.js +++ b/deployments/stitch-frontend/src/queries/jobs.test.js @@ -1,5 +1,5 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; -import { getJobStatus, listJobs, startJob } from "./jobs"; +import { findJobs, getJobStatus, listJobs, startJob } from "./jobs"; const BASE = "http://localhost:8002/api/v1/oil-gas-fields"; @@ -72,4 +72,25 @@ describe("job client", () => { }); expect(records).toHaveLength(1); }); + + it("findJobs POSTs the lookup params to /find", async () => { + fetcher.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => [{ job_id: "job-1", state: "succeeded" }], + }); + + const records = await findJobs( + BASE, + { resource_id: 42, field: "basin" }, + fetcher, + ); + + expect(fetcher).toHaveBeenCalledWith(`${BASE}/find`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ resource_id: 42, field: "basin" }), + }); + expect(records).toHaveLength(1); + }); }); diff --git a/packages/stitch-jobs/src/stitch/jobs/manager.py b/packages/stitch-jobs/src/stitch/jobs/manager.py index 85dec8b6..f4c3f766 100644 --- a/packages/stitch-jobs/src/stitch/jobs/manager.py +++ b/packages/stitch-jobs/src/stitch/jobs/manager.py @@ -133,3 +133,19 @@ async def get(self, job_id: str) -> JobRecord[P, R] | None: async def list(self, *, limit: int | None = None) -> list[JobRecord[P, R]]: return await self._store.list(limit=limit) + + async def list_for_params( + self, params: P, *, limit: int | None = None + ) -> list[JobRecord[P, R]]: + """Return runs whose dedup key matches ``params``, newest first. + + Lets a caller discover the runs for a specific request (e.g. a given + resource/field) without scanning the whole job list — the server + applies the same uniqueness policy used for dedup, so there is no + client/server filter drift. Returns ``[]`` when the policy opts the + params out of deduplication (no stable key). + """ + key = self._policy.key(params) + if key is None: + return [] + return await self._store.list_by_key(key, limit=limit) diff --git a/packages/stitch-jobs/src/stitch/jobs/routers.py b/packages/stitch-jobs/src/stitch/jobs/routers.py index ef15b37e..9c992591 100644 --- a/packages/stitch-jobs/src/stitch/jobs/routers.py +++ b/packages/stitch-jobs/src/stitch/jobs/routers.py @@ -95,4 +95,13 @@ async def jobs( """List recent jobs, newest first — for discovering an in-flight run.""" return await manager.list(limit=limit) + @router.post("/find", response_model=list[record_model]) + async def find(request: start_request_model): + """Return the runs matching a request's params (same dedup policy as + ``/start``), newest first — so a caller can discover/reuse the existing + run for exactly these params without scanning the whole job list. + """ + params = to_params(request) + return await manager.list_for_params(params, limit=default_list_limit) + return router diff --git a/packages/stitch-jobs/src/stitch/jobs/store.py b/packages/stitch-jobs/src/stitch/jobs/store.py index c7c3561f..0cb0c852 100644 --- a/packages/stitch-jobs/src/stitch/jobs/store.py +++ b/packages/stitch-jobs/src/stitch/jobs/store.py @@ -38,6 +38,11 @@ async def find_active_or_recent( async def list(self, *, limit: int | None = None) -> list[JobRecord]: """Return recent records, newest first.""" + async def list_by_key( + self, dedup_key: str, *, limit: int | None = None + ) -> list[JobRecord]: + """Return records with this dedup key, newest first.""" + def clear(self) -> None: """Drop all records (test affordance).""" @@ -129,3 +134,16 @@ async def list(self, *, limit: int | None = None) -> list[JobRecord]: if limit is not None: records = records[:limit] return records + + async def list_by_key( + self, dedup_key: str, *, limit: int | None = None + ) -> list[JobRecord]: + self._evict_expired() + records = sorted( + (r for r in self._records.values() if r.dedup_key == dedup_key), + key=lambda record: record.started_at, + reverse=True, + ) + if limit is not None: + records = records[:limit] + return records diff --git a/packages/stitch-jobs/tests/test_manager.py b/packages/stitch-jobs/tests/test_manager.py index 2d9b62dd..05771268 100644 --- a/packages/stitch-jobs/tests/test_manager.py +++ b/packages/stitch-jobs/tests/test_manager.py @@ -11,6 +11,7 @@ InMemoryJobStore, JobManager, JobState, + NoDedupPolicy, SingletonPolicy, ) @@ -268,6 +269,37 @@ async def run(params: Params) -> Result: assert reused.job_id == first.job_id +@pytest.mark.anyio +async def test_list_for_params_returns_only_matching_key_newest_first() -> None: + async def run(params: Params) -> Result: + return Result(value=1) + + manager: JobManager[Params, Result] = JobManager( + run, policy=FingerprintPolicy(), recent_within=None + ) + a, _ = await manager.start(Params(name="a")) + await _wait_until_terminal(manager, a.job_id) + b, _ = await manager.start(Params(name="b")) + await _wait_until_terminal(manager, b.job_id) + a2, _ = await manager.start(Params(name="a"), force=True) + await _wait_until_terminal(manager, a2.job_id) + + runs = await manager.list_for_params(Params(name="a")) + assert [r.job_id for r in runs] == [a2.job_id, a.job_id] # newest first, no "b" + + +@pytest.mark.anyio +async def test_list_for_params_empty_when_policy_opts_out() -> None: + async def run(params: Params) -> Result: + return Result(value=1) + + manager: JobManager[Params, Result] = JobManager(run, policy=NoDedupPolicy()) + record, _ = await manager.start(Params(name="a")) + await _wait_until_terminal(manager, record.job_id) + + assert await manager.list_for_params(Params(name="a")) == [] + + @pytest.mark.anyio async def test_terminal_records_evicted_after_retention() -> None: now = {"t": datetime(2026, 1, 1, tzinfo=UTC)} diff --git a/packages/stitch-jobs/tests/test_router.py b/packages/stitch-jobs/tests/test_router.py index b86459cf..87126caa 100644 --- a/packages/stitch-jobs/tests/test_router.py +++ b/packages/stitch-jobs/tests/test_router.py @@ -123,6 +123,22 @@ async def run(params: StartRequest) -> Result: assert {job["params"]["name"] for job in listed} == {"a", "b"} +def test_find_returns_runs_matching_params() -> None: + async def run(params: StartRequest) -> Result: + return Result(value=len(params.name)) + + app = build_app(JobManager(run, policy=FingerprintPolicy(), recent_within=None)) + + with TestClient(app) as client: + a = client.post("/api/v1/start", json={"name": "a"}) + _poll(client, a.json()["job_id"]) + b = client.post("/api/v1/start", json={"name": "b"}) + _poll(client, b.json()["job_id"]) + + found = client.post("/api/v1/find", json={"name": "a"}).json() + assert [r["params"]["name"] for r in found] == ["a"] + + def test_dependencies_gate_start() -> None: async def run(params: StartRequest) -> Result: return Result(value=1) From 0112054fd810b13fd1e7330951dd17258f4b2d74 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 18:51:59 +0200 Subject: [PATCH 19/26] extract duplicate `initiated_by` logic --- .../src/stitch/entity_linkage/auth.py | 1 + .../src/stitch/entity_linkage/routers/start.py | 11 +---------- deployments/entity-linkage/tests/test_start.py | 15 --------------- deployments/stitch-llm/src/stitch/llm/auth.py | 1 + .../src/stitch/llm/routers/oil_gas_fields.py | 12 ++---------- .../stitch-service/src/stitch/service/auth.py | 10 ++++++++++ packages/stitch-service/tests/test_auth.py | 18 ++++++++++++++++++ 7 files changed, 33 insertions(+), 35 deletions(-) diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/auth.py b/deployments/entity-linkage/src/stitch/entity_linkage/auth.py index 45ac3bb9..7efeb2fe 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/auth.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/auth.py @@ -16,6 +16,7 @@ require_permissions = _auth.require_permissions get_current_user = _auth.get_current_user get_request_auth_context = _auth.get_request_auth_context +initiated_by = _auth.initiated_by Claims = _auth.Claims CurrentUser = _auth.CurrentUser diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py b/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py index 0b8db3e6..18b50dfc 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py @@ -12,8 +12,7 @@ make_job_router, ) -from stitch.entity_linkage.auth import AuthContext, require_permissions -from stitch.entity_linkage.entities import User +from stitch.entity_linkage.auth import initiated_by, require_permissions from stitch.entity_linkage.linkage import LinkageParams, LinkageResult, run_linkage # Two requests are "the same" run when all tunable params match. Identical @@ -47,14 +46,6 @@ def _to_params(request: StartLinkageRequest) -> LinkageParams: return LinkageParams(**request.model_dump(exclude={"force"})) -def _extract_user_label(user: User) -> str: - return user.name or user.email or user.sub - - -async def initiated_by(auth_context: AuthContext) -> str: - return _extract_user_label(auth_context.user) - - router = make_job_router( _manager, start_request_model=StartLinkageRequest, diff --git a/deployments/entity-linkage/tests/test_start.py b/deployments/entity-linkage/tests/test_start.py index 8dde571b..f725081e 100644 --- a/deployments/entity-linkage/tests/test_start.py +++ b/deployments/entity-linkage/tests/test_start.py @@ -9,7 +9,6 @@ FieldCandidate, FieldDetailCandidate, MatchGroup, - User, ) from stitch.entity_linkage.errors import StitchAPIError from stitch.entity_linkage.linkage import ( @@ -19,7 +18,6 @@ _resolve_match_groups, run_linkage, ) -from stitch.entity_linkage.routers.start import _extract_user_label class FakeStitchApiClient(AbstractAsyncContextManager["FakeStitchApiClient"]): @@ -93,19 +91,6 @@ def test_normalize_country(country: str | None, expected: str | None) -> None: assert _normalize_country(country) == expected -def test_extract_user_label_prefers_name_then_email_then_sub() -> None: - assert ( - _extract_user_label( - User(id=1, sub="sub-1", email="a@example.com", name="Alice") - ) - == "Alice" - ) - assert ( - _extract_user_label(User(id=1, sub="sub-2", email="b@example.com", name="")) - == "b@example.com" - ) - - def test_group_duplicate_names_uses_casefold_and_strips_whitespace() -> None: items = [ FieldCandidate(id=1, name="Alpha", country="US"), diff --git a/deployments/stitch-llm/src/stitch/llm/auth.py b/deployments/stitch-llm/src/stitch/llm/auth.py index c059193e..a4c53ad8 100644 --- a/deployments/stitch-llm/src/stitch/llm/auth.py +++ b/deployments/stitch-llm/src/stitch/llm/auth.py @@ -16,6 +16,7 @@ require_permissions = _auth.require_permissions get_current_user = _auth.get_current_user get_request_auth_context = _auth.get_request_auth_context +initiated_by = _auth.initiated_by Claims = _auth.Claims CurrentUser = _auth.CurrentUser diff --git a/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py b/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py index bf4cd3bc..f02dff1c 100644 --- a/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py +++ b/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py @@ -5,8 +5,8 @@ from stitch.auth.permissions import SERVICE_LLM_SUGGEST from stitch.jobs import FingerprintPolicy, InMemoryJobStore, JobManager, make_job_router -from stitch.llm.auth import AuthContext, require_permissions -from stitch.llm.entities import FieldSuggestionResponse, User +from stitch.llm.auth import initiated_by, require_permissions +from stitch.llm.entities import FieldSuggestionResponse from stitch.llm.jobs import ( AllowedSuggestionField, FieldSuggestionParams, @@ -44,14 +44,6 @@ def _to_params(request: StartSuggestionRequest) -> FieldSuggestionParams: return FieldSuggestionParams(resource_id=request.resource_id, field=request.field) -def _extract_user_label(user: User) -> str: - return user.name or user.email or user.sub - - -async def initiated_by(auth_context: AuthContext) -> str: - return _extract_user_label(auth_context.user) - - _job_router = make_job_router( _manager, start_request_model=StartSuggestionRequest, diff --git a/packages/stitch-service/src/stitch/service/auth.py b/packages/stitch-service/src/stitch/service/auth.py index 612aa14f..8161be99 100644 --- a/packages/stitch-service/src/stitch/service/auth.py +++ b/packages/stitch-service/src/stitch/service/auth.py @@ -47,6 +47,11 @@ class ServiceUser(BaseModel): name: str role: str | None = None + @property + def label(self) -> str: + """Human label for attributing actions (e.g. a job's ``initiated_by``).""" + return self.name or self.email or self.sub + @dataclass(frozen=True, slots=True) class RequestAuthContext: @@ -268,10 +273,15 @@ async def get_request_auth_context( AuthContext = Annotated[RequestAuthContext, Depends(get_request_auth_context)] + async def initiated_by(auth_context: AuthContext) -> str: + """Caller label for attributing a job's ``initiated_by``.""" + return auth_context.user.label + self.get_token_claims = get_token_claims self.require_permissions = require_permissions self.get_current_user = get_current_user self.get_request_auth_context = get_request_auth_context + self.initiated_by = initiated_by self.Claims = Claims self.CurrentUser = CurrentUser self.AuthContext = AuthContext diff --git a/packages/stitch-service/tests/test_auth.py b/packages/stitch-service/tests/test_auth.py index 8c5315bb..2e5f1d8d 100644 --- a/packages/stitch-service/tests/test_auth.py +++ b/packages/stitch-service/tests/test_auth.py @@ -5,7 +5,9 @@ from stitch.service.auth import ( AuthMode, + RequestAuthContext, ServiceAuth, + ServiceUser, build_headers_provider, machine_token_headers_provider, relay_token_headers_provider, @@ -13,6 +15,22 @@ from stitch.client.auth import STITCH_CLIENT_BEARER_TOKEN_ENV_VAR +def test_service_user_label_prefers_name_then_email_then_sub() -> None: + assert ServiceUser(sub="s", email="e@example.com", name="Alice").label == "Alice" + assert ServiceUser(sub="s", email="e@example.com", name="").label == "e@example.com" + assert ServiceUser(sub="s", email="", name="").label == "s" + + +@pytest.mark.anyio +async def test_initiated_by_returns_user_label() -> None: + auth = ServiceAuth(is_auth_disabled=lambda: True) + ctx = RequestAuthContext( + user=ServiceUser(sub="s", email="e@example.com", name="Alice"), + bearer_token=None, + ) + assert await auth.initiated_by(ctx) == "Alice" + + # --------------------------------------------------------------------------- # # Downstream auth seam # --------------------------------------------------------------------------- # From 70cc8beb72ff5718027e97494bdf935dac38b6e1 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 18:57:13 +0200 Subject: [PATCH 20/26] Sanitize LLM errors rather than surfacing raw error to user --- deployments/stitch-llm/src/stitch/llm/jobs.py | 39 ++++++++++++++++--- .../tests/test_oil_gas_fields_api.py | 22 ++++++++++- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/deployments/stitch-llm/src/stitch/llm/jobs.py b/deployments/stitch-llm/src/stitch/llm/jobs.py index 3b8f8541..32624051 100644 --- a/deployments/stitch-llm/src/stitch/llm/jobs.py +++ b/deployments/stitch-llm/src/stitch/llm/jobs.py @@ -1,12 +1,16 @@ from __future__ import annotations +import logging from datetime import UTC, datetime from pydantic import BaseModel +from starlette.status import HTTP_404_NOT_FOUND +from stitch.client import StitchAPIError from stitch.llm.azure_responses import AzureResponsesClient, extract_public_citations from stitch.llm.client import StitchApiClient from stitch.llm.entities import FieldSuggestionResponse +from stitch.llm.errors import AzureResponsesError from stitch.llm.settings import get_settings from stitch.llm.suggestions import ( AllowedSuggestionField, @@ -17,6 +21,8 @@ sanitize_and_validate_suggested_value, ) +logger = logging.getLogger(__name__) + PLACEHOLDER_LLM_VALUE = ":warning: placeholder LLM value" PLACEHOLDER_LLM_MODEL = "placeholder-llm" @@ -40,8 +46,21 @@ async def run_suggestion(params: FieldSuggestionParams) -> FieldSuggestionRespon field = params.field observed_at = datetime.now(UTC) - async with StitchApiClient() as stitch_client: - detail_view = await stitch_client.get_oil_gas_field_detail(resource_id) + try: + async with StitchApiClient() as stitch_client: + detail_view = await stitch_client.get_oil_gas_field_detail(resource_id) + except StitchAPIError as exc: + if exc.status_code == HTTP_404_NOT_FOUND: + raise StitchAPIError( + f"Resource {resource_id} was not found.", status_code=404 + ) from exc + # Don't leak raw downstream response text into the user-facing job + # record; log the detail and surface a generic summary (as the old + # synchronous endpoint did with its 502). + logger.exception("Stitch API request failed for resource %s", resource_id) + raise StitchAPIError( + "Failed to fetch resource detail from Stitch API." + ) from exc # Expected behavior: if the field is already populated this raises and the # run is recorded as a failed job (surfaced in the UI as a failed run), @@ -79,11 +98,19 @@ async def run_suggestion(params: FieldSuggestionParams) -> FieldSuggestionRespon foundry_response={}, ) - async with AzureResponsesClient() as llm_client: - llm_result = await llm_client.generate_field_suggestion( - field=field, - input_messages=input_messages, + try: + async with AzureResponsesClient() as llm_client: + llm_result = await llm_client.generate_field_suggestion( + field=field, + input_messages=input_messages, + ) + except AzureResponsesError as exc: + # Same rationale: keep raw LLM-transport detail out of the user-facing + # record, log the detail, surface a generic summary. + logger.exception( + "LLM request failed for resource %s field %s", resource_id, field ) + raise AzureResponsesError("The language model request failed.") from exc parsed = parse_field_suggestion_response(llm_result.output_text) citations = extract_public_citations(llm_result.response_payload) if parsed.value is None or not citations: diff --git a/deployments/stitch-llm/tests/test_oil_gas_fields_api.py b/deployments/stitch-llm/tests/test_oil_gas_fields_api.py index 0df54ee5..b3053125 100644 --- a/deployments/stitch-llm/tests/test_oil_gas_fields_api.py +++ b/deployments/stitch-llm/tests/test_oil_gas_fields_api.py @@ -347,7 +347,27 @@ def test_job_fails_on_stitch_404( final = _run(test_client) assert final["state"] == "failed" - assert "missing" in final["error"] + assert "not found" in final["error"] + assert "42" in final["error"] + + +def test_job_sanitizes_non_404_downstream_error( + test_client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + stitch_client = FakeStitchApiClient( + error=StitchAPIError( + "GET /oil-gas-fields/42/detail failed with status 500: secret-internal-trace", + status_code=500, + ) + ) + install_fakes(monkeypatch, stitch_client=stitch_client) + + final = _run(test_client) + assert final["state"] == "failed" + # The raw downstream text must not leak into the user-facing record. + assert "secret-internal-trace" not in final["error"] + assert "Failed to fetch resource detail" in final["error"] def test_job_fails_on_missing_azure_config( From 7f41c16029360f0481cc2705f6ae05203cac18cc Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 19:03:10 +0200 Subject: [PATCH 21/26] reuse error parsing in frontend --- .../stitch-frontend/src/queries/api.js | 5 +++- .../stitch-frontend/src/queries/jobs.js | 27 ++++--------------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/deployments/stitch-frontend/src/queries/api.js b/deployments/stitch-frontend/src/queries/api.js index 77473fc9..473792b6 100644 --- a/deployments/stitch-frontend/src/queries/api.js +++ b/deployments/stitch-frontend/src/queries/api.js @@ -76,7 +76,10 @@ function formatApiErrorDetail(detail, fallbackStatus) { return `HTTP error! status: ${fallbackStatus}`; } -async function getErrorDetail(response) { +// Extract a human-readable error detail from a failed response. Canonical +// parser shared with the job client (queries/jobs.js) so every path surfaces +// the same message for a given backend response. +export async function getErrorDetail(response) { const fallback = formatApiErrorDetail(null, response.status); try { diff --git a/deployments/stitch-frontend/src/queries/jobs.js b/deployments/stitch-frontend/src/queries/jobs.js index c9fedd98..95e7c23f 100644 --- a/deployments/stitch-frontend/src/queries/jobs.js +++ b/deployments/stitch-frontend/src/queries/jobs.js @@ -6,29 +6,12 @@ // `${stitchLlmBaseUrl}/oil-gas-fields`, entity-linkage jobs at // `${entityLinkageBaseUrl}`. +import { getErrorDetail } from "./api"; + +// Build an Error from a failed response, reusing the shared detail parser so +// job and CRUD paths surface identical messages. async function errorFromResponse(response) { - let detail = response.statusText || `HTTP error! status: ${response.status}`; - try { - const text = await response.text(); - if (text) { - try { - const body = JSON.parse(text); - const parsed = body?.detail; - if (typeof parsed === "string" && parsed) { - detail = parsed; - } else if (parsed != null) { - detail = JSON.stringify(parsed, null, 2); - } else { - detail = text; - } - } catch { - detail = text; - } - } - } catch { - // fall back to statusText - } - const error = new Error(detail); + const error = new Error(await getErrorDetail(response)); error.status = response.status; return error; } From 3edb8efbe52f06aa1a985c2b1681c514bf60a2a2 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 19:12:41 +0200 Subject: [PATCH 22/26] unify call parameters for jobs --- .../stitch/entity_linkage/routers/start.py | 16 ---- .../src/stitch/llm/routers/oil_gas_fields.py | 24 +----- packages/stitch-jobs/README.md | 4 +- .../stitch-jobs/src/stitch/jobs/routers.py | 79 ++++++++++++------- packages/stitch-jobs/tests/test_router.py | 26 +++++- 5 files changed, 79 insertions(+), 70 deletions(-) diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py b/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py index 18b50dfc..c8195329 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py @@ -3,7 +3,6 @@ from datetime import timedelta from fastapi import Depends -from pydantic import Field from stitch.auth.permissions import SERVICE_ENTITY_LINKAGE_RUN from stitch.jobs import ( FingerprintPolicy, @@ -34,24 +33,9 @@ def get_job_manager() -> JobManager[LinkageParams, LinkageResult]: return _manager -class StartLinkageRequest(LinkageParams): - force: bool = Field( - default=False, - description="Re-run even if a recent identical run exists.", - ) - - -def _to_params(request: StartLinkageRequest) -> LinkageParams: - # `force` is dropped here so it never participates in the dedup key. - return LinkageParams(**request.model_dump(exclude={"force"})) - - router = make_job_router( _manager, - start_request_model=StartLinkageRequest, params_model=LinkageParams, - to_params=_to_params, - force_attr="force", result_model=LinkageResult, dependencies=[Depends(require_permissions(SERVICE_ENTITY_LINKAGE_RUN))], initiated_by=initiated_by, diff --git a/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py b/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py index f02dff1c..61ab69c0 100644 --- a/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py +++ b/deployments/stitch-llm/src/stitch/llm/routers/oil_gas_fields.py @@ -1,17 +1,12 @@ from __future__ import annotations from fastapi import APIRouter, Depends -from pydantic import BaseModel, Field from stitch.auth.permissions import SERVICE_LLM_SUGGEST from stitch.jobs import FingerprintPolicy, InMemoryJobStore, JobManager, make_job_router from stitch.llm.auth import initiated_by, require_permissions from stitch.llm.entities import FieldSuggestionResponse -from stitch.llm.jobs import ( - AllowedSuggestionField, - FieldSuggestionParams, - run_suggestion, -) +from stitch.llm.jobs import FieldSuggestionParams, run_suggestion # Suggestions are tracked per (resource_id, field) with no expiry: once a pair # has a result it is reused indefinitely (decoupled from the original caller, so @@ -30,26 +25,9 @@ def get_job_manager() -> JobManager[FieldSuggestionParams, FieldSuggestionRespon return _manager -class StartSuggestionRequest(BaseModel): - resource_id: int - field: AllowedSuggestionField - force: bool = Field( - default=False, - description="Re-run even if a suggestion for this (resource, field) exists.", - ) - - -def _to_params(request: StartSuggestionRequest) -> FieldSuggestionParams: - # `force` is intentionally dropped so it never participates in the dedup key. - return FieldSuggestionParams(resource_id=request.resource_id, field=request.field) - - _job_router = make_job_router( _manager, - start_request_model=StartSuggestionRequest, params_model=FieldSuggestionParams, - to_params=_to_params, - force_attr="force", result_model=FieldSuggestionResponse, dependencies=[Depends(require_permissions(SERVICE_LLM_SUGGEST))], initiated_by=initiated_by, diff --git a/packages/stitch-jobs/README.md b/packages/stitch-jobs/README.md index da67bd6c..26d66511 100644 --- a/packages/stitch-jobs/README.md +++ b/packages/stitch-jobs/README.md @@ -34,11 +34,13 @@ manager = JobManager( ) router = make_job_router( manager, - start_request_model=StartRequest, + params_model=EtlParams, # request body + dedup params result_model=EtlResult, dependencies=[Depends(require_permissions(SOURCE_WRITE))], initiated_by=current_user_label, ) +# /start gains a `force` field automatically (force=True by default); set it to +# bypass dedup. The router strips `force` before computing the dedup key. ``` ## Scope diff --git a/packages/stitch-jobs/src/stitch/jobs/routers.py b/packages/stitch-jobs/src/stitch/jobs/routers.py index 9c992591..357b8e41 100644 --- a/packages/stitch-jobs/src/stitch/jobs/routers.py +++ b/packages/stitch-jobs/src/stitch/jobs/routers.py @@ -7,7 +7,7 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException, Query, Response -from pydantic import BaseModel +from pydantic import BaseModel, Field, create_model from starlette.status import HTTP_200_OK, HTTP_202_ACCEPTED, HTTP_404_NOT_FOUND from .manager import JobManager @@ -19,35 +19,57 @@ def make_job_router( manager: JobManager, *, - start_request_model: type[BaseModel], + params_model: type[BaseModel], result_model: type[BaseModel], - params_model: type[BaseModel] | None = None, - to_params: Callable[[Any], BaseModel] | None = None, + force: bool = True, dependencies: Sequence[Any] = (), initiated_by: Callable[..., Awaitable[str | None] | str | None] | None = None, - force_attr: str | None = None, tags: Sequence[str] | None = None, default_list_limit: int = 20, ) -> APIRouter: - """Build a reusable ``/start`` + ``/status`` + ``/jobs`` router for a job. - - ``start_request_model`` is the POST body; ``result_model`` is what - ``run_fn`` returns. By default the request body *is* the params; pass - ``params_model`` + ``to_params`` when the stored params differ from the - wire request. ``dependencies`` is where the service plugs in its permission - gate (e.g. ``[Depends(require_permissions(...))]``); ``initiated_by`` is an - optional dependency returning the caller's display label. - - ``force_attr`` names a boolean field on the request body that, when true, - bypasses dedup and forces a fresh run. Keep that field out of ``params`` (via - ``to_params``) so it never participates in the dedup key. + """Build a reusable ``/start`` + ``/status`` + ``/jobs`` + ``/find`` router. + + ``params_model`` is the request body *and* the dedup params; ``result_model`` + is what ``run_fn`` returns. ``dependencies`` is where the service plugs in + its permission gate (e.g. ``[Depends(require_permissions(...))]``); + ``initiated_by`` is an optional dependency returning the caller's label. + + When ``force`` is true (default) the request body gains a ``force: bool`` + field; setting it bypasses dedup and starts a fresh run. The router strips + ``force`` before computing the dedup key, so it can never pollute that key — + services get force without re-deriving the wrapper/strip boilerplate. """ - params_model = params_model or start_request_model - to_params = to_params or (lambda request: request) resolve_initiated_by = initiated_by or (lambda: None) - record_model = JobRecord[params_model, result_model] + if force: + # Synthesize " + force" so callers send/declare just the params. + request_model = create_model( + f"{params_model.__name__}StartRequest", + __base__=params_model, + force=( + bool, + Field( + default=False, + description="Re-run even if a matching recent run exists.", + ), + ), + ) + + def to_params(request: BaseModel) -> BaseModel: + return params_model(**request.model_dump(exclude={"force"})) + + def extract_force(request: BaseModel) -> bool: + return bool(getattr(request, "force", False)) + else: + request_model = params_model + + def to_params(request: BaseModel) -> BaseModel: + return request + + def extract_force(request: BaseModel) -> bool: + return False + router = APIRouter(tags=list(tags) if tags else None) @router.post( @@ -57,7 +79,7 @@ def make_job_router( dependencies=list(dependencies), ) async def start( - request: start_request_model, + request: request_model, response: Response, initiated_by_label: Any = Depends(resolve_initiated_by), ): @@ -67,12 +89,10 @@ async def start( when a recent/active run with the same dedup key is found (so a second caller observes that run rather than starting a duplicate). """ - params = to_params(request) - # Default to False so a mis-set force_attr degrades to "no force" - # rather than raising AttributeError (500). - force = bool(getattr(request, force_attr, False)) if force_attr else False record, created = await manager.start( - params, initiated_by=initiated_by_label, force=force + to_params(request), + initiated_by=initiated_by_label, + force=extract_force(request), ) if not created: response.status_code = HTTP_200_OK @@ -96,12 +116,13 @@ async def jobs( return await manager.list(limit=limit) @router.post("/find", response_model=list[record_model]) - async def find(request: start_request_model): + async def find(request: request_model): """Return the runs matching a request's params (same dedup policy as ``/start``), newest first — so a caller can discover/reuse the existing run for exactly these params without scanning the whole job list. """ - params = to_params(request) - return await manager.list_for_params(params, limit=default_list_limit) + return await manager.list_for_params( + to_params(request), limit=default_list_limit + ) return router diff --git a/packages/stitch-jobs/tests/test_router.py b/packages/stitch-jobs/tests/test_router.py index 87126caa..f04fa124 100644 --- a/packages/stitch-jobs/tests/test_router.py +++ b/packages/stitch-jobs/tests/test_router.py @@ -33,7 +33,7 @@ def build_app(manager: JobManager, **kwargs) -> FastAPI: app = FastAPI() router = make_job_router( manager, - start_request_model=StartRequest, + params_model=StartRequest, result_model=Result, **kwargs, ) @@ -123,6 +123,30 @@ async def run(params: StartRequest) -> Result: assert {job["params"]["name"] for job in listed} == {"a", "b"} +def test_synthesized_force_field_bypasses_dedup() -> None: + async def run(params: StartRequest) -> Result: + return Result(value=1) + + # No force_attr wiring — make_job_router adds the `force` field itself. + app = build_app(JobManager(run, policy=FingerprintPolicy(), recent_within=None)) + + with TestClient(app) as client: + first = client.post("/api/v1/start", json={"name": "a"}) + _poll(client, first.json()["job_id"]) + + # Same params, no force → reuses the prior run. + reused = client.post("/api/v1/start", json={"name": "a"}) + assert reused.status_code == 200 + assert reused.json()["job_id"] == first.json()["job_id"] + + # force=true → a fresh run, and `force` never lands in the dedup params. + forced = client.post("/api/v1/start", json={"name": "a", "force": True}) + assert forced.status_code == 202 + assert forced.json()["job_id"] != first.json()["job_id"] + assert forced.json()["params"] == {"name": "a"} + _poll(client, forced.json()["job_id"]) + + def test_find_returns_runs_matching_params() -> None: async def run(params: StartRequest) -> Result: return Result(value=len(params.name)) From 96fec46587daf60ae2f4aa3450754830d75bb5fa Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 23 Jun 2026 19:51:56 +0200 Subject: [PATCH 23/26] Decouple auth and OIDC configs --- .../src/stitch/service/__init__.py | 2 ++ .../stitch-service/src/stitch/service/auth.py | 30 +++++++++++++++---- packages/stitch-service/tests/test_auth.py | 28 +++++++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/packages/stitch-service/src/stitch/service/__init__.py b/packages/stitch-service/src/stitch/service/__init__.py index 03c7778d..e4812e7c 100644 --- a/packages/stitch-service/src/stitch/service/__init__.py +++ b/packages/stitch-service/src/stitch/service/__init__.py @@ -13,6 +13,7 @@ RequestAuthContext, ServiceAuth, ServiceUser, + TokenValidator, build_headers_provider, machine_token_headers_provider, relay_token_headers_provider, @@ -30,6 +31,7 @@ "RequestAuthContext", "ServiceAuth", "ServiceUser", + "TokenValidator", "build_headers_provider", "create_app", "format_started_at", diff --git a/packages/stitch-service/src/stitch/service/auth.py b/packages/stitch-service/src/stitch/service/auth.py index 8161be99..ba21d768 100644 --- a/packages/stitch-service/src/stitch/service/auth.py +++ b/packages/stitch-service/src/stitch/service/auth.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass from enum import Enum -from typing import Annotated, Literal, NoReturn +from typing import Annotated, Literal, NoReturn, Protocol from fastapi import Depends, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -157,6 +157,18 @@ def _permission_exception_handler(exc: InsufficientPermissionsError) -> NoReturn raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=exc.detail) +class TokenValidator(Protocol): + """Turns a bearer token into claims — the seam ``ServiceAuth`` depends on. + + ``stitch.auth.JWTValidator`` satisfies this. Injecting a custom validator + lets a service run with auth *enabled* without OIDC being configured (tests, + or a non-OIDC verifier), decoupling "is auth on?" from "is OIDC configured?". + """ + + def validate(self, token: str) -> TokenClaims: + """Return the validated claims, or raise ``AuthError`` on failure.""" + + class ServiceAuth: """Inbound auth wiring shared by Stitch services. @@ -168,6 +180,10 @@ class ServiceAuth: Config seams: - ``is_auth_disabled``: callable read per request; when true, all requests resolve to ``dev_claims`` (local-dev bypass). + - ``validator``: a :class:`TokenValidator`. Inject one to run auth-enabled + without OIDC config; when omitted, one is built lazily from + ``oidc_settings_factory`` (the production default). This is what keeps + the OIDC-config seam independent of the dev auth-disabled bypass. - ``user_factory``: maps validated claims to a user (override to hit a DB). - ``oidc_settings_factory`` / ``dev_claims``: rarely overridden. """ @@ -176,6 +192,7 @@ def __init__( self, *, is_auth_disabled: Callable[[], bool], + validator: TokenValidator | None = None, oidc_settings_factory: Callable[[], OIDCSettings] = OIDCSettings, dev_claims: TokenClaims | None = None, user_factory: Callable[[TokenClaims], ServiceUser] = _default_user_from_claims, @@ -185,7 +202,7 @@ def __init__( self._dev_claims = dev_claims if dev_claims is not None else DEFAULT_DEV_CLAIMS self._user_factory = user_factory self._oidc_settings: OIDCSettings | None = None - self._validator: JWTValidator | None = None + self._validator: TokenValidator | None = validator # auto_error=False so a missing header doesn't 403 before our handler # runs (and so AUTH_DISABLED can short-circuit). @@ -291,7 +308,9 @@ def oidc_settings(self) -> OIDCSettings: self._oidc_settings = self._oidc_settings_factory() return self._oidc_settings - def _jwt_validator(self) -> JWTValidator: + def _jwt_validator(self) -> TokenValidator: + # Use the injected validator if provided; otherwise build the default + # OIDC-backed one lazily (this is the only place OIDC settings are read). if self._validator is None: self._validator = JWTValidator(self.oidc_settings()) return self._validator @@ -300,5 +319,6 @@ def validate_auth_config_at_startup(self) -> None: if self._is_auth_disabled(): logger.warning("Auth is disabled — all requests use dev credentials") return - # Fail fast if OIDC config is invalid. - self.oidc_settings() + # Fail fast if the validator can't be built (e.g. OIDC misconfigured). + # An injected validator skips OIDC entirely. + self._jwt_validator() diff --git a/packages/stitch-service/tests/test_auth.py b/packages/stitch-service/tests/test_auth.py index 2e5f1d8d..7243004b 100644 --- a/packages/stitch-service/tests/test_auth.py +++ b/packages/stitch-service/tests/test_auth.py @@ -137,3 +137,31 @@ def claims() -> TokenClaims: assert response.status_code == 200 assert response.json()["bearer_token"] == "caller-jwt" + + +class _StubValidator: + """A TokenValidator that accepts any token as a fixed user (no OIDC).""" + + def validate(self, token: str) -> TokenClaims: + return TokenClaims( + sub=f"stub|{token}", name="Stub", permissions=frozenset({SOURCE_WRITE}) + ) + + +def test_injected_validator_runs_auth_enabled_without_oidc() -> None: + # No OIDC env, no auth-disabled bypass — an injected validator is the only + # thing needed, proving OIDC config is decoupled from the dev bypass. If the + # two were still coupled, building the default OIDC validator here would fail. + auth = ServiceAuth(is_auth_disabled=lambda: False, validator=_StubValidator()) + # Should not raise (no OIDCSettings construction). + auth.validate_auth_config_at_startup() + app = build_app(auth) + + with TestClient(app) as client: + me = client.get("/me", headers={"Authorization": "Bearer abc"}) + assert me.status_code == 200 + assert me.json()["sub"] == "stub|abc" + assert ( + client.post("/guarded", headers={"Authorization": "Bearer abc"}).status_code + == 200 + ) From 21cae577593f1ebe058e5a7d890e2a0be6f97fb6 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 24 Jun 2026 15:49:04 +0200 Subject: [PATCH 24/26] Ensure AsyncClient closes --- .../tests/test_downstream_client.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/packages/stitch-service/tests/test_downstream_client.py b/packages/stitch-service/tests/test_downstream_client.py index d39ca717..2e0550af 100644 --- a/packages/stitch-service/tests/test_downstream_client.py +++ b/packages/stitch-service/tests/test_downstream_client.py @@ -1,7 +1,8 @@ """Integration-level checks that the downstream auth modes actually attach the expected Authorization header on outgoing requests via AsyncStitchClient.""" -from collections.abc import Callable, Mapping +from collections.abc import AsyncIterator, Callable, Mapping +from contextlib import asynccontextmanager import httpx import pytest @@ -11,18 +12,24 @@ from stitch.service.auth import AuthMode, build_headers_provider -def _capturing_client( +@asynccontextmanager +async def _capturing_client( seen: dict, headers_provider: Callable[[], Mapping[str, str]] -) -> AsyncStitchClient: +) -> AsyncIterator[AsyncStitchClient]: def handler(request: httpx.Request) -> httpx.Response: seen["authorization"] = request.headers.get("Authorization") return httpx.Response(200, json={}) + # AsyncStitchClient does not own (or close) an injected client, so close the + # raw transport ourselves to avoid leaking it. raw = httpx.AsyncClient( transport=httpx.MockTransport(handler), base_url="http://downstream.test/api/v1", ) - return AsyncStitchClient(client=raw, headers_provider=headers_provider) + try: + yield AsyncStitchClient(client=raw, headers_provider=headers_provider) + finally: + await raw.aclose() @pytest.mark.anyio From 33abede96af28211114e2341f202704ea87e503b Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 24 Jun 2026 17:51:48 +0200 Subject: [PATCH 25/26] Extract OTel from API into new package, and call in new pkgs --- Makefile | 14 +- deployments/api/pyproject.toml | 2 + .../src/stitch/api/observability/tracing.py | 159 ++++------------ .../api/tests/observability/test_tracing.py | 3 +- deployments/entity-linkage/conftest.py | 6 + deployments/entity-linkage/pyproject.toml | 2 + .../src/stitch/entity_linkage/main.py | 2 + .../src/stitch/entity_linkage/settings.py | 5 +- deployments/stitch-llm/conftest.py | 6 + deployments/stitch-llm/pyproject.toml | 2 + deployments/stitch-llm/src/stitch/llm/main.py | 2 + .../stitch-llm/src/stitch/llm/settings.py | 5 +- packages/stitch-jobs/pyproject.toml | 2 + .../stitch-jobs/src/stitch/jobs/manager.py | 57 ++++-- packages/stitch-jobs/tests/test_tracing.py | 87 +++++++++ packages/stitch-observability/README.md | 36 ++++ packages/stitch-observability/pyproject.toml | 31 ++++ .../src/stitch/observability/__init__.py | 29 +++ .../src/stitch/observability/settings.py | 23 +++ .../src/stitch/observability/tracing.py | 170 ++++++++++++++++++ .../tests/test_tracing.py | 64 +++++++ packages/stitch-service/pyproject.toml | 2 + .../stitch-service/src/stitch/service/app.py | 40 ++++- pyproject.toml | 1 + uv.lock | 60 +++++++ 25 files changed, 658 insertions(+), 152 deletions(-) create mode 100644 deployments/entity-linkage/conftest.py create mode 100644 deployments/stitch-llm/conftest.py create mode 100644 packages/stitch-jobs/tests/test_tracing.py create mode 100644 packages/stitch-observability/README.md create mode 100644 packages/stitch-observability/pyproject.toml create mode 100644 packages/stitch-observability/src/stitch/observability/__init__.py create mode 100644 packages/stitch-observability/src/stitch/observability/settings.py create mode 100644 packages/stitch-observability/src/stitch/observability/tracing.py create mode 100644 packages/stitch-observability/tests/test_tracing.py diff --git a/Makefile b/Makefile index 7a6d99f3..bcc784ac 100644 --- a/Makefile +++ b/Makefile @@ -121,9 +121,16 @@ pkg-test-jobs: pkg-test-exact-jobs: $(MAKE) uv-test-target-exact PKG=stitch-jobs TEST_PATH=packages/stitch-jobs -pkg-build: pkg-build-auth pkg-build-client pkg-build-models pkg-build-ogsi pkg-build-service pkg-build-jobs -pkg-test: pkg-test-auth pkg-test-client pkg-test-models pkg-test-ogsi pkg-test-service pkg-test-jobs -pkg-test-exact: pkg-test-exact-auth pkg-test-exact-client pkg-test-exact-models pkg-test-exact-ogsi pkg-test-exact-service pkg-test-exact-jobs +pkg-build-observability: + $(UV) build --package stitch-observability +pkg-test-observability: + $(MAKE) uv-test-target PKG=stitch-observability TEST_PATH=packages/stitch-observability +pkg-test-exact-observability: + $(MAKE) uv-test-target-exact PKG=stitch-observability TEST_PATH=packages/stitch-observability + +pkg-build: pkg-build-auth pkg-build-client pkg-build-models pkg-build-ogsi pkg-build-service pkg-build-jobs pkg-build-observability +pkg-test: pkg-test-auth pkg-test-client pkg-test-models pkg-test-ogsi pkg-test-service pkg-test-jobs pkg-test-observability +pkg-test-exact: pkg-test-exact-auth pkg-test-exact-client pkg-test-exact-models pkg-test-exact-ogsi pkg-test-exact-service pkg-test-exact-jobs pkg-test-exact-observability # --------------------------------------------------------------------- # Deployments @@ -307,6 +314,7 @@ follow-stack-logs: pkg-build-ogsi pkg-test-ogsi pkg-test-exact-ogsi \ pkg-build-service pkg-test-service pkg-test-exact-service \ pkg-build-jobs pkg-test-jobs pkg-test-exact-jobs \ + pkg-build-observability pkg-test-observability pkg-test-exact-observability \ \ # API api-build api-test api-test-exact api-dev stack-api-dev \ diff --git a/deployments/api/pyproject.toml b/deployments/api/pyproject.toml index 8fe7745a..da164baa 100644 --- a/deployments/api/pyproject.toml +++ b/deployments/api/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "sqlalchemy>=2.0.44", "stitch-auth", "stitch-models", + "stitch-observability", "stitch-ogsi", ] @@ -47,4 +48,5 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } stitch-models = { workspace = true } +stitch-observability = { workspace = true } stitch-ogsi = { workspace = true } diff --git a/deployments/api/src/stitch/api/observability/tracing.py b/deployments/api/src/stitch/api/observability/tracing.py index 2d6fc5bd..a49d1460 100644 --- a/deployments/api/src/stitch/api/observability/tracing.py +++ b/deployments/api/src/stitch/api/observability/tracing.py @@ -1,139 +1,46 @@ -"""OpenTelemetry tracing setup for the API. - -Span *generation* is handled by auto-instrumentation (FastAPI + SQLAlchemy); -this module owns span *export*, which is configurable: - -* ``console`` (default) — finished spans are emitted as structured log records - through the existing :class:`JsonFormatter` (see :mod:`logging_config`), so - local dev gets full trace data on stdout **without** running the collector / - Jaeger sidecars. This is the "log what OTel would send" path. -* ``otlp`` — spans are shipped via OTLP/gRPC to the collector (``→`` Jaeger). -* ``none`` — tracing is disabled entirely. - -Sampling uses ``ParentBased(root=TraceIdRatioBased(ratio))`` so the API honors -an upstream caller's sampling decision (propagated via the W3C ``traceparent`` -header) and only samples independently when it is the root of a trace. The -ratio defaults to 1.0 (capture everything) for local dev. +"""OpenTelemetry tracing for the API — a thin wrapper over the shared +``stitch.observability`` package (one source of truth across services). + +Keeps this module's historical surface (``SERVICE_NAME``, +``configure_tracing(settings)``, ``instrument_fastapi``, ``instrument_sqlalchemy``, +``LoggingSpanExporter``) so call sites (``main.py``, ``db/config.py``) and tests +don't change. The API's query-timing / request-logging / sinks layer stays +API-specific (it hangs off the SQLAlchemy engine). """ -import logging from typing import TYPE_CHECKING -from opentelemetry import trace -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter -from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor -from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor -from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import ( - BatchSpanProcessor, - SimpleSpanProcessor, - SpanExporter, - SpanExportResult, +from stitch.observability import ( + LoggingSpanExporter, + configure_tracing as _configure_tracing, + instrument_fastapi, + instrument_sqlalchemy, ) -from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased if TYPE_CHECKING: - from collections.abc import Sequence - - from fastapi import FastAPI - from opentelemetry.sdk.trace import ReadableSpan - from sqlalchemy.engine import Engine + from opentelemetry.sdk.trace import TracerProvider from ..settings import Settings SERVICE_NAME = "stitch-api" -_span_logger = logging.getLogger("stitch.api.observability.trace") - - -class LoggingSpanExporter(SpanExporter): - """Export finished spans as structured log records instead of shipping them - to a collector. - - Each span becomes one ``stitch.api.observability.trace`` log record whose - ``event`` dict the :class:`JsonFormatter` flattens to the top level, so - fields like ``trace_id`` / ``duration_ms`` are directly queryable and sit - alongside the request / query events on the same stdout stream. - """ - - def export(self, spans: "Sequence[ReadableSpan]") -> SpanExportResult: - for span in spans: - ctx = span.get_span_context() - parent = span.parent - duration_ms = ( - round((span.end_time - span.start_time) / 1e6, 2) - if span.end_time is not None and span.start_time is not None - else None - ) - _span_logger.info( - "span", - extra={ - "event": { - "span_name": span.name, - "trace_id": format(ctx.trace_id, "032x"), - "span_id": format(ctx.span_id, "016x"), - "parent_span_id": format(parent.span_id, "016x") - if parent is not None - else None, - "kind": span.kind.name, - "duration_ms": duration_ms, - "status": span.status.status_code.name, - "attributes": dict(span.attributes or {}), - } - }, - ) - return SpanExportResult.SUCCESS - - def force_flush(self, timeout_millis: int = 30_000) -> bool: - return True - - -def configure_tracing(settings: "Settings") -> TracerProvider | None: - """Install the global tracer provider, or return ``None`` if disabled. - - Call once at startup, before the first span is created. Idempotency is not - guaranteed — ``set_tracer_provider`` warns if called twice. - """ - if not settings.otel_enabled or settings.otel_traces_exporter == "none": - return None - - resource = Resource.create( - { - "service.name": SERVICE_NAME, - "service.version": settings.app_version or "unknown", - "deployment.environment": settings.environment_name, - } +__all__ = [ + "SERVICE_NAME", + "LoggingSpanExporter", + "configure_tracing", + "instrument_fastapi", + "instrument_sqlalchemy", +] + + +def configure_tracing(settings: "Settings") -> "TracerProvider | None": + """Install the API's global tracer provider, or ``None`` if disabled.""" + return _configure_tracing( + service_name=SERVICE_NAME, + enabled=settings.otel_enabled, + exporter=settings.otel_traces_exporter, + otlp_endpoint=settings.otel_exporter_otlp_endpoint, + sample_ratio=settings.otel_sample_ratio, + version=settings.app_version or "unknown", + environment=settings.environment_name, ) - sampler = ParentBased(root=TraceIdRatioBased(settings.otel_sample_ratio)) - provider = TracerProvider(resource=resource, sampler=sampler) - - if settings.otel_traces_exporter == "otlp": - # endpoint=None lets the exporter fall back to OTEL_EXPORTER_OTLP_ENDPOINT - # / the localhost default. - exporter = OTLPSpanExporter(endpoint=settings.otel_exporter_otlp_endpoint) - provider.add_span_processor(BatchSpanProcessor(exporter)) - else: # "console" — log spans to stdout, no sidecar required. - provider.add_span_processor(SimpleSpanProcessor(LoggingSpanExporter())) - - trace.set_tracer_provider(provider) - return provider - - -def instrument_fastapi(app: "FastAPI") -> None: - """Auto-instrument the FastAPI app (server spans + traceparent extraction). - - URL query strings are intentionally left intact — they're the diagnostic - payload for the performance work this serves. When a retained backend makes - aggregate PII a concern (cloud), scrub them at the collector's egress - (an ``attributes``/``redaction`` processor) rather than blinding local dev. - """ - FastAPIInstrumentor.instrument_app(app) - - -def instrument_sqlalchemy(engine: "Engine") -> None: - """Auto-instrument a (sync) SQLAlchemy engine for per-query spans. - - Pass ``async_engine.sync_engine`` for an ``AsyncEngine``. - """ - SQLAlchemyInstrumentor().instrument(engine=engine) diff --git a/deployments/api/tests/observability/test_tracing.py b/deployments/api/tests/observability/test_tracing.py index 27f59131..487b7749 100644 --- a/deployments/api/tests/observability/test_tracing.py +++ b/deployments/api/tests/observability/test_tracing.py @@ -15,7 +15,8 @@ from stitch.api.observability.tracing import LoggingSpanExporter, configure_tracing from stitch.api.settings import Settings -_TRACE_LOGGER = "stitch.api.observability.trace" +# Span log records now come from the shared stitch-observability exporter. +_TRACE_LOGGER = "stitch.observability.trace" @pytest.fixture diff --git a/deployments/entity-linkage/conftest.py b/deployments/entity-linkage/conftest.py new file mode 100644 index 00000000..343187ee --- /dev/null +++ b/deployments/entity-linkage/conftest.py @@ -0,0 +1,6 @@ +import os + +# Disable tracing for the suite before the app module imports and runs +# configure_tracing (mirrors the API's rootdir conftest). An env var set here +# wins over the .env file's value via pydantic-settings precedence. +os.environ.setdefault("OTEL_TRACES_EXPORTER", "none") diff --git a/deployments/entity-linkage/pyproject.toml b/deployments/entity-linkage/pyproject.toml index a043f2e3..b087cc81 100644 --- a/deployments/entity-linkage/pyproject.toml +++ b/deployments/entity-linkage/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "stitch-client", "stitch-jobs", "stitch-models", + "stitch-observability", "stitch-ogsi", "stitch-service", ] @@ -45,5 +46,6 @@ stitch-auth = { workspace = true } stitch-client = { workspace = true } stitch-jobs = { workspace = true } stitch-models = { workspace = true } +stitch-observability = { workspace = true } stitch-ogsi = { workspace = true } stitch-service = { workspace = true } diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/main.py b/deployments/entity-linkage/src/stitch/entity_linkage/main.py index a4f6c9dd..eaa96b11 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/main.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/main.py @@ -24,4 +24,6 @@ def _run_startup(app: FastAPI) -> None: routers=[health_router, start_router], cors_origins=[str(settings.frontend_origin_url)], on_startup=_run_startup, + service_name="stitch-entity-linkage", + otel=settings, ) diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/settings.py b/deployments/entity-linkage/src/stitch/entity_linkage/settings.py index 34dfe81a..b800c9b8 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/settings.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/settings.py @@ -2,10 +2,11 @@ from typing import ClassVar from pydantic import AnyHttpUrl, Field -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import SettingsConfigDict +from stitch.observability import OTelSettings -class Settings(BaseSettings): +class Settings(OTelSettings): log_level: str = Field(default="INFO", alias="ENTITY_LINKAGE_LOG_LEVEL") frontend_origin_url: AnyHttpUrl = Field( default="http://localhost:3000", diff --git a/deployments/stitch-llm/conftest.py b/deployments/stitch-llm/conftest.py new file mode 100644 index 00000000..343187ee --- /dev/null +++ b/deployments/stitch-llm/conftest.py @@ -0,0 +1,6 @@ +import os + +# Disable tracing for the suite before the app module imports and runs +# configure_tracing (mirrors the API's rootdir conftest). An env var set here +# wins over the .env file's value via pydantic-settings precedence. +os.environ.setdefault("OTEL_TRACES_EXPORTER", "none") diff --git a/deployments/stitch-llm/pyproject.toml b/deployments/stitch-llm/pyproject.toml index bf76d1d5..3ad32e4d 100644 --- a/deployments/stitch-llm/pyproject.toml +++ b/deployments/stitch-llm/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "stitch-client", "stitch-jobs", "stitch-models", + "stitch-observability", "stitch-ogsi", "stitch-service", ] @@ -46,5 +47,6 @@ stitch-auth = { workspace = true } stitch-client = { workspace = true } stitch-jobs = { workspace = true } stitch-models = { workspace = true } +stitch-observability = { workspace = true } stitch-ogsi = { workspace = true } stitch-service = { workspace = true } diff --git a/deployments/stitch-llm/src/stitch/llm/main.py b/deployments/stitch-llm/src/stitch/llm/main.py index 1c325f7e..1a6ce516 100644 --- a/deployments/stitch-llm/src/stitch/llm/main.py +++ b/deployments/stitch-llm/src/stitch/llm/main.py @@ -23,4 +23,6 @@ def _run_startup(app: FastAPI) -> None: routers=[health_router, oil_gas_fields_router], cors_origins=[str(settings.frontend_origin_url)], on_startup=_run_startup, + service_name="stitch-llm", + otel=settings, ) diff --git a/deployments/stitch-llm/src/stitch/llm/settings.py b/deployments/stitch-llm/src/stitch/llm/settings.py index 392a4b34..2fd3c829 100644 --- a/deployments/stitch-llm/src/stitch/llm/settings.py +++ b/deployments/stitch-llm/src/stitch/llm/settings.py @@ -8,10 +8,11 @@ SecretStr, field_validator, ) -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import SettingsConfigDict +from stitch.observability import OTelSettings -class Settings(BaseSettings): +class Settings(OTelSettings): log_level: str = Field( default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "STITCH_LLM_LOG_LEVEL"), diff --git a/packages/stitch-jobs/pyproject.toml b/packages/stitch-jobs/pyproject.toml index 5434f71b..38e808c7 100644 --- a/packages/stitch-jobs/pyproject.toml +++ b/packages/stitch-jobs/pyproject.toml @@ -6,6 +6,7 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "fastapi[standard-no-fastapi-cloud-cli]>=0.135.1", + "opentelemetry-api>=1.30.0", "pydantic>=2.12.5", ] @@ -28,4 +29,5 @@ dev = [ "pytest>=9.0.2", "pytest-anyio>=0.0.0", "httpx>=0.28.0", + "opentelemetry-sdk>=1.30.0", ] diff --git a/packages/stitch-jobs/src/stitch/jobs/manager.py b/packages/stitch-jobs/src/stitch/jobs/manager.py index f4c3f766..0a41d681 100644 --- a/packages/stitch-jobs/src/stitch/jobs/manager.py +++ b/packages/stitch-jobs/src/stitch/jobs/manager.py @@ -7,12 +7,20 @@ from typing import Generic from uuid import uuid4 +from opentelemetry import context as otel_context +from opentelemetry import trace +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode + from .models import P, R, JobRecord, JobState from .store import InMemoryJobStore, JobStore from .uniqueness import SingletonPolicy, UniquenessPolicy logger = logging.getLogger("stitch.jobs") +# No-op when no provider is configured (tracing disabled), so jobs behave +# identically whether or not the host service has tracing on. +_tracer = trace.get_tracer("stitch.jobs") + #: Terminal states that, by default, an existing run may be reused from. _DEFAULT_REUSABLE_TERMINAL = frozenset({JobState.succeeded, JobState.failed}) @@ -97,23 +105,46 @@ async def start( started_at=self._clock(), ) await self._store.create(record) - task = asyncio.create_task(self._run(record, params)) + # Capture the triggering request's span so the (detached) job run can + # link back to it without nesting under an already-finished request. + trigger = trace.get_current_span().get_span_context() + task = asyncio.create_task(self._run(record, params, trigger)) self._tasks.add(task) task.add_done_callback(self._tasks.discard) return record, True - async def _run(self, record: JobRecord[P, R], params: P) -> None: - try: - record.result = await self._run_fn(params) - record.state = JobState.succeeded - except Exception as exc: - # Broad on purpose: any run_fn failure is captured onto the record - # (state=failed, error set) rather than crashing the background task. - logger.exception("job %s failed", record.job_id) - record.error = str(exc) - record.state = JobState.failed - finally: - record.finished_at = self._clock() + async def _run( + self, record: JobRecord[P, R], params: P, trigger: SpanContext | None = None + ) -> None: + links = [Link(trigger)] if trigger is not None and trigger.is_valid else None + # New root span (empty parent context) so a reused/decoupled job isn't + # buried under one caller's request; the link makes it navigable from the + # trigger. No-op span when tracing is disabled. + with _tracer.start_as_current_span( + "job.run", + context=otel_context.Context(), + kind=SpanKind.INTERNAL, + links=links, + ) as span: + span.set_attribute("stitch.job.id", record.job_id) + if record.dedup_key is not None: + span.set_attribute("stitch.job.dedup_key", record.dedup_key) + if record.initiated_by is not None: + span.set_attribute("stitch.job.initiated_by", record.initiated_by) + try: + record.result = await self._run_fn(params) + record.state = JobState.succeeded + except Exception as exc: + # Broad on purpose: any run_fn failure is captured onto the record + # (state=failed, error set) rather than crashing the background task. + logger.exception("job %s failed", record.job_id) + record.error = str(exc) + record.state = JobState.failed + span.set_status(Status(StatusCode.ERROR, str(exc))) + span.record_exception(exc) + finally: + record.finished_at = self._clock() + span.set_attribute("stitch.job.state", record.state.value) def reset(self) -> None: """Cancel in-flight tasks and drop all run state. diff --git a/packages/stitch-jobs/tests/test_tracing.py b/packages/stitch-jobs/tests/test_tracing.py new file mode 100644 index 00000000..3929b8b6 --- /dev/null +++ b/packages/stitch-jobs/tests/test_tracing.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import asyncio + +import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) +from pydantic import BaseModel + +from stitch.jobs import JobManager, SingletonPolicy + + +class Params(BaseModel): + name: str + + +class Result(BaseModel): + value: int + + +# Sets the process-global provider once (the manager's module-level tracer is a +# proxy that resolves to it). This file sorts last in the package, so the +# earlier suites run with the default no-op tracer, unaffected. +@pytest.fixture(scope="module") +def span_exporter() -> InMemorySpanExporter: + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + return exporter + + +@pytest.fixture +def spans(span_exporter: InMemorySpanExporter) -> InMemorySpanExporter: + span_exporter.clear() + return span_exporter + + +async def _wait_terminal(manager: JobManager, job_id: str, *, timeout=2.0): + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + record = await manager.get(job_id) + if record is not None and record.is_terminal: + return record + await asyncio.sleep(0.005) + raise AssertionError("job did not finish in time") + + +@pytest.mark.anyio +async def test_job_run_emits_root_span_linked_to_trigger(spans) -> None: + async def run(params: Params) -> Result: + return Result(value=1) + + manager: JobManager[Params, Result] = JobManager(run, policy=SingletonPolicy()) + + tracer = trace.get_tracer("test") + with tracer.start_as_current_span("trigger") as trigger: + trigger_ctx = trigger.get_span_context() + record, _ = await manager.start(Params(name="a")) + await _wait_terminal(manager, record.job_id) + + job_spans = [s for s in spans.get_finished_spans() if s.name == "job.run"] + assert len(job_spans) == 1 + job_span = job_spans[0] + assert job_span.attributes["stitch.job.id"] == record.job_id + assert job_span.attributes["stitch.job.state"] == "succeeded" + # New root (not a child of the trigger), but linked back to it. + assert job_span.parent is None + assert any(link.context.span_id == trigger_ctx.span_id for link in job_span.links) + + +@pytest.mark.anyio +async def test_failed_job_span_has_error_status(spans) -> None: + async def run(params: Params) -> Result: + raise RuntimeError("boom") + + manager: JobManager[Params, Result] = JobManager(run, policy=SingletonPolicy()) + record, _ = await manager.start(Params(name="a")) + await _wait_terminal(manager, record.job_id) + + job_span = next(s for s in spans.get_finished_spans() if s.name == "job.run") + assert job_span.status.status_code.name == "ERROR" + assert job_span.attributes["stitch.job.state"] == "failed" diff --git a/packages/stitch-observability/README.md b/packages/stitch-observability/README.md new file mode 100644 index 00000000..97a8bc90 --- /dev/null +++ b/packages/stitch-observability/README.md @@ -0,0 +1,36 @@ +# stitch-observability + +Shared OpenTelemetry tracing setup + instrumentation for Stitch services, so +every service traces the same way and **interactions between services land in +one trace**. + +```python +from stitch.observability import ( + configure_tracing, instrument_fastapi, instrument_httpx, shutdown_tracing, + OTelSettings, +) + +provider = configure_tracing( + service_name="stitch-entity-linkage", + enabled=settings.otel_enabled, + exporter=settings.otel_traces_exporter, + otlp_endpoint=settings.otel_exporter_otlp_endpoint, + sample_ratio=settings.otel_sample_ratio, +) +if provider is not None: + instrument_fastapi(app) # on the constructed app, before it serves + instrument_httpx() # outbound calls inject W3C traceparent +# ... shutdown_tracing(provider) on exit +``` + +- **`OTelSettings`** — pydantic-settings mixin with the shared `OTEL_*` fields; + a service's `Settings` inherits it. +- **`instrument_httpx()`** — the propagation piece: outbound `httpx` calls carry + `traceparent`, so a downstream service (FastAPI-instrumented) continues the + same trace rather than starting a disconnected one. +- Exporter modes: `console` (spans → structured stdout logs, no sidecar), + `otlp` (→ collector/Jaeger), `none` (disabled). + +`stitch-service`'s `create_app` wires this automatically when given a +`service_name` + `OTelSettings`; `stitch-jobs` emits a `job.run` span per run +via the global tracer (no-op when tracing is off). diff --git a/packages/stitch-observability/pyproject.toml b/packages/stitch-observability/pyproject.toml new file mode 100644 index 00000000..9b95b739 --- /dev/null +++ b/packages/stitch-observability/pyproject.toml @@ -0,0 +1,31 @@ +[project] +name = "stitch-observability" +version = "0.1.0" +description = "Shared OpenTelemetry tracing setup + instrumentation for Stitch services" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "opentelemetry-sdk>=1.30.0", + "opentelemetry-exporter-otlp-proto-grpc>=1.30.0", + "opentelemetry-instrumentation-fastapi>=0.51b0", + "opentelemetry-instrumentation-sqlalchemy>=0.51b0", + "opentelemetry-instrumentation-httpx>=0.51b0", + "pydantic-settings>=2.11.0", +] + +[build-system] +requires = ["uv_build>=0.9.30,<0.10.0"] +build-backend = "uv_build" + +[tool.uv.build-backend] +module-name = "stitch.observability" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = ["-v", "--strict-markers", "--tb=short"] + +[dependency-groups] +dev = ["pytest>=9.0.2"] diff --git a/packages/stitch-observability/src/stitch/observability/__init__.py b/packages/stitch-observability/src/stitch/observability/__init__.py new file mode 100644 index 00000000..79251c2b --- /dev/null +++ b/packages/stitch-observability/src/stitch/observability/__init__.py @@ -0,0 +1,29 @@ +"""Shared OpenTelemetry tracing for Stitch services. + +`configure_tracing` builds the global provider (parametrized by ``service_name``); +`instrument_fastapi` / `instrument_httpx` / `instrument_sqlalchemy` auto-instrument +the relevant layers. httpx instrumentation is what propagates the W3C +``traceparent`` so a service's downstream calls join the same trace end-to-end. +""" + +from .settings import OTelSettings +from .tracing import ( + LoggingSpanExporter, + configure_tracing, + get_tracer, + instrument_fastapi, + instrument_httpx, + instrument_sqlalchemy, + shutdown_tracing, +) + +__all__ = [ + "LoggingSpanExporter", + "OTelSettings", + "configure_tracing", + "get_tracer", + "instrument_fastapi", + "instrument_httpx", + "instrument_sqlalchemy", + "shutdown_tracing", +] diff --git a/packages/stitch-observability/src/stitch/observability/settings.py b/packages/stitch-observability/src/stitch/observability/settings.py new file mode 100644 index 00000000..33e99217 --- /dev/null +++ b/packages/stitch-observability/src/stitch/observability/settings.py @@ -0,0 +1,23 @@ +from typing import Literal + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class OTelSettings(BaseSettings): + """Mixin of the shared ``OTEL_*`` tracing settings. + + A service's ``Settings`` inherits this so every service reads the same env + (``OTEL_ENABLED`` / ``OTEL_TRACES_EXPORTER`` / ``OTEL_EXPORTER_OTLP_ENDPOINT`` + / ``OTEL_SAMPLE_RATIO``), which are already shared across the compose network. + + Defaults: ``console`` exporter logs spans to stdout (no collector needed); + ``otlp`` ships to the collector; ``none`` disables tracing. ``otel_sample_ratio`` + feeds the root sampler (1.0 = capture everything); downstream spans honor the + upstream decision via ParentBased. + """ + + otel_enabled: bool = True + otel_traces_exporter: Literal["console", "otlp", "none"] = "console" + otel_exporter_otlp_endpoint: str | None = None + otel_sample_ratio: float = Field(default=1.0, ge=0.0, le=1.0) diff --git a/packages/stitch-observability/src/stitch/observability/tracing.py b/packages/stitch-observability/src/stitch/observability/tracing.py new file mode 100644 index 00000000..e52782e7 --- /dev/null +++ b/packages/stitch-observability/src/stitch/observability/tracing.py @@ -0,0 +1,170 @@ +"""Shared OpenTelemetry tracing setup for Stitch services. + +Span *generation* is handled by auto-instrumentation (FastAPI, httpx, +SQLAlchemy); this module owns span *export*, configurable via the exporter mode: + +* ``console`` (default) — finished spans are emitted as structured log records + (see :class:`LoggingSpanExporter`), so local dev gets full trace data on + stdout **without** running the collector / Jaeger sidecars. +* ``otlp`` — spans are shipped via OTLP/gRPC to the collector (``→`` Jaeger). +* ``none`` — tracing is disabled entirely. + +Sampling uses ``ParentBased(root=TraceIdRatioBased(ratio))`` so a service honors +an upstream caller's sampling decision (propagated via the W3C ``traceparent`` +header) and only samples independently when it is the root of a trace. +""" + +import logging +from typing import TYPE_CHECKING + +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, + SimpleSpanProcessor, + SpanExporter, + SpanExportResult, +) +from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased + +if TYPE_CHECKING: + from collections.abc import Sequence + + from fastapi import FastAPI + from opentelemetry.sdk.trace import ReadableSpan + from sqlalchemy.engine import Engine + +_span_logger = logging.getLogger("stitch.observability.trace") + + +def get_tracer(name: str) -> trace.Tracer: + """Return a tracer from the global provider (no-op when tracing is off).""" + return trace.get_tracer(name) + + +class LoggingSpanExporter(SpanExporter): + """Export finished spans as structured log records instead of shipping them + to a collector. + + Each span becomes one ``stitch.observability.trace`` log record whose + ``event`` dict a JSON log formatter can flatten, so fields like ``trace_id`` + / ``duration_ms`` sit alongside request / query events on the same stream. + """ + + def export(self, spans: "Sequence[ReadableSpan]") -> SpanExportResult: + for span in spans: + ctx = span.get_span_context() + parent = span.parent + duration_ms = ( + round((span.end_time - span.start_time) / 1e6, 2) + if span.end_time is not None and span.start_time is not None + else None + ) + _span_logger.info( + "span", + extra={ + "event": { + "span_name": span.name, + "trace_id": format(ctx.trace_id, "032x"), + "span_id": format(ctx.span_id, "016x"), + "parent_span_id": format(parent.span_id, "016x") + if parent is not None + else None, + "kind": span.kind.name, + "duration_ms": duration_ms, + "status": span.status.status_code.name, + "attributes": dict(span.attributes or {}), + } + }, + ) + return SpanExportResult.SUCCESS + + def force_flush(self, timeout_millis: int = 30_000) -> bool: + return True + + +def configure_tracing( + *, + service_name: str, + enabled: bool = True, + exporter: str = "console", + otlp_endpoint: str | None = None, + sample_ratio: float = 1.0, + version: str = "unknown", + environment: str = "unknown", +) -> TracerProvider | None: + """Install the global tracer provider, or return ``None`` if disabled. + + Call once at startup, before the first span is created. Idempotency is not + guaranteed — ``set_tracer_provider`` warns if called twice. + """ + if not enabled or exporter == "none": + return None + + resource = Resource.create( + { + "service.name": service_name, + "service.version": version or "unknown", + "deployment.environment": environment, + } + ) + sampler = ParentBased(root=TraceIdRatioBased(sample_ratio)) + provider = TracerProvider(resource=resource, sampler=sampler) + + if exporter == "otlp": + # endpoint=None lets the exporter fall back to OTEL_EXPORTER_OTLP_ENDPOINT + # / the localhost default. + provider.add_span_processor( + BatchSpanProcessor(OTLPSpanExporter(endpoint=otlp_endpoint)) + ) + else: # "console" — log spans to stdout, no sidecar required. + provider.add_span_processor(SimpleSpanProcessor(LoggingSpanExporter())) + + trace.set_tracer_provider(provider) + return provider + + +def shutdown_tracing(provider: TracerProvider | None) -> None: + """Flush and shut down the provider (e.g. a BatchSpanProcessor) on exit.""" + if provider is not None: + provider.shutdown() + + +def instrument_fastapi(app: "FastAPI") -> None: + """Auto-instrument a FastAPI app (server spans + traceparent extraction). + + Run on the constructed app before it serves requests — not inside a startup + hook, where middleware-stack timing makes it ineffective. Imported lazily so + the instrumentor's optional ``fastapi`` dependency is only required by + services that actually call this. + """ + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor.instrument_app(app) + + +def instrument_httpx() -> None: + """Auto-instrument httpx so outbound calls inject the W3C ``traceparent``. + + This is what links a service's downstream calls (via ``AsyncStitchClient`` / + the Azure client) into the same trace the receiving service continues. + Imported lazily so the instrumentor's optional ``httpx`` dependency is only + required by services that actually call this. + """ + from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor + + HTTPXClientInstrumentor().instrument() + + +def instrument_sqlalchemy(engine: "Engine") -> None: + """Auto-instrument a (sync) SQLAlchemy engine for per-query spans. + + Pass ``async_engine.sync_engine`` for an ``AsyncEngine``. Imported lazily so + services without SQLAlchemy (the instrumentor lists it as an optional + "instruments" dependency) don't need it installed to use this package. + """ + from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor + + SQLAlchemyInstrumentor().instrument(engine=engine) diff --git a/packages/stitch-observability/tests/test_tracing.py b/packages/stitch-observability/tests/test_tracing.py new file mode 100644 index 00000000..d328505e --- /dev/null +++ b/packages/stitch-observability/tests/test_tracing.py @@ -0,0 +1,64 @@ +import logging + +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) + +from stitch.observability import OTelSettings, configure_tracing +from stitch.observability.tracing import LoggingSpanExporter + + +def test_configure_tracing_disabled_returns_none() -> None: + assert configure_tracing(service_name="svc", enabled=False) is None + assert configure_tracing(service_name="svc", enabled=True, exporter="none") is None + + +def test_configure_tracing_builds_provider_with_resource() -> None: + # Build directly (not via set_tracer_provider) to avoid mutating global state. + provider = configure_tracing( + service_name="stitch-test", + exporter="console", + version="1.2.3", + environment="test", + ) + assert isinstance(provider, TracerProvider) + attrs = provider.resource.attributes + assert attrs["service.name"] == "stitch-test" + assert attrs["service.version"] == "1.2.3" + assert attrs["deployment.environment"] == "test" + + +def test_logging_span_exporter_emits_one_record_per_span(caplog) -> None: + # A local provider + the exporter under test; never touches the global. + exporter = LoggingSpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + with caplog.at_level(logging.INFO, logger="stitch.observability.trace"): + with tracer.start_as_current_span("unit-span"): + pass + + records = [r for r in caplog.records if r.name == "stitch.observability.trace"] + assert len(records) == 1 + assert records[0].event["span_name"] == "unit-span" + assert "trace_id" in records[0].event + + +def test_otel_settings_defaults_and_bounds() -> None: + s = OTelSettings() + assert s.otel_enabled is True + assert s.otel_traces_exporter == "console" + assert s.otel_sample_ratio == 1.0 + assert OTelSettings(otel_sample_ratio=0.25).otel_sample_ratio == 0.25 + + +def test_in_memory_exporter_captures_spans() -> None: + # Demonstrates the local-provider pattern the jobs trace test uses. + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + provider.get_tracer("t").start_span("s").end() + assert [s.name for s in exporter.get_finished_spans()] == ["s"] diff --git a/packages/stitch-service/pyproject.toml b/packages/stitch-service/pyproject.toml index 120f9b95..369d29c7 100644 --- a/packages/stitch-service/pyproject.toml +++ b/packages/stitch-service/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "fastapi[standard-no-fastapi-cloud-cli]>=0.135.1", "stitch-auth", "stitch-client", + "stitch-observability", ] [build-system] @@ -20,6 +21,7 @@ module-name = "stitch.service" [tool.uv.sources] stitch-auth = { workspace = true } stitch-client = { workspace = true } +stitch-observability = { workspace = true } [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/packages/stitch-service/src/stitch/service/app.py b/packages/stitch-service/src/stitch/service/app.py index b7625d23..008cd0ec 100644 --- a/packages/stitch-service/src/stitch/service/app.py +++ b/packages/stitch-service/src/stitch/service/app.py @@ -6,6 +6,13 @@ from datetime import UTC, datetime from fastapi import APIRouter, FastAPI +from stitch.observability import ( + OTelSettings, + configure_tracing, + instrument_fastapi, + instrument_httpx, + shutdown_tracing, +) from .middleware import register_cors @@ -27,6 +34,10 @@ def create_app( cors_origins: Sequence[str] = (), on_startup: LifecycleHook | None = None, on_shutdown: LifecycleHook | None = None, + service_name: str | None = None, + otel: OTelSettings | None = None, + version: str = "unknown", + environment: str = "unknown", **fastapi_kwargs: object, ) -> FastAPI: """Build a FastAPI app with the scaffolding every non-core service repeats. @@ -35,12 +46,26 @@ def create_app( given routers under ``api_prefix``, and runs the optional ``on_startup`` / ``on_shutdown`` hooks inside the lifespan. - ``on_startup`` is where a service does its own startup validation (auth / - downstream config). Keeping it a service-provided callback — rather than - baking specific validators in here — lets each service own and test that - logic. Observability wiring (deferred to a later pass) will hook in here - too, without reshaping this signature. + Pass ``service_name`` + ``otel`` to enable OpenTelemetry: the global tracer + provider is configured before the app is built, the app and outbound httpx + are instrumented synchronously (before serving — not in ``on_startup``, + where middleware-stack timing makes FastAPI instrumentation ineffective), + and the provider is flushed/shut down on exit. Omit them to leave tracing + off (current behavior). """ + # Configure the global provider before the app exists; instrument the built + # app below (before it serves). `provider is None` when tracing is disabled. + provider = None + if service_name is not None and otel is not None: + provider = configure_tracing( + service_name=service_name, + enabled=otel.otel_enabled, + exporter=otel.otel_traces_exporter, + otlp_endpoint=otel.otel_exporter_otlp_endpoint, + sample_ratio=otel.otel_sample_ratio, + version=version, + environment=environment, + ) @asynccontextmanager async def lifespan(app: FastAPI): @@ -50,6 +75,7 @@ async def lifespan(app: FastAPI): yield if on_shutdown is not None: await _maybe_await(on_shutdown(app)) + shutdown_tracing(provider) if title is not None: fastapi_kwargs["title"] = title @@ -57,6 +83,10 @@ async def lifespan(app: FastAPI): register_cors(app, origins=cors_origins) + if provider is not None: + instrument_fastapi(app) + instrument_httpx() + base_router = APIRouter(prefix=api_prefix) for router in routers: base_router.include_router(router) diff --git a/pyproject.toml b/pyproject.toml index 1da6fff7..0b6441bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ members = [ "packages/stitch-client", "packages/stitch-jobs", "packages/stitch-models", + "packages/stitch-observability", "packages/stitch-ogsi", "packages/stitch-service", ] diff --git a/uv.lock b/uv.lock index 1f839800..c4d8237d 100644 --- a/uv.lock +++ b/uv.lock @@ -17,6 +17,7 @@ members = [ "stitch-jobs", "stitch-llm", "stitch-models", + "stitch-observability", "stitch-ogsi", "stitch-seed", "stitch-service", @@ -687,6 +688,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/3d/2eae63f13f36d7a8ab5bf03d06ecaf169c2069b524547f24947be6d92094/opentelemetry_instrumentation_fastapi-0.63b1-py3-none-any.whl", hash = "sha256:52ee2cde9a2ac094bdd45d79f85860e03a972928a2553006071fe61d94cf7281", size = 12795, upload-time = "2026-05-21T16:35:28.68Z" }, ] +[[package]] +name = "opentelemetry-instrumentation-httpx" +version = "0.63b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/02/27/c2b4335bca030e893acbe5ff2b4f434868773bf94508be7e6bf5af981b24/opentelemetry_instrumentation_httpx-0.63b1.tar.gz", hash = "sha256:f41ec82f25c3abcdada621052db3e5fd648e3b43d55eec4b9c0c5d3ecb7b4ff4", size = 23557, upload-time = "2026-05-21T16:36:34.583Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/b8/f536780996195c3b9f2354998554671e05a7a262df8c043f63fe9e5a6f0b/opentelemetry_instrumentation_httpx-0.63b1-py3-none-any.whl", hash = "sha256:14df6e99d81be9a8cd238f6639b6fa52404c4d3ce219058fcb5dc8c0f2211f86", size = 16336, upload-time = "2026-05-21T16:35:32.221Z" }, +] + [[package]] name = "opentelemetry-instrumentation-sqlalchemy" version = "0.63b1" @@ -1253,6 +1270,7 @@ dependencies = [ { name = "sqlalchemy" }, { name = "stitch-auth" }, { name = "stitch-models" }, + { name = "stitch-observability" }, { name = "stitch-ogsi" }, ] @@ -1279,6 +1297,7 @@ requires-dist = [ { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "stitch-auth", editable = "packages/stitch-auth" }, { name = "stitch-models", editable = "packages/stitch-models" }, + { name = "stitch-observability", editable = "packages/stitch-observability" }, { name = "stitch-ogsi", editable = "packages/stitch-ogsi" }, ] @@ -1359,6 +1378,7 @@ dependencies = [ { name = "stitch-client" }, { name = "stitch-jobs" }, { name = "stitch-models" }, + { name = "stitch-observability" }, { name = "stitch-ogsi" }, { name = "stitch-service" }, ] @@ -1381,6 +1401,7 @@ requires-dist = [ { name = "stitch-client", editable = "packages/stitch-client" }, { name = "stitch-jobs", editable = "packages/stitch-jobs" }, { name = "stitch-models", editable = "packages/stitch-models" }, + { name = "stitch-observability", editable = "packages/stitch-observability" }, { name = "stitch-ogsi", editable = "packages/stitch-ogsi" }, { name = "stitch-service", editable = "packages/stitch-service" }, ] @@ -1400,12 +1421,14 @@ version = "0.1.0" source = { editable = "packages/stitch-jobs" } dependencies = [ { name = "fastapi", extra = ["standard-no-fastapi-cloud-cli"] }, + { name = "opentelemetry-api" }, { name = "pydantic" }, ] [package.dev-dependencies] dev = [ { name = "httpx" }, + { name = "opentelemetry-sdk" }, { name = "pytest" }, { name = "pytest-anyio" }, ] @@ -1413,12 +1436,14 @@ dev = [ [package.metadata] requires-dist = [ { name = "fastapi", extras = ["standard-no-fastapi-cloud-cli"], specifier = ">=0.135.1" }, + { name = "opentelemetry-api", specifier = ">=1.30.0" }, { name = "pydantic", specifier = ">=2.12.5" }, ] [package.metadata.requires-dev] dev = [ { name = "httpx", specifier = ">=0.28.0" }, + { name = "opentelemetry-sdk", specifier = ">=1.30.0" }, { name = "pytest", specifier = ">=9.0.2" }, { name = "pytest-anyio", specifier = ">=0.0.0" }, ] @@ -1435,6 +1460,7 @@ dependencies = [ { name = "stitch-client" }, { name = "stitch-jobs" }, { name = "stitch-models" }, + { name = "stitch-observability" }, { name = "stitch-ogsi" }, { name = "stitch-service" }, ] @@ -1455,6 +1481,7 @@ requires-dist = [ { name = "stitch-client", editable = "packages/stitch-client" }, { name = "stitch-jobs", editable = "packages/stitch-jobs" }, { name = "stitch-models", editable = "packages/stitch-models" }, + { name = "stitch-observability", editable = "packages/stitch-observability" }, { name = "stitch-ogsi", editable = "packages/stitch-ogsi" }, { name = "stitch-service", editable = "packages/stitch-service" }, ] @@ -1485,6 +1512,37 @@ requires-dist = [{ name = "pydantic", specifier = ">=2.12.5" }] [package.metadata.requires-dev] dev = [{ name = "pytest", specifier = ">=9.0.2" }] +[[package]] +name = "stitch-observability" +version = "0.1.0" +source = { editable = "packages/stitch-observability" } +dependencies = [ + { name = "opentelemetry-exporter-otlp-proto-grpc" }, + { name = "opentelemetry-instrumentation-fastapi" }, + { name = "opentelemetry-instrumentation-httpx" }, + { name = "opentelemetry-instrumentation-sqlalchemy" }, + { name = "opentelemetry-sdk" }, + { name = "pydantic-settings" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [ + { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = ">=1.30.0" }, + { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.51b0" }, + { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.51b0" }, + { name = "opentelemetry-instrumentation-sqlalchemy", specifier = ">=0.51b0" }, + { name = "opentelemetry-sdk", specifier = ">=1.30.0" }, + { name = "pydantic-settings", specifier = ">=2.11.0" }, +] + +[package.metadata.requires-dev] +dev = [{ name = "pytest", specifier = ">=9.0.2" }] + [[package]] name = "stitch-ogsi" version = "0.1.0" @@ -1547,6 +1605,7 @@ dependencies = [ { name = "fastapi", extra = ["standard-no-fastapi-cloud-cli"] }, { name = "stitch-auth" }, { name = "stitch-client" }, + { name = "stitch-observability" }, ] [package.dev-dependencies] @@ -1561,6 +1620,7 @@ requires-dist = [ { name = "fastapi", extras = ["standard-no-fastapi-cloud-cli"], specifier = ">=0.135.1" }, { name = "stitch-auth", editable = "packages/stitch-auth" }, { name = "stitch-client", editable = "packages/stitch-client" }, + { name = "stitch-observability", editable = "packages/stitch-observability" }, ] [package.metadata.requires-dev] From 349a06e05c12c591765d12e94324a8971e9d6ab0 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 24 Jun 2026 18:20:57 +0200 Subject: [PATCH 26/26] Address code review * Update docs/comments * prevent leaking OTel state --- packages/stitch-jobs/src/stitch/jobs/store.py | 11 +++-- packages/stitch-jobs/tests/test_tracing.py | 41 ++++++++++--------- .../tests/test_tracing.py | 8 +++- packages/stitch-service/README.md | 9 +++- .../src/stitch/service/__init__.py | 7 ++-- 5 files changed, 46 insertions(+), 30 deletions(-) diff --git a/packages/stitch-jobs/src/stitch/jobs/store.py b/packages/stitch-jobs/src/stitch/jobs/store.py index 0cb0c852..b7139675 100644 --- a/packages/stitch-jobs/src/stitch/jobs/store.py +++ b/packages/stitch-jobs/src/stitch/jobs/store.py @@ -14,10 +14,13 @@ def _utcnow() -> datetime: class JobStore(Protocol): """Persistence seam for job records. - The in-memory implementation below is sufficient for a single replica. A - DB-backed store (surviving restarts and shared across replicas) can be - dropped in later behind this same interface without touching the manager or - routers. + The in-memory implementation below is sufficient for a single replica, where + the manager mutating a ``JobRecord`` in place *is* the persistence (the store + holds the same object). A DB-backed store (surviving restarts, shared across + replicas) fits behind this same read interface, but would additionally need + the manager to write state transitions back through the store — adding that + write-back hook is part of the agreed DB-persistence follow-up, not this + in-memory layer. """ async def create(self, record: JobRecord) -> None: diff --git a/packages/stitch-jobs/tests/test_tracing.py b/packages/stitch-jobs/tests/test_tracing.py index 3929b8b6..6da00b03 100644 --- a/packages/stitch-jobs/tests/test_tracing.py +++ b/packages/stitch-jobs/tests/test_tracing.py @@ -3,7 +3,6 @@ import asyncio import pytest -from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( @@ -12,6 +11,7 @@ from pydantic import BaseModel from stitch.jobs import JobManager, SingletonPolicy +from stitch.jobs import manager as manager_module class Params(BaseModel): @@ -22,22 +22,21 @@ class Result(BaseModel): value: int -# Sets the process-global provider once (the manager's module-level tracer is a -# proxy that resolves to it). This file sorts last in the package, so the -# earlier suites run with the default no-op tracer, unaffected. -@pytest.fixture(scope="module") -def span_exporter() -> InMemorySpanExporter: +@pytest.fixture +def tracing(monkeypatch) -> tuple[TracerProvider, InMemorySpanExporter]: + """Local provider + in-memory exporter, with the manager's module-level + tracer pointed at it for the duration of the test. + + Monkeypatching ``manager._tracer`` (rather than calling + ``trace.set_tracer_provider``) keeps the process-global provider untouched — + OTel makes the global set-once, so it can't be restored in teardown — so the + suite stays isolated and order-independent. + """ exporter = InMemorySpanExporter() provider = TracerProvider() provider.add_span_processor(SimpleSpanProcessor(exporter)) - trace.set_tracer_provider(provider) - return exporter - - -@pytest.fixture -def spans(span_exporter: InMemorySpanExporter) -> InMemorySpanExporter: - span_exporter.clear() - return span_exporter + monkeypatch.setattr(manager_module, "_tracer", provider.get_tracer("stitch.jobs")) + return provider, exporter async def _wait_terminal(manager: JobManager, job_id: str, *, timeout=2.0): @@ -51,19 +50,21 @@ async def _wait_terminal(manager: JobManager, job_id: str, *, timeout=2.0): @pytest.mark.anyio -async def test_job_run_emits_root_span_linked_to_trigger(spans) -> None: +async def test_job_run_emits_root_span_linked_to_trigger(tracing) -> None: + provider, exporter = tracing + async def run(params: Params) -> Result: return Result(value=1) manager: JobManager[Params, Result] = JobManager(run, policy=SingletonPolicy()) - tracer = trace.get_tracer("test") + tracer = provider.get_tracer("test") with tracer.start_as_current_span("trigger") as trigger: trigger_ctx = trigger.get_span_context() record, _ = await manager.start(Params(name="a")) await _wait_terminal(manager, record.job_id) - job_spans = [s for s in spans.get_finished_spans() if s.name == "job.run"] + job_spans = [s for s in exporter.get_finished_spans() if s.name == "job.run"] assert len(job_spans) == 1 job_span = job_spans[0] assert job_span.attributes["stitch.job.id"] == record.job_id @@ -74,7 +75,9 @@ async def run(params: Params) -> Result: @pytest.mark.anyio -async def test_failed_job_span_has_error_status(spans) -> None: +async def test_failed_job_span_has_error_status(tracing) -> None: + _provider, exporter = tracing + async def run(params: Params) -> Result: raise RuntimeError("boom") @@ -82,6 +85,6 @@ async def run(params: Params) -> Result: record, _ = await manager.start(Params(name="a")) await _wait_terminal(manager, record.job_id) - job_span = next(s for s in spans.get_finished_spans() if s.name == "job.run") + job_span = next(s for s in exporter.get_finished_spans() if s.name == "job.run") assert job_span.status.status_code.name == "ERROR" assert job_span.attributes["stitch.job.state"] == "failed" diff --git a/packages/stitch-observability/tests/test_tracing.py b/packages/stitch-observability/tests/test_tracing.py index d328505e..1c6c6b7c 100644 --- a/packages/stitch-observability/tests/test_tracing.py +++ b/packages/stitch-observability/tests/test_tracing.py @@ -1,5 +1,6 @@ import logging +from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( @@ -15,8 +16,11 @@ def test_configure_tracing_disabled_returns_none() -> None: assert configure_tracing(service_name="svc", enabled=True, exporter="none") is None -def test_configure_tracing_builds_provider_with_resource() -> None: - # Build directly (not via set_tracer_provider) to avoid mutating global state. +def test_configure_tracing_builds_provider_with_resource(monkeypatch) -> None: + # configure_tracing installs the provider globally via set_tracer_provider; + # stub that out so this test exercises only provider construction and leaves + # the process-global provider untouched (OTel makes it set-once). + monkeypatch.setattr(trace, "set_tracer_provider", lambda _provider: None) provider = configure_tracing( service_name="stitch-test", exporter="console", diff --git a/packages/stitch-service/README.md b/packages/stitch-service/README.md index b876f930..b348d959 100644 --- a/packages/stitch-service/README.md +++ b/packages/stitch-service/README.md @@ -10,6 +10,10 @@ Shared FastAPI scaffolding for Stitch non-core services — the boilerplate that - health helpers — `make_basic_health_router(service)` for liveness, plus `runtime_block`/`format_started_at`/`uptime_seconds` for assembling a service-specific `/health/details`. +- observability — pass `service_name` + `otel` (an `OTelSettings` from + `stitch-observability`) and `create_app` configures OpenTelemetry tracing: + FastAPI server spans, outbound httpx `traceparent` propagation, and provider + shutdown in the lifespan. Omit them and tracing stays off. ```python from stitch.service import create_app @@ -27,7 +31,8 @@ app = create_app( ## Out of scope (for now) -- **Observability/logging** — in flight on a separate branch; will hook into the - app factory's lifespan later. +- **Structured-log / query-timing layer** — the API's request-logging and + per-query timing sinks hang off its SQLAlchemy engine and stay API-specific; + only tracing is shared here. - **Auth** — each service still owns its auth wiring (settings-coupled); a future pass may extract a configurable auth provider here. diff --git a/packages/stitch-service/src/stitch/service/__init__.py b/packages/stitch-service/src/stitch/service/__init__.py index e4812e7c..4cdd6e6f 100644 --- a/packages/stitch-service/src/stitch/service/__init__.py +++ b/packages/stitch-service/src/stitch/service/__init__.py @@ -2,9 +2,10 @@ Provides the app factory, CORS wiring, health helpers, and the auth seam (both inbound request validation and the downstream machine / on-behalf-of modes) that -every service otherwise copies. Observability is intentionally out of scope for -now (in flight on a separate branch); the app factory leaves lifecycle hooks -open so it can be added later. +every service otherwise copies. When passed a ``service_name`` and ``otel`` +settings, ``create_app`` also configures OpenTelemetry tracing via +``stitch-observability`` (FastAPI server spans + outbound httpx propagation); +omitting them leaves tracing off. """ from .app import create_app