Skip to content

Commit

Permalink
Merge pull request #75 from mayer79/better-order-in-predict
Browse files Browse the repository at this point in the history
Improve imputation order in predict()
  • Loading branch information
mayer79 authored Jul 31, 2024
2 parents 23156f6 + 7def1e2 commit 33823d6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
9 changes: 4 additions & 5 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@ Out-of-sample application is now possible! Thanks to [@jeandigitale](https://git
This means you can run `imp <- missRanger(..., keep_forests = TRUE)` and then apply its models to new data via `predict(imp, newdata)`. The "missRanger" object can be saved/loaded as binary file, e.g, via `saveRDS()`/`readRDS()` for later use.

Note that out-of-sample imputation works best for rows in `newdata` with only one
missing value (actually counting only missings in variables used as covariates in random forests). We call this the "easy case". In the "hard case",
missing value (counting only missings in variables used as covariates in random forests). We call this the "easy case". In the "hard case",
even multiple iterations (set by `iter`) can lead to unsatisfactory results.

The out-of-sample algorithm works as follows:

1. Impute univariately all relevant columns by randomly drawing values
from the original, unimputed data. This step will only impact "hard case" rows.
from the original unimputed data. This step will only impact "hard case" rows.
2. Replace univariate imputations by predictions of random forests. This is done
sequentially over variables in decreasing order of missings in "hard case"
rows (to minimize the impact of univariate imputations).
Optionally, this is followed by predictive mean matching (PMM).
sequentially over variables, where the variables are sorted to minimize the impact
of univariate imputations. Optionally, this is followed by predictive mean matching (PMM).
3. Repeat Step 2 for "hard case" rows multiple times.

### Possibly breaking changes
Expand Down
46 changes: 24 additions & 22 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,17 @@ summary.missRanger <- function(object, ...) {
#' This can be enforced by `predict(..., iter = 0)` or via `missRanger(. ~ 1, ...)`.
#'
#' Note that out-of-sample imputation works best for rows in `newdata` with only one
#' missing value (actually counting only missings in variables used as covariates
#' missing value (counting only missings in variables used as covariates
#' in random forests). We call this the "easy case". In the "hard case",
#' even multiple iterations (set by `iter`) can lead to unsatisfactory results.
#'
#' @details
#' The out-of-sample algorithm works as follows:
#' 1. Impute univariately all relevant columns by randomly drawing values
#' from the original, unimputed data. This step will only impact "hard case" rows.
#' from the original unimputed data. This step will only impact "hard case" rows.
#' 2. Replace univariate imputations by predictions of random forests. This is done
#' sequentially over variables in decreasing order of missings in "hard case" rows
#' (to minimize the impact of univariate imputations). Optionally, this is followed
#' by predictive mean matching (PMM).
#' sequentially over variables, where the variables are sorted to minimize the impact
#' of univariate imputations. Optionally, this is followed by predictive mean matching (PMM).
#' 3. Repeat Step 2 for "hard case" rows multiple times.
#'
#' @param object 'missRanger' object.
Expand Down Expand Up @@ -140,7 +139,7 @@ predict.missRanger <- function(
v_orig <- data_raw[[v]]

if (all(is.na(v_new))) {
next # NA can be of wrong class!
next # NA of wrong class is fine!
}
# class() distinguishes numeric, integer, logical, factor, character, Date, ...
# - variables in to_impute are numeric, integer, logical, factor, or character
Expand Down Expand Up @@ -169,8 +168,7 @@ predict.missRanger <- function(
}

# UNIVARIATE IMPUTATION
# has no effect for "easy case" rows, but is not very expensive


for (v in to_impute) {
bad <- to_fill[, v]
v_orig <- data_raw[[v]]
Expand All @@ -183,7 +181,7 @@ predict.missRanger <- function(
}
}

if (length(impute_by) == 0L || iter < 1L) {
if (length(impute_by) == 0L || iter == 0L) {
if (verbose >= 1L) {
message("\nOnly univariate imputations done")
}
Expand All @@ -198,7 +196,7 @@ predict.missRanger <- function(

# Do we have a random forest for all variables with missings?
# This can fire only if the first iteration in missRanger() was the best, and only
# for maximal one variable.
# for maximal one variable. It is a rare case.
forests_missing <- setdiff(to_impute, names(object$forests))
if (length(forests_missing) > 0L) {
if (verbose >= 1L) {
Expand All @@ -210,20 +208,23 @@ predict.missRanger <- function(
to_impute <- setdiff(to_impute, forests_missing)
}

# Do we have rows of "hard case"? If no, a single iteration is sufficient.
easy <- rowSums(to_fill[, intersect(to_impute, impute_by), drop = FALSE]) <= 1L
if (all(easy)) {
# Do we have rows of "hard case"? If no, a single iteration is sufficient
hard_cols <- intersect(to_impute, impute_by)
hard_rows <- rowSums(to_fill[, hard_cols, drop = FALSE]) > 1L
if (!any(hard_rows)) {
iter <- 1L
} else {
# We impute first the column with most missings in *hard case* rows to minimize
# impact of univariate imputations (here, the case above with missing forest is
# ignored for simplicity)
hard_counts <- colSums(to_fill[, to_impute, drop = FALSE] & !easy)
ord <- order(hard_counts, decreasing = TRUE)
to_impute <- to_impute[ord]
hard_counts <- hard_counts[ord]
}

# We first impute hard columns, then the rest.
# Sorting hard columns is done in decreasing order of missings, counting only
# rows of hard case. Sorting of the rest is irrelevant.
# We ignore the special case where one forest is missing
hard_counts <- colSums(to_fill[hard_rows, hard_cols, drop = FALSE])
to_impute <- c(
hard_cols[order(hard_counts, decreasing = TRUE)],
setdiff(to_impute, hard_cols) # rest
)

for (j in seq_len(iter)) {
for (v in to_impute) {
pred <- stats::predict(
Expand All @@ -248,7 +249,8 @@ predict.missRanger <- function(
newdata[[v]][to_fill[, v]] <- pred
}
if (j == 1L && iter > 1L) {
to_fill <- to_fill & !easy
to_fill <- to_fill & hard_rows
hard_counts <- colSums(to_fill[, to_impute, drop = FALSE])
to_impute <- to_impute[hard_counts > 0L] # Need to fill only hard cases when j>1
}
}
Expand Down
9 changes: 4 additions & 5 deletions man/predict.missRanger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 33823d6

Please sign in to comment.