Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 61 additions & 36 deletions src/gflownet/algo/envelope_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,65 +42,80 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective
)
num_final = num_emb * 2
num_mlp_layers = 0
self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers)
num_edge_emb = num_emb if env_ctx.edges_are_unordered else num_emb * 2
self.emb2add_node = mlp(num_emb, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers)
# Edge attr logits are "sided", so we will compute both sides independently
self.emb2set_edge_attr = mlp(
num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives, num_mlp_layers
)
self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers)
self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers)
self.emb2stop = mlp(num_final, num_emb, num_objectives, num_mlp_layers)
self.emb2reward = mlp(num_final, num_emb, 1, num_mlp_layers)
self.edge2emb = mlp(num_final, num_emb, num_edge_emb, num_mlp_layers)
self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
self.action_type_order = env_ctx.action_type_order
self.mask_value = -10
self.num_objectives = num_objectives
self.edges_are_duplicated = env_ctx.edges_are_duplicated
self.edges_are_unordered = env_ctx.edges_are_unordered

def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
"""See `GraphTransformer` for argument values"""
node_embeddings, graph_embeddings = self.transf(g, cond)
# On `::2`, edges are duplicated to make graphs undirected, only take the even ones
e_row, e_col = g.edge_index[:, ::2]
edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col])
src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_row]], 1))
dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_col]], 1))
if self.edges_are_duplicated:
# On `::2`, edges are typically duplicated to make graphs undirected, only take the even ones
e_row, e_col = g.edge_index[:, ::2]
else:
e_row, e_col = g.edge_index
if self.edges_are_unordered:
edge_embeddings = node_embeddings[e_row] + node_embeddings[e_col]
else:
edge_embeddings = torch.cat([node_embeddings[e_row], node_embeddings[e_col]], 1)
edge_embeddings = self.edge2emb(edge_embeddings)

src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_embeddings, node_embeddings[e_row]], 1))
dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_embeddings, node_embeddings[e_col]], 1))

def _mask(x, m):
# mask logit vector x with binary mask m
return x * m + self.mask_value * (1 - m)

def _mask_obj(x, m):
# mask logit vector x with binary mask m
return (
x.reshape(x.shape[0], x.shape[1] // self.num_objectives, self.num_objectives) * m[:, :, None]
+ self.mask_value * (1 - m[:, :, None])
).reshape(x.shape)

# evelope Q learning uses outputs Q-values for each action and each objective
# we duplicate the masks for each objective
add_node_masks = g.add_node_mask.repeat(1, self.num_objectives)
set_edge_attr_mask = g.set_edge_attr_mask.repeat(1, self.num_objectives)
cat = GraphActionCategorical(
g,
logits=[
F.relu(self.emb2stop(graph_embeddings)),
_mask(F.relu(self.emb2add_node(node_embeddings)), g.add_node_mask),
_mask_obj(F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), g.set_edge_attr_mask),
_mask(F.relu(self.emb2add_node(node_embeddings)), add_node_masks),
_mask(F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), set_edge_attr_mask),
],
keys=[None, "x", "edge_index"],
types=self.action_type_order,
)
r_pred = self.emb2reward(graph_embeddings)
if output_Qs:
# we ouput the full set of Q-values when the model is used in training mode
return cat, r_pred

else:
# if we don't output (e.g. for sampling new trajectories), we get the Q-values for the current omega
# and we need a single set of masks
cat.masks = [1, g.add_node_mask, g.set_edge_attr_mask]

# Compute the greedy policy
# See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
# TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
w = cond[:, -self.num_objectives :]
w_dot_Q = [
(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
for qi, b in zip(cat.logits, cat.batch)
]
# Set the softmax distribution to a very low temperature to make sure only the max gets
# sampled (and we get random argmax tie breaking for free!):
cat.logits = [i * 100 for i in w_dot_Q]

return cat, r_pred
cat.masks = [1, g.add_node_mask.cpu(), g.set_edge_attr_mask.cpu()]
# Compute the greedy policy
# See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
# TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
w = cond[:, -self.num_objectives :]
w_dot_Q = [
(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
for qi, b in zip(cat.logits, cat.batch)
]
# Set the softmax distribution to a very low temperature to make sure only the max gets
# sampled (and we get random argmax tie breaking for free!):
cat.logits = [i * 100 for i in w_dot_Q]
return cat, r_pred


class GraphTransformerEnvelopeQL(nn.Module):
Expand Down Expand Up @@ -205,10 +220,21 @@ def __init__(
self._num_updates = 0
assert self.gamma == 1
self.bootstrap_own_reward = False
self.global_cfg = cfg
# Experimental flags
self.sample_temp = 1
self.do_q_prime_correction = False
self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp)

def set_is_eval(self, is_eval: bool):
self.is_eval = is_eval

def get_random_action_prob(self, it: int):
if self.is_eval:
return self.global_cfg.algo.valid_random_action_prob
if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after:
return self.global_cfg.algo.train_random_action_prob
return 0

def create_training_data_from_own_samples(
self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float
Expand Down Expand Up @@ -280,9 +306,10 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch.log_rewards = log_rewards
batch.cond_info = cond_info
batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float()
batch.preferences = torch.stack([i['cond_info']['preferences'] for i in trajs])

# Now we create a duplicate/repeated batch for Q(s,a,w')
omega_prime = self.task.sample_conditional_information(self.num_omega_samples * batch.num_graphs)
omega_prime = self.task.sample_conditional_information(self.num_omega_samples * batch.num_graphs)
torch_graphs = [i for i in torch_graphs for j in range(self.num_omega_samples)]
actions = [i for i in actions for j in range(self.num_omega_samples)]
batch_prime = self.ctx.collate(torch_graphs)
Expand All @@ -298,7 +325,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:

Parameters
----------
model: TrajectoryBalanceModel
model: GraphTransformerFragEnvelopeQL
A GNN taking in a batch of graphs as input as per constructed by `self.construct_batch`.
Must have a `logZ` attribute, itself a model, which predicts log of Z(cond_info)
batch: gd.Batch
Expand Down Expand Up @@ -370,9 +397,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:
# same (and thus the max is over all of the repeats as well).
# Since the batch slices we will later index to get Q[:, argmax a, argmax omega'] are those
# of Q_omega_prime, we need to use fwd_cat_prime.
argmax = fwd_cat_prime.argmax(
x=w_dot_Q, batch=[b.repeat_interleave(self.num_omega_samples) for b in fwd_cat.batch], dim_size=num_states
)
argmax = fwd_cat_prime.argmax(x=w_dot_Q, batch=[b.repeat_interleave(self.num_omega_samples) for b in fwd_cat.batch], dim_size=num_states)
# Now what we want, for each state, is the vector prediction made by Q(s, a, w') for the
# argmax a,w'. Let's again reuse GraphActionCategorical methods to do the indexing for us.
# We must again use fwd_cat_prime to use the right slices.
Expand Down
2 changes: 0 additions & 2 deletions src/gflownet/envs/frag_mol_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu
# The + 1 is for an extra dimension to indicate when the attribute isn't yet set
self.num_edge_dim = (most_stems + 1) * 2
self.num_cond_dim = num_cond_dim
self.edges_are_duplicated = True
self.edges_are_unordered = False
self.fail_on_missing_attr = True

# Order in which models have to output logits
Expand Down
23 changes: 9 additions & 14 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,19 +701,14 @@ def sample(self) -> List[Tuple[int, int, int]]:
gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self.logits, u)]

if self.masks is not None:
gumbel_safe = [
torch.where(
mask == 1,
torch.maximum(
x,
torch.nextafter(
torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype)
).to(x.device),
),
torch.finfo(x.dtype).min,
)
for x, mask in zip(gumbel, self.masks)
]
gumbel_safe = []
for x, mask in zip(gumbel, self.masks):
small = torch.nextafter(torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype),
torch.tensor(0.0, dtype=x.dtype)).to(x.device)
if type(mask) != torch.Tensor:
mask = torch.tensor(mask, device=x.device)
g = torch.where(mask == 1, torch.maximum(x, small), torch.finfo(x.dtype).min)
gumbel_safe.append(g)
else:
gumbel_safe = gumbel
# Take the argmax
Expand Down Expand Up @@ -764,7 +759,7 @@ def argmax(

# Now we can return the indices of where the actions occured
# in the form List[(type, row, column)]
assert dim_size == type_max_idx.shape[0]
assert dim_size == type_max_idx.shape[0], f"dim_size {dim_size} != type_max_idx.shape[0] {type_max_idx.shape[0]}"
argmaxes = []
for i in range(type_max_idx.shape[0]):
t = type_max_idx[i]
Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/tasks/make_rings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import socket
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple, Union, Optional

import numpy as np
import torch
Expand All @@ -25,7 +25,7 @@ def __init__(
def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards:
return FlatRewards(y)

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]:
return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)}

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/tasks/qm9.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List, Tuple, Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -79,7 +79,7 @@ def load_task_models(self, path):
gap_model = self._wrap_model(gap_model)
return {"mxmnet_gap": gap_model}

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]:
return self.temperature_conditional.sample(n)

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/tasks/qm9_moo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -70,7 +70,7 @@ def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards:
def inverse_flat_reward_transform(self, rp):
return rp

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]:
cond_info = super().sample_conditional_information(n, train_it)
pref_ci = self.pref_cond.sample(n)
focus_ci = (
Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import socket
from typing import Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List, Tuple, Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -57,7 +57,7 @@ def _load_task_models(self):
model = self._wrap_model(model)
return {"seh": model}

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]:
return self.temperature_conditional.sample(n)

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
Expand Down
6 changes: 4 additions & 2 deletions src/gflownet/tasks/seh_frag_moo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Tuple, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
def inverse_flat_reward_transform(self, rp):
return rp

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]:
cond_info = super().sample_conditional_information(n, train_it)
pref_ci = self.pref_cond.sample(n)
focus_ci = (
Expand Down Expand Up @@ -250,6 +250,7 @@ def setup_model(self):
num_layers=self.cfg.model.num_layers,
num_heads=self.cfg.model.graph_transformer.num_heads,
num_objectives=len(self.cfg.task.seh_moo.objectives),

)
else:
super().setup_model()
Expand Down Expand Up @@ -395,6 +396,7 @@ def main():
config.device = "cuda" if torch.cuda.is_available() else "cpu"
config.print_every = 1
config.validate_every = 1
config.algo.num_from_policy = 13
config.num_final_gen_steps = 5
config.num_training_steps = 3
config.pickle_mp_messages = True
Expand Down
37 changes: 37 additions & 0 deletions src/gflownet/tasks/tmp_run_moql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from gflownet.config import Config, init_empty
from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer


def main():
"""Example of how this model can be run."""
config = init_empty(Config())
config.algo.method = "MOQL"
config.desc = "debug_seh_frag_moo"
config.log_dir = "./logs/debug_run_sfm"
config.device = "cuda" if torch.cuda.is_available() else "cpu"
config.num_workers = 0
config.print_every = 1
config.algo.num_from_policy = 2
config.validate_every = 1
config.num_final_gen_steps = 5
config.num_training_steps = 3
config.pickle_mp_messages = True
config.overwrite_existing_exp = True
config.algo.sampling_tau = 0.95
config.algo.train_random_action_prob = 0.01
config.task.seh_moo.objectives = ["seh", "qed"]
config.cond.temperature.sample_dist = "constant"
config.cond.temperature.dist_params = [60.0]
config.cond.weighted_prefs.preference_type = "dirichlet"
config.cond.focus_region.focus_type = None
config.replay.use = False
config.task.seh_moo.n_valid = 15
config.task.seh_moo.n_valid_repeats = 2

trial = SEHMOOFragTrainer(config)
trial.run()


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions src/gflownet/tasks/toy_seq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import socket
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -29,7 +29,7 @@ def __init__(
self.num_cond_dim = self.temperature_conditional.encoding_size()
self.norm = cfg.algo.max_len / min(map(len, seqs))

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
def sample_conditional_information(self, n: int, train_it: Optional[int] = None) -> Dict[str, Tensor]:
return self.temperature_conditional.sample(n)

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
Expand Down