Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,7 @@ ggg_data/
ggg_utils/
saved_models
src/analysis/orca/orca
src/analysis/orca/tmp*
src/analysis/orca/tmp*
sync.sh
.vscode
venv
2 changes: 1 addition & 1 deletion configs/experiment/comm20.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ general:
log_every_steps: 50
number_chain_steps: 50 # Number of frames in each gif
final_model_samples_to_generate: 20
final_model_samples_to_save:
final_model_samples_to_save: 20
final_model_chains_to_save: 10

train:
Expand Down
12 changes: 7 additions & 5 deletions src/analysis/orca/orca.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,7 @@ int writeResults(int g, const char* output_filename) {
fout << endl;
}
fout.close();
return 0;
}

string writeResultsString(int g) {
Expand Down Expand Up @@ -1385,6 +1386,7 @@ int writeEdgeResults(int g, const char* output_filename) {
fout << endl;
}
fout.close();
return 0;
}

string writeEdgeResultsString(int g) {
Expand All @@ -1401,7 +1403,7 @@ string writeEdgeResultsString(int g) {
}

int motif_counts(const char* orbit_type, int graphlet_size,
const char* input_filename, const char* output_filename, string &out_str) {
const char* input_filename, const char* output_filename, int &out_code) {
fstream fin; // input and output files
// open input, output files
if (strcmp(orbit_type, "node")!=0 && strcmp(orbit_type, "edge")!=0) {
Expand Down Expand Up @@ -1487,7 +1489,7 @@ int motif_counts(const char* orbit_type, int graphlet_size,
if (strcmp(output_filename, "std") == 0) {
cout << "orbit counts: \n" << writeResultsString(graphlet_size) << endl;
} else {
out_str = writeResults(graphlet_size, output_filename);
out_code = writeResults(graphlet_size, output_filename);
}
} else {
printf("Counting EDGE orbits of graphlets on %d nodes.\n\n",graphlet_size);
Expand All @@ -1496,7 +1498,7 @@ int motif_counts(const char* orbit_type, int graphlet_size,
if (strcmp(output_filename, "std") == 0) {
cout << "orbit counts: \n" << writeEdgeResultsString(graphlet_size) << endl;
} else {
out_str = writeEdgeResults(graphlet_size, output_filename);
out_code = writeEdgeResults(graphlet_size, output_filename);
}
}

Expand All @@ -1511,8 +1513,8 @@ int init(int argc, char *argv[]) {
}
int graphlet_size;
sscanf(argv[2],"%d", &graphlet_size);
string out;
motif_counts(argv[1], graphlet_size, argv[3], argv[4], out);
int result;
motif_counts(argv[1], graphlet_size, argv[3], argv[4], result);

return 1;
}
Expand Down
8 changes: 5 additions & 3 deletions src/analysis/spectre_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,11 +772,15 @@ def forward(self, generated_graphs: list, name, current_epoch, val_counter, loca

np.savez('generated_adjs.npz', *adjacency_matrices)

to_log = {}

if 'degree' in self.metrics_list:
if local_rank == 0:
print("Computing degree stats..")
degree = degree_stats(reference_graphs, networkx_graphs, is_parallel=True,
compute_emd=self.compute_emd)

to_log['degree'] = degree
if wandb.run:
wandb.run.summary['degree'] = degree

Expand All @@ -786,8 +790,6 @@ def forward(self, generated_graphs: list, name, current_epoch, val_counter, loca
# eigval_stats(eig_ref_list, eig_pred_list, max_eig=20, is_parallel=True, compute_emd=False)
# spectral_filter_stats(eigvec_ref_list, eigval_ref_list, eigvec_pred_list, eigval_pred_list, is_parallel=False,
# compute_emd=False) # This is the one called wavelet
to_log = {}

if 'spectre' in self.metrics_list:
if local_rank == 0:
print("Computing spectre stats...")
Expand All @@ -796,7 +798,7 @@ def forward(self, generated_graphs: list, name, current_epoch, val_counter, loca

to_log['spectre'] = spectre
if wandb.run:
wandb.run.summary['spectre'] = spectre
wandb.run.summary['spectre'] = spectre

if 'clustering' in self.metrics_list:
if local_rank == 0:
Expand Down
66 changes: 13 additions & 53 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,10 @@


def get_resume(cfg, model_kwargs):
""" Resumes a run. It loads previous config without allowing to update keys (used for testing). """
saved_cfg = cfg.copy()
name = cfg.general.name + '_resume'
resume = cfg.general.test_only
if cfg.model.type == 'discrete':
model = DiscreteDenoisingDiffusion.load_from_checkpoint(resume, **model_kwargs)
else:
model = LiftedDenoisingDiffusion.load_from_checkpoint(resume, **model_kwargs)
cfg = model.cfg
cfg.general.test_only = resume
cfg.general.name = name
cfg = utils.update_config_with_new_keys(cfg, saved_cfg)
return cfg, model


def get_resume_adaptive(cfg, model_kwargs):
""" Resumes a run. It loads previous config but allows to make some changes (used for resuming training)."""
saved_cfg = cfg.copy()
# Fetch path to this file to get base path
current_path = os.path.dirname(os.path.realpath(__file__))
root_dir = current_path.split('outputs')[0]

resume_path = os.path.join(root_dir, cfg.general.resume)

if cfg.model.type == 'discrete':
model = DiscreteDenoisingDiffusion.load_from_checkpoint(resume_path, **model_kwargs)
else:
model = LiftedDenoisingDiffusion.load_from_checkpoint(resume_path, **model_kwargs)
new_cfg = model.cfg

for category in cfg:
for arg in cfg[category]:
new_cfg[category][arg] = cfg[category][arg]

new_cfg.general.resume = resume_path
new_cfg.general.name = new_cfg.general.name + '_resume'

new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg)
return new_cfg, model
""" Resumes a run from current config (assume the core configs like model arch keep the same). """
cfg.general.name += "_resume"
model_class = DiscreteDenoisingDiffusion if cfg.model.type == 'discrete' else LiftedDenoisingDiffusion
return cfg, model_class(cfg, **model_kwargs)



Expand Down Expand Up @@ -150,22 +115,17 @@ def main(cfg: DictConfig):
else:
raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"]))

if cfg.general.test_only:
# When testing, previous configuration is fully loaded
cfg, _ = get_resume(cfg, model_kwargs)
os.chdir(cfg.general.test_only.split('checkpoints')[0])
elif cfg.general.resume is not None:
# When resuming, we can override some parts of previous configuration
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
os.chdir(cfg.general.resume.split('checkpoints')[0])
current_ckpt = cfg.general.test_only or cfg.general.resume
if current_ckpt:
cfg, model = get_resume(cfg, model_kwargs)
ckpt_dir = current_ckpt.split('checkpoints')[0]
os.chdir(ckpt_dir)
else:
model_class = DiscreteDenoisingDiffusion if cfg.model.type == 'discrete' else LiftedDenoisingDiffusion
model = model_class(cfg=cfg, **model_kwargs)

utils.create_folders(cfg)

if cfg.model.type == 'discrete':
model = DiscreteDenoisingDiffusion(cfg=cfg, **model_kwargs)
else:
model = LiftedDenoisingDiffusion(cfg=cfg, **model_kwargs)

callbacks = []
if cfg.train.save_model:
checkpoint_callback = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}",
Expand All @@ -188,7 +148,7 @@ def main(cfg: DictConfig):

use_gpu = cfg.general.gpus > 0 and torch.cuda.is_available()
trainer = Trainer(gradient_clip_val=cfg.train.clip_grad,
strategy="ddp_find_unused_parameters_true", # Needed to load old checkpoints
# strategy="ddp" if cfg.general.gpus > 1 else "auto",
accelerator='gpu' if use_gpu else 'cpu',
devices=cfg.general.gpus if use_gpu else 1,
max_epochs=cfg.train.n_epochs,
Expand Down