Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error from implementing MaskedCouplingInferenceFunnel #46

Open
Jice-Zeng opened this issue Aug 15, 2024 · 2 comments
Open

Error from implementing MaskedCouplingInferenceFunnel #46

Jice-Zeng opened this issue Aug 15, 2024 · 2 comments

Comments

@Jice-Zeng
Copy link

I was trying to implement affine (surjective) masked coupling flow, with bijective layer of MaskedCoupling and surjective layer of MaskedCouplingInferenceFunnel. I named the function as make_mcf, it is similar to make_spf except I use ScalarAffine as bijective function.

def make_mcf(
    n_dimension: int,
    n_layers: Optional[int] = 5,
    n_layer_dimensions: Optional[Iterable[int]] = None,
    hidden_sizes: Iterable[int] = (64, 64),
    activation: Callable = jax.nn.tanh,
) -> hk.Transformed:
 
    if isinstance(n_layers, int) and n_layer_dimensions is not None:
        assert n_layers == len(list(n_layer_dimensions))
    elif isinstance(n_layers, int):
        n_layer_dimensions = [n_dimension] * n_layers

    return _make_mcf(
        n_dimension=n_dimension,
        n_layer_dimensions=n_layer_dimensions,
        hidden_sizes=hidden_sizes,
        activation=activation,
    )

def _make_mcf(
    n_dimension,
    n_layer_dimensions,
    hidden_sizes,
    activation,
):

    def _bijector_fn(params):
        shift, log_scale = jnp.split(params, 2, axis=-1)
        return distrax.ScalarAffine(shift, jnp.exp(log_scale))

    def _decoder_fn(dims):
        def fn(z):
            params = surjectors_mlp(dims, activation=activation)(z)
            mu, log_scale = jnp.split(params, 2, -1)
            return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
        return fn
        
    def _conditioner(n_dim):
        return hk.Sequential(
            [
                surjectors_mlp(
                    list(hidden_sizes) + [2 * n_dim],
                    activation=activation,
                ),
                hk.Reshape((n_dim, 2)),
            ]
        )

    @hk.transform
    def _flow(method, **kwargs):
        layers = []
        order = jnp.arange(n_dimension)
        curr_dim = n_dimension
        for i, n_dim_curr_layer in enumerate(n_layer_dimensions):
            # layer is dimensionality preserving
            if n_dim_curr_layer == curr_dim:
                layer = MaskedCoupling(
                    mask=make_alternating_binary_mask(curr_dim, i % 2 == 0),
                    conditioner=_conditioner(curr_dim),
                    bijector_fn=_bijector_fn,
                )
                
                order = order[::-1]
                
            elif n_dim_curr_layer < curr_dim:

                n_latent = n_dim_curr_layer
                layer = MaskedCouplingInferenceFunnel(
                    n_keep=n_latent,
                    decoder=_decoder_fn(
                        list(hidden_sizes) + [2 * (curr_dim - n_latent)]
                    ),
                    conditioner=surjectors_mlp(
                        list(hidden_sizes) + [2 * curr_dim],
                        activation=activation,
                    ),
                    bijector_fn=_bijector_fn,
                )
                curr_dim = n_latent
                
                order = order[::-1]
                order = order[:curr_dim] - jnp.min(order[:curr_dim])
            else:
                raise ValueError(
                    f"n_dimension at layer {i} is layer than the dimension of"
                    f" the following layer {i + 1}"
                )
            layers.append(layer)
            layers.append(Permutation(order, 1))
        chain = Chain(layers[:-1]) #Chain(layers[:-1])

        base_distribution = distrax.Independent(
            distrax.Normal(jnp.zeros(curr_dim), jnp.ones(curr_dim)),
            1,
        )
        td = TransformedDistribution(base_distribution, chain)
        return td(method, **kwargs)

    return _flow

It appears error: the shape of (100,8,1) and (100,8) is not compatible.
so I change the

def _bijector_fn(params):
        shift, log_scale = jnp.split(params, 2, axis=-1)
        return distrax.ScalarAffine(shift, `jnp.exp(log_scale))

as

    def _bijector_fn(params):
        shift, log_scale = jnp.split(params, 2, axis=-1)
        if shift.shape[-1] == 1:
            shift = jnp.squeeze(shift, axis=-1)
        if log_scale.shape[-1] == 1:
            log_scale = jnp.squeeze(log_scale, axis=-1)
        return distrax.ScalarAffine(shift, jnp.exp(log_scale))

Now, it works. I am not sure this is the best way to resolve this issue. If you find better way, please let me know.
Thanks!

@Jice-Zeng Jice-Zeng changed the title Error implementing MaskedCouplingInferenceFunnel Error from implementing MaskedCouplingInferenceFunnel Aug 15, 2024
@dirmeier
Copy link
Owner

Hello, I think the construction is not right. Is this what you are looking for: https://github.com/dirmeier/surjectors/blob/main/examples/coupling_inference_surjection.py?

@Jice-Zeng
Copy link
Author

Hello, I think the construction is not right. Is this what you are looking for: https://github.com/dirmeier/surjectors/blob/main/examples/coupling_inference_surjection.py?

Actually, this implementation is from the function of make_spf which uses spline network as bijective function and MaskedCouplingInferenceFunnel. I modified make_spf by changing spline network to ScalerAffine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants