diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 8e6c2474..ba63e708 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -51,4 +51,18 @@ The data used for training GFlowNets can come from a variety of sources. `DataSo `DataSource` also covers validation sets, including cases such as: - Generating new trajectories (w.r.t a fixed dataset of conditioning goals) -- Evaluating the model's likelihood on trajectories from a fixed, offline dataset \ No newline at end of file +- Evaluating the model's likelihood on trajectories from a fixed, offline dataset + +## Multiprocessing + +We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks. + +Because training models involves sampling them, the worker processes need to be able to call the models. This is done by passing a wrapped model (and possibly wrapped replay buffer) to the workers, using `gflownet.utils.multiprocessing_proxy`. These wrappers ensure that model calls are routed to the main worker process, where the model lives (e.g. in CUDA), and that the returned values are properly serialized and sent back to the worker process. These wrappers are also designed to be API-compatible with models, e.g. `model(input)` or `model.method(input)` will work as expected, regardless of whether `model` is a torch module or a wrapper. Note that it is only possible to call methods on these wrappers, direct attribute access is not supported. + +Note that the workers do not use CUDA, therefore have to work entirely on CPU, but the code is designed to be somewhat agnostic to this fact. By using `get_worker_device`, code can be written without assuming too much; again, calls such as `model(input)` will work as expected. + +On message serialization, naively sending batches of data and results (`Batch` and `GraphActionCategorical`) through multiprocessing queues is fairly inefficient. Torch tries to be smart and will use shared memory for tensors that are sent through queues, which unfortunately is very slow because creating these shared memory files is slow, and because `Data` `Batch`es tend to contain lots of small tensors, which is not a good fit for shared memory. + +We implement two solutions to this problem (in order of preference): +- using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but requires that the size of the largest possible batch/return value is known in advance. This should work for any message, but has only been tested with `Batch` and `GraphActionCategorical` messages. +- using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creation of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle. diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index 6cb8f979..5415ecd8 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -23,6 +23,9 @@ class GFNAlgorithm: def step(self): self.updates += 1 # This isn't used anywhere? + def set_is_eval(self, is_eval: bool): + self.is_eval = is_eval + def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 7077a9d1..7ce05a81 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -3,6 +3,7 @@ import torch_geometric.data as gd from torch import Tensor +from gflownet import GFNAlgorithm from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device @@ -10,7 +11,7 @@ from .graph_sampling import GraphSampler -class A2C: +class A2C(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -36,6 +37,7 @@ def __init__( The experiment configuration """ + self.global_cfg = cfg # TODO: this belongs in the base class self.ctx = ctx self.env = env self.max_len = cfg.algo.max_len @@ -149,7 +151,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values - policy, per_state_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + policy, per_state_preds = model(batch) V = per_state_preds[:, 0] G = rewards[batch_idx] # The return is the terminal reward everywhere, we're using gamma==1 G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid object diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 9bfc3345..4798600b 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -5,6 +5,7 @@ from torch import Tensor from torch_scatter import scatter +from gflownet import GFNAlgorithm from gflownet.config import Config from gflownet.envs.graph_building_env import ( GraphActionCategorical, @@ -39,24 +40,24 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective num_layers=num_layers, num_heads=num_heads, ) - num_final = num_emb * 2 + num_final = num_emb num_mlp_layers = 0 self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers) # Edge attr logits are "sided", so we will compute both sides independently self.emb2set_edge_attr = mlp( num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives, num_mlp_layers ) - self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers) - self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers) + self.emb2stop = mlp(num_emb * 2, num_emb, num_objectives, num_mlp_layers) + self.emb2reward = mlp(num_emb * 2, num_emb, 1, num_mlp_layers) self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers) self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) self.action_type_order = env_ctx.action_type_order self.mask_value = -10 self.num_objectives = num_objectives - def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): + def forward(self, g: gd.Batch, output_Qs=False): """See `GraphTransformer` for argument values""" - node_embeddings, graph_embeddings = self.transf(g, cond) + node_embeddings, graph_embeddings = self.transf(g) # On `::2`, edges are duplicated to make graphs undirected, only take the even ones e_row, e_col = g.edge_index[:, ::2] edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col]) @@ -86,7 +87,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): # Compute the greedy policy # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes - w = cond[:, -self.num_objectives :] + w = g.cond_info[:, -self.num_objectives :] w_dot_Q = [ (qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2) for qi, b in zip(cat.logits, cat.batch) @@ -122,8 +123,9 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective self.action_type_order = env_ctx.action_type_order self.num_objectives = num_objectives - def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): - node_embeddings, graph_embeddings = self.transf(g, cond) + def forward(self, g: gd.Batch, output_Qs=False): + cond = g.cond_info + node_embeddings, graph_embeddings = self.transf(g) ne_row, ne_col = g.non_edge_index # On `::2`, edges are duplicated to make graphs undirected, only take the even ones e_row, e_col = g.edge_index[:, ::2] @@ -156,7 +158,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): return cat, r_pred -class EnvelopeQLearning: +class EnvelopeQLearning(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -182,6 +184,7 @@ def __init__( cfg: Config The experiment configuration """ + self.global_cfg = cfg self.ctx = ctx self.env = env self.task = task @@ -314,7 +317,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values # Q(s,a,omega) - fwd_cat, per_state_preds = model(batch, cond_info[batch_idx], output_Qs=True) + batch.cond_info = cond_info[batch_idx] + fwd_cat, per_state_preds = model(batch, output_Qs=True) Q_omega = fwd_cat.logits # reshape to List[shape: (num in all graphs, num actions on T, num_objectives) | for all types T] Q_omega = [i.reshape((i.shape[0], i.shape[1] // num_objectives, num_objectives)) for i in Q_omega] @@ -323,7 +327,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: batchp = batch.batch_prime batchp_num_trajs = int(batchp.traj_lens.shape[0]) batchp_batch_idx = torch.arange(batchp_num_trajs, device=dev).repeat_interleave(batchp.traj_lens) - fwd_cat_prime, per_state_preds = model(batchp, batchp.cond_info[batchp_batch_idx], output_Qs=True) + batchp.cond_info = batchp.cond_info[batchp_batch_idx] + fwd_cat_prime, per_state_preds = model(batchp, output_Qs=True) Q_omega_prime = fwd_cat_prime.logits # We've repeated everything N_omega times, so we can reshape the same way as above but with # an extra N_omega first dimension diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py index 33c436bf..c75c1ce4 100644 --- a/src/gflownet/algo/flow_matching.py +++ b/src/gflownet/algo/flow_matching.py @@ -46,7 +46,7 @@ def __init__( # in a number of settings the regular loss is more stable. self.fm_balanced_loss = cfg.algo.fm.balanced_loss self.fm_leaf_coef = cfg.algo.fm.leaf_coef - self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent + self.correct_idempotent: bool = cfg.algo.fm.correct_idempotent def construct_batch(self, trajs, cond_info, log_rewards): """Construct a batch from a list of trajectories and their information @@ -149,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Query the model for Fsa. The model will output a GraphActionCategorical, but we will # simply interpret the logits as F(s, a). Conveniently the policy of a GFN is the softmax of # log F(s,a) so we don't have to change anything in the sampling routines. - cat, graph_out = model(batch, batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)]) + batch.cond_info = batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)] + cat, graph_out = model(batch) # We compute \sum_{s,a : T(s,a)=s'} F(s,a), first we index all the parent's outputs by the # parent actions. To do so we reuse the log_prob mechanism, but specify that the logprobs # tensor is actually just the logits (which we chose to interpret as edge flows F(s,a). We diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 9cf9faeb..1199dd31 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -119,8 +119,9 @@ def not_done(lst): # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. # TODO: compute bck_cat.log_prob(bck_a) when relevant - ci = cond_info[not_done_mask] if cond_info is not None else None - fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci) + batch = self.ctx.collate(torch_graphs) + batch.cond_info = cond_info[not_done_mask] if cond_info is not None else None + fwd_cat, *_, log_reward_preds = model(batch.to(dev)) if random_action_prob > 0: # Device which graphs in the minibatch will get their action randomized is_random_action = torch.tensor( diff --git a/src/gflownet/algo/multiobjective_reinforce.py b/src/gflownet/algo/multiobjective_reinforce.py index b1a636de..52314d03 100644 --- a/src/gflownet/algo/multiobjective_reinforce.py +++ b/src/gflownet/algo/multiobjective_reinforce.py @@ -34,7 +34,8 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) # Forward pass of the model, returns a GraphActionCategorical and the optional bootstrap predictions - fwd_cat, log_reward_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + fwd_cat, log_reward_preds = model(batch) # This is the log prob of each action in the trajectory log_prob = fwd_cat.log_prob(batch.actions) diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index a9d61aaa..99730279 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -4,13 +4,14 @@ from torch import Tensor from torch_scatter import scatter +from gflownet import GFNAlgorithm from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device -class SoftQLearning: +class SoftQLearning(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -33,6 +34,7 @@ def __init__( cfg: Config The experiment configuration """ + self.global_cfg = cfg self.ctx = ctx self.env = env self.max_len = cfg.algo.max_len @@ -147,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per object predictions # Here we will interpret the logits of the fwd_cat as Q values - Q, per_state_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + Q, per_state_preds = model(batch) if self.do_q_prime_correction: # First we need to estimate V_soft. We will use q_a' = \pi diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index eac57cc6..8713cde3 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -109,7 +109,7 @@ def __init__( """ self.ctx = ctx self.env = env - self.global_cfg = cfg + self.global_cfg = cfg # TODO: this belongs in the base class self.cfg = cfg.algo.tb self.max_len = cfg.algo.max_len self.max_nodes = cfg.algo.max_nodes @@ -147,9 +147,6 @@ def __init__( self._subtb_max_len = self.global_cfg.algo.max_len + 2 self._init_subtb(get_worker_device()) - def set_is_eval(self, is_eval: bool): - self.is_eval = is_eval - def create_training_data_from_own_samples( self, model: TrajectoryBalanceModel, @@ -402,12 +399,14 @@ def compute_batch_losses( # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). if self.cfg.do_parameterize_p_b: - fwd_cat, bck_cat, per_graph_out = model(batch, batched_cond_info) + batch.cond_info = batched_cond_info + fwd_cat, bck_cat, per_graph_out = model(batch) else: if self.model_is_autoregressive: - fwd_cat, per_graph_out = model(batch, cond_info, batched=True) + fwd_cat, per_graph_out = model(batch, batched=True) else: - fwd_cat, per_graph_out = model(batch, batched_cond_info) + batch.cond_info = batched_cond_info + fwd_cat, per_graph_out = model(batch) # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] diff --git a/src/gflownet/config.py b/src/gflownet/config.py index b66f238d..86b225f0 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -79,6 +79,10 @@ class Config(StrictDataClass): The hostname of the machine on which the experiment is run pickle_mp_messages : bool Whether to pickle messages sent between processes (only relevant if num_workers > 0) + mp_buffer_size : Optional[int] + If specified, use a buffer of this size in bytes for passing tensors between processes. + Note that this is only relevant if num_workers > 0. + Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers. git_hash : Optional[str] The git hash of the current commit overwrite_existing_exp : bool @@ -102,6 +106,7 @@ class Config(StrictDataClass): pickle_mp_messages: bool = False git_hash: Optional[str] = None overwrite_existing_exp: bool = False + mp_buffer_size: Optional[int] = None algo: AlgoConfig = field(default_factory=AlgoConfig) model: ModelConfig = field(default_factory=ModelConfig) opt: OptimizerConfig = field(default_factory=OptimizerConfig) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 85ede753..a6fdecb7 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -4,12 +4,15 @@ import numpy as np import torch from torch.utils.data import IterableDataset +from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer, detach_and_cpu from gflownet.envs.graph_building_env import GraphBuildingEnvContext +from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import get_worker_rng +from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer def cycle_call(it): @@ -44,6 +47,7 @@ def __init__( self.global_step_count.share_memory_() self.global_step_count_lock = torch.multiprocessing.Lock() self.current_iter = start_at_step + self.setup_mp_buffers() def add_sampling_hook(self, hook: Callable): """Add a hook that is called when sampling new trajectories. @@ -231,7 +235,7 @@ def create_batch(self, trajs, batch_info): batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) batch.obj_props = torch.stack([t["obj_props"] for t in trajs]) - return batch + return self._maybe_put_in_mp_buffer(batch) def compute_properties(self, trajs, mark_as_online=False): """Sets trajs' obj_props and is_valid keys by querying the task.""" @@ -319,3 +323,20 @@ def iterate_indices(self, n, num_samples): yield np.arange(i, i + num_samples) if i + num_samples < end: yield np.arange(i + num_samples, end) + + def setup_mp_buffers(self): + if self.cfg.num_workers > 0: + self.mp_buffer_size = self.cfg.mp_buffer_size + if self.mp_buffer_size: + self.result_buffer = [SharedPinnedBuffer(self.mp_buffer_size) for _ in range(self.cfg.num_workers)] + else: + self.mp_buffer_size = None + + def _maybe_put_in_mp_buffer(self, batch): + if self.mp_buffer_size: + if not (isinstance(batch, (Batch, SeqBatch))): + warnings.warn(f"Expected a Batch object, but got {type(batch)}. Not using mp buffers.") + return batch + return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid) + else: + return batch diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index b10b228d..601d7bd5 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -887,10 +887,12 @@ def entropy(self, logprobs=None): """ if logprobs is None: logprobs = self.logsoftmax() + masks = self.action_masks if self.action_masks is not None else [None] * len(logprobs) entropy = -sum( [ - scatter(i * i.exp(), b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) - for i, b in zip(logprobs, self.batch) + scatter(im, b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) + for i, b, m in zip(logprobs, self.batch, masks) + for im in [i.masked_fill(m == 0.0, 0) if m is not None else i] ] ) return entropy diff --git a/src/gflownet/envs/seq_building_env.py b/src/gflownet/envs/seq_building_env.py index 0e0281a9..f47fa6ca 100644 --- a/src/gflownet/envs/seq_building_env.py +++ b/src/gflownet/envs/seq_building_env.py @@ -70,6 +70,7 @@ def __init__(self, seqs: List[torch.Tensor], pad: int): # Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this # is the total number of timesteps. self.num_graphs = self.lens.sum().item() + self.cond_info: torch.Tensor # May be set later def to(self, device): for name in dir(self): diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 366e4390..b84980dc 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -90,7 +90,7 @@ def __init__( ) ) - def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): + def forward(self, g: gd.Batch): """Forward pass Parameters @@ -112,7 +112,7 @@ def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): x = g.x o = self.x2h(x) e = self.e2h(g.edge_attr) - c = self.c2h(cond if cond is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) + c = self.c2h(g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) num_total_nodes = g.x.shape[0] # Augment the edges with a new edge to the conditioning # information node. This new node is connected to every node @@ -255,8 +255,8 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap types=action_types, ) - def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): - node_embeddings, graph_embeddings = self.transf(g, cond) + def forward(self, g: gd.Batch): + node_embeddings, graph_embeddings = self.transf(g) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): ne_row, ne_col = g.non_edge_index diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index 54557922..8ecb8919 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -65,7 +65,7 @@ def logZ(self, cond_info: Optional[torch.Tensor]): return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device)) return self._logZ(cond_info) - def forward(self, xs: SeqBatch, cond, batched=False): + def forward(self, xs: SeqBatch, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions. Parameters @@ -83,6 +83,7 @@ def forward(self, xs: SeqBatch, cond, batched=False): x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device)) pooled_x = x[xs.lens - 1, torch.arange(x.shape[1])] # (batch, nemb) + cond = xs.cond_info if self.use_cond: cond_var = self.cond_embed(cond) # (batch, nemb) cond_var = torch.tile(cond_var, (x.shape[0], 1, 1)) if batched else cond_var diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 103acc95..13dd0e48 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -82,12 +82,17 @@ def setup(self): else: Z_params = [] non_Z_params = list(self.model.parameters()) + self.opt = self._opt(non_Z_params) - self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( - self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) - ) + + if Z_params: + self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) + ) + else: + self.opt_Z = None self.sampling_tau = self.cfg.algo.sampling_tau if self.sampling_tau > 0: @@ -124,10 +129,11 @@ def step(self, loss: Tensor): g1 = model_grad_norm(self.model) self.opt.step() self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() self.lr_sched.step() - self.lr_sched_Z.step() + if self.opt_Z is not None: + self.opt_Z.step() + self.opt_Z.zero_grad() + self.lr_sched_Z.step() if self.sampling_tau > 0: for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 2adca4f4..de4b72f0 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -129,6 +129,7 @@ class SEHFragTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False + cfg.mp_buffer_size = 32 * 1024**2 # 32 MB cfg.num_workers = 8 cfg.opt.learning_rate = 1e-4 cfg.opt.weight_decay = 1e-8 diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 386c0494..4d90fa3a 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -4,7 +4,7 @@ import pathlib import shutil import time -from typing import Any, Callable, Dict, List, Optional, Protocol +from typing import Any, Callable, Dict, List, Optional, Protocol, Union import numpy as np import torch @@ -16,6 +16,7 @@ from rdkit import RDLogger from torch import Tensor from torch.utils.data import DataLoader, Dataset +from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.data.data_source import DataSource @@ -23,7 +24,7 @@ from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger, set_main_process_device, set_worker_rng_seed -from gflownet.utils.multiprocessing_proxy import mp_object_wrapper +from gflownet.utils.multiprocessing_proxy import BufferUnpickler, mp_object_wrapper from gflownet.utils.sqlite_log import SQLiteLogHook from .config import Config @@ -132,6 +133,7 @@ def _wrap_for_mp(self, obj): self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, + sb_size=self.cfg.mp_buffer_size, ) self.to_terminate.append(wrapper.terminate) return wrapper.placeholder @@ -181,8 +183,6 @@ def build_training_data_loader(self) -> DataLoader: def build_validation_data_loader(self) -> DataLoader: model = self._wrap_for_mp(self.model) - # TODO: we're changing the default, make sure anything that is using test data is adjusted - src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) n_drawn = self.cfg.algo.valid_num_from_policy n_from_dataset = self.cfg.algo.valid_num_from_dataset @@ -219,6 +219,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: tick = time.time() self.model.train() try: + loss = info = None loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): raise ValueError("loss is not finite") @@ -247,6 +248,16 @@ def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0 info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} + def _maybe_resolve_shared_buffer( + self, batch: Union[Batch, SeqBatch, tuple, list], dl: DataLoader + ) -> Union[Batch, SeqBatch]: + if dl.dataset.mp_buffer_size and isinstance(batch, (tuple, list)): + batch, wid = batch + batch = BufferUnpickler(dl.dataset.result_buffer[wid], batch, self.device).load() + elif isinstance(batch, (Batch, SeqBatch)): + batch = batch.to(self.device) + return batch + def run(self, logger=None): """Trains the GFN for `num_training_steps` minibatches, performing validation every `validate_every` minibatches. @@ -275,6 +286,7 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() + batch = self._maybe_resolve_shared_buffer(batch, train_dl) epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -282,7 +294,7 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue - info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) + info = self.train_batch(batch, epoch_idx, batch_idx, it) info["time_spent"] = time.time() - start_time start_time = time.time() self.log(info, it, "train") @@ -291,6 +303,7 @@ def run(self, logger=None): if valid_freq > 0 and it % valid_freq == 0: for batch in valid_dl: + batch = self._maybe_resolve_shared_buffer(batch, valid_dl) info = self.evaluate_batch(batch.to(self.device), epoch_idx, batch_idx) self.log(info, it, "valid") logger.info(f"validation - iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) @@ -311,6 +324,7 @@ def run(self, logger=None): range(num_training_steps + 1, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): + batch = self._maybe_resolve_shared_buffer(batch, final_dl) if hasattr(batch, "extra_info"): for k, v in batch.extra_info.items(): if k not in final_info: diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index df13b565..7ed734bd 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -1,22 +1,143 @@ +import io import pickle import queue import threading import traceback +from pickle import Pickler, Unpickler, UnpicklingError import torch import torch.multiprocessing as mp +class SharedPinnedBuffer: + def __init__(self, size): + self.size = size + self.buffer = torch.empty(size, dtype=torch.uint8) + self.buffer.share_memory_() + self.lock = mp.Lock() + self.do_unreg = False + + if not self.buffer.is_pinned(): + # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to + # pin it again; doing so will raise a CUDA error + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) + assert r == 0 + self.do_unreg = True # But then we need to unregister it later + assert self.buffer.is_shared() + assert self.buffer.is_pinned() + + def __del__(self): + if torch.utils.data.get_worker_info() is None: + if self.do_unreg: + cudart = torch.cuda.cudart() + r = cudart.cudaHostUnregister(self.buffer.data_ptr()) + assert r == 0 + + +class _BufferPicklerSentinel: + pass + + +class BufferPickler(Pickler): + def __init__(self, buf: SharedPinnedBuffer): + self._f = io.BytesIO() + super().__init__(self._f) + self.buf = buf + # The lock will be released by the consumer (BufferUnpickler) of this buffer once + # the memory has been transferred to the device and copied + self.buf.lock.acquire() + self.buf_offset = 0 + + def persistent_id(self, v): + if not isinstance(v, torch.Tensor): + return None + numel = v.numel() * v.element_size() + if self.buf_offset + numel > self.buf.size: + raise RuntimeError( + f"Tried to allocate {self.buf_offset + numel} bytes in a buffer of size {self.buf.size}. " + "Consider increasing cfg.mp_buffer_size" + ) + start = self.buf_offset + shape = tuple(v.shape) + if v.ndim > 0 and v.stride(-1) != 1 or not v.is_contiguous(): + v = v.contiguous().reshape(-1) + if v.ndim > 0 and v.stride(-1) != 1: + # We're still not contiguous, this unfortunately happens occasionally, e.g.: + # x = torch.arange(10).reshape((10, 1)) + # y = x.T[::2].T + # y.stride(), y.is_contiguous(), y.contiguous().stride() + # -> (1, 2), True, (1, 2) + v = v.flatten() + 0 + # I don't know if this comes from my misunderstanding of strides or if it's a bug in torch + # but either way torch will refuse to view this tensor as a uint8 tensor, so we have to + 0 + # to force torch to materialize it into a new tensor (it may otherwise be lazy and not materialize) + if numel > 0: + self.buf.buffer[start : start + numel] = v.flatten().view(torch.uint8) + self.buf_offset += numel + self.buf_offset += (8 - self.buf_offset % 8) % 8 # align to 8 bytes + return (_BufferPicklerSentinel, (start, shape, v.dtype)) + + def dumps(self, obj): + self.dump(obj) + return (self._f.getvalue(), self.buf_offset) + + +class BufferUnpickler(Unpickler): + def __init__(self, buf: SharedPinnedBuffer, data, device): + self._f, total_size = io.BytesIO(data[0]), data[1] + super().__init__(self._f) + self.buf = buf + self.target_buf = buf.buffer[:total_size].to(device) + 0 + # Why the `+ 0`? Unfortunately, we have no way to know exactly when the consumer of the object we're + # unpickling will be done using the buffer underlying the tensor, so we have to create a copy. + # If we don't and another consumer starts using the buffer, and this consumer transfers this pinned + # buffer to the GPU, the first consumer's tensors will be corrupted, because (depending on the CUDA + # memory manager) the pinned buffer will transfer to the same GPU location. + # Hopefully, especially if the target device is the GPU, the copy will be fast and/or async. + # Note that this could be fixed by using one buffer for each worker, but that would be significantly + # more memory usage. + + def load_tensor(self, offset, shape, dtype): + numel = prod(shape) * dtype.itemsize + tensor: torch.Tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape) + return tensor + + def persistent_load(self, pid): + if isinstance(pid, tuple): + sentinel, (offset, shape, dtype) = pid + if sentinel is _BufferPicklerSentinel: + return self.load_tensor(offset, shape, dtype) + return UnpicklingError("Invalid persistent id") + + def load(self): + r = super().load() + # We're done with this buffer, release it for the next consumer + self.buf.lock.release() + return r + + +def prod(ns): + p = 1 + for i in ns: + p *= i + return p + + class MPObjectPlaceholder: """This class can be used for example as a model or dataset placeholder in a worker process, and translates calls to the object-placeholder into queries for the main process to execute on the real object.""" - def __init__(self, in_queues, out_queues, pickle_messages=False): + def __init__(self, in_queues, out_queues, pickle_messages=False, shared_buffer_size=None): self.qs = in_queues, out_queues self.device = torch.device("cpu") self.pickle_messages = pickle_messages self._is_init = False + self.shared_buffer_size = shared_buffer_size + if shared_buffer_size: + self._buffer_to_main = SharedPinnedBuffer(shared_buffer_size) + self._buffer_from_main = SharedPinnedBuffer(shared_buffer_size) def _check_init(self): if self._is_init: @@ -31,11 +152,15 @@ def _check_init(self): self._is_init = True def encode(self, m): + if self.shared_buffer_size: + return BufferPickler(self._buffer_to_main).dumps(m) if self.pickle_messages: return pickle.dumps(m) return m def decode(self, m): + if self.shared_buffer_size: + m = BufferUnpickler(self._buffer_from_main, m, self.device).load() if self.pickle_messages: m = pickle.loads(m) if isinstance(m, Exception): @@ -75,7 +200,7 @@ class MPObjectProxy: Always passes CPU tensors between processes. """ - def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False): + def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False, sb_size=None): """Construct a multiprocessing object proxy. Parameters @@ -91,11 +216,14 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo If True, pickle messages sent between processes. This reduces load on shared memory, but increases load on CPU. It is recommended to activate this flag if encountering "Too many open files"-type errors. + sb_size: Optional[int] + shared buffer size """ self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.pickle_messages = pickle_messages - self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages) + self.use_shared_buffer = bool(sb_size) + self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, sb_size) self.obj = obj if hasattr(obj, "parameters"): self.device = next(obj.parameters()).device @@ -107,11 +235,16 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.thread.start() def encode(self, m): + if self.use_shared_buffer: + return BufferPickler(self.placeholder._buffer_from_main).dumps(m) if self.pickle_messages: return pickle.dumps(m) return m def decode(self, m): + if self.use_shared_buffer: + return BufferUnpickler(self.placeholder._buffer_to_main, m, self.device).load() + if self.pickle_messages: return pickle.loads(m) return m @@ -121,8 +254,7 @@ def to_cpu(self, i): def run(self): timeouts = 0 - - while not self.stop.is_set() or timeouts < 500: + while not self.stop.is_set() and timeouts < 5 / 1e-5: for qi, q in enumerate(self.in_queues): try: r = self.decode(q.get(True, 1e-5)) @@ -143,6 +275,7 @@ def run(self): except Exception as e: result = e exc_str = traceback.format_exc() + print(exc_str) try: pickle.dumps(e) except Exception: @@ -159,34 +292,30 @@ def terminate(self): self.stop.set() -def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): +def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False, sb_size=None): """Construct a multiprocessing object proxy for torch DataLoaders so that it does not need to be copied in every worker's memory. For example, this can be used to wrap a model such that only the main process makes cuda calls by forwarding data through the model, or a replay buffer such that the new data is pushed in from the worker processes but only the main process has to hold the full buffer in memory. - self.out_queues[qi].put(self.encode(msg)) - elif isinstance(result, dict): - msg = {k: self.to_cpu(i) for k, i in result.items()} - self.out_queues[qi].put(self.encode(msg)) - else: - msg = self.to_cpu(result) - self.out_queues[qi].put(self.encode(msg)) Parameters ---------- obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) - Lives in the main process to which method calls are passed + Lives in the main process to which method calls are passed num_workers: int Number of DataLoader workers cast_types: tuple Types that will be cast to cuda when received as arguments of method calls. torch.Tensor is cast by default. pickle_messages: bool - If True, pickle messages sent between processes. This reduces load on shared - memory, but increases load on CPU. It is recommended to activate this flag if - encountering "Too many open files"-type errors. + If True, pickle messages sent between processes. This reduces load on shared + memory, but increases load on CPU. It is recommended to activate this flag if + encountering "Too many open files"-type errors. + sb_size: Optional[int] + If not None, creates a shared buffer of this size for sending tensors between processes. + Note, this will allocate two buffers of this size (one for sending, the other for receiving). Returns ------- @@ -194,4 +323,4 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - return MPObjectProxy(obj, num_workers, cast_types, pickle_messages) + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages, sb_size=sb_size) diff --git a/tests/test_graph_building_env.py b/tests/test_graph_building_env.py index e9184cbd..adf41120 100644 --- a/tests/test_graph_building_env.py +++ b/tests/test_graph_building_env.py @@ -123,4 +123,24 @@ def test_log_prob(): def test_entropy(): cat = make_test_cat() - cat.entropy() + entropy = cat.entropy() + assert torch.isfinite(entropy).all() and entropy.shape == (3,) and (entropy > 0).all() + + cat.action_masks = [ + torch.tensor([[0], [1], [1.0]]), + ((torch.arange(cat.logits[1].numel()) % 2) == 0).float().reshape(cat.logits[1].shape), + torch.tensor([[1, 0, 1], [0, 1, 1.0]]), + ] + entropy = cat.entropy() + assert torch.isfinite(entropy).all() and entropy.shape == (3,) and (entropy > 0).all() + + +def test_entropy_grad(): + # Purposefully large values to test extremal behaviors + logits = torch.tensor([[100, 101, -102, 95, 10, 20, 72]]).float() + logits.requires_grad_(True) + batch = Batch.from_data_list([Data(x=torch.ones((1, 10)), y=torch.ones((2, 6)))], follow_batch=["y"]) + cat = GraphActionCategorical(batch, [logits[:, :3], logits[:, 3:].reshape(2, 2)], [None, "y"], [None, None]) + cat._epsilon = 0 + (grad_gac,) = torch.autograd.grad(cat.entropy(), logits, retain_graph=True) + assert torch.isfinite(grad_gac).all()