Skip to content

Commit 986b657

Browse files
committed
rebase fixes, error classes, unskip latency tests
1 parent a28ad82 commit 986b657

11 files changed

+189
-116
lines changed

NAMESPACE

+2-1
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ importFrom(quantreg,rq)
251251
importFrom(recipes,bake)
252252
importFrom(recipes,detect_step)
253253
importFrom(recipes,prep)
254-
importFrom(rlang,"!!!")
255254
importFrom(recipes,recipes_eval_select)
255+
importFrom(rlang,"!!!")
256256
importFrom(rlang,"!!")
257257
importFrom(rlang,"%@%")
258258
importFrom(rlang,"%||%")
@@ -262,6 +262,7 @@ importFrom(rlang,caller_env)
262262
importFrom(rlang,enquos)
263263
importFrom(rlang,global_env)
264264
importFrom(rlang,inject)
265+
importFrom(rlang,is_empty)
265266
importFrom(rlang,is_logical)
266267
importFrom(rlang,is_null)
267268
importFrom(rlang,is_true)

R/canned-epipred.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ print.canned_epipred <- function(x, name, ...) {
133133
purrr::map("columns") %>%
134134
reduce(c)
135135
latency_per_base_col <- latency_step$latency_table %>%
136-
filter(col_name %in% valid_columns) %>% mutate(latency = abs(latency))
136+
filter(col_name %in% valid_columns) %>%
137+
mutate(latency = abs(latency))
137138
if (latency_step$method != "locf" && nrow(latency_per_base_col) > 1) {
138139
intro_text <- glue::glue("{type_str} adjusted per column: ")
139140
} else if (latency_step$method != "locf") {

R/cdc_baseline_forecaster.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ cdc_baseline_forecaster <- function(
7979

8080

8181
latest <- get_test_data(
82-
epi_recipe(epi_data), epi_data)
82+
epi_recipe(epi_data), epi_data
83+
)
8384

8485
f <- frosting() %>%
8586
layer_predict() %>%

R/epi_shift.R

+4-3
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,17 @@ add_shifted_columns <- function(new_data, object, amount) {
4545
shift_sign_lat <- attributes(new_data)$metadata$shift_sign
4646
if (!is.null(latency_table) &&
4747
shift_sign_lat == sign_shift) {
48-
#TODO this doesn't work on lags of transforms
48+
# TODO this doesn't work on lags of transforms
4949
rel_latency <- latency_table %>% filter(col_name %in% object$columns)
5050
} else {
5151
rel_latency <- tibble(col_name = object$columns, latency = 0L)
5252
}
53-
grid <- expand_grid(col = object$columns, amount = sign_shift *amount) %>%
53+
grid <- expand_grid(col = object$columns, amount = sign_shift * amount) %>%
5454
left_join(rel_latency, by = join_by(col == col_name), ) %>%
5555
tidyr::replace_na(list(latency = 0)) %>%
5656
dplyr::mutate(
57-
shift_val = amount + latency) %>%
57+
shift_val = amount + latency
58+
) %>%
5859
mutate(
5960
newname = glue::glue("{object$prefix}{abs(shift_val)}_{col}"), # name is always positive
6061
amount = NULL,

R/get_test_data.R

+5-13
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,9 @@
1111
#' used if growth rate calculations are requested by the recipe. This is
1212
#' calculated internally.
1313
#'
14-
#' It also optionally fills missing values
15-
#' using the last-observation-carried-forward (LOCF) method. If this
16-
#' is not possible (say because there would be only `NA`'s in some location),
17-
#' it will produce an error suggesting alternative options to handle missing
18-
#' values with more advanced techniques.
19-
#'
2014
#' @param recipe A recipe object.
2115
#' @param x An epi_df. The typical usage is to
2216
#' pass the same data as that used for fitting the recipe.
23-
#' @param forecast_date By default, this is set to the maximum
24-
#' `time_value` in `x`. But if there is data latency such that recent `NA`'s
25-
#' should be filled, this may be _after_ the last available `time_value`.
2617
#'
2718
#' @return An object of the same type as `x` with columns `geo_value`, `time_value`, any additional
2819
#' keys, as well other variables in the original dataset.
@@ -36,9 +27,7 @@
3627
#' @importFrom rlang %@%
3728
#' @export
3829

39-
get_test_data <- function(
40-
recipe,
41-
x) {
30+
get_test_data <- function(recipe, x) {
4231
if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.")
4332

4433
check <- hardhat::check_column_names(x, colnames(recipe$template))
@@ -64,7 +53,10 @@ get_test_data <- function(
6453
"!" = "but `x` contains only {avail_recent}."
6554
))
6655
}
67-
max_time_value <- x %>% na.omit %>% pull(time_value) %>% max
56+
max_time_value <- x %>%
57+
na.omit() %>%
58+
pull(time_value) %>%
59+
max()
6860
x <- arrange(x, time_value)
6961
groups <- kill_time_value(epi_keys(recipe))
7062

R/step_adjust_latency.R

+15-11
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
#' jhu_fit
8787
#'
8888
#' @importFrom recipes detect_step
89-
#' @importFrom rlang enquos
89+
#' @importFrom rlang enquos is_empty
9090
step_adjust_latency <-
9191
function(recipe,
9292
...,
@@ -106,39 +106,43 @@ step_adjust_latency <-
106106
id = recipes::rand_id("adjust_latency")) {
107107
arg_is_chr_scalar(id, method)
108108
if (!is_epi_recipe(recipe)) {
109-
cli::cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
109+
cli::cli_abort("This recipe step can only operate on an {.cls epi_recipe}.", class = "epipredict__step_adjust_latency__epi_recipe_only")
110110
}
111111
if (!is.null(columns)) {
112112
cli::cli_abort(c("The `columns` argument must be `NULL`.",
113113
i = "Use `tidyselect` methods to choose columns to lag."
114-
))
114+
), class = "epipredict__step_adjust_latency__cols_not_null")
115115
}
116116
if ((method == "extend_ahead") && (detect_step(recipe, "epi_ahead"))) {
117117
cli::cli_warn(
118-
"If `method` is {.val extend_ahead}, then the previous `step_epi_ahead` won't be modified."
118+
"If `method` is {.val extend_ahead}, then the previous `step_epi_ahead` won't be modified.",
119+
class = "epipredict__step_adjust_latency__misordered_step_warning"
119120
)
120121
} else if ((method == "extend_lags") && detect_step(recipe, "epi_lag")) {
121122
cli::cli_warn(
122123
"If `method` is {.val extend_lags} or {.val locf},
123-
then the previous `step_epi_lag`s won't work with modified data."
124+
then the previous `step_epi_lag`s won't work with modified data.",
125+
class = "epipredict__step_adjust_latency__misordered_step_warning"
124126
)
125127
} else if ((method == "locf") && (length(recipe$steps) > 0)) {
126-
cli::cli_warn("There are steps before `step_adjust_latency`. With the method {.val locf}, it is recommended to include this step before any others")
128+
cli::cli_warn("There are steps before `step_adjust_latency`. With the method {.val locf}, it is recommended to include this step before any others",
129+
class = "epipredict__step_adjust_latency__misordered_step_warning"
130+
)
127131
}
128132
if (detect_step(recipe, "naomit")) {
129133
cli::cli_abort("adjust_latency needs to occur before any `NA` removal,
130-
as columns may be moved around")
134+
as columns may be moved around", class = "epipredict__step_adjust_latency__post_NA_error")
131135
}
132136
if (!is.null(fixed_latency) && !is.null(fixed_forecast_date)) {
133137
cli::cli_abort("Only one of `fixed_latency` and `fixed_forecast_date`
134-
can be non-`NULL` at a time!")
138+
can be non-`NULL` at a time!", class = "epipredict__step_adjust_latency__too_many_args_error")
135139
}
136140
if (length(fixed_latency > 1)) {
137141
template <- recipe$template
138142
data_names <- names(template)[!names(template) %in% epi_keys(template)]
139143
wrong_names <- names(fixed_latency)[!names(fixed_latency) %in% data_names]
140144
if (length(wrong_names) > 0) {
141-
cli::cli_abort("{.val fixed_latency} contains names not in the template dataset: {wrong_names}")
145+
cli::cli_abort("{.val fixed_latency} contains names not in the template dataset: {wrong_names}", class = "epipredict__step_adjust_latency__undefined_names_error")
142146
}
143147
}
144148

@@ -258,8 +262,8 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
258262
"{time_type}."
259263
),
260264
"i" = "latency: {latency_table$latency[[i_latency]]}",
261-
"i" = "`max_time` = {max_time} -> `forecast_date` = {forecast_date}"
262-
))
265+
"i" = "`max_time` = {max(training$time_value)} -> `forecast_date` = {forecast_date}"
266+
), class = "epipredict__prep.step_latency__very_large_latency")
263267
}
264268

265269
step_adjust_latency_new(

R/utils-latency.R

+32-20
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@ set_forecast_date <- function(new_data, info, epi_keys_checked, latency) {
3838
pull(variable)
3939
# make sure that there's enough column names
4040
if (length(original_columns) < 3) {
41-
cli::cli_abort(glue::glue(
42-
"The original columns of `time_value`, ",
43-
"`geo_value` and at least one signal. The current colums are \n",
44-
paste(capture.output(object$info), collapse = "\n\n")
45-
))
41+
cli::cli_abort(
42+
glue::glue(
43+
"The original columns of `time_value`, ",
44+
"`geo_value` and at least one signal. The current colums are \n",
45+
paste(capture.output(object$info), collapse = "\n\n")
46+
),
47+
class = "epipredict__set_forecast_date__too_few_data_columns"
48+
)
4649
}
4750
# the source data determines the actual time_values
4851
# these are the non-na time_values;
@@ -65,25 +68,34 @@ set_forecast_date <- function(new_data, info, epi_keys_checked, latency) {
6568
}
6669
# make sure the as_of is sane
6770
if (!inherits(forecast_date, class(max_time)) & !inherits(forecast_date, "POSIXt")) {
68-
cli::cli_abort(paste(
69-
"the data matrix `forecast_date` value is {forecast_date}, ",
70-
"and not a valid `time_type` with type ",
71-
"matching `time_value`'s type of ",
72-
"{class(max_time)}."
73-
))
71+
cli::cli_abort(
72+
paste(
73+
"the data matrix `forecast_date` value is {forecast_date}, ",
74+
"and not a valid `time_type` with type ",
75+
"matching `time_value`'s type of ",
76+
"{class(max_time)}."
77+
),
78+
class = "epipredict__set_forecast_date__wrong_time_value_type_error"
79+
)
7480
}
7581
if (is.null(forecast_date) || is.na(forecast_date)) {
76-
cli::cli_warn(paste(
77-
"epi_data's `forecast_date` was {forecast_date}, setting to ",
78-
"the latest time value, {max_time}."
79-
))
82+
cli::cli_warn(
83+
paste(
84+
"epi_data's `forecast_date` was {forecast_date}, setting to ",
85+
"the latest time value, {max_time}."
86+
),
87+
class = "epipredict__set_forecast_date__max_time_warning"
88+
)
8089
forecast_date <- max_time
8190
} else if (forecast_date < max_time) {
82-
cli::cli_abort(paste(
83-
"`forecast_date` ({(forecast_date)}) is before the most ",
84-
"recent data ({max_time}). Remove before ",
85-
"predicting."
86-
))
91+
cli::cli_abort(
92+
paste(
93+
"`forecast_date` ({(forecast_date)}) is before the most ",
94+
"recent data ({max_time}). Remove before ",
95+
"predicting."
96+
),
97+
class = "epipredict__set_forecast_date__misordered_forecast_date_error"
98+
)
8799
}
88100
# TODO cover the rest of the possible types for as_of and max_time...
89101
if (inherits(max_time, "Date")) {

man/get_test_data.Rd

-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-snapshots.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ test_that("arx_forecaster snapshots", {
9898

9999
test_that("arx_forecaster output format snapshots", {
100100
jhu <- case_death_rate_subset %>%
101-
dplyr::filter(time_value >= as.Date("2021-12-01"))
101+
dplyr::filter(time_value >= as.Date("2021-12-01"))
102102
out1 <- arx_forecaster(
103103
jhu, "death_rate",
104104
c("case_rate", "death_rate")

0 commit comments

Comments
 (0)