Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix growth rate #437

Merged
merged 8 commits into from
Feb 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.9
Version: 0.1.10
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
@@ -25,7 +25,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
BugReports: https://github.com/cmu-delphi/epipredict/issues/
Depends:
epidatasets,
epiprocess (>= 0.9.0),
epiprocess (>= 0.10.4),
parsnip (>= 1.0.0),
R (>= 3.5.0)
Imports:
@@ -73,7 +73,6 @@ Remotes:
cmu-delphi/epidatasets,
cmu-delphi/epidatr,
cmu-delphi/epiprocess,
cmu-delphi/epidatasets,
dajmcdon/smoothqr
Config/Needs/website: cmu-delphi/delphidocs
Config/testthat/edition: 3
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -11,6 +11,9 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
`data(<dataset name>)`, but can be accessed with
`data(<dataset name>, package = "epidatasets")`, `epidatasets::<dataset name>`
or, after loading the package, the name of the dataset alone (#382).
- Addresses upstream breaking changes from cmu-delphi/epiprocess#595 (`growth_rate()`).
`step_growth_rate()` has lost its `additional_gr_args_list` argument and now
has an `na_rm` argument.

## Improvements

38 changes: 14 additions & 24 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
@@ -27,8 +27,9 @@
#'
#' @examples
#' library(dplyr)
#' tiny_geos <- c("as", "mp", "vi", "gu", "pr")
#' jhu <- covid_case_death_rates %>%
#' filter(time_value >= as.Date("2021-11-01"))
#' filter(time_value >= as.Date("2021-11-01"), !(geo_value %in% tiny_geos))
#'
#' out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate"))
#'
@@ -58,7 +59,10 @@ arx_classifier <- function(
if (args_list$adjust_latency == "none") {
forecast_date_default <- max(epi_data$time_value)
if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) {
cli_warn("The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is occurring {forecast_date}.")
cli_warn(
"The specified forecast date {args_list$forecast_date} doesn't match the
date from which the forecast is occurring {forecast_date}."
)
}
} else {
forecast_date_default <- attributes(epi_data)$metadata$as_of
@@ -101,7 +105,7 @@ arx_classifier <- function(
#'
#' @return An unfit `epi_workflow`.
#' @export
#' @seealso [arx_classifier()]
#' @seealso [arx_classifier()] [arx_class_args_list()]
#' @examples
#' library(dplyr)
#' jhu <- covid_case_death_rates %>%
@@ -154,12 +158,13 @@ arx_class_epi_workflow <- function(
role = "grp",
horizon = args_list$horizon,
method = args_list$method,
log_scale = args_list$log_scale,
additional_gr_args_list = args_list$additional_gr_args
log_scale = args_list$log_scale
)
for (l in seq_along(lags)) {
pred_names <- predictors[l]
pred_names <- as.character(glue::glue_data(args_list, "gr_{horizon}_{method}_{pred_names}"))
pred_names <- as.character(glue::glue_data(
args_list, "gr_{horizon}_{method}_{pred_names}"
))
r <- step_epi_lag(r, !!pred_names, lag = lags[[l]])
}
# ------- outcome
@@ -185,8 +190,7 @@ arx_class_epi_workflow <- function(
role = "pre-outcome",
horizon = args_list$horizon,
method = args_list$method,
log_scale = args_list$log_scale,
additional_gr_args_list = args_list$additional_gr_args
log_scale = args_list$log_scale
)
}
}
@@ -270,9 +274,6 @@ arx_class_epi_workflow <- function(
#' @param method Character. Options available for growth rate calculation.
#' @param log_scale Scalar logical. Whether to compute growth rates on the
#' log scale.
#' @param additional_gr_args List. Optional arguments controlling growth rate
#' calculation. See [epiprocess::growth_rate()] and the related Vignette for
#' more details.
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
#' epi_key that are required for training. If `NULL`, this check is ignored.
#' @param check_enough_data_epi_keys Character vector. A character vector of
@@ -301,7 +302,6 @@ arx_class_args_list <- function(
horizon = 7L,
method = c("rel_change", "linear_reg"),
log_scale = FALSE,
additional_gr_args = list(),
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
...) {
@@ -320,23 +320,14 @@ arx_class_args_list <- function(
arg_is_lgl(log_scale)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (!is.list(additional_gr_args)) {
cli_abort(c(
"`additional_gr_args` must be a {.cls list}.",
"!" = "This is a {.cls {class(additional_gr_args)}}.",
i = "See `?epiprocess::growth_rate` for available arguments."
))
}
arg_is_pos(check_enough_data_n, allow_null = TRUE)
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)

if (!is.null(forecast_date) && !is.null(target_date)) {
if (forecast_date + ahead != target_date) {
cli_warn(
paste0(
"`forecast_date` {.val {forecast_date}} +",
" `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}."
),
"`forecast_date` {.val {forecast_date}} +
`ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
)
}
@@ -362,7 +353,6 @@ arx_class_args_list <- function(
horizon,
method,
log_scale,
additional_gr_args,
check_enough_data_n,
check_enough_data_epi_keys
),
51 changes: 24 additions & 27 deletions R/step_growth_rate.R
Original file line number Diff line number Diff line change
@@ -22,22 +22,25 @@
#' being removed from the data. Alternatively, you could specify arbitrary
#' large values, or perhaps zero. Setting this argument to `NULL` will result
#' in no replacement.
#' @param additional_gr_args_list A list of additional arguments used by
#' [epiprocess::growth_rate()]. All `...` arguments may be passed here along
#' with `dup_rm` and `na_rm`.
#' @inheritParams epiprocess::growth_rate
#' @template step-return
#'
#'
#' @family row operation steps
#' @importFrom epiprocess growth_rate
#' @export
#' @examples
#' r <- epi_recipe(covid_case_death_rates) %>%
#' library(dplyr)
#' tiny_geos <- c("as", "mp", "vi", "gu", "pr")
#' rates <- covid_case_death_rates %>%
#' filter(time_value >= as.Date("2021-11-01"), !(geo_value %in% tiny_geos))
#'
#' r <- epi_recipe(rates) %>%
#' step_growth_rate(case_rate, death_rate)
#' r
#'
#' r %>%
#' prep(covid_case_death_rates) %>%
#' prep(rates) %>%
#' bake(new_data = NULL)
step_growth_rate <-
function(recipe,
@@ -46,11 +49,11 @@ step_growth_rate <-
horizon = 7,
method = c("rel_change", "linear_reg"),
log_scale = FALSE,
na_rm = TRUE,
replace_Inf = NA,
prefix = "gr_",
skip = FALSE,
id = rand_id("growth_rate"),
additional_gr_args_list = list()) {
id = rand_id("growth_rate")) {
if (!is_epi_recipe(recipe)) {
cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
}
@@ -63,15 +66,7 @@ step_growth_rate <-
}
arg_is_chr(role)
arg_is_chr_scalar(prefix, id)
arg_is_lgl_scalar(log_scale, skip)


if (!is.list(additional_gr_args_list)) {
cli_abort(c(
"`additional_gr_args_list` must be a {.cls list}.",
i = "See `?epiprocess::growth_rate` for available options."
))
}
arg_is_lgl_scalar(log_scale, skip, na_rm)

recipes::add_step(
recipe,
@@ -82,13 +77,13 @@ step_growth_rate <-
horizon = horizon,
method = method,
log_scale = log_scale,
na_rm = na_rm,
replace_Inf = replace_Inf,
prefix = prefix,
keys = key_colnames(recipe),
columns = NULL,
skip = skip,
id = id,
additional_gr_args_list = additional_gr_args_list
id = id
)
)
}
@@ -101,13 +96,13 @@ step_growth_rate_new <-
horizon,
method,
log_scale,
na_rm,
replace_Inf,
prefix,
keys,
columns,
skip,
id,
additional_gr_args_list) {
id) {
recipes::step(
subclass = "growth_rate",
terms = terms,
@@ -116,13 +111,13 @@ step_growth_rate_new <-
horizon = horizon,
method = method,
log_scale = log_scale,
na_rm = na_rm,
replace_Inf = replace_Inf,
prefix = prefix,
keys = keys,
columns = columns,
skip = skip,
id = id,
additional_gr_args_list = additional_gr_args_list
id = id
)
}

@@ -137,13 +132,13 @@ prep.step_growth_rate <- function(x, training, info = NULL, ...) {
horizon = x$horizon,
method = x$method,
log_scale = x$log_scale,
na_rm = x$na_rm,
replace_Inf = x$replace_Inf,
prefix = x$prefix,
keys = x$keys,
columns = recipes::recipes_eval_select(x$terms, training, info),
skip = x$skip,
id = x$id,
additional_gr_args_list = x$additional_gr_args_list
id = x$id
)
}

@@ -177,10 +172,12 @@ bake.step_growth_rate <- function(object, new_data, ...) {
across(
all_of(object$columns),
~ epiprocess::growth_rate(
time_value, .x,
.x,
x = time_value,
method = object$method,
h = object$horizon, log_scale = object$log_scale,
!!!object$additional_gr_args_list
h = object$horizon,
log_scale = object$log_scale,
na_rm = object$na_rm
),
.names = "{object$prefix}{object$horizon}_{object$method}_{.col}"
)
5 changes: 0 additions & 5 deletions man/arx_class_args_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/arx_class_epi_workflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/arx_classifier.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions man/epi_recipe.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading