Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions example/1_extract_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tqdm import tqdm
import multiprocessing as mp
from pathlib import Path
import argparse

# Define dictionary for three-letter to one-letter amino acid conversion
protein_letters_3to1 = {
Expand Down Expand Up @@ -57,15 +58,25 @@ def process_single_pdb(args):
print(f"\nError processing {input_pdb}: {str(e)}")
return None, None

def parse_args():
"""Parse command-line arguments"""
parser = argparse.ArgumentParser(description='Extract chains from PDB/CIF files')
parser.add_argument('-i', '--input_dir', type=str, required=True,
help='Input directory containing PDB or CIF files')
parser.add_argument('-o', '--output_dir', type=str, required=True,
help='Output directory for individual chain CIF files')
return parser.parse_args()

def main():
input_dir = "./pdb" # Input directory
output_dir_cif = "./complex_chain_cifs" # CIF output directory

cmd_args = parse_args()
input_dir = cmd_args.input_dir
output_dir_cif = cmd_args.output_dir

# Create output directory
os.makedirs(output_dir_cif, exist_ok=True)
# Get all PDB files
pdb_files = list(Path(input_dir).glob("*.pdb"))

# Get all PDB and CIF files
pdb_files = list(Path(input_dir).glob("*.pdb")) + list(Path(input_dir).glob("*.cif"))

# Prepare parameters for process pool
args = [(str(f), output_dir_cif) for f in pdb_files]
Expand Down Expand Up @@ -112,8 +123,9 @@ def chain_sort_key(chain_id):
df = df[cols]

# Save CSV file
df.to_csv('complex_chain_sequences.csv', index=False)
print("\nSequence information has been saved to complex_chain_sequences.csv")
csv_output_path = os.path.join(output_dir_cif, 'complex_chain_sequences.csv')
df.to_csv(csv_output_path, index=False)
print(f"\nSequence information has been saved to {csv_output_path}")

if __name__ == "__main__":
main()
17 changes: 14 additions & 3 deletions example/2_pdb2jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm
import warnings
from Bio.PDB.PDBExceptions import PDBConstructionWarning
import argparse

# Disable Bio.PDB warnings
warnings.filterwarnings('ignore', category=PDBConstructionWarning)
Expand Down Expand Up @@ -306,11 +307,21 @@ def process_pdb_folder(
except Exception as e:
print(f"Error processing {os.path.basename(input_path)}: {str(e)}")

def parse_args():
"""Parse command-line arguments"""
parser = argparse.ArgumentParser(description='Convert PDB files to JAX traced arrays in H5 format')
parser.add_argument('-i', '--input_dir', type=str, required=True,
help='Input directory containing PDB files')
parser.add_argument('-o', '--output_dir', type=str, required=True,
help='Output directory for H5 files')
return parser.parse_args()

def main():
"""Main function"""
pdb_folder = "./pdb"
output_folder = "./complex_h5"

cmd_args = parse_args()
pdb_folder = cmd_args.input_dir
output_folder = cmd_args.output_dir

process_pdb_folder(
pdb_folder=pdb_folder,
output_folder=output_folder,
Expand Down
21 changes: 17 additions & 4 deletions example/3_generate_json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import pandas as pd
import os
import argparse

def format_msa_sequence(sequence):
"""Format MSA sequence"""
Expand Down Expand Up @@ -93,9 +94,21 @@ def generate_json_files(csv_path, output_dir, cif_dir):

print(f"\nComplete, generated {json_count} JSON files")

def parse_args():
"""Parse command-line arguments"""
parser = argparse.ArgumentParser(description='Generate JSON files from CSV sequences')
parser.add_argument('-c', '--csv_path', type=str, required=True,
help='Path to the CSV file containing chain sequences')
parser.add_argument('-i', '--input_dir', type=str, required=True,
help='Input directory containing CIF files')
parser.add_argument('-o', '--output_dir', type=str, required=True,
help='Output directory for JSON files')
return parser.parse_args()

if __name__ == "__main__":
csv_path = "./complex_chain_sequences.csv" # Path to the CSV file just generated
output_dir = "./complex_json_files" # Output directory for JSON files
cif_dir = "/lustre/grp/cmclab/liuyu/design/AF3Score/example/complex_chain_cifs" # Directory where CIF files are located

cmd_args = parse_args()
csv_path = cmd_args.csv_path
output_dir = cmd_args.output_dir
cif_dir = cmd_args.input_dir

generate_json_files(csv_path, output_dir, cif_dir)
36 changes: 28 additions & 8 deletions example/4_extract_iptm-ipae-pae-interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial
from tqdm import tqdm
from collections import defaultdict
import argparse

def get_interface_res_from_cif(cif_file, dist_cutoff=10):
"""Get interface residues from CIF file"""
Expand Down Expand Up @@ -117,8 +118,16 @@ def process_row(base_path, row):
metrics, error = get_metrics_from_json(base_path, target, description)
return {'idx': row.name, 'metrics': metrics, 'error': error}

def update_sc_file():
df = pd.read_csv("subset_data.csv", sep=',', low_memory=False)
def update_sc_file(input_dir, outfile):
"""Update CSV file with AF3Score metrics

Args:
input_dir: Directory containing score_results subdirectories
outfile: Output CSV file path
"""
# Assume the CSV file is in the input directory
csv_file = os.path.join(os.path.dirname(input_dir) if input_dir != "." else ".", "subset_data.csv")
df = pd.read_csv(csv_file, sep=',', low_memory=False)
print("Columns in file:", df.columns.tolist())

if 'target' not in df.columns or 'description' not in df.columns:
Expand All @@ -136,9 +145,9 @@ def update_sc_file():
df[metric] = np.nan

print(f"\nTotal {len(df)} records to process")

# Base path
base_path = "./score_results"
base_path = input_dir

# Set number of processes
num_processes = max(1, int(cpu_count() * 0.8))
Expand Down Expand Up @@ -166,23 +175,34 @@ def update_sc_file():
failed_records.append(f"line {result['idx']}: {result['error']}")

# Write failed records to file
with open('failed_records.txt', 'w') as f:
failed_records_path = os.path.join(os.path.dirname(outfile) if os.path.dirname(outfile) else '.', 'failed_records.txt')
with open(failed_records_path, 'w') as f:
f.write(f"Total processed rows: {len(df)}\n")
f.write(f"Failed rows: {len(failed_records)}\n\n")
f.write("Detailed failure records:\n")
for record in failed_records:
f.write(f"{record}\n")

# Save updated file
df.to_csv("subset_data_with_metrics.csv", sep=",", index=False)
df.to_csv(outfile, sep=",", index=False)

# Output statistics
print("\nProcessing completion statistics:")
print(f"Total entries: {len(df)}")
for metric in metrics_list:
print(f"Successfully updated {metric} count: {success_count[metric]}")
print(f"Failed entries: {len(failed_records)}")
print(f"Failed records written to: failed_records.txt")
print(f"Failed records written to: {failed_records_path}")

def parse_args():
"""Parse command-line arguments"""
parser = argparse.ArgumentParser(description='Extract iPTM, iPAE, and PAE interaction metrics')
parser.add_argument('-i', '--input_dir', type=str, required=True,
help='Input directory containing score_results')
parser.add_argument('-o', '--output', type=str, required=True,
help='Output CSV file path')
return parser.parse_args()

if __name__ == "__main__":
update_sc_file()
cmd_args = parse_args()
update_sc_file(cmd_args.input_dir, cmd_args.output)
51 changes: 34 additions & 17 deletions run_af3score.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@
'Number of samples to generate for each prediction.',
)

_ALLOW_INTERCHAIN_TEMPLATES = flags.DEFINE_bool(
'allow_interchain_templates',
False,
'Allow inter-chain template information for complex scoring (AF2Rank mode). '
'When True, template embeddings include inter-chain distances. '
'When False (default), only intra-chain template features are used.',
)


class ConfigurableModel(Protocol):
"""A model with a nested config class."""
Expand Down Expand Up @@ -315,12 +323,13 @@ def _model(
assert isinstance(self._model_config, self._model_class.Config)

@hk.transform # Transform regular function into 1. init (for parameter initialization) 2. apply (for forward pass)
def forward_fn(batch, init_guess=True, path='', num_samples=5):
def forward_fn(batch, init_guess=True, path='', num_samples=5, allow_interchain_templates=False):
result = self._model_class(self._model_config)(
batch,
init_guess=init_guess,
batch,
init_guess=init_guess,
path=path,
num_samples=num_samples
num_samples=num_samples,
allow_interchain_templates=allow_interchain_templates,
)
result['__identifier__'] = self.model_params['__meta__']['__identifier__']
return result
Expand All @@ -329,18 +338,19 @@ def forward_fn(batch, init_guess=True, path='', num_samples=5):
jax.jit(
forward_fn.apply, # Function to compile
device=self._device, # jit parameters
static_argnames=('init_guess', 'path', 'num_samples') # jit parameters, both need to be marked as static
),
static_argnames=('init_guess', 'path', 'num_samples', 'allow_interchain_templates') # jit parameters, both need to be marked as static
),
self.model_params # partial sets self.model_params as first parameter of compiled_fn
)

def run_inference(
self,
featurised_example: features.BatchDict,
self,
featurised_example: features.BatchDict,
rng_key: jnp.ndarray,
init_guess: bool = True,
path: str = '',
num_samples: int = 5,
allow_interchain_templates: bool = False,
) -> base_model.ModelResult:
"""Computes a forward pass of the model on a featurised example."""
featurised_example = jax.device_put(
Expand All @@ -350,11 +360,12 @@ def run_inference(
self._device,
)
result = self._model(
rng_key,
featurised_example,
init_guess,
rng_key,
featurised_example,
init_guess,
path,
num_samples
num_samples,
allow_interchain_templates,
)
result = jax.tree.map(np.asarray, result)
result = jax.tree.map(
Expand Down Expand Up @@ -404,6 +415,7 @@ def predict_structure(
init_guess: bool = True,
path: str = '',
num_samples: int = 5,
allow_interchain_templates: bool = False,
global_ccd = None, # Add global CCD parameter
) -> Sequence[ResultsForSeed]:
"""Runs the full inference pipeline to predict structures for each seed."""
Expand All @@ -430,11 +442,12 @@ def predict_structure(
inference_start_time = time.time()
rng_key = jax.random.PRNGKey(seed)
result = model_runner.run_inference(
example,
rng_key,
init_guess=init_guess,
example,
rng_key,
init_guess=init_guess,
path=path,
num_samples=num_samples
num_samples=num_samples,
allow_interchain_templates=allow_interchain_templates,
)
print(
f'Running model inference for seed {seed} took '
Expand Down Expand Up @@ -571,6 +584,7 @@ def process_fold_input(
init_guess: bool = True,
path: str = '',
num_samples: int = 5,
allow_interchain_templates: bool = False,
global_ccd = None, # Add global CCD parameter
) -> folding_input.Input | Sequence[ResultsForSeed]:
"""Runs data pipeline and/or inference on a single fold input.
Expand Down Expand Up @@ -628,6 +642,7 @@ def process_fold_input(
init_guess=init_guess,
path=path,
num_samples=num_samples,
allow_interchain_templates=allow_interchain_templates,
global_ccd=global_ccd # Pass global CCD
)
print(
Expand Down Expand Up @@ -811,6 +826,7 @@ def main(_):
init_guess=_INIT_GUESS.value,
path=h5_path,
num_samples=_NUM_SAMPLES.value,
allow_interchain_templates=_ALLOW_INTERCHAIN_TEMPLATES.value,
global_ccd=global_ccd # Pass global CCD
)

Expand All @@ -829,7 +845,8 @@ def main(_):
buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
init_guess=_INIT_GUESS.value,
path=_INIT_PATH.value,
num_samples=_NUM_SAMPLES.value
num_samples=_NUM_SAMPLES.value,
allow_interchain_templates=_ALLOW_INTERCHAIN_TEMPLATES.value,
)

print('All processing completed successfully.')
Expand Down
18 changes: 14 additions & 4 deletions src/alphafold3/model/diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,13 @@ def _sample_diffusion(


def __call__(
self,
batch: features.BatchDict,
key: jax.Array | None = None,
self,
batch: features.BatchDict,
key: jax.Array | None = None,
init_guess: bool = False,
num_samples: int = 5,
path: str = '',
allow_interchain_templates: bool = False,
) -> base_model.ModelResult:
if path:
print(f"Using path: {path}")
Expand All @@ -302,6 +303,7 @@ def recycle_body(_, args):
prev=prev,
target_feat=target_feat,
key=subkey,
allow_interchain_templates=allow_interchain_templates,
)
embeddings['pair'] = embeddings['pair'].astype(jnp.float32)
embeddings['single'] = embeddings['single'].astype(jnp.float32)
Expand Down Expand Up @@ -699,6 +701,7 @@ def _embed_template_pair(
pair_activations: jnp.ndarray,
pair_mask: jnp.ndarray,
key: jnp.ndarray,
allow_interchain_templates: bool = False,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Embeds Templates and merges into pair activations."""
dtype = pair_activations.dtype
Expand All @@ -710,7 +713,12 @@ def _embed_template_pair(
asym_id = batch.token_features.asym_id
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask = (asym_id[:, None] == asym_id[None, :]).astype(dtype)
# When allow_interchain_templates=True, use all template information
# including inter-chain distances (AF2Rank-style scoring).
if allow_interchain_templates:
multichain_mask = jnp.ones((asym_id.shape[0], asym_id.shape[0]), dtype=dtype)
else:
multichain_mask = (asym_id[:, None] == asym_id[None, :]).astype(dtype)

template_fn = functools.partial(template_module, key=subkey)
template_act = template_fn(
Expand Down Expand Up @@ -771,6 +779,7 @@ def __call__(
prev: dict[str, jnp.ndarray],
target_feat: jnp.ndarray,
key: jnp.ndarray,
allow_interchain_templates: bool = False,
) -> dict[str, jnp.ndarray]:

assert self.global_config.bfloat16 in {'all', 'none'}
Expand Down Expand Up @@ -808,6 +817,7 @@ def __call__(
pair_activations=pair_activations,
pair_mask=pair_mask,
key=key,
allow_interchain_templates=allow_interchain_templates,
)
pair_activations, key = self._embed_process_msa(
msa_batch=batch.msa,
Expand Down