Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix predict empty column #74

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ The out-of-sample algorithm works as follows:
1. Impute univariately all relevant columns by randomly drawing values
from the original, unimputed data. This step will only impact "hard case" rows.
2. Replace univariate imputations by predictions of random forests. This is done
sequentially over variablse in descending order of number of missings
(to minimize the impact of univariate imputations). Optionally, this is followed
by predictive mean matching (PMM).
sequentially over variables in decreasing order of missings in "hard case"
rows (to minimize the impact of univariate imputations).
Optionally, this is followed by predictive mean matching (PMM).
3. Repeat Step 2 for "hard case" rows multiple times.

### Possibly breaking changes
Expand Down
39 changes: 24 additions & 15 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ summary.missRanger <- function(object, ...) {
#' 1. Impute univariately all relevant columns by randomly drawing values
#' from the original, unimputed data. This step will only impact "hard case" rows.
#' 2. Replace univariate imputations by predictions of random forests. This is done
#' sequentially over variables in descending order of number of missings
#' sequentially over variables in decreasing order of missings in "hard case" rows
#' (to minimize the impact of univariate imputations). Optionally, this is followed
#' by predictive mean matching (PMM).
#' 3. Repeat Step 2 for "hard case" rows multiple times.
Expand Down Expand Up @@ -104,11 +104,10 @@ predict.missRanger <- function(
# (a) Only those in newdata
to_impute <- intersect(object$to_impute, colnames(newdata))

# (b) Only those with missings, and in decreasing order
# to minimize impact of univariate imputations
# (b) Only those with missings
to_fill <- is.na(newdata[, to_impute, drop = FALSE])
m <- sort(colSums(to_fill), decreasing = TRUE)
to_impute <- names(m[m > 0])
missing_counts <- colSums(to_fill)
to_impute <- to_impute[missing_counts > 0L]
to_fill <- to_fill[, to_impute, drop = FALSE]

if (length(to_impute) == 0L) {
Expand Down Expand Up @@ -201,23 +200,32 @@ predict.missRanger <- function(
# This can fire only if the first iteration in missRanger() was the best, and only
# for maximal one variable.
forests_missing <- setdiff(to_impute, names(object$forests))
if (verbose >= 1L && length(forests_missing) > 0L) {
message(
"\nNo random forest for ", forests_missing,
". Univariate imputation done for this variable."
)
if (length(forests_missing) > 0L) {
if (verbose >= 1L) {
message(
"\nNo random forest for ", forests_missing,
". Univariate imputation done for this variable."
)
}
to_impute <- setdiff(to_impute, forests_missing)
}
to_impute <- setdiff(to_impute, forests_missing)

# Do we have rows of "hard case"? If no, a single iteration is sufficient.
easy <- rowSums(to_fill[, intersect(to_impute, impute_by), drop = FALSE]) <= 1L
if (all(easy)) {
iter <- 1L
} else {
# We impute first the column with most missings in *hard case* rows to minimize
# impact of univariate imputations (here, the case above with missing forest is
# ignored for simplicity)
hard_counts <- colSums(to_fill[, to_impute, drop = FALSE] & !easy)
ord <- order(hard_counts, decreasing = TRUE)
to_impute <- to_impute[ord]
hard_counts <- hard_counts[ord]
}

for (j in seq_len(iter)) {
for (v in to_impute) {
y <- newdata[[v]]
pred <- stats::predict(
object$forests[[v]],
newdata[to_fill[, v], ],
Expand All @@ -232,15 +240,16 @@ predict.missRanger <- function(
ytrain <- ytrain[!is.na(ytrain)] # To align with OOB predictions
}
pred <- pmm(xtrain = xtrain, xtest = pred, ytrain = ytrain, k = pmm.k)
} else if (is.logical(y)) {
} else if (is.logical(newdata[[v]])) {
pred <- as.logical(pred)
} else if (is.character(y)) {
} else if (is.character(newdata[[v]])) {
pred <- as.character(pred)
}
newdata[[v]][to_fill[, v]] <- pred
}
if (j == 1L) {
if (j == 1L && iter > 1L) {
to_fill <- to_fill & !easy
to_impute <- to_impute[hard_counts > 0L] # Need to fill only hard cases when j>1
}
}
return(newdata)
Expand Down
2 changes: 1 addition & 1 deletion man/predict.missRanger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 22 additions & 9 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ test_that("OOS does not fail if there is all missings, even of wrong type", {
expect_no_error(x1 <- predict(imp, new_rows, seed = 1L))
})

test_that("case works where easy case fills a complete column", {
imp <- missRanger(iris2, verbose = 0, seed = 1L, num.trees = 10, keep_forests = TRUE)

X <- data.frame(
Sepal.Length = c(5.1, NA),
Sepal.Width = c(3.5, 3.4),
Petal.Length = c(NA, 1.4),
Petal.Width = c(NA, 0.3),
Species = factor("setosa")
)
expect_no_error(predict(imp, X))
})

n <- 200L

X <- data.frame(
Expand All @@ -164,6 +177,15 @@ X <- data.frame(

X_NA <- generateNA(X[1:5], p = 0.2, seed = 1L)

test_that("non-syntactic column names work", {
X_NA2 <- X_NA
colnames(X_NA2)[1:2] <- c("1bad name", "2 also bad")

imp1 <- missRanger(X_NA2, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE)

expect_equal(colnames(predict(imp1, head(X_NA2))), colnames(X_NA2))
})

test_that("OOS does not fail if there is all missings, even of wrong type (MORE TYPES)", {
imp1 <- missRanger(
X_NA, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE
Expand All @@ -182,12 +204,3 @@ test_that("OOS does not fail if there is all missings, even of wrong type (MORE
expect_equal(lapply(x1, class), xp)
})

test_that("non-syntactic column names work", {
X_NA2 <- X_NA
colnames(X_NA2)[1:2] <- c("1bad name", "2 also bad")

imp1 <- missRanger(X_NA2, num.trees = 20L, verbose = 0L, seed = 1L, keep_forests = TRUE)

expect_equal(colnames(predict(imp1, head(X_NA2))), colnames(X_NA2))
})

Loading