Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem conditioning on vmap Poisson random variables #96

Open
SamuelBrand1 opened this issue Feb 3, 2025 · 0 comments
Open

Problem conditioning on vmap Poisson random variables #96

SamuelBrand1 opened this issue Feb 3, 2025 · 0 comments

Comments

@SamuelBrand1
Copy link

Hi everyone,

I'm finding oryx a really clean approach to implementing a PPL. However, I'm confused about conditional sampling.

Poisson process with exp random walk intensity

As an attempt to get into the structure of oryx I'm trying to sample from a probabilistic program which represents:

  1. A hierarchical random walk (that is a random walk where the parameters are themselves random variables)
  2. A further Exp transform on the random walk represents the intensity of a Poisson process. This is observed.
  3. Inference done with NUTS from blackjax

Code

Dependencies

import jax
# jax.config.update("jax_enable_x64", True)

import oryx.core.ppl as ppl
import oryx.bijectors as bijectors
import oryx.distributions as tfd
import blackjax


import jax.numpy as jnp
import jax.random as random

from functools import partial
import matplotlib.pyplot as plt

Prob Program
Note that I've implemented the link as a vmap over intensity representing conditional independence of observations.

@partial(jax.jit, static_argnames=["n"])
def hierarchical_random_walk_dist(n, init, step_scale):
    rw_transformation = bijectors.Chain([bijectors.Shift(init), bijectors.Scale(step_scale), bijectors.Cumsum()])
    return tfd.TransformedDistribution(tfd.MultivariateNormalDiag(jnp.zeros(n), jnp.ones(n)), rw_transformation)

def poisson_process(key, n, init_prior_loc, init_prior_scale, step_scale_prior):
    key_poi, key_intensity, key_init, key_step = random.split(key, 4)
    init = ppl.random_variable(tfd.Normal(init_prior_loc, init_prior_scale), name = "init")(key_init)
    step_scale = ppl.random_variable(tfd.HalfNormal(step_scale_prior), name = "step_scale")(key_step)
    intensity = ppl.random_variable(tfd.TransformedDistribution(
        hierarchical_random_walk_dist(n, init, step_scale),
        bijectors.Exp()),
        name = "intensity")(key_intensity)
    poi_keys = random.split(key_poi, n)
    poi = jax.vmap(lambda ky, x: ppl.random_variable(tfd.Poisson(x), name = "poi")(ky))(poi_keys, intensity)
    return poi

Sample some data from model

sampler = ppl.joint_sample(poisson_process)
key_rn = random.PRNGKey(1234)
true_params = sampler(key_rn, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)

plt.plot(true_params['intensity'])
plt.scatter(range(len(true_params['poi'])), true_params['poi'], color='red')
plt.xlabel('time')
plt.ylabel('Intensity')
plt.title('Intensity Random Variable')
plt.show()

Image

This looks reasonable.

Inference

I split out the observed data from the rest of the parameters

true_data = true_params.pop('poi')
true_params

{'init': Array(0.01620068, dtype=float32),
'intensity': Array([2.21524286e+00, 1.02979910e+00, 3.47778827e-01, 1.59668833e-01,
1.43431634e-01, 1.89372957e-01, 1.48013055e-01, 8.88996869e-02,
1.40973523e-01, 1.33820206e-01, 6.49942383e-02, 4.85427566e-02,
1.31719252e-02, 1.61807686e-02, 1.93237811e-02, 7.30752666e-03,
3.75548634e-03, 3.23897717e-03, 4.46824962e-03, 3.59713589e-03,
3.48433293e-03, 4.54167370e-03, 8.35305359e-03, 7.45324651e-03,
1.69865470e-02, 2.82925181e-03, 4.80814092e-03, 5.73506765e-03,
1.29247606e-02, 2.23501474e-02, 2.60949116e-02, 2.78504174e-02,
2.95239929e-02, 3.01535334e-02, 3.17600109e-02, 5.96645549e-02,
1.58876508e-01, 4.15319920e-01, 2.83102959e-01, 3.94434422e-01,
6.21528685e-01, 9.56910610e-01, 4.71480668e-01, 3.51778269e-01,
3.12051624e-01, 3.87135684e-01, 4.41913813e-01, 8.34466696e-01,
1.12293482e+00, 9.62291718e-01, 6.46639347e-01, 1.22468376e+00,
1.33461881e+00, 9.76860523e-01, 1.60133433e+00, 4.31086159e+00,
3.78359699e+00, 4.50091076e+00, 7.61642456e+00, 9.94997692e+00,
1.83034401e+01, 1.86841354e+01, 1.93865471e+01, 4.10644569e+01,
5.11959839e+01, 3.74023285e+01, 1.46664228e+01, 2.73789101e+01,
5.09101982e+01, 2.61694183e+01, 2.90790100e+01, 1.15916996e+01,
1.44228182e+01, 8.16150761e+00, 1.21826038e+01, 8.52718925e+00,
8.82525539e+00, 1.47077036e+01, 1.31940975e+01, 8.21146393e+00,
5.06118011e+00, 3.73972368e+00, 8.66150951e+00, 8.86765766e+00,
1.82184372e+01, 2.03960953e+01, 1.50705147e+01, 3.58565903e+01,
3.94253273e+01, 1.51045656e+01, 1.46200066e+01, 1.30218935e+01,
8.60846615e+00, 6.86474085e+00, 5.52572966e+00, 6.91005898e+00,
4.49717140e+00, 2.01037908e+00, 2.75382376e+00, 3.28753996e+00], dtype=float32),
'step_scale': Array(0.5200044, dtype=float32)}

Then do the usual blackjax approach to sampling (based on their example of using oryx)

def logdensity_fn(params):
    theta = dict(params, poi = true_data)
    return ppl.joint_log_prob(poisson_process)(theta, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)

ll = logdensity_fn(true_params)
# Array(-50.80758, dtype=float32)

# Warmup
inference_key = jax.random.PRNGKey(12)
rng_key, warmup_key = jax.random.split(inference_key)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(warmup_key, true_params, 1000)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

# Sampling

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos

rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 2000)

Issues

  1. The main issue is that this silent fails to sample from the posterior (or I'm not understanding the sample structure):
plt.figure(figsize=(12, 6))
plt.plot(true_params['intensity'], label='True Intensity', color='blue')
for i in range(100):  # Plotting the first 10 sampled intensities for clarity
    plt.plot(states.position['intensity'][i], alpha=0.5)
plt.xlabel('Time')
plt.ylabel('Intensity')
plt.title('True Intensity vs Sampled Intensities')
plt.legend()
plt.show()

Image

  1. Warnings about f32 conversion e.g.

UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
minval = minval + np.zeros([1] * final_rank, dtype=dtype)

Which suggests that the underlying Poisson distribution is struggling but...

  1. Errors if enabling f64 conversion, e.g. if the conversion to double precision is allow then the model fails at joint_sample with error:

TypeError: Tensor conversion requested dtype <class 'numpy.float32'> for array with dtype float64: Traced<ShapedArray(float64[100])>with

Steps forward

I don't have a huge amount of JAX/oryx experience, therefore, it would be great if someone could point out if I've made a glaring error. Or if there is some kind of issue with joint_log_prob in combination with Poisson or the way I've implemented the poisson link.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant