Skip to content

Commit c8f7e1d

Browse files
committed
tests for utils-latency and accompanying fixes
1 parent 0335dd6 commit c8f7e1d

9 files changed

+180
-45
lines changed

NAMESPACE

+6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ S3method(predict,epi_workflow)
5151
S3method(predict,flatline)
5252
S3method(prep,check_enough_train_data)
5353
S3method(prep,epi_recipe)
54+
S3method(prep,step_adjust_latency)
5455
S3method(prep,step_epi_ahead)
5556
S3method(prep,step_epi_lag)
5657
S3method(prep,step_growth_rate)
@@ -180,6 +181,7 @@ export(remove_frosting)
180181
export(remove_model)
181182
export(slather)
182183
export(smooth_quantile_reg)
184+
export(step_adjust_latency)
183185
export(step_epi_ahead)
184186
export(step_epi_lag)
185187
export(step_epi_naomit)
@@ -207,10 +209,13 @@ importFrom(checkmate,assert_number)
207209
importFrom(checkmate,assert_numeric)
208210
importFrom(checkmate,assert_scalar)
209211
importFrom(cli,cli_abort)
212+
importFrom(dplyr,"%>%")
210213
importFrom(dplyr,across)
211214
importFrom(dplyr,all_of)
212215
importFrom(dplyr,group_by)
213216
importFrom(dplyr,n)
217+
importFrom(dplyr,pull)
218+
importFrom(dplyr,rowwise)
214219
importFrom(dplyr,summarise)
215220
importFrom(dplyr,ungroup)
216221
importFrom(epiprocess,growth_rate)
@@ -244,6 +249,7 @@ importFrom(stats,predict)
244249
importFrom(stats,qnorm)
245250
importFrom(stats,quantile)
246251
importFrom(stats,residuals)
252+
importFrom(stringr,str_match)
247253
importFrom(tibble,as_tibble)
248254
importFrom(tibble,is_tibble)
249255
importFrom(tibble,tibble)

R/step_adjust_latency.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,8 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
160160
#' date, rather than relative to the last day of data
161161
#' @param new_data assumes that this already has lag/ahead columns that we need
162162
#' to adjust
163-
#' @importFrom dplyr %>%
164-
#' @keywords internal
165163
#' @importFrom dplyr %>% pull
164+
#' @keywords internal
166165
bake.step_adjust_latency <- function(object, new_data, ...) {
167166
sign_shift <- get_sign(object)
168167
# get the columns used, even if it's all of them
@@ -178,7 +177,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) {
178177
# infer the correct columns to be working with from the previous
179178
# transformations
180179
shift_cols <- get_shifted_column_tibble(
181-
object, new_data, terms_used, as_of,
180+
object$prefix, new_data, terms_used, as_of,
182181
sign_shift
183182
)
184183

R/utils-latency.R

+24-12
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ extend_either <- function(new_data, shift_cols, keys) {
2020
key_cols = keys
2121
)
2222
}) %>%
23+
map(\(x) na.trim(x)) %>% # TODO need to talk about this
2324
reduce(
2425
dplyr::full_join,
2526
by = keys
2627
)
28+
2729
return(new_data %>%
28-
select(-shift_cols$original_name) %>%
30+
select(-shift_cols$original_name) %>% # drop the original versions
2931
dplyr::full_join(shifted, by = keys) %>%
3032
dplyr::group_by(dplyr::across(dplyr::all_of(keys[-1]))) %>%
3133
dplyr::arrange(time_value) %>%
@@ -34,36 +36,44 @@ extend_either <- function(new_data, shift_cols, keys) {
3436

3537
#' find the columns added with the lags or aheads, and the amounts they have
3638
#' been changed
37-
#' @param object the step and its parameters
39+
#' @param prefix the prefix indicating if we are adjusting lags or aheads
3840
#' @param new_data the data transformed so far
3941
#' @return a tibble with columns `column` (relevant shifted names), `shift` (the
4042
#' amount that one is shifted), `latency` (original columns difference between
4143
#' max_time_value and as_of (on a per-initial column basis)),
4244
#' `effective_shift` (shifts+latency), and `new_name` (adjusted names with the
4345
#' effective_shift)
4446
#' @keywords internal
47+
#' @importFrom stringr str_match
48+
#' @importFrom dplyr rowwise %>%
4549
get_shifted_column_tibble <- function(
46-
object, new_data, terms_used, as_of, sign_shift) {
47-
prefix <- object$prefix
50+
prefix, new_data, terms_used, as_of, sign_shift, call = caller_env()) {
4851
relevant_columns <- names(new_data)[grepl(prefix, names(new_data))]
4952
to_keep <- rep(FALSE, length(relevant_columns))
5053
for (col_name in terms_used) {
5154
to_keep <- to_keep | grepl(col_name, relevant_columns)
5255
}
5356
relevant_columns <- relevant_columns[to_keep]
57+
if (length(relevant_columns) == 0) {
58+
cli::cli_abort("There is no column(s) {terms_used}.",
59+
current_column_names = names(new_data),
60+
class = "epipredict_adjust_latency_nonexistent_column_used",
61+
call = call
62+
)
63+
}
5464
# TODO ask about a less jank way to do this
55-
shift_amounts <- as.integer(str_match(
65+
shift_amounts <- as.integer(stringr::str_match(
5666
relevant_columns,
5767
"_\\d+_"
5868
) %>%
5969
`[`(, 1) %>%
60-
str_match("\\d+") %>%
70+
stringr::str_match("\\d+") %>%
6171
`[`(, 1))
6272
shift_cols <- dplyr::tibble(
6373
original_name = relevant_columns,
6474
shifts = shift_amounts
6575
)
66-
shift_cols %>%
76+
shift_cols %<>%
6777
rowwise() %>%
6878
# add the latencies to shift_cols
6979
mutate(latency = get_latency(
@@ -72,8 +82,10 @@ get_shifted_column_tibble <- function(
7282
ungroup() %>%
7383
# add the updated names to shift_cols
7484
mutate(
75-
effective_shift = shifts + latency,
76-
new_name = adjust_name(prefix, shifts, original_name, latency)
85+
effective_shift = shifts + abs(latency)
86+
) %>%
87+
mutate(
88+
new_name = adjust_name(prefix, original_name, effective_shift)
7789
)
7890
return(shift_cols)
7991
}
@@ -136,9 +148,9 @@ get_asof <- function(object, new_data) {
136148
#' adjust the shifts by latency for the names in column assumes e.g.
137149
#' `"lag_6_case_rate"` and returns something like `"lag_10_case_rate"`
138150
#' @keywords internal
139-
adjust_name <- function(prefix, shifts, column, latency) {
151+
adjust_name <- function(prefix, column, effective_shift) {
140152
pattern <- paste0(prefix, "\\d+", "_")
141-
adjusted_shifts <- paste0(prefix, shifts + latency, "_")
153+
adjusted_shifts <- paste0(prefix, effective_shift, "_")
142154
stringi::stri_replace_all_regex(
143155
column,
144156
pattern, adjusted_shifts
@@ -154,5 +166,5 @@ get_latency <- function(new_data, as_of, column, shift_amount, sign_shift) {
154166
drop_na(column) %>%
155167
pull(time_value) %>%
156168
max()
157-
return(as.integer(as_of - (shift_max_date - sign_shift * shift_amount)))
169+
return(as.integer(sign_shift * (as_of - shift_max_date) + shift_amount))
158170
}

man/create_layer.Rd

+5-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/epi_shift.Rd

-28
This file was deleted.

man/step_epi_shift.Rd

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/step_growth_rate.Rd

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/step_lag_difference.Rd

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-utils_latency.R

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
time_values <- as.Date("2021-01-01") + 0:199
2+
as_of <- max(time_values) + 5
3+
max_time <- max(time_values)
4+
old_data <- tibble(
5+
geo_value = rep("place", 200),
6+
time_value = as.Date("2021-01-01") + 0:199,
7+
case_rate = sqrt(1:200) + atan(0.1 * 1:200) + sin(5 * 1:200) + 1,
8+
tmp_death_rate = atan(0.1 * 1:200) + cos(5 * 1:200) + 1
9+
) %>%
10+
as_epi_df(as_of = as_of)
11+
old_data %>% tail()
12+
keys <- c("time_value", "geo_value")
13+
old_data %<>% full_join(epi_shift_single(
14+
old_data, "tmp_death_rate", 1, "death_rate", keys
15+
), by = keys) %>%
16+
select(-tmp_death_rate)
17+
# old data is created so that death rate has a latency of 4, while case_rate has
18+
# a latency of 5
19+
modified_data <-
20+
old_data %>%
21+
dplyr::full_join(
22+
epi_shift_single(old_data, "case_rate", -4, "ahead_4_case_rate", keys),
23+
by = keys
24+
) %>%
25+
dplyr::full_join(
26+
epi_shift_single(old_data, "case_rate", 3, "lag_3_case_rate", keys),
27+
by = keys
28+
) %>%
29+
dplyr::full_join(
30+
epi_shift_single(old_data, "death_rate", 7, "lag_7_death_rate", keys),
31+
by = keys
32+
) %>%
33+
arrange(time_value)
34+
modified_data %>% tail()
35+
as_of - (modified_data %>% filter(!is.na(ahead_4_case_rate)) %>% pull(time_value) %>% max())
36+
all_shift_cols <- tibble::tribble(
37+
~original_name, ~shifts, ~latency, ~effective_shift, ~new_name,
38+
"lag_3_case_rate", 3, 5, 8, "lag_8_case_rate",
39+
"lag_7_death_rate", 7, 4, 11, "lag_11_death_rate",
40+
"ahead_4_case_rate", 4, -5, 9, "ahead_9_case_rate"
41+
)
42+
43+
test_that("get_latency works", {
44+
expect_equal(get_latency(modified_data, as_of, "lag_7_death_rate", 7, 1), 4)
45+
expect_equal(get_latency(modified_data, as_of, "lag_3_case_rate", 3, 1), 5)
46+
# get_latency does't check the shift_amount
47+
expect_equal(get_latency(modified_data, as_of, "lag_3_case_rate", 4, 1), 6)
48+
# ahead works correctly
49+
expect_equal(get_latency(modified_data, as_of, "ahead_4_case_rate", 4, -1), -5)
50+
# setting the wrong sign doubles the shift and gets the sign wrong
51+
expect_equal(get_latency(modified_data, as_of, "ahead_4_case_rate", 4, 1), 5 + 4 * 2)
52+
})
53+
54+
test_that("adjust_name works", {
55+
expect_equal(
56+
adjust_name("lag_", "lag_5_case_rate_13", 10),
57+
"lag_10_case_rate_13"
58+
)
59+
# it won't change a column with the wrong prefix
60+
expect_equal(
61+
adjust_name("lag_", "ahead_5_case_rate", 10),
62+
"ahead_5_case_rate"
63+
)
64+
# it works on vectors of names
65+
expect_equal(
66+
adjust_name("lag_", c("lag_5_floop_35", "lag_2342352_case"), c(10, 7)),
67+
c("lag_10_floop_35", "lag_7_case")
68+
)
69+
})
70+
71+
test_that("get_asof works", {
72+
object <- list(info = tribble(
73+
~variable, ~type, ~role, ~source,
74+
"time_value", "date", "time_value", "original",
75+
"geo_value", "nominal", "geo_value", "original",
76+
"case_rate", "numeric", "raw", "original",
77+
"death_rate", "numeric", "raw", "original",
78+
"not_real", "numeric", "predictor", "derived"
79+
))
80+
expect_equal(get_asof(object, modified_data), as_of)
81+
})
82+
83+
test_that("get_shifted_column_tibble works", {
84+
case_lag <- get_shifted_column_tibble(
85+
"lag_", modified_data,
86+
"case_rate", as_of, 1
87+
)
88+
expect_equal(case_lag, all_shift_cols[1, ])
89+
90+
death_lag <- get_shifted_column_tibble(
91+
"lag_", modified_data,
92+
"death_rate", as_of, 1
93+
)
94+
expect_equal(death_lag, all_shift_cols[2, ])
95+
96+
both_lag <- get_shifted_column_tibble(
97+
"lag_", modified_data,
98+
c("case_rate", "death_rate"), as_of, 1
99+
)
100+
expect_equal(both_lag, all_shift_cols[1:2, ])
101+
102+
case_ahead <- get_shifted_column_tibble(
103+
"ahead_", modified_data,
104+
"case_rate", as_of, -1
105+
)
106+
expect_equal(case_ahead, all_shift_cols[3, ])
107+
})
108+
test_that("get_shifted_column_tibble objects to non-columns", {
109+
expect_error(
110+
get_shifted_column_tibble(
111+
"lag_", modified_data, "not_present", as_of, 1
112+
),
113+
class = "epipredict_adjust_latency_nonexistent_column_used"
114+
)
115+
})
116+
test_that("extend_either works", {
117+
keys <- c("geo_value", "time_value")
118+
# extend_either doesn't differentiate between the directions, it just moves
119+
# things
120+
expected_post_shift <-
121+
old_data %>%
122+
dplyr::full_join(
123+
epi_shift_single(old_data, "case_rate", 8, "lag_8_case_rate", keys),
124+
by = keys
125+
) %>%
126+
dplyr::full_join(
127+
epi_shift_single(old_data, "death_rate", 11, "lag_11_death_rate", keys),
128+
by = keys
129+
) %>%
130+
dplyr::full_join(
131+
epi_shift_single(old_data, "case_rate", -9, "ahead_9_case_rate", keys),
132+
by = keys
133+
) %>%
134+
arrange(time_value)
135+
expect_equal(
136+
extend_either(modified_data, all_shift_cols, keys) %>% arrange(time_value),
137+
expected_post_shift
138+
)
139+
})

0 commit comments

Comments
 (0)