diff --git a/.gitignore b/.gitignore index 3f1d8382..ce2f1a1a 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,11 @@ cutlass/ *.sto *.a3m *.hhr + +# generated local archive test outputs +*.zarr/ +vizfold_archive.zarr/ +vizfold_archive.zip + +# local script from archive team, used for manual smoke testing only +create_archive.py \ No newline at end of file diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 00000000..a30ad446 --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,3 @@ +[server] +# Large Zarr archives may exceed Streamlit's default 200 MB upload limit. +maxUploadSize = 1000 \ No newline at end of file diff --git a/environment.yml b/environment.yml index e02d1b4d..c9097a61 100644 --- a/environment.yml +++ b/environment.yml @@ -33,6 +33,9 @@ dependencies: - bioconda::kalign2 - pytorch::pytorch=2.5 - pytorch::pytorch-cuda=12.4 + - zarr=2.16 + - numcodecs + - fsspec - pip: - deepspeed==0.14.5 - dm-tree==0.1.6 diff --git a/tests/test_archive_reader_contract.py b/tests/test_archive_reader_contract.py new file mode 100644 index 00000000..ce5e2461 --- /dev/null +++ b/tests/test_archive_reader_contract.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from vizfold.offline import ArchiveReader + +zarr = pytest.importorskip("zarr") + + +def test_archive_reader_loads_issue39_style_zarr(tmp_path): + archive_path = tmp_path / "toy.vizfold.zarr" + + root = zarr.open_group(str(archive_path), mode="w") + + metadata = root.require_group("metadata") + metadata.attrs["model_version"] = "openfold-test" + metadata.attrs["config_version"] = "model_1" + metadata.attrs["sequence"] = "ACDE" + metadata.attrs["num_residues"] = 4 + metadata.attrs["num_recycles"] = 1 + metadata.create_dataset( + "residue_index", + data=np.arange(4), + shape=(4,), + ) + + attention = root.require_group("attention").require_group("triangle_start") + + # Issue-39 documented shape for triangle attention: + # (num_residues, num_residues, num_heads) + arr = np.zeros((4, 4, 2), dtype=np.float32) + arr[0, 2, 1] = 0.90 + arr[3, 1, 1] = 0.40 + arr[1, 0, 0] = 0.75 + + attention.create_dataset( + "layer_00", + data=arr, + shape=arr.shape, + chunks=(4, 4, 1), + ) + + reps = root.require_group("representations") + single_arr = np.ones((4, 8), dtype=np.float32) + reps.require_group("single").create_dataset( + "layer_00", + data=single_arr, + shape=single_arr.shape, + ) + pair_arr = np.ones((4, 4, 16), dtype=np.float32) + reps.require_group("pair").create_dataset( + "layer_00", + data=pair_arr, + shape=pair_arr.shape, + ) + + structure = root.require_group("structure") + coords = np.array( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [3.0, 0.0, 0.0], + ], + dtype=np.float32, + ) + + structure.create_dataset( + "atom_positions", + data=coords, + shape=coords.shape, + ) + + reader = ArchiveReader(archive_path) + + meta = reader.metadata() + assert meta.archive_kind == "zarr" + assert meta.model_version == "openfold-test" + assert meta.sequence == "ACDE" + assert meta.structure_available is True + + assert reader.list_attention_types() == ["triangle_start"] + assert reader.list_layers("triangle_start") == [0] + assert reader.list_heads("triangle_start", 0) == [0, 1] + + loaded = reader.load_attention( + attention_type="triangle_start", + layer=0, + head=1, + top_k=2, + ) + + assert loaded.attention_type == "triangle_start" + assert loaded.layer == 0 + assert loaded.head == 1 + assert loaded.as_triplets()[0] == (0, 2, pytest.approx(0.90)) + assert loaded.as_triplets()[1] == (3, 1, pytest.approx(0.40)) + + all_heads = reader.load_attention_heads("triangle_start", 0, top_k=1) + assert sorted(all_heads.keys()) == [0, 1] + + single = reader.load_single_representation(0) + pair = reader.load_pair_representation(0) + + assert single.shape == (4, 8) + assert pair.shape == (4, 4, 16) + + structure_data = reader.load_structure() + assert structure_data.sequence == "ACDE" + assert structure_data.pdb_text is not None + assert "ATOM" in structure_data.pdb_text + +def test_archive_reader_missing_structure_does_not_crash(tmp_path): + archive_path = tmp_path / "no_structure.zarr" + root = zarr.open_group(str(archive_path), mode="w") + + metadata = root.require_group("metadata") + metadata.attrs["sequence"] = "ACDE" + metadata.attrs["model_version"] = "openfold-test" + + attention = root.require_group("attention").require_group("triangle_start") + + arr = np.ones((4, 4, 1), dtype=np.float32) + attention.create_dataset( + "layer_00", + data=arr, + shape=arr.shape, + chunks=(4, 4, 1), + ) + + reader = ArchiveReader(archive_path) + + meta = reader.metadata() + assert meta.structure_available is False + assert "attention" in meta.capabilities + assert "structure" not in meta.capabilities + + structure = reader.load_structure() + assert structure.sequence == "ACDE" + assert structure.pdb_text is None + + +def test_archive_reader_rejects_bad_head(tmp_path): + archive_path = tmp_path / "bad_head.zarr" + root = zarr.open_group(str(archive_path), mode="w") + + metadata = root.require_group("metadata") + metadata.attrs["sequence"] = "ACDE" + + attention = root.require_group("attention").require_group("triangle_start") + + arr = np.ones((4, 4, 2), dtype=np.float32) + attention.create_dataset( + "layer_00", + data=arr, + shape=arr.shape, + chunks=(4, 4, 1), + ) + + reader = ArchiveReader(archive_path) + + assert reader.list_heads("triangle_start", 0) == [0, 1] + + with pytest.raises(IndexError): + reader.load_attention("triangle_start", layer=0, head=5) + + +def test_archive_reader_handles_residue_specific_4d_attention(tmp_path): + archive_path = tmp_path / "residue_attention.zarr" + root = zarr.open_group(str(archive_path), mode="w") + + metadata = root.require_group("metadata") + metadata.attrs["sequence"] = "ACDE" + + attention = root.require_group("attention").require_group("triangle_start") + + # Shape: (heads, residue_index, src_residue, dst_residue) + arr = np.zeros((2, 4, 4, 4), dtype=np.float32) + arr[1, 2, 0, 3] = 0.95 + arr[1, 2, 1, 1] = 0.50 + arr[0, 0, 2, 2] = 0.80 + + attention.create_dataset( + "layer_00", + data=arr, + shape=arr.shape, + chunks=(1, 1, 4, 4), + ) + + reader = ArchiveReader(archive_path) + + assert reader.list_attention_types() == ["triangle_start"] + assert reader.list_layers("triangle_start") == [0] + assert reader.list_heads("triangle_start", 0) == [0, 1] + assert reader.list_residue_indices("triangle_start", 0) == [0, 1, 2, 3] + + loaded = reader.load_attention( + attention_type="triangle_start", + layer=0, + head=1, + residue_idx=2, + top_k=2, + ) + + assert loaded.residue_idx == 2 + assert loaded.as_triplets()[0] == (0, 3, pytest.approx(0.95)) + assert loaded.as_triplets()[1] == (1, 1, pytest.approx(0.50)) + \ No newline at end of file diff --git a/tests/test_legacy_txt_reader.py b/tests/test_legacy_txt_reader.py new file mode 100644 index 00000000..fc25fb70 --- /dev/null +++ b/tests/test_legacy_txt_reader.py @@ -0,0 +1,98 @@ +from pathlib import Path + +from vizfold.offline import LegacyTxtReader + + +def _write_text(path: Path, text: str) -> None: + path.write_text(text.strip() + "\n", encoding="utf-8") + + +def test_legacy_txt_reader_metadata_and_loading(tmp_path: Path) -> None: + attention_dir = tmp_path / "attn" + attention_dir.mkdir() + + fasta_path = tmp_path / "toy.fasta" + pdb_path = tmp_path / "toy_unrelaxed.pdb" + + _write_text( + fasta_path, + """ + >toy + ACDEFG + """, + ) + + _write_text( + pdb_path, + """ + HEADER TOY PDB + ATOM 1 N ALA A 1 11.104 13.207 9.947 1.00 50.00 N + END + """, + ) + + _write_text( + attention_dir / "msa_row_attn_layer47.txt", + """ + Layer 47, Head 0 + 0 1 0.90 + 1 3 0.70 + + Layer 47, Head 2 + 2 5 0.95 + 0 4 0.60 + """, + ) + + _write_text( + attention_dir / "triangle_start_attn_layer47_residue_idx_18.txt", + """ + Layer 47, Head 0 + 18 20 0.80 + 18 21 0.50 + + Layer 47, Head 1 + 18 30 0.92 + """, + ) + + reader = LegacyTxtReader( + attention_dir=attention_dir, + fasta_path=fasta_path, + pdb_path=pdb_path, + protein_id="toy", + ) + + meta = reader.metadata() + assert meta.protein_id == "toy" + assert meta.sequence == "ACDEFG" + assert set(meta.attention_types) == {"msa_row", "triangle_start"} + assert meta.layers_by_type["msa_row"] == [47] + assert meta.layers_by_type["triangle_start"] == [47] + assert meta.residue_indices_by_type["triangle_start"][47] == [18] + + msa_heads = reader.list_heads("msa_row", 47) + assert msa_heads == [0, 2] + + tri_heads = reader.list_heads("triangle_start", 47, residue_idx=18) + assert tri_heads == [0, 1] + + msa_slice = reader.load_attention("msa_row", layer=47, head=2) + assert msa_slice.as_triplets()[0] == (2, 5, 0.95) + assert msa_slice.as_triplets()[1] == (0, 4, 0.60) + + tri_slice = reader.load_attention( + "triangle_start", + layer=47, + head=0, + residue_idx=18, + top_k=1, + ) + assert tri_slice.residue_idx == 18 + assert tri_slice.as_triplets() == [(18, 20, 0.80)] + + structure = reader.load_structure() + assert structure.protein_id == "toy" + assert structure.sequence == "ACDEFG" + assert structure.pdb_text is not None + assert "HEADER TOY PDB" in structure.pdb_text \ No newline at end of file diff --git a/vizfold/__init__.py b/vizfold/__init__.py new file mode 100644 index 00000000..71c87dd7 --- /dev/null +++ b/vizfold/__init__.py @@ -0,0 +1 @@ +"""VizFold utilities.""" \ No newline at end of file diff --git a/vizfold/offline/__init__.py b/vizfold/offline/__init__.py new file mode 100644 index 00000000..87d8ba48 --- /dev/null +++ b/vizfold/offline/__init__.py @@ -0,0 +1,19 @@ +from .archive_reader import ArchiveReader +from .legacy_txt_reader import LegacyTxtReader +from .models import ( + AttentionConnection, + AttentionSlice, + StructureData, + TraceMetadata, +) +from .trace_reader import TraceReader + +__all__ = [ + "ArchiveReader", + "LegacyTxtReader", + "AttentionConnection", + "AttentionSlice", + "StructureData", + "TraceMetadata", + "TraceReader", +] \ No newline at end of file diff --git a/vizfold/offline/archive_reader.py b/vizfold/offline/archive_reader.py new file mode 100644 index 00000000..6d7f9a53 --- /dev/null +++ b/vizfold/offline/archive_reader.py @@ -0,0 +1,793 @@ +from __future__ import annotations + +import tempfile +import zipfile + +import re +from pathlib import Path +from typing import Any + +import numpy as np + +from .models import AttentionConnection, AttentionSlice, StructureData, TraceMetadata +from .trace_reader import TraceReader + +try: + import zarr # type: ignore +except Exception: # pragma: no cover + zarr = None + + +_LAYER_RE = re.compile(r"(?:^|[/_\-])layer[_\-]?(\d+)$|^(\d+)$") + + +class ArchiveReader(TraceReader): + """ + Working reader for standardized VizFold/OpenFold Zarr trace archives. + + Primary supported issue-39 layout: + + metadata/ + representations/single/layer_00 + representations/pair/layer_00 + attention/triangle_start/layer_00 + structure/atom_positions + + Also tolerates prototype layouts like: + + attention/layer_0 + attention # shape [layers, heads, N, N] + inputs/sequence + outputs/coordinates + structure_pdb + """ + + def __init__(self, archive_root: str | Path) -> None: + self.archive_root = Path(archive_root) + self._store: Any | None = None + self._extracted_archive_root: Path | None = None + self._root = self._open_root(self.archive_root) + self._metadata_cache: TraceMetadata | None = None + + def metadata(self) -> TraceMetadata: + if self._metadata_cache is not None: + return self._metadata_cache + + sequence = self.get_sequence() + attention_types = self.list_attention_types() + layers_by_type = { + attn_type: self.list_layers(attn_type) + for attn_type in attention_types + } + heads_by_type = { + attn_type: { + layer: self.list_heads(attn_type, layer) + for layer in layers + } + for attn_type, layers in layers_by_type.items() + } + + residue_indices_by_type: dict[str, dict[int, list[int]]] = {} + for attn_type, layers in layers_by_type.items(): + for layer in layers: + indices = self.list_residue_indices(attn_type, layer) + if indices: + residue_indices_by_type.setdefault(attn_type, {})[layer] = indices + + has_structure = ( + self.get_pdb_string() is not None + or self._find_first_array( + "structure/atom_positions", + "outputs/coordinates", + "coordinates", + ) + is not None + ) + + capabilities = ["metadata", "partial_loading"] + if attention_types: + capabilities.append("attention") + if has_structure: + capabilities.append("structure") + + self._metadata_cache = TraceMetadata( + protein_id=str(self._root.attrs.get("protein_id", self.archive_root.stem)), + source_root=self.archive_root, + sequence=sequence, + attention_types=attention_types, + layers_by_type=layers_by_type, + heads_by_type=heads_by_type, + residue_indices_by_type=residue_indices_by_type, + schema_version=self._read_text_metadata( + "schema_version", + "archive_version", + "spec_version", + ), + archive_kind="zarr", + model_family=self._read_text_metadata("model_family", "model_name"), + model_version=self._read_text_metadata("model_version", "config_version"), + structure_available=has_structure, + capabilities=capabilities, + extras={ + "num_residues": self._read_int_metadata("num_residues"), + "num_recycles": self._read_int_metadata("num_recycles"), + "arrays": self.list_all_arrays(), + }, + ) + return self._metadata_cache + + def list_attention_types(self) -> list[str]: + types: set[str] = set() + + if "attention" in self._root: + node = self._root["attention"] + + if self._is_group(node): + for key in self._keys(node): + child = node[key] + if self._is_group(child): + types.add(str(key)) + + direct_layer_arrays = [ + key + for key in self._keys(node) + if self._is_array(node[key]) + and self._parse_layer_key(key) is not None + ] + if direct_layer_arrays: + types.add("triangle_start") + + elif self._is_array(node): + types.add("attention") + + for name, shape in self.list_attention_arrays().items(): + if name != "attention" and not name.startswith("attention/"): + types.add(name) + + return sorted(types) + + def list_layers(self, attention_type: str) -> list[int]: + return sorted(self._discover_layer_numbers(attention_type)) + + def list_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + ) -> list[int]: + arr = self._load_attention_tensor(attention_type, layer) + return list(range(self._num_heads_from_shape(arr.shape))) + + def list_residue_indices(self, attention_type: str, layer: int) -> list[int]: + arr = self._load_attention_tensor(attention_type, layer) + if arr.ndim != 4: + return [] + + n = self._residue_count_from_attention_shape(arr.shape) + return list(range(n)) if n is not None else [] + + def load_attention( + self, + attention_type: str, + layer: int, + head: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> AttentionSlice: + arr = self._load_attention_tensor(attention_type, layer) + matrix = self._slice_head_matrix(arr, head=head, residue_idx=residue_idx) + connections = self._matrix_to_connections(matrix, top_k=top_k) + + return AttentionSlice( + attention_type=attention_type, + layer=layer, + head=head, + residue_idx=residue_idx, + connections=connections, + ) + + def load_attention_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> dict[int, AttentionSlice]: + return { + head: self.load_attention( + attention_type=attention_type, + layer=layer, + head=head, + residue_idx=residue_idx, + top_k=top_k, + ) + for head in self.list_heads(attention_type, layer, residue_idx) + } + + def load_structure(self) -> StructureData: + pdb_text = self.get_pdb_string() + sequence = self.get_sequence() + + if pdb_text is None: + coords = self._find_first_array( + "structure/atom_positions", + "outputs/coordinates", + "coordinates", + ) + if coords is not None: + pdb_text = self._coords_to_pdb(np.asarray(coords), sequence) + + protein_id = ( + self._metadata_cache.protein_id + if self._metadata_cache is not None + else str(self._root.attrs.get("protein_id", self.archive_root.stem)) + ) + + return StructureData( + protein_id=protein_id, + pdb_path=None, + pdb_text=pdb_text, + sequence=sequence, + ) + + def load_single_representation(self, layer: int) -> np.ndarray: + node = self._resolve_layered_array( + layer, + "representations/single", + "single", + "activations", + dataset_names=("single_repr", "activation", "values"), + ) + if node is None: + raise KeyError(f"Single representation not found for layer {layer}") + return self._to_numpy(node) + + def load_pair_representation(self, layer: int) -> np.ndarray: + node = self._resolve_layered_array( + layer, + "representations/pair", + "pair", + "activations", + dataset_names=("pair_repr", "values"), + ) + if node is None: + raise KeyError(f"Pair representation not found for layer {layer}") + return self._to_numpy(node) + + def list_all_arrays(self) -> dict[str, tuple[int, ...]]: + return { + name: tuple(int(x) for x in array.shape) + for name, array in self._walk_arrays(self._root) + } + + def list_attention_arrays(self) -> dict[str, tuple[int, ...]]: + out: dict[str, tuple[int, ...]] = {} + for name, shape in self.list_all_arrays().items(): + lowered = name.lower() + if self._is_attention_shape(shape) and ( + "att" in lowered or "triangle" in lowered or name == "attention" + ): + out[name] = shape + return out + + def get_sequence(self) -> str | None: + value = self._read_text_metadata("sequence") + if value: + return value + + for path in ("inputs/sequence", "sequence"): + arr = self._get_array(path) + if arr is not None: + text = self._array_to_text(arr) + if text: + return text + return None + + def get_pdb_string(self) -> str | None: + for path in ( + "structure_pdb", + "structure/pdb", + "structure/pdb_text", + "outputs/structure_pdb", + ): + arr = self._get_array(path) + if arr is not None: + text = self._array_to_text(arr) + if text: + return text + return None + + def _open_root(self, path: Path) -> Any: + if zarr is None: + raise ImportError("ArchiveReader requires zarr. Install with `pip install zarr`.") + + if not path.exists(): + raise FileNotFoundError(f"Archive path does not exist: {path}") + + if path.suffix == ".zip": + extracted_dir = Path(tempfile.mkdtemp(prefix="vizfold_zarr_")) + + with zipfile.ZipFile(path, "r") as zf: + zf.extractall(extracted_dir) + + # Case 1: user zipped the contents of the .zarr folder + if (extracted_dir / "zarr.json").exists() or (extracted_dir / ".zgroup").exists(): + self._extracted_archive_root = extracted_dir + return zarr.open_group(str(extracted_dir), mode="r") + + # Case 2: user zipped the .zarr folder itself + zarr_dirs = list(extracted_dir.glob("*.zarr")) + if zarr_dirs: + self._extracted_archive_root = zarr_dirs[0] + return zarr.open_group(str(zarr_dirs[0]), mode="r") + + # Case 3: zip contains one top-level folder + child_dirs = [p for p in extracted_dir.iterdir() if p.is_dir()] + if len(child_dirs) == 1: + self._extracted_archive_root = child_dirs[0] + return zarr.open_group(str(child_dirs[0]), mode="r") + + raise ValueError(f"Could not find Zarr archive root inside zip file: {path}") + + return zarr.open_group(str(path), mode="r") + + def _discover_layer_numbers(self, attention_type: str) -> set[int]: + layers: set[int] = set() + + group = self._attention_group_for_type(attention_type) + if group is not None: + for key in self._keys(group): + child = group[key] + layer = self._parse_layer_key(key) + if layer is not None and self._is_array_or_group(child): + layers.add(layer) + + arr = self._attention_array_for_type(attention_type) + if arr is not None and arr.ndim == 4 and self._looks_like_layered_attention(arr.shape): + layers.update(range(int(arr.shape[0]))) + + if layers: + return layers + + for name in self.list_attention_arrays(): + layer = self._parse_layer_key(name.split("/")[-1]) + if layer is not None: + layers.add(layer) + + return layers + + def _load_attention_tensor(self, attention_type: str, layer: int) -> np.ndarray: + layered = self._attention_array_for_type(attention_type) + if ( + layered is not None + and layered.ndim == 4 + and self._looks_like_layered_attention(layered.shape) + ): + if layer >= layered.shape[0]: + raise IndexError(f"Layer {layer} out of range for {attention_type}") + return np.asarray(layered[layer]) + + group = self._attention_group_for_type(attention_type) + if group is not None: + for layer_key in self._layer_key_candidates(layer): + if layer_key not in group: + continue + + node = group[layer_key] + if self._is_array(node): + return self._to_numpy(node) + + if self._is_group(node): + for dataset_name in ("attention", "values", "heads"): + if dataset_name in node and self._is_array(node[dataset_name]): + return np.asarray(node[dataset_name]) + + for layer_key in self._layer_key_candidates(layer): + arr = self._get_array(f"attention/{layer_key}") + if arr is not None: + return np.asarray(arr) + + arr = self._get_array(attention_type) + if arr is not None: + data = np.asarray(arr) + if data.ndim == 4 and self._looks_like_layered_attention(data.shape): + return data[layer] + return data + + raise KeyError(f"Attention not found for type={attention_type}, layer={layer}") + + def _attention_group_for_type(self, attention_type: str) -> Any | None: + candidates = [] + + if attention_type != "attention": + candidates.append(f"attention/{attention_type}") + + candidates.append("attention") + + for path in candidates: + node = self._get_node(path) + if node is not None and self._is_group(node): + return node + + return None + + def _attention_array_for_type(self, attention_type: str) -> Any | None: + candidates = [] + + if attention_type != "attention": + candidates.append(f"attention/{attention_type}") + + candidates.extend((attention_type, "attention")) + + for path in candidates: + node = self._get_node(path) + if node is not None and self._is_array(node): + return node + + return None + + def _resolve_layered_array( + self, + layer: int, + *groups: str, + dataset_names: tuple[str, ...], + ) -> Any | None: + for group_path in groups: + group = self._get_node(group_path) + if group is None or not self._is_group(group): + continue + + for layer_key in self._layer_key_candidates(layer): + if layer_key not in group: + continue + + node = group[layer_key] + + if self._is_array(node): + return node + + if self._is_group(node): + for name in dataset_names: + if name in node and self._is_array(node[name]): + return node[name] + + return None + + @staticmethod + def _slice_head_matrix( + arr: np.ndarray, + head: int, + residue_idx: int | None, + ) -> np.ndarray: + if arr.ndim == 2: + if head != 0: + raise IndexError("2-D attention matrix only has head 0") + return arr + + if arr.ndim == 3: + axis = ArchiveReader._infer_head_axis(arr.shape) + + if axis == 0: + return arr[head, :, :] + + if axis == 2: + return arr[:, :, head] + + raise ValueError(f"Cannot infer head axis from attention shape {arr.shape}") + + if arr.ndim == 4: + axis = ArchiveReader._infer_head_axis(arr.shape) + + if axis == 0: + cube = arr[head] + + elif axis == 3: + cube = arr[:, :, :, head] + + else: + raise ValueError(f"Cannot infer head axis from attention shape {arr.shape}") + + if residue_idx is None: + return np.asarray(cube).mean(axis=0) + + return cube[residue_idx, :, :] + + raise ValueError(f"Unsupported attention tensor shape {arr.shape}") + + @staticmethod + def _matrix_to_connections( + matrix: np.ndarray, + top_k: int | None, + ) -> list[AttentionConnection]: + mat = np.asarray(matrix, dtype=float) + + if mat.ndim != 2: + raise ValueError(f"Expected 2-D attention matrix, got {mat.shape}") + + flat = mat.reshape(-1) + + if flat.size == 0 or top_k == 0: + return [] + + k = flat.size if top_k is None else min(int(top_k), flat.size) + + if k < flat.size: + idx = np.argpartition(-flat, k - 1)[:k] + idx = idx[np.argsort(-flat[idx])] + else: + idx = np.argsort(-flat) + + n_cols = mat.shape[1] + connections: list[AttentionConnection] = [] + + for flat_idx in idx: + weight = float(flat[flat_idx]) + + if np.isnan(weight): + continue + + src = int(flat_idx // n_cols) + dst = int(flat_idx % n_cols) + + connections.append( + AttentionConnection( + src=src, + dst=dst, + weight=weight, + ) + ) + + return connections + + @staticmethod + def _infer_head_axis(shape: tuple[int, ...]) -> int: + if len(shape) == 3: + if shape[0] <= 64 and shape[1] == shape[2]: + return 0 + if shape[2] <= 64 and shape[0] == shape[1]: + return 2 + + if len(shape) == 4: + if shape[0] <= 64 and shape[1] == shape[2] == shape[3]: + return 0 + if shape[3] <= 64 and shape[0] == shape[1] == shape[2]: + return 3 + + raise ValueError(f"Cannot infer head axis from shape {shape}") + + @staticmethod + def _num_heads_from_shape(shape: tuple[int, ...]) -> int: + if len(shape) == 2: + return 1 + + axis = ArchiveReader._infer_head_axis(shape) + return int(shape[axis]) + + @staticmethod + def _residue_count_from_attention_shape(shape: tuple[int, ...]) -> int | None: + if len(shape) == 2 and shape[0] == shape[1]: + return int(shape[0]) + + if len(shape) == 3: + axis = ArchiveReader._infer_head_axis(shape) + return int(shape[1] if axis == 0 else shape[0]) + + if len(shape) == 4: + axis = ArchiveReader._infer_head_axis(shape) + return int(shape[1] if axis == 0 else shape[0]) + + return None + + @staticmethod + def _looks_like_layered_attention(shape: tuple[int, ...]) -> bool: + return len(shape) == 4 and shape[1] <= 64 and shape[2] == shape[3] + + @staticmethod + def _is_attention_shape(shape: tuple[int, ...]) -> bool: + if len(shape) == 2: + return shape[0] == shape[1] + + if len(shape) == 3: + return ( + shape[0] <= 64 + and shape[1] == shape[2] + ) or ( + shape[2] <= 64 + and shape[0] == shape[1] + ) + + if len(shape) == 4: + return ( + shape[0] <= 64 + and shape[1] == shape[2] == shape[3] + ) or ( + shape[1] <= 64 + and shape[2] == shape[3] + ) or ( + shape[3] <= 64 + and shape[0] == shape[1] == shape[2] + ) + + return False + + def _read_text_metadata(self, *keys: str) -> str | None: + for key in keys: + if key in self._root.attrs and self._root.attrs[key] is not None: + return str(self._root.attrs[key]) + + arr = self._get_array(f"metadata/{key}") + if arr is not None: + text = self._array_to_text(arr) + if text: + return text + + meta = self._get_node("metadata") + if meta is not None and self._is_group(meta) and key in meta.attrs: + return str(meta.attrs[key]) + + return None + + def _read_int_metadata(self, key: str) -> int | None: + text = self._read_text_metadata(key) + + if text is None: + return None + + try: + return int(float(text)) + except ValueError: + return None + + def _find_first_array(self, *paths: str) -> np.ndarray | None: + for path in paths: + arr = self._get_array(path) + if arr is not None: + return self._to_numpy(arr) + + return None + + def _get_array(self, path: str) -> Any | None: + node = self._get_node(path) + + if node is not None and self._is_array(node): + return node + + return None + + def _get_node(self, path: str) -> Any | None: + node = self._root + + if not path: + return node + + for part in path.strip("/").split("/"): + if not self._is_group(node) or part not in node: + return None + + node = node[part] + + return node + + @staticmethod + def _keys(group: Any) -> list[str]: + return sorted(str(k) for k in group.keys()) + + @staticmethod + def _is_array(node: Any) -> bool: + return hasattr(node, "shape") and hasattr(node, "dtype") + + @staticmethod + def _is_group(node: Any) -> bool: + return hasattr(node, "keys") and not ArchiveReader._is_array(node) + + @staticmethod + def _is_array_or_group(node: Any) -> bool: + return ArchiveReader._is_array(node) or ArchiveReader._is_group(node) + + def _walk_arrays(self, group: Any, prefix: str = ""): + for key in self._keys(group): + node = group[key] + name = f"{prefix}/{key}" if prefix else key + + if self._is_array(node): + yield name, node + + elif self._is_group(node): + yield from self._walk_arrays(node, name) + + @staticmethod + def _parse_layer_key(key: str) -> int | None: + match = _LAYER_RE.search(key) + + if not match: + return None + + value = match.group(1) or match.group(2) + return int(value) + + @staticmethod + def _layer_key_candidates(layer: int) -> tuple[str, ...]: + return ( + f"layer_{layer:02d}", + f"layer_{layer}", + str(layer), + ) + + @staticmethod + def _array_to_text(array: Any) -> str | None: + value = ArchiveReader._to_numpy(array) + + if value.shape == (): + item = value.item() + + if isinstance(item, bytes): + return item.decode("utf-8") + + return str(item) + + if value.dtype.kind in {"U", "S", "O"}: + parts = [] + + for item in value.reshape(-1): + if isinstance(item, bytes): + parts.append(item.decode("utf-8")) + else: + parts.append(str(item)) + + return "".join(parts) + + if value.dtype == np.uint8: + try: + return bytes(value.reshape(-1)).decode("utf-8") + except UnicodeDecodeError: + return None + + return None + + # Lightweight visualization fallback. + # + # Some archive writers store structure as coordinates instead of full PDB text. + # The Streamlit viewer can render PDB strings, so this converts coordinate-only + # archives into a simple CA-only PDB-like representation. + # + # This is not intended to reconstruct a complete biologically accurate PDB. + # A final archive schema should ideally include atom names, residue names, + # chain IDs, residue indices, and/or a real structure_pdb field. + @staticmethod + def _coords_to_pdb(coords: np.ndarray, sequence: str | None = None) -> str: + arr = np.asarray(coords, dtype=float) + + if arr.ndim == 3 and arr.shape[-1] == 3: + arr = arr[:, 1, :] + + elif ( + arr.ndim == 2 + and arr.shape[1] == 3 + and sequence + and arr.shape[0] == len(sequence) * 37 + ): + arr = arr.reshape(len(sequence), 37, 3)[:, 1, :] + + elif arr.ndim != 2 or arr.shape[1] != 3: + raise ValueError(f"Cannot convert coordinates with shape {coords.shape} to PDB") + + lines = [] + + for idx, (x, y, z) in enumerate(arr, start=1): + res = sequence[idx - 1] if sequence and idx - 1 < len(sequence) else "X" + lines.append( + f"ATOM {idx:5d} CA {res:>3s} A{idx:4d} " + f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C" + ) + + lines.append("END") + return "\n".join(lines) + "\n" + + @staticmethod + def _to_numpy(array: Any) -> np.ndarray: + if hasattr(array, "shape") and hasattr(array, "__getitem__"): + if tuple(array.shape) == (): + return np.asarray(array[()]) + return np.asarray(array[:]) + + return np.asarray(array) \ No newline at end of file diff --git a/vizfold/offline/exceptions.py b/vizfold/offline/exceptions.py new file mode 100644 index 00000000..a0a8ffca --- /dev/null +++ b/vizfold/offline/exceptions.py @@ -0,0 +1,14 @@ +class TraceReaderError(Exception): + """Base exception for offline trace reading.""" + + +class TraceFormatError(TraceReaderError): + """Raised when a trace file exists but does not match the expected format.""" + + +class TraceNotFoundError(TraceReaderError): + """Raised when a requested trace file or resource cannot be found.""" + + +class UnsupportedAttentionTypeError(TraceReaderError): + """Raised when an attention type is not supported by the reader.""" \ No newline at end of file diff --git a/vizfold/offline/legacy_txt_reader.py b/vizfold/offline/legacy_txt_reader.py new file mode 100644 index 00000000..c6a2418a --- /dev/null +++ b/vizfold/offline/legacy_txt_reader.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import re +from pathlib import Path + +from .exceptions import TraceFormatError, TraceNotFoundError +from .models import AttentionConnection, AttentionSlice, StructureData, TraceMetadata +from .paths import parse_legacy_attention_filename, resolve_legacy_attention_path +from .trace_reader import TraceReader + + +class LegacyTxtReader(TraceReader): + """ + Reader for VizFold's current legacy text-dump attention format. + + This wraps the existing file naming and parsing conventions behind a stable API + so the frontend can stop depending on hardcoded filenames. + """ + + def __init__( + self, + attention_dir: str | Path, + fasta_path: str | Path | None = None, + pdb_path: str | Path | None = None, + protein_id: str | None = None, + ) -> None: + self.attention_dir = Path(attention_dir) + if not self.attention_dir.exists(): + raise TraceNotFoundError(f"attention_dir does not exist: {self.attention_dir}") + + self.fasta_path = Path(fasta_path) if fasta_path is not None else None + self.pdb_path = Path(pdb_path) if pdb_path is not None else None + self.protein_id = protein_id or self._infer_protein_id() + + self._metadata_cache: TraceMetadata | None = None + + def metadata(self) -> TraceMetadata: + if self._metadata_cache is None: + self._metadata_cache = self._build_metadata() + return self._metadata_cache + + def list_attention_types(self) -> list[str]: + return self.metadata().attention_types + + def list_layers(self, attention_type: str) -> list[int]: + return self.metadata().layers_by_type.get(attention_type, []) + + def list_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + ) -> list[int]: + path = resolve_legacy_attention_path( + self.attention_dir, + attention_type=attention_type, + layer=layer, + residue_idx=residue_idx, + ) + heads = self._parse_heads_file(path) + return sorted(heads.keys()) + + def list_residue_indices( + self, + attention_type: str, + layer: int, + ) -> list[int]: + return self.metadata().residue_indices_by_type.get(attention_type, {}).get(layer, []) + + def load_attention( + self, + attention_type: str, + layer: int, + head: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> AttentionSlice: + all_heads = self.load_attention_heads( + attention_type=attention_type, + layer=layer, + residue_idx=residue_idx, + top_k=top_k, + ) + + if head not in all_heads: + available = sorted(all_heads.keys()) + raise TraceNotFoundError( + f"Head {head} not found for attention_type={attention_type}, " + f"layer={layer}, residue_idx={residue_idx}. Available heads: {available}" + ) + + return all_heads[head] + + def load_attention_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> dict[int, AttentionSlice]: + path = resolve_legacy_attention_path( + self.attention_dir, + attention_type=attention_type, + layer=layer, + residue_idx=residue_idx, + ) + parsed = self._parse_heads_file(path) + + slices: dict[int, AttentionSlice] = {} + for head_idx, connections in parsed.items(): + if top_k is not None: + connections = connections[:top_k] + + slices[head_idx] = AttentionSlice( + attention_type=attention_type, + layer=layer, + head=head_idx, + residue_idx=residue_idx, + connections=connections, + ) + + return slices + + def load_structure(self) -> StructureData: + pdb_text = None + if self.pdb_path is not None: + if not self.pdb_path.exists(): + raise TraceNotFoundError(f"pdb_path does not exist: {self.pdb_path}") + pdb_text = self.pdb_path.read_text(encoding="utf-8") + + sequence = self._read_fasta_sequence(self.fasta_path) if self.fasta_path else None + + return StructureData( + protein_id=self.protein_id, + pdb_path=self.pdb_path, + pdb_text=pdb_text, + sequence=sequence, + ) + + def _build_metadata(self) -> TraceMetadata: + layers_by_type: dict[str, set[int]] = {} + heads_by_type_sets: dict[str, dict[int, set[int]]] = {} + residue_indices_by_type: dict[str, dict[int, set[int]]] = {} + + for path in self.attention_dir.iterdir(): + if not path.is_file(): + continue + + parsed = parse_legacy_attention_filename(path.name) + if parsed is None: + continue + + attention_type, layer, residue_idx = parsed + + layers_by_type.setdefault(attention_type, set()).add(layer) + heads_by_type_sets.setdefault(attention_type, {}).setdefault(layer, set()) + + file_heads = self._parse_heads_file(path).keys() + heads_by_type_sets[attention_type][layer].update(file_heads) + + if residue_idx is not None: + residue_indices_by_type.setdefault(attention_type, {}).setdefault(layer, set()).add( + residue_idx + ) + + attention_types = sorted(layers_by_type.keys()) + sequence = self._read_fasta_sequence(self.fasta_path) if self.fasta_path else None + + heads_by_type: dict[str, dict[int, list[int]]] = { + attn_type: { + layer: sorted(heads) + for layer, heads in layer_map.items() + } + for attn_type, layer_map in heads_by_type_sets.items() + } + + residue_indices_by_type_sorted: dict[str, dict[int, list[int]]] = { + attn_type: { + layer: sorted(residue_indices) + for layer, residue_indices in layer_map.items() + } + for attn_type, layer_map in residue_indices_by_type.items() + } + + return TraceMetadata( + protein_id=self.protein_id, + source_root=self.attention_dir, + fasta_path=self.fasta_path, + pdb_path=self.pdb_path, + sequence=sequence, + attention_types=attention_types, + layers_by_type={ + attn_type: sorted(layers) + for attn_type, layers in layers_by_type.items() + }, + heads_by_type=heads_by_type, + residue_indices_by_type=residue_indices_by_type_sorted, + extras={ + "format": "legacy_txt", + }, + ) + + def _parse_heads_file(self, path: Path) -> dict[int, list[AttentionConnection]]: + """ + Expected format: + + Layer 47, Head 0 + 1 5 0.91 + 2 9 0.88 + + Layer 47, Head 1 + 3 7 0.95 + """ + heads: dict[int, list[AttentionConnection]] = {} + current_head: int | None = None + + with path.open("r", encoding="utf-8") as f: + for raw_line in f: + line = raw_line.strip() + if not line: + continue + + if line.lower().startswith("layer"): + numbers = re.findall(r"-?\d+", line) + if not numbers: + raise TraceFormatError(f"Could not parse head header line: {line}") + current_head = int(numbers[-1]) + heads[current_head] = [] + continue + + if current_head is None: + raise TraceFormatError( + f"Found attention row before any head header in file: {path}" + ) + + parts = line.split() + if len(parts) != 3: + raise TraceFormatError( + f"Expected 3 columns in attention row, got {len(parts)}: {line}" + ) + + src = int(float(parts[0])) + dst = int(float(parts[1])) + weight = float(parts[2]) + + heads[current_head].append( + AttentionConnection(src=src, dst=dst, weight=weight) + ) + + for head_idx, conns in heads.items(): + conns.sort(key=lambda x: x.weight, reverse=True) + + return heads + + def _infer_protein_id(self) -> str: + if self.pdb_path is not None: + return self.pdb_path.stem + if self.fasta_path is not None: + return self.fasta_path.stem + return self.attention_dir.name + + @staticmethod + def _read_fasta_sequence(fasta_path: Path | None) -> str | None: + if fasta_path is None: + return None + if not fasta_path.exists(): + raise TraceNotFoundError(f"FASTA path does not exist: {fasta_path}") + + lines = fasta_path.read_text(encoding="utf-8").splitlines() + seq_lines = [line.strip() for line in lines if line and not line.startswith(">")] + return "".join(seq_lines) \ No newline at end of file diff --git a/vizfold/offline/models.py b/vizfold/offline/models.py new file mode 100644 index 00000000..330aa430 --- /dev/null +++ b/vizfold/offline/models.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from pathlib import Path +from typing import Any + + +@dataclass(frozen=True) +class AttentionConnection: + """One residue-residue edge with a scalar attention weight.""" + + src: int + dst: int + weight: float + + +@dataclass(frozen=True) +class AttentionSlice: + """ + One logical attention slice. + + Examples: + - MSA row attention at one layer/head + - Triangle-start attention at one layer/residue/head + """ + + attention_type: str + layer: int + head: int + residue_idx: int | None + connections: list[AttentionConnection] + + def top_k(self, k: int) -> "AttentionSlice": + if k < 0: + raise ValueError("k must be >= 0") + return replace(self, connections=self.connections[:k]) + + def as_triplets(self) -> list[tuple[int, int, float]]: + return [(c.src, c.dst, c.weight) for c in self.connections] + + +@dataclass(frozen=True) +class StructureData: + """ + Minimal structure payload for offline visualization. + """ + + protein_id: str + pdb_path: Path | None + pdb_text: str | None + sequence: str | None = None + + +@dataclass(frozen=True) +class TraceMetadata: + """ + Metadata for one trace source. + + heads_by_type maps: attention_type -> layer -> list[head_idx] + residue_indices_by_type maps: attention_type -> layer -> list[residue_idx] + """ + + protein_id: str + source_root: Path + fasta_path: Path | None = None + pdb_path: Path | None = None + sequence: str | None = None + + attention_types: list[str] = field(default_factory=list) + layers_by_type: dict[str, list[int]] = field(default_factory=dict) + heads_by_type: dict[str, dict[int, list[int]]] = field(default_factory=dict) + residue_indices_by_type: dict[str, dict[int, list[int]]] = field(default_factory=dict) + + # New normalized archive fields + schema_version: str | None = None + archive_kind: str | None = None + model_family: str | None = None + model_version: str | None = None + structure_available: bool = False + capabilities: list[str] = field(default_factory=list) + + extras: dict[str, Any] = field(default_factory=dict) \ No newline at end of file diff --git a/vizfold/offline/paths.py b/vizfold/offline/paths.py new file mode 100644 index 00000000..ee630b08 --- /dev/null +++ b/vizfold/offline/paths.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import re +from pathlib import Path + +from .exceptions import TraceNotFoundError, UnsupportedAttentionTypeError + +_MSA_ROW_RE = re.compile(r"^msa_row_attn_layer(?P\d+)\.txt$") +_TRIANGLE_START_RE = re.compile( + r"^triangle_start_attn_layer(?P\d+)_residue_idx_(?P\d+)\.txt$" +) + + +def legacy_msa_row_path(attention_dir: str | Path, layer: int) -> Path: + return Path(attention_dir) / f"msa_row_attn_layer{layer}.txt" + + +def legacy_triangle_start_path( + attention_dir: str | Path, + layer: int, + residue_idx: int, +) -> Path: + return Path(attention_dir) / f"triangle_start_attn_layer{layer}_residue_idx_{residue_idx}.txt" + + +def resolve_legacy_attention_path( + attention_dir: str | Path, + attention_type: str, + layer: int, + residue_idx: int | None = None, +) -> Path: + if attention_type == "msa_row": + path = legacy_msa_row_path(attention_dir, layer) + elif attention_type == "triangle_start": + if residue_idx is None: + raise ValueError("residue_idx is required for triangle_start attention") + path = legacy_triangle_start_path(attention_dir, layer, residue_idx) + else: + raise UnsupportedAttentionTypeError(f"Unsupported attention_type: {attention_type}") + + if not path.exists(): + raise TraceNotFoundError(f"Attention file not found: {path}") + + return path + + +def parse_legacy_attention_filename(filename: str) -> tuple[str, int, int | None] | None: + """ + Returns: + (attention_type, layer, residue_idx) + + Examples: + msa_row_attn_layer47.txt -> ("msa_row", 47, None) + triangle_start_attn_layer47_residue_idx_18.txt -> ("triangle_start", 47, 18) + """ + msa_match = _MSA_ROW_RE.match(filename) + if msa_match: + return ("msa_row", int(msa_match.group("layer")), None) + + tri_match = _TRIANGLE_START_RE.match(filename) + if tri_match: + return ( + "triangle_start", + int(tri_match.group("layer")), + int(tri_match.group("residue")), + ) + + return None \ No newline at end of file diff --git a/vizfold/offline/trace_reader.py b/vizfold/offline/trace_reader.py new file mode 100644 index 00000000..5a8f35ff --- /dev/null +++ b/vizfold/offline/trace_reader.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from .models import AttentionSlice, StructureData, TraceMetadata + + +class TraceReader(ABC): + """ + Abstract reader interface for offline inference traces. + + Both legacy text-based readers and future archive readers should implement + this exact API so downstream notebooks and Streamlit apps can remain stable. + """ + + @abstractmethod + def metadata(self) -> TraceMetadata: + raise NotImplementedError + + @abstractmethod + def list_attention_types(self) -> list[str]: + raise NotImplementedError + + @abstractmethod + def list_layers(self, attention_type: str) -> list[int]: + raise NotImplementedError + + @abstractmethod + def list_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + ) -> list[int]: + raise NotImplementedError + + @abstractmethod + def list_residue_indices( + self, + attention_type: str, + layer: int, + ) -> list[int]: + raise NotImplementedError + + @abstractmethod + def load_attention( + self, + attention_type: str, + layer: int, + head: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> AttentionSlice: + raise NotImplementedError + + @abstractmethod + def load_attention_heads( + self, + attention_type: str, + layer: int, + residue_idx: int | None = None, + top_k: int | None = None, + ) -> dict[int, AttentionSlice]: + raise NotImplementedError + + @abstractmethod + def load_structure(self) -> StructureData: + raise NotImplementedError \ No newline at end of file diff --git a/webui/app.py b/webui/app.py new file mode 100644 index 00000000..021db521 --- /dev/null +++ b/webui/app.py @@ -0,0 +1,421 @@ +""" +VizFold — Streamlit UI for offline protein model internals exploration. + +Run: + streamlit run webui/app.py +""" + +import hashlib +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +import streamlit as st + +import tempfile + +from components.structure_viewer import render_structure +from components.attention_heatmap import render_heatmap +from components.arc_diagram import render_arc_diagram + +from trace_reader import TraceReader +from vizfold.offline import ArchiveReader +from visualization_adapter import ( + flatten_attention_heads, + build_visualization_payload, +) +# ── Page config ─────────────────────────────────────────────────────────────── + +st.set_page_config( + page_title="VizFold", + page_icon="🔬", + layout="wide", + initial_sidebar_state="expanded", +) + +st.markdown( + """ + + """, + unsafe_allow_html=True, +) + +# ── Zarr session-state cache ────────────────────────────────────────────────── + +def _get_archive_reader(uploaded_file) -> ArchiveReader: + """Cache backend ArchiveReader in session state, keyed by uploaded file hash.""" + file_bytes = uploaded_file.getvalue() + file_hash = hashlib.md5(file_bytes).hexdigest() + + if st.session_state.get("_zarr_hash") != file_hash: + tmp_path = os.path.join( + tempfile.gettempdir(), + f"vizfold_uploaded_{file_hash}.zip", + ) + + with open(tmp_path, "wb") as f: + f.write(file_bytes) + + st.session_state["_zarr_hash"] = file_hash + st.session_state["_archive_reader"] = ArchiveReader(tmp_path) + + return st.session_state["_archive_reader"] + +# ── Sidebar ─────────────────────────────────────────────────────────────────── +# The UI supports two data paths: +# 1. Trace directory: legacy/sample text-output workflow used by the existing UI. +# 2. Zarr archive: standardized archive workflow backed by vizfold.offline.ArchiveReader. + +# Shared rendering outputs — populated by whichever source branch runs +connections: list = [] +fasta_seq: str = "" +n_residues: int = 0 +pdb_path: str | None = None +pdb_string: str | None = None +protein: str = "" +head_label: str = "" +attn_badge: str = "" +residue_idx: int | None = None + +with st.sidebar: + st.markdown("## 🔬 VizFold") + st.caption("Protein model internals explorer") + st.divider() + + source = st.radio( + "Data source", + ["Trace directory", "Zarr archive"], + help="Load from a local trace folder, or upload a Zarr ZipStore (.zip).", + ) + st.divider() + + # ── Branch A: Trace directory ───────────────────────────────────────────── + if source == "Trace directory": + default_trace = os.path.join(os.path.dirname(__file__), "sample_trace") + trace_dir = st.text_input( + "Trace directory", + value=default_trace, + help="Root folder containing per-protein trace subdirectories.", + ) + + reader = TraceReader(trace_dir) + proteins = reader.list_proteins() + + if not proteins: + st.error( + "No proteins found. Check the path, or run " + "`python webui/make_sample_trace.py` first." + ) + st.stop() + + protein = st.selectbox("Protein", proteins) + fasta_seq = reader.get_fasta_sequence(protein) + n_residues = len(fasta_seq) + if n_residues: + st.caption(f"{n_residues} residues") + + st.divider() + st.markdown("**Attention**") + + attn_type = st.radio( + "Type", + ["msa_row", "triangle_start"], + format_func=lambda x: "MSA Row" if x == "msa_row" else "Triangle Start", + ) + + layers = reader.list_layers(protein, attn_type) + if not layers: + st.warning(f"No `{attn_type}` attention files found for **{protein}**.") + st.stop() + + layer_idx = st.select_slider("Layer", options=layers, value=layers[-1]) + + if attn_type == "triangle_start": + available_res = reader.list_triangle_residues(protein, layer_idx) + if not available_res: + st.warning("No triangle attention files found for this layer.") + st.stop() + residue_idx = st.select_slider( + "Source residue", options=available_res, value=available_res[0] + ) + + top_k = st.slider("Top-K connections", 10, 200, 50, step=10) + + if attn_type == "triangle_start" and residue_idx is not None: + attn_data = reader.load_triangle_attention(protein, layer_idx, residue_idx, top_k) + else: + attn_data = reader.load_attention(protein, attn_type, layer_idx, top_k) + + n_heads = len(attn_data) + head_sel = st.selectbox( + "Head", + ["Average"] + list(range(n_heads)), + format_func=lambda x: "All heads averaged" if x == "Average" else f"Head {x}", + ) + + if head_sel == "Average": + connections = flatten_attention_heads(attn_data) + head_label = "All heads averaged" + else: + connections = attn_data.get(int(head_sel), []) + head_label = f"Head {head_sel}" + + pdb_path = reader.get_pdb_path(protein) + attn_badge = ( + "MSA Row" if attn_type == "msa_row" + else f"Triangle Start (res {residue_idx})" + ) + + # ── Branch B: Zarr archive ──────────────────────────────────────────────── + else: + uploaded = st.file_uploader( + "Upload Zarr archive", + type=["zip"], + help=( + "Upload a zipped Zarr archive containing attention arrays, " + "representations, metadata, and optional structure outputs." + ), + ) + + if not uploaded: + st.info("Upload a `.zip` Zarr archive to continue.") + st.stop() + + with st.spinner("Opening Zarr archive..."): + archive_reader = _get_archive_reader(uploaded) + + meta = archive_reader.metadata() + + with st.expander("Archive metadata", expanded=False): + st.write("Archive kind:", meta.archive_kind) + st.write("Model:", meta.model_family) + st.write("Model version:", meta.model_version) + st.write("Attention types:", meta.attention_types) + st.write("Capabilities:", meta.capabilities) + st.write("Available arrays:", meta.extras.get("arrays", {})) + + attention_types = archive_reader.list_attention_types() + + if not attention_types: + st.error("No attention arrays found in this archive.") + st.stop() + + attn_type = st.selectbox("Attention type", attention_types) + + layers = archive_reader.list_layers(attn_type) + + if not layers: + st.error(f"No layers found for attention type `{attn_type}`.") + st.stop() + + layer_idx = st.select_slider("Layer", options=layers, value=layers[-1]) + + residue_options = archive_reader.list_residue_indices(attn_type, layer_idx) + residue_idx = None + + if residue_options: + residue_idx = st.select_slider( + "Source residue", + options=residue_options, + value=residue_options[0], + ) + + heads = archive_reader.list_heads(attn_type, layer_idx, residue_idx) + + if not heads: + st.error(f"No heads found for `{attn_type}` layer {layer_idx}.") + st.stop() + + head_sel_zarr = st.selectbox( + "Head", + ["Average"] + heads, + format_func=lambda x: "All heads averaged" if x == "Average" else f"Head {x}", + ) + + top_k = st.slider("Top-K connections", 10, 200, 50, step=10) + + if head_sel_zarr == "Average": + loaded_heads = archive_reader.load_attention_heads( + attention_type=attn_type, + layer=layer_idx, + residue_idx=residue_idx, + top_k=top_k, + ) + + connections = [] + for attention_slice in loaded_heads.values(): + connections.extend(attention_slice.as_triplets()) + + connections = sorted(connections, key=lambda x: x[2], reverse=True) + head_label = "All heads averaged" + + else: + attention_slice = archive_reader.load_attention( + attention_type=attn_type, + layer=layer_idx, + head=int(head_sel_zarr), + residue_idx=residue_idx, + top_k=top_k, + ) + + connections = attention_slice.as_triplets() + head_label = f"Head {head_sel_zarr}" + + structure_data = archive_reader.load_structure() + + pdb_path = None + pdb_string = structure_data.pdb_text + + n_residues = ( + meta.extras.get("num_residues") + or len(structure_data.sequence or "") + or len(meta.sequence or "") + ) + + if not n_residues and connections: + n_residues = max(max(src, dst) for src, dst, _ in connections) + 1 + + fasta_seq = structure_data.sequence or meta.sequence or ("X" * int(n_residues)) + protein = meta.protein_id + attn_badge = ( + f"{attn_type}" + if residue_idx is None + else f"{attn_type} (res {residue_idx})" + ) + + st.caption( + f"{n_residues} residues · {len(layers)} layers · {len(heads)} heads" + ) + + # ── Shared display toggles ──────────────────────────────────────────────── + st.divider() + st.markdown("**Display**") + show_structure = st.toggle("3D Structure", value=True) + show_heatmap = st.toggle("Attention Heatmap", value=True) + show_arc = st.toggle("Arc Diagram", value=True) + +# ── Header ──────────────────────────────────────────────────────────────────── + +st.markdown(f"# {protein}") + +c1, c2, c3, c4 = st.columns(4) +c1.metric("Residues", n_residues) +c2.metric("Layer", layer_idx) +c3.metric("Head", head_label) +c4.metric("Connections", len(connections)) + +viz_payload = build_visualization_payload( + fasta_seq=fasta_seq, + pdb_path=pdb_path, + connections=connections, + layer_idx=layer_idx, + head_label=head_label, +) + +with st.expander("Visualization Integration Output", expanded=False): + st.write("Stored trace data has been converted into visualization-ready format.") + st.write("Format: `(source_residue, target_residue, attention_weight)`") + + st.write("Layer:", viz_payload["layer_idx"]) + st.write("Head:", viz_payload["head_label"]) + st.write("Residues:", viz_payload["n_residues"]) + st.write("Connections:", len(viz_payload["connections"])) + + if viz_payload["connections"]: + st.dataframe( + [ + { + "source_residue": r1, + "target_residue": r2, + "attention_weight": weight, + } + for r1, r2, weight in viz_payload["connections"][:10] + ], + use_container_width=True, + ) + + +st.caption(f"Attention: **{attn_badge}** · top-{top_k} per head") + +if not connections: + st.warning("No connections loaded — check that the trace files exist.") + +st.divider() + +# ── Main layout ─────────────────────────────────────────────────────────────── + +left_col, right_col = st.columns([1, 1], gap="large") + +with left_col: + if show_structure: + with st.container(border=True): + st.markdown("#### 3D Structure") + has_structure = (pdb_path and os.path.exists(pdb_path)) or pdb_string + if has_structure: + render_structure( + pdb_path=pdb_path, + connections=connections, + n_residues=n_residues, + pdb_string=pdb_string, + ) + else: + st.info( + "No structure file found. For Zarr archives, include a " + "`structure_pdb` array with PDB text content." + ) + + if show_arc: + with st.container(border=True): + st.markdown("#### Arc Diagram") + render_arc_diagram(connections, fasta_seq, highlight_residue=residue_idx) + +with right_col: + if show_heatmap: + with st.container(border=True): + st.markdown("#### Attention Heatmap") + render_heatmap(connections, fasta_seq, head_label) + + with st.container(border=True): + st.markdown("#### Residue Scores") + if connections and fasta_seq: + import numpy as np + import plotly.graph_objects as go + + scores = np.zeros(n_residues) + for r1, r2, w in connections: + if r1 < n_residues: + scores[r1] += w + if r2 < n_residues: + scores[r2] += w + + tick_step = max(1, n_residues // 20) + fig = go.Figure( + go.Bar( + x=list(range(n_residues)), + y=scores, + marker=dict(color=scores, colorscale="Reds", showscale=False), + hovertemplate=( + "Residue %{x} (%{customdata})
" + "Score: %{y:.4f}" + ), + customdata=list(fasta_seq), + ) + ) + fig.update_layout( + xaxis=dict( + title="Residue index", + tickvals=list(range(0, n_residues, tick_step)), + ticktext=[fasta_seq[i] for i in range(0, n_residues, tick_step)], + ), + yaxis_title="Aggregated attention", + margin=dict(l=40, r=10, t=10, b=50), + height=240, + ) + st.plotly_chart(fig, use_container_width=True) + else: + st.info("Load attention data to see per-residue scores.") diff --git a/webui/components/__init__.py b/webui/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/webui/components/arc_diagram.py b/webui/components/arc_diagram.py new file mode 100644 index 00000000..d01b4a35 --- /dev/null +++ b/webui/components/arc_diagram.py @@ -0,0 +1,76 @@ +""" +Arc diagram — residue-to-residue attention arcs drawn with Matplotlib. +Adapted from visualize_attention_arc_diagram_demo_utils.py but returns +a Figure instead of saving to disk, for use inside Streamlit. +""" + +from __future__ import annotations +from typing import List, Optional, Tuple + +import numpy as np +import matplotlib.pyplot as plt +import streamlit as st + + +def render_arc_diagram( + connections: List[Tuple[int, int, float]], + fasta_seq: str, + highlight_residue: Optional[int] = None, +) -> None: + if not connections or not fasta_seq: + st.info("No arc data to display.") + return + fig = _build_figure(connections, fasta_seq, highlight_residue) + st.pyplot(fig, use_container_width=True) + plt.close(fig) + + +# ── Internal ────────────────────────────────────────────────────────────────── + +def _build_figure( + connections: List[Tuple[int, int, float]], + fasta_seq: str, + highlight_residue: Optional[int], +) -> plt.Figure: + n = len(fasta_seq) + weights = [w for _, _, w in connections] + w_min, w_max = min(weights), max(weights) + w_range = (w_max - w_min) or 1.0 + + fig_w = max(14, n // 7) + fig, ax = plt.subplots(figsize=(fig_w, 4)) + fig.patch.set_facecolor("#0e1117") + ax.set_facecolor("#0e1117") + + for r1, r2, w in connections: + x1, x2 = r1 + 0.5, r2 + 0.5 + height = abs(x2 - x1) / 2 + norm_w = (w - w_min) / w_range + lw = 0.3 + norm_w * 2.0 + intensity = 0.35 + 0.65 * norm_w + color = (0.15, 0.45 * (1 - norm_w * 0.5), intensity) + + xs = np.linspace(x1, x2, 80) + ys = height * np.sin(np.linspace(0, np.pi, 80)) + ax.plot(xs, ys, color=color, linewidth=lw, alpha=0.85, solid_capstyle="round") + + ax.set_xlim(0, n) + ax.set_ylim(0, None) + ax.set_xticks(np.arange(n) + 0.5) + labels = ax.set_xticklabels( + list(fasta_seq), fontsize=max(5, min(8, 120 // n)), + color="#aaaaaa", ha="center", + ) + + if highlight_residue is not None and 0 <= highlight_residue < len(labels): + labels[highlight_residue].set_color("#ff4b4b") + labels[highlight_residue].set_fontweight("bold") + + ax.tick_params(axis="x", length=0) + ax.set_yticks([]) + ax.spines[:].set_visible(False) + ax.set_ylabel("Attention strength", color="#aaaaaa", fontsize=9) + ax.yaxis.label.set_color("#aaaaaa") + + plt.tight_layout(pad=0.5) + return fig diff --git a/webui/components/attention_heatmap.py b/webui/components/attention_heatmap.py new file mode 100644 index 00000000..1b0696ff --- /dev/null +++ b/webui/components/attention_heatmap.py @@ -0,0 +1,69 @@ +""" +Interactive N×N attention heatmap rendered with Plotly. +Builds a dense matrix from sparse (res1, res2, weight) connections. +""" + +from __future__ import annotations +from typing import List, Tuple + +import numpy as np +import plotly.graph_objects as go +import streamlit as st + + +def render_heatmap( + connections: List[Tuple[int, int, float]], + fasta_seq: str, + head_label: str, +) -> None: + n = len(fasta_seq) + if n == 0 or not connections: + st.info("No attention data to display.") + return + + matrix = _build_matrix(connections, n) + + tick_step = max(1, n // 20) + tick_vals = list(range(0, n, tick_step)) + tick_text = [fasta_seq[i] for i in tick_vals] + + fig = go.Figure( + data=go.Heatmap( + z=matrix, + colorscale="Blues", + colorbar=dict(title="Attention", thickness=14), + hovertemplate="Source %{y} → Target %{x}
Score: %{z:.5f}", + ) + ) + fig.update_layout( + title=dict(text=f"Attention Map — {head_label}", font=dict(size=14)), + xaxis=dict( + title="Target residue", + tickvals=tick_vals, + ticktext=tick_text, + tickfont=dict(size=10), + ), + yaxis=dict( + title="Source residue", + tickvals=tick_vals, + ticktext=tick_text, + tickfont=dict(size=10), + autorange="reversed", + ), + margin=dict(l=60, r=20, t=50, b=60), + height=440, + ) + st.plotly_chart(fig, use_container_width=True) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _build_matrix( + connections: List[Tuple[int, int, float]], n: int +) -> np.ndarray: + matrix = np.zeros((n, n)) + for r1, r2, w in connections: + if r1 < n and r2 < n: + matrix[r1, r2] += w + matrix[r2, r1] += w + return matrix diff --git a/webui/components/structure_viewer.py b/webui/components/structure_viewer.py new file mode 100644 index 00000000..56808fad --- /dev/null +++ b/webui/components/structure_viewer.py @@ -0,0 +1,86 @@ +""" +3D protein structure viewer powered by py3Dmol. +Residues are colored by aggregated attention score (white → red gradient). +""" + +from __future__ import annotations +from typing import List, Tuple + +import os + +import numpy as np +import py3Dmol +import streamlit as st +import streamlit.components.v1 as components + + +def render_structure( + pdb_path: str | None, + connections: List[Tuple[int, int, float]], + n_residues: int, + height: int = 480, + pdb_string: str | None = None, +) -> None: + if pdb_string is not None: + pdb_str = pdb_string + elif pdb_path and os.path.exists(pdb_path): + with open(pdb_path) as f: + pdb_str = f.read() + else: + st.info("No structure file available.") + return + + scores = _residue_scores(connections, n_residues) + + view = py3Dmol.view(width="100%", height=height) + view.addModel(pdb_str, "pdb") + view.setStyle({"cartoon": {"color": "#cccccc"}}) + view.setBackgroundColor("#0e1117") + + if scores.max() > 0: + normed = scores / scores.max() + for resi_0, t in enumerate(normed): + if t > 0.02: + color = _score_to_hex(float(t)) + view.addStyle( + {"resi": resi_0 + 1}, + {"cartoon": {"color": color}}, + ) + + view.zoomTo() + + # stmol is optional; fall back to raw HTML embed + try: + from stmol import showmol # type: ignore + showmol(view, height=height + 20) + except ImportError: + components.html(view._make_html(), height=height + 20, scrolling=False) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _residue_scores( + connections: List[Tuple[int, int, float]], n_residues: int +) -> np.ndarray: + scores = np.zeros(n_residues) + for r1, r2, w in connections: + if r1 < n_residues: + scores[r1] += w + if r2 < n_residues: + scores[r2] += w + return scores + + +def _score_to_hex(t: float) -> str: + """Map t ∈ [0, 1] to a white→orange→red gradient.""" + if t < 0.5: + s = t * 2 + r = int(255) + g = int(255 * (1 - s * 0.5)) + b = int(255 * (1 - s)) + else: + s = (t - 0.5) * 2 + r = int(255) + g = int(255 * 0.5 * (1 - s)) + b = 0 + return f"#{r:02x}{g:02x}{b:02x}" diff --git a/webui/make_sample_trace.py b/webui/make_sample_trace.py new file mode 100644 index 00000000..52ce0d13 --- /dev/null +++ b/webui/make_sample_trace.py @@ -0,0 +1,135 @@ +""" +make_sample_trace.py — one-time setup script for the VizFold sample trace. + +Run from the repo root: + python webui/make_sample_trace.py + +Copies the 6KWC PDB and FASTA from the examples/ directory into +webui/sample_trace/6KWC/ and writes synthetic attention files so the +Streamlit app can run immediately without a real inference trace. +""" + +import os +import shutil +import random +import math + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +WEBUI_DIR = os.path.dirname(os.path.abspath(__file__)) +OUT_DIR = os.path.join(WEBUI_DIR, "sample_trace", "6KWC") + +PDB_SRC = os.path.join(REPO_ROOT, "examples", "monomer", "sample_predictions", + "6KWC_1_model_1_ptm_relaxed.pdb") +FASTA_SRC = os.path.join(REPO_ROOT, "examples", "monomer", "fasta_dir", "6kwc.fasta") + +LAYERS = [0, 10, 20, 30, 40, 47] +N_HEADS = 8 +N_RESIDUES = 190 +TOP_K = 100 +TRIANGLE_RESIDUES = [18, 39, 51, 79, 138, 159] + +random.seed(42) + + +def softmax_weights(n: int, bias_center: int | None = None) -> list[float]: + """Generate plausible attention-like weights (peaked distribution).""" + logits = [random.gauss(0, 1) for _ in range(n)] + if bias_center is not None: + for i in range(n): + dist = abs(i - bias_center) + logits[i] += max(0, 3.0 - dist * 0.15) + max_l = max(logits) + exps = [math.exp(l - max_l) for l in logits] + total = sum(exps) + return [e / total for e in exps] + + +def generate_connections( + n_residues: int, + n_connections: int, + bias_residue: int | None = None, +) -> list[tuple[int, int, float]]: + row_weights = softmax_weights(n_residues, bias_residue) + connections: list[tuple[int, int, float]] = [] + seen: set[tuple[int, int]] = set() + + while len(connections) < n_connections: + r1 = random.choices(range(n_residues), weights=row_weights)[0] + col_weights = softmax_weights(n_residues, r1) + r2 = random.choices(range(n_residues), weights=col_weights)[0] + if r1 == r2: + continue + key = (min(r1, r2), max(r1, r2)) + if key in seen: + continue + seen.add(key) + weight = random.uniform(0.001, 0.05) + random.random() * 0.02 + connections.append((r1, r2, weight)) + + connections.sort(key=lambda x: x[2], reverse=True) + return connections + + +def write_heads_file(path: str, layer_idx: int, n_heads: int, + n_residues: int, top_k: int, + bias_residue: int | None = None) -> None: + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w") as f: + for head in range(n_heads): + conns = generate_connections(n_residues, top_k, bias_residue) + f.write(f"Layer {layer_idx}, Head {head}\n") + for r1, r2, w in conns: + f.write(f"{r1} {r2} {w:.6f}\n") + print(f" Written: {os.path.relpath(path, REPO_ROOT)}") + + +def main() -> None: + print("VizFold — setting up sample trace for 6KWC\n") + + attn_dir = os.path.join(OUT_DIR, "attention") + os.makedirs(attn_dir, exist_ok=True) + + # --- PDB --- + pdb_dst = os.path.join(OUT_DIR, "6KWC.pdb") + if os.path.exists(PDB_SRC): + shutil.copy2(PDB_SRC, pdb_dst) + print(f" Copied: {os.path.relpath(pdb_dst, REPO_ROOT)}") + else: + print(f" [WARN] PDB source not found: {PDB_SRC}") + + # --- FASTA --- + fasta_dst = os.path.join(OUT_DIR, "6KWC.fasta") + if os.path.exists(FASTA_SRC): + shutil.copy2(FASTA_SRC, fasta_dst) + print(f" Copied: {os.path.relpath(fasta_dst, REPO_ROOT)}") + else: + print(f" [WARN] FASTA source not found: {FASTA_SRC}") + + print() + + # --- MSA row attention --- + print("Generating MSA row attention files...") + for layer in LAYERS: + path = os.path.join(attn_dir, f"msa_row_attn_layer{layer}.txt") + write_heads_file(path, layer, N_HEADS, N_RESIDUES, TOP_K) + + print() + + # --- Triangle start attention --- + print("Generating triangle start attention files...") + for layer in [47]: + for res_idx in TRIANGLE_RESIDUES: + path = os.path.join( + attn_dir, + f"triangle_start_attn_layer{layer}_residue_idx_{res_idx}.txt", + ) + write_heads_file(path, layer, N_HEADS, N_RESIDUES, TOP_K, + bias_residue=res_idx) + + print() + print("Done. Start the app with:") + print(" streamlit run webui/app.py") + + +if __name__ == "__main__": + main() diff --git a/webui/trace_reader.py b/webui/trace_reader.py new file mode 100644 index 00000000..87922a99 --- /dev/null +++ b/webui/trace_reader.py @@ -0,0 +1,294 @@ +""" +TraceReader — loads offline VizFold inference traces from disk. + +Expected directory layout: + + {trace_root}/ + └── {PROTEIN_ID}/ + ├── *.pdb + ├── *.fasta + └── attention/ + ├── msa_row_attn_layer{N}.txt + └── triangle_start_attn_layer{N}_residue_idx_{R}.txt + +Attention files follow the format used by the existing OpenFold viz pipeline: + Layer {N}, Head {H} + res1 res2 weight + ... +""" + +import os +import glob +import re +import tempfile +from typing import Dict, List, Optional, Tuple + +import numpy as np + + +AttentionMap = Dict[int, List[Tuple[int, int, float]]] + + +class TraceReader: + def __init__(self, trace_root: str): + self.root = os.path.expanduser(trace_root) + + # ── Discovery ──────────────────────────────────────────────────────────── + + def list_proteins(self) -> List[str]: + if not os.path.isdir(self.root): + return [] + return sorted( + d for d in os.listdir(self.root) + if os.path.isdir(os.path.join(self.root, d)) + ) + + def get_pdb_path(self, protein: str) -> Optional[str]: + hits = glob.glob(os.path.join(self.root, protein, "*.pdb")) + return hits[0] if hits else None + + def get_fasta_sequence(self, protein: str) -> str: + hits = glob.glob(os.path.join(self.root, protein, "*.fasta")) + if not hits: + return "" + with open(hits[0]) as f: + lines = f.readlines() + return "".join(l.strip() for l in lines if not l.startswith(">")) + + def list_layers(self, protein: str, attention_type: str) -> List[int]: + attn_dir = os.path.join(self.root, protein, "attention") + if not os.path.isdir(attn_dir): + return [] + if attention_type == "msa_row": + pat = re.compile(r"msa_row_attn_layer(\d+)\.txt$") + else: + pat = re.compile(r"triangle_start_attn_layer(\d+)_residue_idx_\d+\.txt$") + layers: set = set() + for fname in os.listdir(attn_dir): + m = pat.match(fname) + if m: + layers.add(int(m.group(1))) + return sorted(layers) + + def list_triangle_residues(self, protein: str, layer_idx: int) -> List[int]: + attn_dir = os.path.join(self.root, protein, "attention") + if not os.path.isdir(attn_dir): + return [] + pat = re.compile( + rf"triangle_start_attn_layer{layer_idx}_residue_idx_(\d+)\.txt$" + ) + residues: set = set() + for fname in os.listdir(attn_dir): + m = pat.match(fname) + if m: + residues.add(int(m.group(1))) + return sorted(residues) + + # ── Loading ─────────────────────────────────────────────────────────────── + + def load_attention( + self, + protein: str, + attention_type: str, + layer_idx: int, + top_k: Optional[int] = None, + ) -> AttentionMap: + if attention_type != "msa_row": + return {} + path = os.path.join( + self.root, protein, "attention", + f"msa_row_attn_layer{layer_idx}.txt", + ) + return self._parse_heads_file(path, top_k) if os.path.exists(path) else {} + + def load_triangle_attention( + self, + protein: str, + layer_idx: int, + residue_idx: int, + top_k: Optional[int] = None, + ) -> AttentionMap: + path = os.path.join( + self.root, protein, "attention", + f"triangle_start_attn_layer{layer_idx}_residue_idx_{residue_idx}.txt", + ) + return self._parse_heads_file(path, top_k) if os.path.exists(path) else {} + + # ── Internal ────────────────────────────────────────────────────────────── + + @staticmethod + def _parse_heads_file(path: str, top_k: Optional[int]) -> AttentionMap: + heads: AttentionMap = {} + current: Optional[int] = None + with open(path) as f: + for line in f: + line = line.strip() + if not line: + continue + if line.lower().startswith("layer"): + parts = line.replace(",", "").split() + current = int(parts[-1]) + heads[current] = [] + elif current is not None: + r1, r2, w = map(float, line.split()) + heads[current].append((int(r1), int(r2), w)) + for h in heads: + heads[h].sort(key=lambda x: x[2], reverse=True) + if top_k is not None: + heads[h] = heads[h][:top_k] + return heads + + +# ── Zarr reader ─────────────────────────────────────────────────────────────── + +Connections = List[Tuple[int, int, float]] + + +class ZarrTraceReader: + """ + Reads attention data from a Zarr ZipStore archive (.zip). + + Auto-detects attention arrays by shape: any array with ndim >= 2 + whose last two dimensions are equal (i.e. N×N residue–residue matrices). + + Supported shapes: + 4D [n_layers, n_heads, N, N] — standard; layer + head indexable + 3D [n_heads, N, N] — single-layer store + 2D [N, N] — single head + single layer + + Optional metadata arrays (auto-detected by name): + sequence / seq / fasta — amino acid string (bytes or str array) + structure_pdb / pdb — PDB file content (bytes or str array) + """ + + def __init__(self, file_bytes: bytes) -> None: + import zarr + import zarr.storage + + self._tmp_path = tempfile.mktemp(suffix=".zip") + with open(self._tmp_path, "wb") as f: + f.write(file_bytes) + + self._store = zarr.storage.ZipStore(self._tmp_path, mode="r") + self._root = zarr.open_group(self._store, mode="r") + self._arrays: Dict[str, object] = {} + self._collect_arrays() + + def _collect_arrays(self) -> None: + import zarr + for name, item in self._root.members(max_depth=None): + if isinstance(item, zarr.Array): + self._arrays[name] = item + + # ── Discovery ───────────────────────────────────────────────────────────── + + def list_all_arrays(self) -> Dict[str, tuple]: + return {name: arr.shape for name, arr in self._arrays.items()} # type: ignore[union-attr] + + def list_attention_arrays(self) -> Dict[str, tuple]: + """Arrays whose last two dims are equal and ≥ 10 (likely N×N attention).""" + out = {} + for name, arr in self._arrays.items(): + shape = arr.shape # type: ignore[union-attr] + if len(shape) >= 2 and shape[-1] == shape[-2] and shape[-1] >= 10: + out[name] = shape + return out + + def n_layers(self, array_name: str) -> int: + shape = self._arrays[array_name].shape # type: ignore[union-attr] + return shape[0] if len(shape) >= 4 else 1 + + def n_heads(self, array_name: str) -> int: + shape = self._arrays[array_name].shape # type: ignore[union-attr] + if len(shape) >= 4: + return shape[1] + if len(shape) == 3: + return shape[0] + return 1 + + def n_residues(self, array_name: str) -> int: + return self._arrays[array_name].shape[-1] # type: ignore[union-attr] + + # ── Loading ─────────────────────────────────────────────────────────────── + + def load_attention( + self, + array_name: str, + layer_idx: int, + head_idx: Optional[int], + top_k: int = 50, + ) -> Connections: + arr = self._arrays[array_name] + shape = arr.shape # type: ignore[union-attr] + + if len(shape) == 4: + # [n_layers, n_heads, N, N] + layer_data = np.array(arr[layer_idx]) # type: ignore[index] + matrix = layer_data.mean(axis=0) if head_idx is None else layer_data[head_idx] + elif len(shape) == 3: + # [n_heads, N, N] + data = np.array(arr[:]) # type: ignore[index] + matrix = data.mean(axis=0) if head_idx is None else data[head_idx] + else: + # [N, N] + matrix = np.array(arr[:]) # type: ignore[index] + + return _dense_to_topk_connections(matrix.astype(float), top_k) + + def get_sequence(self) -> str: + for key in ("sequence", "seq", "fasta", "metadata/sequence"): + if key in self._arrays: + raw = np.array(self._arrays[key][()]) + if raw.dtype.kind in ("S", "U", "O"): + val = raw.flat[0] + return val.decode() if isinstance(val, bytes) else str(val) + return "" + + def get_pdb_string(self) -> Optional[str]: + for key in ("structure_pdb", "pdb", "structure", "structure/pdb"): + if key in self._arrays: + raw = np.array(self._arrays[key][()]) + if raw.dtype.kind in ("S", "U", "O"): + val = raw.flat[0] + text = val.decode() if isinstance(val, bytes) else str(val) + if "ATOM" in text or "HETATM" in text: + return text + return None + + def __del__(self) -> None: + try: + self._store.close() # type: ignore[attr-defined] + except Exception: + pass + try: + os.unlink(self._tmp_path) + except Exception: + pass + + +# ── Shared helper ───────────────────────────────────────────────────────────── + +def _dense_to_topk_connections(matrix: np.ndarray, top_k: int) -> Connections: + """Convert a dense N×N attention matrix to a sorted top-k connections list.""" + n = matrix.shape[0] + rows, cols = np.triu_indices(n, k=1) + weights = (matrix[rows, cols] + matrix[cols, rows]) / 2.0 + idx = np.argsort(weights)[::-1][:top_k] + return [(int(rows[i]), int(cols[i]), float(weights[i])) for i in idx] + + +def flatten_attention_heads(attn_data): + """ + Converts attention data from: + {head: [(res1, res2, weight), ...]} + + into + [(res1, res2, weight), ...] + + """ + connections = [] + + for head_connections in attn_data.values(): + connections.extend(head_connections) + + return sorted(connections, key=lambda x: x[2], reverse=True) diff --git a/webui/visualization_adapter.py b/webui/visualization_adapter.py new file mode 100644 index 00000000..181ab0b5 --- /dev/null +++ b/webui/visualization_adapter.py @@ -0,0 +1,108 @@ +""" +visualization_adapter.py + +Middle-layer adapter for Issue #41. + +This file converts stored offline trace data from the archive/reader layer +into visualization-ready formats for the Streamlit UI. +""" + +from typing import Dict, List, Tuple, Optional +import numpy as np + +Connection = Tuple[int, int, float] +AttentionMap = Dict[int, List[Connection]] + + +def flatten_attention_heads(attn_data: AttentionMap) -> List[Connection]: + """ + Converts: + {head: [(res1, res2, weight), ...]} + + into: + [(res1, res2, weight), ...] + + This is the standardized format expected by Dev's visualization components. + """ + connections: List[Connection] = [] + + for head_connections in attn_data.values(): + connections.extend(head_connections) + + return sorted(connections, key=lambda x: x[2], reverse=True) + + +def dense_attention_to_connections( + matrix: np.ndarray, + top_k: int = 50, + symmetric: bool = True, +) -> List[Connection]: + """ + Converts a dense N x N attention matrix into top-k residue connections. + + This supports archive outputs that store full tensors instead of sparse files. + """ + if matrix.ndim != 2: + raise ValueError("Expected a 2D attention matrix.") + + if matrix.shape[0] != matrix.shape[1]: + raise ValueError("Attention matrix must be square.") + + n = matrix.shape[0] + + if symmetric: + rows, cols = np.triu_indices(n, k=1) + weights = (matrix[rows, cols] + matrix[cols, rows]) / 2.0 + else: + rows, cols = np.where(~np.eye(n, dtype=bool)) + weights = matrix[rows, cols] + + top_indices = np.argsort(weights)[::-1][:top_k] + + return [ + (int(rows[i]), int(cols[i]), float(weights[i])) + for i in top_indices + ] + + +def residue_attention_scores( + connections: List[Connection], + n_residues: int, +) -> np.ndarray: + """ + Aggregates connection weights into one score per residue. + + Used for coloring protein structure or plotting residue-level importance. + """ + scores = np.zeros(n_residues) + + for r1, r2, weight in connections: + if 0 <= r1 < n_residues: + scores[r1] += weight + if 0 <= r2 < n_residues: + scores[r2] += weight + + return scores + + +def build_visualization_payload( + fasta_seq: str, + pdb_path: Optional[str], + connections: List[Connection], + layer_idx: int, + head_label: str, +) -> dict: + """ + Creates one clean payload Dev's UI can consume. + """ + n_residues = len(fasta_seq) + + return { + "fasta_seq": fasta_seq, + "pdb_path": pdb_path, + "connections": connections, + "n_residues": n_residues, + "layer_idx": layer_idx, + "head_label": head_label, + "residue_scores": residue_attention_scores(connections, n_residues), + } \ No newline at end of file