Skip to content

Commit

Permalink
Merge pull request #73 from mayer79/predict-num-threads
Browse files Browse the repository at this point in the history
Predict num threads
  • Loading branch information
mayer79 authored Jul 28, 2024
2 parents d935791 + fded956 commit 1446de6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ The out-of-sample algorithm works as follows:
- For variables that can't be used, more information is printed.
- If `keep_forests = TRUE`, the argument `data_only` is set to `FALSE` by default.
- "missRanger" object now stores `pmm.k`.
- `verbose` argument is passed to `ranger()` as well.

# missRanger 2.5.0

Expand Down
21 changes: 18 additions & 3 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,25 @@ summary.missRanger <- function(object, ...) {
#' @param pmm.k Number of candidate predictions of the original dataset
#' for predictive mean matching (PMM). By default the same value as during fitting.
#' @param iter Number of iterations for "hard case" rows. 0 for univariate imputation.
#' @param num.threads Number of threads used by ranger's predict function.
#' The default `NULL` uses all threads.
#' @param seed Integer seed used for initial univariate imputation and PMM.
#' @param verbose Should info be printed? (1 = yes/default, 0 for no).
#' @param ... Currently not used.
#' @param ... Passed to the predict function of ranger.
#' @export
#' @examples
#' iris2 <- generateNA(iris, seed = 20, p = c(Sepal.Length = 0.2, Species = 0.1))
#' imp <- missRanger(iris2, pmm.k = 5, num.trees = 100, keep_forests = TRUE, seed = 2)
#' predict(imp, head(iris2), seed = 3)
predict.missRanger <- function(
object, newdata, pmm.k = object$pmm.k, iter = 4L, seed = NULL, verbose = 1L, ...
object,
newdata,
pmm.k = object$pmm.k,
iter = 4L,
num.threads = NULL,
seed = NULL,
verbose = 1L,
...
) {
stopifnot(
"'newdata' should be a data.frame!" = is.data.frame(newdata),
Expand Down Expand Up @@ -209,7 +218,13 @@ predict.missRanger <- function(
for (j in seq_len(iter)) {
for (v in to_impute) {
y <- newdata[[v]]
pred <- stats::predict(object$forests[[v]], newdata[to_fill[, v], ])$predictions
pred <- stats::predict(
object$forests[[v]],
newdata[to_fill[, v], ],
num.threads = num.threads,
verbose = verbose >= 1L,
...
)$predictions
if (pmm.k >= 1) {
xtrain <- object$forests[[v]]$predictions
ytrain <- data_raw[[v]]
Expand Down
1 change: 1 addition & 0 deletions R/missRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ missRanger <- function(
save.memory = save.memory,
x = data[!v.na, completed, drop = FALSE],
y = y,
verbose = verbose >= 1,
...
)

Expand Down
6 changes: 5 additions & 1 deletion man/predict.missRanger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1446de6

Please sign in to comment.