-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
51 lines (46 loc) · 1.95 KB
/
train.py
File metadata and controls
51 lines (46 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Run training as specified in a config file.
"""
import hydra
from omegaconf import DictConfig, OmegaConf
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import json
from data.curriculum import new_curriculum
from module import ToyDataFlow, ImageFlow, FlowMatchingModule
from util import read, DummyDataloader
@hydra.main(version_base=None, config_path="config", config_name="config")
def run_training(hparams: DictConfig) -> None:
print(OmegaConf.to_yaml(hparams))
# select task
model: FlowMatchingModule
if hparams.data.dataset in ["pinwheel", "simplex_stark", "simplex_toy", "gaussian_mixture", "coupled_binary"]:
model = ToyDataFlow(hparams)
elif hparams.data.dataset in ["mnist", "cityscapes"]:
model = ImageFlow(hparams)
else:
raise NotImplementedError
curriculum = new_curriculum(hparams.data)
checkpointing = ModelCheckpoint(
every_n_train_steps=read("check_interval_batches", hparams.logging, default=None),
save_top_k=read("checkpoint_topk", hparams.logging, default=1),
every_n_epochs=read("check_interval_epochs", hparams.logging, default=None)
)
if hparams.data.dataset == "simplex_stark":
val_dl = DummyDataloader(hparams.data.num_val_batches)
else:
val_dl = DummyDataloader()
trainer = L.Trainer(
accelerator="auto",
fast_dev_run=read("debug", hparams, default=False),
max_epochs=hparams.training["epochs"],
max_steps=read("steps", hparams.training, default=-1),
check_val_every_n_epoch=read("eval_interval_epochs", hparams.logging, default=None),
val_check_interval=read("eval_interval_batches", hparams.logging, default=1.0),
callbacks=[checkpointing]
)
batch_size = hparams.training["batch_size"]
train_dl = curriculum.dataloader(batch_size=batch_size)
trainer.fit(model, train_dl, val_dataloaders=val_dl)
if __name__ == "__main__":
run_training()