-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automate distribution testing #389
Changes from all commits
0f904b9
3ad5c79
4fbc801
ecad0ee
0c2c0fb
56ace75
bace658
c38d560
040f417
4f40aaf
fb75a15
e6dcabd
8be24c6
50f0343
6f69e62
a06744e
e01de62
5a8fd42
127b12c
80750b0
3ceadfd
151ca2b
06f6d06
d1e8af5
39ced3a
35d283e
34919a4
bc0a518
81f8a37
bdff5d4
3d2ff83
f46b22b
e35ad5c
44e7783
4de34b6
4f32686
a0dfc66
ce6465a
c408c71
7227e19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import functools | ||
import numbers | ||
from typing import Tuple, Union | ||
|
||
import pyro.distributions as dist | ||
|
@@ -40,7 +41,7 @@ | |
from funsor.domains import Real, Reals | ||
import funsor.ops as ops | ||
from funsor.tensor import Tensor, dummy_numeric_array | ||
from funsor.terms import Binary, Funsor, Variable, eager, to_funsor | ||
from funsor.terms import Binary, Funsor, Variable, eager, to_data, to_funsor | ||
from funsor.util import methodof | ||
|
||
|
||
|
@@ -153,6 +154,19 @@ def _infer_param_domain(cls, name, raw_shape): | |
return Real | ||
|
||
|
||
########################################################### | ||
# Converting distribution funsors to PyTorch distributions | ||
########################################################### | ||
|
||
@to_data.register(Multinomial) # noqa: F821 | ||
def multinomial_to_data(funsor_dist, name_to_dim=None): | ||
probs = to_data(funsor_dist.probs, name_to_dim) | ||
total_count = to_data(funsor_dist.total_count, name_to_dim) | ||
if isinstance(total_count, numbers.Number) or len(total_count.shape) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to worry about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made this change because This is also a good reminder to add some generic JIT tests for distribution wrappers in a followup PR. |
||
return dist.Multinomial(int(total_count), probs=probs) | ||
raise NotImplementedError("inhomogeneous total_count not supported") | ||
|
||
|
||
############################################### | ||
# Converting PyTorch Distributions to funsors | ||
############################################### | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to update
funsor.distribution.Distribution.unscaled_sample
to useto_funsor
andto_data
throughout to get some tests to pass.