diff --git a/example/1_extract_chains.py b/example/1_extract_chains.py index 4346c9f..89ea632 100644 --- a/example/1_extract_chains.py +++ b/example/1_extract_chains.py @@ -6,6 +6,7 @@ from tqdm import tqdm import multiprocessing as mp from pathlib import Path +import argparse # Define dictionary for three-letter to one-letter amino acid conversion protein_letters_3to1 = { @@ -57,15 +58,25 @@ def process_single_pdb(args): print(f"\nError processing {input_pdb}: {str(e)}") return None, None +def parse_args(): + """Parse command-line arguments""" + parser = argparse.ArgumentParser(description='Extract chains from PDB/CIF files') + parser.add_argument('-i', '--input_dir', type=str, required=True, + help='Input directory containing PDB or CIF files') + parser.add_argument('-o', '--output_dir', type=str, required=True, + help='Output directory for individual chain CIF files') + return parser.parse_args() + def main(): - input_dir = "./pdb" # Input directory - output_dir_cif = "./complex_chain_cifs" # CIF output directory - + cmd_args = parse_args() + input_dir = cmd_args.input_dir + output_dir_cif = cmd_args.output_dir + # Create output directory os.makedirs(output_dir_cif, exist_ok=True) - - # Get all PDB files - pdb_files = list(Path(input_dir).glob("*.pdb")) + + # Get all PDB and CIF files + pdb_files = list(Path(input_dir).glob("*.pdb")) + list(Path(input_dir).glob("*.cif")) # Prepare parameters for process pool args = [(str(f), output_dir_cif) for f in pdb_files] @@ -112,8 +123,9 @@ def chain_sort_key(chain_id): df = df[cols] # Save CSV file - df.to_csv('complex_chain_sequences.csv', index=False) - print("\nSequence information has been saved to complex_chain_sequences.csv") + csv_output_path = os.path.join(output_dir_cif, 'complex_chain_sequences.csv') + df.to_csv(csv_output_path, index=False) + print(f"\nSequence information has been saved to {csv_output_path}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/example/2_pdb2jax.py b/example/2_pdb2jax.py index c1f78c5..99220b2 100755 --- a/example/2_pdb2jax.py +++ b/example/2_pdb2jax.py @@ -11,6 +11,7 @@ from tqdm import tqdm import warnings from Bio.PDB.PDBExceptions import PDBConstructionWarning +import argparse # Disable Bio.PDB warnings warnings.filterwarnings('ignore', category=PDBConstructionWarning) @@ -306,11 +307,21 @@ def process_pdb_folder( except Exception as e: print(f"Error processing {os.path.basename(input_path)}: {str(e)}") +def parse_args(): + """Parse command-line arguments""" + parser = argparse.ArgumentParser(description='Convert PDB files to JAX traced arrays in H5 format') + parser.add_argument('-i', '--input_dir', type=str, required=True, + help='Input directory containing PDB files') + parser.add_argument('-o', '--output_dir', type=str, required=True, + help='Output directory for H5 files') + return parser.parse_args() + def main(): """Main function""" - pdb_folder = "./pdb" - output_folder = "./complex_h5" - + cmd_args = parse_args() + pdb_folder = cmd_args.input_dir + output_folder = cmd_args.output_dir + process_pdb_folder( pdb_folder=pdb_folder, output_folder=output_folder, diff --git a/example/3_generate_json.py b/example/3_generate_json.py index 5ff061a..59f765c 100755 --- a/example/3_generate_json.py +++ b/example/3_generate_json.py @@ -1,6 +1,7 @@ import json import pandas as pd import os +import argparse def format_msa_sequence(sequence): """Format MSA sequence""" @@ -93,9 +94,21 @@ def generate_json_files(csv_path, output_dir, cif_dir): print(f"\nComplete, generated {json_count} JSON files") +def parse_args(): + """Parse command-line arguments""" + parser = argparse.ArgumentParser(description='Generate JSON files from CSV sequences') + parser.add_argument('-c', '--csv_path', type=str, required=True, + help='Path to the CSV file containing chain sequences') + parser.add_argument('-i', '--input_dir', type=str, required=True, + help='Input directory containing CIF files') + parser.add_argument('-o', '--output_dir', type=str, required=True, + help='Output directory for JSON files') + return parser.parse_args() + if __name__ == "__main__": - csv_path = "./complex_chain_sequences.csv" # Path to the CSV file just generated - output_dir = "./complex_json_files" # Output directory for JSON files - cif_dir = "/lustre/grp/cmclab/liuyu/design/AF3Score/example/complex_chain_cifs" # Directory where CIF files are located - + cmd_args = parse_args() + csv_path = cmd_args.csv_path + output_dir = cmd_args.output_dir + cif_dir = cmd_args.input_dir + generate_json_files(csv_path, output_dir, cif_dir) \ No newline at end of file diff --git a/example/4_extract_iptm-ipae-pae-interaction.py b/example/4_extract_iptm-ipae-pae-interaction.py index 4bf8a41..7aafcb4 100644 --- a/example/4_extract_iptm-ipae-pae-interaction.py +++ b/example/4_extract_iptm-ipae-pae-interaction.py @@ -7,6 +7,7 @@ from functools import partial from tqdm import tqdm from collections import defaultdict +import argparse def get_interface_res_from_cif(cif_file, dist_cutoff=10): """Get interface residues from CIF file""" @@ -117,8 +118,16 @@ def process_row(base_path, row): metrics, error = get_metrics_from_json(base_path, target, description) return {'idx': row.name, 'metrics': metrics, 'error': error} -def update_sc_file(): - df = pd.read_csv("subset_data.csv", sep=',', low_memory=False) +def update_sc_file(input_dir, outfile): + """Update CSV file with AF3Score metrics + + Args: + input_dir: Directory containing score_results subdirectories + outfile: Output CSV file path + """ + # Assume the CSV file is in the input directory + csv_file = os.path.join(os.path.dirname(input_dir) if input_dir != "." else ".", "subset_data.csv") + df = pd.read_csv(csv_file, sep=',', low_memory=False) print("Columns in file:", df.columns.tolist()) if 'target' not in df.columns or 'description' not in df.columns: @@ -136,9 +145,9 @@ def update_sc_file(): df[metric] = np.nan print(f"\nTotal {len(df)} records to process") - + # Base path - base_path = "./score_results" + base_path = input_dir # Set number of processes num_processes = max(1, int(cpu_count() * 0.8)) @@ -166,7 +175,8 @@ def update_sc_file(): failed_records.append(f"line {result['idx']}: {result['error']}") # Write failed records to file - with open('failed_records.txt', 'w') as f: + failed_records_path = os.path.join(os.path.dirname(outfile) if os.path.dirname(outfile) else '.', 'failed_records.txt') + with open(failed_records_path, 'w') as f: f.write(f"Total processed rows: {len(df)}\n") f.write(f"Failed rows: {len(failed_records)}\n\n") f.write("Detailed failure records:\n") @@ -174,7 +184,7 @@ def update_sc_file(): f.write(f"{record}\n") # Save updated file - df.to_csv("subset_data_with_metrics.csv", sep=",", index=False) + df.to_csv(outfile, sep=",", index=False) # Output statistics print("\nProcessing completion statistics:") @@ -182,7 +192,17 @@ def update_sc_file(): for metric in metrics_list: print(f"Successfully updated {metric} count: {success_count[metric]}") print(f"Failed entries: {len(failed_records)}") - print(f"Failed records written to: failed_records.txt") + print(f"Failed records written to: {failed_records_path}") + +def parse_args(): + """Parse command-line arguments""" + parser = argparse.ArgumentParser(description='Extract iPTM, iPAE, and PAE interaction metrics') + parser.add_argument('-i', '--input_dir', type=str, required=True, + help='Input directory containing score_results') + parser.add_argument('-o', '--output', type=str, required=True, + help='Output CSV file path') + return parser.parse_args() if __name__ == "__main__": - update_sc_file() \ No newline at end of file + cmd_args = parse_args() + update_sc_file(cmd_args.input_dir, cmd_args.output) \ No newline at end of file diff --git a/run_af3score.py b/run_af3score.py index 1bb4a8d..cfdf16c 100644 --- a/run_af3score.py +++ b/run_af3score.py @@ -251,6 +251,14 @@ 'Number of samples to generate for each prediction.', ) +_ALLOW_INTERCHAIN_TEMPLATES = flags.DEFINE_bool( + 'allow_interchain_templates', + False, + 'Allow inter-chain template information for complex scoring (AF2Rank mode). ' + 'When True, template embeddings include inter-chain distances. ' + 'When False (default), only intra-chain template features are used.', +) + class ConfigurableModel(Protocol): """A model with a nested config class.""" @@ -315,12 +323,13 @@ def _model( assert isinstance(self._model_config, self._model_class.Config) @hk.transform # Transform regular function into 1. init (for parameter initialization) 2. apply (for forward pass) - def forward_fn(batch, init_guess=True, path='', num_samples=5): + def forward_fn(batch, init_guess=True, path='', num_samples=5, allow_interchain_templates=False): result = self._model_class(self._model_config)( - batch, - init_guess=init_guess, + batch, + init_guess=init_guess, path=path, - num_samples=num_samples + num_samples=num_samples, + allow_interchain_templates=allow_interchain_templates, ) result['__identifier__'] = self.model_params['__meta__']['__identifier__'] return result @@ -329,18 +338,19 @@ def forward_fn(batch, init_guess=True, path='', num_samples=5): jax.jit( forward_fn.apply, # Function to compile device=self._device, # jit parameters - static_argnames=('init_guess', 'path', 'num_samples') # jit parameters, both need to be marked as static - ), + static_argnames=('init_guess', 'path', 'num_samples', 'allow_interchain_templates') # jit parameters, both need to be marked as static + ), self.model_params # partial sets self.model_params as first parameter of compiled_fn ) def run_inference( - self, - featurised_example: features.BatchDict, + self, + featurised_example: features.BatchDict, rng_key: jnp.ndarray, init_guess: bool = True, path: str = '', num_samples: int = 5, + allow_interchain_templates: bool = False, ) -> base_model.ModelResult: """Computes a forward pass of the model on a featurised example.""" featurised_example = jax.device_put( @@ -350,11 +360,12 @@ def run_inference( self._device, ) result = self._model( - rng_key, - featurised_example, - init_guess, + rng_key, + featurised_example, + init_guess, path, - num_samples + num_samples, + allow_interchain_templates, ) result = jax.tree.map(np.asarray, result) result = jax.tree.map( @@ -404,6 +415,7 @@ def predict_structure( init_guess: bool = True, path: str = '', num_samples: int = 5, + allow_interchain_templates: bool = False, global_ccd = None, # Add global CCD parameter ) -> Sequence[ResultsForSeed]: """Runs the full inference pipeline to predict structures for each seed.""" @@ -430,11 +442,12 @@ def predict_structure( inference_start_time = time.time() rng_key = jax.random.PRNGKey(seed) result = model_runner.run_inference( - example, - rng_key, - init_guess=init_guess, + example, + rng_key, + init_guess=init_guess, path=path, - num_samples=num_samples + num_samples=num_samples, + allow_interchain_templates=allow_interchain_templates, ) print( f'Running model inference for seed {seed} took ' @@ -571,6 +584,7 @@ def process_fold_input( init_guess: bool = True, path: str = '', num_samples: int = 5, + allow_interchain_templates: bool = False, global_ccd = None, # Add global CCD parameter ) -> folding_input.Input | Sequence[ResultsForSeed]: """Runs data pipeline and/or inference on a single fold input. @@ -628,6 +642,7 @@ def process_fold_input( init_guess=init_guess, path=path, num_samples=num_samples, + allow_interchain_templates=allow_interchain_templates, global_ccd=global_ccd # Pass global CCD ) print( @@ -811,6 +826,7 @@ def main(_): init_guess=_INIT_GUESS.value, path=h5_path, num_samples=_NUM_SAMPLES.value, + allow_interchain_templates=_ALLOW_INTERCHAIN_TEMPLATES.value, global_ccd=global_ccd # Pass global CCD ) @@ -829,7 +845,8 @@ def main(_): buckets=tuple(int(bucket) for bucket in _BUCKETS.value), init_guess=_INIT_GUESS.value, path=_INIT_PATH.value, - num_samples=_NUM_SAMPLES.value + num_samples=_NUM_SAMPLES.value, + allow_interchain_templates=_ALLOW_INTERCHAIN_TEMPLATES.value, ) print('All processing completed successfully.') diff --git a/src/alphafold3/model/diffusion/model.py b/src/alphafold3/model/diffusion/model.py index a0db067..c70407f 100755 --- a/src/alphafold3/model/diffusion/model.py +++ b/src/alphafold3/model/diffusion/model.py @@ -274,12 +274,13 @@ def _sample_diffusion( def __call__( - self, - batch: features.BatchDict, - key: jax.Array | None = None, + self, + batch: features.BatchDict, + key: jax.Array | None = None, init_guess: bool = False, num_samples: int = 5, path: str = '', + allow_interchain_templates: bool = False, ) -> base_model.ModelResult: if path: print(f"Using path: {path}") @@ -302,6 +303,7 @@ def recycle_body(_, args): prev=prev, target_feat=target_feat, key=subkey, + allow_interchain_templates=allow_interchain_templates, ) embeddings['pair'] = embeddings['pair'].astype(jnp.float32) embeddings['single'] = embeddings['single'].astype(jnp.float32) @@ -699,6 +701,7 @@ def _embed_template_pair( pair_activations: jnp.ndarray, pair_mask: jnp.ndarray, key: jnp.ndarray, + allow_interchain_templates: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Embeds Templates and merges into pair activations.""" dtype = pair_activations.dtype @@ -710,7 +713,12 @@ def _embed_template_pair( asym_id = batch.token_features.asym_id # Construct a mask such that only intra-chain template features are # computed, since all templates are for each chain individually. - multichain_mask = (asym_id[:, None] == asym_id[None, :]).astype(dtype) + # When allow_interchain_templates=True, use all template information + # including inter-chain distances (AF2Rank-style scoring). + if allow_interchain_templates: + multichain_mask = jnp.ones((asym_id.shape[0], asym_id.shape[0]), dtype=dtype) + else: + multichain_mask = (asym_id[:, None] == asym_id[None, :]).astype(dtype) template_fn = functools.partial(template_module, key=subkey) template_act = template_fn( @@ -771,6 +779,7 @@ def __call__( prev: dict[str, jnp.ndarray], target_feat: jnp.ndarray, key: jnp.ndarray, + allow_interchain_templates: bool = False, ) -> dict[str, jnp.ndarray]: assert self.global_config.bfloat16 in {'all', 'none'} @@ -808,6 +817,7 @@ def __call__( pair_activations=pair_activations, pair_mask=pair_mask, key=key, + allow_interchain_templates=allow_interchain_templates, ) pair_activations, key = self._embed_process_msa( msa_batch=batch.msa,