Skip to content

Commit 6b8a9fb

Browse files
committed
initial nssp covid forecast
1 parent 965fce9 commit 6b8a9fb

File tree

4 files changed

+232
-48
lines changed

4 files changed

+232
-48
lines changed

R/aux_data_utils.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,10 +687,22 @@ up_to_date_nssp_state_archive <- function(disease = c("covid", "influenza")) {
687687
issues = "*"
688688
)
689689
nssp_state %>%
690-
select(geo_value, time_value, issue, nssp = value) %>%
690+
select(geo_value, time_value, version = issue, nssp = value) %>%
691+
bind_rows(get_nssp_github()) %>%
691692
as_epi_archive(compactify = TRUE) %>%
692693
extract2("DT") %>%
693694
# End of week to midweek correction.
694-
mutate(time_value = time_value + 3) %>%
695+
mutate(time_value = floor_date(time_value, "week", week_start = 4)-1) %>%
695696
as_epi_archive(compactify = TRUE)
696697
}
698+
699+
get_nssp_github <- function() {
700+
raw_file <- read_csv("https://raw.githubusercontent.com/CDCgov/covid19-forecast-hub/refs/heads/main/auxiliary-data/nssp-raw-data/latest.csv")
701+
state_map <- get_population_data() %>% filter(state_id !="usa")
702+
raw_file %>%
703+
filter(county == "All") %>%
704+
left_join(state_map, by = join_by(geography == state_name)) %>%
705+
select(geo_value = state_id, time_value = week_end, nssp = percent_visits_covid) %>%
706+
mutate(time_value = floor_date(time_value, "week", week_start = 4)-1) %>%
707+
mutate(version = Sys.Date())
708+
}

R/forecasters/forecaster_scaled_pop_seasonal.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ scaled_pop_seasonal <- function(
5656
clip_lower = TRUE,
5757
...
5858
) {
59+
browser()
5960
scale_method <- arg_match(scale_method)
6061
center_method <- arg_match(center_method)
6162
nonlin_method <- arg_match(nonlin_method)

scripts/covid_hosp_prod.R

Lines changed: 172 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ g_climate_geo_agged <- function(epi_data, ahead, extra_data, ...) {
6262
)
6363
}
6464
g_windowed_seasonal <- function(epi_data, ahead, extra_data, ...) {
65+
epi_data %>%
66+
mutate(wkday = wday(time_value)) %>%
67+
distinct(wkday, .keep_all = TRUE) %>%
68+
select(time_value, wkday)
6569
fcst <-
6670
epi_data %>%
6771
scaled_pop_seasonal(
@@ -148,6 +152,10 @@ parameters_and_date_targets <- rlang::list2(
148152
covid_geo_exclusions,
149153
command = "scripts/covid_geo_exclusions.csv"
150154
),
155+
tar_file(
156+
covid_nssp_geo_exclusions,
157+
command = "scripts/covid_nssp_geo_exclusions.csv"
158+
),
151159
tar_file(
152160
covid_data_substitutions,
153161
command = "scripts/covid_data_substitutions.csv"
@@ -178,7 +186,8 @@ parameters_and_date_targets <- rlang::list2(
178186
nssp_latest_data,
179187
command = {
180188
nssp_archive_data %>%
181-
epix_as_of(min(Sys.Date(), nssp_archive_data$versions_end))
189+
epix_as_of(min(Sys.Date(), nssp_archive_data$versions_end)) %>%
190+
arrange(desc(time_value))
182191
}
183192
)
184193
)
@@ -196,7 +205,7 @@ forecast_targets <- tar_map(
196205
),
197206
names = c("id", "forecast_date_chr"),
198207
tar_target(
199-
name = forecast_res,
208+
name = forecast_nhsn,
200209
command = {
201210
# if the forecaster is named latest, it should use the most up to date
202211
# version of the data
@@ -217,7 +226,7 @@ forecast_targets <- tar_map(
217226
add_season_info() %>%
218227
mutate(
219228
geo_value = ifelse(geo_value == "usa", "us", geo_value),
220-
time_value = time_value - 3
229+
time_value = floor_date(time_value, "week", week_start = 4) - 1
221230
) %>%
222231
filter(geo_value %nin% g_insufficient_data_geos)
223232
if (!grepl("latest", id)) {
@@ -235,11 +244,66 @@ forecast_targets <- tar_map(
235244
)
236245
},
237246
pattern = map(aheads)
247+
),
248+
tar_target(
249+
name = forecast_nssp,
250+
command = {
251+
# if the forecaster is named latest, it should use the most up to date
252+
# version of the data
253+
if (grepl("latest", id)) {
254+
nhsn_data <- nhsn_archive_data %>%
255+
epix_as_of(nhsn_archive_data$versions_end) %>%
256+
filter(time_value < as.Date(forecast_date_int))
257+
nssp_data <- nssp_archive_data %>%
258+
epix_as_of(nssp_archive_data$versions_end) %>%
259+
filter(time_value < as.Date(forecast_date_int))
260+
} else {
261+
nhsn_data <- nhsn_archive_data %>%
262+
epix_as_of(min(as.Date(forecast_date_int), nhsn_archive_data$versions_end))
263+
nssp_data <- nssp_archive_data %>%
264+
epix_as_of(min(as.Date(forecast_date_int), nssp_archive_data$versions_end))
265+
}
266+
nssp_data <- nssp_data %>%
267+
rename(value = nssp) %>%
268+
add_season_info() %>%
269+
mutate(
270+
geo_value = ifelse(geo_value == "usa", "us", geo_value),
271+
time_value = floor_date(time_value, "week", week_start = 4) - 1
272+
) %>%
273+
filter(geo_value %nin% g_insufficient_data_geos)
274+
275+
if (!grepl("latest", id)) {
276+
nhsn_data %<>%
277+
data_substitutions(covid_data_substitutions, as.Date(forecast_generation_date_int))
278+
}
279+
280+
# jank renaming to avoid hard-coded variable name problems
281+
nhsn_data %<>% rename(nssp = value)
282+
attributes(nssp_data)$metadata$as_of <- as.Date(forecast_date_int)
283+
284+
forecaster_fn <- get_partially_applied_forecaster(forecaster, aheads, params, param_names)
285+
286+
forecaster_fn(nssp_data, extra_data = nhsn_data) %>%
287+
mutate(
288+
forecaster = id,
289+
geo_value = as.factor(geo_value)
290+
)
291+
},
292+
pattern = map(aheads)
238293
)
239294
)
240-
combined_forecasts <- tar_combine(
241-
name = forecast_full,
242-
forecast_targets[["forecast_res"]],
295+
296+
combined_nhsn_forecasts <- tar_combine(
297+
name = forecast_nhsn_full,
298+
forecast_targets[["forecast_nhsn"]],
299+
command = {
300+
dplyr::bind_rows(!!!.x)
301+
}
302+
)
303+
304+
combined_nssp_forecasts <- tar_combine(
305+
name = forecast_nssp_full,
306+
forecast_targets[["forecast_nssp"]],
243307
command = {
244308
dplyr::bind_rows(!!!.x)
245309
}
@@ -254,9 +318,16 @@ ensemble_targets <- tar_map(
254318
),
255319
names = "forecast_date_chr",
256320
tar_target(
257-
name = forecast_full_filtered,
321+
name = forecast_nhsn_full_filtered,
258322
command = {
259-
forecast_full %>%
323+
forecast_nhsn_full %>%
324+
filter(forecast_date == as.Date(forecast_date_int))
325+
}
326+
),
327+
tar_target(
328+
name = forecast_nssp_full_filtered,
329+
command = {
330+
forecast_nssp_full %>%
260331
filter(forecast_date == as.Date(forecast_date_int))
261332
}
262333
),
@@ -274,6 +345,17 @@ ensemble_targets <- tar_map(
274345
geo_forecasters_weights
275346
},
276347
),
348+
tar_target(
349+
name = geo_nssp_forecasters_weights,
350+
command = {
351+
geo_nssp_forecasters_weights <-
352+
parse_prod_weights(covid_nssp_geo_exclusions, forecast_date_int, g_forecaster_params_grid$id)
353+
if (nrow(geo_nssp_forecasters_weights %>% filter(forecast_date == as.Date(forecast_date_int))) == 0) {
354+
cli_abort("there are no weights for the forecast date {forecast_date}")
355+
}
356+
geo_nssp_forecasters_weights
357+
},
358+
),
277359
tar_target(
278360
name = geo_exclusions,
279361
command = {
@@ -283,7 +365,7 @@ ensemble_targets <- tar_map(
283365
tar_target(
284366
name = ensemble_clim_lin,
285367
command = {
286-
forecast_full_filtered %>%
368+
forecast_nhsn_full_filtered %>%
287369
ensemble_climate_linear(
288370
aheads,
289371
other_weights = geo_forecasters_weights,
@@ -299,7 +381,7 @@ ensemble_targets <- tar_map(
299381
tar_target(
300382
name = ens_ar_only,
301383
command = {
302-
forecast_full_filtered %>%
384+
forecast_nhsn_full_filtered %>%
303385
filter(forecaster %in% c("windowed_seasonal", "windowed_seasonal_extra_sources")) %>%
304386
group_by(geo_value, forecast_date, target_end_date, quantile) %>%
305387
summarize(value = mean(value, na.rm = TRUE), .groups = "drop") %>%
@@ -312,25 +394,47 @@ ensemble_targets <- tar_map(
312394
command = {
313395
ensemble_clim_lin %>%
314396
bind_rows(
315-
forecast_full_filtered %>%
397+
forecast_nhsn_full_filtered %>%
316398
filter(forecaster %in% c("windowed_seasonal", "windowed_seasonal_extra_sources")) %>%
317399
filter(forecast_date < target_end_date) # don't use for neg aheads
318400
) %>%
319401
ensemble_weighted(geo_forecasters_weights) %>%
320402
mutate(forecaster = "ensemble_mix")
321403
},
322404
),
405+
tar_target(
406+
name = ensemble_nssp_mixture_res,
407+
command = {
408+
ensemble_clim_lin %>%
409+
bind_rows(
410+
forecast_nssp_full_filtered %>%
411+
filter(forecaster %in% c("windowed_seasonal", "windowed_seasonal_extra_sources")) %>%
412+
filter(forecast_date < target_end_date) # don't use for neg aheads
413+
) %>%
414+
ensemble_weighted(geo_nssp_forecasters_weights) %>%
415+
mutate(forecaster = "ensemble_mix")
416+
},
417+
),
323418
tar_target(
324419
name = forecasts_and_ensembles,
325420
command = {
326421
bind_rows(
327-
forecast_full_filtered,
422+
forecast_nhsn_full_filtered,
328423
ensemble_clim_lin,
329424
ensemble_mixture_res,
330425
ens_ar_only
331426
)
332427
}
333428
),
429+
tar_target(
430+
name = forecasts_and_ensembles_nssp,
431+
command = {
432+
bind_rows(
433+
forecast_nssp_full_filtered,
434+
ensemble_nssp_mixture_res,
435+
)
436+
}
437+
),
334438
tar_target(
335439
name = make_submission_csv,
336440
command = {
@@ -347,11 +451,24 @@ ensemble_targets <- tar_map(
347451
}
348452
},
349453
),
454+
tar_target(
455+
name = make_nssp_submission_csv,
456+
command = {
457+
if (!g_backtest_mode && g_submission_directory != "cache") {
458+
forecast_reference_date <- get_forecast_reference_date(forecast_date_int)
459+
ensemble_nssp_mixture_res %>%
460+
format_flusight(disease = "covid") %>%
461+
write_submission_file(forecast_reference_date, file.path(g_submission_directory, "model-output/CMU-TimeSeries"))
462+
} else {
463+
cli_alert_info("Not making submission csv because we're in backtest mode or submission directory is cache")
464+
}
465+
},
466+
),
350467
tar_target(
351468
name = make_climate_submission_csv,
352469
command = {
353470
if (!g_backtest_mode && g_submission_directory != "cache") {
354-
forecast_full_filtered %>%
471+
forecast_nhsn_full_filtered %>%
355472
filter(forecaster %in% c("climate_base", "climate_geo_agged")) %>%
356473
group_by(geo_value, target_end_date, quantile) %>%
357474
summarize(forecast_date = as.Date(forecast_date_int), value = mean(value, na.rm = TRUE), .groups = "drop") %>%
@@ -406,19 +523,27 @@ ensemble_targets <- tar_map(
406523
},
407524
),
408525
tar_target(
409-
name = truth_data,
526+
name = truth_data_pre_process,
410527
command = {
411528
# Plot both as_of and latest data to compare
412529
nhsn_data <- nhsn_archive_data %>%
413530
epix_as_of(min(as.Date(forecast_generation_date_int), nhsn_archive_data$versions_end)) %>%
414531
mutate(source = "nhsn as_of forecast") %>%
415532
bind_rows(nhsn_latest_data %>% mutate(source = "nhsn")) %>%
416-
select(geo_value, target_end_date = time_value, value) %>%
533+
select(geo_value, target_end_date = time_value, value, source) %>%
417534
filter(target_end_date > g_truth_data_date, geo_value %nin% g_insufficient_data_geos)
418535
nssp_data <- nssp_latest_data %>%
419536
select(geo_value, target_end_date = time_value, value = nssp) %>%
420537
filter(target_end_date > g_truth_data_date, geo_value %nin% g_insufficient_data_geos) %>%
421538
mutate(target_end_date = target_end_date + 3, source = "nssp")
539+
list(nhsn_data, nssp_data)
540+
}
541+
),
542+
tar_target(
543+
name = truth_data_nhsn,
544+
command = {
545+
nhsn_data <- truth_data_pre_process[[1]]
546+
nssp_data <- truth_data_pre_process[[2]]
422547
nssp_renormalized <-
423548
nssp_data %>%
424549
left_join(
@@ -436,11 +561,36 @@ ensemble_targets <- tar_map(
436561
mutate(value = value * rel_max_value) %>%
437562
select(-rel_max_value)
438563
nhsn_data %>% bind_rows(nssp_renormalized)
439-
},
564+
}
565+
),
566+
tar_target(
567+
name = truth_data_nssp,
568+
command = {
569+
nhsn_data <- truth_data_pre_process[[1]]
570+
nssp_data <- truth_data_pre_process[[2]]
571+
nhsn_renormalized <-
572+
nhsn_data %>%
573+
left_join(
574+
nhsn_data %>%
575+
rename(nssp = value) %>%
576+
full_join(
577+
nssp_data %>%
578+
select(geo_value, target_end_date, value),
579+
by = join_by(geo_value, target_end_date)
580+
) %>%
581+
group_by(geo_value) %>%
582+
summarise(rel_max_value = max(value, na.rm = TRUE) / max(nssp, na.rm = TRUE)),
583+
by = join_by(geo_value)
584+
) %>%
585+
mutate(value = value * rel_max_value) %>%
586+
select(-rel_max_value)
587+
nssp_data %>% bind_rows(nhsn_renormalized)
588+
}
440589
),
441590
tar_target(
442591
notebook,
443592
command = {
593+
browser()
444594
# Only render the report if there is only one forecast date
445595
# i.e. we're running this in prod on schedule
446596
if (!g_backtest_mode) {
@@ -453,9 +603,11 @@ ensemble_targets <- tar_map(
453603
),
454604
params = list(
455605
disease = "covid",
456-
forecast_res = forecasts_and_ensembles %>% ungroup() %>% filter(forecaster != "climate_geo_agged"),
606+
forecast_nhsn = forecasts_and_ensembles %>% ungroup() %>% filter(forecaster != "climate_geo_agged"),
607+
forecast_nssp = forecasts_and_ensembles_nssp,
457608
forecast_date = as.Date(forecast_date_int),
458-
truth_data = truth_data
609+
truth_data_nhsn = truth_data_nhsn,
610+
truth_data_nssp = truth_data_nssp
459611
)
460612
)
461613
}
@@ -506,6 +658,7 @@ list2(
506658
parameters_and_date_targets,
507659
forecast_targets,
508660
ensemble_targets,
509-
combined_forecasts,
661+
combined_nhsn_forecasts,
662+
combined_nssp_forecasts,
510663
score_targets
511664
)

0 commit comments

Comments
 (0)