-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow AutoSemiDAIS to work without global variable #1665
Changes from 1 commit
896dc96
d211a73
84d8a7e
4f5d2d5
370e139
0d4451a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1250,13 +1250,16 @@ def local_model(theta): | |
during partial momentum refreshments in HMC. Defaults to 0.9. | ||
:param float init_scale: Initial scale for the standard deviation of the variational | ||
distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1. | ||
:param str subsample_plate: Optional name of the subsample plate site. This is required | ||
when the model does not have subsample plate (like in VAE settings). | ||
:param bool use_global_dais_params: Whether to use global parameters for DAIS dynamics. | ||
fehiepsi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
model, | ||
local_model, | ||
global_guide, | ||
global_guide=None, | ||
local_guide=None, | ||
*, | ||
prefix="auto", | ||
|
@@ -1265,6 +1268,8 @@ def __init__( | |
eta_max=0.1, | ||
gamma_init=0.9, | ||
init_scale=0.1, | ||
subsample_plate=None, | ||
use_global_dais_params=False, | ||
): | ||
# init_loc_fn is only used to inspect the model. | ||
super().__init__(model, prefix=prefix, init_loc_fn=init_to_uniform) | ||
|
@@ -1289,6 +1294,8 @@ def __init__( | |
self.gamma_init = gamma_init | ||
self.K = K | ||
self.init_scale = init_scale | ||
self.subsample_plate = subsample_plate | ||
self.use_global_dais_params = use_global_dais_params | ||
|
||
def _setup_prototype(self, *args, **kwargs): | ||
super()._setup_prototype(*args, **kwargs) | ||
|
@@ -1301,6 +1308,10 @@ def _setup_prototype(self, *args, **kwargs): | |
and isinstance(site["args"][1], int) | ||
and site["args"][0] > site["args"][1] | ||
} | ||
if self.subsample_plate is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm a bit confused by these args/checks. afaik we should support the following cases:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC we allow multiple plates but only one subsample plate. I just added an
|
||
subsample_plates[self.subsample_plate] = self.prototype_trace[ | ||
self.subsample_plate | ||
] | ||
num_plates = len(subsample_plates) | ||
assert ( | ||
num_plates == 1 | ||
|
@@ -1344,6 +1355,8 @@ def _setup_prototype(self, *args, **kwargs): | |
UnpackTransform(unpack_latent), out_axes=subsample_axes | ||
) | ||
plate_full_size, plate_subsample_size = subsample_plates[plate_name]["args"] | ||
if plate_subsample_size is None: | ||
plate_subsample_size = plate_full_size | ||
self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size | ||
self._local_plate = (plate_name, plate_full_size, plate_subsample_size) | ||
|
||
|
@@ -1451,37 +1464,68 @@ def fn(x): | |
D, K = self._local_latent_dim, self.K | ||
|
||
with numpyro.plate(plate_name, N, subsample_size=subsample_size) as idx: | ||
eta0 = numpyro.param( | ||
"{}_eta0".format(self.prefix), | ||
jnp.ones(N) * self.eta_init, | ||
constraint=constraints.interval(0, self.eta_max), | ||
event_dim=0, | ||
) | ||
eta_coeff = numpyro.param( | ||
"{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0 | ||
) | ||
if self.use_global_dais_params: | ||
eta0 = numpyro.param( | ||
"{}_eta0".format(self.prefix), | ||
self.eta_init, | ||
constraint=constraints.interval(0, self.eta_max), | ||
) | ||
eta0 = jnp.broadcast_to(eta0, idx.shape) | ||
eta_coeff = numpyro.param( | ||
"{}_eta_coeff".format(self.prefix), | ||
0.0, | ||
) | ||
eta_coeff = jnp.broadcast_to(eta_coeff, idx.shape) | ||
gamma = numpyro.param( | ||
"{}_gamma".format(self.prefix), | ||
0.9, | ||
constraint=constraints.interval(0, 1), | ||
) | ||
gamma = jnp.broadcast_to(gamma, idx.shape) | ||
betas = numpyro.param( | ||
"{}_beta_increments".format(self.prefix), | ||
jnp.ones(K), | ||
constraint=constraints.positive, | ||
) | ||
betas = jnp.broadcast_to(betas, idx.shape + (K,)) | ||
mass_matrix = numpyro.param( | ||
"{}_mass_matrix".format(self.prefix), | ||
jnp.ones(D), | ||
constraint=constraints.positive, | ||
) | ||
mass_matrix = jnp.broadcast_to(mass_matrix, idx.shape + (D,)) | ||
else: | ||
eta0 = numpyro.param( | ||
"{}_eta0".format(self.prefix), | ||
jnp.ones(N) * self.eta_init, | ||
constraint=constraints.interval(0, self.eta_max), | ||
event_dim=0, | ||
) | ||
eta_coeff = numpyro.param( | ||
"{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0 | ||
) | ||
gamma = numpyro.param( | ||
"{}_gamma".format(self.prefix), | ||
jnp.ones(N) * 0.9, | ||
constraint=constraints.interval(0, 1), | ||
event_dim=0, | ||
) | ||
betas = numpyro.param( | ||
"{}_beta_increments".format(self.prefix), | ||
jnp.ones((N, K)), | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
mass_matrix = numpyro.param( | ||
"{}_mass_matrix".format(self.prefix), | ||
jnp.ones((N, D)), | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
|
||
gamma = numpyro.param( | ||
"{}_gamma".format(self.prefix), | ||
jnp.ones(N) * 0.9, | ||
constraint=constraints.interval(0, 1), | ||
event_dim=0, | ||
) | ||
betas = numpyro.param( | ||
"{}_beta_increments".format(self.prefix), | ||
jnp.ones((N, K)), | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
betas = jnp.cumsum(betas, axis=-1) | ||
betas = betas / betas[..., -1:] | ||
|
||
mass_matrix = numpyro.param( | ||
"{}_mass_matrix".format(self.prefix), | ||
jnp.ones((N, D)), | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
inv_mass_matrix = 0.5 / mass_matrix | ||
assert inv_mass_matrix.shape == (subsample_size, D) | ||
|
||
|
@@ -1527,17 +1571,32 @@ def base_z_dist_log_prob(z): | |
base_z_dist_log_prob(z_0) / subsample_size, | ||
) | ||
else: | ||
z_0_loc_init = jnp.zeros((N, D)) | ||
z_0_loc = numpyro.param( | ||
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1 | ||
) | ||
z_0_scale_init = jnp.ones((N, D)) * self.init_scale | ||
z_0_scale = numpyro.param( | ||
"{}_z_0_scale".format(self.prefix), | ||
z_0_scale_init, | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
if self.use_global_dais_params: | ||
z_0_loc_init = jnp.zeros(D) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think this part makes sense. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, I think amortized guide should have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but what if i just want non-amortized mean-field variational distributions for the locals? i would need to specify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in that case, you might want to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what exactly are the three options? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None is the current master behavior, or False in this PR. dynamic is your request. full is this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry for being slow here (traveling). i think
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, that will imply that if users want to have global params for the base dist, they will need to use local_guide? I don't have a preference here. To me, base params play a similar role as betas; the dynamic will depend on model density etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think there's any reason in practice why you'd really want global/shared params for the base dist q(z_0). having global params for e.g. beta makes sense because it's basically a "higher order quantity" and so harder to estimate and probably varies less from data point to data point. just in the same way that we'd probably generally be more comfortable in sharing scales/variances across data points than locs/means. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense to me, thanks! |
||
z_0_loc = numpyro.param( | ||
"{}_z_0_loc".format(self.prefix), | ||
z_0_loc_init, | ||
) | ||
z_0_loc = jnp.broadcast_to(z_0_loc, idx.shape + (D,)) | ||
z_0_scale_init = jnp.ones(D) * self.init_scale | ||
z_0_scale = numpyro.param( | ||
"{}_z_0_scale".format(self.prefix), | ||
z_0_scale_init, | ||
constraint=constraints.positive, | ||
) | ||
z_0_scale = jnp.broadcast_to(z_0_scale, idx.shape + (D,)) | ||
else: | ||
z_0_loc_init = jnp.zeros((N, D)) | ||
z_0_loc = numpyro.param( | ||
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1 | ||
) | ||
z_0_scale_init = jnp.ones((N, D)) * self.init_scale | ||
z_0_scale = numpyro.param( | ||
"{}_z_0_scale".format(self.prefix), | ||
z_0_scale_init, | ||
constraint=constraints.positive, | ||
event_dim=1, | ||
) | ||
base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1) | ||
assert base_z_dist.shape() == (subsample_size, D) | ||
z_0 = numpyro.sample( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just revised it to clarify that this is required when the model has a subsample plate without
subsample_size
specified.