Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions finetune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Evaluation Experiments

## Directory Structure and Required Files

In the project folder, please ensure you have the following directories and files:

1. **Checkpoints Directory**

This directory should contain the trained model checkpoints. In particular, the following path must exist:

```
checkpoint/esm2_t33_650M_UR50D_esm2_t30_150M_UR50D_None_1_0.5_0.5_loss_large_1.0_foldseek_gearnet_1_1_0.5_0.5_loss_large_0.8_foldseek_gearnet_3
```

2. **Special Model Checkpoints**

In addition to the general checkpoints, please ensure that you have the following special model directories:

- **ISM Model:**
```
checkpoint/ISM/ism_model
```
- **ESM-s Model:**
```
checkpoint/ESM-s/esm_s_model
```

3. **Data Directory**

Create a directory named `saprot_data` to store the necessary huggingface dataset.


## ESM-2

```bash
bash slurm/ESM2_TEST.sh
```

## ISM and ESM-s

```bash
bash slurm/ESM2_O_TEST.sh
```

## AMPLIFY

```bash
bash slurm/AMPLIFY_TEST.sh
```





# Training Experiments

## Reliance (Using 8 GPUs)

- `Our.sh` and `Ablation1.sh` rely on `Reference.sh`.
- `Ablation3.sh` relies on `Ablation3r.sh`.
- `Ablation4.sh` relies on `Ablation4r.sh`.

## Reference Models

- **AMGEN1:** Train `"chandar-lab/AMPLIFY_120M"` on the validation set.
- **AMGEN2:** Train `"facebook/esm2_t30_150M_UR50D"` on the validation set.

Run:
```bash
sbatch slurm/Reference.sh
```


## Proposed Method

- **AMGEN3:** Train `"chandar-lab/AMPLIFY_350M"` using our proposed method.
- **AMGEN4:** Train `"facebook/esm2_t33_650M_UR50D"` using our proposed method.

Run:
```bash
sbatch slurm/Our.sh
```

## Ablation Studies

### Ablation 1: Loss Weight


- Change `loss_weight` from `[1, 0.5, 0.5]` to `[1, 0.0, 0.5]`.
- Change `loss_weight` from `[1, 0.5, 0.5]` to `[1, 0.5, 0.0]`.
- Change `loss_weight` from `[1, 0.5, 0.5]` to `[1, 0.0, 0.0]`.

Run:
```bash
sbatch slurm/Ablation1.sh
```

### Ablation 2: Sample Mode

- Remove the reference model and use `loss_large`.
- Remove the reference model and use `loss_small`.
- Remove the reference model and `ratio=1.0`

Run:
```bash
sbatch slurm/Ablation2.sh
```

### Ablation 3: Structural Token Type

- Change `struc_token_type` from `foldseek` to `protoken`.
- Change `struc_token_type` from `foldseek` to `aido`.

Run:
```bash
sbatch slurm/Ablation3r.sh
sbatch slurm/Ablation3.sh
```

### Ablation 4: Structural Embedding Type

- Change `struc_embed_type` from `gearnet` to `af2`.

Run:
```bash
sbatch slurm/Ablation4r.sh
sbatch slurm/Ablation4.sh
```
2 changes: 2 additions & 0 deletions finetune/config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- experiments: esm2_t30_150M_UR50D
27 changes: 27 additions & 0 deletions finetune/config/experiments/AMPLIFY_120M.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
seed: 1

mode: train
prt_model_name: chandar-lab/AMPLIFY_120M
loss_weight: [1, 0.5, 0.5]
struc_token_type: foldseek
struc_embed_type: gearnet

reference_model: null
reference_prt_model_name: null

train_data_type: valid_train
valid_data_type: valid_valid
sample_mode: loss_large
ratio: 1.0

n_epochs: 200
batch_size: 8
opt_interval: 64 # 8 * 8 * 64 = 4096
eval_steps: 15
precision: bf16

prefix_path: /network/scratch/c/can.chen/datasets/pdb_data
# prefix_path has three folders: 1. important_data to save key_name related; 2. af2_embedding; 3. gearnet_embedding

save_steps: 15
resume: True
26 changes: 26 additions & 0 deletions finetune/config/experiments/AMPLIFY_350M.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
seed: 1

mode: train
prt_model_name: chandar-lab/AMPLIFY_350M
loss_weight: [1, 0.5, 0.5]
struc_token_type: foldseek
struc_embed_type: gearnet

reference_model: null
reference_prt_model_name: null

train_data_type: train
sample_mode: loss_large
ratio: 0.8

n_epochs: 25
batch_size: 8
opt_interval: 80 # 8 * 8 * 80 * 0.8 = 4096
eval_steps: 15
precision: bf16

prefix_path: /network/scratch/c/can.chen/datasets/pdb_data
# prefix_path has three folders: 1. important_data to save key_name related; 2. af2_embedding; 3. gearnet_embedding

save_steps: 15
resume: True
15 changes: 15 additions & 0 deletions finetune/config/experiments/AMPLIFY_TEST.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
seed: 42

mode: downstream
task_names: ["saprot_data/HumanPPI", "saprot_data/MetalIonBinding_AF2", "saprot_data/Thermostability", "saprot_data/GO_AF2_MF", "saprot_data/GO_AF2_CC", "saprot_data/GO_AF2_BP", "saprot_data/EC_AF2", "saprot_data/DeepLoc_cls10", 'Bo1015/ssp_q3', 'Bo1015/fold_prediction', 'Bo1015/contact_prediction_binary', 'Bo1015/enzyme_catalytic_efficiency', 'Bo1015/fitness_prediction', 'Bo1015/fluorescence_prediction', 'Bo1015/stability_prediction']

prt_model_name: chandar-lab/AMPLIFY_350M
ft_model_path: null


n_epochs: 200
batch_size: 8
opt_interval: 16 # 8 * 16 = 128
precision: 'bf16'

device: 0
15 changes: 15 additions & 0 deletions finetune/config/experiments/ESM2_TEST.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
seed: 42

mode: downstream
task_names: ["saprot_data/HumanPPI", "saprot_data/MetalIonBinding_AF2", "saprot_data/Thermostability", "saprot_data/GO_AF2_MF", "saprot_data/GO_AF2_CC", "saprot_data/GO_AF2_BP", "saprot_data/EC_AF2", "saprot_data/DeepLoc_cls10", 'Bo1015/ssp_q3', 'Bo1015/fold_prediction', 'Bo1015/contact_prediction_binary', 'Bo1015/enzyme_catalytic_efficiency', 'Bo1015/fitness_prediction', 'Bo1015/fluorescence_prediction', 'Bo1015/stability_prediction']

prt_model_name: facebook/esm2_t33_650M_UR50D
ft_model_path: null


n_epochs: 200
batch_size: 8
opt_interval: 16 # 8 * 16 = 128
precision: 'bf16'

device: 0
27 changes: 27 additions & 0 deletions finetune/config/experiments/esm2_t30_150M_UR50D.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
seed: 1

mode: train
prt_model_name: facebook/esm2_t30_150M_UR50D
loss_weight: [1, 0.5, 0.5]
struc_token_type: foldseek
struc_embed_type: gearnet

reference_model: null
reference_prt_model_name: null

train_data_type: valid_train
valid_data_type: valid_valid
sample_mode: loss_large
ratio: 1.0

n_epochs: 200
batch_size: 8
opt_interval: 64 # 8 * 8 * 64 = 4096
eval_steps: 15
precision: bf16

prefix_path: /network/scratch/c/can.chen/datasets/pdb_data
# prefix_path has three folders: 1. important_data to save key_name related; 2. af2_embedding; 3. gearnet_embedding

save_steps: 15
resume: True
26 changes: 26 additions & 0 deletions finetune/config/experiments/esm2_t33_650M_UR50D.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
seed: 1

mode: train
prt_model_name: facebook/esm2_t33_650M_UR50D
loss_weight: [1, 0.5, 0.5]
struc_token_type: foldseek
struc_embed_type: gearnet

reference_model: null
reference_prt_model_name: null

train_data_type: train
sample_mode: loss_large
ratio: 0.8

n_epochs: 25
batch_size: 8
opt_interval: 80 # 8 * 8 * 80 * 0.8 = 4096
eval_steps: 15
precision: bf16

prefix_path: /network/scratch/c/can.chen/datasets/pdb_data
# prefix_path has three folders: 1. important_data to save key_name related; 2. af2_embedding; 3. gearnet_embedding

save_steps: 15
resume: True
28 changes: 28 additions & 0 deletions finetune/config/task.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"Bo1015/contact_prediction_binary": {"loss_type": "classification", "output_type": "residue", "num_labels": 2, "metric": "long_range_precision_at_L"}
"Bo1015/fold_prediction": {"loss_type": "classification", "output_type": "protein", "num_labels": 1195, "metric": "accuracy"}
"Bo1015/ssp_q3": {"loss_type": "classification", "output_type": "residue", "num_labels": 3, "metric": "accuracy"}
"Bo1015/solubility_prediction": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "accuracy"}
"Bo1015/stability_prediction": {"loss_type": "regression", "output_type": "protein", "num_labels": 1, "metric": "spearman"}
"Bo1015/temperature_stability": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "mcc"}
"Bo1015/optimal_temperature": {"loss_type": "regression", "output_type": "protein", "num_labels": 1, "metric": "srcc"}
"Bo1015/optimal_ph": {"loss_type": "regression", "output_type": "protein", "num_labels": 1, "metric": "srcc"}
"Bo1015/cloning_clf": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "auc"}
"Bo1015/material_production": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "auc"}
"Bo1015/metal_ion_binding": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "accuracy"}
"Bo1015/enzyme_catalytic_efficiency": {"loss_type": "regression", "output_type": "protein", "num_labels": 1, "metric": "pcc"}
"Bo1015/peptide_HLA_MHC_affinity": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "auc"}
"Bo1015/tcr_pmhc_affinity": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "auc"}
"Bo1015/antibiotic_resistance": {"loss_type": "classification", "output_type": "protein", "num_labels": 19, "metric": "accuracy"}
"Bo1015/fluorescence_prediction": {"loss_type": "regression", "output_type": "protein", "num_labels": 1, "metric": "spearman"}
"Bo1015/fitness_prediction": {"loss_type": "regression", "output_type": "protein", "num_labels": 1, "metric": "spearman"}
"Bo1015/localization_prediction": {"loss_type": "classification", "output_type": "protein", "num_labels": 10, "metric": "accuracy"}

"saprot_data/DeepLoc_cls10": {"loss_type": "classification", "output_type": "protein", "num_labels": 10, "metric": "accuracy"}
"saprot_data/DeepLoc_cls2": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "accuracy"}
"saprot_data/EC_AF2": {"loss_type": "multi_classification", "output_type": "protein", "num_labels": 585, "metric": "fmax"}
"saprot_data/GO_AF2_MF": {"loss_type": "multi_classification", "output_type": "protein", "num_labels": 489, "metric": "fmax"}
"saprot_data/GO_AF2_CC": {"loss_type": "multi_classification", "output_type": "protein", "num_labels": 320, "metric": "fmax"}
"saprot_data/GO_AF2_BP": {"loss_type": "multi_classification", "output_type": "protein", "num_labels": 1943, "metric": "fmax"}
"saprot_data/MetalIonBinding_AF2": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "accuracy"}
"saprot_data/Thermostability": {"loss_type": "regression", "output_type": "protein", "num_labels": 1, "metric": "spearman"}
"saprot_data/HumanPPI": {"loss_type": "classification", "output_type": "protein", "num_labels": 2, "metric": "accuracy"}
4 changes: 4 additions & 0 deletions finetune/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""This file stores constants we often use"""

MEAN_SEQ_LEN = 240.50554219467827
MEAN_MASK_SEQ_LEN = MEAN_SEQ_LEN * 0.15
Loading