Skip to content

Commit

Permalink
iSOUP-Tree integration files (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
aosojnik authored Apr 14, 2022
1 parent ae31cb7 commit ff3efe3
Show file tree
Hide file tree
Showing 5 changed files with 1,133 additions and 412 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ public double[] getBranchSplitVarianceOutput(DoubleVector[] postSplitDists) {

}

private double computeVariance(double n, double sum, double squares) {
protected double computeVariance(double n, double sum, double squares) {
if (n > 1) {
return (squares - sum * sum / n) / (n - 1);
}
return 0;
}

private double computeVariance(DoubleVector v) {
protected double computeVariance(DoubleVector v) {
return computeVariance(v.getValue(0), v.getValue(1), v.getValue(2));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package moa.classifiers.multilabel.core.splitcriteria;

import moa.core.DoubleVector;

public class PCTWeightedICVarianceReduction extends WeightedICVarianceReduction {

private static final long serialVersionUID = 1L;

public DoubleVector inputWeights;

public double targetInputTradeoff; // \in [0, 1]

public PCTWeightedICVarianceReduction(DoubleVector targetWeights, DoubleVector inputWeights,
double targetInputTradeoff) {
super(targetWeights);
this.inputWeights = inputWeights;
this.targetInputTradeoff = targetInputTradeoff;
}

public double getMeritOfSplit(DoubleVector[] preSplitTargetDist, DoubleVector[][] postSplitTargetDists,
DoubleVector[] preSplitInputDist, DoubleVector[][] postSplitInputDists) {
double targetError = 0;
double inputError = 0;
int numTargets = preSplitTargetDist.length;
int numInputs = preSplitInputDist.length;
for (int i = 0; i < numTargets; i++) {
if (postSplitTargetDists[i][0].getValue(0) == 0 || postSplitTargetDists[i][1].getValue(0) == 0)
// If the split would result in a leaf without labeled instances, discard it
return Double.NaN;
targetError += weights.getValue(i)
* getMeritOfSplitForAttribute(preSplitTargetDist[i], postSplitTargetDists[i]);
}
for (int i = 0; i < numInputs; i++)
inputError += inputWeights.getValue(i)
* getMeritOfSplitForAttribute(preSplitInputDist[i], postSplitInputDists[i]);

return targetInputTradeoff * targetError / weights.sumOfValues()
+ (1 - targetInputTradeoff) * inputError / inputWeights.sumOfValues();
}

//@Override
protected double getMeritOfSplitForAttribute(DoubleVector preSplitDist, DoubleVector[] postSplitDists) {
double merit = 0;
// count number of branches with weightSeen higher than threshold
int count = 0;
for (int i = 0; i < postSplitDists.length; i++)
if (postSplitDists[i].getValue(0) >= 0.05 * preSplitDist.getValue(0))
count = count + 1;
// Consider split if all branches have required weight seen
if (count == postSplitDists.length) {
double varPreSplit = computeVariance(preSplitDist);
double sumVarPostSplit = 0;
double weightTotal = 0;
for (int i = 0; i < postSplitDists.length; i++) {
weightTotal += postSplitDists[i].getValue(0);
}
double[] variances = getBranchSplitVarianceOutput(postSplitDists);
for (int i = 0; i < variances.length; i++)
if (postSplitDists[i].getValue(0) > 0)
sumVarPostSplit += (postSplitDists[i].getValue(0) / weightTotal * variances[i]); // weight variance
if (varPreSplit > 0.0)
merit = 1 - sumVarPostSplit / varPreSplit;
else
merit = 0.0;
}
/*
* if(merit<0 || merit>1) System.out.println("out of range");
*/
return merit;
}

// private double computeVariance(double n, double sum, double squares) {
// if (n > 1) {
// return (squares - sum * sum / n) / (n - 1);
// }
// return 0;
// }

// private double computeVariance(DoubleVector v) {
// return computeVariance(v.getValue(0), v.getValue(1), v.getValue(2));
// }

public double getMeritOfSplit(DoubleVector[] preSplitTargetDist, DoubleVector[] preSplitInputDist,
DoubleVector[][] postSplitTargetDists, DoubleVector[][] postSplitInputDists) {
double error = 0;
int numOutputs = preSplitTargetDist.length;
for (int i = 0; i < numOutputs; i++)
error += weights.getValue(i) * getMeritOfSplitForOutput(preSplitTargetDist, postSplitTargetDists, i);
return error / weights.sumOfValues();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package moa.classifiers.multilabel.core.splitcriteria;

import moa.core.DoubleVector;
import moa.core.ObjectRepository;
import moa.tasks.TaskMonitor;

/**
* Weighted intra cluster variance reduction split criterion
*
* @author Aljaž Osojnik <[email protected]>
*
*/

public class WeightedICVarianceReduction extends ICVarianceReduction {

/**
*
*/
private static final long serialVersionUID = 1L;

public DoubleVector weights;

public WeightedICVarianceReduction(DoubleVector targetWeights) {
super();
this.weights = targetWeights;
}

@Override
public double getMeritOfSplit(DoubleVector[] preSplitDist, DoubleVector[][] postSplitDists) {
double error = 0;
int numOutputs = preSplitDist.length;
for (int i = 0; i < numOutputs; i++)
// TEMPORARY
// error += (1 / weights.getValue(i)) *
// getMeritOfSplitForOutput(preSplitDist,postSplitDists,i);
error += weights.getValue(i) * getMeritOfSplitForOutput(preSplitDist, postSplitDists, i);
return error / weights.sumOfValues();
}

@Override
public void getDescription(StringBuilder sb, int indent) {

}

@Override
protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
}

}
Loading

0 comments on commit ff3efe3

Please sign in to comment.