From e158ee095480d97ae1513e9a65ffb8ec988de22f Mon Sep 17 00:00:00 2001 From: Veronika Maurerova Date: Wed, 13 Sep 2023 14:43:22 +0200 Subject: [PATCH] fix make metrics bug --- h2o-core/src/main/java/hex/AUUC.java | 19 +++++++++--------- .../java/hex/ModelMetricsBinomialUplift.java | 20 +++++++++---------- h2o-core/src/test/java/hex/AUUCTest.java | 2 +- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/h2o-core/src/main/java/hex/AUUC.java b/h2o-core/src/main/java/hex/AUUC.java index 1e7d4b2d1804..2129215cb9a3 100644 --- a/h2o-core/src/main/java/hex/AUUC.java +++ b/h2o-core/src/main/java/hex/AUUC.java @@ -71,20 +71,21 @@ public double[] upliftRandomByType(AUUCType type){ int idx = getIndexByAUUCType(type); return idx < 0 ? null : _upliftRandom[idx]; } - - public AUUC(int nBins, Vec probs, Vec y, Vec uplift, AUUCType auucType) { - this(new AUUCImpl(calculateQuantileThresholds(nBins, probs)).doAll(probs, y, uplift)._bldr, auucType); - } - public AUUC(double[] customThresholds, Vec probs, Vec y, Vec uplift, AUUCType auucType) { - this(new AUUCImpl(customThresholds).doAll(probs, y, uplift)._bldr, auucType); + public AUUC(Vec probs, Vec y, Vec uplift, AUUCType auucType, int nbins) { + this(new AUUCImpl(calculateQuantileThresholds(nbins, probs)).doAll(probs, y, uplift)._bldr, auucType); } public AUUC(AUUCBuilder bldr, AUUCType auucType) { this(bldr, true, auucType); } + + + public AUUC(double[] customThresholds, Vec probs, Vec y, Vec uplift, AUUCType auucType) { + this(new AUUCImpl(customThresholds).doAll(probs, y, uplift)._bldr, auucType); + } - private AUUC(AUUCBuilder bldr, boolean trueProbabilities, AUUCType auucType) { + public AUUC(AUUCBuilder bldr, boolean trueProbabilities, AUUCType auucType) { _auucType = auucType; _auucTypeIndx = getIndexByAUUCType(_auucType); _nBins = bldr._nBins; @@ -316,11 +317,11 @@ public double auucRandom(int idx){ public double auucNormalized(){ return auucNormalized(_auucTypeIndx); } - private static class AUUCImpl extends MRTask { + public static class AUUCImpl extends MRTask { final double[] _thresholds; AUUCBuilder _bldr; - AUUCImpl(double[] thresholds) { + public AUUCImpl(double[] thresholds) { _thresholds = thresholds; } diff --git a/h2o-core/src/main/java/hex/ModelMetricsBinomialUplift.java b/h2o-core/src/main/java/hex/ModelMetricsBinomialUplift.java index 13e38277d5eb..5b6c86cf1999 100644 --- a/h2o-core/src/main/java/hex/ModelMetricsBinomialUplift.java +++ b/h2o-core/src/main/java/hex/ModelMetricsBinomialUplift.java @@ -4,7 +4,6 @@ import water.Scope; import water.exceptions.H2OIllegalArgumentException; import water.fvec.*; -import water.util.ArrayUtils; import water.util.Log; import java.util.Arrays; @@ -163,8 +162,7 @@ static public ModelMetricsBinomialUplift make(Vec predictedProbs, Vec actualLabe mb = new UpliftBinomialMetrics(labels.domain(), customAuucThresholds).doAll(fr)._mb; } labels.remove(); - ModelMetricsBinomialUplift mm = (ModelMetricsBinomialUplift) mb.makeModelMetrics(null, fr, new Frame(predictedProbs), - fr.vec("labels"), fr.vec("treatment"), auucType, auucNbins); // use the Vecs from the frame (to make sure the ESPC is identical) + ModelMetricsBinomialUplift mm = (ModelMetricsBinomialUplift) mb.makeModelMetrics(null, fr, auucType); mm._description = "Computed on user-given predictions and labels."; return mm; } finally { @@ -274,7 +272,7 @@ public double[] perRow(double[] ds, float[] yact, double weight, double offset, treatment = frameWithExtraColumns.vec(m._parms._treatment_column); } } - int auucNbins = m==null || m._parms._auuc_nbins == -1? + int auucNbins = m==null || m._parms._auuc_nbins == -1? AUUC.NBINS : m._parms._auuc_nbins; return makeModelMetrics(m, f, preds, resp, treatment, auucType, auucNbins); } @@ -282,17 +280,19 @@ public double[] perRow(double[] ds, float[] yact, double weight, double offset, private ModelMetrics makeModelMetrics(final Model m, final Frame f, final Frame preds, final Vec resp, final Vec treatment, AUUC.AUUCType auucType, int nbins) { AUUC auuc = null; - if (preds != null && resp != null && treatment != null) { - if (_auuc == null || _auuc._nBins > 0) { - auuc = new AUUC(nbins, preds.vec(0), resp, treatment, auucType); - } else { - auuc = new AUUC(_auuc._thresholds, preds.vec(0), resp, treatment, auucType); + if (preds != null) { + if (resp != null) { + auuc = new AUUC(preds.vec(0), resp, treatment, auucType, nbins); } } return makeModelMetrics(m, f, auuc); } - private ModelMetrics makeModelMetrics(Model m, Frame f, AUUC auuc) { + private ModelMetrics makeModelMetrics(final Model m, final Frame f, AUUC.AUUCType auucType) { + return makeModelMetrics(m, f, new AUUC(_auuc, auucType)); + } + + public ModelMetrics makeModelMetrics(Model m, Frame f, AUUC auuc) { double sigma = Double.NaN; double ate = Double.NaN; double atc = Double.NaN; diff --git a/h2o-core/src/test/java/hex/AUUCTest.java b/h2o-core/src/test/java/hex/AUUCTest.java index a5a4db56bf45..381406bfc601 100644 --- a/h2o-core/src/test/java/hex/AUUCTest.java +++ b/h2o-core/src/test/java/hex/AUUCTest.java @@ -94,7 +94,7 @@ private static AUUC doAUUC(int nbins, double[] probs, double[] y, double[] treat } Frame fr = ArrayUtils.frame(new String[]{"probs", "y", "treatment"}, rows); fr.vec("treatment").setDomain(new String[]{"0", "1"}); - AUUC auuc = new AUUC(nbins, fr.vec("probs"),fr.vec("y"), fr.vec("treatment"), type); + AUUC auuc = new AUUC(fr.vec("probs"),fr.vec("y"), fr.vec("treatment"), type, nbins); fr.remove(); return auuc; }