From cc58d245dc6f64f291d8ef9b584c663600d8a490 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 21 Mar 2024 15:33:15 +0000 Subject: [PATCH] wip --- src/gflownet/algo/envelope_q_learning.py | 97 +++++++++++++++--------- src/gflownet/envs/frag_mol_env.py | 2 - src/gflownet/envs/graph_building_env.py | 23 +++--- src/gflownet/tasks/make_rings.py | 4 +- src/gflownet/tasks/qm9.py | 4 +- src/gflownet/tasks/qm9_moo.py | 4 +- src/gflownet/tasks/seh_frag.py | 4 +- src/gflownet/tasks/seh_frag_moo.py | 6 +- src/gflownet/tasks/tmp_run_moql.py | 37 +++++++++ src/gflownet/tasks/toy_seq.py | 4 +- 10 files changed, 121 insertions(+), 64 deletions(-) create mode 100644 src/gflownet/tasks/tmp_run_moql.py diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 7adfd68c..4d5cdef7 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -42,65 +42,80 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective ) num_final = num_emb * 2 num_mlp_layers = 0 - self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers) + num_edge_emb = num_emb if env_ctx.edges_are_unordered else num_emb * 2 + self.emb2add_node = mlp(num_emb, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers) # Edge attr logits are "sided", so we will compute both sides independently self.emb2set_edge_attr = mlp( num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives, num_mlp_layers ) - self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers) - self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers) - self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers) + self.emb2stop = mlp(num_final, num_emb, num_objectives, num_mlp_layers) + self.emb2reward = mlp(num_final, num_emb, 1, num_mlp_layers) + self.edge2emb = mlp(num_final, num_emb, num_edge_emb, num_mlp_layers) self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) self.action_type_order = env_ctx.action_type_order self.mask_value = -10 self.num_objectives = num_objectives + self.edges_are_duplicated = env_ctx.edges_are_duplicated + self.edges_are_unordered = env_ctx.edges_are_unordered def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): """See `GraphTransformer` for argument values""" node_embeddings, graph_embeddings = self.transf(g, cond) - # On `::2`, edges are duplicated to make graphs undirected, only take the even ones - e_row, e_col = g.edge_index[:, ::2] - edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col]) - src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_row]], 1)) - dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_col]], 1)) + if self.edges_are_duplicated: + # On `::2`, edges are typically duplicated to make graphs undirected, only take the even ones + e_row, e_col = g.edge_index[:, ::2] + else: + e_row, e_col = g.edge_index + if self.edges_are_unordered: + edge_embeddings = node_embeddings[e_row] + node_embeddings[e_col] + else: + edge_embeddings = torch.cat([node_embeddings[e_row], node_embeddings[e_col]], 1) + edge_embeddings = self.edge2emb(edge_embeddings) + + src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_embeddings, node_embeddings[e_row]], 1)) + dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_embeddings, node_embeddings[e_col]], 1)) def _mask(x, m): # mask logit vector x with binary mask m return x * m + self.mask_value * (1 - m) - def _mask_obj(x, m): - # mask logit vector x with binary mask m - return ( - x.reshape(x.shape[0], x.shape[1] // self.num_objectives, self.num_objectives) * m[:, :, None] - + self.mask_value * (1 - m[:, :, None]) - ).reshape(x.shape) - + # evelope Q learning uses outputs Q-values for each action and each objective + # we duplicate the masks for each objective + add_node_masks = g.add_node_mask.repeat(1, self.num_objectives) + set_edge_attr_mask = g.set_edge_attr_mask.repeat(1, self.num_objectives) cat = GraphActionCategorical( g, logits=[ F.relu(self.emb2stop(graph_embeddings)), - _mask(F.relu(self.emb2add_node(node_embeddings)), g.add_node_mask), - _mask_obj(F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), g.set_edge_attr_mask), + _mask(F.relu(self.emb2add_node(node_embeddings)), add_node_masks), + _mask(F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), set_edge_attr_mask), ], keys=[None, "x", "edge_index"], types=self.action_type_order, ) r_pred = self.emb2reward(graph_embeddings) if output_Qs: + # we ouput the full set of Q-values when the model is used in training mode + return cat, r_pred + + else: + # if we don't output (e.g. for sampling new trajectories), we get the Q-values for the current omega + # and we need a single set of masks + cat.masks = [1, g.add_node_mask, g.set_edge_attr_mask] + + # Compute the greedy policy + # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations + # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes + w = cond[:, -self.num_objectives :] + w_dot_Q = [ + (qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2) + for qi, b in zip(cat.logits, cat.batch) + ] + # Set the softmax distribution to a very low temperature to make sure only the max gets + # sampled (and we get random argmax tie breaking for free!): + cat.logits = [i * 100 for i in w_dot_Q] + return cat, r_pred - cat.masks = [1, g.add_node_mask.cpu(), g.set_edge_attr_mask.cpu()] - # Compute the greedy policy - # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations - # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes - w = cond[:, -self.num_objectives :] - w_dot_Q = [ - (qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2) - for qi, b in zip(cat.logits, cat.batch) - ] - # Set the softmax distribution to a very low temperature to make sure only the max gets - # sampled (and we get random argmax tie breaking for free!): - cat.logits = [i * 100 for i in w_dot_Q] - return cat, r_pred class GraphTransformerEnvelopeQL(nn.Module): @@ -205,10 +220,21 @@ def __init__( self._num_updates = 0 assert self.gamma == 1 self.bootstrap_own_reward = False + self.global_cfg = cfg # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) + + def set_is_eval(self, is_eval: bool): + self.is_eval = is_eval + + def get_random_action_prob(self, it: int): + if self.is_eval: + return self.global_cfg.algo.valid_random_action_prob + if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after: + return self.global_cfg.algo.train_random_action_prob + return 0 def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float @@ -280,9 +306,10 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.log_rewards = log_rewards batch.cond_info = cond_info batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() + batch.preferences = torch.stack([i['cond_info']['preferences'] for i in trajs]) # Now we create a duplicate/repeated batch for Q(s,a,w') - omega_prime = self.task.sample_conditional_information(self.num_omega_samples * batch.num_graphs) + omega_prime = self.task.sample_conditional_information(self.num_omega_samples * batch.num_graphs) torch_graphs = [i for i in torch_graphs for j in range(self.num_omega_samples)] actions = [i for i in actions for j in range(self.num_omega_samples)] batch_prime = self.ctx.collate(torch_graphs) @@ -298,7 +325,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: Parameters ---------- - model: TrajectoryBalanceModel + model: GraphTransformerFragEnvelopeQL A GNN taking in a batch of graphs as input as per constructed by `self.construct_batch`. Must have a `logZ` attribute, itself a model, which predicts log of Z(cond_info) batch: gd.Batch @@ -370,9 +397,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # same (and thus the max is over all of the repeats as well). # Since the batch slices we will later index to get Q[:, argmax a, argmax omega'] are those # of Q_omega_prime, we need to use fwd_cat_prime. - argmax = fwd_cat_prime.argmax( - x=w_dot_Q, batch=[b.repeat_interleave(self.num_omega_samples) for b in fwd_cat.batch], dim_size=num_states - ) + argmax = fwd_cat_prime.argmax(x=w_dot_Q, batch=[b.repeat_interleave(self.num_omega_samples) for b in fwd_cat.batch], dim_size=num_states) # Now what we want, for each state, is the vector prediction made by Q(s, a, w') for the # argmax a,w'. Let's again reuse GraphActionCategorical methods to do the indexing for us. # We must again use fwd_cat_prime to use the right slices. diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 71a89f0c..0739e8ab 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -77,8 +77,6 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu # The + 1 is for an extra dimension to indicate when the attribute isn't yet set self.num_edge_dim = (most_stems + 1) * 2 self.num_cond_dim = num_cond_dim - self.edges_are_duplicated = True - self.edges_are_unordered = False self.fail_on_missing_attr = True # Order in which models have to output logits diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 5812ea49..c9f93229 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -701,19 +701,14 @@ def sample(self) -> List[Tuple[int, int, int]]: gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self.logits, u)] if self.masks is not None: - gumbel_safe = [ - torch.where( - mask == 1, - torch.maximum( - x, - torch.nextafter( - torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype) - ).to(x.device), - ), - torch.finfo(x.dtype).min, - ) - for x, mask in zip(gumbel, self.masks) - ] + gumbel_safe = [] + for x, mask in zip(gumbel, self.masks): + small = torch.nextafter(torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), + torch.tensor(0.0, dtype=x.dtype)).to(x.device) + if type(mask) != torch.Tensor: + mask = torch.tensor(mask, device=x.device) + g = torch.where(mask == 1, torch.maximum(x, small), torch.finfo(x.dtype).min) + gumbel_safe.append(g) else: gumbel_safe = gumbel # Take the argmax @@ -764,7 +759,7 @@ def argmax( # Now we can return the indices of where the actions occured # in the form List[(type, row, column)] - assert dim_size == type_max_idx.shape[0] + assert dim_size == type_max_idx.shape[0], f"dim_size {dim_size} != type_max_idx.shape[0] {type_max_idx.shape[0]}" argmaxes = [] for i in range(type_max_idx.shape[0]): t = type_max_idx[i] diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index 34f47924..99c36fe4 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -1,5 +1,5 @@ import socket -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, Optional import numpy as np import torch @@ -25,7 +25,7 @@ def __init__( def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(y) - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]: return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)} def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 5f489938..8b57f66e 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union, Optional import numpy as np import torch @@ -79,7 +79,7 @@ def load_task_models(self, path): gap_model = self._wrap_model(gap_model) return {"mxmnet_gap": gap_model} - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]: return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 51029e3a..c1c511d2 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union, Optional import numpy as np import torch @@ -70,7 +70,7 @@ def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: def inverse_flat_reward_transform(self, rp): return rp - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]: cond_info = super().sample_conditional_information(n, train_it) pref_ci = self.pref_cond.sample(n) focus_ci = ( diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e64d642d..8ccb3712 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,5 +1,5 @@ import socket -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union, Optional import numpy as np import torch @@ -57,7 +57,7 @@ def _load_task_models(self): model = self._wrap_model(model) return {"seh": model} - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]: return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index ef8def85..93d3ed9c 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Tuple, Optional import numpy as np import torch @@ -92,7 +92,7 @@ def __init__( def inverse_flat_reward_transform(self, rp): return rp - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]: cond_info = super().sample_conditional_information(n, train_it) pref_ci = self.pref_cond.sample(n) focus_ci = ( @@ -250,6 +250,7 @@ def setup_model(self): num_layers=self.cfg.model.num_layers, num_heads=self.cfg.model.graph_transformer.num_heads, num_objectives=len(self.cfg.task.seh_moo.objectives), + ) else: super().setup_model() @@ -395,6 +396,7 @@ def main(): config.device = "cuda" if torch.cuda.is_available() else "cpu" config.print_every = 1 config.validate_every = 1 + config.algo.num_from_policy = 13 config.num_final_gen_steps = 5 config.num_training_steps = 3 config.pickle_mp_messages = True diff --git a/src/gflownet/tasks/tmp_run_moql.py b/src/gflownet/tasks/tmp_run_moql.py new file mode 100644 index 00000000..3ae3cd73 --- /dev/null +++ b/src/gflownet/tasks/tmp_run_moql.py @@ -0,0 +1,37 @@ +import torch +from gflownet.config import Config, init_empty +from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer + + +def main(): + """Example of how this model can be run.""" + config = init_empty(Config()) + config.algo.method = "MOQL" + config.desc = "debug_seh_frag_moo" + config.log_dir = "./logs/debug_run_sfm" + config.device = "cuda" if torch.cuda.is_available() else "cpu" + config.num_workers = 0 + config.print_every = 1 + config.algo.num_from_policy = 2 + config.validate_every = 1 + config.num_final_gen_steps = 5 + config.num_training_steps = 3 + config.pickle_mp_messages = True + config.overwrite_existing_exp = True + config.algo.sampling_tau = 0.95 + config.algo.train_random_action_prob = 0.01 + config.task.seh_moo.objectives = ["seh", "qed"] + config.cond.temperature.sample_dist = "constant" + config.cond.temperature.dist_params = [60.0] + config.cond.weighted_prefs.preference_type = "dirichlet" + config.cond.focus_region.focus_type = None + config.replay.use = False + config.task.seh_moo.n_valid = 15 + config.task.seh_moo.n_valid_repeats = 2 + + trial = SEHMOOFragTrainer(config) + trial.run() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index f2c75f60..7f417ef2 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -1,5 +1,5 @@ import socket -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import numpy as np import torch @@ -29,7 +29,7 @@ def __init__( self.num_cond_dim = self.temperature_conditional.encoding_size() self.norm = cfg.algo.max_len / min(map(len, seqs)) - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]: return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: