-
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add combine_roulette function (#555)
* Add combine_roulette function * check weights are positive
- Loading branch information
1 parent
5578b81
commit c556b46
Showing
5 changed files
with
177 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
from numpy.testing import assert_almost_equal | ||
from preliz import combine_roulette | ||
from preliz.distributions import BetaScaled, LogNormal, StudentT | ||
|
||
response0 = ( | ||
[1.5, 2.5, 3.5], | ||
[0.32142857142857145, 0.35714285714285715, 0.32142857142857145], | ||
28, | ||
0, | ||
10, | ||
10, | ||
11, | ||
) | ||
response1 = ( | ||
[7.5, 8.5, 9.5], | ||
[0.32142857142857145, 0.35714285714285715, 0.32142857142857145], | ||
28, | ||
0, | ||
10, | ||
10, | ||
11, | ||
) | ||
response2 = ([9.5], [1], 10, 0, 10, 10, 11) | ||
response3 = ([9.5], [1], 10, 0, 10, 10, 14) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"responses, weights, dist_names, params, result", | ||
[ | ||
([response0, response1], [0.5, 0.5], None, None, BetaScaled(1.2, 1, 0, 10)), | ||
( | ||
[response0, response1], | ||
[0.5, 0.5], | ||
["Beta", "StudentT"], | ||
"TruncatedNormal(lower=0), StudentT(nu=1000)", | ||
StudentT(1000, 5.5, 3.1), | ||
), | ||
([response0, response2], [1, 1], None, None, LogNormal(1.1, 0.6)), | ||
], | ||
) | ||
def test_combine_roulette(responses, weights, dist_names, params, result): | ||
dist = combine_roulette(responses, weights, dist_names, params) | ||
assert_almost_equal(dist.params, result.params, decimal=1) | ||
|
||
|
||
def test_combine_roulette_error(): | ||
with pytest.raises(ValueError): | ||
combine_roulette([response0, response3]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from .beta_mode import beta_mode | ||
from .combine_roulette import combine_roulette | ||
from .maxent import maxent | ||
from .mle import mle | ||
from .quartile import quartile | ||
from .quartile_int import quartile_int | ||
from .roulette import Roulette | ||
|
||
__all__ = ["beta_mode", "maxent", "mle", "Roulette", "quartile", "quartile_int"] | ||
__all__ = ["beta_mode", "combine_roulette", "maxent", "mle", "Roulette", "quartile", "quartile_int"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
|
||
from preliz.internal.distribution_helper import process_extra | ||
from preliz.internal.optimization import fit_to_epdf, get_distributions | ||
|
||
|
||
def combine_roulette(responses, weights=None, dist_names=None, params=None): | ||
""" | ||
Combine multiple elicited distributions into a single distribution. | ||
Parameters | ||
---------- | ||
responses : list of tuples | ||
Typically, each tuple comes from the ``.inputs`` attribute of a ``Roulette`` object and | ||
represents a single elicited distribution. | ||
weights : array-like, optional | ||
Weights for each elicited distribution. Defaults to None, i.e. equal weights. | ||
The sum of the weights must be equal to 1, otherwise it will be normalized. | ||
dist_names: list | ||
List of distributions names to be used in the elicitation. | ||
Defaults to ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]. | ||
params : str, optional | ||
Extra parameters to be passed to the distributions. The format is a string with the | ||
PreliZ's distribution name followed by the argument to fix. | ||
For example: "TruncatedNormal(lower=0), StudentT(nu=8)". | ||
Returns | ||
------- | ||
PreliZ distribution | ||
""" | ||
|
||
if params is not None: | ||
extra_pros = process_extra(params) | ||
else: | ||
extra_pros = [] | ||
|
||
if weights is None: | ||
weights = np.full(len(responses), 1 / len(responses)) | ||
else: | ||
weights = np.array(weights, dtype=float) | ||
|
||
if np.any(weights <= 0): | ||
raise ValueError("The weights must be positive.") | ||
|
||
weights /= weights.sum() | ||
|
||
if not all(records[3:] == responses[0][3:] for records in responses): | ||
raise ValueError( | ||
"To combine single elicitation instances, the grid should be the same for all of them." | ||
) | ||
|
||
if dist_names is None: | ||
dist_names = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"] | ||
|
||
new_pdf = {} | ||
for records, weight in zip(responses, weights): | ||
chips = records[2] | ||
for x_i, pdf_i in zip(records[0], records[1]): | ||
if x_i in new_pdf: | ||
new_pdf[x_i] += pdf_i * weight * chips | ||
else: | ||
new_pdf[x_i] = pdf_i * weight * chips | ||
|
||
total = sum(new_pdf.values()) | ||
mean = 0 | ||
for x_i, pdf_i in new_pdf.items(): | ||
val = pdf_i / total | ||
mean += x_i * val | ||
new_pdf[x_i] = val | ||
|
||
var = 0 | ||
for x_i, pdf_i in new_pdf.items(): | ||
var += pdf_i * (x_i - mean) ** 2 | ||
std = var**0.5 | ||
|
||
# Assuming all the elicited distributions have the same x_min and x_max | ||
x_min = responses[0][3] | ||
x_max = responses[0][4] | ||
|
||
fitted_dist = fit_to_epdf( | ||
get_distributions(dist_names), | ||
list(new_pdf.keys()), | ||
list(new_pdf.values()), | ||
mean, | ||
std, | ||
x_min, | ||
x_max, | ||
extra_pros, | ||
) | ||
|
||
return fitted_dist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters