Skip to content

Commit

Permalink
Replace ordered-RV workaround by pt.sorting
Browse files Browse the repository at this point in the history
This fixes a problem in the gradient of our previous workaround.
Also, the previous workaround unnecessarily introduced a
`Normal` just so we could apply the ordered transform.

The minor version number was increased, because
this changes the structure of multi-peak models.
  • Loading branch information
michaelosthege committed May 3, 2024
1 parent a5125fb commit a5c7d44
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
21 changes: 12 additions & 9 deletions peak_performance/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def define_model_normal(time: np.ndarray, intensity: np.ndarray) -> pm.Model:

def double_model_mean_prior(time):
"""
Function creating prior probability distributions for double peaks using a ZeroSumNormal distribution.
Function creating prior probability distributions for multi-peaks using a ZeroSumNormal distribution.
Parameters
----------
Expand All @@ -203,23 +203,26 @@ def double_model_mean_prior(time):
Returns
-------
mean
Normally distributed prior for the ordered means of the double peak model.
Normally distributed prior for the ordered means of the multi-peak model.
diff
Difference between meanmean and mean.
Difference between the group mean and peak-wise mean.
meanmean
Normally distributed prior for the mean of the double peak means.
Normally distributed prior for the group mean of the peak means.
"""
pmodel = pm.modelcontext(None)
meanmean = pm.Normal("meanmean", mu=np.min(time) + np.ptp(time) / 2, sigma=np.ptp(time) / 6)
diff = pm.ZeroSumNormal(
"diff",
sigma=1,
shape=(2,), # currently no dims due to bug with ordered transformation
# Support arbitrary number of subpeaks
shape=len(pmodel.coords["subpeak"]),
# NOTE: As of PyMC v5.13, the OrderedTransform and ZeroSumTransform are incompatible.
# See https://github.com/pymc-devs/pymc/issues/6975.
# As a workaround we'll call pt.sort a few lines below.
)
mean = pm.Normal(
mean = pm.Deterministic(
"mean",
mu=meanmean + diff,
sigma=1,
transform=pm.distributions.transforms.ordered,
meanmean + pt.sort(diff),
dims=("subpeak",),
)
return mean, diff, meanmean
Expand Down
38 changes: 38 additions & 0 deletions peak_performance/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import arviz as az
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest
import scipy.integrate
import scipy.stats as st
Expand All @@ -26,6 +27,43 @@ def test_initial_guesses():
pass


def test_zsn_sorting():
"""This tests a workaround that we rely on for multi-peak models."""
coords = {
"thing": ["left", "right"],
}
with pm.Model(coords=coords) as pmodel:
hyper = pm.Normal("hyper", mu=0, sigma=3)
diff = pm.ZeroSumNormal(
"diff",
sigma=1,
shape=2,
)
# Create a sorted deterministic without using transforms
diff_sorted = pm.Deterministic("diff_sorted", pt.sort(diff), dims="thing")
pos = pm.Deterministic(
"pos",
hyper + diff_sorted,
dims="thing",
)
# Observe the two things in incorrect order to provoke the model 😈
dat = pm.Data("dat", [0.2, -0.3], dims="thing")
pm.Normal("L", pos, observed=dat, dims="thing")

# Check draws from the prior
drawn = pm.draw(diff_sorted, draws=69)
np.testing.assert_array_less(drawn[:, 0], drawn[:, 1])

# And check MCMC draws too
with pmodel:
idata = pm.sample(
chains=1, tune=10, draws=69, step=pm.Metropolis(), compute_convergence_checks=False
)
sampled = idata.posterior["diff_sorted"].stack(sample=("chain", "draw")).values.T
np.testing.assert_array_less(sampled[:, 0], sampled[:, 1])
pass


class TestDistributions:
def test_normal_posterior(self):
x = np.linspace(-5, 10, 10000)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "peak_performance"
version = "0.6.5"
version = "0.7.0"
authors = [
{name = "Jochen Nießer", email = "[email protected]"},
{name = "Michael Osthege", email = "[email protected]"},
Expand Down

0 comments on commit a5c7d44

Please sign in to comment.