Skip to content

Commit

Permalink
GH-6783 Add custom metric function to UpliftDRF (#15592)
Browse files Browse the repository at this point in the history
* Implement ATE, ATT, ATC metrics
* Enable custom metric for UpliftDRF
* Fix score with treatment column
* Add tests
  • Loading branch information
maurever authored Sep 15, 2023
1 parent f6d120b commit 4c5a01b
Show file tree
Hide file tree
Showing 23 changed files with 752 additions and 36 deletions.
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/schemas/UpliftDRFV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public static final class UpliftDRFParametersV3 extends SharedTreeV3.SharedTreeP
"categorical_encoding",
"distribution",
"check_constant_response",
"custom_metric_func",
"treatment_column",
"uplift_metric",
"auuc_type",
Expand Down
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ protected final boolean doScoringAndSaveModel(boolean finalScoring, boolean oob,
ModelMetrics mmv = scv.scoreAndMakeModelMetrics(_model, _parms.valid(), v, build_tree_one_node);
_lastScoredTree = _model._output._ntrees;
out._validation_metrics = mmv;
out._validation_metrics._description = "Validation metrics";
if (_model._output._ntrees>0 || scoreZeroTrees()) //don't score the 0-tree model - the error is too large
out._scored_valid[out._ntrees].fillFrom(mmv);
}
Expand Down
14 changes: 12 additions & 2 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ public boolean providesVarImp() {
error("_treatment_column", "The treatment column has to be defined.");
if (_parms._custom_distribution_func != null)
error("_custom_distribution_func", "The custom distribution is not yet supported for Uplift DRF.");
if (_parms._custom_metric_func != null)
error("_custom_metric_func", "The custom metric is not yet supported for Uplift DRF.");
if (_parms._stopping_metric != ScoreKeeper.StoppingMetric.AUTO)
error("_stopping_metric", "The early stopping is not yet supported for Uplift DRF.");
if (_parms._stopping_rounds != 0)
Expand Down Expand Up @@ -404,6 +402,9 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,
colHeaders.add("Timestamp"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Duration"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Number of Trees"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Training ATE"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Training ATT"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Training ATC"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Training AUUC nbins"); colTypes.add("int"); colFormat.add("%d");
colHeaders.add("Training AUUC"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Training AUUC normalized"); colTypes.add("double"); colFormat.add("%.5f");
Expand All @@ -413,6 +414,9 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,
}

if (_output._validation_metrics != null) {
colHeaders.add("Validation ATE"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Validation ATT"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Validation ATC"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Validation AUUC nbins"); colTypes.add("int"); colFormat.add("%d");
colHeaders.add("Validation AUUC"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Validation AUUC normalized"); colTypes.add("double"); colFormat.add("%.5f");
Expand Down Expand Up @@ -443,6 +447,9 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,
table.set(row, col++, PrettyPrint.msecs(_training_time_ms[i] - job.start_time(), true));
table.set(row, col++, i);
ScoreKeeper st = _scored_train[i];
table.set(row, col++, st._ate);
table.set(row, col++, st._att);
table.set(row, col++, st._atc);
table.set(row, col++, st._auuc_nbins);
table.set(row, col++, st._AUUC);
table.set(row, col++, st._auuc_normalized);
Expand All @@ -451,6 +458,9 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,

if (_output._validation_metrics != null) {
st = _scored_valid[i];
table.set(row, col++, st._ate);
table.set(row, col++, st._att);
table.set(row, col++, st._atc);
table.set(row, col++, st._auuc_nbins);
table.set(row, col++, st._AUUC);
table.set(row, col++, st._auuc_normalized);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package hex.util;

import hex.AUUC;
import hex.Model;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
Expand Down
10 changes: 9 additions & 1 deletion h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -2222,10 +2222,15 @@ protected void setupLocal() {
Chunk weightsChunk = _hasWeights && _computeMetrics ? chks[_output.weightsIdx()] : null;
Chunk offsetChunk = _output.hasOffset() ? chks[_output.offsetIdx()] : null;
Chunk responseChunk = null;
Chunk treatmentChunk = null;
float [] actual = null;
_mb = Model.this.makeMetricBuilder(_domain);
if (_computeMetrics) {
if (_output.hasResponse()) {
if (_output.hasTreatment()){
actual = new float[2];
responseChunk = chks[_output.responseIdx()];
treatmentChunk = chks[_output.treatmentIdx()];
} else if (_output.hasResponse()) {
actual = new float[1];
responseChunk = chks[_output.responseIdx()];
} else
Expand All @@ -2252,6 +2257,9 @@ protected void setupLocal() {
for (int i = 0; i < actual.length; ++i)
actual[i] = (float) data(chks, row, i);
}
if (treatmentChunk != null) {
actual[1] = (float) treatmentChunk.atd(row);
}
_mb.perRow(preds, actual, weight, offset, Model.this);
// Handle custom metric
customMetricPerRow(preds, actual, weight, offset, Model.this);
Expand Down
48 changes: 39 additions & 9 deletions h2o-core/src/main/java/hex/ModelMetricsBinomialUplift.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@

public class ModelMetricsBinomialUplift extends ModelMetricsSupervised {
public final AUUC _auuc;
public double _ate;
public double _att;
public double _atc;

public ModelMetricsBinomialUplift(Model model, Frame frame, long nobs, String[] domain,
double sigma, AUUC auuc,
public ModelMetricsBinomialUplift(Model model, Frame frame, long nobs, String[] domain,
double ate, double att, double atc, double sigma, AUUC auuc,
CustomMetric customMetric) {
super(model, frame, nobs, 0, domain, sigma, customMetric);
_ate = ate;
_att = att;
_atc = atc;
_auuc = auuc;
}

Expand All @@ -30,6 +36,9 @@ public static ModelMetricsBinomialUplift getFromDKV(Model model, Frame frame) {
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
sb.append("ATE:" ).append((float) _ate).append("\n");
sb.append("ATT:" ).append((float) _att).append("\n");
sb.append("ATC:" ).append((float) _atc).append("\n");
if(_auuc != null){
sb.append("Default AUUC: ").append((float) _auuc.auuc()).append("\n");
sb.append("Qini AUUC: ").append((float) _auuc.auucByType(AUUC.AUUCType.qini)).append("\n");
Expand All @@ -50,6 +59,12 @@ public String toString() {
public double auucNormalized(){return _auuc.auucNormalized();}

public int nbins(){return _auuc._nBins;}

public double ate() {return _ate;}

public double att() {return _att;}

public double atc() {return _atc;}

@Override
protected StringBuilder appendToStringMetrics(StringBuilder sb) {
Expand Down Expand Up @@ -127,13 +142,13 @@ public UpliftBinomialMetrics(String[] domain, double[] thresholds) {
_mb = new MetricBuilderBinomialUplift(domain, thresholds);
Chunk uplift = chks[0];
Chunk actuals = chks[1];
Chunk treatment =chks[2];
Chunk treatment = chks[2];
double[] ds = new double[1];
float[] acts = new float[2];
for (int i=0; i<chks[0]._len;++i) {
ds[0] = uplift.atd(i);
acts[0] = (float) actuals.atd(i);
acts[1] = (float )treatment.atd(i);
acts[1] = (float) treatment.atd(i);
_mb.perRow(ds, acts, 1, 0, null);
}
}
Expand All @@ -143,7 +158,10 @@ public UpliftBinomialMetrics(String[] domain, double[] thresholds) {
public static class MetricBuilderBinomialUplift extends MetricBuilderSupervised<MetricBuilderBinomialUplift> {

protected AUUC.AUUCBuilder _auuc;

public double _sumTE;
public double _sumTETreatment;
public long _treatmentCount;

public MetricBuilderBinomialUplift( String[] domain, double[] thresholds) {
super(2,domain);
if(thresholds != null) {
Expand All @@ -163,17 +181,20 @@ public MetricBuilderBinomialUplift( String[] domain) {
public double[] perRow(double[] ds, float[] yact, double weight, double offset, Model m) {
assert _auuc == null || yact.length == 2 : "Treatment must be included in `yact` when calculating AUUC";
if(Float .isNaN(yact[0])) return ds; // No errors if actual is missing
if(ArrayUtils.hasNaNs(ds)) return ds; // No errors if prediction has missing values (can happen for GLM)
if(weight == 0 || Double.isNaN(weight)) return ds;
int y = (int)yact[0];
if (y != 0 && y != 1) return ds; // The actual is effectively a NaN
_wY += weight * y;
_wYY += weight * y * y;
_count++;
_wcount += weight;
int treatmentGroup = (int)yact[1]; // treatment = 1, control = 0
double treatmentEffect = ds[0] * weight;
_sumTE += treatmentEffect; // result prediction
_sumTETreatment += treatmentGroup * treatmentEffect;
_treatmentCount += treatmentGroup * weight;
if (_auuc != null) {
float treatment = yact[1];
_auuc.perRow(ds[0], weight, y, treatment);
_auuc.perRow(treatmentEffect, weight, y, treatmentGroup);
}
return ds;
}
Expand All @@ -183,6 +204,9 @@ public double[] perRow(double[] ds, float[] yact, double weight, double offset,
if(_auuc != null) {
_auuc.reduce(mb._auuc);
}
_sumTE += mb._sumTE;
_sumTETreatment += mb._sumTETreatment;
_treatmentCount += mb._treatmentCount;
}

/**
Expand Down Expand Up @@ -231,15 +255,21 @@ private ModelMetrics makeModelMetrics(final Model m, final Frame f, final Frame

private ModelMetrics makeModelMetrics(Model m, Frame f, AUUC auuc) {
double sigma = Double.NaN;
double ate = Double.NaN;
double atc = Double.NaN;
double att = Double.NaN;
if(_wcount > 0) {
if (auuc == null) {
sigma = weightedSigma();
auuc = new AUUC(_auuc, m._parms._auuc_type);
}
ate = _sumTE/_wcount;
att = _sumTETreatment/_treatmentCount;
atc = (_sumTE-_sumTETreatment)/(_wcount-_treatmentCount);
} else {
auuc = new AUUC();
}
ModelMetricsBinomialUplift mm = new ModelMetricsBinomialUplift(m, f, _count, _domain, sigma, auuc, _customMetric);
ModelMetricsBinomialUplift mm = new ModelMetricsBinomialUplift(m, f, _count, _domain, ate, att, atc, sigma, auuc, _customMetric);
if (m!=null) m.addModelMetrics(mm);
return mm;
}
Expand Down
6 changes: 6 additions & 0 deletions h2o-core/src/main/java/hex/ScoreKeeper.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public class ScoreKeeper extends Iced {
public double _auuc_normalized = Double.NaN;
public double _qini = Double.NaN;
public int _auuc_nbins = 0;
public double _ate = Double.NaN;
public double _att = Double.NaN;
public double _atc = Double.NaN;

public ScoreKeeper() {}

Expand Down Expand Up @@ -125,6 +128,9 @@ else if (m instanceof ModelMetricsMultinomial) {
_auuc_normalized = ((ModelMetricsBinomialUplift)m).auucNormalized();
_qini = ((ModelMetricsBinomialUplift)m).qini();
_auuc_nbins = ((ModelMetricsBinomialUplift)m).nbins();
_ate = ((ModelMetricsBinomialUplift)m).ate();
_att = ((ModelMetricsBinomialUplift)m).att();
_atc = ((ModelMetricsBinomialUplift)m).atc();
}
if (customMetric != null ) {
_custom_metric = customMetric.value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
public class ModelMetricsBinomialUpliftV3<I extends ModelMetricsBinomialUplift, S extends water.api.schemas3.ModelMetricsBinomialUpliftV3<I, S>>
extends ModelMetricsBaseV3<I,S> {

@API(help="Average Treatment Effect.", direction=API.Direction.OUTPUT)
public double ate;

@API(help="Average Treatment Effect on the Treated.", direction=API.Direction.OUTPUT)
public double att;

@API(help="Average Treatment Effect on the Control.", direction=API.Direction.OUTPUT)
public double atc;

@API(help="The default AUUC for this scoring run.", direction=API.Direction.OUTPUT)
public double AUUC;

Expand Down Expand Up @@ -40,6 +49,9 @@ public S fillFromImpl(ModelMetricsBinomialUplift modelMetrics) {

AUUC auuc = modelMetrics._auuc;
if (null != auuc) {
ate = modelMetrics.ate();
att = modelMetrics.att();
atc = modelMetrics.atc();
AUUC = auuc.auuc();
auuc_normalized = auuc.auucNormalized();
qini = auuc.qini();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
``upload_custom_metric``
------------------------

- Available in: GBM, DRF, Deeplearning, GLM
- Available in: GBM, DRF, Deeplearning, GLM, UpliftDRF
- Hyperparameter: no

Description
Expand Down
Loading

0 comments on commit 4c5a01b

Please sign in to comment.