Skip to content

Commit

Permalink
NF: Allow retrieval of marginal posterior PDFs
Browse files Browse the repository at this point in the history
Closes GH-22.
  • Loading branch information
hoechenberger committed Jul 21, 2019
1 parent aef5747 commit 1f34148
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Current development branch
--------------------------
* Allow retrieval of marginal posterior PDFs via `QuestPlus.marginal_posterior`
22 changes: 22 additions & 0 deletions questplus/qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,28 @@ def param_estimate(self) -> dict:

return param_estimates

@property
def marginal_posterior(self) -> dict:
"""
Retrieve the a dictionary of marginal posterior probability
density functions (PDFs).
Returns
-------
A dictionary of marginal PDFs, where the dictionary keys correspond to
the parameter names.
"""
marginal_posterior = dict()
for param_name in self.param_domain.keys():
marginalized_out_params = list(self.param_domain.keys())
marginalized_out_params.remove(param_name)
marginal_posterior[param_name] = (self.posterior
.sum(dim=marginalized_out_params)
.values)

return marginal_posterior

def to_json(self) -> str:
"""
Dump this `QuestPlus` instance as a JSON string which can be loaded
Expand Down
32 changes: 32 additions & 0 deletions questplus/tests/test_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,37 @@ def test_json():
q_loaded.update(stim=q_loaded.next_stim, outcome=dict(response='Correct'))


def test_marginal_posterior():
contrasts = np.arange(-40, 0 + 1)
slope = np.arange(2, 5 + 1)
lower_asymptote = (0.5,)
lapse_rate = np.arange(0, 0.04 + 0.01, 0.01)

stim_domain = dict(intensity=contrasts)
param_domain = dict(threshold=contrasts, slope=slope,
lower_asymptote=lower_asymptote, lapse_rate=lapse_rate)
outcome_domain = dict(response=['Correct', 'Incorrect'])

func = 'weibull'
stim_scale = 'dB'

q = QuestPlus(stim_domain=stim_domain,
param_domain=param_domain,
outcome_domain=outcome_domain,
func=func, stim_scale=stim_scale)

marginal_posterior = q.marginal_posterior

assert np.allclose(marginal_posterior['threshold'],
np.ones(len(contrasts)) / len(contrasts))
assert np.allclose(marginal_posterior['slope'],
np.ones(len(slope)) / len(slope))
assert np.allclose(marginal_posterior['lower_asymptote'],
np.ones(len(lower_asymptote)) / len(lower_asymptote))
assert np.allclose(marginal_posterior['lapse_rate'],
np.ones(len(lapse_rate)) / len(lapse_rate))


if __name__ == '__main__':
test_threshold()
test_threshold_slope()
Expand All @@ -500,3 +531,4 @@ def test_json():
test_weibull()
test_eq()
test_json()
test_marginal_posterior()

0 comments on commit 1f34148

Please sign in to comment.