Skip to content

Commit

Permalink
Merge pull request #74 from mayer79/fix-predict-empty-column
Browse files Browse the repository at this point in the history
Fix predict empty column
  • Loading branch information
mayer79 authored Jul 30, 2024
2 parents 1446de6 + ca4be68 commit 23156f6
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 28 deletions.
6 changes: 3 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ The out-of-sample algorithm works as follows:
1. Impute univariately all relevant columns by randomly drawing values
from the original, unimputed data. This step will only impact "hard case" rows.
2. Replace univariate imputations by predictions of random forests. This is done
sequentially over variablse in descending order of number of missings
(to minimize the impact of univariate imputations). Optionally, this is followed
by predictive mean matching (PMM).
sequentially over variables in decreasing order of missings in "hard case"
rows (to minimize the impact of univariate imputations).
Optionally, this is followed by predictive mean matching (PMM).
3. Repeat Step 2 for "hard case" rows multiple times.

### Possibly breaking changes
Expand Down
39 changes: 24 additions & 15 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ summary.missRanger <- function(object, ...) {
#' 1. Impute univariately all relevant columns by randomly drawing values
#' from the original, unimputed data. This step will only impact "hard case" rows.
#' 2. Replace univariate imputations by predictions of random forests. This is done
#' sequentially over variables in descending order of number of missings
#' sequentially over variables in decreasing order of missings in "hard case" rows
#' (to minimize the impact of univariate imputations). Optionally, this is followed
#' by predictive mean matching (PMM).
#' 3. Repeat Step 2 for "hard case" rows multiple times.
Expand Down Expand Up @@ -104,11 +104,10 @@ predict.missRanger <- function(
# (a) Only those in newdata
to_impute <- intersect(object$to_impute, colnames(newdata))

# (b) Only those with missings, and in decreasing order
# to minimize impact of univariate imputations
# (b) Only those with missings
to_fill <- is.na(newdata[, to_impute, drop = FALSE])
m <- sort(colSums(to_fill), decreasing = TRUE)
to_impute <- names(m[m > 0])
missing_counts <- colSums(to_fill)
to_impute <- to_impute[missing_counts > 0L]
to_fill <- to_fill[, to_impute, drop = FALSE]

if (length(to_impute) == 0L) {
Expand Down Expand Up @@ -201,23 +200,32 @@ predict.missRanger <- function(
# This can fire only if the first iteration in missRanger() was the best, and only
# for maximal one variable.
forests_missing <- setdiff(to_impute, names(object$forests))
if (verbose >= 1L && length(forests_missing) > 0L) {
message(
"\nNo random forest for ", forests_missing,
". Univariate imputation done for this variable."
)
if (length(forests_missing) > 0L) {
if (verbose >= 1L) {
message(
"\nNo random forest for ", forests_missing,
". Univariate imputation done for this variable."
)
}
to_impute <- setdiff(to_impute, forests_missing)
}
to_impute <- setdiff(to_impute, forests_missing)

# Do we have rows of "hard case"? If no, a single iteration is sufficient.
easy <- rowSums(to_fill[, intersect(to_impute, impute_by), drop = FALSE]) <= 1L
if (all(easy)) {
iter <- 1L
} else {
# We impute first the column with most missings in *hard case* rows to minimize
# impact of univariate imputations (here, the case above with missing forest is
# ignored for simplicity)
hard_counts <- colSums(to_fill[, to_impute, drop = FALSE] & !easy)
ord <- order(hard_counts, decreasing = TRUE)
to_impute <- to_impute[ord]
hard_counts <- hard_counts[ord]
}

for (j in seq_len(iter)) {
for (v in to_impute) {
y <- newdata[[v]]
pred <- stats::predict(
object$forests[[v]],
newdata[to_fill[, v], ],
Expand All @@ -232,15 +240,16 @@ predict.missRanger <- function(
ytrain <- ytrain[!is.na(ytrain)] # To align with OOB predictions
}
pred <- pmm(xtrain = xtrain, xtest = pred, ytrain = ytrain, k = pmm.k)
} else if (is.logical(y)) {
} else if (is.logical(newdata[[v]])) {
pred <- as.logical(pred)
} else if (is.character(y)) {
} else if (is.character(newdata[[v]])) {
pred <- as.character(pred)
}
newdata[[v]][to_fill[, v]] <- pred
}
if (j == 1L) {
if (j == 1L && iter > 1L) {
to_fill <- to_fill & !easy
to_impute <- to_impute[hard_counts > 0L] # Need to fill only hard cases when j>1
}
}
return(newdata)
Expand Down
2 changes: 1 addition & 1 deletion man/predict.missRanger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 22 additions & 9 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ test_that("OOS does not fail if there is all missings, even of wrong type", {
expect_no_error(x1 <- predict(imp, new_rows, seed = 1L))
})

test_that("case works where easy case fills a complete column", {
imp <- missRanger(iris2, verbose = 0, seed = 1L, num.trees = 10, keep_forests = TRUE)

X <- data.frame(
Sepal.Length = c(5.1, NA),
Sepal.Width = c(3.5, 3.4),
Petal.Length = c(NA, 1.4),
Petal.Width = c(NA, 0.3),
Species = factor("setosa")
)
expect_no_error(predict(imp, X))
})

n <- 200L

X <- data.frame(
Expand All @@ -164,6 +177,15 @@ X <- data.frame(

X_NA <- generateNA(X[1:5], p = 0.2, seed = 1L)

test_that("non-syntactic column names work", {
X_NA2 <- X_NA
colnames(X_NA2)[1:2] <- c("1bad name", "2 also bad")

imp1 <- missRanger(X_NA2, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE)

expect_equal(colnames(predict(imp1, head(X_NA2))), colnames(X_NA2))
})

test_that("OOS does not fail if there is all missings, even of wrong type (MORE TYPES)", {
imp1 <- missRanger(
X_NA, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE
Expand All @@ -182,12 +204,3 @@ test_that("OOS does not fail if there is all missings, even of wrong type (MORE
expect_equal(lapply(x1, class), xp)
})

test_that("non-syntactic column names work", {
X_NA2 <- X_NA
colnames(X_NA2)[1:2] <- c("1bad name", "2 also bad")

imp1 <- missRanger(X_NA2, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE)

expect_equal(colnames(predict(imp1, head(X_NA2))), colnames(X_NA2))
})

0 comments on commit 23156f6

Please sign in to comment.