diff --git a/Makefile b/Makefile index d67ba5d4..bcc784ac 100644 --- a/Makefile +++ b/Makefile @@ -107,9 +107,30 @@ pkg-test-ogsi: pkg-test-exact-ogsi: $(MAKE) uv-test-target-exact PKG=stitch-ogsi TEST_PATH=packages/stitch-ogsi -pkg-build: pkg-build-auth pkg-build-client pkg-build-models pkg-build-ogsi -pkg-test: pkg-test-auth pkg-test-client pkg-test-models pkg-test-ogsi -pkg-test-exact: pkg-test-exact-auth pkg-test-exact-client pkg-test-exact-models pkg-test-exact-ogsi +pkg-build-service: + $(UV) build --package stitch-service +pkg-test-service: + $(MAKE) uv-test-target PKG=stitch-service TEST_PATH=packages/stitch-service +pkg-test-exact-service: + $(MAKE) uv-test-target-exact PKG=stitch-service TEST_PATH=packages/stitch-service + +pkg-build-jobs: + $(UV) build --package stitch-jobs +pkg-test-jobs: + $(MAKE) uv-test-target PKG=stitch-jobs TEST_PATH=packages/stitch-jobs +pkg-test-exact-jobs: + $(MAKE) uv-test-target-exact PKG=stitch-jobs TEST_PATH=packages/stitch-jobs + +pkg-build-observability: + $(UV) build --package stitch-observability +pkg-test-observability: + $(MAKE) uv-test-target PKG=stitch-observability TEST_PATH=packages/stitch-observability +pkg-test-exact-observability: + $(MAKE) uv-test-target-exact PKG=stitch-observability TEST_PATH=packages/stitch-observability + +pkg-build: pkg-build-auth pkg-build-client pkg-build-models pkg-build-ogsi pkg-build-service pkg-build-jobs pkg-build-observability +pkg-test: pkg-test-auth pkg-test-client pkg-test-models pkg-test-ogsi pkg-test-service pkg-test-jobs pkg-test-observability +pkg-test-exact: pkg-test-exact-auth pkg-test-exact-client pkg-test-exact-models pkg-test-exact-ogsi pkg-test-exact-service pkg-test-exact-jobs pkg-test-exact-observability # --------------------------------------------------------------------- # Deployments @@ -291,6 +312,9 @@ follow-stack-logs: pkg-build-client pkg-test-client pkg-test-exact-client \ pkg-build-models pkg-test-models pkg-test-exact-models \ pkg-build-ogsi pkg-test-ogsi pkg-test-exact-ogsi \ + pkg-build-service pkg-test-service pkg-test-exact-service \ + pkg-build-jobs pkg-test-jobs pkg-test-exact-jobs \ + pkg-build-observability pkg-test-observability pkg-test-exact-observability \ \ # API api-build api-test api-test-exact api-dev stack-api-dev \ diff --git a/deployments/api/pyproject.toml b/deployments/api/pyproject.toml index 8fe7745a..da164baa 100644 --- a/deployments/api/pyproject.toml +++ b/deployments/api/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "sqlalchemy>=2.0.44", "stitch-auth", "stitch-models", + "stitch-observability", "stitch-ogsi", ] @@ -47,4 +48,5 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } stitch-models = { workspace = true } +stitch-observability = { workspace = true } stitch-ogsi = { workspace = true } diff --git a/deployments/api/src/stitch/api/observability/tracing.py b/deployments/api/src/stitch/api/observability/tracing.py index 2d6fc5bd..a49d1460 100644 --- a/deployments/api/src/stitch/api/observability/tracing.py +++ b/deployments/api/src/stitch/api/observability/tracing.py @@ -1,139 +1,46 @@ -"""OpenTelemetry tracing setup for the API. - -Span *generation* is handled by auto-instrumentation (FastAPI + SQLAlchemy); -this module owns span *export*, which is configurable: - -* ``console`` (default) — finished spans are emitted as structured log records - through the existing :class:`JsonFormatter` (see :mod:`logging_config`), so - local dev gets full trace data on stdout **without** running the collector / - Jaeger sidecars. This is the "log what OTel would send" path. -* ``otlp`` — spans are shipped via OTLP/gRPC to the collector (``→`` Jaeger). -* ``none`` — tracing is disabled entirely. - -Sampling uses ``ParentBased(root=TraceIdRatioBased(ratio))`` so the API honors -an upstream caller's sampling decision (propagated via the W3C ``traceparent`` -header) and only samples independently when it is the root of a trace. The -ratio defaults to 1.0 (capture everything) for local dev. +"""OpenTelemetry tracing for the API — a thin wrapper over the shared +``stitch.observability`` package (one source of truth across services). + +Keeps this module's historical surface (``SERVICE_NAME``, +``configure_tracing(settings)``, ``instrument_fastapi``, ``instrument_sqlalchemy``, +``LoggingSpanExporter``) so call sites (``main.py``, ``db/config.py``) and tests +don't change. The API's query-timing / request-logging / sinks layer stays +API-specific (it hangs off the SQLAlchemy engine). """ -import logging from typing import TYPE_CHECKING -from opentelemetry import trace -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter -from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor -from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor -from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import ( - BatchSpanProcessor, - SimpleSpanProcessor, - SpanExporter, - SpanExportResult, +from stitch.observability import ( + LoggingSpanExporter, + configure_tracing as _configure_tracing, + instrument_fastapi, + instrument_sqlalchemy, ) -from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased if TYPE_CHECKING: - from collections.abc import Sequence - - from fastapi import FastAPI - from opentelemetry.sdk.trace import ReadableSpan - from sqlalchemy.engine import Engine + from opentelemetry.sdk.trace import TracerProvider from ..settings import Settings SERVICE_NAME = "stitch-api" -_span_logger = logging.getLogger("stitch.api.observability.trace") - - -class LoggingSpanExporter(SpanExporter): - """Export finished spans as structured log records instead of shipping them - to a collector. - - Each span becomes one ``stitch.api.observability.trace`` log record whose - ``event`` dict the :class:`JsonFormatter` flattens to the top level, so - fields like ``trace_id`` / ``duration_ms`` are directly queryable and sit - alongside the request / query events on the same stdout stream. - """ - - def export(self, spans: "Sequence[ReadableSpan]") -> SpanExportResult: - for span in spans: - ctx = span.get_span_context() - parent = span.parent - duration_ms = ( - round((span.end_time - span.start_time) / 1e6, 2) - if span.end_time is not None and span.start_time is not None - else None - ) - _span_logger.info( - "span", - extra={ - "event": { - "span_name": span.name, - "trace_id": format(ctx.trace_id, "032x"), - "span_id": format(ctx.span_id, "016x"), - "parent_span_id": format(parent.span_id, "016x") - if parent is not None - else None, - "kind": span.kind.name, - "duration_ms": duration_ms, - "status": span.status.status_code.name, - "attributes": dict(span.attributes or {}), - } - }, - ) - return SpanExportResult.SUCCESS - - def force_flush(self, timeout_millis: int = 30_000) -> bool: - return True - - -def configure_tracing(settings: "Settings") -> TracerProvider | None: - """Install the global tracer provider, or return ``None`` if disabled. - - Call once at startup, before the first span is created. Idempotency is not - guaranteed — ``set_tracer_provider`` warns if called twice. - """ - if not settings.otel_enabled or settings.otel_traces_exporter == "none": - return None - - resource = Resource.create( - { - "service.name": SERVICE_NAME, - "service.version": settings.app_version or "unknown", - "deployment.environment": settings.environment_name, - } +__all__ = [ + "SERVICE_NAME", + "LoggingSpanExporter", + "configure_tracing", + "instrument_fastapi", + "instrument_sqlalchemy", +] + + +def configure_tracing(settings: "Settings") -> "TracerProvider | None": + """Install the API's global tracer provider, or ``None`` if disabled.""" + return _configure_tracing( + service_name=SERVICE_NAME, + enabled=settings.otel_enabled, + exporter=settings.otel_traces_exporter, + otlp_endpoint=settings.otel_exporter_otlp_endpoint, + sample_ratio=settings.otel_sample_ratio, + version=settings.app_version or "unknown", + environment=settings.environment_name, ) - sampler = ParentBased(root=TraceIdRatioBased(settings.otel_sample_ratio)) - provider = TracerProvider(resource=resource, sampler=sampler) - - if settings.otel_traces_exporter == "otlp": - # endpoint=None lets the exporter fall back to OTEL_EXPORTER_OTLP_ENDPOINT - # / the localhost default. - exporter = OTLPSpanExporter(endpoint=settings.otel_exporter_otlp_endpoint) - provider.add_span_processor(BatchSpanProcessor(exporter)) - else: # "console" — log spans to stdout, no sidecar required. - provider.add_span_processor(SimpleSpanProcessor(LoggingSpanExporter())) - - trace.set_tracer_provider(provider) - return provider - - -def instrument_fastapi(app: "FastAPI") -> None: - """Auto-instrument the FastAPI app (server spans + traceparent extraction). - - URL query strings are intentionally left intact — they're the diagnostic - payload for the performance work this serves. When a retained backend makes - aggregate PII a concern (cloud), scrub them at the collector's egress - (an ``attributes``/``redaction`` processor) rather than blinding local dev. - """ - FastAPIInstrumentor.instrument_app(app) - - -def instrument_sqlalchemy(engine: "Engine") -> None: - """Auto-instrument a (sync) SQLAlchemy engine for per-query spans. - - Pass ``async_engine.sync_engine`` for an ``AsyncEngine``. - """ - SQLAlchemyInstrumentor().instrument(engine=engine) diff --git a/deployments/api/tests/observability/test_tracing.py b/deployments/api/tests/observability/test_tracing.py index 27f59131..487b7749 100644 --- a/deployments/api/tests/observability/test_tracing.py +++ b/deployments/api/tests/observability/test_tracing.py @@ -15,7 +15,8 @@ from stitch.api.observability.tracing import LoggingSpanExporter, configure_tracing from stitch.api.settings import Settings -_TRACE_LOGGER = "stitch.api.observability.trace" +# Span log records now come from the shared stitch-observability exporter. +_TRACE_LOGGER = "stitch.observability.trace" @pytest.fixture diff --git a/deployments/entity-linkage/conftest.py b/deployments/entity-linkage/conftest.py new file mode 100644 index 00000000..343187ee --- /dev/null +++ b/deployments/entity-linkage/conftest.py @@ -0,0 +1,6 @@ +import os + +# Disable tracing for the suite before the app module imports and runs +# configure_tracing (mirrors the API's rootdir conftest). An env var set here +# wins over the .env file's value via pydantic-settings precedence. +os.environ.setdefault("OTEL_TRACES_EXPORTER", "none") diff --git a/deployments/entity-linkage/pyproject.toml b/deployments/entity-linkage/pyproject.toml index ca5936a2..b087cc81 100644 --- a/deployments/entity-linkage/pyproject.toml +++ b/deployments/entity-linkage/pyproject.toml @@ -11,8 +11,11 @@ dependencies = [ "pydantic-settings>=2.12.0", "stitch-auth", "stitch-client", + "stitch-jobs", "stitch-models", + "stitch-observability", "stitch-ogsi", + "stitch-service", ] [build-system] @@ -41,5 +44,8 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } stitch-client = { workspace = true } +stitch-jobs = { workspace = true } stitch-models = { workspace = true } +stitch-observability = { workspace = true } stitch-ogsi = { workspace = true } +stitch-service = { workspace = true } diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/auth.py b/deployments/entity-linkage/src/stitch/entity_linkage/auth.py index f04cf08d..7efeb2fe 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/auth.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/auth.py @@ -1,208 +1,23 @@ -import asyncio -import logging -from functools import lru_cache -from typing import Annotated, Literal, NoReturn +"""Entity-linkage auth wiring. -from fastapi import Depends, HTTPException, Request -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +All the mechanics live in :mod:`stitch.service.auth`; here we just bind a +:class:`~stitch.service.auth.ServiceAuth` to this service's settings and +re-export the dependencies the routers and tests import by name. +""" -from stitch.auth import ( - ALL_PERMISSIONS, - AuthError, - InsufficientPermissionsError, - JWKSFetchError, - JWTValidator, - OIDCSettings, - TokenClaims, - check_permissions, -) +from stitch.service.auth import ServiceAuth -from stitch.entity_linkage.entities import RequestAuthContext, User from stitch.entity_linkage.settings import get_settings -logger = logging.getLogger(__name__) +_auth = ServiceAuth(is_auth_disabled=lambda: get_settings().auth_disabled) +validate_auth_config_at_startup = _auth.validate_auth_config_at_startup +get_token_claims = _auth.get_token_claims +require_permissions = _auth.require_permissions +get_current_user = _auth.get_current_user +get_request_auth_context = _auth.get_request_auth_context +initiated_by = _auth.initiated_by -@lru_cache -def get_oidc_settings() -> OIDCSettings: - return OIDCSettings() - - -@lru_cache -def get_jwt_validator() -> JWTValidator: - return JWTValidator(get_oidc_settings()) - - -_DEV_CLAIMS = TokenClaims( - sub="dev|local-placeholder", - email="dev@example.com", - name="Dev User", - permissions=ALL_PERMISSIONS, - raw={}, -) - -# auto_error=False so that when AUTH_DISABLED=true the missing header -# doesn't trigger a 403 before our custom handler runs. -_bearer_scheme = HTTPBearer(auto_error=False) - - -def validate_auth_config_at_startup() -> None: - settings = get_settings() - - if settings.auth_disabled: - logger.warning("Auth is disabled — all requests use dev credentials") - return - - # fail fast if OIDC config is invalid - get_oidc_settings() - - -def _extract_bearer_token_from_request(request: Request) -> str | None: - """ - Return the raw bearer token from the Authorization header. - - This exists separately from JWT validation so that downstream callers can - opt into explicit bearer-token relay in the future. - """ - auth_header = request.headers.get("Authorization") - if not auth_header: - return None - - scheme, _, token = auth_header.partition(" ") - if scheme.lower() != "bearer" or not token: - return None - - return token - - -def _dev_bearer_token() -> str: - """ - Placeholder token used only when auth is disabled in local development. - """ - return "dev-placeholder-token" - - -def _claims_to_user(claims: TokenClaims) -> User: - return User( - id=1, - sub=claims.sub, - email=claims.email or "unknown@example.com", - name=claims.name or claims.email or claims.sub, - ) - - -async def get_token_claims( - request: Request, - _credential: HTTPAuthorizationCredentials | None = Depends(_bearer_scheme), -) -> TokenClaims: - """Extract and validate JWT from Authorization header. - - The ``_credential`` parameter exists solely so FastAPI registers the - HTTPBearer security scheme in the OpenAPI spec (Swagger "Authorize" - button). Actual token parsing still uses the raw header so we can - return precise 401 messages for missing/malformed values. - """ - if get_settings().auth_disabled: - return _DEV_CLAIMS - - auth_header = request.headers.get("Authorization") - if not auth_header: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Missing Authorization header", - headers={"WWW-Authenticate": "Bearer"}, - ) - - scheme, _, token = auth_header.partition(" ") - if scheme.lower() != "bearer" or not token: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid Authorization header format", - headers={"WWW-Authenticate": "Bearer"}, - ) - - validator = get_jwt_validator() - try: - return await asyncio.to_thread(validator.validate, token) - except JWKSFetchError: - logger.error( - "JWKS endpoint unreachable or returned invalid data", exc_info=True - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - except AuthError as e: - logger.warning("JWT validation failed: %s", e, exc_info=True) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - - -Claims = Annotated[TokenClaims, Depends(get_token_claims)] - - -def _permission_exception_handler(exc: InsufficientPermissionsError) -> NoReturn: - raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=exc.detail) - - -def require_permissions( - *required_permissions: str, check: Literal["all", "any"] = "all" -): - async def dependency(claims: Claims) -> None: - check_permissions( - granted=claims.permissions, - required=required_permissions, - check=check, - exc_handler=_permission_exception_handler, - ) - - return dependency - - -async def get_current_user(claims: Claims) -> User: - """ - Resolve validated token claims to a lightweight request user. - """ - if get_settings().auth_disabled: - return User( - id=1, - sub=_DEV_CLAIMS.sub, - email=_DEV_CLAIMS.email or "dev@example.com", - name=_DEV_CLAIMS.name or "Dev User", - ) - return _claims_to_user(claims) - - -CurrentUser = Annotated[User, Depends(get_current_user)] - - -async def get_request_auth_context( - request: Request, - user: CurrentUser, -) -> RequestAuthContext: - """ - Build the request-scoped auth context used by downstream API clients. - - The current deployment wiring uses env-based downstream auth, but we keep - the raw bearer token available for future explicit relay or OBO modes. - """ - if get_settings().auth_disabled: - bearer_token = _dev_bearer_token() - else: - bearer_token = _extract_bearer_token_from_request(request) - - return RequestAuthContext( - user=user, - bearer_token=bearer_token, - ) - - -AuthContext = Annotated[ - RequestAuthContext, - Depends(get_request_auth_context), -] +Claims = _auth.Claims +CurrentUser = _auth.CurrentUser +AuthContext = _auth.AuthContext diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/client.py b/deployments/entity-linkage/src/stitch/entity_linkage/client.py index 907020db..23971b73 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/client.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/client.py @@ -2,11 +2,17 @@ from typing import Any -from stitch.client import AsyncStitchClient, env_bearer_token_headers_provider +from stitch.client import AsyncStitchClient +from stitch.service.auth import AuthMode, build_headers_provider from stitch.entity_linkage.entities import FieldCandidate, FieldDetailCandidate from stitch.entity_linkage.settings import get_settings +# Entity-linkage does its work in a detached background job, so the caller's +# token is gone by the time the run executes — it authenticates downstream with +# its own machine identity (STITCH_CLIENT_BEARER_TOKEN), not on-behalf-of. +_DOWNSTREAM_AUTH_MODE = AuthMode.machine + def _get_api_base_url() -> str: """ @@ -16,8 +22,8 @@ def _get_api_base_url() -> str: def validate_downstream_auth_config_at_startup() -> None: - headers_provider = env_bearer_token_headers_provider() - headers_provider() + # Fail fast at startup if the machine token isn't configured. + build_headers_provider(_DOWNSTREAM_AUTH_MODE)() class StitchApiClient: @@ -29,11 +35,10 @@ def __init__( self._client = client return - headers_provider = env_bearer_token_headers_provider() self._client = AsyncStitchClient( base_url=_get_api_base_url(), timeout=30.0, - headers_provider=headers_provider, + headers_provider=build_headers_provider(_DOWNSTREAM_AUTH_MODE), ) async def __aenter__(self) -> "StitchApiClient": diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/entities.py b/deployments/entity-linkage/src/stitch/entity_linkage/entities.py index b4c2e550..a8864849 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/entities.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/entities.py @@ -1,10 +1,13 @@ -from dataclasses import dataclass from datetime import datetime from math import ceil from typing import Literal -from pydantic import BaseModel, EmailStr, Field, computed_field +from pydantic import BaseModel, Field, computed_field +# Identity is shared scaffolding now; re-exported here so existing imports +# (`from stitch.entity_linkage.entities import User, RequestAuthContext`) keep +# working. +from stitch.service.auth import RequestAuthContext, ServiceUser as User from stitch.ogsi.model.types import ( FieldStatus, LocationType, @@ -13,34 +16,27 @@ ProductionConventionality, ) +__all__ = [ + "FieldCandidate", + "FieldDetailCandidate", + "MatchGroup", + "OGFieldFilterParams", + "OGFieldQueryParams", + "OGFieldSortParams", + "PaginatedResponse", + "PaginationParams", + "RequestAuthContext", + "SortableField", + "Timestamped", + "User", +] + class Timestamped(BaseModel): created: datetime = Field(default_factory=datetime.now) updated: datetime = Field(default_factory=datetime.now) -class User(BaseModel): - id: int = Field(...) - sub: str = Field(...) - role: str | None = None - email: EmailStr - name: str - - -@dataclass(frozen=True, slots=True) -class RequestAuthContext: - """ - Request-scoped auth context for inbound request identity. - - not implemented: - - re-enable downstream relay or OBO auth as an explicit client mode - - keep user attribution/provenance as separate metadata - """ - - user: User - bearer_token: str | None - - class FieldCandidate(BaseModel): id: int name: str | None = None diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/linkage.py b/deployments/entity-linkage/src/stitch/entity_linkage/linkage.py new file mode 100644 index 00000000..16a52bd6 --- /dev/null +++ b/deployments/entity-linkage/src/stitch/entity_linkage/linkage.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from collections import defaultdict + +from pydantic import BaseModel, Field + +from stitch.entity_linkage.client import StitchApiClient +from stitch.entity_linkage.entities import FieldCandidate, MatchGroup + + +class LinkageParams(BaseModel): + """Tunable inputs for one entity-linkage pass. + + Doubles as the ``POST /start`` request body and the params stored on the + job record. + """ + + apply_merges: bool = Field( + default=False, + description=( + "When true, submit confirmed match groups to the Stitch API as " + "merge candidates." + ), + ) + page: int = Field(default=1, ge=1) + page_size: int = Field(default=50, ge=1, le=200) + max_pages: int | None = Field( + default=None, + ge=1, + le=1000, + description="Optional cap on pages fetched. Null means fetch all pages.", + ) + + +class LinkageResult(BaseModel): + """Summary of a completed entity-linkage pass.""" + + pages_fetched: int + total_records_fetched: int + duplicate_name_candidate_count: int + detail_records_fetched: int + match_groups: list[list[int]] + merge_results: list[dict] + + +def _group_duplicate_names( + items: list[FieldCandidate], +) -> dict[str, list[FieldCandidate]]: + grouped: dict[str, list[FieldCandidate]] = defaultdict(list) + for item in items: + if item.normalized_name is None: + continue + grouped[item.normalized_name].append(item) + return { + normalized_name: grouped_items + for normalized_name, grouped_items in grouped.items() + if len(grouped_items) > 1 + } + + +def _normalize_country(country: str | None) -> str | None: + if country is None: + return None + normalized = country.strip().upper() + return normalized or None + + +async def _resolve_match_groups( + client: StitchApiClient, + duplicate_groups: dict[str, list[FieldCandidate]], +) -> tuple[list[MatchGroup], int]: + match_groups: list[MatchGroup] = [] + detail_records_fetched = 0 + + for normalized_name, candidates in duplicate_groups.items(): + by_country: dict[str, list[int]] = defaultdict(list) + + for candidate in candidates: + detail = await client.get_oil_gas_field_detail(candidate.id) + detail_records_fetched += 1 + normalized_country = _normalize_country(detail.country) + if normalized_country is None: + continue + by_country[normalized_country].append(detail.id) + + for country, ids in by_country.items(): + if len(ids) > 1: + match_groups.append( + MatchGroup( + ids=sorted(ids), + normalized_name=normalized_name, + country=country, + ) + ) + + return match_groups, detail_records_fetched + + +async def run_linkage(params: LinkageParams) -> LinkageResult: + """Run one entity-linkage pass and return a summary. + + - fetch paginated oil-gas-fields list + - group exact case-insensitive duplicate names + - fetch detail records for candidate duplicates + - confirm same-country matches + - optionally submit merge candidates + + Invoked as the background job body by the ``JobManager`` in + :mod:`stitch.entity_linkage.routers.start`. Downstream failures + (``StitchAPIError``) propagate and are captured as a failed job, observable + via ``GET /status/{job_id}``. + """ + async with StitchApiClient() as client: + items, pages_fetched = await client.collect_oil_gas_fields( + start_page=params.page, + page_size=params.page_size, + max_pages=params.max_pages, + ) + duplicate_groups = _group_duplicate_names(items) + match_groups, detail_records_fetched = await _resolve_match_groups( + client=client, + duplicate_groups=duplicate_groups, + ) + + merge_results: list[dict] = [] + if params.apply_merges: + for group in match_groups: + response = await client.create_merge_candidate(resource_ids=group.ids) + merge_results.append({"ids": group.ids, "response": response}) + + return LinkageResult( + pages_fetched=pages_fetched, + total_records_fetched=len(items), + duplicate_name_candidate_count=sum( + len(group) for group in duplicate_groups.values() + ), + detail_records_fetched=detail_records_fetched, + match_groups=[group.ids for group in match_groups], + merge_results=merge_results, + ) diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/main.py b/deployments/entity-linkage/src/stitch/entity_linkage/main.py index ab95c02f..eaa96b11 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/main.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/main.py @@ -1,7 +1,6 @@ -from contextlib import asynccontextmanager -from datetime import UTC, datetime -from fastapi import APIRouter, FastAPI -from .middleware import register_middlewares +from fastapi import FastAPI +from stitch.service import create_app + from .auth import validate_auth_config_at_startup from .client import validate_downstream_auth_config_at_startup from .settings import get_settings @@ -9,27 +8,22 @@ from .routers.health import router as health_router from .routers.start import router as start_router -base_router = APIRouter(prefix="/api/v1") -base_router.include_router(health_router) -base_router.include_router(start_router) - -@asynccontextmanager -async def lifespan(app: FastAPI): - app.state.started_at = datetime.now(UTC) +def _run_startup(app: FastAPI) -> None: app.state.auth_config_validated = False app.state.downstream_auth_config_validated = False validate_auth_config_at_startup() app.state.auth_config_validated = True validate_downstream_auth_config_at_startup() app.state.downstream_auth_config_validated = True - yield -app = FastAPI(lifespan=lifespan) - settings = get_settings() -register_middlewares(application=app, settings=settings) - -app.include_router(base_router) +app = create_app( + routers=[health_router, start_router], + cors_origins=[str(settings.frontend_origin_url)], + on_startup=_run_startup, + service_name="stitch-entity-linkage", + otel=settings, +) diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py b/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py index 7f284bfb..c8195329 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/routers/start.py @@ -1,168 +1,43 @@ from __future__ import annotations -from collections import defaultdict +from datetime import timedelta -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel, Field -from starlette.status import HTTP_502_BAD_GATEWAY +from fastapi import Depends from stitch.auth.permissions import SERVICE_ENTITY_LINKAGE_RUN +from stitch.jobs import ( + FingerprintPolicy, + InMemoryJobStore, + JobManager, + make_job_router, +) -from stitch.entity_linkage.auth import AuthContext, require_permissions -from stitch.entity_linkage.client import StitchApiClient -from stitch.entity_linkage.entities import FieldCandidate, MatchGroup, User -from stitch.entity_linkage.errors import StitchAPIError - -router = APIRouter(tags=["entity-linkage"]) - - -class StartRequest(BaseModel): - apply_merges: bool = Field( - default=False, - description=( - "When true, submit confirmed match groups to the Stitch API as " - "merge candidates." - ), - ) - page: int = Field(default=1, ge=1) - page_size: int = Field(default=50, ge=1, le=200) - max_pages: int | None = Field( - default=None, - ge=1, - le=1000, - description="Optional cap on pages fetched. Null means fetch all pages.", - ) - - -class StartResponse(BaseModel): - initiated_by: str - apply_merges: bool - pages_fetched: int - total_records_fetched: int - duplicate_name_candidate_count: int - detail_records_fetched: int - match_groups: list[list[int]] - merge_results: list[dict] - - -def _extract_user_label(user: User) -> str: - return user.name or user.email or user.sub - - -def _group_duplicate_names( - items: list[FieldCandidate], -) -> dict[str, list[FieldCandidate]]: - grouped: dict[str, list[FieldCandidate]] = defaultdict(list) - for item in items: - if item.normalized_name is None: - continue - grouped[item.normalized_name].append(item) - return { - normalized_name: grouped_items - for normalized_name, grouped_items in grouped.items() - if len(grouped_items) > 1 - } - - -def _normalize_country(country: str | None) -> str | None: - if country is None: - return None - normalized = country.strip().upper() - return normalized or None - - -async def _resolve_match_groups( - client: StitchApiClient, - duplicate_groups: dict[str, list[FieldCandidate]], -) -> tuple[list[MatchGroup], int]: - match_groups: list[MatchGroup] = [] - detail_records_fetched = 0 - - for normalized_name, candidates in duplicate_groups.items(): - by_country: dict[str, list[int]] = defaultdict(list) - - for candidate in candidates: - detail = await client.get_oil_gas_field_detail(candidate.id) - detail_records_fetched += 1 - normalized_country = _normalize_country(detail.country) - if normalized_country is None: - continue - by_country[normalized_country].append(detail.id) +from stitch.entity_linkage.auth import initiated_by, require_permissions +from stitch.entity_linkage.linkage import LinkageParams, LinkageResult, run_linkage + +# Two requests are "the same" run when all tunable params match. Identical +# requests (same paging + apply_merges) collapse onto one job — so a second +# user sees the in-flight run, and reuses its result for `recent_within` after +# it finishes — while different params run independently. +# Reuse an identical run for 24h. Retention must cover the reuse window, else +# terminal records would be evicted before they could be reused. +_REUSE_WINDOW = timedelta(hours=24) +_manager: JobManager[LinkageParams, LinkageResult] = JobManager( + run_linkage, + policy=FingerprintPolicy(), + recent_within=_REUSE_WINDOW, + store=InMemoryJobStore(retention=_REUSE_WINDOW), +) - for country, ids in by_country.items(): - if len(ids) > 1: - match_groups.append( - MatchGroup( - ids=sorted(ids), - normalized_name=normalized_name, - country=country, - ) - ) - return match_groups, detail_records_fetched +def get_job_manager() -> JobManager[LinkageParams, LinkageResult]: + return _manager -@router.post( - "/start", - response_model=StartResponse, +router = make_job_router( + _manager, + params_model=LinkageParams, + result_model=LinkageResult, dependencies=[Depends(require_permissions(SERVICE_ENTITY_LINKAGE_RUN))], + initiated_by=initiated_by, + tags=["entity-linkage"], ) -async def start( - request: StartRequest, - auth_context: AuthContext, -) -> StartResponse: - """ - In-memory entity-linkage pass: - - fetch paginated oil-gas-fields list - - group exact case-insensitive duplicate names - - fetch detail records for candidate duplicates - - confirm same-country matches - - optionally submit merge candidates - - Not implemented: - - add concurrency controls for detail fetches - - add stronger second-phase inspection beyond country equality - - add alternate downstream auth modes beyond env-token auth - """ - try: - async with StitchApiClient() as client: - items, pages_fetched = await client.collect_oil_gas_fields( - start_page=request.page, - page_size=request.page_size, - max_pages=request.max_pages, - ) - duplicate_groups = _group_duplicate_names(items) - match_groups, detail_records_fetched = await _resolve_match_groups( - client=client, - duplicate_groups=duplicate_groups, - ) - - merge_results: list[dict] = [] - if request.apply_merges: - for group in match_groups: - response = await client.create_merge_candidate( - resource_ids=group.ids - ) - merge_results.append( - { - "ids": group.ids, - "response": response, - } - ) - except StitchAPIError as exc: - raise HTTPException( - status_code=HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc - - return StartResponse( - initiated_by=_extract_user_label(auth_context.user), - apply_merges=request.apply_merges, - pages_fetched=pages_fetched, - total_records_fetched=len(items), - duplicate_name_candidate_count=sum( - len(group) for group in duplicate_groups.values() - ), - detail_records_fetched=detail_records_fetched, - match_groups=[group.ids for group in match_groups], - merge_results=merge_results, - ) diff --git a/deployments/entity-linkage/src/stitch/entity_linkage/settings.py b/deployments/entity-linkage/src/stitch/entity_linkage/settings.py index 34dfe81a..b800c9b8 100644 --- a/deployments/entity-linkage/src/stitch/entity_linkage/settings.py +++ b/deployments/entity-linkage/src/stitch/entity_linkage/settings.py @@ -2,10 +2,11 @@ from typing import ClassVar from pydantic import AnyHttpUrl, Field -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import SettingsConfigDict +from stitch.observability import OTelSettings -class Settings(BaseSettings): +class Settings(OTelSettings): log_level: str = Field(default="INFO", alias="ENTITY_LINKAGE_LOG_LEVEL") frontend_origin_url: AnyHttpUrl = Field( default="http://localhost:3000", diff --git a/deployments/entity-linkage/tests/conftest.py b/deployments/entity-linkage/tests/conftest.py new file mode 100644 index 00000000..5c53fe0a --- /dev/null +++ b/deployments/entity-linkage/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" diff --git a/deployments/entity-linkage/tests/test_downstream_auth.py b/deployments/entity-linkage/tests/test_downstream_auth.py new file mode 100644 index 00000000..206c11a1 --- /dev/null +++ b/deployments/entity-linkage/tests/test_downstream_auth.py @@ -0,0 +1,31 @@ +"""Entity-linkage authenticates downstream with its own machine identity. + +It runs its work in a detached background job, so the caller's token is gone by +the time the run executes — passthrough is not an option here. +""" + +import pytest +from stitch.client.auth import STITCH_CLIENT_BEARER_TOKEN_ENV_VAR +from stitch.service.auth import AuthMode + +from stitch.entity_linkage import client as client_module + + +def test_downstream_uses_machine_identity() -> None: + assert client_module._DOWNSTREAM_AUTH_MODE is AuthMode.machine + + +def test_validate_downstream_requires_machine_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, raising=False) + with pytest.raises(ValueError): + client_module.validate_downstream_auth_config_at_startup() + + +def test_validate_downstream_passes_with_machine_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv(STITCH_CLIENT_BEARER_TOKEN_ENV_VAR, "machine-tok") + # Should not raise. + client_module.validate_downstream_auth_config_at_startup() diff --git a/deployments/entity-linkage/tests/test_start.py b/deployments/entity-linkage/tests/test_start.py index 828e05cf..f725081e 100644 --- a/deployments/entity-linkage/tests/test_start.py +++ b/deployments/entity-linkage/tests/test_start.py @@ -3,45 +3,23 @@ from contextlib import AbstractAsyncContextManager import pytest -from fastapi import HTTPException +from stitch.entity_linkage import linkage as linkage_module from stitch.entity_linkage.entities import ( FieldCandidate, FieldDetailCandidate, MatchGroup, - RequestAuthContext, - User, ) from stitch.entity_linkage.errors import StitchAPIError -from stitch.entity_linkage.routers import start as start_module -from stitch.entity_linkage.routers.start import ( - StartRequest, - _extract_user_label, +from stitch.entity_linkage.linkage import ( + LinkageParams, _group_duplicate_names, _normalize_country, _resolve_match_groups, - start, + run_linkage, ) -def make_auth_context( - *, - name: str = "Test User", - email: str = "test@example.com", - sub: str = "auth0|user-123", - bearer_token: str | None = "token-123", -) -> RequestAuthContext: - return RequestAuthContext( - user=User( - id=1, - sub=sub, - email=email, - name=name, - ), - bearer_token=bearer_token, - ) - - class FakeStitchApiClient(AbstractAsyncContextManager["FakeStitchApiClient"]): def __init__( self, @@ -80,11 +58,7 @@ async def collect_oil_gas_fields( max_pages: int | None = None, ) -> tuple[list[FieldCandidate], int]: self.collect_calls.append( - { - "start_page": start_page, - "page_size": page_size, - "max_pages": max_pages, - } + {"start_page": start_page, "page_size": page_size, "max_pages": max_pages} ) if self.collect_error is not None: raise self.collect_error @@ -117,23 +91,6 @@ def test_normalize_country(country: str | None, expected: str | None) -> None: assert _normalize_country(country) == expected -def test_extract_user_label_prefers_name_then_email_then_sub() -> None: - assert ( - _extract_user_label( - User(id=1, sub="sub-1", email="a@example.com", name="Alice") - ) - == "Alice" - ) - assert ( - _extract_user_label(User(id=1, sub="sub-2", email="b@example.com", name="")) - == "b@example.com" - ) - assert ( - _extract_user_label(User(id=1, sub="sub-3", email="c@example.com", name="")) - != "sub-3" - ) - - def test_group_duplicate_names_uses_casefold_and_strips_whitespace() -> None: items = [ FieldCandidate(id=1, name="Alpha", country="US"), @@ -189,10 +146,9 @@ async def test_resolve_match_groups_groups_only_same_country_duplicates() -> Non @pytest.mark.anyio -async def test_start_returns_summary_without_merges( +async def test_run_linkage_returns_summary_without_merges( monkeypatch: pytest.MonkeyPatch, ) -> None: - auth_context = make_auth_context(name="Alex Reviewer") fake_client = FakeStitchApiClient( items=[ FieldCandidate(id=1, name="Alpha", country="US"), @@ -208,21 +164,18 @@ async def test_start_returns_summary_without_merges( }, ) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: fake_client) + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: fake_client) - response = await start( - StartRequest(apply_merges=False, page=2, page_size=25, max_pages=4), - auth_context=auth_context, + result = await run_linkage( + LinkageParams(apply_merges=False, page=2, page_size=25, max_pages=4) ) - assert response.initiated_by == "Alex Reviewer" - assert response.apply_merges is False - assert response.pages_fetched == 3 - assert response.total_records_fetched == 4 - assert response.duplicate_name_candidate_count == 3 - assert response.detail_records_fetched == 3 - assert response.match_groups == [[1, 3]] - assert response.merge_results == [] + assert result.pages_fetched == 3 + assert result.total_records_fetched == 4 + assert result.duplicate_name_candidate_count == 3 + assert result.detail_records_fetched == 3 + assert result.match_groups == [[1, 3]] + assert result.merge_results == [] assert fake_client.collect_calls == [ {"start_page": 2, "page_size": 25, "max_pages": 4} @@ -231,10 +184,9 @@ async def test_start_returns_summary_without_merges( @pytest.mark.anyio -async def test_start_applies_merges_for_each_match_group( +async def test_run_linkage_applies_merges_for_each_match_group( monkeypatch: pytest.MonkeyPatch, ) -> None: - auth_context = make_auth_context() fake_client = FakeStitchApiClient( items=[ FieldCandidate(id=1, name="Alpha", country="ignored"), @@ -254,26 +206,22 @@ async def test_start_applies_merges_for_each_match_group( }, ) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: fake_client) + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: fake_client) - response = await start( - StartRequest(apply_merges=True), - auth_context=auth_context, - ) + result = await run_linkage(LinkageParams(apply_merges=True)) - assert response.match_groups == [[1, 2], [3, 4]] + assert result.match_groups == [[1, 2], [3, 4]] assert fake_client.merge_calls == [[1, 2], [3, 4]] - assert response.merge_results == [ + assert result.merge_results == [ {"ids": [1, 2], "response": {"merged_ids": [1, 2], "winner": 1}}, {"ids": [3, 4], "response": {"merged_ids": [3, 4], "winner": 3}}, ] @pytest.mark.anyio -async def test_start_returns_no_matches_when_duplicate_names_do_not_confirm( +async def test_run_linkage_returns_no_matches_when_duplicate_names_do_not_confirm( monkeypatch: pytest.MonkeyPatch, ) -> None: - auth_context = make_auth_context() fake_client = FakeStitchApiClient( items=[ FieldCandidate(id=1, name="Alpha", country="ignored"), @@ -286,36 +234,30 @@ async def test_start_returns_no_matches_when_duplicate_names_do_not_confirm( }, ) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: fake_client) + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: fake_client) - response = await start( - StartRequest(apply_merges=True), - auth_context=auth_context, - ) + result = await run_linkage(LinkageParams(apply_merges=True)) - assert response.duplicate_name_candidate_count == 2 - assert response.detail_records_fetched == 2 - assert response.match_groups == [] - assert response.merge_results == [] + assert result.duplicate_name_candidate_count == 2 + assert result.detail_records_fetched == 2 + assert result.match_groups == [] + assert result.merge_results == [] assert fake_client.merge_calls == [] @pytest.mark.anyio -async def test_start_translates_stitch_api_error_to_502( +async def test_run_linkage_propagates_stitch_api_error( monkeypatch: pytest.MonkeyPatch, ) -> None: - auth_context = make_auth_context() + # In the job model, downstream errors propagate out of run_linkage and are + # captured by the JobManager as a failed job (no synchronous 502). fake_client = FakeStitchApiClient( collect_error=StitchAPIError( "GET /oil-gas-fields/ failed with status 500: boom" ), ) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: fake_client) - - with pytest.raises(HTTPException) as exc_info: - await start(StartRequest(), auth_context=auth_context) + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: fake_client) - exc = exc_info.value - assert exc.status_code == 502 - assert exc.detail == "GET /oil-gas-fields/ failed with status 500: boom" + with pytest.raises(StitchAPIError): + await run_linkage(LinkageParams()) diff --git a/deployments/entity-linkage/tests/test_start_api.py b/deployments/entity-linkage/tests/test_start_api.py index 0f1587e5..f5b5623d 100644 --- a/deployments/entity-linkage/tests/test_start_api.py +++ b/deployments/entity-linkage/tests/test_start_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from contextlib import AbstractAsyncContextManager import pytest @@ -7,6 +8,9 @@ from stitch.auth import TokenClaims from stitch.auth.permissions import SERVICE_ENTITY_LINKAGE_RUN +import stitch.entity_linkage.main as main_module +from stitch.entity_linkage import linkage as linkage_module +from stitch.entity_linkage.auth import get_request_auth_context, get_token_claims from stitch.entity_linkage.entities import ( FieldCandidate, FieldDetailCandidate, @@ -16,9 +20,7 @@ from stitch.entity_linkage.errors import StitchAPIError from stitch.entity_linkage.main import app from stitch.entity_linkage.routers import health as health_module -from stitch.entity_linkage.routers import start as start_module -from stitch.entity_linkage.auth import get_request_auth_context, get_token_claims -from stitch.entity_linkage import main as main_module +from stitch.entity_linkage.routers.start import get_job_manager def make_auth_context( @@ -29,12 +31,7 @@ def make_auth_context( bearer_token: str | None = "integration-token", ) -> RequestAuthContext: return RequestAuthContext( - user=User( - id=1, - sub=sub, - email=email, - name=name, - ), + user=User(id=1, sub=sub, email=email, name=name), bearer_token=bearer_token, ) @@ -84,11 +81,7 @@ async def collect_oil_gas_fields( max_pages: int | None = None, ) -> tuple[list[FieldCandidate], int]: self.collect_calls.append( - { - "start_page": start_page, - "page_size": page_size, - "max_pages": max_pages, - } + {"start_page": start_page, "page_size": page_size, "max_pages": max_pages} ) if self.collect_error is not None: raise self.collect_error @@ -113,6 +106,14 @@ async def get_auth_me(self) -> dict: return self.auth_me_response +@pytest.fixture(autouse=True) +def reset_job_manager(): + """Each test starts with a clean, isolated job store.""" + get_job_manager().reset() + yield + get_job_manager().reset() + + @pytest.fixture def auth_context() -> RequestAuthContext: return make_auth_context() @@ -144,7 +145,9 @@ def install( merge_error=merge_error, ) created_clients.append(client) - monkeypatch.setattr(start_module, "StitchApiClient", lambda: client) + # The job runs run_linkage in the background, which constructs the + # client from the linkage module's namespace. + monkeypatch.setattr(linkage_module, "StitchApiClient", lambda: client) return client return install, created_clients @@ -177,6 +180,16 @@ def override_token_claims() -> TokenClaims: app.dependency_overrides.clear() +def _poll(client: TestClient, job_id: str, *, timeout: float = 5.0) -> dict: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + body = client.get(f"/api/v1/status/{job_id}").json() + if body["state"] != "running": + return body + time.sleep(0.02) + raise AssertionError("job did not finish within timeout") + + def test_post_start_requires_service_permission( auth_context: RequestAuthContext, monkeypatch: pytest.MonkeyPatch, @@ -203,7 +216,7 @@ def override_token_claims() -> TokenClaims: assert SERVICE_ENTITY_LINKAGE_RUN in response.json()["detail"] -def test_post_start_returns_serialized_response_model( +def test_post_start_accepts_job_and_status_reports_result( test_client: TestClient, api_client_factory, ) -> None: @@ -221,20 +234,27 @@ def test_post_start_returns_serialized_response_model( }, ) - response = test_client.post( + started = test_client.post( "/api/v1/start", - json={ - "apply_merges": False, - "page": 3, - "page_size": 25, - "max_pages": 7, - }, + json={"apply_merges": False, "page": 3, "page_size": 25, "max_pages": 7}, ) - assert response.status_code == 200 - assert response.json() == { - "initiated_by": "Integration Tester", + assert started.status_code == 202 + body = started.json() + assert body["state"] == "running" + assert body["initiated_by"] == "Integration Tester" + assert body["params"] == { "apply_merges": False, + "page": 3, + "page_size": 25, + "max_pages": 7, + } + + final = _poll(test_client, body["job_id"]) + assert final["state"] == "succeeded" + assert final["error"] is None + assert final["finished_at"] is not None + assert final["result"] == { "pages_fetched": 2, "total_records_fetched": 3, "duplicate_name_candidate_count": 2, @@ -245,17 +265,13 @@ def test_post_start_returns_serialized_response_model( assert len(created_clients) == 1 assert fake_client.collect_calls == [ - { - "start_page": 3, - "page_size": 25, - "max_pages": 7, - } + {"start_page": 3, "page_size": 25, "max_pages": 7} ] assert fake_client.detail_calls == [1, 2] assert fake_client.merge_calls == [] -def test_post_start_applies_merges_and_returns_merge_results( +def test_post_start_applies_merges_and_reports_merge_results( test_client: TestClient, api_client_factory, ) -> None: @@ -279,21 +295,20 @@ def test_post_start_applies_merges_and_returns_merge_results( }, ) - response = test_client.post( - "/api/v1/start", - json={"apply_merges": True}, - ) + started = test_client.post("/api/v1/start", json={"apply_merges": True}) + assert started.status_code == 202 - assert response.status_code == 200 - assert response.json()["match_groups"] == [[10, 11], [20, 21]] - assert response.json()["merge_results"] == [ + final = _poll(test_client, started.json()["job_id"]) + assert final["state"] == "succeeded" + assert final["result"]["match_groups"] == [[10, 11], [20, 21]] + assert final["result"]["merge_results"] == [ {"ids": [10, 11], "response": {"merged_ids": [10, 11], "winner": 10}}, {"ids": [20, 21], "response": {"merged_ids": [20, 21], "winner": 20}}, ] assert fake_client.merge_calls == [[10, 11], [20, 21]] -def test_post_start_returns_empty_matches_when_country_check_does_not_confirm( +def test_post_start_reports_empty_matches_when_country_check_does_not_confirm( test_client: TestClient, api_client_factory, ) -> None: @@ -310,19 +325,17 @@ def test_post_start_returns_empty_matches_when_country_check_does_not_confirm( }, ) - response = test_client.post( - "/api/v1/start", - json={"apply_merges": True}, - ) + started = test_client.post("/api/v1/start", json={"apply_merges": True}) + final = _poll(test_client, started.json()["job_id"]) - assert response.status_code == 200 - assert response.json()["duplicate_name_candidate_count"] == 2 - assert response.json()["detail_records_fetched"] == 2 - assert response.json()["match_groups"] == [] - assert response.json()["merge_results"] == [] + assert final["state"] == "succeeded" + assert final["result"]["duplicate_name_candidate_count"] == 2 + assert final["result"]["detail_records_fetched"] == 2 + assert final["result"]["match_groups"] == [] + assert final["result"]["merge_results"] == [] -def test_post_start_translates_stitch_api_error_to_502( +def test_job_records_failure_when_downstream_errors( test_client: TestClient, api_client_factory, ) -> None: @@ -333,46 +346,90 @@ def test_post_start_translates_stitch_api_error_to_502( ), ) - response = test_client.post( - "/api/v1/start", - json={"apply_merges": False}, + started = test_client.post("/api/v1/start", json={"apply_merges": False}) + assert started.status_code == 202 + + final = _poll(test_client, started.json()["job_id"]) + assert final["state"] == "failed" + assert final["result"] is None + assert "GET /oil-gas-fields/ failed with status 500: boom" in final["error"] + + +def test_second_caller_observes_existing_run( + test_client: TestClient, + api_client_factory, +) -> None: + install, _ = api_client_factory + install( + items=[ + FieldCandidate(id=1, name="Alpha", country="ignored"), + FieldCandidate(id=2, name="alpha", country="ignored"), + ], + details_by_id={ + 1: FieldDetailCandidate(id=1, name="Alpha", country="US"), + 2: FieldDetailCandidate(id=2, name="Alpha", country="US"), + }, ) - assert response.status_code == 502 - assert response.json() == { - "detail": "GET /oil-gas-fields/ failed with status 500: boom", - } + first = test_client.post("/api/v1/start", json={"apply_merges": False}) + job_id = first.json()["job_id"] + _poll(test_client, job_id) + # Same params within the reuse window → returns the existing run (200), not + # a fresh job. This is the cross-user "request already made" behavior. + second = test_client.post("/api/v1/start", json={"apply_merges": False}) + assert second.status_code == 200 + assert second.json()["job_id"] == job_id -def test_post_start_validates_request_body_constraints( + +def test_force_starts_a_new_run( test_client: TestClient, api_client_factory, ) -> None: install, _ = api_client_factory install( - items=[], - details_by_id={}, + items=[ + FieldCandidate(id=1, name="Alpha", country="ignored"), + FieldCandidate(id=2, name="alpha", country="ignored"), + ], + details_by_id={ + 1: FieldDetailCandidate(id=1, name="Alpha", country="US"), + 2: FieldDetailCandidate(id=2, name="Alpha", country="US"), + }, ) + first = test_client.post("/api/v1/start", json={"apply_merges": False}) + job_id = first.json()["job_id"] + _poll(test_client, job_id) + + forced = test_client.post( + "/api/v1/start", json={"apply_merges": False, "force": True} + ) + assert forced.status_code == 202 + assert forced.json()["job_id"] != job_id + _poll(test_client, forced.json()["job_id"]) + + +def test_post_start_validates_request_body_constraints( + test_client: TestClient, +) -> None: response = test_client.post( "/api/v1/start", - json={ - "apply_merges": False, - "page": 0, - "page_size": 500, - "max_pages": 0, - }, + json={"apply_merges": False, "page": 0, "page_size": 500, "max_pages": 0}, ) assert response.status_code == 422 detail = response.json()["detail"] - fields = {tuple(item["loc"]) for item in detail} assert ("body", "page") in fields assert ("body", "page_size") in fields assert ("body", "max_pages") in fields +def test_status_404_for_unknown_job(test_client: TestClient) -> None: + assert test_client.get("/api/v1/status/nope").status_code == 404 + + def test_health_details_reports_ready_when_downstream_auth_probe_succeeds( test_client: TestClient, monkeypatch: pytest.MonkeyPatch, diff --git a/deployments/stitch-frontend/src/components/JobResultList.jsx b/deployments/stitch-frontend/src/components/JobResultList.jsx new file mode 100644 index 00000000..58e78437 --- /dev/null +++ b/deployments/stitch-frontend/src/components/JobResultList.jsx @@ -0,0 +1,84 @@ +import { useState } from "react"; + +const STATE_STYLES = { + running: "border-warning/30 bg-warning-soft text-warning", + succeeded: "border-success/25 bg-success-soft text-success-strong", + failed: "border-danger/25 bg-danger-soft text-danger", +}; + +const DATE_FORMATTER = new Intl.DateTimeFormat(undefined, { + month: "short", + day: "numeric", + hour: "2-digit", + minute: "2-digit", +}); + +function StateBadge({ state }) { + if (!state) return null; + const classes = STATE_STYLES[state] ?? "border-line bg-surface text-ink"; + return ( + + {state} + + ); +} + +function formatStartedAt(value) { + if (!value) return "—"; + const date = new Date(value); + return Number.isNaN(date.getTime()) ? "—" : DATE_FORMATTER.format(date); +} + +// Collapsible list of job records, newest first, with the most recent expanded +// by default. Each service supplies `renderResult(record)` for the body. +export default function JobResultList({ records, renderResult }) { + // Per-item user overrides; absent → default (newest open, others collapsed). + // Derived rather than effect-driven so the newest run stays expanded as new + // runs arrive, without a set-state-in-effect. + const [overrides, setOverrides] = useState({}); + const newestId = records[0]?.job_id; + + if (!records.length) return null; + + const isOpen = (id) => overrides[id] ?? id === newestId; + + function toggle(id) { + setOverrides((prev) => ({ ...prev, [id]: !isOpen(id) })); + } + + return ( +
    + {records.map((record) => { + const open = isOpen(record.job_id); + return ( +
  1. + + {open && ( +
    + {renderResult(record)} +
    + )} +
  2. + ); + })} +
+ ); +} diff --git a/deployments/stitch-frontend/src/components/JobTriggerButton.jsx b/deployments/stitch-frontend/src/components/JobTriggerButton.jsx new file mode 100644 index 00000000..3b7ea5a7 --- /dev/null +++ b/deployments/stitch-frontend/src/components/JobTriggerButton.jsx @@ -0,0 +1,43 @@ +import Button from "./Button"; + +// Smart trigger for a job: its label reflects whether a result already exists +// and whether a forced re-run is requested, and it shows a spinner while the +// job is running/polling. +// +// labels: { running, show, create, recreate } +// - running → shown with a spinner while a run is in flight +// - show → a prior result exists and hasn't been revealed yet +// - recreate → force is toggled on (re-run) +// - create → no prior result; first run +export default function JobTriggerButton({ + running, + force, + hasExisting, + revealed, + labels, + onClick, + disabled = false, + variant = "secondary", +}) { + let label; + if (running) label = labels.running; + else if (force) label = labels.recreate; + else if (hasExisting && !revealed) label = labels.show; + else label = labels.create; + + return ( + + ); +} diff --git a/deployments/stitch-frontend/src/components/LastUpdated.jsx b/deployments/stitch-frontend/src/components/LastUpdated.jsx new file mode 100644 index 00000000..7a819131 --- /dev/null +++ b/deployments/stitch-frontend/src/components/LastUpdated.jsx @@ -0,0 +1,28 @@ +import { useEffect, useState } from "react"; + +function relativeLabel(at) { + const seconds = Math.max(0, Math.round((Date.now() - at) / 1000)); + if (seconds < 5) return "just now"; + if (seconds < 60) return `${seconds}s ago`; + const minutes = Math.round(seconds / 60); + if (minutes < 60) return `${minutes}m ago`; + return `${Math.round(minutes / 60)}h ago`; +} + +// Live "Updated N ago" indicator; re-renders on its own so the relative time +// keeps counting up after the last poll. +export default function LastUpdated({ at }) { + const [, tick] = useState(0); + + useEffect(() => { + if (!at) return undefined; + const id = setInterval(() => tick((n) => n + 1), 5000); + return () => clearInterval(id); + }, [at]); + + if (!at) return null; + + return ( + Updated {relativeLabel(at)} + ); +} diff --git a/deployments/stitch-frontend/src/hooks/useJobRunner.js b/deployments/stitch-frontend/src/hooks/useJobRunner.js new file mode 100644 index 00000000..87410e7b --- /dev/null +++ b/deployments/stitch-frontend/src/hooks/useJobRunner.js @@ -0,0 +1,142 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { findJobs, getJobStatus, startJob } from "../queries/jobs"; + +const POLL_INTERVAL_MS = 1000; + +function sortNewestFirst(records) { + return [...records].sort( + (a, b) => + new Date(b.started_at).getTime() - new Date(a.started_at).getTime(), + ); +} + +// Drives a Stitch job from the UI: loads the prior runs for the current params +// on mount, starts/auto-polls runs, and tracks the records (newest first). +// Shared by every job-shaped service (LLM, entity-linkage, ETL). +// +// - baseUrl: where the job routes live (POST /start, POST /find, GET /status). +// - fetcher: authenticated fetch wrapper (may change each render — captured by ref). +// - lookupBody: the request params (without `force`) used to look up existing +// runs via /find; the server filters by the same dedup policy as /start, so +// there's no fetch-everything-then-filter and no client/server filter drift. +export function useJobRunner({ baseUrl, fetcher, lookupBody }) { + const [records, setRecords] = useState([]); + const [isStarting, setIsStarting] = useState(false); + const [isPolling, setIsPolling] = useState(false); + const [error, setError] = useState(""); + const [lastUpdatedAt, setLastUpdatedAt] = useState(null); + + // Stable refs so the load effect doesn't churn on every parent re-render. + const fetcherRef = useRef(fetcher); + fetcherRef.current = fetcher; + const lookupRef = useRef(lookupBody); + lookupRef.current = lookupBody; + // Serialized lookup params double as the effect's reload key. + const lookupKey = JSON.stringify(lookupBody ?? null); + // Bumped whenever params change / on unmount, to cancel stale polls. + const generationRef = useRef(0); + + const upsert = useCallback((record) => { + setRecords((prev) => + sortNewestFirst([ + ...prev.filter((r) => r.job_id !== record.job_id), + record, + ]), + ); + setLastUpdatedAt(Date.now()); + }, []); + + const poll = useCallback( + async (jobId, generation) => { + setIsPolling(true); + try { + while (generationRef.current === generation) { + const record = await getJobStatus(baseUrl, jobId, fetcherRef.current); + if (generationRef.current !== generation) return; + upsert(record); + if (record.state !== "running") return; + await new Promise((resolve) => setTimeout(resolve, POLL_INTERVAL_MS)); + } + } catch (err) { + if (generationRef.current === generation) { + setError(err.message || "Failed to check job status."); + } + } finally { + if (generationRef.current === generation) setIsPolling(false); + } + }, + [baseUrl, upsert], + ); + + // Load the runs for the current params on mount / when params change. + useEffect(() => { + generationRef.current += 1; + const generation = generationRef.current; + setRecords([]); + setError(""); + setIsPolling(false); + + (async () => { + try { + const mine = await findJobs( + baseUrl, + lookupRef.current ?? {}, + fetcherRef.current, + ); + if (generationRef.current !== generation) return; + const sorted = sortNewestFirst(mine); + setRecords(sorted); + setLastUpdatedAt(Date.now()); + const running = sorted.find((r) => r.state === "running"); + if (running) poll(running.job_id, generation); + } catch { + // No prior runs (or lookup unavailable) — start from a clean slate. + if (generationRef.current === generation) setRecords([]); + } + })(); + + return () => { + generationRef.current += 1; // cancel any in-flight poll for this generation + }; + }, [baseUrl, lookupKey, poll]); + + const start = useCallback( + async (body) => { + setIsStarting(true); + setError(""); + const generation = generationRef.current; + try { + const record = await startJob(baseUrl, body, fetcherRef.current); + if (generationRef.current !== generation) return record; + upsert(record); + if (record.state === "running") poll(record.job_id, generation); + return record; + } catch (err) { + setError(err.message || "Failed to start job."); + return null; + } finally { + setIsStarting(false); + } + }, + [baseUrl, poll, upsert], + ); + + // Known behavior: `current` is the newest run by start time (which drives the + // running/spinner state), while `latestSucceeded` (used for results/persist) + // is the newest succeeded run. With force re-runs these can differ briefly. + const current = records[0] ?? null; + const latestSucceeded = records.find((r) => r.state === "succeeded") ?? null; + + return { + records, + current, + latestSucceeded, + hasExisting: records.length > 0, + isRunning: isStarting || isPolling || current?.state === "running", + isStarting, + isPolling, + error, + lastUpdatedAt, + start, + }; +} diff --git a/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx b/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx index ff7ee1d2..ef4258f9 100644 --- a/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx +++ b/deployments/stitch-frontend/src/pages/EntityLinkagePage.jsx @@ -1,8 +1,12 @@ import { useState } from "react"; import { useAuth0 } from "@auth0/auth0-react"; import { useConfig } from "../config/useConfig"; +import { createAuthenticatedFetcher } from "../auth/api"; +import { useJobRunner } from "../hooks/useJobRunner"; +import JobTriggerButton from "../components/JobTriggerButton"; +import JobResultList from "../components/JobResultList"; +import LastUpdated from "../components/LastUpdated"; import StructuredDataView from "../components/StructuredDataView"; -import Button from "../components/Button"; function formatCount(count, singular, plural = `${singular}s`) { return `${count} ${count === 1 ? singular : plural}`; @@ -66,10 +70,6 @@ function RunResult({ result }) { const matchGroups = getMatchGroups(result); const details = getResultDetails(result); - if (!result) { - return

No run has completed yet.

; - } - return (
@@ -96,59 +96,33 @@ function RunResult({ result }) { export default function EntityLinkagePage() { const config = useConfig(); const { getAccessTokenSilently } = useAuth0(); + const fetcher = createAuthenticatedFetcher(config, getAccessTokenSilently); const [applyMerges, setApplyMerges] = useState(false); - const [loading, setLoading] = useState(false); - const [result, setResult] = useState(null); - const [error, setError] = useState(null); - - async function handleStart() { - setLoading(true); - setError(null); - setResult(null); - - try { - const token = await getAccessTokenSilently({ - authorizationParams: { audience: config.auth0.audience }, - }); - - const response = await fetch(`${config.entityLinkageBaseUrl}/start`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - apply_merges: applyMerges, - }), - }); - - const text = await response.text(); - - let parsed; - try { - parsed = text ? JSON.parse(text) : null; - } catch { - parsed = { raw: text }; - } - - if (!response.ok) { - setError({ - status: response.status, - body: parsed, - }); - return; - } - - setResult(parsed); - } catch (err) { - setError({ - status: null, - body: err instanceof Error ? err.message : String(err), - }); - } finally { - setLoading(false); + const [forceRerun, setForceRerun] = useState(false); + const [revealed, setRevealed] = useState(false); + + const job = useJobRunner({ + baseUrl: config.entityLinkageBaseUrl, + fetcher, + lookupBody: { apply_merges: applyMerges }, + }); + + function handleToggleApplyMerges(event) { + setApplyMerges(event.target.checked); + setForceRerun(false); + setRevealed(false); + } + + async function handleTrigger() { + // A recent run with these params exists and we're not forcing → reveal it. + if (job.hasExisting && !forceRerun && !revealed) { + setRevealed(true); + return; } + setRevealed(true); + await job.start({ apply_merges: applyMerges, force: forceRerun }); + setForceRerun(false); } return ( @@ -159,43 +133,80 @@ export default function EntityLinkagePage() {

Entity Linkage

- Start an entity-linkage run and review the result. + Start an entity-linkage run and review the result. A run already in + progress (or recently completed) for the same options is shared rather + than started again.

-
+
-
- +
+
+ + +
+
- {error ? ( -
-

Run error

-
- -
-
- ) : null} - -
-

Run result

-
- + {job.error && ( +
+ {job.error}
-
+ )} + + {revealed && ( +
+

Runs

+ + record.state === "succeeded" ? ( + + ) : record.state === "failed" ? ( +

+ {record.error || "Run failed."} +

+ ) : ( +

Running…

+ ) + } + /> +
+ )}
); } diff --git a/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx b/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx index cc92aa93..2c96fbae 100644 --- a/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx +++ b/deployments/stitch-frontend/src/pages/EntityLinkagePage.test.jsx @@ -3,38 +3,50 @@ import { screen, waitFor } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import { useAuth0 } from "@auth0/auth0-react"; import EntityLinkagePage from "./EntityLinkagePage"; +import * as jobsModule from "../queries/jobs"; import { auth0TestDefaults, renderWithQueryClient } from "../test/utils"; -describe("EntityLinkagePage", () => { - let getAccessTokenSilently; +const RUNNING_RECORD = { + job_id: "job-123", + state: "running", + initiated_by: "Test User", + params: { apply_merges: false, page: 1, page_size: 50, max_pages: null }, + started_at: "2026-01-01T00:00:00Z", + finished_at: null, + result: null, + error: null, +}; + +const SUCCEEDED_RECORD = { + ...RUNNING_RECORD, + state: "succeeded", + finished_at: "2026-01-01T00:00:05Z", + result: { + pages_fetched: 1, + total_records_fetched: 4, + duplicate_name_candidate_count: 4, + detail_records_fetched: 4, + match_groups: [ + [101, 102], + [203, 204, 205], + ], + merge_results: [], + }, +}; +describe("EntityLinkagePage", () => { beforeEach(() => { - getAccessTokenSilently = vi.fn().mockResolvedValue("test-access-token"); - vi.mocked(useAuth0).mockReturnValue({ - ...auth0TestDefaults, - getAccessTokenSilently, - }); + vi.clearAllMocks(); + vi.mocked(useAuth0).mockReturnValue(auth0TestDefaults); + // Default: no prior runs (loaded on mount). + vi.spyOn(jobsModule, "findJobs").mockResolvedValue([]); }); - it("renders match groups as visually separated groups", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue({ - ok: true, - status: 200, - text: async () => - JSON.stringify({ - initiated_by: "Test User", - apply_merges: false, - pages_fetched: 1, - total_records_fetched: 4, - duplicate_name_candidate_count: 4, - detail_records_fetched: 4, - match_groups: [ - [101, 102], - [203, 204, 205], - ], - merge_results: [], - }), - }); + it("starts a run, auto-polls, and renders the completed result", async () => { + const startSpy = vi + .spyOn(jobsModule, "startJob") + .mockResolvedValue(RUNNING_RECORD); + vi.spyOn(jobsModule, "getJobStatus").mockResolvedValue(SUCCEEDED_RECORD); renderWithQueryClient(); @@ -45,21 +57,77 @@ describe("EntityLinkagePage", () => { screen.getByRole("heading", { name: "Match groups" }), ).toBeInTheDocument(); }); - expect(screen.getByText("2 groups")).toBeInTheDocument(); + expect(screen.getByText("Resource 101")).toBeInTheDocument(); + expect(screen.getByText("Resource 205")).toBeInTheDocument(); + + // start body carries apply_merges + force; auto-poll happened (no manual refresh). + expect(startSpy).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ apply_merges: false, force: false }), + expect.anything(), + ); + expect(jobsModule.getJobStatus).toHaveBeenCalled(); expect( - screen.getByRole("heading", { name: "Match group 1" }), - ).toBeInTheDocument(); + screen.queryByRole("button", { name: /refresh status/i }), + ).not.toBeInTheDocument(); + }); + + it("offers 'Show result' for a recent run and reveals it without re-running", async () => { + vi.spyOn(jobsModule, "findJobs").mockResolvedValue([SUCCEEDED_RECORD]); + const startSpy = vi.spyOn(jobsModule, "startJob"); + + renderWithQueryClient(); + + const showButton = await screen.findByRole("button", { + name: /show result/i, + }); + await userEvent.click(showButton); + expect( - screen.getByRole("heading", { name: "Match group 2" }), + await screen.findByRole("heading", { name: "Match groups" }), ).toBeInTheDocument(); - expect(screen.getByText("Resource 101")).toBeInTheDocument(); - expect(screen.getByText("Resource 205")).toBeInTheDocument(); - expect(getAccessTokenSilently).toHaveBeenCalledWith({ - authorizationParams: { audience: "https://stitch-api.local" }, + expect(startSpy).not.toHaveBeenCalled(); + }); + + it("forces a re-run when Re-run is checked", async () => { + const startSpy = vi + .spyOn(jobsModule, "startJob") + .mockResolvedValue(SUCCEEDED_RECORD); + + renderWithQueryClient(); + + await userEvent.click(screen.getByRole("checkbox", { name: /re-run/i })); + await userEvent.click(screen.getByRole("button", { name: "Re-run" })); + + await waitFor(() => { + expect( + screen.getByRole("heading", { name: "Match groups" }), + ).toBeInTheDocument(); }); - expect(getAccessTokenSilently.mock.invocationCallOrder[0]).toBeLessThan( - fetch.mock.invocationCallOrder[0], + expect(startSpy).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ apply_merges: false, force: true }), + expect.anything(), ); }); + + it("surfaces a failed run", async () => { + vi.spyOn(jobsModule, "startJob").mockResolvedValue({ + ...RUNNING_RECORD, + state: "failed", + finished_at: "2026-01-01T00:00:05Z", + error: "GET /oil-gas-fields/ failed with status 500: boom", + }); + + renderWithQueryClient(); + + await userEvent.click(screen.getByRole("button", { name: "Start run" })); + + await waitFor(() => { + expect( + screen.getByText("GET /oil-gas-fields/ failed with status 500: boom"), + ).toBeInTheDocument(); + }); + }); }); diff --git a/deployments/stitch-frontend/src/pages/EtlPage.jsx b/deployments/stitch-frontend/src/pages/EtlPage.jsx index cbafee6f..75af0123 100644 --- a/deployments/stitch-frontend/src/pages/EtlPage.jsx +++ b/deployments/stitch-frontend/src/pages/EtlPage.jsx @@ -1,10 +1,18 @@ import { useState } from "react"; import { useAuth0 } from "@auth0/auth0-react"; import { useConfig } from "../config/useConfig"; +import { createAuthenticatedFetcher } from "../auth/api"; +import { useJobRunner } from "../hooks/useJobRunner"; +import JobTriggerButton from "../components/JobTriggerButton"; +import JobResultList from "../components/JobResultList"; +import LastUpdated from "../components/LastUpdated"; import StructuredDataView from "../components/StructuredDataView"; -import Button from "../components/Button"; import Input from "../components/Input"; +// NOTE: the ETL services aren't on the shared `stitch-jobs` framework yet. This +// UI targets that contract (POST /start, GET /status/{job_id}, GET /jobs) so it +// lights up once the backend adopts it. + // Per-ETL run parameters. Empty number/text fields are omitted from the // request body so the service falls back to its env-derived defaults. const GEM_FIELDS = [ @@ -43,37 +51,7 @@ const WOODMAC_FIELDS = [ }, ]; -const STATE_STYLES = { - running: "border-warning/30 bg-warning-soft text-warning", - succeeded: "border-success/25 bg-success-soft text-success-strong", - failed: "border-danger/25 bg-danger-soft text-danger", -}; - -function StateBadge({ state }) { - if (!state) return null; - - const classes = STATE_STYLES[state] ?? "border-line bg-surface text-ink"; - - return ( - - {state} - - ); -} - -async function parseJsonResponse(response) { - const text = await response.text(); - - try { - return text ? JSON.parse(text) : null; - } catch { - return { raw: text }; - } -} - -function EtlPanel({ title, description, baseUrl, fields, getToken }) { +function EtlPanel({ title, description, baseUrl, fields, fetcher }) { const [values, setValues] = useState(() => Object.fromEntries( fields.map((field) => [ @@ -82,17 +60,20 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { ]), ), ); - const [starting, setStarting] = useState(false); - const [refreshing, setRefreshing] = useState(false); - const [record, setRecord] = useState(null); - const [error, setError] = useState(null); + const [forceRerun, setForceRerun] = useState(false); + const [revealed, setRevealed] = useState(false); + + // Look up this pipeline's runs with default params (a stable key, so editing + // the tunable fields doesn't refetch on every keystroke). The pipeline is its + // own service, so /find returns its runs per the backend's dedup policy. + const job = useJobRunner({ baseUrl, fetcher, lookupBody: {} }); function setField(key, value) { setValues((prev) => ({ ...prev, [key]: value })); } function buildRequestBody() { - const body = {}; + const body = { force: forceRerun }; for (const field of fields) { const value = values[field.key]; @@ -107,81 +88,22 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { return body; } - async function handleStart() { - setStarting(true); - setError(null); - - try { - const token = await getToken(); - - const response = await fetch(`${baseUrl}/start`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify(buildRequestBody()), - }); - - const parsed = await parseJsonResponse(response); - - if (response.status === 409) { - setError({ - status: 409, - message: "A run is already in progress — refresh status to check.", - body: parsed, - }); - return; - } - - if (!response.ok) { - setError({ status: response.status, body: parsed }); - return; - } - - setRecord(parsed); - } catch (err) { - setError({ - status: null, - body: err instanceof Error ? err.message : String(err), - }); - } finally { - setStarting(false); + async function handleTrigger() { + // A recent run exists and we're not forcing → just reveal it. + if (job.hasExisting && !forceRerun && !revealed) { + setRevealed(true); + return; } + setRevealed(true); + await job.start(buildRequestBody()); + setForceRerun(false); } - async function handleRefresh() { - setRefreshing(true); - setError(null); - - try { - // GET /status is unauthenticated per the ETL OpenAPI spec. - const response = await fetch(`${baseUrl}/status`); - const parsed = await parseJsonResponse(response); - - if (!response.ok) { - setError({ status: response.status, body: parsed }); - return; - } - - setRecord(parsed); - } catch (err) { - setError({ - status: null, - body: err instanceof Error ? err.message : String(err), - }); - } finally { - setRefreshing(false); - } - } - - const isRunning = record?.state === "running"; - return (

{title}

- +

{description}

@@ -194,6 +116,7 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { type="checkbox" checked={values[field.key]} onChange={(e) => setField(field.key, e.target.checked)} + disabled={job.isRunning} className="accent-primary" /> {field.label} @@ -209,6 +132,7 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { onChange={(e) => setField(field.key, e.target.value)} placeholder={field.placeholder} min={field.type === "number" ? 1 : undefined} + disabled={job.isRunning} className="w-full" /> @@ -220,43 +144,62 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { )}
-
- - + /> +
- {error ? ( + {job.error ? (
- {error.message ? ( -

{error.message}

- ) : null} - + {job.error}
) : null}
-

Run status

- {record ? ( - +

Runs

+ {revealed && job.records.length ? ( + + record.state === "succeeded" ? ( + + ) : record.state === "failed" ? ( +

+ {record.error || "Run failed."} +

+ ) : ( +

Running…

+ ) + } + /> ) : (

- No run started yet. Start a run or refresh to fetch the latest - status. + No run started yet. Start a run to begin.

)}
@@ -267,11 +210,7 @@ function EtlPanel({ title, description, baseUrl, fields, getToken }) { export default function EtlPage() { const config = useConfig(); const { getAccessTokenSilently } = useAuth0(); - - const getToken = () => - getAccessTokenSilently({ - authorizationParams: { audience: config.auth0.audience }, - }); + const fetcher = createAuthenticatedFetcher(config, getAccessTokenSilently); return (
@@ -281,8 +220,8 @@ export default function EtlPage() {

ETL Pipelines

- Start an ETL run and check its status. Only one run per pipeline may - be active at a time. + Start an ETL run and watch its status. A recent run for a pipeline is + shown rather than started again; use “Re-run” to force a fresh run.

@@ -292,14 +231,14 @@ export default function EtlPage() { description="Load GEM oil & gas data from the configured spreadsheet and post it to Stitch." baseUrl={config.etlGemBaseUrl} fields={GEM_FIELDS} - getToken={getToken} + fetcher={fetcher} /> diff --git a/deployments/stitch-frontend/src/pages/EtlPage.test.jsx b/deployments/stitch-frontend/src/pages/EtlPage.test.jsx index 148a3e9c..94f02754 100644 --- a/deployments/stitch-frontend/src/pages/EtlPage.test.jsx +++ b/deployments/stitch-frontend/src/pages/EtlPage.test.jsx @@ -3,24 +3,36 @@ import { screen, waitFor, within } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import { useAuth0 } from "@auth0/auth0-react"; import EtlPage from "./EtlPage"; +import * as jobsModule from "../queries/jobs"; import { auth0TestDefaults, renderWithQueryClient } from "../test/utils"; +const GEM_BASE = "http://localhost:8101/api/v1"; + function getPanel(title) { return screen.getByRole("heading", { name: title }).closest("section"); } -describe("EtlPage", () => { - let getAccessTokenSilently; +function succeededRecord(overrides = {}) { + return { + job_id: "job-123", + state: "succeeded", + started_at: "2026-06-11T10:00:00Z", + finished_at: "2026-06-11T10:05:00Z", + params: {}, + result: { payloads_posted: 42 }, + error: null, + ...overrides, + }; +} +describe("EtlPage", () => { beforeEach(() => { - getAccessTokenSilently = vi.fn().mockResolvedValue("test-access-token"); - vi.mocked(useAuth0).mockReturnValue({ - ...auth0TestDefaults, - getAccessTokenSilently, - }); + vi.clearAllMocks(); + vi.mocked(useAuth0).mockReturnValue(auth0TestDefaults); + vi.spyOn(jobsModule, "findJobs").mockResolvedValue([]); }); - it("renders a panel for each ETL pipeline", () => { + it("renders a panel for each ETL pipeline with no manual refresh", () => { renderWithQueryClient(); expect(screen.getByRole("heading", { name: "GEM" })).toBeInTheDocument(); @@ -31,22 +43,14 @@ describe("EtlPage", () => { 2, ); expect( - screen.getAllByRole("button", { name: "Refresh status" }), - ).toHaveLength(2); + screen.queryByRole("button", { name: /refresh status/i }), + ).not.toBeInTheDocument(); }); - it("starts a GEM run with an authenticated token and shows the returned state", async () => { - const fetchMock = vi.spyOn(globalThis, "fetch").mockResolvedValue({ - ok: true, - status: 202, - text: async () => - JSON.stringify({ - job_id: "job-123", - state: "running", - started_at: "2026-06-11T10:00:00Z", - initiated_by: "Test User", - }), - }); + it("starts a GEM run and renders the completed result", async () => { + const startSpy = vi + .spyOn(jobsModule, "startJob") + .mockResolvedValue(succeededRecord()); renderWithQueryClient(); @@ -56,77 +60,58 @@ describe("EtlPage", () => { ); await waitFor(() => { - expect(within(gemPanel).getAllByText("running").length).toBeGreaterThan( - 0, - ); + expect(within(gemPanel).getByText("succeeded")).toBeInTheDocument(); }); - expect(getAccessTokenSilently).toHaveBeenCalledWith({ - authorizationParams: { audience: "https://stitch-api.local" }, - }); - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:8101/api/v1/start", - expect.objectContaining({ - method: "POST", - headers: expect.objectContaining({ - Authorization: "Bearer test-access-token", - }), - }), + expect(startSpy).toHaveBeenCalledWith( + GEM_BASE, + expect.objectContaining({ force: false }), + expect.anything(), ); }); - it("surfaces a friendly message when a run is already in progress (409)", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue({ - ok: false, - status: 409, - text: async () => JSON.stringify({ detail: "A run is already active" }), - }); + it("offers 'Show result' for a recent run and reveals it without re-running", async () => { + vi.spyOn(jobsModule, "findJobs").mockImplementation(async (baseUrl) => + baseUrl === GEM_BASE ? [succeededRecord()] : [], + ); + const startSpy = vi.spyOn(jobsModule, "startJob"); renderWithQueryClient(); - const woodmacPanel = getPanel("WoodMac"); - await userEvent.click( - within(woodmacPanel).getByRole("button", { name: "Start run" }), - ); + const gemPanel = getPanel("GEM"); + const showButton = await within(gemPanel).findByRole("button", { + name: /show result/i, + }); + await userEvent.click(showButton); await waitFor(() => { - expect( - within(woodmacPanel).getByText( - "A run is already in progress — refresh status to check.", - ), - ).toBeInTheDocument(); + expect(within(gemPanel).getByText("succeeded")).toBeInTheDocument(); }); + expect(startSpy).not.toHaveBeenCalled(); }); - it("refreshes status via an unauthenticated GET", async () => { - const fetchMock = vi.spyOn(globalThis, "fetch").mockResolvedValue({ - ok: true, - status: 200, - text: async () => - JSON.stringify({ - job_id: "job-789", - state: "succeeded", - started_at: "2026-06-11T10:00:00Z", - finished_at: "2026-06-11T10:05:00Z", - result: { payloads_posted: 42 }, - }), - }); + it("forces a re-run when Re-run is checked", async () => { + const startSpy = vi + .spyOn(jobsModule, "startJob") + .mockResolvedValue(succeededRecord()); renderWithQueryClient(); - const woodmacPanel = getPanel("WoodMac"); + const gemPanel = getPanel("GEM"); + await userEvent.click( + within(gemPanel).getByRole("checkbox", { name: /re-run/i }), + ); await userEvent.click( - within(woodmacPanel).getByRole("button", { name: "Refresh status" }), + within(gemPanel).getByRole("button", { name: "Re-run" }), ); await waitFor(() => { - expect( - within(woodmacPanel).getAllByText("succeeded").length, - ).toBeGreaterThan(0); + expect(within(gemPanel).getByText("succeeded")).toBeInTheDocument(); }); - - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:8102/api/v1/status", + expect(startSpy).toHaveBeenCalledWith( + GEM_BASE, + expect.objectContaining({ force: true }), + expect.anything(), ); }); }); diff --git a/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx b/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx index 6354c680..cb4f6895 100644 --- a/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx +++ b/deployments/stitch-frontend/src/pages/ResourceDetailPage.jsx @@ -4,11 +4,11 @@ import { useParams, useNavigate } from "react-router-dom"; import { useResourceDetail, useSourceDetail } from "../hooks/useResources"; import { createAuthenticatedFetcher } from "../auth/api"; import { useConfig } from "../config/useConfig"; -import { - createLLMSuggestion, - createMergeCandidate, - createResource, -} from "../queries/api"; +import { createMergeCandidate, createResource } from "../queries/api"; +import { useJobRunner } from "../hooks/useJobRunner"; +import JobTriggerButton from "../components/JobTriggerButton"; +import JobResultList from "../components/JobResultList"; +import LastUpdated from "../components/LastUpdated"; import SourceMixBar from "../components/SourceMixBar"; import SectionHeader from "../components/SectionHeader"; import { FieldCard, FieldGrid } from "../components/FieldCard"; @@ -182,45 +182,60 @@ function AISuggestionPanel({ endpoint, resourceId }) { const { getAccessTokenSilently } = useAuth0(); const fetcher = createAuthenticatedFetcher(config, getAccessTokenSilently); const [selectedField, setSelectedField] = useState(AI_SUGGESTION_FIELDS[0]); - const [result, setResult] = useState(null); - const [error, setError] = useState(""); - const [isLoading, setIsLoading] = useState(false); + const [forceRerun, setForceRerun] = useState(false); + const [revealed, setRevealed] = useState(false); const [isPersisting, setIsPersisting] = useState(false); const [persistState, setPersistState] = useState(null); + const [persistError, setPersistError] = useState(""); + + const job = useJobRunner({ + baseUrl: `${config.stitchLlmBaseUrl}/${endpoint}`, + fetcher, + lookupBody: { resource_id: resourceId, field: selectedField }, + }); + // Persist (and the value/citation rendering) act on the latest succeeded run. + const result = job.latestSucceeded?.result ?? null; const canPersist = result?.value != null; const isPersistedCurrentSuggestion = result && persistState?.status === "success" && persistState.suggestionKey === getSuggestionSubmissionKey(result); + const error = job.error || persistError; - async function handleGenerateSuggestion() { - setIsLoading(true); - setError(""); - setResult(null); + function handleFieldChange(event) { + setSelectedField(event.target.value); + setForceRerun(false); + setRevealed(false); setPersistState(null); + setPersistError(""); + } - try { - const suggestion = await createLLMSuggestion( - config, - resourceId, - selectedField, - fetcher, - endpoint, - ); - setResult(suggestion); - } catch (err) { - setError(err.message || "Failed to generate suggestion."); - } finally { - setIsLoading(false); + async function handleTrigger() { + setPersistState(null); + setPersistError(""); + + // A prior suggestion exists and we're not forcing a new one → just reveal + // it; no LLM call. + if (job.hasExisting && !forceRerun && !revealed) { + setRevealed(true); + return; } + + setRevealed(true); + await job.start({ + resource_id: resourceId, + field: selectedField, + force: forceRerun, + }); + setForceRerun(false); } async function handlePersistSuggestion() { if (!result || result.value == null) return; setIsPersisting(true); - setError(""); + setPersistError(""); const persistIntentId = createPersistIntentId(); const resourcePayload = buildLLMResourcePayload({ @@ -257,13 +272,13 @@ function AISuggestionPanel({ endpoint, resourceId }) { resourceId: createdResource.id, suggestionKey, }); - setError( + setPersistError( `Suggestion saved as resource ${createdResource.id}, but the merge draft was not created.`, ); } } catch (err) { setPersistState(null); - setError(err.message || "Failed to persist suggestion."); + setPersistError(err.message || "Failed to persist suggestion."); } finally { setIsPersisting(false); } @@ -278,11 +293,7 @@ function AISuggestionPanel({ endpoint, resourceId }) { Field - + + + +
+ +
{error && ( @@ -307,9 +338,24 @@ function AISuggestionPanel({ endpoint, resourceId }) { )} - {result && } + {revealed && ( + + record.state === "succeeded" ? ( + + ) : record.state === "failed" ? ( +

+ {record.error || "Suggestion job failed."} +

+ ) : ( +

Generating…

+ ) + } + /> + )} - {canPersist && ( + {revealed && canPersist && (