-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow AutoSemiDAIS to work without global variable #1665
Changes from 3 commits
896dc96
d211a73
84d8a7e
4f5d2d5
370e139
0d4451a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1250,13 +1250,22 @@ def local_model(theta): | |
during partial momentum refreshments in HMC. Defaults to 0.9. | ||
:param float init_scale: Initial scale for the standard deviation of the variational | ||
distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1. | ||
:param str subsample_plate: Optional name of the subsample plate site. This is required | ||
when the model has a subsample plate without `subsample_size` specified or | ||
the model has a subsample plate with `subsample_size` equal to the plate size. | ||
:param bool use_global_dais_params: Whether parameters controlling DAIS dynamic | ||
(HMC step size, HMC mass matrix, etc.) should be global (i.e. common to all | ||
data points in the subsample plate) or local (i.e. each data point in the | ||
subsample plate has individual parameters). Note that if `local_guide` is None | ||
and this argument is True, we also use global parameters for the base | ||
distribution. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model, | ||
local_model, | ||
global_guide, | ||
global_guide=None, | ||
local_guide=None, | ||
*, | ||
prefix="auto", | ||
|
@@ -1265,6 +1274,8 @@ def __init__( | |
eta_max=0.1, | ||
gamma_init=0.9, | ||
init_scale=0.1, | ||
subsample_plate=None, | ||
use_global_dais_params=False, | ||
): | ||
# init_loc_fn is only used to inspect the model. | ||
super().__init__(model, prefix=prefix, init_loc_fn=init_to_uniform) | ||
|
@@ -1289,6 +1300,8 @@ def __init__( | |
self.gamma_init = gamma_init | ||
self.K = K | ||
self.init_scale = init_scale | ||
self.subsample_plate = subsample_plate | ||
self.use_global_dais_params = use_global_dais_params | ||
|
||
def _setup_prototype(self, *args, **kwargs): | ||
super()._setup_prototype(*args, **kwargs) | ||
|
@@ -1301,6 +1314,17 @@ def _setup_prototype(self, *args, **kwargs): | |
and isinstance(site["args"][1], int) | ||
and site["args"][0] > site["args"][1] | ||
} | ||
if self.subsample_plate is not None: | ||
subsample_plates[self.subsample_plate] = self.prototype_trace[ | ||
self.subsample_plate | ||
] | ||
elif not subsample_plates: | ||
# Consider all plates as subsample plates. | ||
subsample_plates = { | ||
name: site | ||
for name, site in self.prototype_trace.items() | ||
if site["type"] == "plate" | ||
} | ||
num_plates = len(subsample_plates) | ||
assert ( | ||
num_plates == 1 | ||
|
@@ -1344,6 +1368,8 @@ def _setup_prototype(self, *args, **kwargs): | |
UnpackTransform(unpack_latent), out_axes=subsample_axes | ||
) | ||
plate_full_size, plate_subsample_size = subsample_plates[plate_name]["args"] | ||
if plate_subsample_size is None: | ||
plate_subsample_size = plate_full_size | ||
self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size | ||
self._local_plate = (plate_name, plate_full_size, plate_subsample_size) | ||
|
||
|
@@ -1451,37 +1477,68 @@ def fn(x): | |
D, K = self._local_latent_dim, self.K | ||
|
||
with numpyro.plate(plate_name, N, subsample_size=subsample_size) as idx: | ||
eta0 = numpyro.param( | ||
"{}_eta0".format(self.prefix), | ||
jnp.ones(N) * self.eta_init, | ||
constraint=constraints.interval(0, self.eta_max), | ||
event_dim=0, | ||
) | ||
eta_coeff = numpyro.param( | ||
"{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0 | ||
) | ||
if self.use_global_dais_params: | ||
eta0 = numpyro.param( | ||
"{}_eta0".format(self.prefix), | ||
self.eta_init, | ||
constraint=constraints.interval(0, self.eta_max), | ||
) | ||
eta0 = jnp.broadcast_to(eta0, idx.shape) | ||
eta_coeff = numpyro.param( | ||
"{}_eta_coeff".format(self.prefix), | ||
0.0, | ||
) | ||
eta_coeff = jnp.broadcast_to(eta_coeff, idx.shape) | ||
gamma = numpyro.param( | ||
"{}_gamma".format(self.prefix), | ||
0.9, | ||
constraint=constraints.interval(0, 1), | ||
) | ||
gamma = jnp.broadcast_to(gamma, idx.shape) | ||
betas = numpyro.param( | ||
"{}_beta_increments".format(self.prefix), | ||
jnp.ones(K), | ||
constraint=constraints.positive, | ||
) | ||
betas = jnp.broadcast_to(betas, idx.shape + (K,)) | ||
mass_matrix = numpyro.param( | ||
"{}_mass_matrix".format(self.prefix), | ||
jnp.ones(D), | ||
constraint=constraints.positive, | ||
) | ||
mass_matrix = jnp.broadcast_to(mass_matrix, idx.shape + (D,)) | ||
else: | ||
eta0 = numpyro.param( | ||
"{}_eta0".format(self.prefix), | ||
jnp.ones(N) * self.eta_init, | ||
constraint=constraints.interval(0, self.eta_max), | ||
event_dim=0, | ||
) | ||
eta_coeff = numpyro.param( | ||
"{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0 | ||
) | ||
gamma = numpyro.param( | ||
"{}_gamma".format(self.prefix), | ||
jnp.ones(N) * 0.9, | ||
constraint=constraints.interval(0, 1), | ||
event_dim=0, | ||
) | ||
betas = numpyro.param( | ||
"{}_beta_increments".format(self.prefix), | ||
jnp.ones((N, K)), | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
mass_matrix = numpyro.param( | ||
"{}_mass_matrix".format(self.prefix), | ||
jnp.ones((N, D)), | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
|
||
gamma = numpyro.param( | ||
"{}_gamma".format(self.prefix), | ||
jnp.ones(N) * 0.9, | ||
constraint=constraints.interval(0, 1), | ||
event_dim=0, | ||
) | ||
betas = numpyro.param( | ||
"{}_beta_increments".format(self.prefix), | ||
jnp.ones((N, K)), | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
betas = jnp.cumsum(betas, axis=-1) | ||
betas = betas / betas[..., -1:] | ||
|
||
mass_matrix = numpyro.param( | ||
"{}_mass_matrix".format(self.prefix), | ||
jnp.ones((N, D)), | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
inv_mass_matrix = 0.5 / mass_matrix | ||
assert inv_mass_matrix.shape == (subsample_size, D) | ||
|
||
|
@@ -1527,17 +1584,32 @@ def base_z_dist_log_prob(z): | |
base_z_dist_log_prob(z_0) / subsample_size, | ||
) | ||
else: | ||
z_0_loc_init = jnp.zeros((N, D)) | ||
z_0_loc = numpyro.param( | ||
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1 | ||
) | ||
z_0_scale_init = jnp.ones((N, D)) * self.init_scale | ||
z_0_scale = numpyro.param( | ||
"{}_z_0_scale".format(self.prefix), | ||
z_0_scale_init, | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
if self.use_global_dais_params: | ||
z_0_loc_init = jnp.zeros(D) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think this part makes sense. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, I think amortized guide should have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but what if i just want non-amortized mean-field variational distributions for the locals? i would need to specify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in that case, you might want to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what exactly are the three options? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None is the current master behavior, or False in this PR. dynamic is your request. full is this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry for being slow here (traveling). i think
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, that will imply that if users want to have global params for the base dist, they will need to use local_guide? I don't have a preference here. To me, base params play a similar role as betas; the dynamic will depend on model density etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think there's any reason in practice why you'd really want global/shared params for the base dist q(z_0). having global params for e.g. beta makes sense because it's basically a "higher order quantity" and so harder to estimate and probably varies less from data point to data point. just in the same way that we'd probably generally be more comfortable in sharing scales/variances across data points than locs/means. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense to me, thanks! |
||
z_0_loc = numpyro.param( | ||
"{}_z_0_loc".format(self.prefix), | ||
z_0_loc_init, | ||
) | ||
z_0_loc = jnp.broadcast_to(z_0_loc, idx.shape + (D,)) | ||
z_0_scale_init = jnp.ones(D) * self.init_scale | ||
z_0_scale = numpyro.param( | ||
"{}_z_0_scale".format(self.prefix), | ||
z_0_scale_init, | ||
constraint=constraints.positive, | ||
) | ||
z_0_scale = jnp.broadcast_to(z_0_scale, idx.shape + (D,)) | ||
else: | ||
z_0_loc_init = jnp.zeros((N, D)) | ||
z_0_loc = numpyro.param( | ||
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1 | ||
) | ||
z_0_scale_init = jnp.ones((N, D)) * self.init_scale | ||
z_0_scale = numpyro.param( | ||
"{}_z_0_scale".format(self.prefix), | ||
z_0_scale_init, | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1) | ||
assert base_z_dist.shape() == (subsample_size, D) | ||
z_0 = numpyro.sample( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm a bit confused by these args/checks. afaik we should support the following cases:
any other scenario (e.g. 0 plates or > 1 plates) is not supported. is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC we allow multiple plates but only one subsample plate. I just added an
elif not subsample_plates: ...
branch to cover the case you mentioned above: