diff --git a/deployments/api/alembic/versions/6de2b873bacb_baseline.py b/deployments/api/alembic/versions/f3fb36006ce6_baseline.py similarity index 71% rename from deployments/api/alembic/versions/6de2b873bacb_baseline.py rename to deployments/api/alembic/versions/f3fb36006ce6_baseline.py index 7f867a40..87801ac4 100644 --- a/deployments/api/alembic/versions/6de2b873bacb_baseline.py +++ b/deployments/api/alembic/versions/f3fb36006ce6_baseline.py @@ -1,22 +1,31 @@ """baseline -Revision ID: 6de2b873bacb +Revision ID: f3fb36006ce6 Revises: -Create Date: 2026-06-04 12:35:31.176312 +Create Date: 2026-06-17 19:08:25.103926 """ from __future__ import annotations from alembic import op import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +import stitch.api.db.model.types # revision identifiers, used by Alembic. -revision = "6de2b873bacb" +revision = "f3fb36006ce6" down_revision = None branch_labels = None depends_on = None +DEFAULT_PRIORITIES = [ + {"source": "rmi", "priority": 1}, + {"source": "gem", "priority": 2}, + {"source": "wm", "priority": 3}, + {"source": "llm", "priority": 4}, +] + def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### @@ -33,12 +42,7 @@ def upgrade() -> None: sa.column("source", sa.String), sa.column("priority", sa.Integer), ), - [ - {"source": "rmi", "priority": 1}, - {"source": "wm", "priority": 2}, - {"source": "gem", "priority": 3}, - {"source": "llm", "priority": 4}, - ], + DEFAULT_PRIORITIES, ) op.create_table( "users", @@ -112,57 +116,8 @@ def upgrade() -> None: sa.Enum("gem", "wm", "rmi", "llm", native_enum=False), nullable=False, ), - sa.Column("owners", sa.JSON(), nullable=True), - sa.Column("operators", sa.JSON(), nullable=True), - sa.Column("source_record", sa.JSON(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("country", sa.String(), nullable=True), - sa.Column("name_local", sa.String(), nullable=True), - sa.Column("state_province", sa.String(), nullable=True), - sa.Column("region", sa.String(), nullable=True), - sa.Column("basin", sa.String(), nullable=True), - sa.Column("reservoir_formation", sa.String(), nullable=True), - sa.Column("latitude", sa.Float(), nullable=True), - sa.Column("longitude", sa.Float(), nullable=True), - sa.Column("discovery_year", sa.Integer(), nullable=True), - sa.Column("production_start_year", sa.Integer(), nullable=True), - sa.Column("fid_year", sa.Integer(), nullable=True), - sa.Column( - "location_type", - sa.Enum("Onshore", "Offshore", "Unknown", native_enum=False), - nullable=True, - ), - sa.Column( - "production_conventionality", - sa.Enum( - "Conventional", "Unconventional", "Mixed", "Unknown", native_enum=False - ), - nullable=True, - ), sa.Column( - "primary_hydrocarbon_group", - sa.Enum( - "Ultra-Light Oil", - "Light Oil", - "Medium Oil", - "Heavy Oil", - "Extra-Heavy Oil", - "Dry Gas", - "Wet Gas", - "Acid Gas", - "Condensate", - "Mixed", - "Unknown", - native_enum=False, - ), - nullable=True, - ), - sa.Column( - "field_status", - sa.Enum( - "Producing", "Non-Producing", "Abandoned", "Planned", native_enum=False - ), - nullable=True, + "source_record", stitch.api.db.model.types.StitchJson(), nullable=False ), sa.Column( "created", @@ -309,6 +264,83 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) + op.create_table( + "og_field_resource_source_priority", + sa.Column( + "resource_id", + sa.BigInteger() + .with_variant(sa.BIGINT(), "postgresql") + .with_variant(sa.INTEGER(), "sqlite"), + nullable=False, + ), + sa.Column("source", sa.String(length=10), nullable=False), + sa.Column("priority", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["resource_id"], ["og_field_resources.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["source"], + ["og_field_source_priority.source"], + ), + sa.PrimaryKeyConstraint("resource_id", "source"), + ) + op.create_table( + "oil_gas_field_source_values", + sa.Column( + "id", + sa.BigInteger() + .with_variant(sa.BIGINT(), "postgresql") + .with_variant(sa.INTEGER(), "sqlite"), + autoincrement=True, + nullable=False, + ), + sa.Column( + "source_pk", + sa.BigInteger() + .with_variant(sa.BIGINT(), "postgresql") + .with_variant(sa.INTEGER(), "sqlite"), + nullable=False, + ), + sa.Column("colname", sa.String(length=50), nullable=False), + sa.Column("value_text", sa.String(), nullable=True), + sa.Column( + "value_num", + sa.Float().with_variant(sa.DOUBLE_PRECISION(), "postgresql"), + nullable=True, + ), + sa.Column( + "value_json", + sa.JSON(none_as_null=True).with_variant( + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), "postgresql" + ), + nullable=True, + ), + sa.CheckConstraint( + "colname IN ('name', 'country', 'name_local', 'state_province', 'region', 'basin', 'reservoir_formation', 'location_type', 'production_conventionality', 'primary_hydrocarbon_group', 'field_status', 'latitude', 'longitude', 'discovery_year', 'production_start_year', 'fid_year', 'owners', 'operators')", + name="ck_source_value_colname", + ), + sa.CheckConstraint( + "(CASE WHEN value_text IS NOT NULL THEN 1 ELSE 0 END + CASE WHEN value_num IS NOT NULL THEN 1 ELSE 0 END + CASE WHEN value_json IS NOT NULL THEN 1 ELSE 0 END) = 1", + name="ck_source_value_exactly_one", + ), + sa.ForeignKeyConstraint( + ["source_pk"], ["oil_gas_field_sources.id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("source_pk", "colname", name="uq_source_value_colname"), + ) + op.create_index( + "ix_source_value_colname_num", + "oil_gas_field_source_values", + ["colname", "value_num"], + unique=False, + ) + op.create_index( + "ix_source_value_colname_text", + "oil_gas_field_source_values", + ["colname", "value_text"], + unique=False, + ) op.create_table( "merge_candidate_items", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), @@ -341,6 +373,11 @@ def upgrade() -> None: ), ) # ### end Alembic commands ### + # + # NOTE: substring search (ILIKE '%term%') currently relies on standard text + # matching backed by the (colname, value_text) B-tree index -- no trigram + # acceleration, to avoid requiring the pg_trgm extension. See the deferred + # follow-up for adding a pg_trgm GIN index if substring search gets slow. def downgrade() -> None: diff --git a/deployments/api/src/stitch/api/db/coalesce_sql.py b/deployments/api/src/stitch/api/db/coalesce_sql.py new file mode 100644 index 00000000..a77dc0e7 --- /dev/null +++ b/deployments/api/src/stitch/api/db/coalesce_sql.py @@ -0,0 +1,234 @@ +"""SQL-side coalescing over the long source-value representation. + +For each ``(resource, colname)`` the highest-priority active source supplies the +value. Priority is ``COALESCE(per-resource override, default)``. Selection uses a +``ROW_NUMBER()`` window (portable to SQLite, unlike ``DISTINCT ON``); the winning +row carries its provenance (source + source_pk) for free. + +Two consumers share the ``winners`` CTE: + * the list endpoint pivots it back to one wide row per resource (in SQL); + * the detail path streams winning rows and pivots in Python. +""" + +from __future__ import annotations + +from collections.abc import Collection + +from sqlalchemy import Text, and_, case, cast, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from stitch.ogsi.model.og_field import OilGasFieldBase +from stitch.ogsi.model.types import OGSISrcKey + +from .model import ( + MembershipModel, + MembershipStatus, + OGFieldResourceSourcePriority, + OGFieldSourcePriority, + OilGasFieldSourceModel, + OilGasFieldSourceValueModel, + ResourceModel, +) +from .model.oil_gas_field_source_value import ( + ATTRIBUTE_KINDS, + ATTRIBUTE_NAMES, + ValueKind, + materialize_value, + value_attr_for, +) + +PROVENANCE_SUFFIX = "__provenance_source" + + +def build_coalesced_values( + selected_sources: Collection[OGSISrcKey] | None = None, + licensed_sources: Collection[OGSISrcKey] | None = None, + resource_ids: Collection[int] | None = None, +): + """CTE of the priority-winning value per (resource_id, colname). + + Columns: ``resource_id, colname, value_text, value_num, value_json, + source, source_pk``. + """ + s = OilGasFieldSourceModel + v = OilGasFieldSourceValueModel + m = MembershipModel + r = ResourceModel + p = OGFieldSourcePriority + o = OGFieldResourceSourcePriority + + priority = func.coalesce(o.priority, p.priority) + active_src = ( + select( + r.id.label("resource_id"), + m.source.label("source"), + m.source_pk.label("source_pk"), + priority.label("priority"), + ) + .select_from(m) + .join(r, r.id == m.resource_id) + .join(p, p.source == m.source) + .outerjoin(o, and_(o.resource_id == r.id, o.source == m.source)) + .where( + r.repointed_id.is_(None), + m.status == MembershipStatus.ACTIVE, + ) + ) + if selected_sources is not None: + active_src = active_src.where( + m.source.in_(list(dict.fromkeys(selected_sources))) + ) + if licensed_sources is not None: + active_src = active_src.where( + m.source.in_(list(dict.fromkeys(licensed_sources))) + ) + if resource_ids is not None: + active_src = active_src.where(r.id.in_(list(resource_ids))) + # Join the source header on both pk AND source key: membership.source is not + # FK-tied to the header's source, so matching only on source_pk could let a + # mismatched membership row participate in coalescing. + active_src = active_src.join( + s, and_(s.id == m.source_pk, s.source == m.source) + ).cte("active_src") + + ranked = ( + select( + active_src.c.resource_id, + active_src.c.source, + active_src.c.source_pk, + v.colname.label("colname"), + v.value_text, + v.value_num, + v.value_json, + func.row_number() + .over( + partition_by=(active_src.c.resource_id, v.colname), + # source_pk is the final tie-break so the winner is deterministic + # even when a resource has multiple records of the same source + # (same priority + same source key). + order_by=( + active_src.c.priority.asc(), + active_src.c.source.asc(), + active_src.c.source_pk.asc(), + ), + ) + .label("rn"), + ) + .select_from(active_src) + .join(v, v.source_pk == active_src.c.source_pk) + ).cte("ranked") + + # rn == 1 keeps the single highest-priority row per (resource, colname). + return select(ranked).where(ranked.c.rn == 1).cte("coalesced_values") + + +def _when_col(values_cte, field_name: str, value_col): + """``value_col`` only on rows whose colname matches (NULL otherwise).""" + return case((values_cte.c.colname == field_name, value_col)) + + +def _pivot_value_column(values_cte, field_name: str): + """The coalesced value for ``field_name`` as a labeled column. + + Exactly one row exists per (resource, colname), so MAX just selects that + single non-null value. Postgres has no ``max(jsonb)`` aggregate, so JSON + values are maxed as text; the caller (``_list_item_from_row``) deserializes + those JSON-typed fields back to Python. + """ + if ATTRIBUTE_KINDS[field_name] is ValueKind.JSON: + return func.max( + _when_col(values_cte, field_name, cast(values_cte.c.value_json, Text)) + ).label(field_name) + value_col = getattr(values_cte.c, value_attr_for(field_name)) + return func.max(_when_col(values_cte, field_name, value_col)).label(field_name) + + +def _resource_spine(selected_sources: Collection[OGSISrcKey] | None): + """Resources that should appear in the list (membership-based existence). + + Gated by the requested ``source`` filter but NOT by licensing: an + unlicensed resource still appears, just with redacted (NULL) field values. + """ + m = MembershipModel + r = ResourceModel + stmt = ( + select(r.id.label("id")) + .select_from(r) + .join(m, m.resource_id == r.id) + .where(r.repointed_id.is_(None), m.status == MembershipStatus.ACTIVE) + ) + if selected_sources is not None: + stmt = stmt.where(m.source.in_(list(dict.fromkeys(selected_sources)))) + return stmt.distinct() + + +def build_resource_list_cte( + selected_sources: Collection[OGSISrcKey] | None, + licensed_sources: Collection[OGSISrcKey] | None, +): + """One wide row per resource: ``id``, each field, each field+provenance. + + The resource spine (existence) is LEFT JOINed to the coalesced licensed + values, so resources with only unlicensed/absent data appear with NULLs. + """ + values_cte = build_coalesced_values( + selected_sources=selected_sources, licensed_sources=licensed_sources + ) + pivot = select(values_cte.c.resource_id.label("resource_id")) + for field_name in ATTRIBUTE_NAMES: + pivot = pivot.add_columns( + _pivot_value_column(values_cte, field_name), + func.max(_when_col(values_cte, field_name, values_cte.c.source)).label( + f"{field_name}{PROVENANCE_SUFFIX}" + ), + ) + pivot = pivot.group_by(values_cte.c.resource_id).cte("resource_value_pivot") + + spine = _resource_spine(selected_sources).cte("resource_spine") + coalesced = select(spine.c.id.label("id")) + for field_name in ATTRIBUTE_NAMES: + coalesced = coalesced.add_columns( + pivot.c[field_name], + pivot.c[f"{field_name}{PROVENANCE_SUFFIX}"], + ) + coalesced = coalesced.select_from( + spine.outerjoin(pivot, pivot.c.resource_id == spine.c.id) + ) + return coalesced.cte("licensed_resource_list") + + +async def coalesce_persisted_resource( + session: AsyncSession, + resource_id: int, + licensed_sources: Collection[OGSISrcKey] | None = None, +) -> tuple[OilGasFieldBase, dict[str, tuple | None]]: + """Coalesce a single persisted resource, pivoting the winning rows in Python.""" + values_cte = build_coalesced_values( + selected_sources=None, + licensed_sources=licensed_sources, + resource_ids=[resource_id], + ) + stmt = select( + values_cte.c.colname, + values_cte.c.value_text, + values_cte.c.value_num, + values_cte.c.value_json, + values_cte.c.source, + values_cte.c.source_pk, + ) + rows = (await session.execute(stmt)).mappings().all() + + view_data: dict[str, object] = {k: None for k in ATTRIBUTE_NAMES} + provenance: dict[str, tuple | None] = {k: None for k in ATTRIBUTE_NAMES} + for row in rows: + colname = row["colname"] + value = materialize_value( + colname, + value_text=row["value_text"], + value_num=row["value_num"], + value_json=row["value_json"], + ) + view_data[colname] = value + provenance[colname] = (value, row["source"], row["source_pk"]) + + return OilGasFieldBase(**view_data), provenance diff --git a/deployments/api/src/stitch/api/db/model/__init__.py b/deployments/api/src/stitch/api/db/model/__init__.py index 2d3a9eae..55cd212a 100644 --- a/deployments/api/src/stitch/api/db/model/__init__.py +++ b/deployments/api/src/stitch/api/db/model/__init__.py @@ -1,7 +1,8 @@ from .common import Base as StitchBase -from .og_field_query_mixin import OGFieldQueryMixin from .og_field_source_priority import OGFieldSourcePriority +from .og_field_resource_source_priority import OGFieldResourceSourcePriority from .oil_gas_field_source import OilGasFieldSourceModel +from .oil_gas_field_source_value import OilGasFieldSourceValueModel from .membership import MembershipModel, MembershipStatus from .resource import ResourceModel from .merge_candidate import MergeCandidateItemModel, MergeCandidateModel @@ -10,9 +11,10 @@ __all__ = [ "MembershipModel", "MembershipStatus", - "OGFieldQueryMixin", "OGFieldSourcePriority", + "OGFieldResourceSourcePriority", "OilGasFieldSourceModel", + "OilGasFieldSourceValueModel", "MergeCandidateItemModel", "MergeCandidateModel", "ResourceModel", diff --git a/deployments/api/src/stitch/api/db/model/og_field_query_mixin.py b/deployments/api/src/stitch/api/db/model/og_field_query_mixin.py deleted file mode 100644 index e39394e2..00000000 --- a/deployments/api/src/stitch/api/db/model/og_field_query_mixin.py +++ /dev/null @@ -1,214 +0,0 @@ -"""Declarative mixin: shared OG field columns + query classmethods.""" - -from __future__ import annotations - -from collections.abc import Collection, Sequence -from typing import Any, ClassVar, Self - -from sqlalchemy import ( - ColumnElement, - Float, - Integer, - Select, - String, - asc, - desc, - func, - or_, - select, -) -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column - -from stitch.api.entities import OGFieldQueryParams, OGSI_SOURCE_DEFAULT -from stitch.ogsi.model import LocationType -from stitch.ogsi.model.types import ( - FieldStatus, - OGSISrcKey, - PrimaryHydrocarbonGroup, - ProductionConventionality, -) - - -@declarative_mixin -class OGFieldQueryMixin: - """Shared OG field columns and query classmethods. - - Provides the full set of domain columns (aligned with OilGasFieldBase, - minus owners/operators) and classmethods for filtered, sorted, paginated - queries. Subclasses may override ``_base_query`` to customise the FROM - clause (e.g. adding joins) while inheriting conditions, sorting, and - pagination logic. - """ - - # ------------------------------------------------------------------ - # Query field configuration - # ------------------------------------------------------------------ - - _q_fields: ClassVar[tuple[str, ...]] = ( - "name", - "name_local", - "basin", - "state_province", - "region", - ) - - _exact_match_fields: ClassVar[tuple[str, ...]] = ( - *_q_fields, - "id", - "country", - "field_status", - "location_type", - "production_conventionality", - "primary_hydrocarbon_group", - ) - - # ------------------------------------------------------------------ - # Shared column declarations - # ------------------------------------------------------------------ - - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str | None] = mapped_column(String, nullable=True) - country: Mapped[str | None] = mapped_column(String, nullable=True) - name_local: Mapped[str | None] = mapped_column(String, nullable=True) - state_province: Mapped[str | None] = mapped_column(String, nullable=True) - region: Mapped[str | None] = mapped_column(String, nullable=True) - basin: Mapped[str | None] = mapped_column(String, nullable=True) - reservoir_formation: Mapped[str | None] = mapped_column(String, nullable=True) - latitude: Mapped[float | None] = mapped_column(Float, nullable=True) - longitude: Mapped[float | None] = mapped_column(Float, nullable=True) - discovery_year: Mapped[int | None] = mapped_column(Integer, nullable=True) - production_start_year: Mapped[int | None] = mapped_column(Integer, nullable=True) - fid_year: Mapped[int | None] = mapped_column(Integer, nullable=True) - - # Enum/Literal columns - location_type: Mapped[LocationType | None] = mapped_column( - default=None, nullable=True - ) - production_conventionality: Mapped[ProductionConventionality | None] = ( - mapped_column(default=None, nullable=True) - ) - primary_hydrocarbon_group: Mapped[PrimaryHydrocarbonGroup | None] = mapped_column( - default=None, nullable=True - ) - field_status: Mapped[FieldStatus | None] = mapped_column( - default=None, nullable=True - ) - - # ------------------------------------------------------------------ - # Public query classmethods - # ------------------------------------------------------------------ - - @classmethod - async def query( - cls, - session: AsyncSession, - params: OGFieldQueryParams, - licensed_sources: Collection[OGSISrcKey] | None = None, - ) -> Sequence[Self]: - """Execute a filtered, sorted, paginated query and return (rows, total).""" - base = cls._base_query(params, licensed_sources=licensed_sources) - stmt = cls._apply_pagination(base, params) - rows = (await session.scalars(stmt)).all() - return rows - - @classmethod - async def count( - cls, - session: AsyncSession, - params: OGFieldQueryParams | None = None, - licensed_sources: Collection[OGSISrcKey] | None = None, - ) -> int: - """Return the total number of matching rows (unfiltered when params is None).""" - if params is None: - stmt = select(func.count()).select_from(cls) - else: - stmt = select(func.count()).select_from( - cls._base_query(params, licensed_sources=licensed_sources).subquery() - ) - return await session.scalar(stmt) or 0 - - # ------------------------------------------------------------------ - # Internal helpers (overridable) - # ------------------------------------------------------------------ - - __primary_sort_col__: ClassVar[str] = "id" - - @classmethod - def _base_query[QM: OGFieldQueryMixin]( - cls: type[QM], - params: OGFieldQueryParams, - licensed_sources: Collection[OGSISrcKey] | None = None, - ) -> Select[tuple[QM]]: - """Filtered + sorted SELECT with no pagination. - - Override this in subclasses to modify the FROM clause (e.g. add joins). - """ - stmt: Select[tuple[QM]] = select(cls).distinct() - for cond in cls._build_conditions(params, licensed_sources=licensed_sources): - stmt = stmt.where(cond) - return stmt.order_by(*cls._create_sort_clauses(params)) - - @classmethod - def _build_conditions( - cls, - params: OGFieldQueryParams, - licensed_sources: Collection[OGSISrcKey] | None = None, - ) -> list[ColumnElement[bool]]: - """Build WHERE conditions from filter params. - - ``params.source`` and ``licensed_sources`` are conceptually distinct: - the former is the user-requested existence filter; the latter is the - server-derived data-access filter. They are applied as separate - predicates so the intent stays explicit. - """ - conditions: list[ColumnElement[bool]] = [] - - if params.q: - q_term = f"%{params.q}%" - q_conds: list[ColumnElement[bool]] = [] - for field_name in cls._q_fields: - col: ColumnElement[bool] | None = getattr(cls, field_name, None) - if col is not None: - q_conds.append(col.ilike(q_term)) - if q_conds: - conditions.append(or_(*q_conds)) - - for field_name in cls._exact_match_fields: - value = getattr(params, field_name, None) - if value is not None: - col = getattr(cls, field_name, None) - if col is not None: - conditions.append(col == value) - - source_col = getattr(cls, "source", None) - if source_col is not None: - sources = list( - dict.fromkeys(getattr(params, "source", OGSI_SOURCE_DEFAULT)) - ) - conditions.append(source_col.in_(sources)) - if licensed_sources is not None: - conditions.append(source_col.in_(list(dict.fromkeys(licensed_sources)))) - - return conditions - - @classmethod - def _create_sort_clauses(cls, params: OGFieldQueryParams) -> list[Any]: - """Create ORDER BY clauses with a stable primary-key tie-breaker.""" - clauses: list[Any] = [] - sort_col = getattr(cls, params.sort_by, None) - if sort_col is not None: - direction = desc if params.sort_order == "desc" else asc - clauses.append(direction(sort_col).nulls_last()) - if params.sort_by != cls.__primary_sort_col__: - primary_sort_col = getattr(cls, cls.__primary_sort_col__, None) - if primary_sort_col is not None: - clauses.append(asc(primary_sort_col)) - return clauses - - @classmethod - def _apply_pagination[QM: OGFieldQueryMixin]( - cls: type[QM], stmt: Select[tuple[QM]], params: OGFieldQueryParams - ) -> Select[tuple[QM]]: - """Apply offset/limit for pagination.""" - return stmt.offset(params.offset).limit(params.limit) diff --git a/deployments/api/src/stitch/api/db/model/og_field_resource_source_priority.py b/deployments/api/src/stitch/api/db/model/og_field_resource_source_priority.py new file mode 100644 index 00000000..7fc9938a --- /dev/null +++ b/deployments/api/src/stitch/api/db/model/og_field_resource_source_priority.py @@ -0,0 +1,36 @@ +"""Per-resource source-priority overrides. + +Lets a resource re-rank its sources, overriding the global defaults in +``og_field_source_priority``. The effective priority used during coalescing is +``COALESCE(override.priority, default.priority)`` -- absent an override row, the +default applies, so behaviour is identical to having no overrides at all. +""" + +from sqlalchemy import ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column +from stitch.ogsi.model.types import OGSISrcKey + +from .common import Base +from .types import PORTABLE_BIGINT + + +class OGFieldResourceSourcePriority(Base): + __tablename__ = "og_field_resource_source_priority" + + resource_id: Mapped[int] = mapped_column( + PORTABLE_BIGINT, + ForeignKey("og_field_resources.id", ondelete="CASCADE"), + primary_key=True, + ) + # FUTURE: key the override on a specific source record (FK to + # oil_gas_field_sources.id) rather than the source *key*, so a resource with + # multiple records from the same source (e.g. two WoodMac records) can be + # ranked individually. Today priority is per source-class. + source: Mapped[OGSISrcKey] = mapped_column( + String(10), + ForeignKey("og_field_source_priority.source"), + primary_key=True, + ) + # Not globally unique (unlike the default table): different resources may + # legitimately reuse the same priority value. + priority: Mapped[int] = mapped_column(Integer, nullable=False) diff --git a/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py index 1b868d4c..6e109ff2 100644 --- a/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py +++ b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py @@ -1,78 +1,82 @@ from __future__ import annotations -from collections.abc import Collection -from typing import Any, ClassVar, override +from typing import Any, ClassVar from pydantic import TypeAdapter -from sqlalchemy import ( - JSON, - inspect, - select, -) -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship from stitch.ogsi.model import OGFieldSource from stitch.ogsi.model.types import OGSISrcKey -from stitch.api.db.model.types import PORTABLE_BIGINT -from stitch.api.entities import OGFieldQueryParams, User +from stitch.api.db.model.types import PORTABLE_BIGINT, StitchJson +from stitch.api.entities import User from .common import Base -from .membership import MembershipModel, MembershipStatus from .mixins import TimestampMixin, UserAuditMixin -from .og_field_query_mixin import OGFieldQueryMixin +from .oil_gas_field_source_value import ATTRIBUTE_NAMES, OilGasFieldSourceValueModel -class OilGasFieldSourceModel(OGFieldQueryMixin, TimestampMixin, UserAuditMixin, Base): - """A single OG field source record (canonicalized), feedable into a Resource.""" +class OilGasFieldSourceModel(TimestampMixin, UserAuditMixin, Base): + """Header for a single OG field source record. + + Identity + raw payload only; the coalesced attributes live in long form on + ``oil_gas_field_source_values`` (``values`` relationship). + """ type_adapter: ClassVar[TypeAdapter[OGFieldSource]] = TypeAdapter(OGFieldSource) __tablename__: str = "oil_gas_field_sources" - id: Mapped[int] = mapped_column(PORTABLE_BIGINT, primary_key=True) + id: Mapped[int] = mapped_column( + PORTABLE_BIGINT, primary_key=True, autoincrement=True + ) - # SqlAlchemy will translate Literal types into Enums + # SqlAlchemy will translate the Literal type into an Enum. source: Mapped[OGSISrcKey] = mapped_column(nullable=False) - # JSON columns - owners: Mapped[list[dict[str, Any]] | None] = mapped_column(JSON, nullable=True) - operators: Mapped[list[dict[str, Any]] | None] = mapped_column(JSON, nullable=True) - source_record: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) + # Raw, non-coalesced original payload. + source_record: Mapped[dict[str, Any]] = mapped_column(StitchJson, nullable=False) - @classmethod - @override - def _base_query( - cls, - params: OGFieldQueryParams, - licensed_sources: Collection[OGSISrcKey] | None = None, - ): - """Filter to sources with at least one active membership.""" - - active_membership = ( - select(1) - .where(MembershipModel.source_pk == cls.id) - .where(MembershipModel.status == MembershipStatus.ACTIVE) - .exists() - ) - stmt = select(cls).where(active_membership) - for cond in cls._build_conditions(params, licensed_sources=licensed_sources): - stmt = stmt.where(cond) - return stmt.order_by(*cls._create_sort_clauses(params)) + # Long-form coalesced attributes; eager-loaded so as_entity() stays sync. + values: Mapped[list[OilGasFieldSourceValueModel]] = relationship( + cascade="all, delete-orphan", + lazy="selectin", + ) @classmethod def create_from_entity(cls, ent: OGFieldSource, created_by: User): - cols = {col.key for col in inspect(cls).columns} - kwargs = {k: val for k, val in ent.model_dump(mode="json").items() if k in cols} - return cls( - **kwargs, created_by_id=created_by.id, last_updated_by_id=created_by.id - ) - - def as_entity(self) -> OGFieldSource: - return self.__class__.type_adapter.validate_python(self) + return cls._build(ent, created_by_id=created_by.id) @classmethod def from_entity(cls, entity: OGFieldSource): - mapper = inspect(cls) - column_keys = {col.key for col in mapper.columns} - filtered = {k: v for k, v in entity.model_dump().items() if k in column_keys} - return cls(**filtered) + return cls._build(entity) + + @classmethod + def _build(cls, ent: OGFieldSource, created_by_id: int | None = None): + dumped = ent.model_dump(mode="json") + kwargs: dict[str, Any] = { + "source": dumped["source"], + "source_record": dumped["source_record"], + } + if created_by_id is not None: + kwargs["created_by_id"] = created_by_id + kwargs["last_updated_by_id"] = created_by_id + header = cls(**kwargs) + header.values = [ + OilGasFieldSourceValueModel.from_attribute(colname, dumped[colname]) + for colname in ATTRIBUTE_NAMES + if dumped.get(colname) is not None + ] + return header + + def as_entity(self) -> OGFieldSource: + # Materialize absent attributes as None (the long table is dense). + data: dict[str, Any] = {colname: None for colname in ATTRIBUTE_NAMES} + data.update( + id=self.id, + source=self.source, + source_record=self.source_record, + ) + for value in self.values: + colname, py_value = value.to_attribute() + data[colname] = py_value + return self.__class__.type_adapter.validate_python(data) diff --git a/deployments/api/src/stitch/api/db/model/oil_gas_field_source_value.py b/deployments/api/src/stitch/api/db/model/oil_gas_field_source_value.py new file mode 100644 index 00000000..12feb840 --- /dev/null +++ b/deployments/api/src/stitch/api/db/model/oil_gas_field_source_value.py @@ -0,0 +1,174 @@ +"""Long (EAV) storage for OG field source attributes. + +Each coalesced attribute of a source record is stored as one row here, instead +of as a column on a wide ``oil_gas_field_sources`` table. The value lives in one +of three typed columns (``value_text`` / ``value_num`` / ``value_json``) chosen +by ``ATTRIBUTE_KINDS`` -- the single source of truth mapping each +``OilGasFieldBase`` field to its physical storage. Typed columns keep Postgres's +own type system doing exact-match, substring, and numerically-correct ordering +without per-query casts. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from sqlalchemy import ( + CheckConstraint, + ForeignKey, + Index, + String, + UniqueConstraint, +) +from sqlalchemy.orm import Mapped, mapped_column + +from stitch.ogsi.model.og_field import OilGasFieldBase + +from .common import Base +from .types import PORTABLE_BIGINT, PORTABLE_FLOAT, PORTABLE_JSON_NULL + + +class ValueKind(StrEnum): + TEXT = "text" + INT = "int" + FLOAT = "float" + JSON = "json" + + +# Single source of truth: which physical value column backs each coalesced +# attribute of ``OilGasFieldBase`` and how to materialize it back to Python. +ATTRIBUTE_KINDS: dict[str, ValueKind] = { + "name": ValueKind.TEXT, + "country": ValueKind.TEXT, + "name_local": ValueKind.TEXT, + "state_province": ValueKind.TEXT, + "region": ValueKind.TEXT, + "basin": ValueKind.TEXT, + "reservoir_formation": ValueKind.TEXT, + "location_type": ValueKind.TEXT, + "production_conventionality": ValueKind.TEXT, + "primary_hydrocarbon_group": ValueKind.TEXT, + "field_status": ValueKind.TEXT, + "latitude": ValueKind.FLOAT, + "longitude": ValueKind.FLOAT, + "discovery_year": ValueKind.INT, + "production_start_year": ValueKind.INT, + "fid_year": ValueKind.INT, + "owners": ValueKind.JSON, + "operators": ValueKind.JSON, +} + +# Fail fast if the registry drifts from the entity it mirrors. Use an explicit +# raise (not assert) so the check survives `python -O`. +if set(ATTRIBUTE_KINDS) != set(OilGasFieldBase.model_fields): + raise RuntimeError( + "ATTRIBUTE_KINDS out of sync with OilGasFieldBase fields: " + f"{set(ATTRIBUTE_KINDS) ^ set(OilGasFieldBase.model_fields)}" + ) + +ATTRIBUTE_NAMES: tuple[str, ...] = tuple(ATTRIBUTE_KINDS) + +# Attributes stored as JSON (owners/operators) -- emitted as text by the list +# coalescing pivot and deserialized in Python. +JSON_ATTRIBUTE_NAMES: frozenset[str] = frozenset( + name for name, kind in ATTRIBUTE_KINDS.items() if kind is ValueKind.JSON +) + +_NUM_KINDS = frozenset({ValueKind.INT, ValueKind.FLOAT}) + + +def value_attr_for(colname: str) -> str: + """Return the physical column attribute name backing ``colname``.""" + kind = ATTRIBUTE_KINDS[colname] + if kind in _NUM_KINDS: + return "value_num" + if kind is ValueKind.JSON: + return "value_json" + return "value_text" + + +def materialize_value( + colname: str, + *, + value_text: Any, + value_num: Any, + value_json: Any, +) -> Any: + """Pick + coerce the Python value for ``colname`` from typed columns.""" + kind = ATTRIBUTE_KINDS[colname] + if kind is ValueKind.INT: + return None if value_num is None else int(value_num) + if kind is ValueKind.FLOAT: + return value_num + if kind is ValueKind.JSON: + return value_json + return value_text + + +class OilGasFieldSourceValueModel(Base): + """A single (source-record, attribute) value in long form.""" + + __tablename__ = "oil_gas_field_source_values" + + id: Mapped[int] = mapped_column( + PORTABLE_BIGINT, primary_key=True, autoincrement=True + ) + source_pk: Mapped[int] = mapped_column( + PORTABLE_BIGINT, + ForeignKey("oil_gas_field_sources.id", ondelete="CASCADE"), + nullable=False, + ) + colname: Mapped[str] = mapped_column(String(50), nullable=False) + value_text: Mapped[str | None] = mapped_column(String, nullable=True) + value_num: Mapped[float | None] = mapped_column(PORTABLE_FLOAT, nullable=True) + value_json: Mapped[Any | None] = mapped_column(PORTABLE_JSON_NULL, nullable=True) + + __table_args__ = ( + # Dense table: at most one value per (record, attribute). + UniqueConstraint("source_pk", "colname", name="uq_source_value_colname"), + # Exactly one typed column populated -- a value row is never empty. + CheckConstraint( + "(CASE WHEN value_text IS NOT NULL THEN 1 ELSE 0 END" + " + CASE WHEN value_num IS NOT NULL THEN 1 ELSE 0 END" + " + CASE WHEN value_json IS NOT NULL THEN 1 ELSE 0 END) = 1", + name="ck_source_value_exactly_one", + ), + # colname is a closed, code-defined set. + CheckConstraint( + "colname IN (" + ", ".join(f"'{n}'" for n in ATTRIBUTE_NAMES) + ")", + name="ck_source_value_colname", + ), + # Exact-match + DISTINCT listing across text attributes. + Index("ix_source_value_colname_text", "colname", "value_text"), + # Numerically-correct ordered scans (lat/long/years). + Index("ix_source_value_colname_num", "colname", "value_num"), + ) + + @classmethod + def from_attribute(cls, colname: str, value: Any) -> OilGasFieldSourceValueModel: + """Build a value row for ``colname`` routing ``value`` to its column.""" + kind = ATTRIBUTE_KINDS[colname] + if kind in _NUM_KINDS: + return cls(colname=colname, value_num=value) + if kind is ValueKind.JSON: + return cls(colname=colname, value_json=value) + return cls(colname=colname, value_text=value) + + def to_attribute(self) -> tuple[str, Any]: + """Return ``(colname, python_value)`` materialized to its declared type.""" + return self.colname, materialize_value( + self.colname, + value_text=self.value_text, + value_num=self.value_num, + value_json=self.value_json, + ) + + @classmethod + def value_col_for(cls, colname: str): + kind = ATTRIBUTE_KINDS[colname] + if kind in _NUM_KINDS: + return cls.value_num + if kind is ValueKind.JSON: + return cls.value_json + return cls.value_text diff --git a/deployments/api/src/stitch/api/db/model/types.py b/deployments/api/src/stitch/api/db/model/types.py index 9e901c5d..3a8adf49 100644 --- a/deployments/api/src/stitch/api/db/model/types.py +++ b/deployments/api/src/stitch/api/db/model/types.py @@ -1,4 +1,4 @@ -from sqlalchemy import JSON, BigInteger, Dialect, TypeDecorator +from sqlalchemy import JSON, BigInteger, Dialect, Float, TypeDecorator from sqlalchemy.dialects import postgresql, sqlite @@ -8,6 +8,18 @@ .with_variant(sqlite.INTEGER(), "sqlite") ) PORTABLE_JSON = JSON().with_variant(postgresql.JSONB(), "postgresql") +PORTABLE_FLOAT = ( + Float() + .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") + .with_variant(sqlite.REAL(), "sqlite") +) + +# JSON column where a Python ``None`` binds to SQL NULL (not the JSON ``null`` +# literal). Needed so the long values table's "exactly one column populated" +# check sees an unset value column as truly NULL. +PORTABLE_JSON_NULL = JSON(none_as_null=True).with_variant( + postgresql.JSONB(none_as_null=True), "postgresql" +) class StitchJson(TypeDecorator): diff --git a/deployments/api/src/stitch/api/db/og_field_resource_actions.py b/deployments/api/src/stitch/api/db/og_field_resource_actions.py index c9bbf3a1..4976c39a 100644 --- a/deployments/api/src/stitch/api/db/og_field_resource_actions.py +++ b/deployments/api/src/stitch/api/db/og_field_resource_actions.py @@ -1,3 +1,4 @@ +import json from collections.abc import Collection, Sequence from typing import Any, get_args @@ -5,9 +6,7 @@ from sqlalchemy import ( ColumnElement, String, - and_, asc, - case, cast, desc, func, @@ -38,32 +37,20 @@ from stitch.ogsi.model.og_field import OilGasFieldBase from stitch.ogsi.model.types import OGSISrcKey +from .coalesce_sql import PROVENANCE_SUFFIX, build_resource_list_cte from .model import ( MembershipModel, MembershipStatus, - OGFieldSourcePriority, - OilGasFieldSourceModel, ResourceModel, ) -from .model.og_field_source_priority import DEFAULT_PRIORITIES +from .model.oil_gas_field_source_value import JSON_ATTRIBUTE_NAMES +from .queries import EXACT_MATCH_FIELDS, Q_FIELDS from .utils import resource_model_to_entity -_LIST_JSON_FIELDS = ("owners", "operators") -_LIST_SCALAR_FIELDS = tuple( - field_name - for field_name in OilGasFieldBase.model_fields - if field_name not in _LIST_JSON_FIELDS -) -_LIST_DATA_FIELDS = (*_LIST_SCALAR_FIELDS, *_LIST_JSON_FIELDS) -_PROVENANCE_SUFFIX = "__provenance_source" _FILTER_OPTION_FIELDS: frozenset[str] = frozenset(get_args(FilterOptionField)) -def _priority_values() -> tuple[int, ...]: - return tuple(int(priority["priority"]) for priority in DEFAULT_PRIORITIES) - - async def query( session: AsyncSession, params: OGFieldQueryParams, @@ -76,7 +63,7 @@ async def query( detail="sort_by=source is not supported for resource list queries.", ) - coalesced = _build_licensed_resource_list_cte(params, licensed_sources) + coalesced = build_resource_list_cte(params.source, licensed_sources) filtered = select(coalesced) for condition in _build_final_conditions(coalesced, params): filtered = filtered.where(condition) @@ -106,7 +93,7 @@ async def filter_options( detail=f"field={params.field} is not supported for resource filter options.", ) - coalesced = _build_licensed_resource_list_cte(params, licensed_sources) + coalesced = build_resource_list_cte(params.source, licensed_sources) col = _resource_list_column(coalesced, params.field) if col is None: raise HTTPException( @@ -125,114 +112,6 @@ async def filter_options( return list(values.all()) -def _build_licensed_resource_list_cte( - params: OGFieldQueryParams, - licensed_sources: Collection[OGSISrcKey] | None, -): - s = OilGasFieldSourceModel - m = MembershipModel - r = ResourceModel - p = OGFieldSourcePriority - - selected_sources = list(dict.fromkeys(params.source)) - source_join_conditions = [ - s.id == m.source_pk, - s.source == m.source, - ] - if licensed_sources is not None: - source_join_conditions.append( - s.source.in_(list(dict.fromkeys(licensed_sources))) - ) - - qualified = ( - select( - r.id.label("id"), - m.source.label("source"), - p.priority.label("priority"), - ) - .join(m, m.resource_id == r.id) - .join(p, p.source == m.source) - .outerjoin(s, and_(*source_join_conditions)) - .where( - r.repointed_id.is_(None), - m.status == MembershipStatus.ACTIVE, - m.source.in_(selected_sources), - ) - ) - - for field_name in _LIST_DATA_FIELDS: - qualified = qualified.add_columns(getattr(s, field_name).label(field_name)) - - qualified_cte = qualified.cte("qualified_resource_sources") - coalesced = select(qualified_cte.c.id.label("id")).group_by(qualified_cte.c.id) - - for field_name in _LIST_SCALAR_FIELDS: - field_col = getattr(qualified_cte.c, field_name) - value_by_priority = [ - func.max(case((qualified_cte.c.priority == priority, field_col))) - for priority in _priority_values() - ] - provenance_by_priority = [ - func.max( - case( - ( - and_( - qualified_cte.c.priority == priority, - field_col.is_not(None), - ), - qualified_cte.c.source, - ) - ) - ) - for priority in _priority_values() - ] - coalesced = coalesced.add_columns( - func.coalesce(*value_by_priority).label(field_name), - func.coalesce(*provenance_by_priority).label( - f"{field_name}{_PROVENANCE_SUFFIX}" - ), - ) - - for field_name in _LIST_JSON_FIELDS: - value_alias = qualified_cte.alias(f"{field_name}_value_source") - provenance_alias = qualified_cte.alias(f"{field_name}_provenance_source") - value_col = getattr(value_alias.c, field_name) - provenance_col = getattr(provenance_alias.c, field_name) - value_is_present = _json_value_is_present(value_col) - provenance_is_present = _json_value_is_present(provenance_col) - - value_subquery = ( - select(value_col) - .where( - value_alias.c.id == qualified_cte.c.id, - value_is_present, - ) - .order_by(value_alias.c.priority.asc()) - .limit(1) - .scalar_subquery() - ) - provenance_subquery = ( - select(provenance_alias.c.source) - .where( - provenance_alias.c.id == qualified_cte.c.id, - provenance_is_present, - ) - .order_by(provenance_alias.c.priority.asc()) - .limit(1) - .scalar_subquery() - ) - coalesced = coalesced.add_columns( - value_subquery.label(field_name), - provenance_subquery.label(f"{field_name}{_PROVENANCE_SUFFIX}"), - ) - - return coalesced.cte("licensed_resource_list") - - -def _json_value_is_present(col) -> ColumnElement[bool]: - return and_(col.is_not(None), cast(col, String) != "null") - - def _build_final_conditions( coalesced, params: OGFieldQueryParams, @@ -242,14 +121,14 @@ def _build_final_conditions( if params.q: q_term = f"%{params.q}%" q_conditions: list[ColumnElement[bool]] = [] - for field_name in OilGasFieldSourceModel._q_fields: + for field_name in Q_FIELDS: col = getattr(coalesced.c, field_name, None) if col is not None: q_conditions.append(col.ilike(q_term)) if q_conditions: conditions.append(or_(*q_conditions)) - for field_name in OilGasFieldSourceModel._exact_match_fields: + for field_name in EXACT_MATCH_FIELDS: value = getattr(params, field_name, None) if value is None: continue @@ -277,15 +156,23 @@ def _resource_list_column(coalesced, field_name: str): return getattr(coalesced.c, field_name, None) +def _row_field_value(row: RowMapping, field_name: str): + """Read a coalesced field, deserializing JSON-typed fields emitted as text.""" + value = row.get(field_name) + if field_name in JSON_ATTRIBUTE_NAMES and isinstance(value, str): + return json.loads(value) + return value + + def _list_item_from_row(row: RowMapping) -> OGFieldListItemView: data = OilGasFieldBase( **{ - field_name: row.get(field_name) + field_name: _row_field_value(row, field_name) for field_name in OilGasFieldBase.model_fields } ) provenance: dict[str, OGSISrcKey | None] = { - field_name: row.get(f"{field_name}{_PROVENANCE_SUFFIX}") + field_name: row.get(f"{field_name}{PROVENANCE_SUFFIX}") for field_name in OilGasFieldBase.model_fields } return OGFieldListItemView(id=row["id"], data=data, provenance=provenance) diff --git a/deployments/api/src/stitch/api/db/og_field_source_actions.py b/deployments/api/src/stitch/api/db/og_field_source_actions.py index 2df39f0c..19a06eca 100644 --- a/deployments/api/src/stitch/api/db/og_field_source_actions.py +++ b/deployments/api/src/stitch/api/db/og_field_source_actions.py @@ -1,6 +1,6 @@ from collections.abc import Collection, Sequence -from sqlalchemy import select +from sqlalchemy import func, select from stitch.api.db.config import AsyncSession from stitch.api.db.errors import ( @@ -19,6 +19,10 @@ ResourceModel, MembershipModel, ) +from .queries import ( + construct_sources_count_statement, + construct_sources_query_statement, +) from .utils import resource_model_to_entity @@ -159,10 +163,23 @@ async def query( params: OGFieldQueryParams, licensed_sources: Collection[OGSISrcKey] | None = None, ) -> tuple[Sequence[OGFieldSource], int]: - models = await OilGasFieldSourceModel.query( - session, params, licensed_sources=licensed_sources - ) - total = await OilGasFieldSourceModel.count( - session, params, licensed_sources=licensed_sources + """Filtered/sorted/paginated source records (id-ordered) plus total count.""" + stmt = construct_sources_query_statement(params, licensed_sources) + ids = list((await session.scalars(stmt)).all()) + + count_stmt = construct_sources_count_statement(params, licensed_sources) + total = ( + await session.scalar(select(func.count()).select_from(count_stmt.subquery())) + or 0 ) - return tuple(m.as_entity() for m in models), total + + if not ids: + return (), total + + headers = ( + await session.scalars( + select(OilGasFieldSourceModel).where(OilGasFieldSourceModel.id.in_(ids)) + ) + ).all() + by_id = {h.id: h for h in headers} + return tuple(by_id[i].as_entity() for i in ids if i in by_id), total diff --git a/deployments/api/src/stitch/api/db/queries.py b/deployments/api/src/stitch/api/db/queries.py new file mode 100644 index 00000000..533f9f2e --- /dev/null +++ b/deployments/api/src/stitch/api/db/queries.py @@ -0,0 +1,216 @@ +"""Pure construction of the source-record query/count statements. + +A source record's attributes live in the long ``oil_gas_field_source_values`` +table, not as wide columns. These helpers pivot each active-membership record's +value rows back into a wide, one-row-per-record CTE -- narrowed to only the +attributes the current query filters or sorts on -- then build the filtered, +sorted, paginated id-``Select`` the endpoint needs. Construction is pure (no +session); execution + hydration live in ``og_field_source_actions``. +""" + +from __future__ import annotations + +from collections.abc import Collection +from typing import Any, Final + +from sqlalchemy import ( + CTE, + ColumnElement, + Select, + asc, + case, + desc, + func, + or_, + select, +) + +from stitch.api.db.model import ( + MembershipModel, + MembershipStatus, + OilGasFieldSourceModel, + OilGasFieldSourceValueModel, +) +from stitch.api.db.model.oil_gas_field_source_value import value_attr_for +from stitch.api.entities import OGSI_SOURCE_DEFAULT, OGFieldQueryParams +from stitch.ogsi.model.types import OGSISrcKey + +# Single source of truth for the source-list field metadata. This is a shared +# cross-module contract: the resource-list actions import these constants too. +Q_FIELDS: Final[tuple[str, ...]] = ( + "name", + "name_local", + "basin", + "state_province", + "region", +) + +EXACT_MATCH_FIELDS: Final[tuple[str, ...]] = ( + *Q_FIELDS, + "id", + "country", + "field_status", + "location_type", + "production_conventionality", + "primary_hydrocarbon_group", +) + +PRIMARY_SORT_COL: Final[str] = "id" + +# Sort targets resolved from a header column (``id``/``source``) or absent on +# the source path (``resource_id``); none of these need a pivoted value column. +_HEADER_SORT_FIELDS: Final[frozenset[str]] = frozenset({"id", "source", "resource_id"}) + + +def _participating_columns(params: OGFieldQueryParams) -> list[str]: + """Value attributes the query actually touches -- the columns to pivot. + + The sort field (unless it resolves from a header column), every *set* + exact-match field, and ``Q_FIELDS`` when a substring search is requested. + ``id`` is excluded: it is a header column on the base select, not pivoted. + Order is preserved and duplicates removed. + """ + participating: list[str] = [] + if params.sort_by not in _HEADER_SORT_FIELDS: + participating.append(params.sort_by) + + for field in EXACT_MATCH_FIELDS: + if field != "id" and getattr(params, field, None) is not None: + participating.append(field) + + if params.q: + participating += Q_FIELDS + + return list(dict.fromkeys(participating)) + + +def base_source_query_statement( + params: OGFieldQueryParams, + licensed_sources: Collection[OGSISrcKey] | None = None, +) -> CTE: + """One wide row per active-membership source record, narrowed to ``params``. + + ``id``/``source`` come straight off the header; each participating value + attribute is pivoted out of its typed value column. JSON attributes + (owners/operators) are never filterable or sortable -- the closed Literal + param types keep them out of the participating set -- so no JSON branch is + needed here (this also sidesteps the missing ``max(jsonb)`` aggregate). + """ + s = OilGasFieldSourceModel + m = MembershipModel + v = OilGasFieldSourceValueModel + + # EXISTS rather than a join to memberships: a record can have several + # memberships, and a join would fan it out to one row per membership, + # inflating the GROUP BY pivot. Correlate on both pk AND source key + # (membership.source is not FK-tied to the header's source) so a mismatched + # row cannot mark the record active. + active_membership = ( + select(1) + .where(m.source_pk == s.id) + .where(m.source == s.source) + .where(m.status == MembershipStatus.ACTIVE) + .exists() + ) + + stmt = ( + select(s.id.label("id"), s.source.label("source")) + .select_from(s) + .outerjoin(v, v.source_pk == s.id) + .where(active_membership) + ) + if licensed_sources is not None: + stmt = stmt.where(s.source.in_(list(dict.fromkeys(licensed_sources)))) + stmt = stmt.group_by(s.id, s.source) + + for field_name in _participating_columns(params): + value_col = getattr(v, value_attr_for(field_name)) + stmt = stmt.add_columns( + func.max(case((v.colname == field_name, value_col))).label(field_name) + ) + return stmt.cte("source_base") + + +def _require_column(cte: CTE, field_name: str) -> ColumnElement[Any]: + """Return ``cte.c.`` or raise if the narrowing dropped it. + + Conditions and sort clauses are derived from the same + ``_participating_columns`` result as the pivot, so a missing column means + the narrowing drifted -- raise (per the repo convention) rather than + silently drop the filter/sort, which would return wrong rows. + """ + col = getattr(cte.c, field_name, None) + if col is None: + raise RuntimeError( + f"source query references {field_name!r}, absent from the narrowed " + "pivot; participating-columns narrowing is out of sync." + ) + return col + + +def _build_conditions( + cte: CTE, + params: OGFieldQueryParams, + licensed_sources: Collection[OGSISrcKey] | None = None, +) -> list[ColumnElement[bool]]: + conditions: list[ColumnElement[bool]] = [] + + if params.q: + q_term = f"%{params.q}%" + conditions.append( + or_(*(_require_column(cte, field).ilike(q_term) for field in Q_FIELDS)) + ) + + for field_name in EXACT_MATCH_FIELDS: + value = getattr(params, field_name, None) + if value is None: + continue + conditions.append(_require_column(cte, field_name) == value) + + sources = list(dict.fromkeys(getattr(params, "source", OGSI_SOURCE_DEFAULT))) + conditions.append(cte.c.source.in_(sources)) + if licensed_sources is not None: + conditions.append(cte.c.source.in_(list(dict.fromkeys(licensed_sources)))) + + return conditions + + +def _build_sort_clauses(cte: CTE, params: OGFieldQueryParams) -> list[Any]: + clauses: list[Any] = [] + + # resource_id does not exist on the source path; a sort by it degrades to + # the id tiebreak below (matching the prior mixin behavior). + if params.sort_by != "resource_id": + sort_col = _require_column(cte, params.sort_by) + direction = desc if params.sort_order == "desc" else asc + clauses.append(direction(sort_col).nulls_last()) + + if params.sort_by != PRIMARY_SORT_COL: + clauses.append(asc(_require_column(cte, PRIMARY_SORT_COL))) + + return clauses + + +def construct_sources_query_statement( + params: OGFieldQueryParams, + licensed_sources: Collection[OGSISrcKey] | None = None, +) -> Select[tuple[int]]: + """Filtered + sorted + paginated id-``Select`` for the source-list query.""" + base = base_source_query_statement(params, licensed_sources=licensed_sources) + stmt = select(base.c.id) + for cond in _build_conditions(base, params, licensed_sources): + stmt = stmt.where(cond) + stmt = stmt.order_by(*_build_sort_clauses(base, params)) + return stmt.offset(params.offset).limit(params.limit) + + +def construct_sources_count_statement( + params: OGFieldQueryParams, + licensed_sources: Collection[OGSISrcKey] | None = None, +) -> Select[tuple[int]]: + """Filtered (unordered, unpaginated) id-``Select``; caller wraps in count().""" + base = base_source_query_statement(params, licensed_sources=licensed_sources) + stmt = select(base.c.id) + for cond in _build_conditions(base, params, licensed_sources): + stmt = stmt.where(cond) + return stmt diff --git a/deployments/api/src/stitch/api/db/utils.py b/deployments/api/src/stitch/api/db/utils.py index 06fb36f1..cde425cc 100644 --- a/deployments/api/src/stitch/api/db/utils.py +++ b/deployments/api/src/stitch/api/db/utils.py @@ -13,6 +13,7 @@ OGSISrcKey, ) from stitch.api.coalesce import coalesce_og_field_resource +from stitch.api.db.coalesce_sql import coalesce_persisted_resource from stitch.api.db.errors import ResourceIntegrityError from .model import ResourceModel @@ -60,7 +61,10 @@ async def resource_model_to_entity( rep_model = await session.get(ResourceModel, model.repointed_id) rep_res = rep_model.as_empty_entity() if rep_model else None - view, provenance = coalesce_og_field_resource(src_data) + # Coalesce in SQL (priority-resolved, override-aware) over the long values. + view, provenance = await coalesce_persisted_resource( + session, model.id, licensed_sources=licensed_sources + ) return OGFieldResource( id=model.id, diff --git a/deployments/api/tests/db/test_base_query.py b/deployments/api/tests/db/test_base_query.py index b64d8fd6..d5b79d75 100644 --- a/deployments/api/tests/db/test_base_query.py +++ b/deployments/api/tests/db/test_base_query.py @@ -1,16 +1,20 @@ -"""Integration tests for OGFieldQueryMixin query helpers against OilGasFieldSourceModel.""" +"""Integration tests for the source-list query path via og_field_source_actions.query.""" import pytest -from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession -from stitch.api.db.model import OilGasFieldSourceModel +from stitch.api.db import og_field_source_actions as source_actions +from stitch.api.db.model import ( + MembershipModel, + MembershipStatus, + ResourceModel, +) from stitch.api.entities import ( OGFieldQueryParams, OGFieldSortParams, User, ) -from tests.utils import make_source_record +from tests.utils import make_source_model @pytest.fixture @@ -18,23 +22,12 @@ async def seeded_sources( seeded_integration_session: AsyncSession, test_user: User, ): - """Seed 8 diverse source rows for query testing.""" + """Seed 8 diverse source rows (each with an active membership) for query testing.""" session = seeded_integration_session uid = test_user.id def with_record(source: str, **kwargs): - payload = { - "source": source, - "name": kwargs.get("name"), - "country": kwargs.get("country"), - } - return OilGasFieldSourceModel( - source=source, - source_record=make_source_record(payload=payload).model_dump(mode="json"), - created_by_id=uid, - last_updated_by_id=uid, - **kwargs, - ) + return make_source_model(source=source, created_by_id=uid, **kwargs) sources = [ with_record( @@ -103,22 +96,31 @@ def with_record(source: str, **kwargs): ] session.add_all(sources) await session.flush() + + # Each source needs an active membership to be visible to the source-list query. + resource = ResourceModel.create(created_by=test_user) + session.add(resource) + await session.flush() + session.add_all( + MembershipModel.create( + created_by=test_user, + resource_id=resource.id, + source=src.source, + source_pk=src.id, + status=MembershipStatus.ACTIVE, + ) + for src in sources + ) + await session.flush() return sources -async def _execute(session: AsyncSession, **overrides): - """Helper: build and execute a query using the mixin's building blocks directly.""" +async def _execute(session: AsyncSession, *, licensed_sources=None, **overrides): + """Run the long-aware source-list query and return coalesced entities + total.""" params = OGFieldQueryParams(**overrides) - M = OilGasFieldSourceModel - - base = select(M).distinct() - for cond in M._build_conditions(params): - base = base.where(cond) - base = base.order_by(*M._create_sort_clauses(params)) - - total = await session.scalar(select(func.count()).select_from(base.subquery())) or 0 - stmt = M._apply_pagination(base, params) - rows = (await session.scalars(stmt)).all() + rows, total = await source_actions.query( + session, params, licensed_sources=licensed_sources + ) return rows, total @@ -267,3 +269,161 @@ async def test_invalid_sort_field_raises(self): """OGFieldSortParams with invalid sort_by raises ValidationError.""" with pytest.raises(Exception): OGFieldSortParams(sort_by="owners") + + +class TestSourceFilter: + """`source` is a source-path filter (resources cannot filter by source).""" + + @pytest.mark.anyio + async def test_single_source(self, seeded_integration_session, seeded_sources): + """source=['gem'] returns only the 4 gem rows.""" + rows, total = await _execute(seeded_integration_session, source=["gem"]) + assert total == 4 + assert {r.source for r in rows} == {"gem"} + + @pytest.mark.anyio + async def test_multiple_sources(self, seeded_integration_session, seeded_sources): + """source=['gem','wm'] returns the 4 gem + 2 wm rows.""" + rows, total = await _execute( + seeded_integration_session, source=["gem", "wm"], page_size=200 + ) + assert total == 6 + assert {r.source for r in rows} == {"gem", "wm"} + + +class TestLicensedSourcesGating: + """licensed_sources hides rows whose source is not licensed.""" + + @pytest.mark.anyio + async def test_licensing_hides_unlicensed_rows( + self, seeded_integration_session, seeded_sources + ): + """Only gem rows survive even though params.source defaults to all four.""" + rows, total = await _execute( + seeded_integration_session, + licensed_sources=["gem"], + page_size=200, + ) + assert total == 4 + assert {r.source for r in rows} == {"gem"} + assert "Ghawar" not in {r.name for r in rows} # wm row is hidden + + +class TestNarrowingProofs: + """The narrowed pivot really materializes the fields it filters/sorts on.""" + + @pytest.mark.anyio + async def test_filter_by_basin(self, seeded_integration_session, seeded_sources): + """basin=Permian proves basin is pivoted (two Permian rows).""" + rows, total = await _execute(seeded_integration_session, basin="Permian") + assert total == 2 + assert {r.name for r in rows} == {"Permian Basin", "Permian Delaware"} + + @pytest.mark.anyio + async def test_sort_by_discovery_year_nulls_last( + self, seeded_integration_session, seeded_sources + ): + """discovery_year sorts numerically (from value_num) with NULLs last.""" + rows, total = await _execute( + seeded_integration_session, + sort_by="discovery_year", + sort_order="asc", + page_size=200, + ) + assert total == 8 + years = [r.discovery_year for r in rows] + assert years == [1920, 1948, 1959, 1968, 2000, None, None, None] + + @pytest.mark.anyio + async def test_empty_involved_orders_by_id( + self, seeded_integration_session, seeded_sources + ): + """sort_by=id, no q/filters: zero value columns, all active rows by id.""" + rows, total = await _execute( + seeded_integration_session, sort_by="id", page_size=200 + ) + assert total == 8 + ids = [r.id for r in rows] + assert ids == sorted(ids) + + +@pytest.fixture +async def tiebreak_sources( + seeded_integration_session: AsyncSession, + test_user: User, +): + """Six gem rows with duplicate (2000) and NULL discovery years + memberships.""" + session = seeded_integration_session + years = [2000, 2000, None, 1990, None, 2000] + sources = [ + make_source_model( + source="gem", + created_by_id=test_user.id, + name=f"S{i}", + discovery_year=year, + ) + for i, year in enumerate(years) + ] + session.add_all(sources) + await session.flush() + + resource = ResourceModel.create(created_by=test_user) + session.add(resource) + await session.flush() + session.add_all( + MembershipModel.create( + created_by=test_user, + resource_id=resource.id, + source=src.source, + source_pk=src.id, + status=MembershipStatus.ACTIVE, + ) + for src in sources + ) + await session.flush() + return sources, years + + +class TestDeterministicTiebreak: + """Equal/NULL sort values fall back to a deterministic asc(id) tiebreak.""" + + @pytest.mark.anyio + async def test_duplicate_and_null_sort_values( + self, seeded_integration_session, tiebreak_sources + ): + sources, years = tiebreak_sources + rows, total = await _execute( + seeded_integration_session, + sort_by="discovery_year", + sort_order="asc", + page_size=200, + ) + assert total == len(sources) + + # Expected: non-NULL years asc, NULLs last, ties broken by ascending id. + id_year = list(zip((s.id for s in sources), years)) + expected_ids = [ + id_ + for id_, _ in sorted( + id_year, key=lambda iy: (iy[1] is None, iy[1] or 0, iy[0]) + ) + ] + assert [r.id for r in rows] == expected_ids + + +class TestSortBySource: + """The source path allows sort_by=source (the resource path forbids it).""" + + @pytest.mark.anyio + async def test_sort_by_source_orders_by_source( + self, seeded_integration_session, seeded_sources + ): + # sort_by=source is outside the SortableField Literal, so bypass + # validation via assignment (pydantic does not re-validate on set). + params = OGFieldQueryParams(page_size=200) + params.sort_by = "source" + params.sort_order = "asc" + rows, total = await source_actions.query(seeded_integration_session, params) + assert total == 8 + sources_in_order = [r.source for r in rows] + assert sources_in_order == sorted(sources_in_order) diff --git a/deployments/api/tests/db/test_resource_actions.py b/deployments/api/tests/db/test_resource_actions.py index 8719eb27..faa389b2 100644 --- a/deployments/api/tests/db/test_resource_actions.py +++ b/deployments/api/tests/db/test_resource_actions.py @@ -10,7 +10,7 @@ from stitch.api.db.model import ( MembershipModel, MembershipStatus, - OilGasFieldSourceModel, + OGFieldResourceSourcePriority, ResourceModel, ) from stitch.api.entities import ( @@ -19,7 +19,7 @@ User, ) from tests.factories import ResourceCreateFactory -from tests.utils import make_source_record +from tests.utils import make_source_model _QueryParams = OGFieldQueryParams @@ -36,16 +36,11 @@ async def _create_resource_with_sources( await session.flush() for row in source_rows: - payload = { - "source": row["source"], - "name": row.get("name"), - "country": row.get("country"), - } - source = OilGasFieldSourceModel( - **row, - source_record=make_source_record(payload=payload).model_dump(mode="json"), + attrs = {k: v for k, v in row.items() if k != "source"} + source = make_source_model( + source=row["source"], created_by_id=user.id, - last_updated_by_id=user.id, + **attrs, ) session.add(source) await session.flush() @@ -629,8 +624,8 @@ async def test_excludes_repointed_and_inactive_memberships( def test_postgres_distinct_query_orders_by_selected_value_alias(self): params = OGFieldFilterOptionsParams(field="basin") - coalesced = resource_actions._build_licensed_resource_list_cte( - params, + coalesced = resource_actions.build_resource_list_cte( + params.source, licensed_sources=frozenset({"gem", "wm", "rmi", "llm"}), ) col = resource_actions._resource_list_column(coalesced, params.field) @@ -825,3 +820,69 @@ async def test_repointed_resources_are_excluded( assert total == 1 assert [item.id for item in items] == [root_id] + + +class TestResourcePriorityOverride: + """A per-resource override re-ranks sources, flipping the coalesced winner.""" + + async def _seed(self, session, user) -> int: + # Default priority: gem(2) outranks wm(3), so gem wins by default. + return await _create_resource_with_sources( + session, + user, + {"source": "gem", "name": "GEM Name", "country": "USA"}, + {"source": "wm", "name": "WM Name", "country": "CAN"}, + ) + + @pytest.mark.anyio + async def test_override_flips_winner_in_detail_and_list( + self, + seeded_integration_session: AsyncSession, + test_user: User, + ): + session = seeded_integration_session + rid = await self._seed(session, test_user) + + # Default: gem wins. + before = await resource_actions.get(session, rid) + assert before.view.name == "GEM Name" + assert before.provenance["name"][1] == "gem" + + # Override wm to top priority for THIS resource only. + session.add( + OGFieldResourceSourcePriority(resource_id=rid, source="wm", priority=1) + ) + await session.flush() + + # Detail path reflects the override (value + provenance). + after = await resource_actions.get(session, rid) + assert after.view.name == "WM Name" + assert after.view.country == "CAN" + assert after.provenance["name"][1] == "wm" + + # List path reflects it too. + items, _ = await resource_actions.query(session, _QueryParams()) + item = next(i for i in items if i.id == rid) + assert item.data.name == "WM Name" + assert item.provenance["name"] == "wm" + + @pytest.mark.anyio + async def test_override_is_scoped_to_its_resource( + self, + seeded_integration_session: AsyncSession, + test_user: User, + ): + session = seeded_integration_session + overridden = await self._seed(session, test_user) + untouched = await self._seed(session, test_user) + + session.add( + OGFieldResourceSourcePriority( + resource_id=overridden, source="wm", priority=1 + ) + ) + await session.flush() + + assert (await resource_actions.get(session, overridden)).view.name == "WM Name" + # The other resource keeps the default ranking. + assert (await resource_actions.get(session, untouched)).view.name == "GEM Name" diff --git a/deployments/api/tests/routers/test_licensed_sources_routes.py b/deployments/api/tests/routers/test_licensed_sources_routes.py index 5767ab61..3122fa7c 100644 --- a/deployments/api/tests/routers/test_licensed_sources_routes.py +++ b/deployments/api/tests/routers/test_licensed_sources_routes.py @@ -20,7 +20,7 @@ ) from stitch.api.entities import User from stitch.api.main import app -from tests.utils import make_source_record +from tests.utils import make_source_model, make_source_record def _gem_only_claims() -> TokenClaims: @@ -115,11 +115,12 @@ async def _seed_resource_with_sources( session.add(resource) await session.flush() for row in source_rows: - source = OilGasFieldSourceModel( - **row, - source_record=make_source_record(payload=row).model_dump(mode="json"), + attrs = {k: v for k, v in row.items() if k != "source"} + source = make_source_model( + source=row["source"], created_by_id=user.id, - last_updated_by_id=user.id, + source_record=make_source_record(payload=row).model_dump(mode="json"), + **attrs, ) session.add(source) await session.flush() diff --git a/deployments/api/tests/utils.py b/deployments/api/tests/utils.py index 5b34fe67..57728bb4 100644 --- a/deployments/api/tests/utils.py +++ b/deployments/api/tests/utils.py @@ -40,6 +40,44 @@ def make_source_record( ) +def make_source_model( + *, + source: OGSISrcKey, + created_by_id: int, + source_record: dict[str, Any] | None = None, + **attrs: Any, +): + """Build a long-form source header ORM model + its value rows from kwargs. + + Mirrors how the app persists a source record: identity/raw payload on the + header, each non-null attribute as a typed row in the long values table. + """ + from stitch.api.db.model import OilGasFieldSourceModel + from stitch.api.db.model.oil_gas_field_source_value import ( + ATTRIBUTE_NAMES, + OilGasFieldSourceValueModel, + ) + + if source_record is None: + record_payload = {"source": source, **attrs} + source_record = make_source_record(payload=record_payload).model_dump( + mode="json" + ) + + model = OilGasFieldSourceModel( + source=source, + source_record=source_record, + created_by_id=created_by_id, + last_updated_by_id=created_by_id, + ) + model.values = [ + OilGasFieldSourceValueModel.from_attribute(colname, value) + for colname, value in attrs.items() + if colname in ATTRIBUTE_NAMES and value is not None + ] + return model + + def make_source( fact: OGFieldBaseFactory, managed: bool = True,