Skip to content

Commit 39f4d2d

Browse files
authored
Merge pull request #318 from cmu-delphi/ndefries/epi-slide-rtv
Allow `epi_slide` to access `ref_time_value`
2 parents a4c19f7 + d22b88b commit 39f4d2d

File tree

4 files changed

+300
-15
lines changed

4 files changed

+300
-15
lines changed

NAMESPACE

+2
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,15 @@ importFrom(data.table,key)
6969
importFrom(data.table,set)
7070
importFrom(data.table,setkeyv)
7171
importFrom(dplyr,arrange)
72+
importFrom(dplyr,bind_rows)
7273
importFrom(dplyr,dplyr_col_modify)
7374
importFrom(dplyr,dplyr_reconstruct)
7475
importFrom(dplyr,dplyr_row_slice)
7576
importFrom(dplyr,filter)
7677
importFrom(dplyr,group_by)
7778
importFrom(dplyr,group_by_drop_default)
7879
importFrom(dplyr,group_modify)
80+
importFrom(dplyr,group_vars)
7981
importFrom(dplyr,groups)
8082
importFrom(dplyr,mutate)
8183
importFrom(dplyr,relocate)

R/slide.R

+71-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
#' If `f` is missing, then `...` will specify the computation.
2424
#' @param ... Additional arguments to pass to the function or formula specified
2525
#' via `f`. Alternatively, if `f` is missing, then the `...` is interpreted as
26-
#' an expression for tidy evaluation. See details.
26+
#' an expression for tidy evaluation; in addition to referring to columns
27+
#' directly by name, the expression has access to `.data` and `.env` pronouns
28+
#' as in `dplyr` verbs, and can also refer to `.x`, `.group_key`, and
29+
#' `.ref_time_value`. See details.
2730
#' @param before,after How far `before` and `after` each `ref_time_value` should
2831
#' the sliding window extend? At least one of these two arguments must be
2932
#' provided; the other's default will be 0. Any value provided for either
@@ -119,7 +122,8 @@
119122
#' through the `new_col_name` argument.
120123
#'
121124
#' @importFrom lubridate days weeks
122-
#' @importFrom rlang .data .env !! enquo enquos sym
125+
#' @importFrom dplyr bind_rows group_vars filter select
126+
#' @importFrom rlang .data .env !! enquo enquos sym env
123127
#' @export
124128
#' @examples
125129
#' # slide a 7-day trailing average formula on cases
@@ -166,11 +170,8 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
166170

167171
# Check that `f` takes enough args
168172
if (!missing(f) && is.function(f)) {
169-
assert_sufficient_f_args(f, ...)
173+
assert_sufficient_f_args(f, ..., n_mandatory_f_args = 3L)
170174
}
171-
172-
# Arrange by increasing time_value
173-
x = arrange(x, time_value)
174175

175176
if (missing(ref_time_values)) {
176177
ref_time_values = unique(x$time_value)
@@ -231,6 +232,35 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
231232
after <- time_step(after)
232233
}
233234

235+
min_ref_time_values = ref_time_values - before
236+
min_ref_time_values_not_in_x <- min_ref_time_values[!(min_ref_time_values %in% unique(x$time_value))]
237+
238+
# Do set up to let us recover `ref_time_value`s later.
239+
# A helper column marking real observations.
240+
x$.real = TRUE
241+
242+
# Create df containing phony data. Df has the same columns and attributes as
243+
# `x`, but filled with `NA`s aside from grouping columns. Number of rows is
244+
# equal to the number of `min_ref_time_values_not_in_x` we have * the
245+
# number of unique levels seen in the grouping columns.
246+
before_time_values_df = data.frame(time_value=min_ref_time_values_not_in_x)
247+
if (length(group_vars(x)) != 0) {
248+
before_time_values_df = dplyr::cross_join(
249+
# Get unique combinations of grouping columns seen in real data.
250+
unique(x[, group_vars(x)]),
251+
before_time_values_df
252+
)
253+
}
254+
# Automatically fill in all other columns from `x` with `NA`s, and carry
255+
# attributes over to new df.
256+
before_time_values_df <- bind_rows(x[0,], before_time_values_df)
257+
before_time_values_df$.real <- FALSE
258+
259+
x <- bind_rows(before_time_values_df, x)
260+
261+
# Arrange by increasing time_value
262+
x = arrange(x, time_value)
263+
234264
# Now set up starts and stops for sliding/hopping
235265
time_range = range(unique(x$time_value))
236266
starts = in_range(ref_time_values - before, time_range)
@@ -272,7 +302,9 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
272302
o = .data_group$time_value %in% time_values
273303
num_ref_rows = sum(o)
274304

275-
# Count the number of appearances of each reference time value
305+
# Count the number of appearances of each reference time value (these
306+
# appearances should all be real for now, but if we allow ref time values
307+
# outside of .data_group's time values):
276308
counts = .data_group %>%
277309
dplyr::filter(.data$time_value %in% time_values) %>%
278310
dplyr::count(.data$time_value) %>%
@@ -282,7 +314,7 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
282314
!all(purrr::map_lgl(slide_values_list, is.data.frame))) {
283315
Abort("The slide computations must return always atomic vectors or data frames (and not a mix of these two structures).")
284316
}
285-
317+
286318
# Unlist if appropriate:
287319
slide_values =
288320
if (as_list_col) {
@@ -318,16 +350,24 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
318350
# fills with NA equivalent.
319351
vctrs::vec_slice(slide_values, o) = orig_values
320352
} else {
353+
# This implicitly removes phony (`.real` == FALSE) observations.
321354
.data_group = filter(.data_group, o)
322355
}
323356
return(mutate(.data_group, !!new_col := slide_values))
324357
}
325358

326359
# If f is not missing, then just go ahead, slide by group
327360
if (!missing(f)) {
361+
if (rlang::is_formula(f)) f = as_slide_computation(f)
362+
f_rtv_wrapper = function(x, g, ...) {
363+
ref_time_value = min(x$time_value) + before
364+
x <- x[x$.real,]
365+
x$.real <- NULL
366+
f(x, g, ref_time_value, ...)
367+
}
328368
x = x %>%
329369
group_modify(slide_one_grp,
330-
f = f, ...,
370+
f = f_rtv_wrapper, ...,
331371
starts = starts,
332372
stops = stops,
333373
time_values = ref_time_values,
@@ -347,7 +387,18 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
347387
}
348388

349389
quo = quos[[1]]
350-
f = function(x, quo, ...) rlang::eval_tidy(quo, x)
390+
f = function(.x, .group_key, quo, ...) {
391+
.ref_time_value = min(.x$time_value) + before
392+
.x <- .x[.x$.real,]
393+
.x$.real <- NULL
394+
data_mask = rlang::as_data_mask(.x)
395+
# We'll also install `.x` directly, not as an `rlang_data_pronoun`, so
396+
# that we can, e.g., use more dplyr and epiprocess operations.
397+
data_mask$.x = .x
398+
data_mask$.group_key = .group_key
399+
data_mask$.ref_time_value = .ref_time_value
400+
rlang::eval_tidy(quo, data_mask)
401+
}
351402
new_col = sym(names(rlang::quos_auto_name(quos)))
352403

353404
x = x %>%
@@ -365,5 +416,15 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
365416
if (!as_list_col) {
366417
x = unnest(x, !!new_col, names_sep = names_sep)
367418
}
419+
420+
# Remove any remaining phony observations. When `all_rows` is TRUE, phony
421+
# observations aren't necessarily removed in `slide_one_grp`.
422+
if (all_rows) {
423+
x <- x[x$.real,]
424+
}
425+
426+
# Drop helper column `.real`.
427+
x$.real <- NULL
428+
368429
return(x)
369430
}

man/epi_slide.Rd

+4-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)