diff --git a/NEWS.md b/NEWS.md index 34d0dca..8cde291 100644 --- a/NEWS.md +++ b/NEWS.md @@ -15,9 +15,9 @@ The out-of-sample algorithm works as follows: 1. Impute univariately all relevant columns by randomly drawing values from the original, unimputed data. This step will only impact "hard case" rows. 2. Replace univariate imputations by predictions of random forests. This is done - sequentially over variablse in descending order of number of missings - (to minimize the impact of univariate imputations). Optionally, this is followed - by predictive mean matching (PMM). + sequentially over variables in decreasing order of missings in "hard case" + rows (to minimize the impact of univariate imputations). + Optionally, this is followed by predictive mean matching (PMM). 3. Repeat Step 2 for "hard case" rows multiple times. ### Possibly breaking changes diff --git a/R/methods.R b/R/methods.R index 4976b3f..e38f3ff 100644 --- a/R/methods.R +++ b/R/methods.R @@ -61,7 +61,7 @@ summary.missRanger <- function(object, ...) { #' 1. Impute univariately all relevant columns by randomly drawing values #' from the original, unimputed data. This step will only impact "hard case" rows. #' 2. Replace univariate imputations by predictions of random forests. This is done -#' sequentially over variables in descending order of number of missings +#' sequentially over variables in decreasing order of missings in "hard case" rows #' (to minimize the impact of univariate imputations). Optionally, this is followed #' by predictive mean matching (PMM). #' 3. Repeat Step 2 for "hard case" rows multiple times. @@ -104,11 +104,10 @@ predict.missRanger <- function( # (a) Only those in newdata to_impute <- intersect(object$to_impute, colnames(newdata)) - # (b) Only those with missings, and in decreasing order - # to minimize impact of univariate imputations + # (b) Only those with missings to_fill <- is.na(newdata[, to_impute, drop = FALSE]) - m <- sort(colSums(to_fill), decreasing = TRUE) - to_impute <- names(m[m > 0]) + missing_counts <- colSums(to_fill) + to_impute <- to_impute[missing_counts > 0L] to_fill <- to_fill[, to_impute, drop = FALSE] if (length(to_impute) == 0L) { @@ -201,23 +200,32 @@ predict.missRanger <- function( # This can fire only if the first iteration in missRanger() was the best, and only # for maximal one variable. forests_missing <- setdiff(to_impute, names(object$forests)) - if (verbose >= 1L && length(forests_missing) > 0L) { - message( - "\nNo random forest for ", forests_missing, - ". Univariate imputation done for this variable." - ) + if (length(forests_missing) > 0L) { + if (verbose >= 1L) { + message( + "\nNo random forest for ", forests_missing, + ". Univariate imputation done for this variable." + ) + } + to_impute <- setdiff(to_impute, forests_missing) } - to_impute <- setdiff(to_impute, forests_missing) # Do we have rows of "hard case"? If no, a single iteration is sufficient. easy <- rowSums(to_fill[, intersect(to_impute, impute_by), drop = FALSE]) <= 1L if (all(easy)) { iter <- 1L + } else { + # We impute first the column with most missings in *hard case* rows to minimize + # impact of univariate imputations (here, the case above with missing forest is + # ignored for simplicity) + hard_counts <- colSums(to_fill[, to_impute, drop = FALSE] & !easy) + ord <- order(hard_counts, decreasing = TRUE) + to_impute <- to_impute[ord] + hard_counts <- hard_counts[ord] } for (j in seq_len(iter)) { for (v in to_impute) { - y <- newdata[[v]] pred <- stats::predict( object$forests[[v]], newdata[to_fill[, v], ], @@ -232,15 +240,16 @@ predict.missRanger <- function( ytrain <- ytrain[!is.na(ytrain)] # To align with OOB predictions } pred <- pmm(xtrain = xtrain, xtest = pred, ytrain = ytrain, k = pmm.k) - } else if (is.logical(y)) { + } else if (is.logical(newdata[[v]])) { pred <- as.logical(pred) - } else if (is.character(y)) { + } else if (is.character(newdata[[v]])) { pred <- as.character(pred) } newdata[[v]][to_fill[, v]] <- pred } - if (j == 1L) { + if (j == 1L && iter > 1L) { to_fill <- to_fill & !easy + to_impute <- to_impute[hard_counts > 0L] # Need to fill only hard cases when j>1 } } return(newdata) diff --git a/man/predict.missRanger.Rd b/man/predict.missRanger.Rd index 4cda811..e61c877 100644 --- a/man/predict.missRanger.Rd +++ b/man/predict.missRanger.Rd @@ -52,7 +52,7 @@ The out-of-sample algorithm works as follows: \item Impute univariately all relevant columns by randomly drawing values from the original, unimputed data. This step will only impact "hard case" rows. \item Replace univariate imputations by predictions of random forests. This is done -sequentially over variables in descending order of number of missings +sequentially over variables in decreasing order of missings in "hard case" rows (to minimize the impact of univariate imputations). Optionally, this is followed by predictive mean matching (PMM). \item Repeat Step 2 for "hard case" rows multiple times. diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R index f5ab86d..f7edca8 100644 --- a/tests/testthat/test-predict.R +++ b/tests/testthat/test-predict.R @@ -152,6 +152,19 @@ test_that("OOS does not fail if there is all missings, even of wrong type", { expect_no_error(x1 <- predict(imp, new_rows, seed = 1L)) }) +test_that("case works where easy case fills a complete column", { + imp <- missRanger(iris2, verbose = 0, seed = 1L, num.trees = 10, keep_forests = TRUE) + + X <- data.frame( + Sepal.Length = c(5.1, NA), + Sepal.Width = c(3.5, 3.4), + Petal.Length = c(NA, 1.4), + Petal.Width = c(NA, 0.3), + Species = factor("setosa") + ) + expect_no_error(predict(imp, X)) +}) + n <- 200L X <- data.frame( @@ -164,6 +177,15 @@ X <- data.frame( X_NA <- generateNA(X[1:5], p = 0.2, seed = 1L) +test_that("non-syntactic column names work", { + X_NA2 <- X_NA + colnames(X_NA2)[1:2] <- c("1bad name", "2 also bad") + + imp1 <- missRanger(X_NA2, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE) + + expect_equal(colnames(predict(imp1, head(X_NA2))), colnames(X_NA2)) +}) + test_that("OOS does not fail if there is all missings, even of wrong type (MORE TYPES)", { imp1 <- missRanger( X_NA, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE @@ -182,12 +204,3 @@ test_that("OOS does not fail if there is all missings, even of wrong type (MORE expect_equal(lapply(x1, class), xp) }) -test_that("non-syntactic column names work", { - X_NA2 <- X_NA - colnames(X_NA2)[1:2] <- c("1bad name", "2 also bad") - - imp1 <- missRanger(X_NA2, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE) - - expect_equal(colnames(predict(imp1, head(X_NA2))), colnames(X_NA2)) -}) -