Skip to content

Add ESMFold backend smoke test and reproducibility documentation#58

Open
Mose-Kim02 wants to merge 19 commits into
AI2Science:mainfrom
Mose-Kim02:feature/esmfold-backend
Open

Add ESMFold backend smoke test and reproducibility documentation#58
Mose-Kim02 wants to merge 19 commits into
AI2Science:mainfrom
Mose-Kim02:feature/esmfold-backend

Conversation

@Mose-Kim02

Copy link
Copy Markdown

Summary

  • Adds smoke test validating ESMFold backend trace extraction pipeline
  • Adds reproducibility documentation for local and ICE execution

Validated locally

  • structure-only inference runs successfully
  • attention trace extraction runs successfully
  • activation trace extraction runs successfully
  • 36 attention tensors exported
  • 36 activation tensors exported
  • 72 total trace tensors
  • predicted.pdb and meta.json generated correctly
  • archive structure matches VizFold schema

Purpose
Supports integration testing and reproducibility for Issue #43 shared backend branch.

…pipeline

- Rewrite hooks.py: target HF encoder.layer[i].attention.self directly
  instead of broad name-matching; monkey-patch forward to force
  output_attentions=True so real attention weights [B,H,N,N] are captured
  (not hidden states); separate attention/activation hooks with correct
  layer indices; slice out <cls>/<eos> tokens from attention maps
- Fix _coords_to_minimal_pdb: use 3-letter residue codes (valid PDB)
- Remove dead code: try_use_outputs() path, shared mutable counter
- Extract structure logic into _extract_structure() method
- Unify FASTA reading into single read_fasta() returning (seq, id, hash)
- Wire --dtype through CLI (float32/float16 model loading)
- Log runner.run() result (attention/activation layer counts)
- Fix trace_adapter: correct head-slicing axis for 3D vs 4D tensors;
  fix entropy calculation (per-row, not per-matrix)
Comment thread tests/test_esmf_backend_smoke.py Outdated


def test_esmfold_backend_smoke():
output_dir = "outputs/test_trace_ci"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the output_dir is hardcoded, where if you run this twice it will pass on stale data from the first run. what i would suggest is using tmp_path (pytest fixture) or tempfile.mkdtemp() so each run starts clean

Comment thread tests/test_esmf_backend_smoke.py Outdated
output_dir = "outputs/test_trace_ci"

cmd = [
"python",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this line, use sys.executable instead of "python" here. otherwise it might pick up the wrong python on some machines

Comment thread tests/test_esmf_backend_smoke.py Outdated
for root, _, files in os.walk(f"{output_dir}/trace"):
trace_files += [f for f in files if f.endswith(".pt")]

assert len(trace_files) >= 36

@jayvenn21 jayvenn21 Mar 22, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking file count is good but it would also be nice to load one of the .pt files and assert the shape is [1, H, N, N] for attention so we know the tensors are actually correct and not just empty files. specifically, we want to know that the hook actually captured real attention weights with the right dimensions through this additional shape check.

Comment thread tests/test_esmf_backend_smoke.py Outdated
@@ -0,0 +1,35 @@
import os

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general comment for this file: there's already a test_esmf_smoke.py with a few test cases, might make sense to add these into that file instead of a separate one so we dont end up with two test files for the same thing

```bash
python3 -m venv .venv
source .venv/bin/activate

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the markdown is broken here since the code fence never gets closed. as a result, everything after this renders as one big block. this just needs the closing triple backticks after each code section

@jayvenn21 jayvenn21 left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review #1 for smoke test and repro docs

jayvenn21 and others added 11 commits March 22, 2026 21:47
…re, trace relpaths, summary logging, layer_count
…ence#2)

Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
…I2Science#1)

* Add VizFold text-file attention export compatible with existing visualization tools

* Bug fix: override the positional arg in-place instead of adding to kwargs

* Fix: trace_formats missing from meta.json

* Robust attention saving & forward signature handling

hooks.py:
- make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index

trace_adapter.py:
- reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files
- layer-index extraction via regex
- compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]
Co-authored-by: Mose Kim <kimmose2002@gmail.com>
…cience#3)

* Extract s_s folding trunk activations and enforce safetensors

* Update backend pipeline

* Capture s_s and s_z at every recycling iteration via trunk hook

* Remove test output artifacts

---------

Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
* Add VizFold text-file attention export compatible with existing visualization tools

* Bug fix: override the positional arg in-place instead of adding to kwargs

* Fix: trace_formats missing from meta.json

* Robust attention saving & forward signature handling

hooks.py:
- make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index

trace_adapter.py:
- reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files
- layer-index extraction via regex
- compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]

* Capture and save evoformer trunk intermediates

Add per-block evoformer tracing and output saving for ESMFold.

- hooks.py: introduce register_trunk_hooks and _make_trunk_block_hook to register forward hooks on model.trunk.blocks (EsmFoldTriangularSelfAttentionBlock). Captured per-block sequence_state and pairwise_state are stored in collector.trunk_blocks; clear() updated and warnings added when trunk/blocks are missing.
- inference.py: register the new trunk hooks in ESMFoldRunner, extract and save final folding trunk pair representations (out.s_z), and write per-block evoformer intermediates to trace/trunk/*.pt while recording shapes. Logging messages adjusted.
- trace_adapter.py: update trace layout to include trunk/ files (block_{idx}_seq/pair, s_s, s_z).

* ESMFold: save trunk tensors, CPU attention

Ensure attention tensors are moved to CPU in hooks (detach().cpu()) to avoid GPU tensor serialization. Stop extracting final trunk outputs from model.out and instead collect final s_s/s_z from collector.recycled_s_s/recycled_s_z (avoids redundant copies) and save per-block trunk tensors plus final s_s/s_z into trace/trunk/.

* Squeeze batch dim in hooks; drop recycling archive

Fix tensor shape handling in ESMFoldTraceCollector hooks by squeezing the leading batch dimension before detaching and moving seq and pair states to CPU, preventing stored activations from containing an extra batch axis. Also remove the prior archival of recycled s_s/s_z tensors in the ESMFoldRunner inference flow to avoid redundant/memory-heavy activation copies and logging related to those recycled tensors.
@Mose-Kim02 Mose-Kim02 force-pushed the feature/esmfold-backend branch from 44ab130 to 26ae087 Compare April 3, 2026 20:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants