From f63517f3c198ddd697ee1cbc1511c00193339f60 Mon Sep 17 00:00:00 2001 From: archiblesherman Date: Thu, 5 Mar 2026 17:53:52 -0500 Subject: [PATCH 1/5] Add offline trace reader scaffolding for issue #41 --- tests/test_archive_reader_contract.py | 29 +++ tests/test_legacy_txt_reader.py | 98 ++++++++++ vizfold/__init__.py | 1 + vizfold/offline/__init__.py | 19 ++ vizfold/offline/archive_reader.py | 80 ++++++++ vizfold/offline/exceptions.py | 14 ++ vizfold/offline/legacy_txt_reader.py | 271 ++++++++++++++++++++++++++ vizfold/offline/models.py | 71 +++++++ vizfold/offline/paths.py | 68 +++++++ vizfold/offline/trace_reader.py | 68 +++++++ 10 files changed, 719 insertions(+) create mode 100644 tests/test_archive_reader_contract.py create mode 100644 tests/test_legacy_txt_reader.py create mode 100644 vizfold/__init__.py create mode 100644 vizfold/offline/__init__.py create mode 100644 vizfold/offline/archive_reader.py create mode 100644 vizfold/offline/exceptions.py create mode 100644 vizfold/offline/legacy_txt_reader.py create mode 100644 vizfold/offline/models.py create mode 100644 vizfold/offline/paths.py create mode 100644 vizfold/offline/trace_reader.py diff --git a/tests/test_archive_reader_contract.py b/tests/test_archive_reader_contract.py new file mode 100644 index 0000000..f81b0b5 --- /dev/null +++ b/tests/test_archive_reader_contract.py @@ -0,0 +1,29 @@ +import pytest + +from vizfold.offline import ArchiveReader + + +def test_archive_reader_constructor() -> None: + reader = ArchiveReader("dummy_archive") + assert str(reader.archive_root).endswith("dummy_archive") + + +@pytest.mark.parametrize( + "method_name,args,kwargs", + [ + ("metadata", tuple(), {}), + ("list_attention_types", tuple(), {}), + ("list_layers", ("msa_row",), {}), + ("list_heads", ("msa_row", 47), {}), + ("list_residue_indices", ("triangle_start", 47), {}), + ("load_attention", ("msa_row", 47, 0), {}), + ("load_attention_heads", ("msa_row", 47), {}), + ("load_structure", tuple(), {}), + ], +) +def test_archive_reader_methods_raise_not_implemented(method_name, args, kwargs) -> None: + reader = ArchiveReader("dummy_archive") + method = getattr(reader, method_name) + + with pytest.raises(NotImplementedError): + method(*args, **kwargs) \ No newline at end of file diff --git a/tests/test_legacy_txt_reader.py b/tests/test_legacy_txt_reader.py new file mode 100644 index 0000000..fc25fb7 --- /dev/null +++ b/tests/test_legacy_txt_reader.py @@ -0,0 +1,98 @@ +from pathlib import Path + +from vizfold.offline import LegacyTxtReader + + +def _write_text(path: Path, text: str) -> None: + path.write_text(text.strip() + "\n", encoding="utf-8") + + +def test_legacy_txt_reader_metadata_and_loading(tmp_path: Path) -> None: + attention_dir = tmp_path / "attn" + attention_dir.mkdir() + + fasta_path = tmp_path / "toy.fasta" + pdb_path = tmp_path / "toy_unrelaxed.pdb" + + _write_text( + fasta_path, + """ + >toy + ACDEFG + """, + ) + + _write_text( + pdb_path, + """ + HEADER TOY PDB + ATOM 1 N ALA A 1 11.104 13.207 9.947 1.00 50.00 N + END + """, + ) + + _write_text( + attention_dir / "msa_row_attn_layer47.txt", + """ + Layer 47, Head 0 + 0 1 0.90 + 1 3 0.70 + + Layer 47, Head 2 + 2 5 0.95 + 0 4 0.60 + """, + ) + + _write_text( + attention_dir / "triangle_start_attn_layer47_residue_idx_18.txt", + """ + Layer 47, Head 0 + 18 20 0.80 + 18 21 0.50 + + Layer 47, Head 1 + 18 30 0.92 + """, + ) + + reader = LegacyTxtReader( + attention_dir=attention_dir, + fasta_path=fasta_path, + pdb_path=pdb_path, + protein_id="toy", + ) + + meta = reader.metadata() + assert meta.protein_id == "toy" + assert meta.sequence == "ACDEFG" + assert set(meta.attention_types) == {"msa_row", "triangle_start"} + assert meta.layers_by_type["msa_row"] == [47] + assert meta.layers_by_type["triangle_start"] == [47] + assert meta.residue_indices_by_type["triangle_start"][47] == [18] + + msa_heads = reader.list_heads("msa_row", 47) + assert msa_heads == [0, 2] + + tri_heads = reader.list_heads("triangle_start", 47, residue_idx=18) + assert tri_heads == [0, 1] + + msa_slice = reader.load_attention("msa_row", layer=47, head=2) + assert msa_slice.as_triplets()[0] == (2, 5, 0.95) + assert msa_slice.as_triplets()[1] == (0, 4, 0.60) + + tri_slice = reader.load_attention( + "triangle_start", + layer=47, + head=0, + residue_idx=18, + top_k=1, + ) + assert tri_slice.residue_idx == 18 + assert tri_slice.as_triplets() == [(18, 20, 0.80)] + + structure = reader.load_structure() + assert structure.protein_id == "toy" + assert structure.sequence == "ACDEFG" + assert structure.pdb_text is not None + assert "HEADER TOY PDB" in structure.pdb_text \ No newline at end of file diff --git a/vizfold/__init__.py b/vizfold/__init__.py new file mode 100644 index 0000000..71c87dd --- /dev/null +++ b/vizfold/__init__.py @@ -0,0 +1 @@ +"""VizFold utilities.""" \ No newline at end of file diff --git a/vizfold/offline/__init__.py b/vizfold/offline/__init__.py new file mode 100644 index 0000000..87d8ba4 --- /dev/null +++ b/vizfold/offline/__init__.py @@ -0,0 +1,19 @@ +from .archive_reader import ArchiveReader +from .legacy_txt_reader import LegacyTxtReader +from .models import ( + AttentionConnection, + AttentionSlice, + StructureData, + TraceMetadata, +) +from .trace_reader import TraceReader + +__all__ = [ + "ArchiveReader", + "LegacyTxtReader", + "AttentionConnection", + "AttentionSlice", + "StructureData", + "TraceMetadata", + "TraceReader", +] \ No newline at end of file diff --git a/vizfold/offline/archive_reader.py b/vizfold/offline/archive_reader.py new file mode 100644 index 0000000..4507183 --- /dev/null +++ b/vizfold/offline/archive_reader.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from pathlib import Path + +from .models import AttentionSlice, StructureData, TraceMetadata +from .trace_reader import TraceReader + + +class ArchiveReader(TraceReader): + """ + Placeholder reader for the future standardized archive from issue #39. + + The point of adding this now is to lock the interface that both the + frontend and the archive writer will target. + """ + + def __init__(self, archive_root: str | Path) -> None: + self.archive_root = Path(archive_root) + + def metadata(self) -> TraceMetadata: + raise NotImplementedError( + "ArchiveReader is waiting on the finalized archive schema from issue #39." + ) + + def list_attention_types(self) -> list[str]: + raise NotImplementedError( + "ArchiveReader is waiting on the finalized archive schema from issue #39." + ) + + def list_layers(self, attention_type: str) -> list[int]: + raise NotImplementedError( + "ArchiveReader is waiting on the finalized archive schema from issue #39." + ) + + def list_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + ) -> list[int]: + raise NotImplementedError( + "ArchiveReader is waiting on the finalized archive schema from issue #39." + ) + + def list_residue_indices( + self, + attention_type: str, + layer: int, + ) -> list[int]: + raise NotImplementedError( + "ArchiveReader is waiting on the finalized archive schema from issue #39." + ) + + def load_attention( + self, + attention_type: str, + layer: int, + head: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> AttentionSlice: + raise NotImplementedError( + "ArchiveReader is waiting on the finalized archive schema from issue #39." + ) + + def load_attention_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> dict[int, AttentionSlice]: + raise NotImplementedError( + "ArchiveReader is waiting on the finalized archive schema from issue #39." + ) + + def load_structure(self) -> StructureData: + raise NotImplementedError( + "ArchiveReader is waiting on the finalized archive schema from issue #39." + ) \ No newline at end of file diff --git a/vizfold/offline/exceptions.py b/vizfold/offline/exceptions.py new file mode 100644 index 0000000..a0a8ffc --- /dev/null +++ b/vizfold/offline/exceptions.py @@ -0,0 +1,14 @@ +class TraceReaderError(Exception): + """Base exception for offline trace reading.""" + + +class TraceFormatError(TraceReaderError): + """Raised when a trace file exists but does not match the expected format.""" + + +class TraceNotFoundError(TraceReaderError): + """Raised when a requested trace file or resource cannot be found.""" + + +class UnsupportedAttentionTypeError(TraceReaderError): + """Raised when an attention type is not supported by the reader.""" \ No newline at end of file diff --git a/vizfold/offline/legacy_txt_reader.py b/vizfold/offline/legacy_txt_reader.py new file mode 100644 index 0000000..c6a2418 --- /dev/null +++ b/vizfold/offline/legacy_txt_reader.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import re +from pathlib import Path + +from .exceptions import TraceFormatError, TraceNotFoundError +from .models import AttentionConnection, AttentionSlice, StructureData, TraceMetadata +from .paths import parse_legacy_attention_filename, resolve_legacy_attention_path +from .trace_reader import TraceReader + + +class LegacyTxtReader(TraceReader): + """ + Reader for VizFold's current legacy text-dump attention format. + + This wraps the existing file naming and parsing conventions behind a stable API + so the frontend can stop depending on hardcoded filenames. + """ + + def __init__( + self, + attention_dir: str | Path, + fasta_path: str | Path | None = None, + pdb_path: str | Path | None = None, + protein_id: str | None = None, + ) -> None: + self.attention_dir = Path(attention_dir) + if not self.attention_dir.exists(): + raise TraceNotFoundError(f"attention_dir does not exist: {self.attention_dir}") + + self.fasta_path = Path(fasta_path) if fasta_path is not None else None + self.pdb_path = Path(pdb_path) if pdb_path is not None else None + self.protein_id = protein_id or self._infer_protein_id() + + self._metadata_cache: TraceMetadata | None = None + + def metadata(self) -> TraceMetadata: + if self._metadata_cache is None: + self._metadata_cache = self._build_metadata() + return self._metadata_cache + + def list_attention_types(self) -> list[str]: + return self.metadata().attention_types + + def list_layers(self, attention_type: str) -> list[int]: + return self.metadata().layers_by_type.get(attention_type, []) + + def list_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + ) -> list[int]: + path = resolve_legacy_attention_path( + self.attention_dir, + attention_type=attention_type, + layer=layer, + residue_idx=residue_idx, + ) + heads = self._parse_heads_file(path) + return sorted(heads.keys()) + + def list_residue_indices( + self, + attention_type: str, + layer: int, + ) -> list[int]: + return self.metadata().residue_indices_by_type.get(attention_type, {}).get(layer, []) + + def load_attention( + self, + attention_type: str, + layer: int, + head: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> AttentionSlice: + all_heads = self.load_attention_heads( + attention_type=attention_type, + layer=layer, + residue_idx=residue_idx, + top_k=top_k, + ) + + if head not in all_heads: + available = sorted(all_heads.keys()) + raise TraceNotFoundError( + f"Head {head} not found for attention_type={attention_type}, " + f"layer={layer}, residue_idx={residue_idx}. Available heads: {available}" + ) + + return all_heads[head] + + def load_attention_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> dict[int, AttentionSlice]: + path = resolve_legacy_attention_path( + self.attention_dir, + attention_type=attention_type, + layer=layer, + residue_idx=residue_idx, + ) + parsed = self._parse_heads_file(path) + + slices: dict[int, AttentionSlice] = {} + for head_idx, connections in parsed.items(): + if top_k is not None: + connections = connections[:top_k] + + slices[head_idx] = AttentionSlice( + attention_type=attention_type, + layer=layer, + head=head_idx, + residue_idx=residue_idx, + connections=connections, + ) + + return slices + + def load_structure(self) -> StructureData: + pdb_text = None + if self.pdb_path is not None: + if not self.pdb_path.exists(): + raise TraceNotFoundError(f"pdb_path does not exist: {self.pdb_path}") + pdb_text = self.pdb_path.read_text(encoding="utf-8") + + sequence = self._read_fasta_sequence(self.fasta_path) if self.fasta_path else None + + return StructureData( + protein_id=self.protein_id, + pdb_path=self.pdb_path, + pdb_text=pdb_text, + sequence=sequence, + ) + + def _build_metadata(self) -> TraceMetadata: + layers_by_type: dict[str, set[int]] = {} + heads_by_type_sets: dict[str, dict[int, set[int]]] = {} + residue_indices_by_type: dict[str, dict[int, set[int]]] = {} + + for path in self.attention_dir.iterdir(): + if not path.is_file(): + continue + + parsed = parse_legacy_attention_filename(path.name) + if parsed is None: + continue + + attention_type, layer, residue_idx = parsed + + layers_by_type.setdefault(attention_type, set()).add(layer) + heads_by_type_sets.setdefault(attention_type, {}).setdefault(layer, set()) + + file_heads = self._parse_heads_file(path).keys() + heads_by_type_sets[attention_type][layer].update(file_heads) + + if residue_idx is not None: + residue_indices_by_type.setdefault(attention_type, {}).setdefault(layer, set()).add( + residue_idx + ) + + attention_types = sorted(layers_by_type.keys()) + sequence = self._read_fasta_sequence(self.fasta_path) if self.fasta_path else None + + heads_by_type: dict[str, dict[int, list[int]]] = { + attn_type: { + layer: sorted(heads) + for layer, heads in layer_map.items() + } + for attn_type, layer_map in heads_by_type_sets.items() + } + + residue_indices_by_type_sorted: dict[str, dict[int, list[int]]] = { + attn_type: { + layer: sorted(residue_indices) + for layer, residue_indices in layer_map.items() + } + for attn_type, layer_map in residue_indices_by_type.items() + } + + return TraceMetadata( + protein_id=self.protein_id, + source_root=self.attention_dir, + fasta_path=self.fasta_path, + pdb_path=self.pdb_path, + sequence=sequence, + attention_types=attention_types, + layers_by_type={ + attn_type: sorted(layers) + for attn_type, layers in layers_by_type.items() + }, + heads_by_type=heads_by_type, + residue_indices_by_type=residue_indices_by_type_sorted, + extras={ + "format": "legacy_txt", + }, + ) + + def _parse_heads_file(self, path: Path) -> dict[int, list[AttentionConnection]]: + """ + Expected format: + + Layer 47, Head 0 + 1 5 0.91 + 2 9 0.88 + + Layer 47, Head 1 + 3 7 0.95 + """ + heads: dict[int, list[AttentionConnection]] = {} + current_head: int | None = None + + with path.open("r", encoding="utf-8") as f: + for raw_line in f: + line = raw_line.strip() + if not line: + continue + + if line.lower().startswith("layer"): + numbers = re.findall(r"-?\d+", line) + if not numbers: + raise TraceFormatError(f"Could not parse head header line: {line}") + current_head = int(numbers[-1]) + heads[current_head] = [] + continue + + if current_head is None: + raise TraceFormatError( + f"Found attention row before any head header in file: {path}" + ) + + parts = line.split() + if len(parts) != 3: + raise TraceFormatError( + f"Expected 3 columns in attention row, got {len(parts)}: {line}" + ) + + src = int(float(parts[0])) + dst = int(float(parts[1])) + weight = float(parts[2]) + + heads[current_head].append( + AttentionConnection(src=src, dst=dst, weight=weight) + ) + + for head_idx, conns in heads.items(): + conns.sort(key=lambda x: x.weight, reverse=True) + + return heads + + def _infer_protein_id(self) -> str: + if self.pdb_path is not None: + return self.pdb_path.stem + if self.fasta_path is not None: + return self.fasta_path.stem + return self.attention_dir.name + + @staticmethod + def _read_fasta_sequence(fasta_path: Path | None) -> str | None: + if fasta_path is None: + return None + if not fasta_path.exists(): + raise TraceNotFoundError(f"FASTA path does not exist: {fasta_path}") + + lines = fasta_path.read_text(encoding="utf-8").splitlines() + seq_lines = [line.strip() for line in lines if line and not line.startswith(">")] + return "".join(seq_lines) \ No newline at end of file diff --git a/vizfold/offline/models.py b/vizfold/offline/models.py new file mode 100644 index 0000000..3187141 --- /dev/null +++ b/vizfold/offline/models.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from pathlib import Path +from typing import Any + + +@dataclass(frozen=True) +class AttentionConnection: + """One residue-residue edge with a scalar attention weight.""" + src: int + dst: int + weight: float + + +@dataclass(frozen=True) +class AttentionSlice: + """ + One logical attention slice. + + Examples: + - MSA row attention at one layer/head + - Triangle-start attention at one layer/residue/head + """ + attention_type: str + layer: int + head: int + residue_idx: int | None + connections: list[AttentionConnection] + + def top_k(self, k: int) -> "AttentionSlice": + if k < 0: + raise ValueError("k must be >= 0") + return replace(self, connections=self.connections[:k]) + + def as_triplets(self) -> list[tuple[int, int, float]]: + return [(c.src, c.dst, c.weight) for c in self.connections] + + +@dataclass(frozen=True) +class StructureData: + """ + Minimal structure payload for offline visualization. + """ + protein_id: str + pdb_path: Path | None + pdb_text: str | None + sequence: str | None = None + + +@dataclass(frozen=True) +class TraceMetadata: + """ + Metadata for one trace source. + + heads_by_type maps: + attention_type -> layer -> list[head_idx] + + residue_indices_by_type maps: + attention_type -> layer -> list[residue_idx] + """ + protein_id: str + source_root: Path + fasta_path: Path | None = None + pdb_path: Path | None = None + sequence: str | None = None + attention_types: list[str] = field(default_factory=list) + layers_by_type: dict[str, list[int]] = field(default_factory=dict) + heads_by_type: dict[str, dict[int, list[int]]] = field(default_factory=dict) + residue_indices_by_type: dict[str, dict[int, list[int]]] = field(default_factory=dict) + extras: dict[str, Any] = field(default_factory=dict) \ No newline at end of file diff --git a/vizfold/offline/paths.py b/vizfold/offline/paths.py new file mode 100644 index 0000000..ee630b0 --- /dev/null +++ b/vizfold/offline/paths.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import re +from pathlib import Path + +from .exceptions import TraceNotFoundError, UnsupportedAttentionTypeError + +_MSA_ROW_RE = re.compile(r"^msa_row_attn_layer(?P\d+)\.txt$") +_TRIANGLE_START_RE = re.compile( + r"^triangle_start_attn_layer(?P\d+)_residue_idx_(?P\d+)\.txt$" +) + + +def legacy_msa_row_path(attention_dir: str | Path, layer: int) -> Path: + return Path(attention_dir) / f"msa_row_attn_layer{layer}.txt" + + +def legacy_triangle_start_path( + attention_dir: str | Path, + layer: int, + residue_idx: int, +) -> Path: + return Path(attention_dir) / f"triangle_start_attn_layer{layer}_residue_idx_{residue_idx}.txt" + + +def resolve_legacy_attention_path( + attention_dir: str | Path, + attention_type: str, + layer: int, + residue_idx: int | None = None, +) -> Path: + if attention_type == "msa_row": + path = legacy_msa_row_path(attention_dir, layer) + elif attention_type == "triangle_start": + if residue_idx is None: + raise ValueError("residue_idx is required for triangle_start attention") + path = legacy_triangle_start_path(attention_dir, layer, residue_idx) + else: + raise UnsupportedAttentionTypeError(f"Unsupported attention_type: {attention_type}") + + if not path.exists(): + raise TraceNotFoundError(f"Attention file not found: {path}") + + return path + + +def parse_legacy_attention_filename(filename: str) -> tuple[str, int, int | None] | None: + """ + Returns: + (attention_type, layer, residue_idx) + + Examples: + msa_row_attn_layer47.txt -> ("msa_row", 47, None) + triangle_start_attn_layer47_residue_idx_18.txt -> ("triangle_start", 47, 18) + """ + msa_match = _MSA_ROW_RE.match(filename) + if msa_match: + return ("msa_row", int(msa_match.group("layer")), None) + + tri_match = _TRIANGLE_START_RE.match(filename) + if tri_match: + return ( + "triangle_start", + int(tri_match.group("layer")), + int(tri_match.group("residue")), + ) + + return None \ No newline at end of file diff --git a/vizfold/offline/trace_reader.py b/vizfold/offline/trace_reader.py new file mode 100644 index 0000000..5a8f35f --- /dev/null +++ b/vizfold/offline/trace_reader.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from .models import AttentionSlice, StructureData, TraceMetadata + + +class TraceReader(ABC): + """ + Abstract reader interface for offline inference traces. + + Both legacy text-based readers and future archive readers should implement + this exact API so downstream notebooks and Streamlit apps can remain stable. + """ + + @abstractmethod + def metadata(self) -> TraceMetadata: + raise NotImplementedError + + @abstractmethod + def list_attention_types(self) -> list[str]: + raise NotImplementedError + + @abstractmethod + def list_layers(self, attention_type: str) -> list[int]: + raise NotImplementedError + + @abstractmethod + def list_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + ) -> list[int]: + raise NotImplementedError + + @abstractmethod + def list_residue_indices( + self, + attention_type: str, + layer: int, + ) -> list[int]: + raise NotImplementedError + + @abstractmethod + def load_attention( + self, + attention_type: str, + layer: int, + head: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> AttentionSlice: + raise NotImplementedError + + @abstractmethod + def load_attention_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> dict[int, AttentionSlice]: + raise NotImplementedError + + @abstractmethod + def load_structure(self) -> StructureData: + raise NotImplementedError \ No newline at end of file From 86d2794e06297ae229d6a6a175329830e6a545c2 Mon Sep 17 00:00:00 2001 From: archiblesherman Date: Sun, 22 Mar 2026 21:01:14 -0400 Subject: [PATCH 2/5] ArchiveReader: add schema-aware metadata probing and optional Zarr support --- vizfold/offline/archive_reader.py | 346 ++++++++++++++++++++++++++++-- vizfold/offline/models.py | 21 +- 2 files changed, 339 insertions(+), 28 deletions(-) diff --git a/vizfold/offline/archive_reader.py b/vizfold/offline/archive_reader.py index 4507183..f0144fd 100644 --- a/vizfold/offline/archive_reader.py +++ b/vizfold/offline/archive_reader.py @@ -1,36 +1,63 @@ from __future__ import annotations +import json from pathlib import Path +from typing import Any from .models import AttentionSlice, StructureData, TraceMetadata from .trace_reader import TraceReader +try: + import zarr # type: ignore +except Exception: + zarr = None + class ArchiveReader(TraceReader): """ - Placeholder reader for the future standardized archive from issue #39. + Schema-aware scaffolding for the future standardized archive from issue #39. - The point of adding this now is to lock the interface that both the - frontend and the archive writer will target. + This reader intentionally does NOT hard-code the current issue-39 prototype + layouts. Instead, it: + - detects archive kind (currently only probes Zarr when available) + - normalizes metadata into the TraceMetadata contract + - exposes capability discovery now + - defers actual tensor-path loading until the VizFold protein schema is final """ def __init__(self, archive_root: str | Path) -> None: self.archive_root = Path(archive_root) + self._zarr_root = None + self._probe = self._probe_archive() def metadata(self) -> TraceMetadata: - raise NotImplementedError( - "ArchiveReader is waiting on the finalized archive schema from issue #39." + return TraceMetadata( + protein_id=self._probe["protein_id"], + source_root=self.archive_root, + fasta_path=self._path_or_none(self._probe.get("fasta_path")), + pdb_path=self._path_or_none(self._probe.get("pdb_path")), + sequence=self._probe.get("sequence"), + attention_types=self._probe["attention_types"], + layers_by_type=self._probe["layers_by_type"], + heads_by_type=self._probe["heads_by_type"], + residue_indices_by_type=self._probe["residue_indices_by_type"], + schema_version=self._probe.get("schema_version"), + archive_kind=self._probe.get("archive_kind"), + model_family=self._probe.get("model_family"), + model_version=self._probe.get("model_version"), + structure_available="structure" in self._probe["capabilities"], + capabilities=self._probe["capabilities"], + extras={ + "sequence_length": self._probe.get("sequence_length"), + "raw_metadata": self._probe.get("raw_metadata", {}), + }, ) def list_attention_types(self) -> list[str]: - raise NotImplementedError( - "ArchiveReader is waiting on the finalized archive schema from issue #39." - ) + return list(self._probe["attention_types"]) def list_layers(self, attention_type: str) -> list[int]: - raise NotImplementedError( - "ArchiveReader is waiting on the finalized archive schema from issue #39." - ) + return list(self._probe["layers_by_type"].get(attention_type, [])) def list_heads( self, @@ -38,17 +65,17 @@ def list_heads( layer: int, residue_idx: int | None = None, ) -> list[int]: - raise NotImplementedError( - "ArchiveReader is waiting on the finalized archive schema from issue #39." - ) + # Today we ignore residue_idx at the indexing layer unless the finalized + # schema eventually needs a residue-specific head map. + return list(self._probe["heads_by_type"].get(attention_type, {}).get(layer, [])) def list_residue_indices( self, attention_type: str, layer: int, ) -> list[int]: - raise NotImplementedError( - "ArchiveReader is waiting on the finalized archive schema from issue #39." + return list( + self._probe["residue_indices_by_type"].get(attention_type, {}).get(layer, []) ) def load_attention( @@ -59,8 +86,13 @@ def load_attention( residue_idx: int | None = None, top_k: int | None = None, ) -> AttentionSlice: - raise NotImplementedError( - "ArchiveReader is waiting on the finalized archive schema from issue #39." + raise self._not_ready( + "load_attention", + attention_type=attention_type, + layer=layer, + head=head, + residue_idx=residue_idx, + top_k=top_k, ) def load_attention_heads( @@ -70,11 +102,279 @@ def load_attention_heads( residue_idx: int | None = None, top_k: int | None = None, ) -> dict[int, AttentionSlice]: - raise NotImplementedError( - "ArchiveReader is waiting on the finalized archive schema from issue #39." + raise self._not_ready( + "load_attention_heads", + attention_type=attention_type, + layer=layer, + residue_idx=residue_idx, + top_k=top_k, ) def load_structure(self) -> StructureData: - raise NotImplementedError( - "ArchiveReader is waiting on the finalized archive schema from issue #39." - ) \ No newline at end of file + raise self._not_ready("load_structure") + + def _probe_archive(self) -> dict[str, Any]: + if not self.archive_root.exists(): + raise FileNotFoundError(f"Archive path does not exist: {self.archive_root}") + + probe: dict[str, Any] = { + "protein_id": self.archive_root.stem, + "schema_version": None, + "archive_kind": None, + "model_family": None, + "model_version": None, + "sequence": None, + "sequence_length": None, + "fasta_path": None, + "pdb_path": None, + "attention_types": [], + "layers_by_type": {}, + "heads_by_type": {}, + "residue_indices_by_type": {}, + "capabilities": set(), + "raw_metadata": {}, + } + + # Sidecar metadata.json support is a safe, schema-neutral bridge. + metadata_path = self.archive_root / "metadata.json" + if metadata_path.exists(): + payload = self._load_json(metadata_path) + probe["raw_metadata"]["metadata.json"] = payload + self._merge_probe(probe, self._normalize_metadata_payload(payload)) + + # Probe Zarr archives when zarr is installed. + if self.archive_root.suffix == ".zarr": + probe["archive_kind"] = "zarr" + probe["capabilities"].add("partial_loading") + + if zarr is not None: + root = zarr.open(str(self.archive_root), mode="r") + self._zarr_root = root + + attrs = dict(getattr(root, "attrs", {})) + probe["raw_metadata"]["zarr_attrs"] = attrs + self._merge_probe(probe, self._normalize_metadata_payload(attrs)) + + self._infer_capabilities_from_zarr_root(probe, root) + + probe["attention_types"] = sorted(set(probe["attention_types"])) + probe["layers_by_type"] = { + attn_type: sorted(set(layers)) + for attn_type, layers in probe["layers_by_type"].items() + } + probe["heads_by_type"] = { + attn_type: { + int(layer): sorted(set(heads)) + for layer, heads in layer_map.items() + } + for attn_type, layer_map in probe["heads_by_type"].items() + } + probe["residue_indices_by_type"] = { + attn_type: { + int(layer): sorted(set(indices)) + for layer, indices in layer_map.items() + } + for attn_type, layer_map in probe["residue_indices_by_type"].items() + } + probe["capabilities"] = sorted(probe["capabilities"]) + + return probe + + def _normalize_metadata_payload(self, payload: dict[str, Any]) -> dict[str, Any]: + normalized: dict[str, Any] = { + "schema_version": self._first_non_empty( + payload.get("schema_version"), + payload.get("archive_version"), + payload.get("spec_version"), + ), + "archive_kind": payload.get("archive_kind"), + "model_family": self._first_non_empty( + payload.get("model_family"), + payload.get("model_name"), + ), + "model_version": payload.get("model_version"), + "protein_id": self._first_non_empty( + payload.get("protein_id"), + payload.get("trace_id"), + payload.get("sample_id"), + ), + "sequence": payload.get("sequence"), + "sequence_length": payload.get("sequence_length"), + "fasta_path": payload.get("fasta_path"), + "pdb_path": payload.get("pdb_path"), + } + + capabilities = payload.get("capabilities") + if isinstance(capabilities, (list, tuple, set)): + normalized["capabilities"] = {str(x) for x in capabilities} + + attention_index = payload.get("attention_index") + if not isinstance(attention_index, dict): + attention_index = payload + + normalized.update(self._normalize_attention_index(attention_index)) + return normalized + + def _normalize_attention_index(self, raw: dict[str, Any]) -> dict[str, Any]: + attention_types = [str(x) for x in raw.get("attention_types", [])] + + layers_by_type = self._normalize_layers_by_type(raw.get("layers_by_type", {})) + heads_by_type = self._normalize_nested_index(raw.get("heads_by_type", {})) + residue_indices_by_type = self._normalize_nested_index( + raw.get("residue_indices_by_type", {}) + ) + + if not attention_types: + attention_types = sorted( + set(layers_by_type.keys()) + | set(heads_by_type.keys()) + | set(residue_indices_by_type.keys()) + ) + + normalized: dict[str, Any] = { + "attention_types": attention_types, + "layers_by_type": layers_by_type, + "heads_by_type": heads_by_type, + "residue_indices_by_type": residue_indices_by_type, + } + + if attention_types: + normalized["capabilities"] = {"attention_index"} + + if any(residue_indices_by_type.values()): + normalized.setdefault("capabilities", set()).add("residue_indexed_attention") + + return normalized + + @staticmethod + def _normalize_layers_by_type(raw: Any) -> dict[str, list[int]]: + if not isinstance(raw, dict): + return {} + + out: dict[str, list[int]] = {} + for attention_type, layers in raw.items(): + if not isinstance(layers, (list, tuple)): + continue + out[str(attention_type)] = [int(layer) for layer in layers] + return out + + @staticmethod + def _normalize_nested_index(raw: Any) -> dict[str, dict[int, list[int]]]: + if not isinstance(raw, dict): + return {} + + out: dict[str, dict[int, list[int]]] = {} + for attention_type, layer_map in raw.items(): + if not isinstance(layer_map, dict): + continue + + normalized_layer_map: dict[int, list[int]] = {} + for layer, values in layer_map.items(): + if not isinstance(values, (list, tuple)): + continue + normalized_layer_map[int(layer)] = [int(v) for v in values] + + out[str(attention_type)] = normalized_layer_map + + return out + + def _infer_capabilities_from_zarr_root(self, probe: dict[str, Any], root: Any) -> None: + names = set() + + try: + names.update(root.group_keys()) + except Exception: + pass + + try: + names.update(root.array_keys()) + except Exception: + pass + + lowered = {name.lower() for name in names} + + if any(name in lowered for name in ("attention", "attn", "representations")): + probe["capabilities"].add("attention") + + if any(name in lowered for name in ("structure", "structures", "pdb")): + probe["capabilities"].add("structure") + + if any(name in lowered for name in ("metadata", "meta")): + probe["capabilities"].add("metadata") + + @staticmethod + def _merge_probe(base: dict[str, Any], new: dict[str, Any]) -> None: + scalar_keys = ( + "protein_id", + "schema_version", + "archive_kind", + "model_family", + "model_version", + "sequence", + "sequence_length", + "fasta_path", + "pdb_path", + ) + for key in scalar_keys: + value = new.get(key) + if value is not None: + base[key] = value + + if new.get("attention_types"): + base["attention_types"] = list( + set(base["attention_types"]) | set(new["attention_types"]) + ) + + for key in ("layers_by_type", "heads_by_type", "residue_indices_by_type"): + incoming = new.get(key, {}) + if not incoming: + continue + + if key == "layers_by_type": + for attention_type, layers in incoming.items(): + base[key].setdefault(attention_type, []) + base[key][attention_type].extend(layers) + else: + for attention_type, layer_map in incoming.items(): + base[key].setdefault(attention_type, {}) + for layer, values in layer_map.items(): + base[key][attention_type].setdefault(int(layer), []) + base[key][attention_type][int(layer)].extend(values) + + for cap in new.get("capabilities", set()): + base["capabilities"].add(cap) + + def _not_ready(self, method_name: str, **kwargs: Any) -> NotImplementedError: + details = { + "archive_kind": self._probe.get("archive_kind"), + "schema_version": self._probe.get("schema_version"), + "capabilities": self._probe.get("capabilities"), + } + if kwargs: + details["request"] = kwargs + + return NotImplementedError( + f"{method_name} is intentionally deferred until issue #39 finalizes " + f"the VizFold protein archive schema. Probe summary: {details}" + ) + + @staticmethod + def _first_non_empty(*values: Any) -> Any: + for value in values: + if value is not None and value != "": + return value + return None + + @staticmethod + def _load_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise ValueError(f"Expected JSON object in {path}, got {type(data).__name__}") + return data + + @staticmethod + def _path_or_none(value: Any) -> Path | None: + if value is None: + return None + return Path(value) \ No newline at end of file diff --git a/vizfold/offline/models.py b/vizfold/offline/models.py index 3187141..330aa43 100644 --- a/vizfold/offline/models.py +++ b/vizfold/offline/models.py @@ -8,6 +8,7 @@ @dataclass(frozen=True) class AttentionConnection: """One residue-residue edge with a scalar attention weight.""" + src: int dst: int weight: float @@ -22,6 +23,7 @@ class AttentionSlice: - MSA row attention at one layer/head - Triangle-start attention at one layer/residue/head """ + attention_type: str layer: int head: int @@ -42,6 +44,7 @@ class StructureData: """ Minimal structure payload for offline visualization. """ + protein_id: str pdb_path: Path | None pdb_text: str | None @@ -53,19 +56,27 @@ class TraceMetadata: """ Metadata for one trace source. - heads_by_type maps: - attention_type -> layer -> list[head_idx] - - residue_indices_by_type maps: - attention_type -> layer -> list[residue_idx] + heads_by_type maps: attention_type -> layer -> list[head_idx] + residue_indices_by_type maps: attention_type -> layer -> list[residue_idx] """ + protein_id: str source_root: Path fasta_path: Path | None = None pdb_path: Path | None = None sequence: str | None = None + attention_types: list[str] = field(default_factory=list) layers_by_type: dict[str, list[int]] = field(default_factory=dict) heads_by_type: dict[str, dict[int, list[int]]] = field(default_factory=dict) residue_indices_by_type: dict[str, dict[int, list[int]]] = field(default_factory=dict) + + # New normalized archive fields + schema_version: str | None = None + archive_kind: str | None = None + model_family: str | None = None + model_version: str | None = None + structure_available: bool = False + capabilities: list[str] = field(default_factory=list) + extras: dict[str, Any] = field(default_factory=dict) \ No newline at end of file From 9c819c6730702b61fdcceaafc5210d9bf09e2c59 Mon Sep 17 00:00:00 2001 From: archiblesherman Date: Sun, 26 Apr 2026 18:48:34 -0400 Subject: [PATCH 3/5] Implement working Zarr ArchiveReader with tests. Needs to be tested with real data. --- environment.yml | 3 + tests/test_archive_reader_contract.py | 228 ++++++- vizfold/offline/archive_reader.py | 919 ++++++++++++++++++-------- 3 files changed, 856 insertions(+), 294 deletions(-) diff --git a/environment.yml b/environment.yml index e02d1b4..c9097a6 100644 --- a/environment.yml +++ b/environment.yml @@ -33,6 +33,9 @@ dependencies: - bioconda::kalign2 - pytorch::pytorch=2.5 - pytorch::pytorch-cuda=12.4 + - zarr=2.16 + - numcodecs + - fsspec - pip: - deepspeed==0.14.5 - dm-tree==0.1.6 diff --git a/tests/test_archive_reader_contract.py b/tests/test_archive_reader_contract.py index f81b0b5..ce5e246 100644 --- a/tests/test_archive_reader_contract.py +++ b/tests/test_archive_reader_contract.py @@ -1,29 +1,209 @@ +from __future__ import annotations + +import numpy as np import pytest from vizfold.offline import ArchiveReader +zarr = pytest.importorskip("zarr") + + +def test_archive_reader_loads_issue39_style_zarr(tmp_path): + archive_path = tmp_path / "toy.vizfold.zarr" + + root = zarr.open_group(str(archive_path), mode="w") + + metadata = root.require_group("metadata") + metadata.attrs["model_version"] = "openfold-test" + metadata.attrs["config_version"] = "model_1" + metadata.attrs["sequence"] = "ACDE" + metadata.attrs["num_residues"] = 4 + metadata.attrs["num_recycles"] = 1 + metadata.create_dataset( + "residue_index", + data=np.arange(4), + shape=(4,), + ) + + attention = root.require_group("attention").require_group("triangle_start") + + # Issue-39 documented shape for triangle attention: + # (num_residues, num_residues, num_heads) + arr = np.zeros((4, 4, 2), dtype=np.float32) + arr[0, 2, 1] = 0.90 + arr[3, 1, 1] = 0.40 + arr[1, 0, 0] = 0.75 + + attention.create_dataset( + "layer_00", + data=arr, + shape=arr.shape, + chunks=(4, 4, 1), + ) + + reps = root.require_group("representations") + single_arr = np.ones((4, 8), dtype=np.float32) + reps.require_group("single").create_dataset( + "layer_00", + data=single_arr, + shape=single_arr.shape, + ) + pair_arr = np.ones((4, 4, 16), dtype=np.float32) + reps.require_group("pair").create_dataset( + "layer_00", + data=pair_arr, + shape=pair_arr.shape, + ) + + structure = root.require_group("structure") + coords = np.array( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [3.0, 0.0, 0.0], + ], + dtype=np.float32, + ) + + structure.create_dataset( + "atom_positions", + data=coords, + shape=coords.shape, + ) + + reader = ArchiveReader(archive_path) + + meta = reader.metadata() + assert meta.archive_kind == "zarr" + assert meta.model_version == "openfold-test" + assert meta.sequence == "ACDE" + assert meta.structure_available is True + + assert reader.list_attention_types() == ["triangle_start"] + assert reader.list_layers("triangle_start") == [0] + assert reader.list_heads("triangle_start", 0) == [0, 1] + + loaded = reader.load_attention( + attention_type="triangle_start", + layer=0, + head=1, + top_k=2, + ) + + assert loaded.attention_type == "triangle_start" + assert loaded.layer == 0 + assert loaded.head == 1 + assert loaded.as_triplets()[0] == (0, 2, pytest.approx(0.90)) + assert loaded.as_triplets()[1] == (3, 1, pytest.approx(0.40)) + + all_heads = reader.load_attention_heads("triangle_start", 0, top_k=1) + assert sorted(all_heads.keys()) == [0, 1] + + single = reader.load_single_representation(0) + pair = reader.load_pair_representation(0) + + assert single.shape == (4, 8) + assert pair.shape == (4, 4, 16) + + structure_data = reader.load_structure() + assert structure_data.sequence == "ACDE" + assert structure_data.pdb_text is not None + assert "ATOM" in structure_data.pdb_text + +def test_archive_reader_missing_structure_does_not_crash(tmp_path): + archive_path = tmp_path / "no_structure.zarr" + root = zarr.open_group(str(archive_path), mode="w") + + metadata = root.require_group("metadata") + metadata.attrs["sequence"] = "ACDE" + metadata.attrs["model_version"] = "openfold-test" + + attention = root.require_group("attention").require_group("triangle_start") + + arr = np.ones((4, 4, 1), dtype=np.float32) + attention.create_dataset( + "layer_00", + data=arr, + shape=arr.shape, + chunks=(4, 4, 1), + ) + + reader = ArchiveReader(archive_path) + + meta = reader.metadata() + assert meta.structure_available is False + assert "attention" in meta.capabilities + assert "structure" not in meta.capabilities + + structure = reader.load_structure() + assert structure.sequence == "ACDE" + assert structure.pdb_text is None + + +def test_archive_reader_rejects_bad_head(tmp_path): + archive_path = tmp_path / "bad_head.zarr" + root = zarr.open_group(str(archive_path), mode="w") + + metadata = root.require_group("metadata") + metadata.attrs["sequence"] = "ACDE" + + attention = root.require_group("attention").require_group("triangle_start") + + arr = np.ones((4, 4, 2), dtype=np.float32) + attention.create_dataset( + "layer_00", + data=arr, + shape=arr.shape, + chunks=(4, 4, 1), + ) + + reader = ArchiveReader(archive_path) + + assert reader.list_heads("triangle_start", 0) == [0, 1] + + with pytest.raises(IndexError): + reader.load_attention("triangle_start", layer=0, head=5) + + +def test_archive_reader_handles_residue_specific_4d_attention(tmp_path): + archive_path = tmp_path / "residue_attention.zarr" + root = zarr.open_group(str(archive_path), mode="w") + + metadata = root.require_group("metadata") + metadata.attrs["sequence"] = "ACDE" + + attention = root.require_group("attention").require_group("triangle_start") + + # Shape: (heads, residue_index, src_residue, dst_residue) + arr = np.zeros((2, 4, 4, 4), dtype=np.float32) + arr[1, 2, 0, 3] = 0.95 + arr[1, 2, 1, 1] = 0.50 + arr[0, 0, 2, 2] = 0.80 + + attention.create_dataset( + "layer_00", + data=arr, + shape=arr.shape, + chunks=(1, 1, 4, 4), + ) + + reader = ArchiveReader(archive_path) + + assert reader.list_attention_types() == ["triangle_start"] + assert reader.list_layers("triangle_start") == [0] + assert reader.list_heads("triangle_start", 0) == [0, 1] + assert reader.list_residue_indices("triangle_start", 0) == [0, 1, 2, 3] + + loaded = reader.load_attention( + attention_type="triangle_start", + layer=0, + head=1, + residue_idx=2, + top_k=2, + ) -def test_archive_reader_constructor() -> None: - reader = ArchiveReader("dummy_archive") - assert str(reader.archive_root).endswith("dummy_archive") - - -@pytest.mark.parametrize( - "method_name,args,kwargs", - [ - ("metadata", tuple(), {}), - ("list_attention_types", tuple(), {}), - ("list_layers", ("msa_row",), {}), - ("list_heads", ("msa_row", 47), {}), - ("list_residue_indices", ("triangle_start", 47), {}), - ("load_attention", ("msa_row", 47, 0), {}), - ("load_attention_heads", ("msa_row", 47), {}), - ("load_structure", tuple(), {}), - ], -) -def test_archive_reader_methods_raise_not_implemented(method_name, args, kwargs) -> None: - reader = ArchiveReader("dummy_archive") - method = getattr(reader, method_name) - - with pytest.raises(NotImplementedError): - method(*args, **kwargs) \ No newline at end of file + assert loaded.residue_idx == 2 + assert loaded.as_triplets()[0] == (0, 3, pytest.approx(0.95)) + assert loaded.as_triplets()[1] == (1, 1, pytest.approx(0.50)) + \ No newline at end of file diff --git a/vizfold/offline/archive_reader.py b/vizfold/offline/archive_reader.py index f0144fd..fcb1e23 100644 --- a/vizfold/offline/archive_reader.py +++ b/vizfold/offline/archive_reader.py @@ -1,63 +1,149 @@ from __future__ import annotations -import json +import re from pathlib import Path from typing import Any -from .models import AttentionSlice, StructureData, TraceMetadata +import numpy as np + +from .models import AttentionConnection, AttentionSlice, StructureData, TraceMetadata from .trace_reader import TraceReader try: import zarr # type: ignore -except Exception: +except Exception: # pragma: no cover zarr = None +_LAYER_RE = re.compile(r"(?:^|[/_\-])layer[_\-]?(\d+)$|^(\d+)$") + + class ArchiveReader(TraceReader): """ - Schema-aware scaffolding for the future standardized archive from issue #39. - - This reader intentionally does NOT hard-code the current issue-39 prototype - layouts. Instead, it: - - detects archive kind (currently only probes Zarr when available) - - normalizes metadata into the TraceMetadata contract - - exposes capability discovery now - - defers actual tensor-path loading until the VizFold protein schema is final + Working reader for standardized VizFold/OpenFold Zarr trace archives. + + Primary supported issue-39 layout: + + metadata/ + representations/single/layer_00 + representations/pair/layer_00 + attention/triangle_start/layer_00 + structure/atom_positions + + Also tolerates prototype layouts like: + + attention/layer_0 + attention # shape [layers, heads, N, N] + inputs/sequence + outputs/coordinates + structure_pdb """ def __init__(self, archive_root: str | Path) -> None: self.archive_root = Path(archive_root) - self._zarr_root = None - self._probe = self._probe_archive() + self._store: Any | None = None + self._root = self._open_root(self.archive_root) + self._metadata_cache: TraceMetadata | None = None def metadata(self) -> TraceMetadata: - return TraceMetadata( - protein_id=self._probe["protein_id"], + if self._metadata_cache is not None: + return self._metadata_cache + + sequence = self.get_sequence() + attention_types = self.list_attention_types() + layers_by_type = { + attn_type: self.list_layers(attn_type) + for attn_type in attention_types + } + heads_by_type = { + attn_type: { + layer: self.list_heads(attn_type, layer) + for layer in layers + } + for attn_type, layers in layers_by_type.items() + } + + residue_indices_by_type: dict[str, dict[int, list[int]]] = {} + for attn_type, layers in layers_by_type.items(): + for layer in layers: + indices = self.list_residue_indices(attn_type, layer) + if indices: + residue_indices_by_type.setdefault(attn_type, {})[layer] = indices + + has_structure = ( + self.get_pdb_string() is not None + or self._find_first_array( + "structure/atom_positions", + "outputs/coordinates", + "coordinates", + ) + is not None + ) + + capabilities = ["metadata", "partial_loading"] + if attention_types: + capabilities.append("attention") + if has_structure: + capabilities.append("structure") + + self._metadata_cache = TraceMetadata( + protein_id=str(self._root.attrs.get("protein_id", self.archive_root.stem)), source_root=self.archive_root, - fasta_path=self._path_or_none(self._probe.get("fasta_path")), - pdb_path=self._path_or_none(self._probe.get("pdb_path")), - sequence=self._probe.get("sequence"), - attention_types=self._probe["attention_types"], - layers_by_type=self._probe["layers_by_type"], - heads_by_type=self._probe["heads_by_type"], - residue_indices_by_type=self._probe["residue_indices_by_type"], - schema_version=self._probe.get("schema_version"), - archive_kind=self._probe.get("archive_kind"), - model_family=self._probe.get("model_family"), - model_version=self._probe.get("model_version"), - structure_available="structure" in self._probe["capabilities"], - capabilities=self._probe["capabilities"], + sequence=sequence, + attention_types=attention_types, + layers_by_type=layers_by_type, + heads_by_type=heads_by_type, + residue_indices_by_type=residue_indices_by_type, + schema_version=self._read_text_metadata( + "schema_version", + "archive_version", + "spec_version", + ), + archive_kind="zarr", + model_family=self._read_text_metadata("model_family", "model_name"), + model_version=self._read_text_metadata("model_version", "config_version"), + structure_available=has_structure, + capabilities=capabilities, extras={ - "sequence_length": self._probe.get("sequence_length"), - "raw_metadata": self._probe.get("raw_metadata", {}), + "num_residues": self._read_int_metadata("num_residues"), + "num_recycles": self._read_int_metadata("num_recycles"), + "arrays": self.list_all_arrays(), }, ) + return self._metadata_cache def list_attention_types(self) -> list[str]: - return list(self._probe["attention_types"]) + types: set[str] = set() + + if "attention" in self._root: + node = self._root["attention"] + + if self._is_group(node): + for key in self._keys(node): + child = node[key] + if self._is_group(child): + types.add(str(key)) + + direct_layer_arrays = [ + key + for key in self._keys(node) + if self._is_array(node[key]) + and self._parse_layer_key(key) is not None + ] + if direct_layer_arrays: + types.add("triangle_start") + + elif self._is_array(node): + types.add("attention") + + for name, shape in self.list_attention_arrays().items(): + if name != "attention" and not name.startswith("attention/"): + types.add(name) + + return sorted(types) def list_layers(self, attention_type: str) -> list[int]: - return list(self._probe["layers_by_type"].get(attention_type, [])) + return sorted(self._discover_layer_numbers(attention_type)) def list_heads( self, @@ -65,18 +151,16 @@ def list_heads( layer: int, residue_idx: int | None = None, ) -> list[int]: - # Today we ignore residue_idx at the indexing layer unless the finalized - # schema eventually needs a residue-specific head map. - return list(self._probe["heads_by_type"].get(attention_type, {}).get(layer, [])) + arr = self._load_attention_tensor(attention_type, layer) + return list(range(self._num_heads_from_shape(arr.shape))) - def list_residue_indices( - self, - attention_type: str, - layer: int, - ) -> list[int]: - return list( - self._probe["residue_indices_by_type"].get(attention_type, {}).get(layer, []) - ) + def list_residue_indices(self, attention_type: str, layer: int) -> list[int]: + arr = self._load_attention_tensor(attention_type, layer) + if arr.ndim != 4: + return [] + + n = self._residue_count_from_attention_shape(arr.shape) + return list(range(n)) if n is not None else [] def load_attention( self, @@ -86,13 +170,16 @@ def load_attention( residue_idx: int | None = None, top_k: int | None = None, ) -> AttentionSlice: - raise self._not_ready( - "load_attention", + arr = self._load_attention_tensor(attention_type, layer) + matrix = self._slice_head_matrix(arr, head=head, residue_idx=residue_idx) + connections = self._matrix_to_connections(matrix, top_k=top_k) + + return AttentionSlice( attention_type=attention_type, layer=layer, head=head, residue_idx=residue_idx, - top_k=top_k, + connections=connections, ) def load_attention_heads( @@ -102,279 +189,571 @@ def load_attention_heads( residue_idx: int | None = None, top_k: int | None = None, ) -> dict[int, AttentionSlice]: - raise self._not_ready( - "load_attention_heads", - attention_type=attention_type, - layer=layer, - residue_idx=residue_idx, - top_k=top_k, - ) + return { + head: self.load_attention( + attention_type=attention_type, + layer=layer, + head=head, + residue_idx=residue_idx, + top_k=top_k, + ) + for head in self.list_heads(attention_type, layer, residue_idx) + } def load_structure(self) -> StructureData: - raise self._not_ready("load_structure") - - def _probe_archive(self) -> dict[str, Any]: - if not self.archive_root.exists(): - raise FileNotFoundError(f"Archive path does not exist: {self.archive_root}") - - probe: dict[str, Any] = { - "protein_id": self.archive_root.stem, - "schema_version": None, - "archive_kind": None, - "model_family": None, - "model_version": None, - "sequence": None, - "sequence_length": None, - "fasta_path": None, - "pdb_path": None, - "attention_types": [], - "layers_by_type": {}, - "heads_by_type": {}, - "residue_indices_by_type": {}, - "capabilities": set(), - "raw_metadata": {}, + pdb_text = self.get_pdb_string() + sequence = self.get_sequence() + + if pdb_text is None: + coords = self._find_first_array( + "structure/atom_positions", + "outputs/coordinates", + "coordinates", + ) + if coords is not None: + pdb_text = self._coords_to_pdb(np.asarray(coords), sequence) + + protein_id = ( + self._metadata_cache.protein_id + if self._metadata_cache is not None + else str(self._root.attrs.get("protein_id", self.archive_root.stem)) + ) + + return StructureData( + protein_id=protein_id, + pdb_path=None, + pdb_text=pdb_text, + sequence=sequence, + ) + + def load_single_representation(self, layer: int) -> np.ndarray: + node = self._resolve_layered_array( + layer, + "representations/single", + "single", + "activations", + dataset_names=("single_repr", "activation", "values"), + ) + if node is None: + raise KeyError(f"Single representation not found for layer {layer}") + return self._to_numpy(node) + + def load_pair_representation(self, layer: int) -> np.ndarray: + node = self._resolve_layered_array( + layer, + "representations/pair", + "pair", + "activations", + dataset_names=("pair_repr", "values"), + ) + if node is None: + raise KeyError(f"Pair representation not found for layer {layer}") + return self._to_numpy(node) + + def list_all_arrays(self) -> dict[str, tuple[int, ...]]: + return { + name: tuple(int(x) for x in array.shape) + for name, array in self._walk_arrays(self._root) } - # Sidecar metadata.json support is a safe, schema-neutral bridge. - metadata_path = self.archive_root / "metadata.json" - if metadata_path.exists(): - payload = self._load_json(metadata_path) - probe["raw_metadata"]["metadata.json"] = payload - self._merge_probe(probe, self._normalize_metadata_payload(payload)) + def list_attention_arrays(self) -> dict[str, tuple[int, ...]]: + out: dict[str, tuple[int, ...]] = {} + for name, shape in self.list_all_arrays().items(): + lowered = name.lower() + if self._is_attention_shape(shape) and ( + "att" in lowered or "triangle" in lowered or name == "attention" + ): + out[name] = shape + return out - # Probe Zarr archives when zarr is installed. - if self.archive_root.suffix == ".zarr": - probe["archive_kind"] = "zarr" - probe["capabilities"].add("partial_loading") + def get_sequence(self) -> str | None: + value = self._read_text_metadata("sequence") + if value: + return value + + for path in ("inputs/sequence", "sequence"): + arr = self._get_array(path) + if arr is not None: + text = self._array_to_text(arr) + if text: + return text + return None - if zarr is not None: - root = zarr.open(str(self.archive_root), mode="r") - self._zarr_root = root + def get_pdb_string(self) -> str | None: + for path in ( + "structure_pdb", + "structure/pdb", + "structure/pdb_text", + "outputs/structure_pdb", + ): + arr = self._get_array(path) + if arr is not None: + text = self._array_to_text(arr) + if text: + return text + return None - attrs = dict(getattr(root, "attrs", {})) - probe["raw_metadata"]["zarr_attrs"] = attrs - self._merge_probe(probe, self._normalize_metadata_payload(attrs)) + def _open_root(self, path: Path) -> Any: + if zarr is None: + raise ImportError("ArchiveReader requires zarr. Install with `pip install zarr`.") + + if not path.exists(): + raise FileNotFoundError(f"Archive path does not exist: {path}") + + if path.suffix == ".zip": + self._store = zarr.ZipStore(str(path), mode="r") + return zarr.open_group(store=self._store, mode="r") + + return zarr.open_group(str(path), mode="r") + + def _discover_layer_numbers(self, attention_type: str) -> set[int]: + layers: set[int] = set() + + group = self._attention_group_for_type(attention_type) + if group is not None: + for key in self._keys(group): + child = group[key] + layer = self._parse_layer_key(key) + if layer is not None and self._is_array_or_group(child): + layers.add(layer) + + arr = self._attention_array_for_type(attention_type) + if arr is not None and arr.ndim == 4 and self._looks_like_layered_attention(arr.shape): + layers.update(range(int(arr.shape[0]))) + + if layers: + return layers + + for name in self.list_attention_arrays(): + layer = self._parse_layer_key(name.split("/")[-1]) + if layer is not None: + layers.add(layer) + + return layers + + def _load_attention_tensor(self, attention_type: str, layer: int) -> np.ndarray: + layered = self._attention_array_for_type(attention_type) + if ( + layered is not None + and layered.ndim == 4 + and self._looks_like_layered_attention(layered.shape) + ): + if layer >= layered.shape[0]: + raise IndexError(f"Layer {layer} out of range for {attention_type}") + return np.asarray(layered[layer]) + + group = self._attention_group_for_type(attention_type) + if group is not None: + for layer_key in self._layer_key_candidates(layer): + if layer_key not in group: + continue - self._infer_capabilities_from_zarr_root(probe, root) + node = group[layer_key] + if self._is_array(node): + return self._to_numpy(node) - probe["attention_types"] = sorted(set(probe["attention_types"])) - probe["layers_by_type"] = { - attn_type: sorted(set(layers)) - for attn_type, layers in probe["layers_by_type"].items() - } - probe["heads_by_type"] = { - attn_type: { - int(layer): sorted(set(heads)) - for layer, heads in layer_map.items() - } - for attn_type, layer_map in probe["heads_by_type"].items() - } - probe["residue_indices_by_type"] = { - attn_type: { - int(layer): sorted(set(indices)) - for layer, indices in layer_map.items() - } - for attn_type, layer_map in probe["residue_indices_by_type"].items() - } - probe["capabilities"] = sorted(probe["capabilities"]) + if self._is_group(node): + for dataset_name in ("attention", "values", "heads"): + if dataset_name in node and self._is_array(node[dataset_name]): + return np.asarray(node[dataset_name]) - return probe + for layer_key in self._layer_key_candidates(layer): + arr = self._get_array(f"attention/{layer_key}") + if arr is not None: + return np.asarray(arr) - def _normalize_metadata_payload(self, payload: dict[str, Any]) -> dict[str, Any]: - normalized: dict[str, Any] = { - "schema_version": self._first_non_empty( - payload.get("schema_version"), - payload.get("archive_version"), - payload.get("spec_version"), - ), - "archive_kind": payload.get("archive_kind"), - "model_family": self._first_non_empty( - payload.get("model_family"), - payload.get("model_name"), - ), - "model_version": payload.get("model_version"), - "protein_id": self._first_non_empty( - payload.get("protein_id"), - payload.get("trace_id"), - payload.get("sample_id"), - ), - "sequence": payload.get("sequence"), - "sequence_length": payload.get("sequence_length"), - "fasta_path": payload.get("fasta_path"), - "pdb_path": payload.get("pdb_path"), - } + arr = self._get_array(attention_type) + if arr is not None: + data = np.asarray(arr) + if data.ndim == 4 and self._looks_like_layered_attention(data.shape): + return data[layer] + return data - capabilities = payload.get("capabilities") - if isinstance(capabilities, (list, tuple, set)): - normalized["capabilities"] = {str(x) for x in capabilities} + raise KeyError(f"Attention not found for type={attention_type}, layer={layer}") - attention_index = payload.get("attention_index") - if not isinstance(attention_index, dict): - attention_index = payload + def _attention_group_for_type(self, attention_type: str) -> Any | None: + candidates = [] - normalized.update(self._normalize_attention_index(attention_index)) - return normalized + if attention_type != "attention": + candidates.append(f"attention/{attention_type}") - def _normalize_attention_index(self, raw: dict[str, Any]) -> dict[str, Any]: - attention_types = [str(x) for x in raw.get("attention_types", [])] + candidates.append("attention") - layers_by_type = self._normalize_layers_by_type(raw.get("layers_by_type", {})) - heads_by_type = self._normalize_nested_index(raw.get("heads_by_type", {})) - residue_indices_by_type = self._normalize_nested_index( - raw.get("residue_indices_by_type", {}) - ) + for path in candidates: + node = self._get_node(path) + if node is not None and self._is_group(node): + return node - if not attention_types: - attention_types = sorted( - set(layers_by_type.keys()) - | set(heads_by_type.keys()) - | set(residue_indices_by_type.keys()) - ) + return None - normalized: dict[str, Any] = { - "attention_types": attention_types, - "layers_by_type": layers_by_type, - "heads_by_type": heads_by_type, - "residue_indices_by_type": residue_indices_by_type, - } + def _attention_array_for_type(self, attention_type: str) -> Any | None: + candidates = [] - if attention_types: - normalized["capabilities"] = {"attention_index"} + if attention_type != "attention": + candidates.append(f"attention/{attention_type}") - if any(residue_indices_by_type.values()): - normalized.setdefault("capabilities", set()).add("residue_indexed_attention") + candidates.extend((attention_type, "attention")) - return normalized + for path in candidates: + node = self._get_node(path) + if node is not None and self._is_array(node): + return node - @staticmethod - def _normalize_layers_by_type(raw: Any) -> dict[str, list[int]]: - if not isinstance(raw, dict): - return {} + return None - out: dict[str, list[int]] = {} - for attention_type, layers in raw.items(): - if not isinstance(layers, (list, tuple)): + def _resolve_layered_array( + self, + layer: int, + *groups: str, + dataset_names: tuple[str, ...], + ) -> Any | None: + for group_path in groups: + group = self._get_node(group_path) + if group is None or not self._is_group(group): continue - out[str(attention_type)] = [int(layer) for layer in layers] - return out + + for layer_key in self._layer_key_candidates(layer): + if layer_key not in group: + continue + + node = group[layer_key] + + if self._is_array(node): + return node + + if self._is_group(node): + for name in dataset_names: + if name in node and self._is_array(node[name]): + return node[name] + + return None + + @staticmethod + def _slice_head_matrix( + arr: np.ndarray, + head: int, + residue_idx: int | None, + ) -> np.ndarray: + if arr.ndim == 2: + if head != 0: + raise IndexError("2-D attention matrix only has head 0") + return arr + + if arr.ndim == 3: + axis = ArchiveReader._infer_head_axis(arr.shape) + + if axis == 0: + return arr[head, :, :] + + if axis == 2: + return arr[:, :, head] + + raise ValueError(f"Cannot infer head axis from attention shape {arr.shape}") + + if arr.ndim == 4: + axis = ArchiveReader._infer_head_axis(arr.shape) + + if axis == 0: + cube = arr[head] + + elif axis == 3: + cube = arr[:, :, :, head] + + else: + raise ValueError(f"Cannot infer head axis from attention shape {arr.shape}") + + if residue_idx is None: + return np.asarray(cube).mean(axis=0) + + return cube[residue_idx, :, :] + + raise ValueError(f"Unsupported attention tensor shape {arr.shape}") @staticmethod - def _normalize_nested_index(raw: Any) -> dict[str, dict[int, list[int]]]: - if not isinstance(raw, dict): - return {} + def _matrix_to_connections( + matrix: np.ndarray, + top_k: int | None, + ) -> list[AttentionConnection]: + mat = np.asarray(matrix, dtype=float) + + if mat.ndim != 2: + raise ValueError(f"Expected 2-D attention matrix, got {mat.shape}") + + flat = mat.reshape(-1) + + if flat.size == 0 or top_k == 0: + return [] - out: dict[str, dict[int, list[int]]] = {} - for attention_type, layer_map in raw.items(): - if not isinstance(layer_map, dict): + k = flat.size if top_k is None else min(int(top_k), flat.size) + + if k < flat.size: + idx = np.argpartition(-flat, k - 1)[:k] + idx = idx[np.argsort(-flat[idx])] + else: + idx = np.argsort(-flat) + + n_cols = mat.shape[1] + connections: list[AttentionConnection] = [] + + for flat_idx in idx: + weight = float(flat[flat_idx]) + + if np.isnan(weight): continue - normalized_layer_map: dict[int, list[int]] = {} - for layer, values in layer_map.items(): - if not isinstance(values, (list, tuple)): - continue - normalized_layer_map[int(layer)] = [int(v) for v in values] + src = int(flat_idx // n_cols) + dst = int(flat_idx % n_cols) - out[str(attention_type)] = normalized_layer_map + connections.append( + AttentionConnection( + src=src, + dst=dst, + weight=weight, + ) + ) - return out + return connections - def _infer_capabilities_from_zarr_root(self, probe: dict[str, Any], root: Any) -> None: - names = set() + @staticmethod + def _infer_head_axis(shape: tuple[int, ...]) -> int: + if len(shape) == 3: + if shape[0] <= 64 and shape[1] == shape[2]: + return 0 + if shape[2] <= 64 and shape[0] == shape[1]: + return 2 - try: - names.update(root.group_keys()) - except Exception: - pass + if len(shape) == 4: + if shape[0] <= 64 and shape[1] == shape[2] == shape[3]: + return 0 + if shape[3] <= 64 and shape[0] == shape[1] == shape[2]: + return 3 - try: - names.update(root.array_keys()) - except Exception: - pass + raise ValueError(f"Cannot infer head axis from shape {shape}") - lowered = {name.lower() for name in names} + @staticmethod + def _num_heads_from_shape(shape: tuple[int, ...]) -> int: + if len(shape) == 2: + return 1 + + axis = ArchiveReader._infer_head_axis(shape) + return int(shape[axis]) + + @staticmethod + def _residue_count_from_attention_shape(shape: tuple[int, ...]) -> int | None: + if len(shape) == 2 and shape[0] == shape[1]: + return int(shape[0]) - if any(name in lowered for name in ("attention", "attn", "representations")): - probe["capabilities"].add("attention") + if len(shape) == 3: + axis = ArchiveReader._infer_head_axis(shape) + return int(shape[1] if axis == 0 else shape[0]) - if any(name in lowered for name in ("structure", "structures", "pdb")): - probe["capabilities"].add("structure") + if len(shape) == 4: + axis = ArchiveReader._infer_head_axis(shape) + return int(shape[1] if axis == 0 else shape[0]) - if any(name in lowered for name in ("metadata", "meta")): - probe["capabilities"].add("metadata") + return None @staticmethod - def _merge_probe(base: dict[str, Any], new: dict[str, Any]) -> None: - scalar_keys = ( - "protein_id", - "schema_version", - "archive_kind", - "model_family", - "model_version", - "sequence", - "sequence_length", - "fasta_path", - "pdb_path", - ) - for key in scalar_keys: - value = new.get(key) - if value is not None: - base[key] = value - - if new.get("attention_types"): - base["attention_types"] = list( - set(base["attention_types"]) | set(new["attention_types"]) + def _looks_like_layered_attention(shape: tuple[int, ...]) -> bool: + return len(shape) == 4 and shape[1] <= 64 and shape[2] == shape[3] + + @staticmethod + def _is_attention_shape(shape: tuple[int, ...]) -> bool: + if len(shape) == 2: + return shape[0] == shape[1] + + if len(shape) == 3: + return ( + shape[0] <= 64 + and shape[1] == shape[2] + ) or ( + shape[2] <= 64 + and shape[0] == shape[1] ) - for key in ("layers_by_type", "heads_by_type", "residue_indices_by_type"): - incoming = new.get(key, {}) - if not incoming: - continue + if len(shape) == 4: + return ( + shape[0] <= 64 + and shape[1] == shape[2] == shape[3] + ) or ( + shape[1] <= 64 + and shape[2] == shape[3] + ) or ( + shape[3] <= 64 + and shape[0] == shape[1] == shape[2] + ) - if key == "layers_by_type": - for attention_type, layers in incoming.items(): - base[key].setdefault(attention_type, []) - base[key][attention_type].extend(layers) - else: - for attention_type, layer_map in incoming.items(): - base[key].setdefault(attention_type, {}) - for layer, values in layer_map.items(): - base[key][attention_type].setdefault(int(layer), []) - base[key][attention_type][int(layer)].extend(values) - - for cap in new.get("capabilities", set()): - base["capabilities"].add(cap) - - def _not_ready(self, method_name: str, **kwargs: Any) -> NotImplementedError: - details = { - "archive_kind": self._probe.get("archive_kind"), - "schema_version": self._probe.get("schema_version"), - "capabilities": self._probe.get("capabilities"), - } - if kwargs: - details["request"] = kwargs + return False + + def _read_text_metadata(self, *keys: str) -> str | None: + for key in keys: + if key in self._root.attrs and self._root.attrs[key] is not None: + return str(self._root.attrs[key]) + + arr = self._get_array(f"metadata/{key}") + if arr is not None: + text = self._array_to_text(arr) + if text: + return text + + meta = self._get_node("metadata") + if meta is not None and self._is_group(meta) and key in meta.attrs: + return str(meta.attrs[key]) + + return None + + def _read_int_metadata(self, key: str) -> int | None: + text = self._read_text_metadata(key) + + if text is None: + return None + + try: + return int(float(text)) + except ValueError: + return None + + def _find_first_array(self, *paths: str) -> np.ndarray | None: + for path in paths: + arr = self._get_array(path) + if arr is not None: + return self._to_numpy(arr) + + return None + + def _get_array(self, path: str) -> Any | None: + node = self._get_node(path) + + if node is not None and self._is_array(node): + return node + + return None + + def _get_node(self, path: str) -> Any | None: + node = self._root + + if not path: + return node + + for part in path.strip("/").split("/"): + if not self._is_group(node) or part not in node: + return None + + node = node[part] + + return node + + @staticmethod + def _keys(group: Any) -> list[str]: + return sorted(str(k) for k in group.keys()) + + @staticmethod + def _is_array(node: Any) -> bool: + return hasattr(node, "shape") and hasattr(node, "dtype") + + @staticmethod + def _is_group(node: Any) -> bool: + return hasattr(node, "keys") and not ArchiveReader._is_array(node) + + @staticmethod + def _is_array_or_group(node: Any) -> bool: + return ArchiveReader._is_array(node) or ArchiveReader._is_group(node) + + def _walk_arrays(self, group: Any, prefix: str = ""): + for key in self._keys(group): + node = group[key] + name = f"{prefix}/{key}" if prefix else key + + if self._is_array(node): + yield name, node + + elif self._is_group(node): + yield from self._walk_arrays(node, name) + + @staticmethod + def _parse_layer_key(key: str) -> int | None: + match = _LAYER_RE.search(key) - return NotImplementedError( - f"{method_name} is intentionally deferred until issue #39 finalizes " - f"the VizFold protein archive schema. Probe summary: {details}" + if not match: + return None + + value = match.group(1) or match.group(2) + return int(value) + + @staticmethod + def _layer_key_candidates(layer: int) -> tuple[str, ...]: + return ( + f"layer_{layer:02d}", + f"layer_{layer}", + str(layer), ) @staticmethod - def _first_non_empty(*values: Any) -> Any: - for value in values: - if value is not None and value != "": - return value + def _array_to_text(array: Any) -> str | None: + value = ArchiveReader._to_numpy(array) + + if value.shape == (): + item = value.item() + + if isinstance(item, bytes): + return item.decode("utf-8") + + return str(item) + + if value.dtype.kind in {"U", "S", "O"}: + parts = [] + + for item in value.reshape(-1): + if isinstance(item, bytes): + parts.append(item.decode("utf-8")) + else: + parts.append(str(item)) + + return "".join(parts) + + if value.dtype == np.uint8: + try: + return bytes(value.reshape(-1)).decode("utf-8") + except UnicodeDecodeError: + return None + return None @staticmethod - def _load_json(path: Path) -> dict[str, Any]: - with path.open("r", encoding="utf-8") as f: - data = json.load(f) - if not isinstance(data, dict): - raise ValueError(f"Expected JSON object in {path}, got {type(data).__name__}") - return data + def _coords_to_pdb(coords: np.ndarray, sequence: str | None = None) -> str: + arr = np.asarray(coords, dtype=float) + + if arr.ndim == 3 and arr.shape[-1] == 3: + arr = arr[:, 1, :] + + elif ( + arr.ndim == 2 + and arr.shape[1] == 3 + and sequence + and arr.shape[0] == len(sequence) * 37 + ): + arr = arr.reshape(len(sequence), 37, 3)[:, 1, :] + + elif arr.ndim != 2 or arr.shape[1] != 3: + raise ValueError(f"Cannot convert coordinates with shape {coords.shape} to PDB") + + lines = [] + + for idx, (x, y, z) in enumerate(arr, start=1): + res = sequence[idx - 1] if sequence and idx - 1 < len(sequence) else "X" + lines.append( + f"ATOM {idx:5d} CA {res:>3s} A{idx:4d} " + f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C" + ) + lines.append("END") + return "\n".join(lines) + "\n" + @staticmethod - def _path_or_none(value: Any) -> Path | None: - if value is None: - return None - return Path(value) \ No newline at end of file + def _to_numpy(array: Any) -> np.ndarray: + if hasattr(array, "shape") and hasattr(array, "__getitem__"): + if tuple(array.shape) == (): + return np.asarray(array[()]) + return np.asarray(array[:]) + + return np.asarray(array) \ No newline at end of file From 06d0361b3fc4cc08f3c208c7db803e2a52378baf Mon Sep 17 00:00:00 2001 From: archiblesherman Date: Tue, 28 Apr 2026 22:37:10 -0400 Subject: [PATCH 4/5] Merged Frontend UI, Integration Layer, and Backend Reader. --- .gitignore | 7 + .streamlit/config.toml | 2 + vizfold/offline/archive_reader.py | 30 +- webui/app.py | 418 ++++++++++++++++++++++++++ webui/components/__init__.py | 0 webui/components/arc_diagram.py | 76 +++++ webui/components/attention_heatmap.py | 69 +++++ webui/components/structure_viewer.py | 86 ++++++ webui/make_sample_trace.py | 135 +++++++++ webui/trace_reader.py | 294 ++++++++++++++++++ webui/visualization_adapter.py | 108 +++++++ 11 files changed, 1223 insertions(+), 2 deletions(-) create mode 100644 .streamlit/config.toml create mode 100644 webui/app.py create mode 100644 webui/components/__init__.py create mode 100644 webui/components/arc_diagram.py create mode 100644 webui/components/attention_heatmap.py create mode 100644 webui/components/structure_viewer.py create mode 100644 webui/make_sample_trace.py create mode 100644 webui/trace_reader.py create mode 100644 webui/visualization_adapter.py diff --git a/.gitignore b/.gitignore index 3f1d838..3681434 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,10 @@ cutlass/ *.sto *.a3m *.hhr + +# generated archive test outputs +*.zarr/ +*.zip +vizfold_archive.zarr/ +vizfold_archive.zip +create_archive.py \ No newline at end of file diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 0000000..3fa35e6 --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,2 @@ +[server] +maxUploadSize = 1000 \ No newline at end of file diff --git a/vizfold/offline/archive_reader.py b/vizfold/offline/archive_reader.py index fcb1e23..fd695b0 100644 --- a/vizfold/offline/archive_reader.py +++ b/vizfold/offline/archive_reader.py @@ -1,5 +1,9 @@ from __future__ import annotations +import shutil +import tempfile +import zipfile + import re from pathlib import Path from typing import Any @@ -42,6 +46,7 @@ class ArchiveReader(TraceReader): def __init__(self, archive_root: str | Path) -> None: self.archive_root = Path(archive_root) self._store: Any | None = None + self._extracted_archive_root: Path | None = None self._root = self._open_root(self.archive_root) self._metadata_cache: TraceMetadata | None = None @@ -301,8 +306,29 @@ def _open_root(self, path: Path) -> Any: raise FileNotFoundError(f"Archive path does not exist: {path}") if path.suffix == ".zip": - self._store = zarr.ZipStore(str(path), mode="r") - return zarr.open_group(store=self._store, mode="r") + extracted_dir = Path(tempfile.mkdtemp(prefix="vizfold_zarr_")) + + with zipfile.ZipFile(path, "r") as zf: + zf.extractall(extracted_dir) + + # Case 1: user zipped the contents of the .zarr folder + if (extracted_dir / "zarr.json").exists() or (extracted_dir / ".zgroup").exists(): + self._extracted_archive_root = extracted_dir + return zarr.open_group(str(extracted_dir), mode="r") + + # Case 2: user zipped the .zarr folder itself + zarr_dirs = list(extracted_dir.glob("*.zarr")) + if zarr_dirs: + self._extracted_archive_root = zarr_dirs[0] + return zarr.open_group(str(zarr_dirs[0]), mode="r") + + # Case 3: zip contains one top-level folder + child_dirs = [p for p in extracted_dir.iterdir() if p.is_dir()] + if len(child_dirs) == 1: + self._extracted_archive_root = child_dirs[0] + return zarr.open_group(str(child_dirs[0]), mode="r") + + raise ValueError(f"Could not find Zarr archive root inside zip file: {path}") return zarr.open_group(str(path), mode="r") diff --git a/webui/app.py b/webui/app.py new file mode 100644 index 0000000..868501a --- /dev/null +++ b/webui/app.py @@ -0,0 +1,418 @@ +""" +VizFold — Streamlit UI for offline protein model internals exploration. + +Run: + streamlit run webui/app.py +""" + +import hashlib +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +import streamlit as st + +import tempfile + +from components.structure_viewer import render_structure +from components.attention_heatmap import render_heatmap +from components.arc_diagram import render_arc_diagram + +from trace_reader import TraceReader +from vizfold.offline import ArchiveReader +from visualization_adapter import ( + flatten_attention_heads, + build_visualization_payload, +) +# ── Page config ─────────────────────────────────────────────────────────────── + +st.set_page_config( + page_title="VizFold", + page_icon="🔬", + layout="wide", + initial_sidebar_state="expanded", +) + +st.markdown( + """ + + """, + unsafe_allow_html=True, +) + +# ── Zarr session-state cache ────────────────────────────────────────────────── + +def _get_archive_reader(uploaded_file) -> ArchiveReader: + """Cache backend ArchiveReader in session state, keyed by uploaded file hash.""" + file_bytes = uploaded_file.getvalue() + file_hash = hashlib.md5(file_bytes).hexdigest() + + if st.session_state.get("_zarr_hash") != file_hash: + tmp_path = os.path.join( + tempfile.gettempdir(), + f"vizfold_uploaded_{file_hash}.zip", + ) + + with open(tmp_path, "wb") as f: + f.write(file_bytes) + + st.session_state["_zarr_hash"] = file_hash + st.session_state["_archive_reader"] = ArchiveReader(tmp_path) + + return st.session_state["_archive_reader"] + +# ── Sidebar ─────────────────────────────────────────────────────────────────── + +# Shared rendering outputs — populated by whichever source branch runs +connections: list = [] +fasta_seq: str = "" +n_residues: int = 0 +pdb_path: str | None = None +pdb_string: str | None = None +protein: str = "" +head_label: str = "" +attn_badge: str = "" +residue_idx: int | None = None + +with st.sidebar: + st.markdown("## 🔬 VizFold") + st.caption("Protein model internals explorer") + st.divider() + + source = st.radio( + "Data source", + ["Trace directory", "Zarr archive"], + help="Load from a local trace folder, or upload a Zarr ZipStore (.zip).", + ) + st.divider() + + # ── Branch A: Trace directory ───────────────────────────────────────────── + if source == "Trace directory": + default_trace = os.path.join(os.path.dirname(__file__), "sample_trace") + trace_dir = st.text_input( + "Trace directory", + value=default_trace, + help="Root folder containing per-protein trace subdirectories.", + ) + + reader = TraceReader(trace_dir) + proteins = reader.list_proteins() + + if not proteins: + st.error( + "No proteins found. Check the path, or run " + "`python webui/make_sample_trace.py` first." + ) + st.stop() + + protein = st.selectbox("Protein", proteins) + fasta_seq = reader.get_fasta_sequence(protein) + n_residues = len(fasta_seq) + if n_residues: + st.caption(f"{n_residues} residues") + + st.divider() + st.markdown("**Attention**") + + attn_type = st.radio( + "Type", + ["msa_row", "triangle_start"], + format_func=lambda x: "MSA Row" if x == "msa_row" else "Triangle Start", + ) + + layers = reader.list_layers(protein, attn_type) + if not layers: + st.warning(f"No `{attn_type}` attention files found for **{protein}**.") + st.stop() + + layer_idx = st.select_slider("Layer", options=layers, value=layers[-1]) + + if attn_type == "triangle_start": + available_res = reader.list_triangle_residues(protein, layer_idx) + if not available_res: + st.warning("No triangle attention files found for this layer.") + st.stop() + residue_idx = st.select_slider( + "Source residue", options=available_res, value=available_res[0] + ) + + top_k = st.slider("Top-K connections", 10, 200, 50, step=10) + + if attn_type == "triangle_start" and residue_idx is not None: + attn_data = reader.load_triangle_attention(protein, layer_idx, residue_idx, top_k) + else: + attn_data = reader.load_attention(protein, attn_type, layer_idx, top_k) + + n_heads = len(attn_data) + head_sel = st.selectbox( + "Head", + ["Average"] + list(range(n_heads)), + format_func=lambda x: "All heads averaged" if x == "Average" else f"Head {x}", + ) + + if head_sel == "Average": + connections = flatten_attention_heads(attn_data) + head_label = "All heads averaged" + else: + connections = attn_data.get(int(head_sel), []) + head_label = f"Head {head_sel}" + + pdb_path = reader.get_pdb_path(protein) + attn_badge = ( + "MSA Row" if attn_type == "msa_row" + else f"Triangle Start (res {residue_idx})" + ) + + # ── Branch B: Zarr archive ──────────────────────────────────────────────── + else: + uploaded = st.file_uploader( + "Upload Zarr archive", + type=["zip"], + help=( + "Upload a zipped Zarr archive containing attention arrays, " + "representations, metadata, and optional structure outputs." + ), + ) + + if not uploaded: + st.info("Upload a `.zip` Zarr archive to continue.") + st.stop() + + with st.spinner("Opening Zarr archive..."): + archive_reader = _get_archive_reader(uploaded) + + meta = archive_reader.metadata() + + with st.expander("Archive metadata", expanded=False): + st.write("Archive kind:", meta.archive_kind) + st.write("Model:", meta.model_family) + st.write("Model version:", meta.model_version) + st.write("Attention types:", meta.attention_types) + st.write("Capabilities:", meta.capabilities) + st.write("Available arrays:", meta.extras.get("arrays", {})) + + attention_types = archive_reader.list_attention_types() + + if not attention_types: + st.error("No attention arrays found in this archive.") + st.stop() + + attn_type = st.selectbox("Attention type", attention_types) + + layers = archive_reader.list_layers(attn_type) + + if not layers: + st.error(f"No layers found for attention type `{attn_type}`.") + st.stop() + + layer_idx = st.select_slider("Layer", options=layers, value=layers[-1]) + + residue_options = archive_reader.list_residue_indices(attn_type, layer_idx) + residue_idx = None + + if residue_options: + residue_idx = st.select_slider( + "Source residue", + options=residue_options, + value=residue_options[0], + ) + + heads = archive_reader.list_heads(attn_type, layer_idx, residue_idx) + + if not heads: + st.error(f"No heads found for `{attn_type}` layer {layer_idx}.") + st.stop() + + head_sel_zarr = st.selectbox( + "Head", + ["Average"] + heads, + format_func=lambda x: "All heads averaged" if x == "Average" else f"Head {x}", + ) + + top_k = st.slider("Top-K connections", 10, 200, 50, step=10) + + if head_sel_zarr == "Average": + loaded_heads = archive_reader.load_attention_heads( + attention_type=attn_type, + layer=layer_idx, + residue_idx=residue_idx, + top_k=top_k, + ) + + connections = [] + for attention_slice in loaded_heads.values(): + connections.extend(attention_slice.as_triplets()) + + connections = sorted(connections, key=lambda x: x[2], reverse=True) + head_label = "All heads averaged" + + else: + attention_slice = archive_reader.load_attention( + attention_type=attn_type, + layer=layer_idx, + head=int(head_sel_zarr), + residue_idx=residue_idx, + top_k=top_k, + ) + + connections = attention_slice.as_triplets() + head_label = f"Head {head_sel_zarr}" + + structure_data = archive_reader.load_structure() + + pdb_path = None + pdb_string = structure_data.pdb_text + + n_residues = ( + meta.extras.get("num_residues") + or len(structure_data.sequence or "") + or len(meta.sequence or "") + ) + + if not n_residues and connections: + n_residues = max(max(src, dst) for src, dst, _ in connections) + 1 + + fasta_seq = structure_data.sequence or meta.sequence or ("X" * int(n_residues)) + protein = meta.protein_id + attn_badge = ( + f"{attn_type}" + if residue_idx is None + else f"{attn_type} (res {residue_idx})" + ) + + st.caption( + f"{n_residues} residues · {len(layers)} layers · {len(heads)} heads" + ) + + # ── Shared display toggles ──────────────────────────────────────────────── + st.divider() + st.markdown("**Display**") + show_structure = st.toggle("3D Structure", value=True) + show_heatmap = st.toggle("Attention Heatmap", value=True) + show_arc = st.toggle("Arc Diagram", value=True) + +# ── Header ──────────────────────────────────────────────────────────────────── + +st.markdown(f"# {protein}") + +c1, c2, c3, c4 = st.columns(4) +c1.metric("Residues", n_residues) +c2.metric("Layer", layer_idx) +c3.metric("Head", head_label) +c4.metric("Connections", len(connections)) + +viz_payload = build_visualization_payload( + fasta_seq=fasta_seq, + pdb_path=pdb_path, + connections=connections, + layer_idx=layer_idx, + head_label=head_label, +) + +with st.expander("Visualization Integration Output", expanded=False): + st.write("Stored trace data has been converted into visualization-ready format.") + st.write("Format: `(source_residue, target_residue, attention_weight)`") + + st.write("Layer:", viz_payload["layer_idx"]) + st.write("Head:", viz_payload["head_label"]) + st.write("Residues:", viz_payload["n_residues"]) + st.write("Connections:", len(viz_payload["connections"])) + + if viz_payload["connections"]: + st.dataframe( + [ + { + "source_residue": r1, + "target_residue": r2, + "attention_weight": weight, + } + for r1, r2, weight in viz_payload["connections"][:10] + ], + use_container_width=True, + ) + + +st.caption(f"Attention: **{attn_badge}** · top-{top_k} per head") + +if not connections: + st.warning("No connections loaded — check that the trace files exist.") + +st.divider() + +# ── Main layout ─────────────────────────────────────────────────────────────── + +left_col, right_col = st.columns([1, 1], gap="large") + +with left_col: + if show_structure: + with st.container(border=True): + st.markdown("#### 3D Structure") + has_structure = (pdb_path and os.path.exists(pdb_path)) or pdb_string + if has_structure: + render_structure( + pdb_path=pdb_path, + connections=connections, + n_residues=n_residues, + pdb_string=pdb_string, + ) + else: + st.info( + "No structure file found. For Zarr archives, include a " + "`structure_pdb` array with PDB text content." + ) + + if show_arc: + with st.container(border=True): + st.markdown("#### Arc Diagram") + render_arc_diagram(connections, fasta_seq, highlight_residue=residue_idx) + +with right_col: + if show_heatmap: + with st.container(border=True): + st.markdown("#### Attention Heatmap") + render_heatmap(connections, fasta_seq, head_label) + + with st.container(border=True): + st.markdown("#### Residue Scores") + if connections and fasta_seq: + import numpy as np + import plotly.graph_objects as go + + scores = np.zeros(n_residues) + for r1, r2, w in connections: + if r1 < n_residues: + scores[r1] += w + if r2 < n_residues: + scores[r2] += w + + tick_step = max(1, n_residues // 20) + fig = go.Figure( + go.Bar( + x=list(range(n_residues)), + y=scores, + marker=dict(color=scores, colorscale="Reds", showscale=False), + hovertemplate=( + "Residue %{x} (%{customdata})
" + "Score: %{y:.4f}" + ), + customdata=list(fasta_seq), + ) + ) + fig.update_layout( + xaxis=dict( + title="Residue index", + tickvals=list(range(0, n_residues, tick_step)), + ticktext=[fasta_seq[i] for i in range(0, n_residues, tick_step)], + ), + yaxis_title="Aggregated attention", + margin=dict(l=40, r=10, t=10, b=50), + height=240, + ) + st.plotly_chart(fig, use_container_width=True) + else: + st.info("Load attention data to see per-residue scores.") diff --git a/webui/components/__init__.py b/webui/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/webui/components/arc_diagram.py b/webui/components/arc_diagram.py new file mode 100644 index 0000000..d01b4a3 --- /dev/null +++ b/webui/components/arc_diagram.py @@ -0,0 +1,76 @@ +""" +Arc diagram — residue-to-residue attention arcs drawn with Matplotlib. +Adapted from visualize_attention_arc_diagram_demo_utils.py but returns +a Figure instead of saving to disk, for use inside Streamlit. +""" + +from __future__ import annotations +from typing import List, Optional, Tuple + +import numpy as np +import matplotlib.pyplot as plt +import streamlit as st + + +def render_arc_diagram( + connections: List[Tuple[int, int, float]], + fasta_seq: str, + highlight_residue: Optional[int] = None, +) -> None: + if not connections or not fasta_seq: + st.info("No arc data to display.") + return + fig = _build_figure(connections, fasta_seq, highlight_residue) + st.pyplot(fig, use_container_width=True) + plt.close(fig) + + +# ── Internal ────────────────────────────────────────────────────────────────── + +def _build_figure( + connections: List[Tuple[int, int, float]], + fasta_seq: str, + highlight_residue: Optional[int], +) -> plt.Figure: + n = len(fasta_seq) + weights = [w for _, _, w in connections] + w_min, w_max = min(weights), max(weights) + w_range = (w_max - w_min) or 1.0 + + fig_w = max(14, n // 7) + fig, ax = plt.subplots(figsize=(fig_w, 4)) + fig.patch.set_facecolor("#0e1117") + ax.set_facecolor("#0e1117") + + for r1, r2, w in connections: + x1, x2 = r1 + 0.5, r2 + 0.5 + height = abs(x2 - x1) / 2 + norm_w = (w - w_min) / w_range + lw = 0.3 + norm_w * 2.0 + intensity = 0.35 + 0.65 * norm_w + color = (0.15, 0.45 * (1 - norm_w * 0.5), intensity) + + xs = np.linspace(x1, x2, 80) + ys = height * np.sin(np.linspace(0, np.pi, 80)) + ax.plot(xs, ys, color=color, linewidth=lw, alpha=0.85, solid_capstyle="round") + + ax.set_xlim(0, n) + ax.set_ylim(0, None) + ax.set_xticks(np.arange(n) + 0.5) + labels = ax.set_xticklabels( + list(fasta_seq), fontsize=max(5, min(8, 120 // n)), + color="#aaaaaa", ha="center", + ) + + if highlight_residue is not None and 0 <= highlight_residue < len(labels): + labels[highlight_residue].set_color("#ff4b4b") + labels[highlight_residue].set_fontweight("bold") + + ax.tick_params(axis="x", length=0) + ax.set_yticks([]) + ax.spines[:].set_visible(False) + ax.set_ylabel("Attention strength", color="#aaaaaa", fontsize=9) + ax.yaxis.label.set_color("#aaaaaa") + + plt.tight_layout(pad=0.5) + return fig diff --git a/webui/components/attention_heatmap.py b/webui/components/attention_heatmap.py new file mode 100644 index 0000000..1b0696f --- /dev/null +++ b/webui/components/attention_heatmap.py @@ -0,0 +1,69 @@ +""" +Interactive N×N attention heatmap rendered with Plotly. +Builds a dense matrix from sparse (res1, res2, weight) connections. +""" + +from __future__ import annotations +from typing import List, Tuple + +import numpy as np +import plotly.graph_objects as go +import streamlit as st + + +def render_heatmap( + connections: List[Tuple[int, int, float]], + fasta_seq: str, + head_label: str, +) -> None: + n = len(fasta_seq) + if n == 0 or not connections: + st.info("No attention data to display.") + return + + matrix = _build_matrix(connections, n) + + tick_step = max(1, n // 20) + tick_vals = list(range(0, n, tick_step)) + tick_text = [fasta_seq[i] for i in tick_vals] + + fig = go.Figure( + data=go.Heatmap( + z=matrix, + colorscale="Blues", + colorbar=dict(title="Attention", thickness=14), + hovertemplate="Source %{y} → Target %{x}
Score: %{z:.5f}", + ) + ) + fig.update_layout( + title=dict(text=f"Attention Map — {head_label}", font=dict(size=14)), + xaxis=dict( + title="Target residue", + tickvals=tick_vals, + ticktext=tick_text, + tickfont=dict(size=10), + ), + yaxis=dict( + title="Source residue", + tickvals=tick_vals, + ticktext=tick_text, + tickfont=dict(size=10), + autorange="reversed", + ), + margin=dict(l=60, r=20, t=50, b=60), + height=440, + ) + st.plotly_chart(fig, use_container_width=True) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _build_matrix( + connections: List[Tuple[int, int, float]], n: int +) -> np.ndarray: + matrix = np.zeros((n, n)) + for r1, r2, w in connections: + if r1 < n and r2 < n: + matrix[r1, r2] += w + matrix[r2, r1] += w + return matrix diff --git a/webui/components/structure_viewer.py b/webui/components/structure_viewer.py new file mode 100644 index 0000000..56808fa --- /dev/null +++ b/webui/components/structure_viewer.py @@ -0,0 +1,86 @@ +""" +3D protein structure viewer powered by py3Dmol. +Residues are colored by aggregated attention score (white → red gradient). +""" + +from __future__ import annotations +from typing import List, Tuple + +import os + +import numpy as np +import py3Dmol +import streamlit as st +import streamlit.components.v1 as components + + +def render_structure( + pdb_path: str | None, + connections: List[Tuple[int, int, float]], + n_residues: int, + height: int = 480, + pdb_string: str | None = None, +) -> None: + if pdb_string is not None: + pdb_str = pdb_string + elif pdb_path and os.path.exists(pdb_path): + with open(pdb_path) as f: + pdb_str = f.read() + else: + st.info("No structure file available.") + return + + scores = _residue_scores(connections, n_residues) + + view = py3Dmol.view(width="100%", height=height) + view.addModel(pdb_str, "pdb") + view.setStyle({"cartoon": {"color": "#cccccc"}}) + view.setBackgroundColor("#0e1117") + + if scores.max() > 0: + normed = scores / scores.max() + for resi_0, t in enumerate(normed): + if t > 0.02: + color = _score_to_hex(float(t)) + view.addStyle( + {"resi": resi_0 + 1}, + {"cartoon": {"color": color}}, + ) + + view.zoomTo() + + # stmol is optional; fall back to raw HTML embed + try: + from stmol import showmol # type: ignore + showmol(view, height=height + 20) + except ImportError: + components.html(view._make_html(), height=height + 20, scrolling=False) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _residue_scores( + connections: List[Tuple[int, int, float]], n_residues: int +) -> np.ndarray: + scores = np.zeros(n_residues) + for r1, r2, w in connections: + if r1 < n_residues: + scores[r1] += w + if r2 < n_residues: + scores[r2] += w + return scores + + +def _score_to_hex(t: float) -> str: + """Map t ∈ [0, 1] to a white→orange→red gradient.""" + if t < 0.5: + s = t * 2 + r = int(255) + g = int(255 * (1 - s * 0.5)) + b = int(255 * (1 - s)) + else: + s = (t - 0.5) * 2 + r = int(255) + g = int(255 * 0.5 * (1 - s)) + b = 0 + return f"#{r:02x}{g:02x}{b:02x}" diff --git a/webui/make_sample_trace.py b/webui/make_sample_trace.py new file mode 100644 index 0000000..52ce0d1 --- /dev/null +++ b/webui/make_sample_trace.py @@ -0,0 +1,135 @@ +""" +make_sample_trace.py — one-time setup script for the VizFold sample trace. + +Run from the repo root: + python webui/make_sample_trace.py + +Copies the 6KWC PDB and FASTA from the examples/ directory into +webui/sample_trace/6KWC/ and writes synthetic attention files so the +Streamlit app can run immediately without a real inference trace. +""" + +import os +import shutil +import random +import math + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +WEBUI_DIR = os.path.dirname(os.path.abspath(__file__)) +OUT_DIR = os.path.join(WEBUI_DIR, "sample_trace", "6KWC") + +PDB_SRC = os.path.join(REPO_ROOT, "examples", "monomer", "sample_predictions", + "6KWC_1_model_1_ptm_relaxed.pdb") +FASTA_SRC = os.path.join(REPO_ROOT, "examples", "monomer", "fasta_dir", "6kwc.fasta") + +LAYERS = [0, 10, 20, 30, 40, 47] +N_HEADS = 8 +N_RESIDUES = 190 +TOP_K = 100 +TRIANGLE_RESIDUES = [18, 39, 51, 79, 138, 159] + +random.seed(42) + + +def softmax_weights(n: int, bias_center: int | None = None) -> list[float]: + """Generate plausible attention-like weights (peaked distribution).""" + logits = [random.gauss(0, 1) for _ in range(n)] + if bias_center is not None: + for i in range(n): + dist = abs(i - bias_center) + logits[i] += max(0, 3.0 - dist * 0.15) + max_l = max(logits) + exps = [math.exp(l - max_l) for l in logits] + total = sum(exps) + return [e / total for e in exps] + + +def generate_connections( + n_residues: int, + n_connections: int, + bias_residue: int | None = None, +) -> list[tuple[int, int, float]]: + row_weights = softmax_weights(n_residues, bias_residue) + connections: list[tuple[int, int, float]] = [] + seen: set[tuple[int, int]] = set() + + while len(connections) < n_connections: + r1 = random.choices(range(n_residues), weights=row_weights)[0] + col_weights = softmax_weights(n_residues, r1) + r2 = random.choices(range(n_residues), weights=col_weights)[0] + if r1 == r2: + continue + key = (min(r1, r2), max(r1, r2)) + if key in seen: + continue + seen.add(key) + weight = random.uniform(0.001, 0.05) + random.random() * 0.02 + connections.append((r1, r2, weight)) + + connections.sort(key=lambda x: x[2], reverse=True) + return connections + + +def write_heads_file(path: str, layer_idx: int, n_heads: int, + n_residues: int, top_k: int, + bias_residue: int | None = None) -> None: + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w") as f: + for head in range(n_heads): + conns = generate_connections(n_residues, top_k, bias_residue) + f.write(f"Layer {layer_idx}, Head {head}\n") + for r1, r2, w in conns: + f.write(f"{r1} {r2} {w:.6f}\n") + print(f" Written: {os.path.relpath(path, REPO_ROOT)}") + + +def main() -> None: + print("VizFold — setting up sample trace for 6KWC\n") + + attn_dir = os.path.join(OUT_DIR, "attention") + os.makedirs(attn_dir, exist_ok=True) + + # --- PDB --- + pdb_dst = os.path.join(OUT_DIR, "6KWC.pdb") + if os.path.exists(PDB_SRC): + shutil.copy2(PDB_SRC, pdb_dst) + print(f" Copied: {os.path.relpath(pdb_dst, REPO_ROOT)}") + else: + print(f" [WARN] PDB source not found: {PDB_SRC}") + + # --- FASTA --- + fasta_dst = os.path.join(OUT_DIR, "6KWC.fasta") + if os.path.exists(FASTA_SRC): + shutil.copy2(FASTA_SRC, fasta_dst) + print(f" Copied: {os.path.relpath(fasta_dst, REPO_ROOT)}") + else: + print(f" [WARN] FASTA source not found: {FASTA_SRC}") + + print() + + # --- MSA row attention --- + print("Generating MSA row attention files...") + for layer in LAYERS: + path = os.path.join(attn_dir, f"msa_row_attn_layer{layer}.txt") + write_heads_file(path, layer, N_HEADS, N_RESIDUES, TOP_K) + + print() + + # --- Triangle start attention --- + print("Generating triangle start attention files...") + for layer in [47]: + for res_idx in TRIANGLE_RESIDUES: + path = os.path.join( + attn_dir, + f"triangle_start_attn_layer{layer}_residue_idx_{res_idx}.txt", + ) + write_heads_file(path, layer, N_HEADS, N_RESIDUES, TOP_K, + bias_residue=res_idx) + + print() + print("Done. Start the app with:") + print(" streamlit run webui/app.py") + + +if __name__ == "__main__": + main() diff --git a/webui/trace_reader.py b/webui/trace_reader.py new file mode 100644 index 0000000..87922a9 --- /dev/null +++ b/webui/trace_reader.py @@ -0,0 +1,294 @@ +""" +TraceReader — loads offline VizFold inference traces from disk. + +Expected directory layout: + + {trace_root}/ + └── {PROTEIN_ID}/ + ├── *.pdb + ├── *.fasta + └── attention/ + ├── msa_row_attn_layer{N}.txt + └── triangle_start_attn_layer{N}_residue_idx_{R}.txt + +Attention files follow the format used by the existing OpenFold viz pipeline: + Layer {N}, Head {H} + res1 res2 weight + ... +""" + +import os +import glob +import re +import tempfile +from typing import Dict, List, Optional, Tuple + +import numpy as np + + +AttentionMap = Dict[int, List[Tuple[int, int, float]]] + + +class TraceReader: + def __init__(self, trace_root: str): + self.root = os.path.expanduser(trace_root) + + # ── Discovery ──────────────────────────────────────────────────────────── + + def list_proteins(self) -> List[str]: + if not os.path.isdir(self.root): + return [] + return sorted( + d for d in os.listdir(self.root) + if os.path.isdir(os.path.join(self.root, d)) + ) + + def get_pdb_path(self, protein: str) -> Optional[str]: + hits = glob.glob(os.path.join(self.root, protein, "*.pdb")) + return hits[0] if hits else None + + def get_fasta_sequence(self, protein: str) -> str: + hits = glob.glob(os.path.join(self.root, protein, "*.fasta")) + if not hits: + return "" + with open(hits[0]) as f: + lines = f.readlines() + return "".join(l.strip() for l in lines if not l.startswith(">")) + + def list_layers(self, protein: str, attention_type: str) -> List[int]: + attn_dir = os.path.join(self.root, protein, "attention") + if not os.path.isdir(attn_dir): + return [] + if attention_type == "msa_row": + pat = re.compile(r"msa_row_attn_layer(\d+)\.txt$") + else: + pat = re.compile(r"triangle_start_attn_layer(\d+)_residue_idx_\d+\.txt$") + layers: set = set() + for fname in os.listdir(attn_dir): + m = pat.match(fname) + if m: + layers.add(int(m.group(1))) + return sorted(layers) + + def list_triangle_residues(self, protein: str, layer_idx: int) -> List[int]: + attn_dir = os.path.join(self.root, protein, "attention") + if not os.path.isdir(attn_dir): + return [] + pat = re.compile( + rf"triangle_start_attn_layer{layer_idx}_residue_idx_(\d+)\.txt$" + ) + residues: set = set() + for fname in os.listdir(attn_dir): + m = pat.match(fname) + if m: + residues.add(int(m.group(1))) + return sorted(residues) + + # ── Loading ─────────────────────────────────────────────────────────────── + + def load_attention( + self, + protein: str, + attention_type: str, + layer_idx: int, + top_k: Optional[int] = None, + ) -> AttentionMap: + if attention_type != "msa_row": + return {} + path = os.path.join( + self.root, protein, "attention", + f"msa_row_attn_layer{layer_idx}.txt", + ) + return self._parse_heads_file(path, top_k) if os.path.exists(path) else {} + + def load_triangle_attention( + self, + protein: str, + layer_idx: int, + residue_idx: int, + top_k: Optional[int] = None, + ) -> AttentionMap: + path = os.path.join( + self.root, protein, "attention", + f"triangle_start_attn_layer{layer_idx}_residue_idx_{residue_idx}.txt", + ) + return self._parse_heads_file(path, top_k) if os.path.exists(path) else {} + + # ── Internal ────────────────────────────────────────────────────────────── + + @staticmethod + def _parse_heads_file(path: str, top_k: Optional[int]) -> AttentionMap: + heads: AttentionMap = {} + current: Optional[int] = None + with open(path) as f: + for line in f: + line = line.strip() + if not line: + continue + if line.lower().startswith("layer"): + parts = line.replace(",", "").split() + current = int(parts[-1]) + heads[current] = [] + elif current is not None: + r1, r2, w = map(float, line.split()) + heads[current].append((int(r1), int(r2), w)) + for h in heads: + heads[h].sort(key=lambda x: x[2], reverse=True) + if top_k is not None: + heads[h] = heads[h][:top_k] + return heads + + +# ── Zarr reader ─────────────────────────────────────────────────────────────── + +Connections = List[Tuple[int, int, float]] + + +class ZarrTraceReader: + """ + Reads attention data from a Zarr ZipStore archive (.zip). + + Auto-detects attention arrays by shape: any array with ndim >= 2 + whose last two dimensions are equal (i.e. N×N residue–residue matrices). + + Supported shapes: + 4D [n_layers, n_heads, N, N] — standard; layer + head indexable + 3D [n_heads, N, N] — single-layer store + 2D [N, N] — single head + single layer + + Optional metadata arrays (auto-detected by name): + sequence / seq / fasta — amino acid string (bytes or str array) + structure_pdb / pdb — PDB file content (bytes or str array) + """ + + def __init__(self, file_bytes: bytes) -> None: + import zarr + import zarr.storage + + self._tmp_path = tempfile.mktemp(suffix=".zip") + with open(self._tmp_path, "wb") as f: + f.write(file_bytes) + + self._store = zarr.storage.ZipStore(self._tmp_path, mode="r") + self._root = zarr.open_group(self._store, mode="r") + self._arrays: Dict[str, object] = {} + self._collect_arrays() + + def _collect_arrays(self) -> None: + import zarr + for name, item in self._root.members(max_depth=None): + if isinstance(item, zarr.Array): + self._arrays[name] = item + + # ── Discovery ───────────────────────────────────────────────────────────── + + def list_all_arrays(self) -> Dict[str, tuple]: + return {name: arr.shape for name, arr in self._arrays.items()} # type: ignore[union-attr] + + def list_attention_arrays(self) -> Dict[str, tuple]: + """Arrays whose last two dims are equal and ≥ 10 (likely N×N attention).""" + out = {} + for name, arr in self._arrays.items(): + shape = arr.shape # type: ignore[union-attr] + if len(shape) >= 2 and shape[-1] == shape[-2] and shape[-1] >= 10: + out[name] = shape + return out + + def n_layers(self, array_name: str) -> int: + shape = self._arrays[array_name].shape # type: ignore[union-attr] + return shape[0] if len(shape) >= 4 else 1 + + def n_heads(self, array_name: str) -> int: + shape = self._arrays[array_name].shape # type: ignore[union-attr] + if len(shape) >= 4: + return shape[1] + if len(shape) == 3: + return shape[0] + return 1 + + def n_residues(self, array_name: str) -> int: + return self._arrays[array_name].shape[-1] # type: ignore[union-attr] + + # ── Loading ─────────────────────────────────────────────────────────────── + + def load_attention( + self, + array_name: str, + layer_idx: int, + head_idx: Optional[int], + top_k: int = 50, + ) -> Connections: + arr = self._arrays[array_name] + shape = arr.shape # type: ignore[union-attr] + + if len(shape) == 4: + # [n_layers, n_heads, N, N] + layer_data = np.array(arr[layer_idx]) # type: ignore[index] + matrix = layer_data.mean(axis=0) if head_idx is None else layer_data[head_idx] + elif len(shape) == 3: + # [n_heads, N, N] + data = np.array(arr[:]) # type: ignore[index] + matrix = data.mean(axis=0) if head_idx is None else data[head_idx] + else: + # [N, N] + matrix = np.array(arr[:]) # type: ignore[index] + + return _dense_to_topk_connections(matrix.astype(float), top_k) + + def get_sequence(self) -> str: + for key in ("sequence", "seq", "fasta", "metadata/sequence"): + if key in self._arrays: + raw = np.array(self._arrays[key][()]) + if raw.dtype.kind in ("S", "U", "O"): + val = raw.flat[0] + return val.decode() if isinstance(val, bytes) else str(val) + return "" + + def get_pdb_string(self) -> Optional[str]: + for key in ("structure_pdb", "pdb", "structure", "structure/pdb"): + if key in self._arrays: + raw = np.array(self._arrays[key][()]) + if raw.dtype.kind in ("S", "U", "O"): + val = raw.flat[0] + text = val.decode() if isinstance(val, bytes) else str(val) + if "ATOM" in text or "HETATM" in text: + return text + return None + + def __del__(self) -> None: + try: + self._store.close() # type: ignore[attr-defined] + except Exception: + pass + try: + os.unlink(self._tmp_path) + except Exception: + pass + + +# ── Shared helper ───────────────────────────────────────────────────────────── + +def _dense_to_topk_connections(matrix: np.ndarray, top_k: int) -> Connections: + """Convert a dense N×N attention matrix to a sorted top-k connections list.""" + n = matrix.shape[0] + rows, cols = np.triu_indices(n, k=1) + weights = (matrix[rows, cols] + matrix[cols, rows]) / 2.0 + idx = np.argsort(weights)[::-1][:top_k] + return [(int(rows[i]), int(cols[i]), float(weights[i])) for i in idx] + + +def flatten_attention_heads(attn_data): + """ + Converts attention data from: + {head: [(res1, res2, weight), ...]} + + into + [(res1, res2, weight), ...] + + """ + connections = [] + + for head_connections in attn_data.values(): + connections.extend(head_connections) + + return sorted(connections, key=lambda x: x[2], reverse=True) diff --git a/webui/visualization_adapter.py b/webui/visualization_adapter.py new file mode 100644 index 0000000..181ab0b --- /dev/null +++ b/webui/visualization_adapter.py @@ -0,0 +1,108 @@ +""" +visualization_adapter.py + +Middle-layer adapter for Issue #41. + +This file converts stored offline trace data from the archive/reader layer +into visualization-ready formats for the Streamlit UI. +""" + +from typing import Dict, List, Tuple, Optional +import numpy as np + +Connection = Tuple[int, int, float] +AttentionMap = Dict[int, List[Connection]] + + +def flatten_attention_heads(attn_data: AttentionMap) -> List[Connection]: + """ + Converts: + {head: [(res1, res2, weight), ...]} + + into: + [(res1, res2, weight), ...] + + This is the standardized format expected by Dev's visualization components. + """ + connections: List[Connection] = [] + + for head_connections in attn_data.values(): + connections.extend(head_connections) + + return sorted(connections, key=lambda x: x[2], reverse=True) + + +def dense_attention_to_connections( + matrix: np.ndarray, + top_k: int = 50, + symmetric: bool = True, +) -> List[Connection]: + """ + Converts a dense N x N attention matrix into top-k residue connections. + + This supports archive outputs that store full tensors instead of sparse files. + """ + if matrix.ndim != 2: + raise ValueError("Expected a 2D attention matrix.") + + if matrix.shape[0] != matrix.shape[1]: + raise ValueError("Attention matrix must be square.") + + n = matrix.shape[0] + + if symmetric: + rows, cols = np.triu_indices(n, k=1) + weights = (matrix[rows, cols] + matrix[cols, rows]) / 2.0 + else: + rows, cols = np.where(~np.eye(n, dtype=bool)) + weights = matrix[rows, cols] + + top_indices = np.argsort(weights)[::-1][:top_k] + + return [ + (int(rows[i]), int(cols[i]), float(weights[i])) + for i in top_indices + ] + + +def residue_attention_scores( + connections: List[Connection], + n_residues: int, +) -> np.ndarray: + """ + Aggregates connection weights into one score per residue. + + Used for coloring protein structure or plotting residue-level importance. + """ + scores = np.zeros(n_residues) + + for r1, r2, weight in connections: + if 0 <= r1 < n_residues: + scores[r1] += weight + if 0 <= r2 < n_residues: + scores[r2] += weight + + return scores + + +def build_visualization_payload( + fasta_seq: str, + pdb_path: Optional[str], + connections: List[Connection], + layer_idx: int, + head_label: str, +) -> dict: + """ + Creates one clean payload Dev's UI can consume. + """ + n_residues = len(fasta_seq) + + return { + "fasta_seq": fasta_seq, + "pdb_path": pdb_path, + "connections": connections, + "n_residues": n_residues, + "layer_idx": layer_idx, + "head_label": head_label, + "residue_scores": residue_attention_scores(connections, n_residues), + } \ No newline at end of file From 5a954b79bd49c4e534bafbfe521aa2574e9f197e Mon Sep 17 00:00:00 2001 From: archiblesherman Date: Wed, 29 Apr 2026 21:46:08 -0400 Subject: [PATCH 5/5] Clean up offline visualization reader integration --- .gitignore | 5 +++-- .streamlit/config.toml | 1 + vizfold/offline/archive_reader.py | 12 ++++++++++-- webui/app.py | 3 +++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 3681434..ce2f1a1 100644 --- a/.gitignore +++ b/.gitignore @@ -19,9 +19,10 @@ cutlass/ *.a3m *.hhr -# generated archive test outputs +# generated local archive test outputs *.zarr/ -*.zip vizfold_archive.zarr/ vizfold_archive.zip + +# local script from archive team, used for manual smoke testing only create_archive.py \ No newline at end of file diff --git a/.streamlit/config.toml b/.streamlit/config.toml index 3fa35e6..a30ad44 100644 --- a/.streamlit/config.toml +++ b/.streamlit/config.toml @@ -1,2 +1,3 @@ [server] +# Large Zarr archives may exceed Streamlit's default 200 MB upload limit. maxUploadSize = 1000 \ No newline at end of file diff --git a/vizfold/offline/archive_reader.py b/vizfold/offline/archive_reader.py index fd695b0..6d7f9a5 100644 --- a/vizfold/offline/archive_reader.py +++ b/vizfold/offline/archive_reader.py @@ -1,6 +1,5 @@ from __future__ import annotations -import shutil import tempfile import zipfile @@ -744,7 +743,16 @@ def _array_to_text(array: Any) -> str | None: return None return None - + + # Lightweight visualization fallback. + # + # Some archive writers store structure as coordinates instead of full PDB text. + # The Streamlit viewer can render PDB strings, so this converts coordinate-only + # archives into a simple CA-only PDB-like representation. + # + # This is not intended to reconstruct a complete biologically accurate PDB. + # A final archive schema should ideally include atom names, residue names, + # chain IDs, residue indices, and/or a real structure_pdb field. @staticmethod def _coords_to_pdb(coords: np.ndarray, sequence: str | None = None) -> str: arr = np.asarray(coords, dtype=float) diff --git a/webui/app.py b/webui/app.py index 868501a..021db52 100644 --- a/webui/app.py +++ b/webui/app.py @@ -67,6 +67,9 @@ def _get_archive_reader(uploaded_file) -> ArchiveReader: return st.session_state["_archive_reader"] # ── Sidebar ─────────────────────────────────────────────────────────────────── +# The UI supports two data paths: +# 1. Trace directory: legacy/sample text-output workflow used by the existing UI. +# 2. Zarr archive: standardized archive workflow backed by vizfold.offline.ArchiveReader. # Shared rendering outputs — populated by whichever source branch runs connections: list = []