Skip to content

Commit 4a348c8

Browse files
committed
shifting must happen before joining to avoid duplicated rows
1 parent c1cf5ba commit 4a348c8

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

Diff for: R/step_climate.R

+8-9
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ step_climate <-
165165
arg_is_lgl_scalar(skip)
166166

167167
time_aggr <- switch(time_type,
168-
epiweek = lubridate::epiweek, week = lubridate::week,
168+
epiweek = lubridate::epiweek, week = lubridate::isoweek,
169169
month = lubridate::month, day = lubridate::yday)
170170

171171
recipes::add_step(
@@ -243,11 +243,15 @@ prep.step_climate <- function(x, training, info = NULL, ...) {
243243
modulus <- switch(x$time_type, epiweek = 53L, week = 53L, month = 12L, day = 365L)
244244

245245
fn <- switch(x$center_method,
246-
mean = function(x, w) weighted.mean(x, w, na.rm = TRUE),
246+
mean = function(x, w) stats::weighted.mean(x, w, na.rm = TRUE),
247247
median = function(x, w) median(x, na.rm = TRUE))
248248

249249
climate_table <- training %>%
250-
mutate(.idx = x$time_aggr(time_value), .weights = wts) %>%
250+
mutate(
251+
.idx = x$time_aggr(time_value), .weights = wts,
252+
.idx = (.idx - x$forecast_ahead) %% modulus,
253+
.idx = dplyr::case_when(.idx == 0 ~ modulus, TRUE ~ .idx)
254+
) %>%
251255
select(.idx, .weights, all_of(c(col_names, x$epi_keys))) %>%
252256
tidyr::pivot_longer(all_of(unname(col_names))) %>%
253257
dplyr::reframe(
@@ -279,14 +283,9 @@ prep.step_climate <- function(x, training, info = NULL, ...) {
279283
}
280284

281285

282-
283286
#' @export
284287
bake.step_climate <- function(object, new_data, ...) {
285-
climate_table <- object$climate_table %>%
286-
mutate(
287-
.idx = (.idx - object$forecast_ahead) %% object$modulus,
288-
.idx = dplyr::case_when(.idx == 0 ~ object$modulus, TRUE ~ .idx)
289-
)
288+
climate_table <- object$climate_table
290289
new_data %>%
291290
mutate(.idx = object$time_aggr(time_value)) %>%
292291
left_join(climate_table, by = c(".idx", object$epi_keys)) %>%

Diff for: tests/testthat/test-step_climate.R

+7-2
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,20 @@ test_that("leading the climate predictor works as expected", {
9999
step_epi_naomit()
100100
p <- prep(r, x)
101101

102-
expected_res <- tibble(.idx = 1:53, climate_y = c(2, 2:25, 25, 25, 25:2, 2, 2))
102+
expected_res <- tibble(.idx = 1:53, climate_y = c(2, 2:25, 25, 25, 25:2, 2, 2)) %>%
103+
mutate(
104+
.idx = (.idx - 2L) %% 53,
105+
.idx = dplyr::case_when(.idx == 0 ~ 53L, TRUE ~ .idx)
106+
) %>%
107+
arrange(.idx)
103108
expect_equal(p$steps[[3]]$climate_table, expected_res)
104109

105110
b <- bake(p, new_data = NULL)
106111
expect_identical(max(b$time_value), as.Date("2021-12-17")) # last date with no NAs
107112
# expected climate predictor should be shifted forward by 2 weeks
108113
expected_climate_pred <- x %>%
109114
mutate(
110-
.idx = (lubridate::epiweek(time_value) + 2) %% 53,
115+
.idx = lubridate::epiweek(time_value) %% 53,
111116
.idx = dplyr::case_when(.idx == 0 ~ 53, TRUE ~ .idx)
112117
) %>%
113118
left_join(expected_res, by = join_by(.idx)) %>%

0 commit comments

Comments
 (0)