Skip to content

Commit ad1f753

Browse files
committed
refactor: hoist some epi_slide_opt pre-processing to helpers
For re-use in an epix_epi_slide_opt
1 parent 5bd90fd commit ad1f753

4 files changed

+259
-118
lines changed

R/slide.R

+174-118
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,164 @@ get_before_after_from_window <- function(window_size, align, time_type) {
551551
return(list(before = before, after = after))
552552
}
553553

554+
555+
#' Information about upstream (`{data.table}`/`{slider}`) slide functions
556+
#'
557+
#' Underlies [`upstream_slide_f_info`].
558+
#'
559+
#' @keywords internal
560+
upstream_slide_f_possibilities <- tibble::tribble(
561+
~f, ~package, ~namer,
562+
frollmean, "data.table", ~ if (is.logical(.x)) "prop" else "av",
563+
frollsum, "data.table", ~ if (is.logical(.x)) "count" else "sum",
564+
frollapply, "data.table", ~"slide",
565+
slide_sum, "slider", ~ if (is.logical(.x)) "count" else "sum",
566+
slide_prod, "slider", ~"prod",
567+
slide_mean, "slider", ~ if (is.logical(.x)) "prop" else "av",
568+
slide_min, "slider", ~"min",
569+
slide_max, "slider", ~"max",
570+
slide_all, "slider", ~"all",
571+
slide_any, "slider", ~"any",
572+
)
573+
574+
#' Validate & get information about an upstream slide function
575+
#'
576+
#' @param .f function such as `data.table::frollmean` or `slider::slide_mean`;
577+
#' must appear in [`upstream_slide_f_possibilities`]
578+
#' @return named list with two elements: `from_package`, a string containing the
579+
#' upstream package name ("data.table" or "slider"), and `namer`, a function
580+
#' that takes a column to call `.f` on and outputs a basic name or
581+
#' abbreviation for what operation `.f` represents on that kind of column
582+
#' (e.g., "sum", "av", "count").
583+
#'
584+
#' @keywords internal
585+
upstream_slide_f_info <- function(.f) {
586+
# Check that slide function `.f` is one of those short-listed from
587+
# `data.table` and `slider` (or a function that has the exact same definition,
588+
# e.g. if the function has been reexported or defined locally). Extract some
589+
# metadata. `namer` will be mapped over columns (.x will be a column, not the
590+
# entire edf).
591+
f_info_row <- upstream_slide_f_possibilities %>%
592+
filter(map_lgl(.data$f, ~ identical(.f, .x)))
593+
if (nrow(f_info_row) == 0L) {
594+
# `f` is from somewhere else and not supported
595+
cli_abort(
596+
c(
597+
"problem with {rlang::expr_label(rlang::caller_arg(f))}",
598+
"i" = "`f` must be one of `data.table`'s rolling functions (`frollmean`,
599+
`frollsum`, `frollapply`. See `?data.table::roll`) or one of
600+
`slider`'s specialized sliding functions (`slide_mean`, `slide_sum`,
601+
etc. See `?slider::\`summary-slide\`` for more options)."
602+
),
603+
class = "epiprocess__epi_slide_opt__unsupported_slide_function",
604+
epiprocess__f = .f
605+
)
606+
}
607+
if (nrow(f_info_row) > 1L) {
608+
cli_abort('epiprocess internal error: looking up `.f` in table of possible
609+
functions yielded multiple matches. Please report it using "New
610+
issue" at https://github.com/cmu-delphi/epiprocess/issues, using
611+
reprex::reprex to provide a minimal reproducible example.')
612+
}
613+
f_from_package <- f_info_row$package
614+
list(
615+
from_package = f_from_package,
616+
namer = unwrap(f_info_row$namer)
617+
)
618+
}
619+
620+
#' Calculate input and output column names for an `{epiprocess}` [`dplyr::across`]-like operations
621+
#'
622+
#' @param .x data.frame to perform input column tidyselection on
623+
#' @param time_type as in [`new_epi_df`]
624+
#' @param col_names_quo enquosed input column tidyselect expression
625+
#' @param .f_namer function taking an input column object and outputting a name
626+
#' for a corresponding output column; see [`upstream_slide_f_info`]
627+
#' @param .window_size as in [`epi_slide_opt`]
628+
#' @param .align as in [`epi_slide_opt`]
629+
#' @param .prefix as in [`epi_slide_opt`]
630+
#' @param .suffix as in [`epi_slide_opt`]
631+
#' @param .new_col_names as in [`epi_slide_opt`]
632+
#' @return named list with two elements: `input_col_names`, chr, subset of
633+
#' `names(.x)`; and `output_colnames`, chr, same length as `input_col_names`
634+
#'
635+
#' @keywords internal
636+
across_ish_names_info <- function(.x, time_type, col_names_quo, .f_namer, .window_size, .align, .prefix, .suffix, .new_col_names) {
637+
# The position of a given column can be differ between input `.x` and
638+
# `.data_group` since the grouping step by default drops grouping columns.
639+
# To avoid rerunning `eval_select` for every `.data_group`, convert
640+
# positions of user-provided `col_names` into string column names. We avoid
641+
# using `names(pos)` directly for robustness and in case we later want to
642+
# allow users to rename fields via tidyselection.
643+
pos <- eval_select(col_names_quo, data = .x, allow_rename = FALSE)
644+
input_col_names <- names(.x)[pos]
645+
646+
# Handle output naming
647+
if ((!is.null(.prefix) || !is.null(.suffix)) && !is.null(.new_col_names)) {
648+
cli_abort(
649+
"Can't use both .prefix/.suffix and .new_col_names at the same time.",
650+
class = "epiprocess__epi_slide_opt_incompatible_naming_args"
651+
)
652+
}
653+
assert_string(.prefix, null.ok = TRUE)
654+
assert_string(.suffix, null.ok = TRUE)
655+
assert_character(.new_col_names, len = length(input_col_names), null.ok = TRUE)
656+
if (is.null(.prefix) && is.null(.suffix) && is.null(.new_col_names)) {
657+
.suffix <- "_{.n}{.time_unit_abbr}{.align_abbr}{.f_abbr}"
658+
# ^ does not account for any arguments specified to underlying functions via
659+
# `...` such as `na.rm =`, nor does it distinguish between functions from
660+
# different packages accomplishing the same type of computation. Those are
661+
# probably only set one way per task, so this probably produces cleaner
662+
# names without clashes (though maybe some confusion if switching between
663+
# code with different settings).
664+
}
665+
if (!is.null(.prefix) || !is.null(.suffix)) {
666+
.prefix <- .prefix %||% ""
667+
.suffix <- .suffix %||% ""
668+
if (identical(.window_size, Inf)) {
669+
n <- "running_"
670+
time_unit_abbr <- ""
671+
align_abbr <- ""
672+
} else {
673+
n <- time_delta_to_n_steps(.window_size, time_type)
674+
time_unit_abbr <- time_type_unit_abbr(time_type)
675+
align_abbr <- c(right = "", center = "c", left = "l")[[.align]]
676+
}
677+
glue_env <- rlang::env(
678+
.n = n,
679+
.time_unit_abbr = time_unit_abbr,
680+
.align_abbr = align_abbr,
681+
.f_abbr = purrr::map_chr(.x[, c(input_col_names)], .f_namer), # compat between DT and tbl selection
682+
quo_get_env(col_names_quo)
683+
)
684+
.new_col_names <- unclass(
685+
glue(.prefix, .envir = glue_env) +
686+
input_col_names +
687+
glue(.suffix, .envir = glue_env)
688+
)
689+
} else {
690+
# `.new_col_names` was provided by user; we don't need to do anything.
691+
}
692+
if (any(.new_col_names %in% names(.x))) {
693+
cli_abort(c(
694+
"Naming conflict between new columns and existing columns",
695+
"x" = "Overlapping names: {format_varnames(intersect(.new_col_names, names(.x)))}"
696+
), class = "epiprocess__epi_slide_opt_old_new_name_conflict")
697+
}
698+
if (anyDuplicated(.new_col_names)) {
699+
cli_abort(c(
700+
"New column names contain duplicates",
701+
"x" = "Duplicated names: {format_varnames(unique(.new_col_names[duplicated(.new_col_names)]))}"
702+
), class = "epiprocess__epi_slide_opt_new_name_duplicated")
703+
}
704+
output_col_names <- .new_col_names
705+
706+
return(list(
707+
input_col_names = input_col_names,
708+
output_col_names = output_col_names
709+
))
710+
}
711+
554712
#' Optimized slide functions for common cases
555713
#'
556714
#' @description `epi_slide_opt` allows sliding an n-timestep [data.table::froll]
@@ -750,59 +908,12 @@ epi_slide_opt <- function(
750908
# Check for duplicated time values within groups
751909
assert(check_ukey_unique(ungroup(.x), c(group_vars(.x), "time_value")))
752910

753-
# The position of a given column can be differ between input `.x` and
754-
# `.data_group` since the grouping step by default drops grouping columns.
755-
# To avoid rerunning `eval_select` for every `.data_group`, convert
756-
# positions of user-provided `col_names` into string column names. We avoid
757-
# using `names(pos)` directly for robustness and in case we later want to
758-
# allow users to rename fields via tidyselection.
911+
# Validate/process .col_names, .f:
759912
col_names_quo <- enquo(.col_names)
760-
pos <- eval_select(col_names_quo, data = .x, allow_rename = FALSE)
761-
col_names_chr <- names(.x)[pos]
762-
763-
# Check that slide function `.f` is one of those short-listed from
764-
# `data.table` and `slider` (or a function that has the exact same definition,
765-
# e.g. if the function has been reexported or defined locally). Extract some
766-
# metadata. `namer` will be mapped over columns (.x will be a column, not the
767-
# entire edf).
768-
f_possibilities <-
769-
tibble::tribble(
770-
~f, ~package, ~namer,
771-
frollmean, "data.table", ~ if (is.logical(.x)) "prop" else "av",
772-
frollsum, "data.table", ~ if (is.logical(.x)) "count" else "sum",
773-
frollapply, "data.table", ~"slide",
774-
slide_sum, "slider", ~ if (is.logical(.x)) "count" else "sum",
775-
slide_prod, "slider", ~"prod",
776-
slide_mean, "slider", ~ if (is.logical(.x)) "prop" else "av",
777-
slide_min, "slider", ~"min",
778-
slide_max, "slider", ~"max",
779-
slide_all, "slider", ~"all",
780-
slide_any, "slider", ~"any",
781-
)
782-
f_info <- f_possibilities %>%
783-
filter(map_lgl(.data$f, ~ identical(.f, .x)))
784-
if (nrow(f_info) == 0L) {
785-
# `f` is from somewhere else and not supported
786-
cli_abort(
787-
c(
788-
"problem with {rlang::expr_label(rlang::caller_arg(f))}",
789-
"i" = "`f` must be one of `data.table`'s rolling functions (`frollmean`,
790-
`frollsum`, `frollapply`. See `?data.table::roll`) or one of
791-
`slider`'s specialized sliding functions (`slide_mean`, `slide_sum`,
792-
etc. See `?slider::\`summary-slide\`` for more options)."
793-
),
794-
class = "epiprocess__epi_slide_opt__unsupported_slide_function",
795-
epiprocess__f = .f
796-
)
797-
}
798-
if (nrow(f_info) > 1L) {
799-
cli_abort('epiprocess internal error: looking up `.f` in table of possible
800-
functions yielded multiple matches. Please report it using "New
801-
issue" at https://github.com/cmu-delphi/epiprocess/issues, using
802-
reprex::reprex to provide a minimal reproducible example.')
803-
}
804-
f_from_package <- f_info$package
913+
f_info <- upstream_slide_f_info(.f)
914+
f_from_package <- f_info$from_package
805915

916+
# Validate/process .ref_time_values:
806917
user_provided_rtvs <- !is.null(.ref_time_values)
807918
if (!user_provided_rtvs) {
808919
.ref_time_values <- unique(.x$time_value)
@@ -832,65 +943,10 @@ epi_slide_opt <- function(
832943
validate_slide_window_arg(.window_size, time_type)
833944
window_args <- get_before_after_from_window(.window_size, .align, time_type)
834945

835-
# Handle output naming
836-
if ((!is.null(.prefix) || !is.null(.suffix)) && !is.null(.new_col_names)) {
837-
cli_abort(
838-
"Can't use both .prefix/.suffix and .new_col_names at the same time.",
839-
class = "epiprocess__epi_slide_opt_incompatible_naming_args"
840-
)
841-
}
842-
assert_string(.prefix, null.ok = TRUE)
843-
assert_string(.suffix, null.ok = TRUE)
844-
assert_character(.new_col_names, len = length(col_names_chr), null.ok = TRUE)
845-
if (is.null(.prefix) && is.null(.suffix) && is.null(.new_col_names)) {
846-
.suffix <- "_{.n}{.time_unit_abbr}{.align_abbr}{.f_abbr}"
847-
# ^ does not account for any arguments specified to underlying functions via
848-
# `...` such as `na.rm =`, nor does it distinguish between functions from
849-
# different packages accomplishing the same type of computation. Those are
850-
# probably only set one way per task, so this probably produces cleaner
851-
# names without clashes (though maybe some confusion if switching between
852-
# code with different settings).
853-
}
854-
if (!is.null(.prefix) || !is.null(.suffix)) {
855-
.prefix <- .prefix %||% ""
856-
.suffix <- .suffix %||% ""
857-
if (identical(.window_size, Inf)) {
858-
n <- "running_"
859-
time_unit_abbr <- ""
860-
align_abbr <- ""
861-
} else {
862-
n <- time_delta_to_n_steps(.window_size, time_type)
863-
time_unit_abbr <- time_type_unit_abbr(time_type)
864-
align_abbr <- c(right = "", center = "c", left = "l")[[.align]]
865-
}
866-
glue_env <- rlang::env(
867-
.n = n,
868-
.time_unit_abbr = time_unit_abbr,
869-
.align_abbr = align_abbr,
870-
.f_abbr = purrr::map_chr(.x[col_names_chr], unwrap(f_info$namer)),
871-
quo_get_env(col_names_quo)
872-
)
873-
.new_col_names <- unclass(
874-
glue(.prefix, .envir = glue_env) +
875-
col_names_chr +
876-
glue(.suffix, .envir = glue_env)
877-
)
878-
} else {
879-
# `.new_col_names` was provided by user; we don't need to do anything.
880-
}
881-
if (any(.new_col_names %in% names(.x))) {
882-
cli_abort(c(
883-
"Naming conflict between new columns and existing columns",
884-
"x" = "Overlapping names: {format_varnames(intersect(.new_col_names, names(.x)))}"
885-
), class = "epiprocess__epi_slide_opt_old_new_name_conflict")
886-
}
887-
if (anyDuplicated(.new_col_names)) {
888-
cli_abort(c(
889-
"New column names contain duplicates",
890-
"x" = "Duplicated names: {format_varnames(unique(.new_col_names[duplicated(.new_col_names)]))}"
891-
), class = "epiprocess__epi_slide_opt_new_name_duplicated")
892-
}
893-
result_col_names <- .new_col_names
946+
# Handle output naming:
947+
names_info <- across_ish_names_info(.x, time_type, col_names_quo, f_info$namer, .window_size, .align, .prefix, .suffix, .new_col_names)
948+
input_col_names <- names_info$input_col_names
949+
output_col_names <- names_info$output_col_names
894950

895951
# Make a complete date sequence between min(.x$time_value) and max(.x$time_value).
896952
date_seq_list <- full_date_seq(.x, window_args$before, window_args$after, time_type)
@@ -935,23 +991,23 @@ epi_slide_opt <- function(
935991
# be; shift results to the left by `after` timesteps.
936992
if (window_args$before != Inf) {
937993
window_size <- window_args$before + window_args$after + 1L
938-
roll_output <- .f(x = .data_group[, col_names_chr], n = window_size, ...)
994+
roll_output <- .f(x = .data_group[, input_col_names], n = window_size, ...)
939995
} else {
940996
window_size <- list(seq_along(.data_group$time_value))
941-
roll_output <- .f(x = .data_group[, col_names_chr], n = window_size, adaptive = TRUE, ...)
997+
roll_output <- .f(x = .data_group[, input_col_names], n = window_size, adaptive = TRUE, ...)
942998
}
943999
if (window_args$after >= 1) {
944-
.data_group[, result_col_names] <- purrr::map(roll_output, function(.x) {
1000+
.data_group[, output_col_names] <- purrr::map(roll_output, function(.x) {
9451001
c(.x[(window_args$after + 1L):length(.x)], rep(NA, window_args$after))
9461002
})
9471003
} else {
948-
.data_group[, result_col_names] <- roll_output
1004+
.data_group[, output_col_names] <- roll_output
9491005
}
9501006
}
9511007
if (f_from_package == "slider") {
952-
for (i in seq_along(col_names_chr)) {
953-
.data_group[, result_col_names[i]] <- .f(
954-
x = .data_group[[col_names_chr[i]]],
1008+
for (i in seq_along(input_col_names)) {
1009+
.data_group[, output_col_names[i]] <- .f(
1010+
x = .data_group[[input_col_names[i]]],
9551011
before = as.numeric(window_args$before),
9561012
after = as.numeric(window_args$after),
9571013
...
@@ -970,7 +1026,7 @@ epi_slide_opt <- function(
9701026
group_by(!!!.x_orig_groups)
9711027

9721028
if (.all_rows) {
973-
result[!vec_in(result$time_value, ref_time_values), result_col_names] <- NA
1029+
result[!vec_in(result$time_value, ref_time_values), output_col_names] <- NA
9741030
} else if (user_provided_rtvs) {
9751031
result <- result[vec_in(result$time_value, ref_time_values), ]
9761032
}

man/across_ish_names_info.Rd

+46
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)