Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow AutoSemiDAIS to work without global variable #1665

Merged
merged 6 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 98 additions & 39 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,13 +1250,16 @@ def local_model(theta):
during partial momentum refreshments in HMC. Defaults to 0.9.
:param float init_scale: Initial scale for the standard deviation of the variational
distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1.
:param str subsample_plate: Optional name of the subsample plate site. This is required
when the model does not have subsample plate (like in VAE settings).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just revised it to clarify that this is required when the model has a subsample plate without subsample_size specified.

:param bool use_global_dais_params: Whether to use global parameters for DAIS dynamics.
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
model,
local_model,
global_guide,
global_guide=None,
local_guide=None,
*,
prefix="auto",
Expand All @@ -1265,6 +1268,8 @@ def __init__(
eta_max=0.1,
gamma_init=0.9,
init_scale=0.1,
subsample_plate=None,
use_global_dais_params=False,
):
# init_loc_fn is only used to inspect the model.
super().__init__(model, prefix=prefix, init_loc_fn=init_to_uniform)
Expand All @@ -1289,6 +1294,8 @@ def __init__(
self.gamma_init = gamma_init
self.K = K
self.init_scale = init_scale
self.subsample_plate = subsample_plate
self.use_global_dais_params = use_global_dais_params

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
Expand All @@ -1301,6 +1308,10 @@ def _setup_prototype(self, *args, **kwargs):
and isinstance(site["args"][1], int)
and site["args"][0] > site["args"][1]
}
if self.subsample_plate is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a bit confused by these args/checks. afaik we should support the following cases:

  • there is a single plate but there is no subsampling
  • there is a single plate and it is subsampled
    any other scenario (e.g. 0 plates or > 1 plates) is not supported. is that right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC we allow multiple plates but only one subsample plate. I just added an elif not subsample_plates: ... branch to cover the case you mentioned above:

  • there is a single plate but there is no subsampling

subsample_plates[self.subsample_plate] = self.prototype_trace[
self.subsample_plate
]
num_plates = len(subsample_plates)
assert (
num_plates == 1
Expand Down Expand Up @@ -1344,6 +1355,8 @@ def _setup_prototype(self, *args, **kwargs):
UnpackTransform(unpack_latent), out_axes=subsample_axes
)
plate_full_size, plate_subsample_size = subsample_plates[plate_name]["args"]
if plate_subsample_size is None:
plate_subsample_size = plate_full_size
self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size
self._local_plate = (plate_name, plate_full_size, plate_subsample_size)

Expand Down Expand Up @@ -1451,37 +1464,68 @@ def fn(x):
D, K = self._local_latent_dim, self.K

with numpyro.plate(plate_name, N, subsample_size=subsample_size) as idx:
eta0 = numpyro.param(
"{}_eta0".format(self.prefix),
jnp.ones(N) * self.eta_init,
constraint=constraints.interval(0, self.eta_max),
event_dim=0,
)
eta_coeff = numpyro.param(
"{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0
)
if self.use_global_dais_params:
eta0 = numpyro.param(
"{}_eta0".format(self.prefix),
self.eta_init,
constraint=constraints.interval(0, self.eta_max),
)
eta0 = jnp.broadcast_to(eta0, idx.shape)
eta_coeff = numpyro.param(
"{}_eta_coeff".format(self.prefix),
0.0,
)
eta_coeff = jnp.broadcast_to(eta_coeff, idx.shape)
gamma = numpyro.param(
"{}_gamma".format(self.prefix),
0.9,
constraint=constraints.interval(0, 1),
)
gamma = jnp.broadcast_to(gamma, idx.shape)
betas = numpyro.param(
"{}_beta_increments".format(self.prefix),
jnp.ones(K),
constraint=constraints.positive,
)
betas = jnp.broadcast_to(betas, idx.shape + (K,))
mass_matrix = numpyro.param(
"{}_mass_matrix".format(self.prefix),
jnp.ones(D),
constraint=constraints.positive,
)
mass_matrix = jnp.broadcast_to(mass_matrix, idx.shape + (D,))
else:
eta0 = numpyro.param(
"{}_eta0".format(self.prefix),
jnp.ones(N) * self.eta_init,
constraint=constraints.interval(0, self.eta_max),
event_dim=0,
)
eta_coeff = numpyro.param(
"{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0
)
gamma = numpyro.param(
"{}_gamma".format(self.prefix),
jnp.ones(N) * 0.9,
constraint=constraints.interval(0, 1),
event_dim=0,
)
betas = numpyro.param(
"{}_beta_increments".format(self.prefix),
jnp.ones((N, K)),
constraint=constraints.positive,
event_dim=1,
)
mass_matrix = numpyro.param(
"{}_mass_matrix".format(self.prefix),
jnp.ones((N, D)),
constraint=constraints.positive,
event_dim=1,
)

gamma = numpyro.param(
"{}_gamma".format(self.prefix),
jnp.ones(N) * 0.9,
constraint=constraints.interval(0, 1),
event_dim=0,
)
betas = numpyro.param(
"{}_beta_increments".format(self.prefix),
jnp.ones((N, K)),
constraint=constraints.positive,
event_dim=1,
)
betas = jnp.cumsum(betas, axis=-1)
betas = betas / betas[..., -1:]

mass_matrix = numpyro.param(
"{}_mass_matrix".format(self.prefix),
jnp.ones((N, D)),
constraint=constraints.positive,
event_dim=1,
)
inv_mass_matrix = 0.5 / mass_matrix
assert inv_mass_matrix.shape == (subsample_size, D)

Expand Down Expand Up @@ -1527,17 +1571,32 @@ def base_z_dist_log_prob(z):
base_z_dist_log_prob(z_0) / subsample_size,
)
else:
z_0_loc_init = jnp.zeros((N, D))
z_0_loc = numpyro.param(
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1
)
z_0_scale_init = jnp.ones((N, D)) * self.init_scale
z_0_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
z_0_scale_init,
constraint=constraints.positive,
event_dim=1,
)
if self.use_global_dais_params:
z_0_loc_init = jnp.zeros(D)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think this part makes sense. the z params should always be local.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think amortized guide should have local_guide specified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what if i just want non-amortized mean-field variational distributions for the locals? i would need to specify local_guide as opposed to relying on a convenient default behavior?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in that case, you might want to set use_global_dais_params=False. We might also change the semantics of this flag to global_dais_params=None/"dynamic"/"full". wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what exactly are the three options?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is the current master behavior, or False in this PR. dynamic is your request. full is this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for being slow here (traveling). i think use_global_dais_params which controls e.g. betas should be entirely separate from whatever controls the distribution over z_0. what about the following behavior:

  • if local_guide is provided use that.
  • otherwise if local_guide=None and there exist local variables instantiate a auto mean-field guide?

Copy link
Member Author

@fehiepsi fehiepsi Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that will imply that if users want to have global params for the base dist, they will need to use local_guide? I don't have a preference here. To me, base params play a similar role as betas; the dynamic will depend on model density etc. In the DAIS paper, the author even uses a fixed base dist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think there's any reason in practice why you'd really want global/shared params for the base dist q(z_0). having global params for e.g. beta makes sense because it's basically a "higher order quantity" and so harder to estimate and probably varies less from data point to data point. just in the same way that we'd probably generally be more comfortable in sharing scales/variances across data points than locs/means.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me, thanks!

z_0_loc = numpyro.param(
"{}_z_0_loc".format(self.prefix),
z_0_loc_init,
)
z_0_loc = jnp.broadcast_to(z_0_loc, idx.shape + (D,))
z_0_scale_init = jnp.ones(D) * self.init_scale
z_0_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
z_0_scale_init,
constraint=constraints.positive,
)
z_0_scale = jnp.broadcast_to(z_0_scale, idx.shape + (D,))
else:
z_0_loc_init = jnp.zeros((N, D))
z_0_loc = numpyro.param(
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1
)
z_0_scale_init = jnp.ones((N, D)) * self.init_scale
z_0_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
z_0_scale_init,
constraint=constraints.positive,
event_dim=1,
)
base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1)
assert base_z_dist.shape() == (subsample_size, D)
z_0 = numpyro.sample(
Expand Down
24 changes: 24 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,3 +1226,27 @@ def model():
)
assert guide_samples["x"].shape == sample_shape + shape
assert guide_samples["x2"].shape == sample_shape + shape


@pytest.mark.parametrize("use_global_dais_params", [True, False])
def test_dais_vae(use_global_dais_params):
def model():
with numpyro.plate("N", 10):
numpyro.sample("x", dist.Normal(jnp.arange(-5, 5), 2))

guide = AutoSemiDAIS(
model, model, subsample_plate="N", use_global_dais_params=use_global_dais_params
)
svi = SVI(model, guide, optax.adam(0.02), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(0), 3000)
samples = guide.sample_posterior(
random.PRNGKey(1), svi_results.params, sample_shape=(1000,)
)
if use_global_dais_params:
assert_allclose(
samples["x"].mean(), jnp.arange(-5, 5).mean(), atol=0.1, rtol=0.1
)
else:
assert_allclose(
samples["x"].mean(axis=0), jnp.arange(-5, 5), atol=0.2, rtol=0.1
)