-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathstep_epi_slide_mean.R
188 lines (177 loc) · 5.57 KB
/
step_epi_slide_mean.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#' Calculate a rolling mean
#'
#' `step_epi_slide_mean()` creates a *specification* of a recipe step that will
#' generate one or more new columns of derived data by computing a sliding
#' mean along existing data.
#'
#'
#' @inheritParams step_epi_lag
#' @param before,after non-negative integers.
#' How far `before` and `after` each `time_value` should
#' the sliding window extend? Any value provided for either
#' argument must be a single, non-`NA`, non-negative,
#' [integer-compatible][vctrs::vec_cast] number of time steps. Endpoints of
#' the window are inclusive. Common settings:
#' * For trailing/right-aligned windows from `time_value - time_step(k)` to
#' `time_value`, use `before=k, after=0`. This is the most likely use case
#' for the purposes of forecasting.
#' * For center-aligned windows from `time_value - time_step(k)` to
#' `time_value + time_step(k)`, use `before=k, after=k`.
#' * For leading/left-aligned windows from `time_value` to
#' `time_value + time_step(k)`, use `after=k, after=0`.
#'
#' You may also pass a [lubridate::period], like `lubridate::weeks(1)` or a
#' character string that is coercible to a [lubridate::period], like
#' `"2 weeks"`.
#' @template step-return
#'
#' @export
#' @examples
#' library(dplyr)
#' jhu <- case_death_rate_subset %>%
#' filter(time_value >= as.Date("2021-01-01"), geo_value %in% c("ca", "ny"))
#' rec <- epi_recipe(jhu) %>%
#' step_epi_slide(case_rate, death_rate,
#' .f = \(x) mean(x, na.rm = TRUE),
#' before = 6L
#' )
#' bake(prep(rec, jhu), new_data = NULL)
step_epi_slide_mean <-
function(recipe,
...,
before = 0L,
after = 0L,
role = "predictor",
prefix = "epi_slide_mean_",
skip = FALSE,
id = rand_id("epi_slide_mean")) {
if (!is_epi_recipe(recipe)) {
rlang::abort("This recipe step can only operate on an `epi_recipe`.")
}
arg_is_scalar(before, after)
before <- try_period(before)
after <- try_period(after)
arg_is_chr_scalar(role, prefix, id)
arg_is_lgl_scalar(skip)
add_step(
recipe,
step_epi_slide_mean_new(
terms = enquos(...),
before = before,
after = after,
role = role,
trained = FALSE,
prefix = prefix,
keys = epi_keys(recipe),
columns = NULL,
skip = skip,
id = id
)
)
}
step_epi_slide_mean_new <-
function(terms,
before,
after,
role,
trained,
prefix,
keys,
columns,
skip,
id) {
step(
subclass = "epi_slide_mean",
terms = terms,
before = before,
after = after,
role = role,
trained = trained,
prefix = prefix,
keys = keys,
columns = columns,
skip = skip,
id = id
)
}
#' @export
prep.step_epi_slide_mean <- function(x, training, info = NULL, ...) {
col_names <- recipes::recipes_eval_select(x$terms, data = training, info = info)
check_type(training[, col_names], types = c("double", "integer"))
time_type <- attributes(training)$metadata$time_type
before <- lubridate_period_to_integer(x$before, time_type)
after <- lubridate_period_to_integer(x$after, time_type)
step_epi_slide_mean_new(
terms = x$terms,
before = before,
after = after,
role = x$role,
trained = TRUE,
prefix = x$prefix,
keys = x$keys,
columns = col_names,
skip = x$skip,
id = x$id
)
}
#' lubridate converts to seconds by default, and as.integer doesn't throw errors if it isn't actually an integer
#' @importFrom lubridate time_length is.period
#' @keywords internal
lubridate_period_to_integer <- function(value, time_type) {
if (is.period(value)) {
if (time_type == "day") {
value <- time_length(value, unit = "day")
} else if (time_type == "week") {
value <- time_length(value, unit = "week")
} else {
cli_abort(
"unsupported time type of {time_type}. Use integer instead of lubridate period.",
class = "epipredict__step_epi_slide_mean__unsupported_error"
)
}
if (value %% 1 !=0) {
cli_abort(
"Converted `before` value of {before} is not an integer.",
class = "epipredict__step_epi_slide_mean__unsupported_error"
)
}
value <- as.integer(value)
}
return(value)
}
#' @export
bake.step_epi_slide_mean <- function(object, new_data, ...) {
recipes::check_new_data(names(object$columns), object, new_data)
col_names <- as.vector(object$columns)
name_prefix <- object$prefix
new_names <- glue::glue("{name_prefix}{col_names}")
## ensure no name clashes
new_data_names <- colnames(new_data)
intersection <- new_data_names %in% new_names
if (any(intersection)) {
nms <- new_data_names[intersection]
cli_abort(
c("In `step_epi_slide_mean()` a name collision occurred. The following variable names already exist:",
`*` = "{.var {nms}}"
),
call = caller_env(),
class = "epipredict__step__name_collision_error"
)
}
renaming <- glue::glue("slide_value_{col_names}")
names(renaming) <- new_names
names(new_names) <- glue::glue("slide_value_{col_names}")
new_data %>%
group_by(across(all_of(object$keys[-1]))) %>%
epi_slide_mean(col_names, before = object$before, after = object$after) %>%
rename(renaming)
}
#' @export
print.step_epi_slide_mean <- function(x, width = max(20, options()$width - 30), ...) {
print_epi_step(
x$columns, x$terms, x$trained,
title = "Calculating epi_slide for ",
conjunction = "with", extra_text = x$f_name
)
invisible(x)
}