Skip to content

Commit fc2c58d

Browse files
authored
Merge pull request #1981 from cmu-delphi/optimize_with_dask
Optimize with dask
2 parents 9740899 + d1ee4ce commit fc2c58d

10 files changed

+148
-96
lines changed

doctor_visits/delphi_doctor_visits/config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@ class Config:
1919
# data columns
2020
CLI_COLS = ["Covid_like", "Flu_like", "Mixed"]
2121
FLU1_COL = ["Flu1"]
22-
COUNT_COLS = ["Denominator"] + FLU1_COL + CLI_COLS
22+
COUNT_COLS = CLI_COLS + FLU1_COL + ["Denominator"]
2323
DATE_COL = "ServiceDate"
2424
GEO_COL = "PatCountyFIPS"
2525
AGE_COL = "PatAgeGroup"
2626
HRR_COLS = ["Pat HRR Name", "Pat HRR ID"]
27+
# as of 2020-05-11, input file expected to have 10 columns
28+
# id cols: ServiceDate, PatCountyFIPS, PatAgeGroup, Pat HRR ID/Pat HRR Name
29+
# value cols: Denominator, Covid_like, Flu_like, Flu1, Mixed
2730
ID_COLS = [DATE_COL] + [GEO_COL] + HRR_COLS + [AGE_COL]
28-
FILT_COLS = ID_COLS + COUNT_COLS
31+
# drop HRR columns - unused for now since we assign HRRs by FIPS
32+
FILT_COLS = [DATE_COL] + [GEO_COL] + [AGE_COL] + COUNT_COLS
2933
DTYPES = {
3034
"ServiceDate": str,
3135
"PatCountyFIPS": str,
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import dask.dataframe as dd
2+
from datetime import datetime
3+
import numpy as np
4+
import pandas as pd
5+
from pathlib import Path
6+
7+
from .config import Config
8+
9+
10+
def write_to_csv(output_df: pd.DataFrame, geo_level: str, se:bool, out_name: str, logger, output_path="."):
11+
"""Write sensor values to csv.
12+
13+
Args:
14+
output_dict: dictionary containing sensor rates, se, unique dates, and unique geo_id
15+
geo_level: geographic resolution, one of ["county", "state", "msa", "hrr", "nation", "hhs"]
16+
se: boolean to write out standard errors, if true, use an obfuscated name
17+
out_name: name of the output file
18+
output_path: outfile path to write the csv (default is current directory)
19+
"""
20+
if se:
21+
logger.info(f"========= WARNING: WRITING SEs TO {out_name} =========")
22+
23+
out_n = 0
24+
for d in set(output_df["date"]):
25+
filename = "%s/%s_%s_%s.csv" % (output_path,
26+
(d + Config.DAY_SHIFT).strftime("%Y%m%d"),
27+
geo_level,
28+
out_name)
29+
single_date_df = output_df[output_df["date"] == d]
30+
with open(filename, "w") as outfile:
31+
outfile.write("geo_id,val,se,direction,sample_size\n")
32+
33+
for line in single_date_df.itertuples():
34+
geo_id = line.geo_id
35+
sensor = 100 * line.val # report percentages
36+
se_val = 100 * line.se
37+
assert not np.isnan(sensor), "sensor value is nan, check pipeline"
38+
assert sensor < 90, f"strangely high percentage {geo_id, sensor}"
39+
if not np.isnan(se_val):
40+
assert se_val < 5, f"standard error suspiciously high! investigate {geo_id}"
41+
42+
if se:
43+
assert sensor > 0 and se_val > 0, "p=0, std_err=0 invalid"
44+
outfile.write(
45+
"%s,%f,%s,%s,%s\n" % (geo_id, sensor, se_val, "NA", "NA"))
46+
else:
47+
# for privacy reasons we will not report the standard error
48+
outfile.write(
49+
"%s,%f,%s,%s,%s\n" % (geo_id, sensor, "NA", "NA", "NA"))
50+
out_n += 1
51+
logger.debug(f"wrote {out_n} rows for {geo_level}")
52+
53+
54+
def csv_to_df(filepath: str, startdate: datetime, enddate: datetime, dropdate: datetime, logger) -> pd.DataFrame:
55+
'''
56+
Reads csv using Dask and filters out based on date range and currently unused column,
57+
then converts back into pandas dataframe.
58+
Parameters
59+
----------
60+
filepath: path to the aggregated doctor-visits data
61+
startdate: first sensor date (YYYY-mm-dd)
62+
enddate: last sensor date (YYYY-mm-dd)
63+
dropdate: data drop date (YYYY-mm-dd)
64+
65+
-------
66+
'''
67+
filepath = Path(filepath)
68+
logger.info(f"Processing {filepath}")
69+
70+
ddata = dd.read_csv(
71+
filepath,
72+
compression="gzip",
73+
dtype=Config.DTYPES,
74+
blocksize=None,
75+
)
76+
77+
ddata = ddata.dropna()
78+
# rename inconsistent column names to match config column names
79+
ddata = ddata.rename(columns=Config.DEVIANT_COLS_MAP)
80+
81+
ddata = ddata[Config.FILT_COLS]
82+
ddata[Config.DATE_COL] = dd.to_datetime(ddata[Config.DATE_COL])
83+
84+
# restrict to training start and end date
85+
startdate = startdate - Config.DAY_SHIFT
86+
87+
assert startdate > Config.FIRST_DATA_DATE, "Start date <= first day of data"
88+
assert startdate < enddate, "Start date >= end date"
89+
assert enddate <= dropdate, "End date > drop date"
90+
91+
date_filter = ((ddata[Config.DATE_COL] >= Config.FIRST_DATA_DATE) & (ddata[Config.DATE_COL] < dropdate))
92+
93+
df = ddata[date_filter].compute()
94+
95+
# aggregate age groups (so data is unique by service date and FIPS)
96+
df = df.groupby([Config.DATE_COL, Config.GEO_COL]).sum(numeric_only=True).reset_index()
97+
assert np.sum(df.duplicated()) == 0, "Duplicates after age group aggregation"
98+
assert (df[Config.COUNT_COLS] >= 0).all().all(), "Counts must be nonnegative"
99+
100+
logger.info(f"Done processing {filepath}")
101+
return df

doctor_visits/delphi_doctor_visits/run.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from delphi_utils import get_structured_logger
1515

1616
# first party
17-
from .update_sensor import update_sensor, write_to_csv
17+
from .update_sensor import update_sensor
18+
from .process_data import csv_to_df, write_to_csv
1819
from .download_claims_ftp_files import download
1920
from .get_latest_claims_name import get_latest_filename
2021

@@ -85,6 +86,7 @@ def run_module(params): # pylint: disable=too-many-statements
8586
## geographies
8687
geos = ["state", "msa", "hrr", "county", "hhs", "nation"]
8788

89+
claims_df = csv_to_df(claims_file, startdate_dt, enddate_dt, dropdate_dt, logger)
8890

8991
## print out other vars
9092
logger.info("outpath:\t\t%s", export_dir)
@@ -103,10 +105,10 @@ def run_module(params): # pylint: disable=too-many-statements
103105
else:
104106
logger.info("starting %s, no adj", geo)
105107
sensor = update_sensor(
106-
filepath=claims_file,
107-
startdate=startdate,
108-
enddate=enddate,
109-
dropdate=dropdate,
108+
data=claims_df,
109+
startdate=startdate_dt,
110+
enddate=enddate_dt,
111+
dropdate=dropdate_dt,
110112
geo=geo,
111113
parallel=params["indicator"]["parallel"],
112114
weekday=weekday,

doctor_visits/delphi_doctor_visits/update_sensor.py

Lines changed: 6 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
"""
1010

1111
# standard packages
12-
from datetime import timedelta
12+
from datetime import timedelta, datetime
1313
from multiprocessing import Pool, cpu_count
14-
from pathlib import Path
1514

1615
# third party
1716
import numpy as np
@@ -24,57 +23,14 @@
2423
from .sensor import DoctorVisitsSensor
2524

2625

27-
def write_to_csv(output_df: pd.DataFrame, geo_level, se, out_name, logger, output_path="."):
28-
"""Write sensor values to csv.
29-
30-
Args:
31-
output_dict: dictionary containing sensor rates, se, unique dates, and unique geo_id
32-
se: boolean to write out standard errors, if true, use an obfuscated name
33-
out_name: name of the output file
34-
output_path: outfile path to write the csv (default is current directory)
35-
"""
36-
if se:
37-
logger.info(f"========= WARNING: WRITING SEs TO {out_name} =========")
38-
39-
out_n = 0
40-
for d in set(output_df["date"]):
41-
filename = "%s/%s_%s_%s.csv" % (output_path,
42-
(d + Config.DAY_SHIFT).strftime("%Y%m%d"),
43-
geo_level,
44-
out_name)
45-
single_date_df = output_df[output_df["date"] == d]
46-
with open(filename, "w") as outfile:
47-
outfile.write("geo_id,val,se,direction,sample_size\n")
48-
49-
for line in single_date_df.itertuples():
50-
geo_id = line.geo_id
51-
sensor = 100 * line.val # report percentages
52-
se_val = 100 * line.se
53-
assert not np.isnan(sensor), "sensor value is nan, check pipeline"
54-
assert sensor < 90, f"strangely high percentage {geo_id, sensor}"
55-
if not np.isnan(se_val):
56-
assert se_val < 5, f"standard error suspiciously high! investigate {geo_id}"
57-
58-
if se:
59-
assert sensor > 0 and se_val > 0, "p=0, std_err=0 invalid"
60-
outfile.write(
61-
"%s,%f,%s,%s,%s\n" % (geo_id, sensor, se_val, "NA", "NA"))
62-
else:
63-
# for privacy reasons we will not report the standard error
64-
outfile.write(
65-
"%s,%f,%s,%s,%s\n" % (geo_id, sensor, "NA", "NA", "NA"))
66-
out_n += 1
67-
logger.debug(f"wrote {out_n} rows for {geo_level}")
68-
69-
7026
def update_sensor(
71-
filepath, startdate, enddate, dropdate, geo, parallel,
72-
weekday, se, logger
27+
data:pd.DataFrame, startdate:datetime, enddate:datetime, dropdate:datetime, geo:str, parallel: bool,
28+
weekday:bool, se:bool, logger
7329
):
7430
"""Generate sensor values.
7531
7632
Args:
77-
filepath: path to the aggregated doctor-visits data
33+
data: dataframe of the cleaned claims file
7834
startdate: first sensor date (YYYY-mm-dd)
7935
enddate: last sensor date (YYYY-mm-dd)
8036
dropdate: data drop date (YYYY-mm-dd)
@@ -84,45 +40,10 @@ def update_sensor(
8440
se: boolean to write out standard errors, if true, use an obfuscated name
8541
logger: the structured logger
8642
"""
87-
# as of 2020-05-11, input file expected to have 10 columns
88-
# id cols: ServiceDate, PatCountyFIPS, PatAgeGroup, Pat HRR ID/Pat HRR Name
89-
# value cols: Denominator, Covid_like, Flu_like, Flu1, Mixed
90-
filename = Path(filepath).name
91-
data = pd.read_csv(
92-
filepath,
93-
dtype=Config.DTYPES,
94-
)
95-
logger.info(f"Starting processing {filename} ")
96-
data.rename(columns=Config.DEVIANT_COLS_MAP, inplace=True)
97-
data = data[Config.FILT_COLS]
98-
data[Config.DATE_COL] = data[Config.DATE_COL].apply(pd.to_datetime)
99-
logger.info(f"finished processing {filename} ")
100-
assert (
101-
np.sum(data.duplicated(subset=Config.ID_COLS)) == 0
102-
), "Duplicated data! Check the input file"
103-
104-
# drop HRR columns - unused for now since we assign HRRs by FIPS
105-
data.drop(columns=Config.HRR_COLS, inplace=True)
106-
data.dropna(inplace=True) # drop rows with any missing entries
107-
108-
# aggregate age groups (so data is unique by service date and FIPS)
109-
data = data.groupby([Config.DATE_COL, Config.GEO_COL]).sum(numeric_only=True).reset_index()
110-
assert np.sum(data.duplicated()) == 0, "Duplicates after age group aggregation"
111-
assert (data[Config.COUNT_COLS] >= 0).all().all(), "Counts must be nonnegative"
112-
113-
## collect dates
114-
# restrict to training start and end date
43+
11544
drange = lambda s, e: np.array([s + timedelta(days=x) for x in range((e - s).days)])
116-
startdate = pd.to_datetime(startdate) - Config.DAY_SHIFT
117-
burnindate = startdate - Config.DAY_SHIFT
118-
enddate = pd.to_datetime(enddate)
119-
dropdate = pd.to_datetime(dropdate)
120-
assert startdate > Config.FIRST_DATA_DATE, "Start date <= first day of data"
121-
assert startdate < enddate, "Start date >= end date"
122-
assert enddate <= dropdate, "End date > drop date"
123-
data = data[(data[Config.DATE_COL] >= Config.FIRST_DATA_DATE) & \
124-
(data[Config.DATE_COL] < dropdate)]
12545
fit_dates = drange(Config.FIRST_DATA_DATE, dropdate)
46+
burnindate = startdate - Config.DAY_SHIFT
12647
burn_in_dates = drange(burnindate, dropdate)
12748
sensor_dates = drange(startdate, enddate)
12849
# The ordering of sensor dates corresponds to the order of burn-in dates

doctor_visits/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"pytest-cov",
1212
"pytest",
1313
"scikit-learn",
14+
"dask",
1415
]
1516

1617
setup(
Binary file not shown.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Tests for update_sensor.py."""
2+
from datetime import datetime
3+
import logging
4+
import pandas as pd
5+
6+
from delphi_doctor_visits.process_data import csv_to_df
7+
8+
TEST_LOGGER = logging.getLogger()
9+
10+
class TestProcessData:
11+
def test_csv_to_df(self):
12+
actual = csv_to_df(
13+
filepath="./test_data/SYNEDI_AGG_OUTPATIENT_07022020_1455CDT.csv.gz",
14+
startdate=datetime(2020, 2, 4),
15+
enddate=datetime(2020, 2, 5),
16+
dropdate=datetime(2020, 2,6),
17+
logger=TEST_LOGGER,
18+
)
19+
20+
comparison = pd.read_pickle("./comparison/process_data/main_after_date_SYNEDI_AGG_OUTPATIENT_07022020_1455CDT.pkl")
21+
pd.testing.assert_frame_equal(actual.reset_index(drop=True), comparison)

doctor_visits/tests/test_update_sensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Tests for update_sensor.py."""
2+
from datetime import datetime
23
import logging
34
import pandas as pd
45

@@ -8,11 +9,12 @@
89

910
class TestUpdateSensor:
1011
def test_update_sensor(self):
12+
df = pd.read_pickle("./test_data/SYNEDI_AGG_OUTPATIENT_07022020_1455CDT.pkl")
1113
actual = update_sensor(
12-
filepath="./test_data/SYNEDI_AGG_OUTPATIENT_07022020_1455CDT.csv.gz",
13-
startdate="2020-02-04",
14-
enddate="2020-02-05",
15-
dropdate="2020-02-06",
14+
data=df,
15+
startdate=datetime(2020, 2, 4),
16+
enddate=datetime(2020, 2, 5),
17+
dropdate=datetime(2020, 2,6),
1618
geo="state",
1719
parallel=False,
1820
weekday=False,

0 commit comments

Comments
 (0)