You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The main issue is that this silent fails to sample from the posterior (or I'm not understanding the sample structure):
plt.figure(figsize=(12, 6))
plt.plot(true_params['intensity'], label='True Intensity', color='blue')
foriinrange(100): # Plotting the first 10 sampled intensities for clarityplt.plot(states.position['intensity'][i], alpha=0.5)
plt.xlabel('Time')
plt.ylabel('Intensity')
plt.title('True Intensity vs Sampled Intensities')
plt.legend()
plt.show()
Warnings about f32 conversion e.g.
UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
minval = minval + np.zeros([1] * final_rank, dtype=dtype)
Which suggests that the underlying Poisson distribution is struggling but...
Errors if enabling f64 conversion, e.g. if the conversion to double precision is allow then the model fails at joint_sample with error:
TypeError: Tensor conversion requested dtype <class 'numpy.float32'> for array with dtype float64: Traced<ShapedArray(float64[100])>with
Steps forward
I don't have a huge amount of JAX/oryx experience, therefore, it would be great if someone could point out if I've made a glaring error. Or if there is some kind of issue with joint_log_prob in combination with Poisson or the way I've implemented the poisson link.
The text was updated successfully, but these errors were encountered:
Hi everyone,
I'm finding
oryx
a really clean approach to implementing a PPL. However, I'm confused about conditional sampling.Poisson process with exp random walk intensity
As an attempt to get into the structure of
oryx
I'm trying to sample from a probabilistic program which represents:Exp
transform on the random walk represents the intensity of a Poisson process. This is observed.NUTS
fromblackjax
Code
Dependencies
Prob Program
Note that I've implemented the link as a
vmap
overintensity
representing conditional independence of observations.Sample some data from model
This looks reasonable.
Inference
I split out the observed data from the rest of the parameters
Then do the usual
blackjax
approach to sampling (based on their example of usingoryx
)Issues
f32
conversion e.g.Which suggests that the underlying Poisson distribution is struggling but...
f64
conversion, e.g. if the conversion to double precision is allow then the model fails atjoint_sample
with error:Steps forward
I don't have a huge amount of JAX/oryx experience, therefore, it would be great if someone could point out if I've made a glaring error. Or if there is some kind of issue with
joint_log_prob
in combination withPoisson
or the way I've implemented the poisson link.The text was updated successfully, but these errors were encountered: