From f271745b5f122f01d4766cdb541f2994775b0802 Mon Sep 17 00:00:00 2001 From: Veronika Maurerova Date: Fri, 11 Aug 2023 16:46:46 +0200 Subject: [PATCH] Fix make_metrics bug --- h2o-py/h2o/h2o.py | 1 + h2o-r/h2o-package/R/models.R | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/h2o-py/h2o/h2o.py b/h2o-py/h2o/h2o.py index 521cba050c24..94d4f664b322 100644 --- a/h2o-py/h2o/h2o.py +++ b/h2o-py/h2o/h2o.py @@ -2043,6 +2043,7 @@ def make_metrics(predicted, actual, domain=None, distribution=None, weights=None if weights is not None: params["weights_frame"] = weights.frame_id if treatment is not None: + assert treatment.ncol == 1, "`treatment` frame should have exactly 1 column" params["treatment_frame"] = treatment.frame_id allowed_auuc_types = ["qini", "lift", "gain", "AUTO"] assert auuc_type in allowed_auuc_types, "auuc_type should be "+(" ".join([str(type) for type in allowed_auuc_types])) diff --git a/h2o-r/h2o-package/R/models.R b/h2o-r/h2o-package/R/models.R index f6e4d509a4d8..5f7ae8138d97 100755 --- a/h2o-r/h2o-package/R/models.R +++ b/h2o-r/h2o-package/R/models.R @@ -1138,7 +1138,7 @@ h2o.make_metrics <- function(predicted, actuals, domain=NULL, distribution=NULL, predicted <- .validate.H2OFrame(predicted, required=TRUE) actuals <- .validate.H2OFrame(actuals, required=TRUE) weights <- .validate.H2OFrame(weights, required=FALSE) - treatment <- .validate.H2OFrame(treatment, required=TRUE) + treatment <- .validate.H2OFrame(treatment, required=FALSE) if (!is.character(auc_type)) stop("auc_type argument must be of type character") if (!(auc_type %in% c("MACRO_OVO", "MACRO_OVR", "WEIGHTED_OVO", "WEIGHTED_OVR", "NONE", "AUTO"))) { stop("auc_type argument must be MACRO_OVO, MACRO_OVR, WEIGHTED_OVO, WEIGHTED_OVR, NONE, AUTO") @@ -1157,8 +1157,8 @@ h2o.make_metrics <- function(predicted, actuals, domain=NULL, distribution=NULL, if (auuc_nbins < -1 || auuc_nbins == 0) { stop("auuc_nbins must be -1 or higher than 0.") } - params$auuc_type = auuc_type - params$auuc_nbins = auuc_nbins + params$auuc_type <- auuc_type + params$auuc_nbins <- auuc_nbins } params$domain <- domain params$distribution <- distribution