From bcf5d13617ac14250b3e6dcc95f6b8b4165e09a5 Mon Sep 17 00:00:00 2001 From: Michael Shvartsman <70196+mshvartsman@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:42:00 -0700 Subject: [PATCH] Introducing Response Time Modeling (#361) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/361 Implements https://arxiv.org/abs/2306.06296 (Response Time Improves Choice Prediction and Function Estimation for Gaussian Process Models of Perception and Preferences, UAI 2024). Code to make figures etc in the paper will be available at https://github.com/facebookresearch/response-time-gps. Differential Revision: D59561138 --- aepsych/__init__.py | 6 + aepsych/distributions.py | 706 ++++++++++++++++++++++++++++++ aepsych/kernels/__init__.py | 4 + aepsych/kernels/pairwisekernel.py | 85 ++++ aepsych/likelihoods/__init__.py | 4 + aepsych/likelihoods/ddm.py | 152 +++++++ tests/test_ddm_distr.py | 181 ++++++++ tests/test_pairwise_kernel.py | 151 +++++++ 8 files changed, 1289 insertions(+) create mode 100644 aepsych/distributions.py create mode 100644 aepsych/kernels/pairwisekernel.py create mode 100644 aepsych/likelihoods/ddm.py create mode 100644 tests/test_ddm_distr.py create mode 100644 tests/test_pairwise_kernel.py diff --git a/aepsych/__init__.py b/aepsych/__init__.py index ee68af6d5..1d00e0c69 100644 --- a/aepsych/__init__.py +++ b/aepsych/__init__.py @@ -11,6 +11,7 @@ from . import acquisition, config, factory, generators, models, strategy, utils from .config import Config +from .distributions import RTDistWithUniformLapseRate, LogNormalDDMDistribution, ShiftedGammaDDMDistribution, ShiftedInverseGammaDDMDistribution, ShiftedLogNormalDDMDistribution from .likelihoods import BernoulliObjectiveLikelihood from .models import GPClassificationModel from .strategy import SequentialStrategy, Strategy @@ -31,6 +32,11 @@ "BernoulliObjectiveLikelihood", "BernoulliLikelihood", "GaussianLikelihood", + "RTDistWithUniformLapseRate", + "LogNormalDDMDistribution", + "ShiftedGammaDDMDistribution", + "ShiftedInverseGammaDDMDistribution", + "ShiftedLogNormalDDMDistribution", ] try: diff --git a/aepsych/distributions.py b/aepsych/distributions.py new file mode 100644 index 000000000..f97da32ab --- /dev/null +++ b/aepsych/distributions.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings + + +from logging import getLogger +from numbers import Number + +import torch +from torch.distributions import Bernoulli, Exponential, LogNormal, Normal +from gpytorch import constraints +from gpytorch.distributions import Distribution +from torch.distributions import ( + Gamma, + TransformedDistribution, + Uniform, +) +from torch.distributions.transforms import AffineTransform, ExpTransform, PowerTransform +from torch.distributions.utils import broadcast_all + +logger = getLogger() + + +class RTDistWithUniformLapseRate(Distribution): + arg_constraints = { + "lapse_rate": constraints.Interval(1e-5, 0.2), + "max_rt": constraints.Positive(), + } + + def __init__( + self, lapse_rate, max_rt, base_dist, validate_args=False, **kwargs + ): + self.lapse_rate = lapse_rate + self.base_dist = base_dist + self.max_rt = max_rt + self.lapse_dist = Uniform(-self.max_rt, self.max_rt) + self.p_lapse_dist = Bernoulli(self.lapse_rate) + super().__init__(**kwargs, validate_args=validate_args) + + @property + def mean(self): + return ( + self.lapse_rate * self.lapse_dist.mean + + (1 - self.lapse_rate) * self.base_dist.mean + ) + + def log_prob(self, rts): + # rt whose p=0 will have logp=nan, replace with -1000 which will exp() to 0 anyway + # in logsumexp + rt_logp = torch.nan_to_num(self.base_dist.log_prob(rts), nan=-1000) + lapse_logp = self.lapse_dist.log_prob(rts) + + [*batch_shape, rt_shape] = rt_logp.shape + assert rt_shape == lapse_logp.shape[0] + lapse_logp = lapse_logp.expand(*batch_shape, -1) + + mix_logps = torch.stack( + ( + lapse_logp + torch.log(self.lapse_rate), + rt_logp + torch.log(1 - self.lapse_rate), + ), + dim=-1, + ) + return torch.logsumexp(mix_logps, dim=-1) + + def sample(self, sample_shape=torch.Size([])): # noqa B008 + rt_samps = self.base_dist.sample(sample_shape=sample_shape) + unif_samps = self.lapse_dist.rsample(sample_shape=rt_samps.shape) + coinflips = self.p_lapse_dist.sample(sample_shape=rt_samps.shape).int()[..., 0] + return torch.where(coinflips == 1, unif_samps, rt_samps) + + +class ExGaussian(Distribution): + def __init__(self, mean, stddev, lam, validate_args=False, *args, **kwargs): + self.mean = mean + self.stddev = stddev + self.lam = lam + + super().__init__(**kwargs, validate_args=validate_args) + + def log_prob(self, x): + """ + Same as PyMC + """ + res = torch.where( + self.lam > 0.05 * self.stddev, + -torch.log(self.lam) + + (self.mean - x) / self.lam + + 0.5 * (self.stddev / self.lam) ** 2 + + torch.log( + Normal(loc=self.mean + (self.stddev**2) / self.lam, scale=self.stddev**2).cdf(x) + ), + LogNormal(loc=self.mean, scale=self.stddev**2).log_prob(x), + ) + return res + + def rsample(self, sample_shape=torch.Size()): # noqa B008 + return Normal(loc=self.mean, scale=self.stddev).rsample( + sample_shape=sample_shape + ) + Exponential(rate=self.lam).rsample(sample_shape=sample_shape) + + +class ShiftedGamma(TransformedDistribution): + r""" + Creates a shifted log-normal distribution parameterized by + :attr:`shift` and, :attr:`loc`, and :attr:`scale` where:: + + """ + arg_constraints = { + "concentration": constraints.Positive(), + "rate": constraints.Positive(), + } + support = constraints.Positive() + has_rsample = True + + def __init__( + self, shift, concentration, rate, validate_args=False, **kwargs + ): + base_dist = Gamma(concentration, rate, validate_args=validate_args) + self.shift = shift + super().__init__( + base_dist, + [AffineTransform(loc=shift, scale=torch.tensor(1.0))], + validate_args=validate_args, + **kwargs, + ) + + @property + def concentration(self): + return self.base_dist.concentration + + @property + def rate(self): + return self.base_dist.rate + + def log_prob(self, X): + return torch.where(X < self.shift, torch.nan, super().log_prob(X)) + + +class ShiftedInverseGamma(TransformedDistribution): + r""" + Creates a shifted log-normal distribution parameterized by + :attr:`shift` and, :attr:`loc`, and :attr:`scale` where:: + + """ + arg_constraints = { + "concentration": constraints.Positive(), + "rate": constraints.Positive(), + } + support = constraints.Positive() + has_rsample = True + + def __init__( + self, shift, concentration, rate, validate_args=False, **kwargs + ): + base_dist = Gamma(concentration, rate, validate_args=validate_args) + self.shift = shift + super(ShiftedInverseGamma, self).__init__( + base_dist, + [ + PowerTransform(exponent=torch.tensor(-1.0)), + AffineTransform(loc=shift, scale=torch.tensor(1.0)), + ], + validate_args=validate_args, + **kwargs, + ) + + @property + def concentration(self): + return self.base_dist.concentration + + @property + def rate(self): + return self.base_dist.rate + + +class ShiftedLognormal(TransformedDistribution): + r""" + Creates a shifted log-normal distribution parameterized by + :attr:`shift` and, :attr:`loc`, and :attr:`scale` where:: + + """ + arg_constraints = { + "scale": constraints.Positive(), + } + support = constraints.Positive() + has_rsample = True + + def __init__(self, shift, loc, scale, validate_args=False): + base_dist = Normal(loc, scale, validate_args=validate_args) + self.shift = shift + super(ShiftedLognormal, self).__init__( + base_dist, + [ExpTransform(), AffineTransform(loc=shift, scale=torch.tensor(1.0))], + validate_args=validate_args, + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ShiftedLognormal, _instance) + return super(ShiftedLognormal, self).expand(batch_shape, _instance=new) + + @property + def loc(self): + return self.base_dist.loc + + @property + def scale(self): + return self.base_dist.scale + + @property + def mean(self): + return (self.loc + self.scale.pow(2) / 2).exp() + self.shift + + @property + def variance(self): + return (self.scale.pow(2).exp() - 1) * (2 * self.loc + self.scale.pow(2)).exp() + + +def coth(x): + # probably not numerically terrific + return torch.cosh(x) / torch.sinh(x) + + +def csch(x): + # probably not numerically terrific + return 1 / torch.sinh(x) + +class DDMMomentMatchDistribution(Distribution): + """ + Distribution over [choices, rts]. However, since rts are always positive, we use the sign of the RTs + to track choice information (so rt>0 means yes choice, rt<0 means no choice) which means as far as gpytorch knows, + we still have a univariate outcome. + There's basically 2 steps here: + 1. Use the DDM parameters to compute moments of the conditional RT distributions + and choice probabilities using the expressions in https://mae.princeton.edu/sites/default/files/SrivastHSimen-JMatPsy16.pdf. + This is what the base class does. + 2. Moment-match the moments to some nicer distribution, and pretend that's our likelihood. That + is what subclasses do. + """ + + SMALL_DRIFT_CUTOFF = ( + 1e-2 # use this as cutoff to use asymptotic drift -> 0 expressions + ) + arg_constraints = { + "threshold": constraints.Positive(), + "relative_x0": constraints.Interval(0.2, 0.8), + "t0": constraints.Positive(), + } + support = constraints.Positive() + + def __init__( + self, drift, threshold, relative_x0, t0, restrict_skew=False, max_shift=None + ): + + self.drift = drift + self.threshold = threshold + self.max_shift = max_shift + + # relative x0 is scaled 0 to 1, x0 is -thresh to thresh + # boundarySep = thresh * 2 + # relativeInitCond = (x0+z) / boundarySep + # boundarySep * relativeInitCond = x0+z + + self.x0 = threshold * (2 * relative_x0 - 1) + + kz = drift * threshold + kx = drift * self.x0 + + near_zero_drift = drift.abs() < self.SMALL_DRIFT_CUTOFF + + # as abs(drift) -> 0, use different expressions (expr 30 and 32) + rt_mean_yes0 = (4 * threshold**2 - (threshold + self.x0) ** 2) / 3 + rt_mean_no0 = (4 * threshold**2 - (threshold - self.x0) ** 2) / 3 + rt_var_yes0 = (32 * threshold**4 - 2 * (threshold + self.x0) ** 4) / 45 + rt_var_no0 = (32 * threshold**4 - 2 * (threshold - self.x0) ** 4) / 45 + + # for nonzero drift, expr 29 and 31 + self.rt_mean_yes = ( + torch.where( + near_zero_drift, + rt_mean_yes0, + drift ** (-2) * ((2 * kz * coth(2 * kz)) - (kx + kz) * coth(kx + kz)), + ) + + t0 + ) + self.rt_mean_no = ( + torch.where( + near_zero_drift, + rt_mean_no0, + drift ** (-2) * ((2 * kz * coth(2 * kz)) - (-kx + kz) * coth(-kx + kz)), + ) + + t0 + ) + self.rt_var_yes = torch.where( + near_zero_drift, + rt_var_yes0, + drift ** (-4) + * ( + 4 * kz**2 * csch(2 * kz) ** 2 + + 2 * kz * coth(2 * kz) + - (kx + kz) ** 2 * csch(kx + kz) ** 2 + - (kx + kz) * coth(kx + kz) + ), + ) + self.rt_var_no = torch.where( + near_zero_drift, + rt_var_no0, + drift ** (-4) + * ( + 4 * kz**2 * csch(2 * kz) ** 2 + + 2 * kz * coth(2 * kz) + - (-kx + kz) ** 2 * csch(-kx + kz) ** 2 + - (-kx + kz) * coth(-kx + kz) + ), + ) + + # expr 36 + rt_3rd_moment_yes = drift ** (-6) * ( + 12 * kz**2 * csch(2 * kz) ** 2 + + 16 * kz**3 * coth(2 * kz) * csch(2 * kz) ** 2 + + 6 * kz * coth(2 * kz) + - 3 * (kz + kx) ** 2 * csch(kx + kz) ** 2 + - 2 * (kx + kz) ** 3 * coth(kz + kx) * csch(kz + kx) ** 2 + - 3 * (kx + kz) * coth(kx + kz) + ) + rt_3rd_moment_no = drift ** (-6) * ( + 12 * kz**2 * csch(2 * kz) ** 2 + + 16 * kz**3 * coth(2 * kz) * csch(2 * kz) ** 2 + + 6 * kz * coth(2 * kz) + - 3 * (kz - kx) ** 2 * csch(kz - kx) ** 2 + - 2 * (-kx + kz) ** 3 * coth(kz - kx) * csch(kz - kx) ** 2 + - 3 * (-kx + kz) * coth(-kx + kz) + ) + rt_skew_yes = rt_3rd_moment_yes / self.rt_var_yes ** (3 / 2) + rt_skew_no = rt_3rd_moment_no / self.rt_var_no ** (3 / 2) + + # expr 37 + # np.sqrt(45/2) = 4.743416490252569 + SQRT45_2 = 4.743416490252569 + rt_skew_yes0 = SQRT45_2 * ( + (8 * (64 * threshold**6 - (threshold + self.x0) ** 6)) + / (21 * (16 * threshold**4 - (threshold + self.x0) ** 4) ** (3 / 2)) + ) + rt_skew_no0 = SQRT45_2 * ( + (8 * (64 * threshold**6 - (threshold - self.x0) ** 6)) + / (21 * (16 * threshold**4 - (threshold - self.x0) ** 4) ** (3 / 2)) + ) + + self.rt_skew_yes = torch.where(near_zero_drift, rt_skew_yes0, rt_skew_yes) + self.rt_skew_no = torch.where(near_zero_drift, rt_skew_no0, rt_skew_no) + + # expr 6 and 9 + self.response_prob = torch.where( + near_zero_drift, + (threshold - self.x0) / (2 * threshold), + 1 + - (torch.exp(-2 * kx) - torch.exp(-2 * kz)) + / (torch.exp(2 * kz) - torch.exp(-2 * kz)), + ) + + # these will fail if numerical stability is bad, clamp them + self.response_prob = self.response_prob.clamp(min=1e-5, max=1 - 1e-5) + self.rt_var_yes = self.rt_var_yes.clamp(min=1e-5) + self.rt_var_no = self.rt_var_no.clamp(min=1e-5) + self.rt_mean_yes = self.rt_mean_yes.clamp(min=1e-5) + self.rt_mean_no = self.rt_mean_no.clamp(min=1e-5) + if restrict_skew: + self.rt_skew_yes = self.rt_skew_yes.clamp(min=0.01, max=10) + self.rt_skew_no = self.rt_skew_no.clamp(min=0.01, max=10) + + self._make_moment_matched_likelihood() + + def _make_moment_matched_likelihood(self): + raise NotImplementedError + + @property + def mean(self): + return ( + self.response_prob * self.rt_mean_yes + + (1 - self.response_prob) * self.rt_mean_no + ) + + def rsample(self, sample_shape=torch.Size()): # noqa B008 + choices = self.choice_dist.sample(sample_shape=sample_shape) + rt_yes = self.rt_yes_dist.rsample(sample_shape=sample_shape) + rt_no = self.rt_no_dist.rsample(sample_shape=sample_shape) + return torch.where(choices > 0, rt_yes, -rt_no) + + def log_prob(self, signed_rts): + # log p(rt, choice | theta) =log p(rt|choice, theta) + log p(choice | theta) + # p(rt|choice) is our conditional lognormal, p(choice) is bernoulli. + choices = signed_rts > 0 + yes_log_probs = self.rt_yes_dist.log_prob(torch.abs(signed_rts)) + no_log_probs = self.rt_no_dist.log_prob(torch.abs(signed_rts)) + rt_log_probs = torch.where(choices, yes_log_probs, no_log_probs) + + return rt_log_probs + self.choice_dist.log_prob(choices.float()) + + +class LogNormalDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + + # moment match to lognormal (from https://en.wikipedia.org/wiki/Log-normal_distribution) + lognormal_mu_yes = torch.log( + self.rt_mean_yes / torch.sqrt(self.rt_var_yes / self.rt_mean_yes**2 + 1) + ) + lognormal_sigma_yes = torch.sqrt( + torch.log(self.rt_var_yes / self.rt_mean_yes**2 + 1) + ) + lognormal_mu_no = torch.log( + self.rt_mean_no / torch.sqrt(self.rt_var_no / self.rt_mean_no**2 + 1) + ) + lognormal_sigma_no = torch.sqrt( + torch.log(self.rt_var_no / self.rt_mean_no**2 + 1) + ) + + assert (lognormal_sigma_yes > 0.0).all(), lognormal_sigma_yes.min() + assert (lognormal_sigma_no > 0.0).all(), lognormal_sigma_no.min() + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + self.rt_yes_dist = torch.distributions.LogNormal( + loc=lognormal_mu_yes, scale=lognormal_sigma_yes + ) + self.rt_no_dist = torch.distributions.LogNormal( + loc=lognormal_mu_no, scale=lognormal_sigma_no + ) + + +class ShiftedLogNormalDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + + # moment match to shifted lognormal (from https://jod.pm-research.com/content/21/4/103, + # doi:10.3905/jod.2014.21.4.103, lemma 8 + + B_yes = 0.5 * ( + self.rt_skew_yes.square() + + 2 + - torch.sqrt(self.rt_skew_yes**4 + 4 * self.rt_skew_yes.square()) + ) + + shifted_lognormal_shift_yes = self.rt_mean_yes - ( + self.rt_var_yes.sqrt() / self.rt_skew_yes + ) * (1 + B_yes ** (1 / 3) + B_yes ** (-1 / 3)) + + shifted_lognormal_var_yes = torch.log( + 1 + + self.rt_var_yes / ((self.rt_mean_yes - shifted_lognormal_shift_yes) ** 2) + ) + shifted_lognormal_mean_yes = ( + torch.log(self.rt_mean_yes - shifted_lognormal_shift_yes) + - shifted_lognormal_var_yes**2 / 2 + ) + + B_no = 0.5 * ( + self.rt_skew_no.square() + + 2 + - torch.sqrt(self.rt_skew_no**4 + 4 * self.rt_skew_no.square()) + ) + + shifted_lognormal_shift_no = self.rt_mean_no - ( + self.rt_var_no.sqrt() / self.rt_skew_no + ) * (1 + B_no ** (1 / 3) + B_no ** (-1 / 3)) + + shifted_lognormal_var_no = torch.log( + 1 + self.rt_var_no / ((self.rt_mean_no - shifted_lognormal_shift_no) ** 2) + ) + shifted_lognormal_mean_no = ( + torch.log(self.rt_mean_no - shifted_lognormal_shift_no) + - shifted_lognormal_var_no**2 / 2 + ) + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + + shifted_lognormal_shift_yes = shifted_lognormal_shift_yes.clamp( + min=0, max=self.max_shift + ) + shifted_lognormal_shift_no = shifted_lognormal_shift_no.clamp( + min=0, max=self.max_shift + ) + + self.rt_yes_dist = ShiftedLognormal( + shift=shifted_lognormal_shift_yes, + loc=shifted_lognormal_mean_yes, + scale=shifted_lognormal_var_yes.sqrt(), + ) + self.rt_no_dist = ShiftedLognormal( + shift=shifted_lognormal_shift_no, + loc=shifted_lognormal_mean_no, + scale=shifted_lognormal_var_no.sqrt(), + ) + + +class ExGaussianDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + # moment match to exgaussian (from https://en.wikipedia.org/wiki/Exponentially_modified_Gaussian_distribution#Parameter_estimation) + + # exgaussian is restricted to skew <= 2, so we clamp + + clamped_yes_skew = torch.clamp(self.rt_skew_yes, max=torch.tensor(2.0)) + clamped_no_skew = torch.clamp(self.rt_skew_no, max=torch.tensor(2.0)) + tau_yes = self.rt_var_yes.sqrt() * (clamped_yes_skew / 2) ** (1 / 3) + mu_yes = self.rt_mean_yes - tau_yes + var_yes = self.rt_var_yes * (1 - (clamped_yes_skew / 2) ** (2 / 3)) + + tau_no = self.rt_var_no.sqrt() * (clamped_no_skew / 2) ** (1 / 3) + mu_no = self.rt_mean_no - tau_no + var_no = self.rt_var_no * (1 - (clamped_no_skew / 2) ** (2 / 3)) + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + self.rt_yes_dist = ExGaussian(m=mu_yes, s=var_yes.sqrt(), l=1 / tau_yes) + self.rt_no_dist = ExGaussian(m=mu_no, s=var_no.sqrt(), l=1 / tau_no) + + +class ShiftedGammaDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + # doi:10.1016/j.insmatheco.2020.12.002, + # section 4.2. + + a_yes = 4 / self.rt_skew_yes**2 + scale_yes = (self.rt_var_yes / a_yes).sqrt() + shift_yes = self.rt_mean_yes - a_yes * scale_yes + + a_no = 4 / self.rt_skew_no**2 + scale_no = (self.rt_var_no / a_no).sqrt() + shift_no = self.rt_mean_no - a_no * scale_no + + shift_yes = shift_yes.clamp(min=0, max=self.max_shift) + shift_no = shift_no.clamp(min=0, max=self.max_shift) + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + self.rt_yes_dist = ShiftedGamma( + shift=shift_yes, concentration=a_yes, rate=1 / scale_yes + ) + self.rt_no_dist = ShiftedGamma( + shift=shift_no, concentration=a_no, rate=1 / scale_no + ) + + +class ShiftedInverseGammaDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + # doi:10.1016/j.insmatheco.2020.12.002, + # section 4.3. + + shift_yes = self.rt_mean_yes - self.rt_var_yes.sqrt() / self.rt_skew_yes * ( + 2 + (4 + self.rt_skew_yes.square()).sqrt() + ) + a_yes = 2 + (self.rt_mean_yes - shift_yes).square() / self.rt_var_yes + b_yes = (self.rt_mean_yes - shift_yes) * (a_yes - 1) + + shift_no = self.rt_mean_no - self.rt_var_no.sqrt() / self.rt_skew_no * ( + 2 + (4 + self.rt_skew_no.square()).sqrt() + ) + a_no = 2 + (self.rt_mean_no - shift_no).square() / self.rt_var_no + b_no = (self.rt_mean_no - shift_no) * (a_no - 1) + + shift_yes = shift_yes.clamp(min=0, max=self.max_shift) + shift_no = shift_no.clamp(min=0, max=self.max_shift) + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + self.rt_yes_dist = ShiftedInverseGamma( + shift=shift_yes, concentration=a_yes, rate=b_yes + ) + self.rt_no_dist = ShiftedInverseGamma( + shift=shift_no, concentration=a_no, rate=b_no + ) + + +class DDMDistribution(Distribution): + + arg_constraints = { + "z": constraints.Positive(), + "relative_x0": constraints.Interval(0.0, 1.0), + "t0": constraints.Positive(), + } + def __init__(self, a, z, relative_x0, t0, eps=1e-10, validate_args=True): + + self.a, self.z, self.relative_x0, self.t0 = broadcast_all(a, z, relative_x0, t0) + + if ( + isinstance(a, Number) + and isinstance(z, Number) + and isinstance(relative_x0, Number) + and isinstance(t0, Number) + ): + batch_shape = torch.Size() + else: + batch_shape = self.a.size() + + self.eps = eps + super().__init__(batch_shape=batch_shape, validate_args=validate_args) + + def _standardized_WFPT_large_time(self, t, w, nterms): + # large time expansion from navarro & fuss + + piSqOv2 = 4.93480220054 + # use nterms that's max over the batch. This guarantees + # we'll hit our target precision and enable batched + # computation, but will incur extra cost for the extra + # terms if not needed. + k = torch.arange(1, nterms + 1) + k = k.expand(*w.shape, *k.shape) # match batch shape to params + w = w[:, None] # broadcast an extra dim for w we can reduce sum over + + terms = ( + torch.pi + * k + * torch.exp(-(k**2) * t * piSqOv2) + * torch.sin(k * torch.pi * w) + ) + assert terms.shape == (*t.shape[:-1], *self.batch_shape, nterms) + return terms.sum(-1) + + def _standardized_WFPT_small_time(self, t, w, nterms): + # small time expansion navarro & fuss + + fr = math.floor(-(nterms - 1) / 2) + to = math.ceil((nterms - 2) // 2) + k = torch.arange(fr, to + 1) + k = k.expand(*w.shape, *k.shape) + w = w[:, None] # broadcast an extra dim for w we can reduce sum over + + terms = ( + 1 + / torch.sqrt(2 * torch.pi * t**3) + * (w + 2 * k) + * torch.exp(-((w + 2 * k) ** 2) / (2 * t)) + ) + assert terms.shape == (*t.shape[:-1], *self.batch_shape, nterms) + return terms.sum(0) + + def log_prob(self, signed_rt): + """ + Log probability of first passage time of double-threshold wiener process + (aka "pure DDM" of Bogacz et al.). Uses series truncation of Navarro & Fuss 2009 + """ + + shifted_t = signed_rt.abs() - self.t0 # correct for the shift + # normalize time (this also implicitly broadcasts) + normT = shifted_t / (self.relative_x0**2) + + # if t is below NDT, return -inf + t_below_ndt = normT <= 0 + + # by default return hit of lower bound, so if resp is correct flip + # signflip based on choice as needed + driftsign = torch.where(signed_rt > 0, -1, 1) + a = self.a * driftsign + relative_x0 = torch.where(signed_rt > 0, 1 - self.relative_x0, self.relative_x0) + + largeK = torch.ceil( + torch.sqrt( + (-2 * torch.log(torch.pi * normT * self.eps)) / (torch.pi**2 * normT) + ) + ) + smallK = torch.ceil( + 2 + + torch.sqrt( + -2 * normT * torch.log(2 * self.eps * torch.sqrt(2 * torch.pi * normT)) + ) + ) + + # if eps is too big for bound to be valid, adjust + smallK[self.eps > (1 / (2 * torch.sqrt(2 * torch.pi * normT)))] = 2 + bound_invalid = self.eps > (1 / (torch.pi * torch.sqrt(normT))) + largeK[bound_invalid] = torch.ceil( + (1 / (torch.pi * torch.sqrt(normT[bound_invalid]))) + ) + + # pick the smaller of large and small k options, then + # take the max so we can batch properly without needing ragged arrays + nterms = torch.min(largeK, smallK)[torch.logical_not(t_below_ndt)] + if nterms.max() - nterms.min() > 100: + warnings.warn( + "Number of series terms over a batch varies by more than 100, compute costs may be increased", + RuntimeWarning, + stacklevel=2 + ) + + nterms = nterms.max() + + use_large_time = largeK >= smallK + + prob = torch.zeros_like(normT) + prob[t_below_ndt] = -torch.inf + + large_time_approx = self._standardized_WFPT_large_time(normT, relative_x0, nterms) + small_time_approx = self._standardized_WFPT_small_time(normT, relative_x0, nterms) + prob[use_large_time] = large_time_approx[use_large_time.squeeze()] + prob[torch.logical_not(use_large_time)] = small_time_approx[ + torch.logical_not(use_large_time).squeeze() + ] + + boundarySep = 2 * self.z + + # scale from the std case to whatever is our actual + scaler = (1 / relative_x0**2) * torch.exp( + -a * boundarySep * relative_x0 - (a**2 * shifted_t / 2) + ) + + return torch.log(scaler * prob) diff --git a/aepsych/kernels/__init__.py b/aepsych/kernels/__init__.py index 8b2df349c..6c59dc1c7 100644 --- a/aepsych/kernels/__init__.py +++ b/aepsych/kernels/__init__.py @@ -4,3 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from .pairwisekernel import PairwiseKernel +from .rbf_partial_grad import RBFKernelPartialObsGrad + +__all__ = ["PairwiseKernel", "RBFKernelPartialObsGrad"] diff --git a/aepsych/kernels/pairwisekernel.py b/aepsych/kernels/pairwisekernel.py new file mode 100644 index 000000000..a76b8c75c --- /dev/null +++ b/aepsych/kernels/pairwisekernel.py @@ -0,0 +1,85 @@ +import torch +from gpytorch.kernels import Kernel +from gpytorch.lazy import lazify + + +class PairwiseKernel(Kernel): + """ + Wrapper to convert a kernel K on R^k to a kernel K' on R^{2k}, modeling + functions of the form g(a, b) = f(a) - f(b), where f ~ GP(mu, K). + + Since g is a linear combination of Gaussians, it follows that g ~ GP(0, K') + where K'((a,b), (c,d)) = K(a,c) - K(a, d) - K(b, c) + K(b, d). + + """ + + def __init__(self, latent_kernel, is_partial_obs=False, **kwargs): + super(PairwiseKernel, self).__init__(**kwargs) + + self.latent_kernel = latent_kernel + self.is_partial_obs = is_partial_obs + + def forward(self, x1, x2, diag=False, **params): + r""" + TODO: make last_batch_dim work properly + + d must be 2*k for integer k, k is the dimension of the latent space + Args: + :attr:`x1` (Tensor `n x d` or `b x n x d`): + First set of data + :attr:`x2` (Tensor `m x d` or `b x m x d`): + Second set of data + :attr:`diag` (bool): + Should the Kernel compute the whole kernel, or just the diag? + + Returns: + :class:`Tensor` or :class:`gpytorch.lazy.LazyTensor`. + The exact size depends on the kernel's evaluation mode: + + * `full_covar`: `n x m` or `b x n x m` + * `diag`: `n` or `b x n` + """ + if self.is_partial_obs: + d = x1.shape[-1] - 1 + assert d == x2.shape[-1] - 1, "tensors not the same dimension" + assert d % 2 == 0, "dimension must be even" + + k = int(d / 2) + + # special handling for kernels that (also) do funky + # things with the input dimension + deriv_idx_1 = x1[..., -1][:, None] + deriv_idx_2 = x2[..., -1][:, None] + + a = torch.cat((x1[..., :k], deriv_idx_1), dim=1) + b = torch.cat((x1[..., k:-1], deriv_idx_1), dim=1) + c = torch.cat((x2[..., :k], deriv_idx_2), dim=1) + d = torch.cat((x2[..., k:-1], deriv_idx_2), dim=1) + + else: + d = x1.shape[-1] + + assert d == x2.shape[-1], "tensors not the same dimension" + assert d % 2 == 0, "dimension must be even" + + k = int(d / 2) + + a = x1[..., :k] + b = x1[..., k:] + c = x2[..., :k] + d = x2[..., k:] + + if not diag: + return ( + lazify(self.latent_kernel(a, c, diag=diag, **params)) + + lazify(self.latent_kernel(b, d, diag=diag, **params)) + - lazify(self.latent_kernel(b, c, diag=diag, **params)) + - lazify(self.latent_kernel(a, d, diag=diag, **params)) + ) + else: + return ( + self.latent_kernel(a, c, diag=diag, **params) + + self.latent_kernel(b, d, diag=diag, **params) + - self.latent_kernel(b, c, diag=diag, **params) + - self.latent_kernel(a, d, diag=diag, **params) + ) diff --git a/aepsych/likelihoods/__init__.py b/aepsych/likelihoods/__init__.py index dfe839d4a..824d78d35 100644 --- a/aepsych/likelihoods/__init__.py +++ b/aepsych/likelihoods/__init__.py @@ -9,13 +9,17 @@ from ..config import Config from .bernoulli import BernoulliObjectiveLikelihood +from .ddm import DDMLikelihood, LapseRateRTLikelihood from .ordinal import OrdinalLikelihood from .semi_p import LinearBernoulliLikelihood + __all__ = [ "BernoulliObjectiveLikelihood", "OrdinalLikelihood", "LinearBernoulliLikelihood", + "DDMLikelihood", + "LapseRateRTLikelihood" ] Config.register_module(sys.modules[__name__]) diff --git a/aepsych/likelihoods/ddm.py b/aepsych/likelihoods/ddm.py new file mode 100644 index 000000000..851142037 --- /dev/null +++ b/aepsych/likelihoods/ddm.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gpytorch + +import torch +from gpytorch.likelihoods import _OneDimensionalLikelihood + +from aepsych.distributions import RTDistWithUniformLapseRate + +class DDMLikelihood(_OneDimensionalLikelihood): + """ """ + + def __init__(self, distribution, max_shift = None, restrict_skew = False): + super().__init__() + self.distribution = distribution + self.register_parameter( + name="raw_relative_x0", parameter=torch.nn.Parameter(torch.randn(1)) + ) + self.register_constraint("raw_relative_x0", gpytorch.constraints.Interval(0, 1)) + + self.register_parameter( + name="raw_t0", parameter=torch.nn.Parameter(torch.randn(1)) + ) + self.register_constraint("raw_t0", gpytorch.constraints.Interval(0., 1.0)) + + self.register_parameter( + name="raw_threshold", parameter=torch.nn.Parameter(torch.randn(1)) + ) + self.register_constraint("raw_threshold", gpytorch.constraints.Positive()) + + self.max_shift = max_shift + self.restrict_skew = restrict_skew + + def _set_relative_x0(self, value): + value = self.raw_relative_x0_constraint.inverse_transform(value) + self.initialize(raw_relative_x0=value) + + def _set_threshold(self, value): + value = self.raw_threshold_constraint.inverse_transform(value) + self.initialize(raw_threshold=value) + + def _set_t0(self, value): + value = self.raw_t0_constraint.inverse_transform(value) + self.initialize(raw_t0=value) + + @property + def relative_x0(self): + return self.raw_relative_x0_constraint.transform(self.raw_relative_x0) + + @relative_x0.setter + def relative_x0(self, value): + self._set_relative_x0(value) + + @property + def x0(self): + return self.threshold * (2*self.relative_x0 - 1) + + @property + def t0(self): + return self.raw_t0_constraint.transform(self.raw_t0) + + @t0.setter + def t0(self, value): + self._set_t0(value) + + @property + def threshold(self): + return self.raw_threshold_constraint.transform(self.raw_threshold) + + @threshold.setter + def threshold(self, value): + self._set_threshold(value) + + def forward(self, function_samples, *params, **kwargs): + return self.distribution( + drift=function_samples, threshold=self.threshold, relative_x0=self.relative_x0, t0=self.t0, max_shift = self.max_shift, restrict_skew = self.restrict_skew + ) + + @classmethod + def from_config(cls, config): + classname = cls.__name__ + max_shift = config.getfloat(classname, "max_shift", fallback=None) + restrict_skew = config.getboolean(classname, "restrict_skew", fallback=None) + + distribution = config.getobj(classname, "distribution") + + + return cls(distribution=distribution, max_shift=max_shift, restrict_skew=restrict_skew) + + # def log_marginal(self, observations, function_dist, *args, **kwargs): + # """ + # here we need the expectation of logp(r,c|f) w.r.t f + # p(r, c|f) = p(r|c,f)p(c|f), so we can factorize + # the log marginal as E_f log p(r|c,f) + E_f log p(c|f). + # and to the integrals separately + # """ + # choices = observations > 0 + # # rt_log_probs = torch.where(choices, yes_log_probs, no_log_probs) + + # def choice_prob_sampler(function_samples): + # ddmdist = self.forward(function_samples) + # return ddmdist.choice_dist.log_prob(choices.float()).exp() + + # choice_marginal = self.quadrature(choice_prob_sampler, function_dist) + + # def rt_prob_sampler(function_samples): + # ddmdist = self.forward(function_samples) + # yes_probs = ddmdist.rt_yes_dist.log_prob(torch.abs(observations)).exp() + # no_probs = ddmdist.rt_no_dist.log_prob(torch.abs(observations)).exp() + # return torch.where(choices, yes_probs, no_probs) + + # rt_marginal = self.quadrature(rt_prob_sampler, function_dist) + + # return choice_marginal.log() + rt_marginal.log() + + +class LapseRateRTLikelihood(_OneDimensionalLikelihood): + def __init__(self, base_likelihood, max_rt=10.0): + super().__init__() + self.max_rt = max_rt + self.base_likelihood = base_likelihood + self.register_parameter( + name="raw_lapse_rate", parameter=torch.nn.Parameter(torch.randn(1)) + ) + self.register_constraint( + "raw_lapse_rate", gpytorch.constraints.Interval(1e-5, 0.2) + ) # any greater than that and the model is really bad anyway + + @property + def lapse_rate(self): + return self.raw_lapse_rate_constraint.transform(self.raw_lapse_rate) + + def forward(self, function_samples, *args, **kwargs): + base_dist = self.base_likelihood(function_samples, *args, **kwargs) + return RTDistWithUniformLapseRate( + lapse_rate=self.lapse_rate, base_dist=base_dist, max_rt=self.max_rt + ) + + @classmethod + def from_config(cls, config): + classname = cls.__name__ + max_rt = config.getfloat(classname, "max_rt", fallback=10.) + + base_lik_class = config.getobj(classname, "base_likelihood") + + base_lik = base_lik_class.from_config(config) + return cls(base_likelihood = base_lik, max_rt = max_rt) diff --git a/tests/test_ddm_distr.py b/tests/test_ddm_distr.py new file mode 100644 index 000000000..a2e27ea22 --- /dev/null +++ b/tests/test_ddm_distr.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +import unittest + +import numpy as np + +import torch +from functools import partial +from torch.func import grad +from torch import vmap +from numbers import Number + +from aepsych.distributions import DDMMomentMatchDistribution + + +class TestDDMDistr(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + pass + +global_atol = 1e-3 + +def ddm_mgf(alpha, drift, x0, threshold, response=1): + """ + Moment-generating function of the Wiener First Passage time distribution (DDM) + """ + if response == 0: + drift = -drift.clone() + threshold = -threshold.clone() + return torch.exp(drift * (threshold - x0)) * ( + torch.sinh((threshold + x0) * torch.sqrt(drift**2 - 2 * alpha)) + / torch.sinh(2 * threshold * torch.sqrt(drift**2 - 2 * alpha)) + ) + + +def ddm_cgf(alpha, drift, x0, threshold, response=1): + """ + Cumulant-generating function of the Wiener First Passage time distribution (DDM) + """ + + return torch.log(ddm_mgf(alpha, drift, x0, threshold, response=response)) + + +def ddm_moment_cumulant(n, drift, x0, threshold, fun="cumulant", response=1): + """ + Function to generate arbitrary moments or cumulants of DDM by autodiff, + vectorized over drift (but not other arguments currently, TODO). + """ + assert fun in ("moment", "cumulant") + if isinstance(drift, Number): + drift = torch.Tensor([drift]) + if fun == "moment": + deriv_fun = ddm_mgf + elif fun == "cumulant": + deriv_fun = ddm_cgf + else: + raise RuntimeError(f"fun should be moment or cumulant, got {fun}") + for _ in range(n): + deriv_fun = grad(deriv_fun) + moment_fun = partial(deriv_fun, torch.tensor(0.0), response=response) + + moment_fun_vmap = vmap(moment_fun, in_dims=(0, None, None)) + + return moment_fun_vmap(drift, x0, threshold) + + + +class DDMMomemtnMatchTest(unittest.TestCase): + def setUp(self): + np.random.seed(1) + torch.manual_seed(1) + self.f = torch.randn(100) + # things are numerically unstable as drift -> 0. + # in the momentmatch expressions we use limiting expressions + # if drift is too small but we don't have that in moment/cumulant + # so exclude from tests. TODO: can probably improve numerical stability. + self.f = self.f + torch.sign(self.f) * 0.05 + self.relative_x0 = torch.tensor(0.1) + self.t0 = torch.tensor(0.15) + self.rt_dist = TestDDMDistr( + drift=self.f, + threshold=torch.tensor(0.5), + relative_x0=self.relative_x0, + t0=self.t0, + ) + self.x0 = torch.tensor(0.5 * (2 * self.relative_x0- 1)) + + def test_mean(self): + # sanity check mean + expected_yes_mean = ddm_moment_cumulant( + n=1, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=1, + ) + expected_no_mean = ddm_moment_cumulant( + n=1, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=0, + ) + self.assertTrue( + torch.allclose( + self.rt_dist.rt_mean_yes, + expected_yes_mean + self.t0, + atol=global_atol, + ) + ) + self.assertTrue( + torch.allclose( + self.rt_dist.rt_mean_no, + expected_no_mean + self.t0, + atol=global_atol, + ) + ) + + def test_var(self): + # sanity check var + expected_yes_var = ddm_moment_cumulant( + n=2, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=1, + ) + expected_no_var = ddm_moment_cumulant( + n=2, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=0, + ) + self.assertTrue( + torch.allclose(self.rt_dist.rt_var_yes, expected_yes_var, atol=global_atol) + ) + self.assertTrue( + torch.allclose(self.rt_dist.rt_var_no, expected_no_var, atol=global_atol) + ) + + def test_skew(self): + # sanity check skew + expected_yes_var = ddm_moment_cumulant( + n=2, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=1, + ) + expected_no_var = ddm_moment_cumulant( + n=2, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=0, + ) + expected_yes_skew = ddm_moment_cumulant( + n=3, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=1, + ) / (expected_yes_var ** (3 / 2)) + expected_no_skew = ddm_moment_cumulant( + n=3, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=0, + ) / (expected_no_var ** (3 / 2)) + self.assertTrue( + torch.allclose( + self.rt_dist.rt_skew_yes, expected_yes_skew, atol=global_atol + ) + ) + self.assertTrue( + torch.allclose(self.rt_dist.rt_skew_no, expected_no_skew, atol=global_atol) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pairwise_kernel.py b/tests/test_pairwise_kernel.py new file mode 100644 index 000000000..fbb408d89 --- /dev/null +++ b/tests/test_pairwise_kernel.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +import unittest + +import numpy as np +import numpy.testing as npt +import torch +from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad +from aepsych.kernels.pairwisekernel import PairwiseKernel +from gpytorch.kernels import RBFKernel + + +class PairwiseKernelTest(unittest.TestCase): + """ + Basic tests that PairwiseKernel is working + """ + + def setUp(self): + self.latent_kernel = RBFKernel() + self.kernel = PairwiseKernel(self.latent_kernel) + + def test_kernelgrad_pairwise(self): + kernel = PairwiseKernel(RBFKernelPartialObsGrad(), is_partial_obs=True) + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + + x1 = torch.cat((x1, torch.zeros(2, 1)), dim=1) + x2 = torch.cat((x2, torch.zeros(2, 1)), dim=1) + + deriv_idx_1 = x1[..., -1][:, None] + deriv_idx_2 = x2[..., -1][:, None] + + a = torch.cat((x1[..., :2], deriv_idx_1), dim=1) + b = torch.cat((x1[..., 2:-1], deriv_idx_1), dim=1) + c = torch.cat((x2[..., :2], deriv_idx_2), dim=1) + d = torch.cat((x2[..., 2:-1], deriv_idx_2), dim=1) + + c12 = kernel.forward(x1, x2).evaluate().detach().numpy() + pwc = ( + ( + kernel.latent_kernel.forward(a, c) + - kernel.latent_kernel.forward(a, d) + - kernel.latent_kernel.forward(b, c) + + kernel.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c12, pwc, atol=1e-6) + + def test_dim_check(self): + """ + Test that we get expected errors. + """ + x1 = torch.zeros(torch.Size([3])) + x2 = torch.zeros(torch.Size([3])) + x3 = torch.zeros(torch.Size([2])) + x4 = torch.zeros(torch.Size([4])) + + self.assertRaises(AssertionError, self.kernel.forward, x1=x1, x2=x2) + + self.assertRaises(AssertionError, self.kernel.forward, x1=x3, x2=x4) + + def test_covar(self): + """ + Test that we get expected covariances + """ + np.random.seed(1) + torch.manual_seed(1) + + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + a = x1[..., :2] + b = x1[..., 2:] + c = x2[..., :2] + d = x2[..., 2:] + c12 = self.kernel.forward(x1, x2).evaluate().detach().numpy() + pwc = ( + ( + self.latent_kernel.forward(a, c) + - self.latent_kernel.forward(a, d) + - self.latent_kernel.forward(b, c) + + self.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c12, pwc, atol=1e-6) + + shape = np.array(c12.shape) + npt.assert_equal(shape, np.array([2, 2])) + + x3 = torch.rand(torch.Size([3, 4])) + x4 = torch.rand(torch.Size([6, 4])) + a = x3[..., :2] + b = x3[..., 2:] + c = x4[..., :2] + d = x4[..., 2:] + c34 = self.kernel.forward(x3, x4).evaluate().detach().numpy() + pwc = ( + ( + self.latent_kernel.forward(a, c) + - self.latent_kernel.forward(a, d) + - self.latent_kernel.forward(b, c) + + self.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c34, pwc, atol=1e-6) + + shape = np.array(c34.shape) + npt.assert_equal(shape, np.array([3, 6])) + + def test_latent_diag(self): + """ + g(a, a) = 0 for all a, so K((a, a), (a, a)) = 0 + """ + + np.random.seed(1) + torch.manual_seed(1) + a = torch.rand(torch.Size([2, 2])) + + # should get 0 variance on pairs (a,a) + diag = torch.cat((a, a), dim=1) + diagv = self.kernel.forward(diag, diag).evaluate().detach().numpy() + npt.assert_allclose(diagv, 0.0) + + def test_diag(self): + """ + make sure the diagonal is the right shape + """ + np.random.seed(1) + torch.manual_seed(1) + + x1 = torch.rand(torch.Size([2, 2, 4])) + x2 = torch.rand(torch.Size([2, 2, 4])) + + diag = self.kernel(x1, x2, diag=True) + shape = np.array(diag.shape) + npt.assert_equal(shape, np.array([2, 2])) + + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + + diag = self.kernel(x1, x2, diag=True) + shape = np.array(diag.shape) + npt.assert_equal(shape, np.array([2])) + + +if __name__ == "__main__": + unittest.main()