Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared pinned buffers #120

Open
wants to merge 25 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7dbca12
first throw at refactoring SamplingIterator
bengioe Feb 28, 2024
939cb56
Merge branch 'trunk' into bengioe-better-iterators
bengioe Feb 28, 2024
dfba1ca
changed all iterators to DataSource
bengioe Feb 29, 2024
e5239fb
lots of little fixes, tested all tasks, better device management
bengioe Feb 29, 2024
43dfc2b
style
bengioe Mar 1, 2024
279ecfc
change batch size hyperparameters + fix nested dataclasses
bengioe Mar 7, 2024
2ba251a
Merge branch 'trunk' into bengioe-better-iterators
bengioe Mar 7, 2024
282bbfb
move things around & prevent circular import
bengioe Mar 7, 2024
c3bc6d0
tox
bengioe Mar 7, 2024
b1c5630
fix imports
bengioe Mar 7, 2024
a64a639
replace device references with get_worker_device
bengioe Mar 7, 2024
28bcc59
little fixes
bengioe Mar 7, 2024
4811e7c
a few more stragglers
bengioe Mar 7, 2024
7d32ac1
proof of concept of using shared pinned buffers
bengioe Feb 23, 2024
d4a2a7d
32mb buffer
bengioe Feb 23, 2024
27dfc23
add to DataSource
bengioe Mar 7, 2024
e9f1dc1
various fixes
bengioe Mar 8, 2024
c048e77
major simplification by reusing pickling mechanisms
bengioe Mar 8, 2024
acfe070
memory copy + fixes and doc
bengioe Mar 11, 2024
9454da8
Merge branch 'trunk' into bengioe-mp-with-batch-buffers
bengioe Mar 11, 2024
2b9da70
Merge branch 'trunk' into bengioe-mp-with-batch-buffers
bengioe May 8, 2024
907ffcd
fix global_cfg + opt_Z when there's no Z
bengioe May 8, 2024
60722a7
fix entropy when masks are used
bengioe May 9, 2024
f859640
small fixes
bengioe May 9, 2024
d536233
removing timing prints
bengioe May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,18 @@ The data used for training GFlowNets can come from a variety of sources. `DataSo

`DataSource` also covers validation sets, including cases such as:
- Generating new trajectories (w.r.t a fixed dataset of conditioning goals)
- Evaluating the model's likelihood on trajectories from a fixed, offline dataset
- Evaluating the model's likelihood on trajectories from a fixed, offline dataset

## Multiprocessing

We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks.

Because training models involves sampling them, the worker processes need to be able to call the models. This is done by passing a wrapped model (and possibly wrapped replay buffer) to the workers, using `gflownet.utils.multiprocessing_proxy`. These wrappers ensure that model calls are routed to the main worker process, where the model lives (e.g. in CUDA), and that the returned values are properly serialized and sent back to the worker process. These wrappers are also designed to be API-compatible with models, e.g. `model(input)` or `model.method(input)` will work as expected, regardless of whether `model` is a torch module or a wrapper. Note that it is only possible to call methods on these wrappers, direct attribute access is not supported.

Note that the workers do not use CUDA, therefore have to work entirely on CPU, but the code is designed to be somewhat agnostic to this fact. By using `get_worker_device`, code can be written without assuming too much; again, calls such as `model(input)` will work as expected.

On message serialization, naively sending batches of data and results (`Batch` and `GraphActionCategorical`) through multiprocessing queues is fairly inefficient. Torch tries to be smart and will use shared memory for tensors that are sent through queues, which unfortunately is very slow because creating these shared memory files is slow, and because `Data` `Batch`es tend to contain lots of small tensors, which is not a good fit for shared memory.

We implement two solutions to this problem (in order of preference):
- using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but requires that the size of the largest possible batch/return value is known in advance. This should work for any message, but has only been tested with `Batch` and `GraphActionCategorical` messages.
- using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creation of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle.
3 changes: 3 additions & 0 deletions src/gflownet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class GFNAlgorithm:
def step(self):
self.updates += 1 # This isn't used anywhere?

def set_is_eval(self, is_eval: bool):
self.is_eval = is_eval

def compute_batch_losses(
self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0
) -> Tuple[Tensor, Dict[str, Tensor]]:
Expand Down
7 changes: 5 additions & 2 deletions src/gflownet/algo/advantage_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import torch_geometric.data as gd
from torch import Tensor

from gflownet import GFNAlgorithm
from gflownet.config import Config
from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory
from gflownet.utils.misc import get_worker_device

from .graph_sampling import GraphSampler


class A2C:
class A2C(GFNAlgorithm):
def __init__(
self,
env: GraphBuildingEnv,
Expand All @@ -36,6 +37,7 @@ def __init__(
The experiment configuration

"""
self.global_cfg = cfg # TODO: this belongs in the base class
self.ctx = ctx
self.env = env
self.max_len = cfg.algo.max_len
Expand Down Expand Up @@ -149,7 +151,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:

# Forward pass of the model, returns a GraphActionCategorical and per graph predictions
# Here we will interpret the logits of the fwd_cat as Q values
policy, per_state_preds = model(batch, cond_info[batch_idx])
batch.cond_info = cond_info[batch_idx]
policy, per_state_preds = model(batch)
V = per_state_preds[:, 0]
G = rewards[batch_idx] # The return is the terminal reward everywhere, we're using gamma==1
G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid object
Expand Down
27 changes: 16 additions & 11 deletions src/gflownet/algo/envelope_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import Tensor
from torch_scatter import scatter

from gflownet import GFNAlgorithm
from gflownet.config import Config
from gflownet.envs.graph_building_env import (
GraphActionCategorical,
Expand Down Expand Up @@ -39,24 +40,24 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective
num_layers=num_layers,
num_heads=num_heads,
)
num_final = num_emb * 2
num_final = num_emb
num_mlp_layers = 0
self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers)
# Edge attr logits are "sided", so we will compute both sides independently
self.emb2set_edge_attr = mlp(
num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives, num_mlp_layers
)
self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers)
self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
self.emb2stop = mlp(num_emb * 2, num_emb, num_objectives, num_mlp_layers)
self.emb2reward = mlp(num_emb * 2, num_emb, 1, num_mlp_layers)
self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers)
self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
self.action_type_order = env_ctx.action_type_order
self.mask_value = -10
self.num_objectives = num_objectives

def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
def forward(self, g: gd.Batch, output_Qs=False):
"""See `GraphTransformer` for argument values"""
node_embeddings, graph_embeddings = self.transf(g, cond)
node_embeddings, graph_embeddings = self.transf(g)
# On `::2`, edges are duplicated to make graphs undirected, only take the even ones
e_row, e_col = g.edge_index[:, ::2]
edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col])
Expand Down Expand Up @@ -86,7 +87,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
# Compute the greedy policy
# See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
# TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
w = cond[:, -self.num_objectives :]
w = g.cond_info[:, -self.num_objectives :]
w_dot_Q = [
(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
for qi, b in zip(cat.logits, cat.batch)
Expand Down Expand Up @@ -122,8 +123,9 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective
self.action_type_order = env_ctx.action_type_order
self.num_objectives = num_objectives

def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
node_embeddings, graph_embeddings = self.transf(g, cond)
def forward(self, g: gd.Batch, output_Qs=False):
cond = g.cond_info
node_embeddings, graph_embeddings = self.transf(g)
ne_row, ne_col = g.non_edge_index
# On `::2`, edges are duplicated to make graphs undirected, only take the even ones
e_row, e_col = g.edge_index[:, ::2]
Expand Down Expand Up @@ -156,7 +158,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
return cat, r_pred


class EnvelopeQLearning:
class EnvelopeQLearning(GFNAlgorithm):
def __init__(
self,
env: GraphBuildingEnv,
Expand All @@ -182,6 +184,7 @@ def __init__(
cfg: Config
The experiment configuration
"""
self.global_cfg = cfg
self.ctx = ctx
self.env = env
self.task = task
Expand Down Expand Up @@ -314,7 +317,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:
# Forward pass of the model, returns a GraphActionCategorical and per graph predictions
# Here we will interpret the logits of the fwd_cat as Q values
# Q(s,a,omega)
fwd_cat, per_state_preds = model(batch, cond_info[batch_idx], output_Qs=True)
batch.cond_info = cond_info[batch_idx]
fwd_cat, per_state_preds = model(batch, output_Qs=True)
Q_omega = fwd_cat.logits
# reshape to List[shape: (num <T> in all graphs, num actions on T, num_objectives) | for all types T]
Q_omega = [i.reshape((i.shape[0], i.shape[1] // num_objectives, num_objectives)) for i in Q_omega]
Expand All @@ -323,7 +327,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:
batchp = batch.batch_prime
batchp_num_trajs = int(batchp.traj_lens.shape[0])
batchp_batch_idx = torch.arange(batchp_num_trajs, device=dev).repeat_interleave(batchp.traj_lens)
fwd_cat_prime, per_state_preds = model(batchp, batchp.cond_info[batchp_batch_idx], output_Qs=True)
batchp.cond_info = batchp.cond_info[batchp_batch_idx]
fwd_cat_prime, per_state_preds = model(batchp, output_Qs=True)
Q_omega_prime = fwd_cat_prime.logits
# We've repeated everything N_omega times, so we can reshape the same way as above but with
# an extra N_omega first dimension
Expand Down
5 changes: 3 additions & 2 deletions src/gflownet/algo/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
# in a number of settings the regular loss is more stable.
self.fm_balanced_loss = cfg.algo.fm.balanced_loss
self.fm_leaf_coef = cfg.algo.fm.leaf_coef
self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent
self.correct_idempotent: bool = cfg.algo.fm.correct_idempotent

def construct_batch(self, trajs, cond_info, log_rewards):
"""Construct a batch from a list of trajectories and their information
Expand Down Expand Up @@ -149,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:
# Query the model for Fsa. The model will output a GraphActionCategorical, but we will
# simply interpret the logits as F(s, a). Conveniently the policy of a GFN is the softmax of
# log F(s,a) so we don't have to change anything in the sampling routines.
cat, graph_out = model(batch, batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)])
batch.cond_info = batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)]
cat, graph_out = model(batch)
# We compute \sum_{s,a : T(s,a)=s'} F(s,a), first we index all the parent's outputs by the
# parent actions. To do so we reuse the log_prob mechanism, but specify that the logprobs
# tensor is actually just the logits (which we chose to interpret as edge flows F(s,a). We
Expand Down
5 changes: 3 additions & 2 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def not_done(lst):
# Forward pass to get GraphActionCategorical
# Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does.
# TODO: compute bck_cat.log_prob(bck_a) when relevant
ci = cond_info[not_done_mask] if cond_info is not None else None
fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci)
batch = self.ctx.collate(torch_graphs)
batch.cond_info = cond_info[not_done_mask] if cond_info is not None else None
fwd_cat, *_, log_reward_preds = model(batch.to(dev))
if random_action_prob > 0:
# Device which graphs in the minibatch will get their action randomized
is_random_action = torch.tensor(
Expand Down
3 changes: 2 additions & 1 deletion src/gflownet/algo/multiobjective_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n
batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens)

# Forward pass of the model, returns a GraphActionCategorical and the optional bootstrap predictions
fwd_cat, log_reward_preds = model(batch, cond_info[batch_idx])
batch.cond_info = cond_info[batch_idx]
fwd_cat, log_reward_preds = model(batch)

# This is the log prob of each action in the trajectory
log_prob = fwd_cat.log_prob(batch.actions)
Expand Down
7 changes: 5 additions & 2 deletions src/gflownet/algo/soft_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from torch import Tensor
from torch_scatter import scatter

from gflownet import GFNAlgorithm
from gflownet.algo.graph_sampling import GraphSampler
from gflownet.config import Config
from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory
from gflownet.utils.misc import get_worker_device


class SoftQLearning:
class SoftQLearning(GFNAlgorithm):
def __init__(
self,
env: GraphBuildingEnv,
Expand All @@ -33,6 +34,7 @@ def __init__(
cfg: Config
The experiment configuration
"""
self.global_cfg = cfg
self.ctx = ctx
self.env = env
self.max_len = cfg.algo.max_len
Expand Down Expand Up @@ -147,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:

# Forward pass of the model, returns a GraphActionCategorical and per object predictions
# Here we will interpret the logits of the fwd_cat as Q values
Q, per_state_preds = model(batch, cond_info[batch_idx])
batch.cond_info = cond_info[batch_idx]
Q, per_state_preds = model(batch)

if self.do_q_prime_correction:
# First we need to estimate V_soft. We will use q_a' = \pi
Expand Down
13 changes: 6 additions & 7 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
"""
self.ctx = ctx
self.env = env
self.global_cfg = cfg
self.global_cfg = cfg # TODO: this belongs in the base class
self.cfg = cfg.algo.tb
self.max_len = cfg.algo.max_len
self.max_nodes = cfg.algo.max_nodes
Expand Down Expand Up @@ -147,9 +147,6 @@ def __init__(
self._subtb_max_len = self.global_cfg.algo.max_len + 2
self._init_subtb(get_worker_device())

def set_is_eval(self, is_eval: bool):
self.is_eval = is_eval

def create_training_data_from_own_samples(
self,
model: TrajectoryBalanceModel,
Expand Down Expand Up @@ -402,12 +399,14 @@ def compute_batch_losses(
# Forward pass of the model, returns a GraphActionCategorical representing the forward
# policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB).
if self.cfg.do_parameterize_p_b:
fwd_cat, bck_cat, per_graph_out = model(batch, batched_cond_info)
batch.cond_info = batched_cond_info
fwd_cat, bck_cat, per_graph_out = model(batch)
else:
if self.model_is_autoregressive:
fwd_cat, per_graph_out = model(batch, cond_info, batched=True)
fwd_cat, per_graph_out = model(batch, batched=True)
else:
fwd_cat, per_graph_out = model(batch, batched_cond_info)
batch.cond_info = batched_cond_info
fwd_cat, per_graph_out = model(batch)
# Retreive the reward predictions for the full graphs,
# i.e. the final graph of each trajectory
log_reward_preds = per_graph_out[final_graph_idx, 0]
Expand Down
5 changes: 5 additions & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class Config(StrictDataClass):
The hostname of the machine on which the experiment is run
pickle_mp_messages : bool
Whether to pickle messages sent between processes (only relevant if num_workers > 0)
mp_buffer_size : Optional[int]
If specified, use a buffer of this size in bytes for passing tensors between processes.
Note that this is only relevant if num_workers > 0.
Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers.
git_hash : Optional[str]
The git hash of the current commit
overwrite_existing_exp : bool
Expand All @@ -102,6 +106,7 @@ class Config(StrictDataClass):
pickle_mp_messages: bool = False
git_hash: Optional[str] = None
overwrite_existing_exp: bool = False
mp_buffer_size: Optional[int] = None
algo: AlgoConfig = field(default_factory=AlgoConfig)
model: ModelConfig = field(default_factory=ModelConfig)
opt: OptimizerConfig = field(default_factory=OptimizerConfig)
Expand Down
23 changes: 22 additions & 1 deletion src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import numpy as np
import torch
from torch.utils.data import IterableDataset
from torch_geometric.data import Batch

from gflownet import GFNAlgorithm, GFNTask
from gflownet.config import Config
from gflownet.data.replay_buffer import ReplayBuffer, detach_and_cpu
from gflownet.envs.graph_building_env import GraphBuildingEnvContext
from gflownet.envs.seq_building_env import SeqBatch
from gflownet.utils.misc import get_worker_rng
from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer


def cycle_call(it):
Expand Down Expand Up @@ -44,6 +47,7 @@ def __init__(
self.global_step_count.share_memory_()
self.global_step_count_lock = torch.multiprocessing.Lock()
self.current_iter = start_at_step
self.setup_mp_buffers()

def add_sampling_hook(self, hook: Callable):
"""Add a hook that is called when sampling new trajectories.
Expand Down Expand Up @@ -231,7 +235,7 @@ def create_batch(self, trajs, batch_info):
batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32)
batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32)
batch.obj_props = torch.stack([t["obj_props"] for t in trajs])
return batch
return self._maybe_put_in_mp_buffer(batch)

def compute_properties(self, trajs, mark_as_online=False):
"""Sets trajs' obj_props and is_valid keys by querying the task."""
Expand Down Expand Up @@ -319,3 +323,20 @@ def iterate_indices(self, n, num_samples):
yield np.arange(i, i + num_samples)
if i + num_samples < end:
yield np.arange(i + num_samples, end)

def setup_mp_buffers(self):
if self.cfg.num_workers > 0:
self.mp_buffer_size = self.cfg.mp_buffer_size
if self.mp_buffer_size:
self.result_buffer = [SharedPinnedBuffer(self.mp_buffer_size) for _ in range(self.cfg.num_workers)]
else:
self.mp_buffer_size = None

def _maybe_put_in_mp_buffer(self, batch):
if self.mp_buffer_size:
if not (isinstance(batch, (Batch, SeqBatch))):
warnings.warn(f"Expected a Batch object, but got {type(batch)}. Not using mp buffers.")
return batch
return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid)
else:
return batch
6 changes: 4 additions & 2 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,10 +887,12 @@ def entropy(self, logprobs=None):
"""
if logprobs is None:
logprobs = self.logsoftmax()
masks = self.action_masks if self.action_masks is not None else [None] * len(logprobs)
entropy = -sum(
[
scatter(i * i.exp(), b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1)
for i, b in zip(logprobs, self.batch)
scatter(im, b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1)
for i, b, m in zip(logprobs, self.batch, masks)
for im in [i.masked_fill(m == 0.0, 0) if m is not None else i]
]
)
return entropy
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/envs/seq_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, seqs: List[torch.Tensor], pad: int):
# Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this
# is the total number of timesteps.
self.num_graphs = self.lens.sum().item()
self.cond_info: torch.Tensor # May be set later

def to(self, device):
for name in dir(self):
Expand Down
Loading
Loading