Skip to content

Commit

Permalink
Merge pull request #216 from GFNOrg/hyeok9855/get_logprobs
Browse files Browse the repository at this point in the history
Refactor log probability calculations into separate utility functions
hyeok9855 authored Nov 15, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents 8259a21 + b506769 commit b5d909e
Showing 5 changed files with 259 additions and 150 deletions.
106 changes: 4 additions & 102 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -11,11 +11,7 @@
from gfn.modules import GFNModule
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)
from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
@@ -145,6 +141,7 @@ def get_pfs_and_pbs(
trajectories: Trajectories to evaluate.
fill_value: Value to use for invalid states (i.e. $s_f$ that is added to
shorter trajectories).
recalculate_all_logprobs: Whether to re-evaluate all logprobs.
Returns: A tuple of float tensors of shape (max_length, n_trajectories) containing
the log_pf and log_pb for each action in each trajectory. The first one can be None.
@@ -153,103 +150,9 @@ def get_pfs_and_pbs(
ValueError: if the trajectories are backward.
AssertionError: when actions and states dimensions mismatch.
"""
# fill value is the value used for invalid states (sink state usually)
if trajectories.is_backward:
raise ValueError("Backward trajectories are not supported")

valid_states = trajectories.states[~trajectories.states.is_sink_state]
valid_actions = trajectories.actions[~trajectories.actions.is_dummy]

# uncomment next line for debugging
# assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy)

if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

if has_log_probs(trajectories) and not recalculate_all_logprobs:
log_pf_trajectories = trajectories.log_probs
else:
if (
trajectories.estimator_outputs is not None
and not recalculate_all_logprobs
):
estimator_outputs = trajectories.estimator_outputs[
~trajectories.actions.is_dummy
]
else:
if trajectories.conditioning is not None:
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
with has_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states, masked_cond)
else:
# Here, we pass all valid states, i.e., non-sink states.
with no_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.pf.to_probability_distribution(
valid_states, estimator_outputs
).log_prob(
valid_actions.tensor
) # Using the actions sampled off-policy.
log_pf_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
dtype=torch.float,
)
log_pf_trajectories[~trajectories.actions.is_dummy] = valid_log_pf_actions

non_initial_valid_states = valid_states[~valid_states.is_initial_state]
non_exit_valid_actions = valid_actions[~valid_actions.is_exit]

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
if trajectories.conditioning is not None:
# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state][~valid_states.is_initial_state]

# Pass all valid states, i.e., non-sink states, except the initial state.
with has_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states, masked_cond)
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
with no_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)

log_pb_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
dtype=torch.float,
return get_trajectory_pfs_and_pbs(
self.pf, self.pb, trajectories, fill_value, recalculate_all_logprobs
)
log_pb_trajectories_slice = torch.full_like(
valid_actions.tensor[..., 0], fill_value=fill_value, dtype=torch.float
)
log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice

assert log_pf_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
assert log_pb_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
return log_pf_trajectories, log_pb_trajectories

def get_trajectories_scores(
self,
@@ -265,7 +168,6 @@ def get_trajectories_scores(
Returns: A tuple of float tensors of shape (n_trajectories,)
containing the total log_pf, total log_pb, and the total
log-likelihood of the trajectories.
"""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
67 changes: 21 additions & 46 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)
from gfn.utils.prob_calculations import get_transition_pfs_and_pbs


def check_compatibility(states, actions, transitions):
@@ -78,6 +79,13 @@ def logF_parameters(self):
)
)

def get_pfs_and_pbs(
self, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
return get_transition_pfs_and_pbs(
self.pf, self.pb, transitions, recalculate_all_logprobs
)

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -101,70 +109,39 @@ def get_scores(
"""
if transitions.is_backward:
raise ValueError("Backward transitions are not supported")

states = transitions.states
actions = transitions.actions

# uncomment next line for debugging
# assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy)
check_compatibility(states, actions, transitions)

if has_log_probs(transitions) and not recalculate_all_logprobs:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions, with optional conditioning.
# TODO: Inefficient duplication in case of tempered policy
# The Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
# See (#156).
if transitions.conditioning is not None:
with has_conditioning_exception_handler("pf", self.pf):
module_output = self.pf(states, transitions.conditioning)
else:
with no_conditioning_exception_handler("pf", self.pf):
module_output = self.pf(states)

valid_log_pf_actions = self.pf.to_probability_distribution(
states, module_output
).log_prob(actions.tensor)
log_pf_actions, log_pb_actions = self.get_pfs_and_pbs(
transitions, recalculate_all_logprobs
)

# LogF is potentially a conditional computation.
if transitions.conditioning is not None:
with has_conditioning_exception_handler("logF", self.logF):
valid_log_F_s = self.logF(states, transitions.conditioning).squeeze(-1)
log_F_s = self.logF(states, transitions.conditioning).squeeze(-1)
else:
with no_conditioning_exception_handler("logF", self.logF):
valid_log_F_s = self.logF(states).squeeze(-1)
log_F_s = self.logF(states).squeeze(-1)

if self.forward_looking:
log_rewards = env.log_reward(states) # TODO: RM unsqueeze(-1) ?
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
valid_log_F_s = valid_log_F_s + log_rewards
log_F_s = log_F_s + log_rewards

preds = valid_log_pf_actions + valid_log_F_s
targets = torch.zeros_like(preds)
preds = log_pf_actions + log_F_s

# uncomment next line for debugging
# assert transitions.next_states.is_sink_state.equal(transitions.is_done)

# automatically removes invalid transitions (i.e. s_f -> s_f)
valid_next_states = transitions.next_states[~transitions.is_done]
non_exit_actions = actions[~actions.is_exit]

# Evaluate the log PB of the actions, with optional conditioning.
if transitions.conditioning is not None:
with has_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(
valid_next_states, transitions.conditioning[~transitions.is_done]
)
else:
with no_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(valid_next_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
valid_next_states, module_output
).log_prob(non_exit_actions.tensor)

valid_transitions_is_done = transitions.is_done[
~transitions.states.is_sink_state
]
@@ -179,23 +156,21 @@ def get_scores(
with no_conditioning_exception_handler("logF", self.logF):
valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1)

targets[~valid_transitions_is_done] = valid_log_pb_actions
log_pb_actions = targets.clone()
targets[~valid_transitions_is_done] += valid_log_F_s_next
log_F_s_next = torch.zeros_like(log_pb_actions)
log_F_s_next[~valid_transitions_is_done] = valid_log_F_s_next
assert transitions.log_rewards is not None
valid_transitions_log_rewards = transitions.log_rewards[
~transitions.states.is_sink_state
]
targets[valid_transitions_is_done] = valid_transitions_log_rewards[
log_F_s_next[valid_transitions_is_done] = valid_transitions_log_rewards[
valid_transitions_is_done
]
targets = log_pb_actions + log_F_s_next

scores = preds - targets

assert valid_log_pf_actions.shape == (transitions.n_transitions,)
assert log_pb_actions.shape == (transitions.n_transitions,)
assert scores.shape == (transitions.n_transitions,)
return valid_log_pf_actions, log_pb_actions, scores
return log_pf_actions, log_pb_actions, scores

def loss(self, env: Env, transitions: Transitions) -> torch.Tensor:
"""Detailed balance loss.
2 changes: 1 addition & 1 deletion src/gfn/modules.py
Original file line number Diff line number Diff line change
@@ -352,7 +352,7 @@ def _forward_trunk(

return out

def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
def forward(self, states: States, conditioning: torch.Tensor) -> torch.Tensor:
"""Forward pass of the module.
Args:
2 changes: 1 addition & 1 deletion src/gfn/utils/distributions.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@


class UnsqueezedCategorical(Categorical):
"""Samples froma categorical distribution with an unsqueezed final dimension.
"""Samples from a categorical distribution with an unsqueezed final dimension.
Samples are unsqueezed to be of shape (batch_size, 1) instead of (batch_size,).
232 changes: 232 additions & 0 deletions src/gfn/utils/prob_calculations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from typing import Tuple

import torch

from gfn.containers import Trajectories, Transitions
from gfn.modules import GFNModule
from gfn.states import States
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)


def check_cond_forward(
module: GFNModule,
module_name: str,
states: States,
condition: torch.Tensor | None = None,
) -> torch.Tensor:
if condition is not None:
with has_conditioning_exception_handler(module_name, module):
return module(states, condition)
else:
with no_conditioning_exception_handler(module_name, module):
return module(states)


#########################
##### Trajectories #####
#########################


def get_trajectory_pfs_and_pbs(
pf: GFNModule,
pb: GFNModule,
trajectories: Trajectories,
fill_value: float = 0.0,
recalculate_all_logprobs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# fill value is the value used for invalid states (sink state usually)
if trajectories.is_backward:
raise ValueError("Backward trajectories are not supported")

# uncomment next line for debugging
# assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy)

log_pf_trajectories = get_trajectory_pfs(
pf,
trajectories,
fill_value=fill_value,
recalculate_all_logprobs=recalculate_all_logprobs,
)
log_pb_trajectories = get_trajectory_pbs(pb, trajectories, fill_value=fill_value)

assert log_pf_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
assert log_pb_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)

return log_pf_trajectories, log_pb_trajectories


def get_trajectory_pfs(
pf: GFNModule,
trajectories: Trajectories,
fill_value: float = 0.0,
recalculate_all_logprobs: bool = False,
) -> torch.Tensor:
state_mask = ~trajectories.states.is_sink_state
action_mask = ~trajectories.actions.is_dummy

valid_states = trajectories.states[state_mask]
valid_actions = trajectories.actions[action_mask]

if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

if has_log_probs(trajectories) and not recalculate_all_logprobs:
log_pf_trajectories = trajectories.log_probs
else:
if trajectories.estimator_outputs is not None and not recalculate_all_logprobs:
estimator_outputs = trajectories.estimator_outputs[action_mask]
else:
masked_cond = None
if trajectories.conditioning is not None:
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[state_mask]

estimator_outputs = check_cond_forward(pf, "pf", valid_states, masked_cond)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = pf.to_probability_distribution(
valid_states, estimator_outputs
).log_prob(
valid_actions.tensor
) # Using the actions sampled off-policy.

log_pf_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
dtype=torch.float,
)
log_pf_trajectories[action_mask] = valid_log_pf_actions

return log_pf_trajectories


def get_trajectory_pbs(
pb: GFNModule, trajectories: Trajectories, fill_value: float = 0.0
) -> torch.Tensor:
# Note the different mask for valid states and actions compared to the pf case.
state_mask = (
~trajectories.states.is_sink_state & ~trajectories.states.is_initial_state
)
action_mask = ~trajectories.actions.is_dummy & ~trajectories.actions.is_exit

valid_states = trajectories.states[state_mask]
valid_actions = trajectories.actions[action_mask]

if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
masked_cond = None
if trajectories.conditioning is not None:
# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[state_mask]

estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond)

valid_log_pb_actions = pb.to_probability_distribution(
valid_states, estimator_outputs
).log_prob(valid_actions.tensor)

log_pb_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
dtype=torch.float,
)
log_pb_trajectories[action_mask] = valid_log_pb_actions

return log_pb_trajectories


########################
##### Transitions #####
########################


def get_transition_pfs_and_pbs(
pf: GFNModule,
pb: GFNModule,
transitions: Transitions,
recalculate_all_logprobs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if transitions.is_backward:
raise ValueError("Backward transitions are not supported")

log_pf_transitions = get_transition_pfs(pf, transitions, recalculate_all_logprobs)
log_pb_transitions = get_transition_pbs(pb, transitions)

assert log_pf_transitions.shape == (transitions.n_transitions,)
assert log_pb_transitions.shape == (transitions.n_transitions,)

return log_pf_transitions, log_pb_transitions


def get_transition_pfs(
pf: GFNModule, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> torch.Tensor:
states = transitions.states
actions = transitions.actions

if has_log_probs(transitions) and not recalculate_all_logprobs:
log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions, with optional conditioning.
# TODO: Inefficient duplication in case of tempered policy
# The Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
# See (#156).
estimator_outputs = check_cond_forward(
pf, "pf", states, transitions.conditioning
)

log_pf_actions = pf.to_probability_distribution(
states, estimator_outputs
).log_prob(actions.tensor)

return log_pf_actions


def get_transition_pbs(pb: GFNModule, transitions: Transitions) -> torch.Tensor:
# automatically removes invalid transitions (i.e. s_f -> s_f)
valid_next_states = transitions.next_states[~transitions.is_done]
non_exit_actions = transitions.actions[~transitions.actions.is_exit]

# Evaluate the log PB of the actions, with optional conditioning.
masked_cond = (
transitions.conditioning[~transitions.is_done]
if transitions.conditioning is not None
else None
)
estimator_outputs = check_cond_forward(pb, "pb", valid_next_states, masked_cond)

valid_log_pb_actions = pb.to_probability_distribution(
valid_next_states, estimator_outputs
).log_prob(non_exit_actions.tensor)

# Evaluate the log PB of the actions.
log_pb_actions = torch.zeros(
(transitions.n_transitions,),
dtype=torch.float,
device=valid_log_pb_actions.device,
)

log_pb_actions[~transitions.is_done] = valid_log_pb_actions

return log_pb_actions

0 comments on commit b5d909e

Please sign in to comment.