-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor log probability calculations into separate utility functions #216
Conversation
…into util functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an awesome PR, I think having these functions factored out will come in handy down the line as well for novel methods which might calculate these elements slightly differently. I think we should rename the function calls but this is a minor nit. I'll also wait for @younik before approving.
@@ -78,6 +79,13 @@ def logF_parameters(self): | |||
) | |||
) | |||
|
|||
def get_pfs_and_pbs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I misunderstood. I quite like this factorization.
valid_log_pf_actions = self.pf.to_probability_distribution( | ||
states, module_output | ||
).log_prob(actions.tensor) | ||
log_pf_actions, log_pb_actions = self.get_pfs_and_pbs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I really like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good, it makes code more readable!
Thank you :)
src/gfn/gflownet/detailed_balance.py
Outdated
log_pb_actions = targets.clone() | ||
targets[~valid_transitions_is_done] += valid_log_F_s_next | ||
log_F_s_next = torch.zeros_like(log_pb_actions) | ||
log_F_s_next[~valid_transitions_is_done] += valid_log_F_s_next |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with "=" is more clear
See the issue #211.
To address the need for calculating log probabilities in
Sampler
, we decided to separate the logic for computinglog_probs
from theGFlowNet
classes and instead create utility functions for these calculations.Naming (of variables, files, etc.) suggestions are welcomed!