Skip to content

Commit bd1d91e

Browse files
committed
feat+fix: score covid prod
* covid prod now has two modes: prod and backtest, new reports available * fix prod pipelines for package updates * add retry function for failing API calls * add timestamps to targets output * update forecast data Julia (ty David) * add forecast data R code * fix daily_to_weekly_archive for epiprocess update, add comments * add tar_change to some data-dependent targets fix: tests fix: make flu_hosp_prod backtest_mode aware fix: typo fix: covid prod good to go fix fix: a few bugs and update exclusions fix: data_substitutions
1 parent c91a53f commit bd1d91e

22 files changed

+973
-367
lines changed

Makefile

+2-1
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ submit: submit-covid submit-flu
6363
get_nwss:
6464
mkdir -p aux_data/nwss_covid_data; \
6565
mkdir -p aux_data/nwss_flu_data; \
66+
. .venv/bin/activate; \
6667
cd scripts/nwss_export_tool/; \
6768
python nwss_covid_export.py; \
68-
python nwss_covid_export.py
69+
python nwss_influenza_export.py
6970

7071
run-nohup:
7172
nohup Rscript scripts/run.R &

R/aux_data_utils.R

+56-21
Original file line numberDiff line numberDiff line change
@@ -188,38 +188,56 @@ daily_to_weekly <- function(epi_df, agg_method = c("sum", "mean"), day_of_week =
188188
select(-epiweek, -year)
189189
}
190190

191+
#' Aggregate a daily archive to a weekly archive.
192+
#'
193+
#' @param epi_arch the archive to aggregate.
194+
#' @param agg_columns the columns to aggregate.
195+
#' @param agg_method the method to use to aggregate the data, one of "sum" or "mean".
196+
#' @param day_of_week the day of the week to use as the reference day.
197+
#' @param day_of_week_end the day of the week to use as the end of the week.
191198
daily_to_weekly_archive <- function(epi_arch,
192199
agg_columns,
193200
agg_method = c("sum", "mean"),
194201
day_of_week = 4L,
195202
day_of_week_end = 7L) {
203+
# How to aggregate the windowed data.
196204
agg_method <- arg_match(agg_method)
205+
# The columns we will later group by when aggregating.
197206
keys <- key_colnames(epi_arch, exclude = c("time_value", "version"))
207+
# The versions we will slide over.
198208
ref_time_values <- epi_arch$DT$version %>%
199209
unique() %>%
200210
sort()
211+
# Choose a fast function to use to slide and aggregate.
201212
if (agg_method == "sum") {
202213
slide_fun <- epi_slide_sum
203214
} else if (agg_method == "mean") {
204215
slide_fun <- epi_slide_mean
205216
}
206-
too_many_tibbles <- epix_slide(
217+
# Slide over the versions and aggregate.
218+
epix_slide(
207219
epi_arch,
208-
.before = 99999999L,
209220
.versions = ref_time_values,
210-
function(x, group, ref_time) {
221+
function(x, group_keys, ref_time) {
222+
# The last day of the week we will slide over.
211223
ref_time_last_week_end <-
212224
floor_date(ref_time, "week", day_of_week_end - 1) # this is over by 1
225+
# The last day of the week we will slide over.
213226
max_time <- max(x$time_value)
227+
# The days we will slide over.
214228
valid_slide_days <- seq.Date(
215229
from = ceiling_date(min(x$time_value), "week", week_start = day_of_week_end - 1),
216230
to = floor_date(max(x$time_value), "week", week_start = day_of_week_end - 1),
217231
by = 7L
218232
)
233+
# If the last day of the week is not the end of the week, add it to the
234+
# list of valid slide days (this will produce an incomplete slide, but
235+
# that's fine for us, since it should only be 1 day, historically.)
219236
if (wday(max_time) != day_of_week_end) {
220237
valid_slide_days <- c(valid_slide_days, max_time)
221238
}
222-
slid_result <- x %>%
239+
# Slide over the days and aggregate.
240+
x %>%
223241
group_by(across(all_of(keys))) %>%
224242
slide_fun(
225243
agg_columns,
@@ -229,18 +247,13 @@ daily_to_weekly_archive <- function(epi_arch,
229247
) %>%
230248
select(-all_of(agg_columns)) %>%
231249
rename_with(~ gsub("slide_value_", "", .x)) %>%
232-
# only keep 1/week
233-
# group_by week, keep the largest in each week
234-
# alternatively
235-
# switch time_value to the designated day of the week
250+
rename_with(~ gsub("_7dsum", "", .x)) %>%
251+
# Round all dates to reference day of the week. These will get
252+
# de-duplicated by compactify in as_epi_archive below.
236253
mutate(time_value = round_date(time_value, "week", day_of_week - 1)) %>%
237254
as_tibble()
238255
}
239-
)
240-
too_many_tibbles %>%
241-
pull(time_value) %>%
242-
max()
243-
too_many_tibbles %>%
256+
) %>%
244257
as_epi_archive(compactify = TRUE)
245258
}
246259

@@ -313,9 +326,8 @@ get_health_data <- function(as_of, disease = c("covid", "flu")) {
313326

314327
most_recent_row <- meta_data %>%
315328
# update_date is actually a time, so we need to filter for the day after.
316-
filter(update_date <= as_of + 1) %>%
317-
arrange(desc(update_date)) %>%
318-
slice(1)
329+
filter(update_date <= as.Date(as_of) + 1) %>%
330+
slice_max(update_date)
319331

320332
if (nrow(most_recent_row) == 0) {
321333
cli::cli_abort("No data available for the given date.")
@@ -331,9 +343,7 @@ get_health_data <- function(as_of, disease = c("covid", "flu")) {
331343
if (disease == "covid") {
332344
data %<>% mutate(
333345
hhs = previous_day_admission_adult_covid_confirmed +
334-
previous_day_admission_adult_covid_suspected +
335-
previous_day_admission_pediatric_covid_confirmed +
336-
previous_day_admission_pediatric_covid_suspected
346+
previous_day_admission_pediatric_covid_confirmed
337347
)
338348
} else if (disease == "flu") {
339349
data %<>% mutate(hhs = previous_day_admission_influenza_confirmed)
@@ -709,10 +719,12 @@ create_nhsn_data_archive <- function(disease_name) {
709719
as_epi_archive(compactify = TRUE)
710720
}
711721

712-
713722
up_to_date_nssp_state_archive <- function(disease = c("covid", "influenza")) {
714723
disease <- arg_match(disease)
715-
nssp_state <- pub_covidcast(
724+
nssp_state <- retry_fn(
725+
max_attempts = 10,
726+
wait_seconds = 1,
727+
fn = pub_covidcast,
716728
source = "nssp",
717729
signal = glue::glue("pct_ed_visits_{disease}"),
718730
time_type = "week",
@@ -728,3 +740,26 @@ up_to_date_nssp_state_archive <- function(disease = c("covid", "influenza")) {
728740
mutate(time_value = time_value + 3) %>%
729741
as_epi_archive(compactify = TRUE)
730742
}
743+
744+
# Get the last time the signal was updated.
745+
get_covidcast_signal_last_update <- function(source, signal) {
746+
pub_covidcast_meta() %>%
747+
filter(source == !!source, signal == !!signal) %>%
748+
pull(last_update) %>%
749+
as.POSIXct()
750+
}
751+
752+
# Get the last time the Socrata dataset was updated.
753+
get_socrata_updated_at <- function(dataset_url) {
754+
httr::GET(dataset_url) %>%
755+
httr::content() %>%
756+
pluck("rowsUpdatedAt") %>%
757+
as.POSIXct()
758+
}
759+
760+
get_s3_object_last_modified <- function(bucket, key) {
761+
# Format looks like "Fri, 31 Jan 2025 22:01:16 GMT"
762+
attr(aws.s3::head_object(key, bucket = bucket), "last-modified") %>%
763+
str_replace_all(" GMT", "") %>%
764+
as.POSIXct(format = "%a, %d %b %Y %H:%M:%S")
765+
}

R/forecasters/epipredict_utilities.R

+6-3
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ run_workflow_and_format <- function(preproc,
114114
if (is.null(as_of)) {
115115
as_of <- max(train_data$time_value)
116116
}
117+
118+
# Look at the train data (uncomment for debuggin).
119+
# df <- preproc %>% prep(train_data) %>% bake(train_data)
120+
# browser()
121+
117122
workflow <- epi_workflow(preproc, trainer) %>%
118123
fit(train_data) %>%
119124
add_frosting(postproc)
@@ -125,9 +130,7 @@ run_workflow_and_format <- function(preproc,
125130
# keeping only the last time_value for any given location/key
126131
pred %<>%
127132
group_by(across(all_of(key_colnames(train_data, exclude = "time_value")))) %>%
128-
# TODO: slice_max(time_value)?
129-
arrange(time_value) %>%
130-
filter(row_number() == n()) %>%
133+
slice_max(time_value) %>%
131134
ungroup()
132135
return(format_storage(pred, as_of))
133136
}

R/imports.R

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ library(parsnip)
2929
library(paws.storage)
3030
library(plotly)
3131
library(purrr)
32+
library(qs2)
3233
library(quantreg)
3334
library(readr)
3435
library(recipes)

R/scoring.R

+17-13
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,34 @@
11
# Scoring and Evaluation Functions
22

3-
evaluate_predictions <- function(predictions_cards, truth_data) {
4-
checkmate::assert_data_frame(predictions_cards)
3+
evaluate_predictions <- function(forecasts, truth_data) {
4+
checkmate::assert_data_frame(forecasts)
55
checkmate::assert_data_frame(truth_data)
66
checkmate::assert_names(
7-
names(predictions_cards),
7+
names(forecasts),
88
must.include = c("model", "geo_value", "forecast_date", "target_end_date", "quantile", "prediction")
99
)
1010
checkmate::assert_names(
1111
names(truth_data),
1212
must.include = c("geo_value", "target_end_date", "true_value")
1313
)
1414

15-
left_join(predictions_cards, truth_data, by = c("geo_value", "target_end_date")) %>%
16-
scoringutils::score(metrics = c("interval_score", "ae_median", "coverage")) %>%
17-
scoringutils::add_coverage(by = c("model", "geo_value", "forecast_date", "target_end_date"), ranges = c(80)) %>%
18-
scoringutils::summarize_scores(by = c("model", "geo_value", "forecast_date", "target_end_date")) %>%
15+
forecast_obj <- left_join(forecasts, truth_data, by = c("geo_value", "target_end_date")) %>%
16+
scoringutils::as_forecast_quantile(
17+
quantile_level = "quantile",
18+
observed = "true_value",
19+
predicted = "prediction",
20+
forecast_unit = c("model", "geo_value", "forecast_date", "target_end_date")
21+
)
22+
23+
scores <- forecast_obj %>%
24+
scoringutils::score(metrics = get_metrics(.)) %>%
1925
as_tibble() %>%
2026
select(
21-
model,
22-
geo_value,
23-
forecast_date,
24-
target_end_date,
25-
wis = interval_score,
27+
model, geo_value, forecast_date, target_end_date,
28+
wis,
2629
ae = ae_median,
27-
coverage_80
30+
coverage_50 = interval_coverage_50,
31+
coverage_90 = interval_coverage_90
2832
) %>%
2933
mutate(ahead = as.numeric(target_end_date - forecast_date))
3034
}

0 commit comments

Comments
 (0)