Skip to content

Commit 2dd1ae0

Browse files
dsweber2dshemetov
andauthored
extend to quantile_dist, exclude multi-output (#458)
* extend to quantile_dist, exclude multi-output * Drop by specification and infer from the epi_df * lint+test: test coverage, handle na lambda case, lint * fix: quantile_pred arithmetic * fix: rlang calls --------- Co-authored-by: Dmitry Shemetov <[email protected]>
1 parent 18612d4 commit 2dd1ae0

11 files changed

+324
-232
lines changed

NAMESPACE

+2
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ S3method(update,layer)
122122
S3method(vec_arith,quantile_pred)
123123
S3method(vec_arith.numeric,quantile_pred)
124124
S3method(vec_arith.quantile_pred,numeric)
125+
S3method(vec_arith.quantile_pred,quantile_pred)
125126
S3method(vec_math,quantile_pred)
126127
S3method(vec_proxy_equal,quantile_pred)
127128
S3method(weighted_interval_score,quantile_pred)
@@ -233,6 +234,7 @@ import(epidatasets)
233234
import(epiprocess)
234235
import(parsnip)
235236
import(recipes)
237+
import(vctrs)
236238
importFrom(checkmate,assert_class)
237239
importFrom(checkmate,assert_numeric)
238240
importFrom(checkmate,test_character)

R/layer_yeo_johnson.R

+87-149
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
#' Unormalizing transformation
22
#'
3-
#' Will undo a step_epi_YeoJohnson transformation.
3+
#' Will undo a step_epi_YeoJohnson transformation. For practical reasons, if you
4+
#' are using this step on a column that will eventually become the outcome
5+
#' variable, you should make sure that the original name of that column is a
6+
#' subset of the outcome variable name. `ahead_7_cases` when `cases` is
7+
#' transformed will work well, while `ahead_7` will not.
48
#'
59
#' @inheritParams layer_population_scaling
6-
#' @param yj_params Internal. A data frame of parameters to be used for
7-
#' inverting the transformation.
8-
#' @param by A (possibly named) character vector of variables to join by.
10+
#' @param yj_params A data frame of parameters to be used for inverting the
11+
#' transformation. Typically set automatically. If you have done multiple
12+
#' transformations such that the outcome variable name no longer contains the
13+
#' column that this step transforms, then you should manually specify this to
14+
#' be the parameters fit in the corresponding `step_epi_YeoJohnson`. For an
15+
#' example where you wouldn't need to set this, if your output is
16+
#' `ahead_7_cases` and `step_epi_YeoJohnson` transformed cases (possibly with
17+
#' other columns), then you wouldn't need to set this. However if you have
18+
#' renamed your output column to `diff_7`, then you will need to extract the `yj_params` from the step.
919
#'
1020
#' @return an updated `frosting` postprocessor
1121
#' @export
@@ -37,65 +47,36 @@
3747
#' # Compare to the original data.
3848
#' jhu %>% filter(time_value == "2021-12-31")
3949
#' forecast(wf)
40-
layer_epi_YeoJohnson <- function(frosting, ..., yj_params = NULL, by = NULL, id = rand_id("epi_YeoJohnson")) {
50+
layer_epi_YeoJohnson <- function(frosting, ..., yj_params = NULL, id = rand_id("epi_YeoJohnson")) {
4151
checkmate::assert_tibble(yj_params, min.rows = 1, null.ok = TRUE)
4252

4353
add_layer(
4454
frosting,
4555
layer_epi_YeoJohnson_new(
4656
yj_params = yj_params,
47-
by = by,
4857
terms = dplyr::enquos(...),
4958
id = id
5059
)
5160
)
5261
}
5362

54-
layer_epi_YeoJohnson_new <- function(yj_params, by, terms, id) {
55-
layer("epi_YeoJohnson", yj_params = yj_params, by = by, terms = terms, id = id)
63+
layer_epi_YeoJohnson_new <- function(yj_params, terms, id) {
64+
layer("epi_YeoJohnson", yj_params = yj_params, terms = terms, id = id)
5665
}
5766

5867
#' @export
5968
#' @importFrom workflows extract_preprocessor
6069
slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data, ...) {
6170
rlang::check_dots_empty()
6271

63-
# TODO: We will error if we don't have a workflow. Write a check later.
64-
65-
# Get the yj_params from the layer or from the workflow.
66-
yj_params <- object$yj_params %||% get_yj_params_in_layer(workflow)
67-
68-
# If the by is not specified, try to infer it from the yj_params.
69-
if (is.null(object$by)) {
70-
# Assume `layer_predict` has calculated the prediction keys and other
71-
# layers don't change the prediction key colnames:
72-
prediction_key_colnames <- names(components$keys)
73-
lhs_potential_keys <- prediction_key_colnames
74-
rhs_potential_keys <- colnames(select(yj_params, -starts_with(".yj_param_")))
75-
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
76-
suggested_min_keys <- setdiff(lhs_potential_keys, "time_value")
77-
if (!all(suggested_min_keys %in% object$by)) {
78-
cli_warn(
79-
c(
80-
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
81-
but {?wasn't/weren't} found in the population `df`.",
82-
"i" = "Defaulting to join by {object$by}",
83-
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
84-
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
85-
">" = "Manually specify `by =` to silence"
86-
),
87-
class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys"
88-
)
89-
}
90-
}
72+
# get the yj_params from the layer or from the workflow.
73+
yj_params <-
74+
object$yj_params %||%
75+
get_params_in_layer(workflow, "epi_YeoJohnson", "yj_params")
9176

9277
# Establish the join columns.
93-
object$by <- object$by %||%
94-
intersect(
95-
epi_keys_only(components$predictions),
96-
colnames(select(yj_params, -starts_with(".yj_param_")))
97-
)
98-
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
78+
join_by_columns <- key_colnames(new_data, exclude = "time_value") %>% sort()
79+
joinby <- list(x = join_by_columns, y = join_by_columns)
9980
hardhat::validate_column_names(components$predictions, joinby$x)
10081
hardhat::validate_column_names(yj_params, joinby$y)
10182

@@ -115,55 +96,15 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
11596
# The `object$terms` is where the user specifies the columns they want to
11697
# untransform. We need to match the outcomes with their yj_param columns in our
11798
# parameter table and then apply the inverse transformation.
118-
if (identical(col_names, ".pred")) {
119-
# In this case, we don't get a hint for the outcome column name, so we need
120-
# to infer it from the mold.
121-
if (length(components$mold$outcomes) > 1) {
122-
cli_abort("Only one outcome is allowed when specifying `.pred`.", call = rlang::caller_env())
123-
}
124-
# `outcomes` is a vector of objects like ahead_1_cases, ahead_7_cases, etc.
125-
# We want to extract the cases part.
126-
outcome_cols <- names(components$mold$outcomes) %>%
127-
stringr::str_match("ahead_\\d+_(.*)") %>%
128-
magrittr::extract(, 2)
129-
99+
if (length(col_names) == 0) {
100+
# not specified by the user, so just modify everything starting with `.pred`
130101
components$predictions <- components$predictions %>%
131-
mutate(.pred := yj_inverse(.pred, !!sym(paste0(".yj_param_", outcome_cols))))
132-
} else if (identical(col_names, character(0))) {
133-
# Wish I could suggest `all_outcomes()` here, but currently it's the same as
134-
# not specifying any terms. I don't want to spend time with dealing with
135-
# this case until someone asks for it.
136-
cli::cli_abort(
137-
"Not specifying columns to layer Yeo-Johnson is not implemented.
138-
If you had a single outcome, you can use `.pred` as a column name.
139-
If you had multiple outcomes, you'll need to specify them like
140-
`.pred_ahead_1_<outcome_col>`, `.pred_ahead_7_<outcome_col>`, etc.
141-
",
142-
call = rlang::caller_env()
143-
)
102+
mutate(across(starts_with(".pred"), \(.pred) yj_inverse(.pred, .lambda))) %>%
103+
select(-.lambda)
144104
} else {
145-
# In this case, we assume that the user has specified the columns they want
146-
# transformed here. We then need to determine the yj_param columns for each of
147-
# these columns. That is, we need to convert a vector of column names like
148-
# c(".pred_ahead_1_case_rate", ".pred_ahead_7_case_rate") to
149-
# c(".yj_param_ahead_1_case_rate", ".yj_param_ahead_7_case_rate").
150-
original_outcome_cols <- stringr::str_match(col_names, ".pred_ahead_\\d+_(.*)")[, 2]
151-
outcomes_wout_ahead <- stringr::str_match(names(components$mold$outcomes), "ahead_\\d+_(.*)")[, 2]
152-
if (any(original_outcome_cols %nin% outcomes_wout_ahead)) {
153-
cli_abort(
154-
"All columns specified in `...` must be outcome columns.
155-
They must be of the form `.pred_ahead_1_<outcome_col>`, `.pred_ahead_7_<outcome_col>`, etc.
156-
",
157-
call = rlang::caller_env()
158-
)
159-
}
160-
161-
for (i in seq_along(col_names)) {
162-
col <- col_names[i]
163-
yj_param_col <- paste0(".yj_param_", original_outcome_cols[i])
164-
components$predictions <- components$predictions %>%
165-
mutate(!!sym(col) := yj_inverse(!!sym(col), !!sym(yj_param_col)))
166-
}
105+
components$predictions <- components$predictions %>%
106+
mutate(across(all_of(col_names), \(.pred) yj_inverse(.pred, .lambda))) %>%
107+
select(-.lambda)
167108
}
168109

169110
# Remove the yj_param columns.
@@ -182,75 +123,72 @@ print.layer_epi_YeoJohnson <- function(x, width = max(20, options()$width - 30),
182123
# Inverse Yeo-Johnson transformation
183124
#
184125
# Inverse of `yj_transform` in step_yeo_johnson.R.
185-
yj_inverse <- function(x, lambda, eps = 0.001) {
126+
yj_inverse <- function(x_in, lambda, eps = 0.001) {
186127
if (any(is.na(lambda))) {
187-
return(x)
188-
}
189-
if (length(x) > 1 && length(lambda) == 1) {
190-
lambda <- rep(lambda, length(x))
191-
} else if (length(x) != length(lambda)) {
192-
cli::cli_abort("Length of `x` must be equal to length of `lambda`.", call = rlang::caller_fn())
193-
}
194-
if (!inherits(x, "tbl_df") || is.data.frame(x)) {
195-
x <- unlist(x, use.names = FALSE)
196-
} else {
197-
if (!is.vector(x)) {
198-
x <- as.vector(x)
199-
}
200-
}
201-
202-
nn_inv_trans <- function(x, lambda) {
203-
out <- double(length(x))
204-
sm_lambdas <- abs(lambda) < eps
205-
if (length(sm_lambdas) > 0) {
206-
out[sm_lambdas] <- exp(x[sm_lambdas]) - 1
207-
}
208-
x <- x[!sm_lambdas]
209-
lambda <- lambda[!sm_lambdas]
210-
if (length(x) > 0) {
211-
out[!sm_lambdas] <- (lambda * x + 1)^(1 / lambda) - 1
212-
}
213-
out
128+
cli::cli_abort("`lambda` cannot be `NA`.", call = rlang::caller_call())
214129
}
215-
216-
ng_inv_trans <- function(x, lambda) {
217-
out <- double(length(x))
218-
near2_lambdas <- abs(lambda - 2) < eps
219-
if (length(near2_lambdas) > 0) {
220-
out[near2_lambdas] <- -(exp(-x[near2_lambdas]) - 1)
221-
}
222-
x <- x[!near2_lambdas]
223-
lambda <- lambda[!near2_lambdas]
224-
if (length(x) > 0) {
225-
out[!near2_lambdas] <- -(((lambda - 2) * x + 1)^(1 / (2 - lambda)) - 1)
226-
}
227-
out
228-
}
229-
230-
dat_neg <- x < 0
231-
not_neg <- which(!dat_neg)
232-
is_neg <- which(dat_neg)
233-
234-
if (length(not_neg) > 0) {
235-
x[not_neg] <- nn_inv_trans(x[not_neg], lambda[not_neg])
236-
}
237-
238-
if (length(is_neg) > 0) {
239-
x[is_neg] <- ng_inv_trans(x[is_neg], lambda[is_neg])
130+
x_lambda <- yj_input_type_management(x_in, lambda)
131+
x <- x_lambda[[1]]
132+
lambda <- x_lambda[[2]]
133+
inv_x <- ifelse(
134+
x < 0,
135+
# negative values we test if lambda is ~2
136+
ifelse(
137+
abs(lambda - 2) < eps,
138+
-(exp(-x) - 1),
139+
-(((lambda - 2) * x + 1)^(1 / (2 - lambda)) - 1)
140+
),
141+
# non-negative values we test if lambda is ~0
142+
ifelse(
143+
abs(lambda) < eps,
144+
(exp(x) - 1),
145+
(lambda * x + 1)^(1 / lambda) - 1
146+
)
147+
)
148+
if (x_in %>% inherits("quantile_pred")) {
149+
inv_x <- inv_x %>% quantile_pred(x_in %@% "quantile_levels")
240150
}
241-
x
151+
inv_x
242152
}
243153

244-
get_yj_params_in_layer <- function(workflow) {
154+
155+
#' get the parameters used in the initial step
156+
#'
157+
#' @param workflow the workflow to extract the parameters from
158+
#' @param step_name the name of the step to look for, as recognized by `detect_step`
159+
#' @param param_name the parameter to pull out of the step
160+
#' @keywords internal
161+
get_params_in_layer <- function(workflow, step_name = "epi_YeoJohnson", param_name = "yj_params") {
162+
full_step_name <- glue::glue("step_{step_name}")
245163
this_recipe <- hardhat::extract_recipe(workflow)
246-
if (!(this_recipe %>% recipes::detect_step("epi_YeoJohnson"))) {
247-
cli_abort("`layer_epi_YeoJohnson` requires `step_epi_YeoJohnson` in the recipe.", call = rlang::caller_env())
164+
if (!(this_recipe %>% recipes::detect_step(step_name))) {
165+
cli_abort("`layer_{step_name}` requires `step_{step_name}` in the recipe.", call = rlang::caller_call())
166+
}
167+
outcomes <-
168+
workflows::extract_recipe(workflow)$term_info %>%
169+
filter(role == "outcome") %>%
170+
pull(variable)
171+
if (length(outcomes) > 1) {
172+
cli_abort(
173+
"`layer_{step_name}` doesn't support multiple output columns.
174+
This workflow produces {outcomes} as output columns.",
175+
call = rlang::caller_call(),
176+
class = "epipredict__layer_yeo_johnson_multi_outcome_error"
177+
)
248178
}
249179
for (step in this_recipe$steps) {
250-
if (inherits(step, "step_epi_YeoJohnson")) {
251-
yj_params <- step$yj_params
180+
# if it's a `step_name` step that also transforms a column that is a subset
181+
# of the output column name
182+
is_outcome_subset <- map_lgl(step$columns, ~ grepl(.x, outcomes))
183+
if (inherits(step, full_step_name) && any(is_outcome_subset)) {
184+
params <- step[[param_name]] %>%
185+
select(
186+
key_colnames(workflow$original_data, exclude = "time_value"),
187+
contains(step$columns[is_outcome_subset])
188+
) %>%
189+
rename(.lambda = contains(step$columns))
252190
break
253191
}
254192
}
255-
yj_params
193+
params
256194
}

R/quantile_pred-methods.R

+19-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ vec_proxy_equal.quantile_pred <- function(x, ...) {
111111
dplyr::select(-.row)
112112
}
113113

114-
115114
# quantiles by treating quantile_pred like a distribution -----------------
116115

117116

@@ -287,13 +286,32 @@ vec_math.quantile_pred <- function(.fn, .x, ...) {
287286
quantile_pred(.fn(.x), quantile_levels)
288287
}
289288

289+
#' Internal vctrs methods
290+
#'
291+
#' @import vctrs
292+
#' @keywords internal
293+
#' @name epipredict-vctrs
294+
290295
#' @importFrom vctrs vec_arith vec_arith.numeric
291296
#' @export
292297
#' @method vec_arith quantile_pred
293298
vec_arith.quantile_pred <- function(op, x, y, ...) {
294299
UseMethod("vec_arith.quantile_pred", y)
295300
}
296301

302+
303+
#' @export
304+
#' @method vec_arith.quantile_pred quantile_pred
305+
vec_arith.quantile_pred.quantile_pred <- function(op, x, y, ...) {
306+
all_quantiles <- unique(c(x %@% "quantile_levels", y %@% "quantile_levels"))
307+
op_fn <- getExportedValue("base", op)
308+
# Interpolate/extrapolate to the same quantiles
309+
x <- quantile.quantile_pred(x, all_quantiles)
310+
y <- quantile.quantile_pred(y, all_quantiles)
311+
out <- op_fn(x, y, ...)
312+
quantile_pred(out, all_quantiles)
313+
}
314+
297315
#' @export
298316
#' @method vec_arith.quantile_pred numeric
299317
vec_arith.quantile_pred.numeric <- function(op, x, y, ...) {

0 commit comments

Comments
 (0)