Skip to content

Commit 58b51a6

Browse files
authored
Merge branch 'main' into doctor_visits_refactor_for_speed
2 parents b52d80a + 84d0597 commit 58b51a6

File tree

8 files changed

+80
-43
lines changed

8 files changed

+80
-43
lines changed

_delphi_utils_python/delphi_utils/weekday.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,19 @@ class Weekday:
1212
"""Class to handle weekday effects."""
1313

1414
@staticmethod
15-
def get_params(data, denominator_col, numerator_cols, date_col, scales, logger):
15+
def get_params(data, denominator_col, numerator_cols, date_col, scales, logger, solver_override=None):
1616
r"""Fit weekday correction for each col in numerator_cols.
1717
1818
Return a matrix of parameters: the entire vector of betas, for each time
1919
series column in the data.
20+
21+
solver: Historically used "ECOS" but due to numerical stability issues, "CLARABEL"
22+
(introduced in cvxpy 1.3)is now the default solver in cvxpy 1.5.
2023
"""
24+
if solver_override is None:
25+
solver = cp.CLARABEL
26+
else:
27+
solver = solver_override
2128
tmp = data.reset_index()
2229
denoms = tmp.groupby(date_col).sum()[denominator_col]
2330
nums = tmp.groupby(date_col).sum()[numerator_cols]
@@ -35,7 +42,7 @@ def get_params(data, denominator_col, numerator_cols, date_col, scales, logger):
3542

3643
# Loop over the available numerator columns and smooth each separately.
3744
for i in range(nums.shape[1]):
38-
result = Weekday._fit(X, scales, npnums[:, i], npdenoms)
45+
result = Weekday._fit(X, scales, npnums[:, i], npdenoms, solver)
3946
if result is None:
4047
logger.error("Unable to calculate weekday correction")
4148
else:
@@ -44,7 +51,18 @@ def get_params(data, denominator_col, numerator_cols, date_col, scales, logger):
4451
return params
4552

4653
@staticmethod
47-
def _fit(X, scales, npnums, npdenoms):
54+
def get_params_legacy(data, denominator_col, numerator_cols, date_col, scales, logger):
55+
r"""
56+
Preserves older default behavior of using the ECOS solver.
57+
58+
NOTE: "ECOS" solver will not be installed by default as of cvxpy 1.6
59+
"""
60+
return Weekday.get_params(
61+
data, denominator_col, numerator_cols, date_col, scales, logger, solver_override=cp.ECOS
62+
)
63+
64+
@staticmethod
65+
def _fit(X, scales, npnums, npdenoms, solver):
4866
r"""Correct a signal estimated as numerator/denominator for weekday effects.
4967
5068
The ordinary estimate would be numerator_t/denominator_t for each time point
@@ -78,6 +96,8 @@ def _fit(X, scales, npnums, npdenoms):
7896
7997
ll = (numerator * (X*b + log(denominator)) - sum(exp(X*b) + log(denominator)))
8098
/ num_days
99+
100+
solver: Historically use "ECOS" but due to numerical issues, "CLARABEL" is now default.
81101
"""
82102
b = cp.Variable((X.shape[1]))
83103

@@ -93,7 +113,7 @@ def _fit(X, scales, npnums, npdenoms):
93113
for scale in scales:
94114
try:
95115
prob = cp.Problem(cp.Minimize((-ll + lmbda * penalty) / scale))
96-
_ = prob.solve(solver=cp.CLARABEL)
116+
_ = prob.solve(solver=solver)
97117
return b.value
98118
except SolverError:
99119
# If the magnitude of the objective function is too large, an error is

_delphi_utils_python/tests/test_weekday.py

+25
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@ def test_get_params(self):
2525
-0.81464459, -0.76322013, -0.7667211,-0.8251475]])
2626
assert np.allclose(result, expected_result)
2727

28+
def test_get_params_legacy(self):
29+
TEST_LOGGER = logging.getLogger()
30+
31+
result = Weekday.get_params_legacy(self.TEST_DATA, "den", ["num"], "date", [1], TEST_LOGGER)
32+
print(result)
33+
expected_result = [
34+
-0.05993665,
35+
-0.0727396,
36+
-0.05618517,
37+
0.0343405,
38+
0.12534997,
39+
0.04561813,
40+
-2.27669028,
41+
-1.89564374,
42+
-1.5695407,
43+
-1.29838116,
44+
-1.08216513,
45+
-0.92089259,
46+
-0.81456355,
47+
-0.76317802,
48+
-0.76673598,
49+
-0.82523745,
50+
]
51+
assert np.allclose(result, expected_result)
52+
2853
def test_calc_adjustment_with_zero_parameters(self):
2954
params = np.array([[0, 0, 0, 0, 0, 0, 0]])
3055

ansible/templates/sir_complainsalot-params-prod.json.j2

-9
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212
"maintainers": ["U01AP8GSWG3","U01069KCRS7"],
1313
"retired-signals": ["smoothed_covid19","smoothed_adj_covid19"]
1414
},
15-
"chng": {
16-
"max_age": 6,
17-
"maintainers": ["U01AP8GSWG3","U01069KCRS7"],
18-
"retired-signals": ["7dav_outpatient_covid","7dav_inpatient_covid"]
19-
},
2015
"google-symptoms": {
2116
"max_age": 6,
2217
"maintainers": ["U01AP8GSWG3","U01069KCRS7"],
@@ -47,10 +42,6 @@
4742
"max_age":19,
4843
"maintainers": []
4944
},
50-
"hhs": {
51-
"max_age":15,
52-
"maintainers": []
53-
},
5445
"nssp": {
5546
"max_age":13,
5647
"maintainers": []

claims_hosp/delphi_claims_hosp/update_indicator.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# third party
1414
import numpy as np
1515
import pandas as pd
16-
from delphi_utils import GeoMapper
1716

1817
# first party
19-
from delphi_utils import Weekday
18+
from delphi_utils import GeoMapper, Weekday
19+
2020
from .config import Config, GeoConstants
21-
from .load_data import load_data
2221
from .indicator import ClaimsHospIndicator
22+
from .load_data import load_data
2323

2424

2525
class ClaimsHospIndicatorUpdater:
@@ -152,15 +152,18 @@ def update_indicator(self, input_filepath, outpath, logger):
152152
data_frame = self.geo_reindex(data)
153153

154154
# handle if we need to adjust by weekday
155-
wd_params = Weekday.get_params(
156-
data_frame,
157-
"den",
158-
["num"],
159-
Config.DATE_COL,
160-
[1, 1e5],
161-
logger,
162-
) if self.weekday else None
163-
155+
wd_params = (
156+
Weekday.get_params_legacy(
157+
data_frame,
158+
"den",
159+
["num"],
160+
Config.DATE_COL,
161+
[1, 1e5],
162+
logger,
163+
)
164+
if self.weekday
165+
else None
166+
)
164167
# run fitting code (maybe in parallel)
165168
rates = {}
166169
std_errs = {}

claims_hosp/setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"pylint==2.8.3",
1414
"pytest-cov",
1515
"pytest",
16+
"cvxpy<1.6",
1617
]
1718

1819
setup(

doctor_visits/delphi_doctor_visits/update_sensor.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
# first party
2020
from delphi_utils import Weekday
21+
2122
from .config import Config
2223
from .geo_maps import GeoMaps
2324
from .sensor import DoctorVisitsSensor
@@ -51,15 +52,19 @@ def update_sensor(
5152
(burn_in_dates >= startdate) & (burn_in_dates <= enddate))[0][:len(sensor_dates)]
5253

5354
# handle if we need to adjust by weekday
54-
params = Weekday.get_params(
55-
data,
56-
"Denominator",
57-
Config.CLI_COLS + Config.FLU1_COL,
58-
Config.DATE_COL,
59-
[1, 1e5, 1e10, 1e15],
60-
logger,
61-
) if weekday else None
62-
if weekday and np.any(np.all(params == 0,axis=1)):
55+
params = (
56+
Weekday.get_params(
57+
data,
58+
"Denominator",
59+
Config.CLI_COLS + Config.FLU1_COL,
60+
Config.DATE_COL,
61+
[1, 1e5, 1e10, 1e15],
62+
logger,
63+
)
64+
if weekday
65+
else None
66+
)
67+
if weekday and np.any(np.all(params == 0, axis=1)):
6368
# Weekday correction failed for at least one count type
6469
return None
6570

doctor_visits/setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"pytest",
1313
"scikit-learn",
1414
"dask",
15+
"cvxpy>=1.5",
1516
]
1617

1718
setup(

sir_complainsalot/params.json.template

-9
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212
"maintainers": ["U01AP8GSWG3","U01069KCRS7"],
1313
"retired-signals": ["smoothed_covid19","smoothed_adj_covid19"]
1414
},
15-
"chng": {
16-
"max_age": 6,
17-
"maintainers": ["U01AP8GSWG3","U01069KCRS7"],
18-
"retired-signals": ["7dav_outpatient_covid","7dav_inpatient_covid"]
19-
},
2015
"google-symptoms": {
2116
"max_age": 6,
2217
"maintainers": ["U01AP8GSWG3","U01069KCRS7"],
@@ -47,10 +42,6 @@
4742
"max_age":19,
4843
"maintainers": []
4944
},
50-
"hhs": {
51-
"max_age":15,
52-
"maintainers": []
53-
},
5445
"nssp": {
5546
"max_age":13,
5647
"maintainers": []

0 commit comments

Comments
 (0)