From d1a6d664af6fff3e9b979826c56a1b2a0c9bd16c Mon Sep 17 00:00:00 2001 From: Tim Taylor Date: Fri, 14 Apr 2023 11:07:14 +0100 Subject: [PATCH] fix `fit.list()` --- DESCRIPTION | 2 +- NEWS.md | 5 +++++ R/fit-list.R | 19 ++++++++++++++++--- tests/testthat/test-lists.R | 5 +++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index ff07f8a..40d79c3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: trending Title: Model Temporal Trends -Version: 0.1.0 +Version: 0.1.0.9000 Authors@R: c( person( diff --git a/NEWS.md b/NEWS.md index 82b18dd..4f31b5b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +# trending (development version) + +* Fix a bug in `fit.list()` that was causing it to sometimes fail when used + within other functions. + # trending 0.1.0 ## Breaking changes diff --git a/R/fit-list.R b/R/fit-list.R index 1829f35..052e8f3 100644 --- a/R/fit-list.R +++ b/R/fit-list.R @@ -40,9 +40,22 @@ fit.list <- function(x, data, ...) { if (!all(vapply(x, inherits, logical(1), "trending_model"))) { stop("list entries should be `trending_model` objects", call. = FALSE) } - qfun <- bquote(lapply(x, fit, data = .(substitute(data)), as_tibble = FALSE)) - res <- eval(qfun) - nms <- names(x) + + # Fix for https://github.com/reconverse/trending/issues/22 + # TODO - improve this as very, very hacky + original___x <- x + tmp <- as.character(substitute(data)) + if (length(tmp) == 1L) { + assign(tmp[1L], data) + } + res <- vector("list", length(original___x)) + for (i in seq_along(res)) { + x_model <- original___x[[i]] + qfun <- bquote(fit(x_model, data = .(substitute(data)), as_tibble = FALSE)) + res[[i]] <- eval(qfun) + } + nms <- names(original___x) + if (!is.null(nms)) names(res) <- nms res <- lapply(seq_along(res[[1]]), function(i) lapply(res, "[[", i)) res <- tibble(result = res[[1]], warnings = res[[2]], errors = res[[3]]) diff --git a/tests/testthat/test-lists.R b/tests/testthat/test-lists.R index edf916f..ab7b0d5 100644 --- a/tests/testthat/test-lists.R +++ b/tests/testthat/test-lists.R @@ -17,6 +17,11 @@ test_that("lm_model", { list_pred_nmd <- predict(list_fit_nmd, mtcars) list_pred_from_model <- predict(list(l=l, nb=nb), mtcars) + # test for #22 fix + f <- function(y) fit(list(l, nb), y) + expect_identical(f(mtcars)$errors, list(NULL, NULL)) + + expect_equal(get_warnings(list_fit), list(NULL, NULL)) expect_equal(get_warnings(list_pred_from_model), list(l=NULL, nb=NULL)) expect_equal(get_errors(list_fit), list(NULL, NULL))