Skip to content

Commit

Permalink
Merge pull request #2 from ihmeuw-msca/feature/objective-functions
Browse files Browse the repository at this point in the history
Feature/objective functions
  • Loading branch information
mbi6245 authored Aug 20, 2024
2 parents ac0b758 + a0f2567 commit 7d39aa7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
18 changes: 14 additions & 4 deletions src/ensemble/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from ensemble.distributions import distribution_dict

# from jaxopt import ScipyBoundedMinimize


class EnsembleModel:
"""Ensemble distribution object that provides limited functionality similar
Expand Down Expand Up @@ -277,9 +279,13 @@ def objective_func(self, vec: np.ndarray, objective: str) -> float:
NotImplementedError
because the other ones havent been implemented yet lol
"""
if objective is not None:
raise NotImplementedError
return linalg.norm(vec, 2) ** 2
match objective:
case "L1":
return linalg.norm(vec, 1)
case "L2":
return linalg.norm(vec, 2) ** 2
case "KS":
return np.max(np.abs(vec))

def ensemble_func(
self, weights: List[float], ecdf: np.ndarray, cdfs: np.ndarray
Expand Down Expand Up @@ -342,12 +348,16 @@ def fit(self, data: npt.ArrayLike) -> EnsembleResult:
# initialize equal weights for all dists and optimize
initial_guess = np.zeros(num_distributions) + 1 / num_distributions
bounds = tuple((0, 1) for i in range(num_distributions))
# TODO: IMPLEMENT WITH JAX INSTEAD
# minimizer_result = ScipyBoundedMinimize(
# fun=self.ensemble_func, args=(ecdf, cdfs), method="l-bfgs-b"
# ).run(initial_guess, bounds=bounds)
# fitted_weights = minimizer_result.params
minimizer_result = opt.minimize(
fun=self.ensemble_func,
x0=initial_guess,
args=(ecdf, cdfs),
bounds=bounds,
# options={"disp": True},
)
fitted_weights = minimizer_result.x

Expand Down
37 changes: 32 additions & 5 deletions tests/test_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
variance=1,
).rvs(size=100)

ENSEMBLE_POS_DRAWS2 = EnsembleModel(
distributions=["exponential", "lognormal", "fisk"],
weights=[0.3, 0.5, 0.1],
mean=40,
variance=5,
)
# ENSEMBLE_POS_DRAWS = EnsembleModel(
# distributions=["exponential", "lognormal"],
# weights=[0.5, 0.5],
Expand All @@ -25,26 +31,47 @@


def test_1_dist():
model = EnsembleFitter(["normal"], None)
model = EnsembleFitter(["normal"], "L2")
res = model.fit(STD_NORMAL_DRAWS)
print(res.weights)
assert np.isclose(res.weights[0], 1)

wrong_model = EnsembleFitter(["normal", "gumbel"], None)
wrong_model = EnsembleFitter(["normal", "gumbel"], "L2")
res = wrong_model.fit(STD_NORMAL_DRAWS)
print(res.weights)
assert np.allclose(res.weights, [1, 0])


def test_2_real_line_dists():
model1 = EnsembleFitter(["normal", "gumbel"], None)
model1 = EnsembleFitter(["normal", "gumbel"], "L2")
res1 = model1.fit(ENSEMBLE_RL_DRAWS)
print(res1.weights)
assert np.allclose(res1.weights, [0.7, 0.3])


def test_2_positive_dists():
model2 = EnsembleFitter(["exponential", "lognormal"], None)
def test_2_positive_dists_L1():
model2 = EnsembleFitter(["exponential", "lognormal"], "L1")
res2 = model2.fit(ENSEMBLE_POS_DRAWS)
print(res2.weights)
assert np.allclose(res2.weights, [0.5, 0.5])


def test_2_positive_dists_L2():
model2 = EnsembleFitter(["exponential", "lognormal"], "L2")
res2 = model2.fit(ENSEMBLE_POS_DRAWS)
print(res2.weights)
assert np.allclose(res2.weights, [0.5, 0.5])


def test_2_positive_dists_KS():
model2 = EnsembleFitter(["exponential", "lognormal"], "KS")
res2 = model2.fit(ENSEMBLE_POS_DRAWS)
print(res2.weights)
assert np.allclose(res2.weights, [0.5, 0.5])


def test_3_positive_dists_KS():
model2 = EnsembleFitter(["exponential", "lognormal", "fisk"], "KS")
res2 = model2.fit(ENSEMBLE_POS_DRAWS)
print(res2.weights)
assert np.allclose(res2.weights, [0.9, 0.05, 0.05])

0 comments on commit 7d39aa7

Please sign in to comment.