diff --git a/README.md b/README.md
index ec15a886..4bf94356 100644
--- a/README.md
+++ b/README.md
@@ -17,3 +17,8 @@ This project is licensed under the [Apache License 2.0](https://www.apache.org/l
See the [LICENSE](./LICENSE) file for details.
---
+
+This fork attempts to combine the features provided in the previous attention-related heatmap visualization pull requests (namely #16, #21, and #28) into a single pull request.
+
+Sample Output
+
diff --git a/outputs/attention_images_6KWC_demo_tri_18/heatmaps/html/msa_row_layer47_heatmap_grid.html b/outputs/attention_images_6KWC_demo_tri_18/heatmaps/html/msa_row_layer47_heatmap_grid.html
new file mode 100644
index 00000000..c9212a99
--- /dev/null
+++ b/outputs/attention_images_6KWC_demo_tri_18/heatmaps/html/msa_row_layer47_heatmap_grid.html
@@ -0,0 +1,3888 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/outputs/attention_images_dummy_protein_demo_tri_18/heatmaps/html/msa_row_layer47_heatmap_grid.html b/outputs/attention_images_dummy_protein_demo_tri_18/heatmaps/html/msa_row_layer47_heatmap_grid.html
new file mode 100644
index 00000000..1b3999b8
--- /dev/null
+++ b/outputs/attention_images_dummy_protein_demo_tri_18/heatmaps/html/msa_row_layer47_heatmap_grid.html
@@ -0,0 +1,3888 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/generate_attention_heatmaps.py b/scripts/generate_attention_heatmaps.py
new file mode 100644
index 00000000..272a87d8
--- /dev/null
+++ b/scripts/generate_attention_heatmaps.py
@@ -0,0 +1,239 @@
+#!/usr/bin/env python3
+"""
+CLI script for generating attention heatmap visualizations.
+
+This script generates heatmap visualizations of OpenFold attention mechanisms,
+enabling cross-head comparison and pattern recognition that complements
+existing arc diagrams and PyMOL overlays.
+"""
+
+import argparse
+import os
+import sys
+from pathlib import Path
+
+# Add the parent directory to the path so we can import the visualization modules
+sys.path.append(str(Path(__file__).parent.parent))
+
+from visualize_attention_heatmap_utils import (
+ plot_all_heads_heatmap,
+ plot_combined_attention_heatmap
+)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Generate attention heatmap visualizations",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Generate MSA Row attention heatmap
+ python scripts/generate_attention_heatmaps.py \\
+ --attention_dir ./outputs/attention_files_6KWC_demo_tri_18 \\
+ --output_dir ./outputs/heatmap_visualizations \\
+ --protein 6KWC \\
+ --layer 47 \\
+ --attention_type msa_row
+
+ # Generate Triangle Start attention heatmap
+ python scripts/generate_attention_heatmaps.py \\
+ --attention_dir ./outputs/attention_files_6KWC_demo_tri_18 \\
+ --output_dir ./outputs/heatmap_visualizations \\
+ --protein 6KWC \\
+ --layer 47 \\
+ --attention_type triangle_start
+
+ # Generate combined heatmap (both MSA Row and Triangle Start)
+ python scripts/generate_attention_heatmaps.py \\
+ --attention_dir ./outputs/attention_files_6KWC_demo_tri_18 \\
+ --output_dir ./outputs/heatmap_visualizations \\
+ --protein 6KWC \\
+ --layer 47 \\
+ --attention_type combined
+
+ # Generate heatmaps for multiple layers
+ python scripts/generate_attention_heatmaps.py \\
+ --attention_dir ./outputs/attention_files_6KWC_demo_tri_18 \\
+ --output_dir ./outputs/heatmap_visualizations \\
+ --protein 6KWC \\
+ --layers 40 45 47 50 \\
+ --attention_type msa_row
+ """
+ )
+
+ # Required arguments
+ parser.add_argument(
+ "--attention_dir",
+ type=str,
+ required=True,
+ help="Directory containing attention text files"
+ )
+
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Directory to save output PNG files"
+ )
+
+ parser.add_argument(
+ "--protein",
+ type=str,
+ required=True,
+ help="Protein identifier (e.g., '6KWC')"
+ )
+
+ # Optional arguments
+ parser.add_argument(
+ "--attention_type",
+ type=str,
+ choices=["msa_row", "triangle_start", "combined"],
+ default="combined",
+ help="Type of attention to visualize (default: combined)"
+ )
+
+ parser.add_argument(
+ "--layer",
+ type=int,
+ default=47,
+ help="Layer number to visualize (default: 47)"
+ )
+
+ parser.add_argument(
+ "--layers",
+ type=int,
+ nargs="+",
+ help="Multiple layer numbers to visualize (overrides --layer)"
+ )
+
+ parser.add_argument(
+ "--seq_length",
+ type=int,
+ help="Sequence length (auto-detect if not specified)"
+ )
+
+ parser.add_argument(
+ "--fasta_path",
+ type=str,
+ default="./examples/monomer/fasta_dir_6KWC/6KWC.fasta",
+ help="Path to FASTA file for sequence length detection"
+ )
+
+ parser.add_argument(
+ "--normalization",
+ type=str,
+ choices=["global", "per_head"],
+ default="global",
+ help="Normalization method (default: global)"
+ )
+
+ parser.add_argument(
+ "--colormap",
+ type=str,
+ default="viridis",
+ help="Matplotlib colormap name (default: viridis)"
+ )
+
+ parser.add_argument(
+ "--figsize_per_head",
+ type=float,
+ nargs=2,
+ default=[2.0, 2.0],
+ metavar=("WIDTH", "HEIGHT"),
+ help="Size of each subplot in inches (default: 2.0 2.0)"
+ )
+
+ parser.add_argument(
+ "--dpi",
+ type=int,
+ default=300,
+ help="Output resolution in DPI (default: 300)"
+ )
+
+ parser.add_argument(
+ "--residue_indices",
+ nargs='+',
+ type=int,
+ help="Residue indices for triangle_start attention (default: [18])"
+ )
+
+ parser.add_argument(
+ "--verbose",
+ action="store_true",
+ help="Enable verbose output"
+ )
+
+ args = parser.parse_args()
+
+ # Validate inputs
+ if not os.path.exists(args.attention_dir):
+ print(f"Error: Attention directory not found: {args.attention_dir}")
+ sys.exit(1)
+
+ # Determine layers to process
+ if args.layers:
+ layers = args.layers
+ else:
+ layers = [args.layer]
+
+ # Create output directory
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Process each layer
+ for layer_idx in layers:
+ print(f"\nProcessing layer {layer_idx}...")
+
+ try:
+ if args.attention_type == "combined":
+ # Generate combined heatmap
+ output_path = plot_combined_attention_heatmap(
+ attention_dir=args.attention_dir,
+ output_dir=args.output_dir,
+ protein=args.protein,
+ layer_idx=layer_idx,
+ seq_length=args.seq_length,
+ fasta_path=args.fasta_path,
+ normalization=args.normalization,
+ colormap=args.colormap,
+ figsize_per_head=tuple(args.figsize_per_head),
+ dpi=args.dpi,
+ save_to_png=True,
+ residue_indices=args.residue_indices
+ )
+
+ if args.verbose:
+ print(f"Generated combined heatmap: {output_path}")
+
+ else:
+ # Generate individual attention type heatmap
+ output_path = plot_all_heads_heatmap(
+ attention_dir=args.attention_dir,
+ output_dir=args.output_dir,
+ protein=args.protein,
+ attention_type=args.attention_type,
+ layer_idx=layer_idx,
+ seq_length=args.seq_length,
+ fasta_path=args.fasta_path,
+ normalization=args.normalization,
+ colormap=args.colormap,
+ figsize_per_head=tuple(args.figsize_per_head),
+ dpi=args.dpi,
+ save_to_png=True,
+ residue_indices=args.residue_indices
+ )
+
+ if args.verbose:
+ print(f"Generated {args.attention_type} heatmap: {output_path}")
+
+ except FileNotFoundError as e:
+ print(f"Error: {e}")
+ continue
+ except Exception as e:
+ print(f"Error processing layer {layer_idx}: {e}")
+ continue
+
+ print(f"\nHeatmap generation complete! Outputs saved to: {args.output_dir}")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/visualize_attention_heatmap_utils.py b/visualize_attention_heatmap_utils.py
new file mode 100644
index 00000000..41799b0f
--- /dev/null
+++ b/visualize_attention_heatmap_utils.py
@@ -0,0 +1,698 @@
+"""
+Utilities for generating attention heatmap visualizations.
+
+This module consolidates the heatmap functionality from the heatmap PRs:
+
+1. Static Matplotlib PNG heatmaps for all heads.
+2. Static Matplotlib PNG combined heatmaps for MSA Row + Triangle Start.
+3. Interactive Plotly HTML heatmap grids.
+
+The heatmaps operate on exported attention text files and complement the
+existing arc diagram / PyMOL visualization workflow. This module does not
+import or invoke PyMOL.
+"""
+
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+
+from visualize_attention_arc_diagram_demo_utils import load_all_heads
+
+
+# =============================================================================
+# Shared helpers
+# =============================================================================
+
+
+def reconstruct_attention_matrix(connections, seq_length):
+ """
+ Reconstruct a full attention matrix from sparse top-K connections.
+ """
+ matrix = np.zeros((seq_length, seq_length))
+
+ for res1, res2, weight in connections:
+ if 0 <= res1 < seq_length and 0 <= res2 < seq_length:
+ matrix[res1, res2] = weight
+
+ return matrix
+
+
+def get_sequence_length_from_fasta(fasta_path):
+ """
+ Get sequence length from FASTA file.
+ """
+ with open(fasta_path, "r") as f:
+ lines = f.readlines()
+
+ seq_lines = [line.strip() for line in lines if not line.startswith(">")]
+ sequence = "".join(seq_lines)
+ return len(sequence)
+
+
+def _resolve_seq_length(seq_length=None, fasta_path=None):
+ """
+ Resolve sequence length from an explicit value or FASTA path.
+
+ This intentionally avoids a hard-coded fallback length, because using the
+ wrong sequence length can produce misleading heatmaps.
+ """
+ if seq_length is not None:
+ return seq_length
+
+ if fasta_path and os.path.exists(fasta_path):
+ return get_sequence_length_from_fasta(fasta_path)
+
+ raise ValueError("seq_length is required when fasta_path is not provided or does not exist")
+
+
+def _get_grid_shape(num_heads):
+ """
+ Choose subplot grid dimensions.
+ """
+ if num_heads <= 4:
+ return 1, num_heads
+ if num_heads <= 8:
+ return 2, 4
+ if num_heads <= 12:
+ return 3, 4
+ if num_heads <= 16:
+ return 4, 4
+
+ cols = min(4, int(np.ceil(np.sqrt(num_heads))))
+ rows = (num_heads + cols - 1) // cols
+ return rows, cols
+
+
+def _normalize_matrix(matrix, normalization, global_min=None, global_max=None):
+ """
+ Normalize an attention matrix using global or per-head normalization.
+ """
+ if normalization == "global":
+ min_val = global_min
+ max_val = global_max
+ elif normalization == "per_head":
+ min_val = np.nanmin(matrix)
+ max_val = np.nanmax(matrix)
+ else:
+ raise ValueError(f"Invalid normalization: {normalization}")
+
+ if max_val > min_val:
+ return (matrix - min_val) / (max_val - min_val)
+
+ return matrix
+
+
+def _get_attention_file(attention_dir, attention_type, layer_idx, residue_idx=None):
+ """
+ Build the attention file path for a given attention type.
+ """
+ if attention_type == "msa_row":
+ return os.path.join(attention_dir, f"msa_row_attn_layer{layer_idx}.txt")
+
+ if attention_type == "triangle_start":
+ if residue_idx is None:
+ raise ValueError("residue_idx is required for triangle_start attention")
+
+ return os.path.join(
+ attention_dir,
+ f"triangle_start_attn_layer{layer_idx}_residue_idx_{residue_idx}.txt",
+ )
+
+ raise ValueError(f"Unknown attention_type: {attention_type}")
+
+
+def _load_heads_for_attention_type(attention_dir, attention_type, layer_idx, residue_indices=None):
+ """
+ Load heads for msa_row or triangle_start attention.
+
+ For triangle_start, residue_indices is a list of candidate residue indices.
+ The first existing file with valid data is used, matching the behavior from
+ the second PR.
+ """
+ if attention_type == "msa_row":
+ file_path = _get_attention_file(attention_dir, attention_type, layer_idx)
+
+ if not os.path.exists(file_path):
+ print(f"[Warning] Missing file: {file_path}")
+ return {}, None
+
+ return load_all_heads(file_path, top_k=None), None
+
+ if attention_type == "triangle_start":
+ if residue_indices is None:
+ raise ValueError("residue_indices required for triangle_start attention")
+
+ for residue_idx in residue_indices:
+ file_path = _get_attention_file(
+ attention_dir,
+ attention_type,
+ layer_idx,
+ residue_idx=residue_idx,
+ )
+
+ if not os.path.exists(file_path):
+ print(f"[Warning] Missing file for residue {residue_idx}: {file_path}")
+ continue
+
+ heads = load_all_heads(file_path, top_k=None)
+
+ if heads:
+ return heads, residue_idx
+
+ print(f"[Warning] No attention data found in {file_path}")
+
+ print(f"[Warning] No valid attention data found for residues {residue_indices}")
+ return {}, None
+
+ raise ValueError(f"Invalid attention_type: {attention_type}")
+
+
+def _build_attention_matrices(heads, seq_length):
+ """
+ Convert loaded attention heads into dense matrices.
+ """
+ attention_matrices = {}
+
+ for head_idx, connections in heads.items():
+ if not connections:
+ continue
+
+ matrix = reconstruct_attention_matrix(connections, seq_length)
+ attention_matrices[head_idx] = matrix
+
+ return attention_matrices
+
+
+# =============================================================================
+# Static Matplotlib PNG heatmaps from PR 2
+# =============================================================================
+
+
+def plot_all_heads_heatmap(
+ attention_dir,
+ output_dir,
+ protein,
+ attention_type="msa_row",
+ layer_idx=47,
+ seq_length=None,
+ fasta_path=None,
+ normalization="global",
+ colormap="viridis",
+ figsize_per_head=(2.0, 2.0),
+ dpi=300,
+ save_to_png=True,
+ residue_indices=None,
+):
+ """
+ Generate heatmap grid showing all attention heads for a layer.
+
+ Args:
+ attention_dir: Directory containing attention text files.
+ output_dir: Directory to save output PNG.
+ protein: Protein identifier, e.g. "6KWC".
+ attention_type: "msa_row" or "triangle_start".
+ layer_idx: Layer number to visualize.
+ seq_length: Sequence length. If None, infer from fasta_path.
+ fasta_path: Path to FASTA file for sequence length detection.
+ normalization: "global" or "per_head".
+ colormap: Matplotlib colormap name.
+ figsize_per_head: Size of each subplot, in inches.
+ dpi: Output resolution.
+ save_to_png: Whether to save to PNG file.
+ residue_indices: List of residue indices for triangle_start.
+
+ Returns:
+ Output path if save_to_png=True, otherwise the Matplotlib figure.
+ Returns None if no valid heads are found.
+ """
+ if attention_type not in ["msa_row", "triangle_start"]:
+ raise ValueError(f"Invalid attention_type: {attention_type}")
+
+ if normalization not in ["global", "per_head"]:
+ raise ValueError(f"Invalid normalization: {normalization}")
+
+ if attention_type == "triangle_start" and residue_indices is None:
+ raise ValueError("residue_indices required for triangle_start attention")
+
+ seq_length = _resolve_seq_length(seq_length, fasta_path)
+
+ heads, selected_residue_idx = _load_heads_for_attention_type(
+ attention_dir=attention_dir,
+ attention_type=attention_type,
+ layer_idx=layer_idx,
+ residue_indices=residue_indices,
+ )
+
+ attention_matrices = _build_attention_matrices(heads, seq_length)
+ num_heads = len(attention_matrices)
+
+ if num_heads == 0:
+ print("[Warning] No valid attention heads to visualize")
+ return None
+
+ rows, cols = _get_grid_shape(num_heads)
+
+ fig_width = cols * figsize_per_head[0]
+ fig_height = rows * figsize_per_head[1]
+
+ fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
+
+ if num_heads == 1:
+ axes = [axes]
+ else:
+ axes = np.array(axes).flatten()
+
+ all_values = []
+ for matrix in attention_matrices.values():
+ all_values.extend(matrix.flatten())
+
+ all_values = np.array(all_values)
+ global_min = np.min(all_values)
+ global_max = np.max(all_values)
+
+ for i, (head_idx, matrix) in enumerate(sorted(attention_matrices.items())):
+ ax = axes[i]
+
+ normalized_matrix = _normalize_matrix(
+ matrix=matrix,
+ normalization=normalization,
+ global_min=global_min,
+ global_max=global_max,
+ )
+
+ im = ax.imshow(
+ normalized_matrix,
+ cmap=colormap,
+ aspect="auto",
+ interpolation="nearest",
+ )
+
+ ax.set_title(f"Head {head_idx}", fontsize=10, weight="bold")
+ ax.set_xlabel("Residue Position", fontsize=8)
+ ax.set_ylabel("Residue Position", fontsize=8)
+ ax.tick_params(axis="both", which="major", labelsize=6)
+ ax.set_xlim(0, 33)
+ ax.set_ylim(33, 0)
+
+ cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
+ cbar.ax.tick_params(labelsize=6)
+
+ for i in range(num_heads, len(axes)):
+ axes[i].axis("off")
+
+ plt.tight_layout()
+
+ if rows == 1:
+ plt.subplots_adjust(top=0.75)
+ title_y = 0.95
+ elif rows <= 2:
+ plt.subplots_adjust(top=0.85)
+ title_y = 0.98
+ else:
+ plt.subplots_adjust(top=0.90)
+ title_y = 0.99
+
+ title = f"{protein} {attention_type.replace('_', ' ').title()} Attention - Layer {layer_idx}"
+ if selected_residue_idx is not None:
+ title += f" - Residue {selected_residue_idx}"
+
+ fig.suptitle(title, fontsize=14, weight="bold", y=title_y)
+
+ if save_to_png:
+ os.makedirs(output_dir, exist_ok=True)
+
+ output_filename = f"{attention_type}_heatmap_layer_{layer_idx}_{protein}"
+ if selected_residue_idx is not None:
+ output_filename += f"_residue_{selected_residue_idx}"
+ output_filename += ".png"
+
+ output_path = os.path.join(output_dir, output_filename)
+
+ plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
+ plt.close(fig)
+
+ print(f"[Saved] Heatmap visualization to {output_path}")
+ return output_path
+
+ return fig
+
+
+def plot_combined_attention_heatmap(
+ attention_dir,
+ output_dir,
+ protein,
+ layer_idx=47,
+ seq_length=None,
+ fasta_path=None,
+ normalization="global",
+ colormap="viridis",
+ figsize_per_head=(1.5, 1.5),
+ dpi=300,
+ save_to_png=True,
+ residue_indices=None,
+):
+ """
+ Generate combined heatmap showing both MSA Row and Triangle Start attention.
+
+ Args:
+ attention_dir: Directory containing attention text files.
+ output_dir: Directory to save output PNG.
+ protein: Protein identifier, e.g. "6KWC".
+ layer_idx: Layer number to visualize.
+ seq_length: Sequence length. If None, infer from fasta_path.
+ fasta_path: Path to FASTA file for sequence length detection.
+ normalization: "global" or "per_head".
+ colormap: Matplotlib colormap name.
+ figsize_per_head: Size of each subplot, in inches.
+ dpi: Output resolution.
+ save_to_png: Whether to save to PNG file.
+ residue_indices: List of residue indices for triangle_start.
+
+ Returns:
+ Output path if save_to_png=True, otherwise the Matplotlib figure.
+ Returns None if no valid heads are found.
+ """
+ if normalization not in ["global", "per_head"]:
+ raise ValueError(f"Invalid normalization: {normalization}")
+
+ seq_length = _resolve_seq_length(seq_length, fasta_path)
+
+ msa_heads, _ = _load_heads_for_attention_type(
+ attention_dir=attention_dir,
+ attention_type="msa_row",
+ layer_idx=layer_idx,
+ )
+
+ if residue_indices is None:
+ residue_indices = [18]
+
+ tri_heads, selected_residue_idx = _load_heads_for_attention_type(
+ attention_dir=attention_dir,
+ attention_type="triangle_start",
+ layer_idx=layer_idx,
+ residue_indices=residue_indices,
+ )
+
+ msa_matrices = _build_attention_matrices(msa_heads, seq_length)
+ tri_matrices = _build_attention_matrices(tri_heads, seq_length)
+
+ plot_items = []
+
+ for head_idx, matrix in sorted(msa_matrices.items()):
+ plot_items.append(("MSA", head_idx, matrix))
+
+ for head_idx, matrix in sorted(tri_matrices.items()):
+ plot_items.append(("Tri", head_idx, matrix))
+
+ total_heads = len(plot_items)
+
+ if total_heads == 0:
+ print("[Warning] No valid attention heads to visualize")
+ return None
+
+ rows, cols = _get_grid_shape(total_heads)
+
+ fig_width = cols * figsize_per_head[0]
+ fig_height = rows * figsize_per_head[1]
+
+ fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
+
+ if total_heads == 1:
+ axes = [axes]
+ else:
+ axes = np.array(axes).flatten()
+
+ all_values = []
+ for _, _, matrix in plot_items:
+ all_values.extend(matrix.flatten())
+
+ all_values = np.array(all_values)
+ global_min = np.min(all_values)
+ global_max = np.max(all_values)
+
+ for i, (label, head_idx, matrix) in enumerate(plot_items):
+ ax = axes[i]
+
+ normalized_matrix = _normalize_matrix(
+ matrix=matrix,
+ normalization=normalization,
+ global_min=global_min,
+ global_max=global_max,
+ )
+
+ im = ax.imshow(
+ normalized_matrix,
+ cmap=colormap,
+ aspect="auto",
+ interpolation="nearest",
+ )
+
+ ax.set_title(f"{label} Head {head_idx}", fontsize=10, weight="bold")
+ ax.set_xlabel("Residue", fontsize=8)
+ ax.set_ylabel("Residue", fontsize=8)
+ ax.tick_params(axis="both", which="major", labelsize=6)
+
+ cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
+ cbar.ax.tick_params(labelsize=6)
+
+ for i in range(total_heads, len(axes)):
+ axes[i].axis("off")
+
+ plt.tight_layout()
+
+ if rows == 1:
+ plt.subplots_adjust(top=0.75)
+ title_y = 0.95
+ elif rows <= 2:
+ plt.subplots_adjust(top=0.85)
+ title_y = 0.98
+ else:
+ plt.subplots_adjust(top=0.90)
+ title_y = 0.99
+
+ title = f"{protein} Combined Attention Heatmaps - Layer {layer_idx}"
+ if selected_residue_idx is not None:
+ title += f" - Triangle Residue {selected_residue_idx}"
+
+ fig.suptitle(title, fontsize=14, weight="bold", y=title_y)
+
+ if save_to_png:
+ os.makedirs(output_dir, exist_ok=True)
+
+ output_filename = f"combined_attention_heatmap_layer_{layer_idx}_{protein}"
+ if selected_residue_idx is not None:
+ output_filename += f"_triangle_residue_{selected_residue_idx}"
+ output_filename += ".png"
+
+ output_path = os.path.join(output_dir, output_filename)
+
+ plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
+ plt.close(fig)
+
+ print(f"[Saved] Combined heatmap visualization to {output_path}")
+ return output_path
+
+ return fig
+
+
+# =============================================================================
+# Interactive Plotly HTML heatmaps from PR 1
+# =============================================================================
+
+
+def create_heatmap_grid(
+ attention_file,
+ seq_len,
+ layer_idx=47,
+ attention_type="msa_row",
+ output_html="heatmap_grid.html",
+ threshold=None,
+):
+ """
+ Create an interactive Plotly heatmap grid for all heads in one attention file.
+
+ This preserves the first PR's HTML heatmap behavior.
+ """
+ heads = load_all_heads(attention_file, top_k=None)
+ num_heads = len(heads)
+
+ if num_heads == 0:
+ print(f"No heads found in {attention_file}")
+ return None
+
+ cols = min(4, num_heads)
+ rows = (num_heads + cols - 1) // cols
+
+ all_weights = [
+ weight
+ for head_idx in sorted(heads.keys())
+ for _, _, weight in heads[head_idx]
+ ]
+
+ if threshold is not None:
+ all_weights = [weight for weight in all_weights if weight >= threshold]
+
+ global_min = min(all_weights) if all_weights else 0
+ global_max = max(all_weights) if all_weights else 1
+
+ per_head_mins = []
+ per_head_maxs = []
+
+ fig = make_subplots(
+ rows=rows,
+ cols=cols,
+ subplot_titles=[f"Head {i}" for i in sorted(heads.keys())],
+ horizontal_spacing=0.05,
+ vertical_spacing=0.15,
+ )
+
+ for idx, head_idx in enumerate(sorted(heads.keys())):
+ row = idx // cols + 1
+ col = idx % cols + 1
+
+ matrix = reconstruct_attention_matrix(heads[head_idx], seq_len)
+
+ if threshold is not None:
+ matrix[matrix < threshold] = np.nan
+
+ head_weights = [weight for _, _, weight in heads[head_idx]]
+
+ if threshold is not None:
+ head_weights = [weight for weight in head_weights if weight >= threshold]
+
+ head_min = min(head_weights) if head_weights else 0
+ head_max = max(head_weights) if head_weights else 1
+
+ per_head_mins.append(head_min)
+ per_head_maxs.append(head_max)
+
+ fig.add_trace(
+ go.Heatmap(
+ z=matrix,
+ colorscale="Blues",
+ zmin=global_min,
+ zmax=global_max,
+ showscale=(idx == 0),
+ colorbar=dict(x=1.02, len=0.3, title="Weight") if idx == 0 else None,
+ ),
+ row=row,
+ col=col,
+ )
+
+ fig.update_xaxes(
+ title_text="Residue",
+ row=row,
+ col=col,
+ showticklabels=False,
+ )
+ fig.update_yaxes(
+ title_text="Residue",
+ row=row,
+ col=col,
+ showticklabels=False,
+ )
+
+ title_text = f"{attention_type.upper()} Layer {layer_idx} - All Heads"
+ if threshold is not None:
+ title_text += f" (Threshold > {threshold})"
+
+ fig.update_layout(
+ title_text=title_text,
+ title_x=0.5,
+ height=350 * rows,
+ width=1200,
+ showlegend=False,
+ updatemenus=[
+ dict(
+ type="buttons",
+ direction="right",
+ x=0.6,
+ xanchor="left",
+ y=1.15,
+ yanchor="top",
+ buttons=[
+ dict(
+ label="Global Norm",
+ method="restyle",
+ args=[
+ {
+ "zmin": [global_min],
+ "zmax": [global_max],
+ "showscale": [True] + [False] * (num_heads - 1),
+ }
+ ],
+ ),
+ dict(
+ label="Per-Head Norm",
+ method="restyle",
+ args=[
+ {
+ "zmin": per_head_mins,
+ "zmax": per_head_maxs,
+ "showscale": [False] * num_heads,
+ }
+ ],
+ ),
+ ],
+ )
+ ],
+ )
+
+ fig.write_html(output_html)
+ print(f"Saved: {output_html}")
+
+ return fig
+
+
+def visualize_layer_attention(
+ attention_dir,
+ seq_len,
+ layer_idx=47,
+ attention_type="msa_row",
+ residue_idx=None,
+ output_dir="./outputs/attention_heatmaps",
+ threshold=None,
+):
+ """
+ Visualize layer-specific attention as an interactive Plotly heatmap grid.
+
+ This preserves the first PR's notebook-facing helper.
+ """
+ os.makedirs(output_dir, exist_ok=True)
+
+ if attention_type == "msa_row":
+ attention_file = os.path.join(attention_dir, f"msa_row_attn_layer{layer_idx}.txt")
+ output_html = os.path.join(output_dir, f"msa_row_layer{layer_idx}_heatmap_grid.html")
+
+ elif attention_type == "triangle_start":
+ if residue_idx is None:
+ raise ValueError("residue_idx required for triangle_start")
+
+ attention_file = os.path.join(
+ attention_dir,
+ f"triangle_start_attn_layer{layer_idx}_residue_idx_{residue_idx}.txt",
+ )
+ output_html = os.path.join(
+ output_dir,
+ f"triangle_start_layer{layer_idx}_res{residue_idx}_heatmap_grid.html",
+ )
+
+ else:
+ raise ValueError(f"Unknown attention_type: {attention_type}")
+
+ if not os.path.exists(attention_file):
+ print(f"File not found: {attention_file}")
+ return None
+
+ print(f"Processing: {attention_file}")
+
+ return create_heatmap_grid(
+ attention_file=attention_file,
+ seq_len=seq_len,
+ layer_idx=layer_idx,
+ attention_type=attention_type,
+ output_html=output_html,
+ threshold=threshold,
+ )
diff --git a/viz_attention_demo_base.ipynb b/viz_attention_demo_base.ipynb
index d10f0b1b..ecb35bff 100644
--- a/viz_attention_demo_base.ipynb
+++ b/viz_attention_demo_base.ipynb
@@ -29,6 +29,85 @@
"# In this case, the code will compute MSAs and alignments, which can take several hours\n"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4d0089b0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Generate heatmap visualizations.\n",
+ "# These complement the existing arc diagram / PyMOL visualizations.\n",
+ "from visualize_attention_heatmap_utils import (\n",
+ " plot_all_heads_heatmap,\n",
+ " plot_combined_attention_heatmap,\n",
+ " visualize_layer_attention,\n",
+ ")\n",
+ "\n",
+ "# Setup heatmap output directory\n",
+ "output_dir_heatmaps = os.path.join(IMAGE_OUTPUT_DIR, \"heatmaps\")\n",
+ "\n",
+ "# Generate MSA Row attention heatmap PNG\n",
+ "plot_all_heads_heatmap(\n",
+ " attention_dir=ATTN_MAP_DIR,\n",
+ " output_dir=output_dir_heatmaps,\n",
+ " protein=PROT,\n",
+ " attention_type=\"msa_row\",\n",
+ " layer_idx=LAYER_IDX,\n",
+ " fasta_path=FASTA_PATH,\n",
+ " normalization=\"per_head\",\n",
+ " colormap=\"viridis\",\n",
+ " figsize_per_head=(2.0, 2.0),\n",
+ " dpi=300,\n",
+ " save_to_png=True,\n",
+ " residue_indices=None,\n",
+ ")\n",
+ "\n",
+ "# Generate Triangle Start attention heatmap PNG\n",
+ "plot_all_heads_heatmap(\n",
+ " attention_dir=ATTN_MAP_DIR,\n",
+ " output_dir=output_dir_heatmaps,\n",
+ " protein=PROT,\n",
+ " attention_type=\"triangle_start\",\n",
+ " layer_idx=LAYER_IDX,\n",
+ " fasta_path=FASTA_PATH,\n",
+ " normalization=\"per_head\",\n",
+ " colormap=\"viridis\",\n",
+ " figsize_per_head=(2.0, 2.0),\n",
+ " dpi=300,\n",
+ " save_to_png=True,\n",
+ " residue_indices=[TRI_RESIDUE_IDX],\n",
+ ")\n",
+ "\n",
+ "# Generate combined MSA Row + Triangle Start heatmap PNG\n",
+ "plot_combined_attention_heatmap(\n",
+ " attention_dir=ATTN_MAP_DIR,\n",
+ " output_dir=output_dir_heatmaps,\n",
+ " protein=PROT,\n",
+ " layer_idx=LAYER_IDX,\n",
+ " fasta_path=FASTA_PATH,\n",
+ " normalization=\"per_head\",\n",
+ " colormap=\"viridis\",\n",
+ " figsize_per_head=(1.5, 1.5),\n",
+ " dpi=300,\n",
+ " save_to_png=True,\n",
+ " residue_indices=[TRI_RESIDUE_IDX],\n",
+ ")\n",
+ "\n",
+ "# Generate interactive Plotly heatmap grid inline\n",
+ "interactive_heatmap = visualize_layer_attention(\n",
+ " attention_dir=ATTN_MAP_DIR,\n",
+ " seq_len=len(residue_seq),\n",
+ " layer_idx=LAYER_IDX,\n",
+ " attention_type=\"msa_row\",\n",
+ " output_dir=os.path.join(output_dir_heatmaps, \"html\"),\n",
+ " threshold=None,\n",
+ ")\n",
+ "\n",
+ "if interactive_heatmap is not None:\n",
+ " interactive_heatmap.show()"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,