-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes multilabel classifiers and multilabel generator (#249)
* Fixes multilabel classifiers and multilabel generator * Formatting
- Loading branch information
Alberto Cano
authored
Apr 14, 2022
1 parent
ff3efe3
commit eab15f7
Showing
6 changed files
with
37 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
import com.yahoo.labs.samoa.instances.MultiLabelPrediction; | ||
import com.yahoo.labs.samoa.instances.Prediction; | ||
import moa.classifiers.AbstractMultiLabelLearner; | ||
import moa.classifiers.MultiLabelLearner; | ||
import moa.classifiers.MultiTargetRegressor; | ||
import moa.core.StringUtils; | ||
|
||
|
@@ -38,7 +39,7 @@ | |
* @author Jesse Read ([email protected]) | ||
* @version $Revision: 1 $ | ||
*/ | ||
public class MajorityLabelset extends AbstractMultiLabelLearner implements MultiTargetRegressor { | ||
public class MajorityLabelset extends AbstractMultiLabelLearner implements MultiLabelLearner { | ||
//AbstractClassifier { | ||
|
||
private static final long serialVersionUID = 1L; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,11 +28,14 @@ | |
import com.github.javacliparser.FloatOption; | ||
import com.github.javacliparser.IntOption; | ||
import moa.streams.InstanceStream; | ||
import moa.streams.MultiTargetInstanceStream; | ||
import moa.tasks.TaskMonitor; | ||
import com.yahoo.labs.samoa.instances.Attribute; | ||
import com.yahoo.labs.samoa.instances.DenseInstance; | ||
import com.yahoo.labs.samoa.instances.Instance; | ||
import com.yahoo.labs.samoa.instances.Instances; | ||
import com.yahoo.labs.samoa.instances.InstancesHeader; | ||
import com.yahoo.labs.samoa.instances.Range; | ||
import com.yahoo.labs.samoa.instances.SparseInstance; | ||
import moa.core.FastVector; | ||
import moa.core.Utils; | ||
|
@@ -43,7 +46,7 @@ | |
* @author Jesse Read (([email protected])) | ||
* @version $Revision: 7 $ | ||
*/ | ||
public class MetaMultilabelGenerator extends AbstractOptionHandler implements InstanceStream { | ||
public class MetaMultilabelGenerator extends AbstractOptionHandler implements MultiTargetInstanceStream { | ||
|
||
private static final long serialVersionUID = 1L; | ||
|
||
|
@@ -153,19 +156,25 @@ public void restart() { | |
* @param si single-label Instances | ||
*/ | ||
protected MultilabelInstancesHeader generateMultilabelHeader(Instances si) { | ||
Instances mi = new Instances(si, 0, 0); | ||
mi.setClassIndex(-1); | ||
mi.deleteAttributeAt(mi.numAttributes() - 1); | ||
FastVector bfv = new FastVector(); | ||
bfv.addElement("0"); | ||
bfv.addElement("1"); | ||
for (int i = 0; i < this.m_L; i++) { | ||
mi.insertAttributeAt(new Attribute("class" + i, bfv), i); | ||
} | ||
this.multilabelStreamTemplate = mi; | ||
this.multilabelStreamTemplate.setRelationName("SYN_Z" + this.labelCardinalityOption.getValue() + "L" + this.m_L + "X" + m_A + "S" + metaRandomSeedOption.getValue() + ": -C " + this.m_L); | ||
this.multilabelStreamTemplate.setClassIndex(this.m_L); | ||
return new MultilabelInstancesHeader(multilabelStreamTemplate, m_L); | ||
Instances mi = new Instances(si, 0, 0); | ||
mi.deleteAttributeAt(mi.numAttributes() - 1); | ||
FastVector bfv = new FastVector(); | ||
bfv.addElement("0"); | ||
bfv.addElement("1"); | ||
for (int i = 0; i < this.m_L; i++) { | ||
mi.insertAttributeAt(new Attribute("class" + i, bfv), i); | ||
} | ||
|
||
Range range = new Range(Integer.toString((numLabelsOption.getValue()))); | ||
|
||
this.multilabelStreamTemplate = mi; | ||
this.multilabelStreamTemplate.setRelationName("SYN_Z" + this.labelCardinalityOption.getValue() + "L" + this.m_L + "X" + m_A + "S" + metaRandomSeedOption.getValue() + ": -C " + this.m_L); | ||
this.multilabelStreamTemplate.setClassIndex(Integer.MAX_VALUE); | ||
this.multilabelStreamTemplate.setRangeOutputIndices(range); | ||
|
||
MultilabelInstancesHeader header = new MultilabelInstancesHeader(multilabelStreamTemplate, m_L); | ||
header.setRangeOutputIndices(range); | ||
return header; | ||
} | ||
|
||
/** | ||
|
@@ -267,7 +276,7 @@ private double joint(int k, int y[]) { | |
private Instance generateMLInstance(HashSet<Integer> Y) { | ||
|
||
// create a multi-label instance: | ||
Instance x_ml = new SparseInstance(this.multilabelStreamTemplate.numAttributes()); | ||
Instance x_ml = new DenseInstance(this.multilabelStreamTemplate.numAttributes()); | ||
x_ml.setDataset(this.multilabelStreamTemplate); | ||
|
||
// set classes | ||
|
@@ -472,7 +481,7 @@ public int compare(HashSet Y1, HashSet Y2) { | |
} | ||
|
||
// shuffle | ||
Collections.shuffle(Arrays.asList(map_set)); | ||
Collections.shuffle(Arrays.asList(map_set), m_MetaRandom); | ||
|
||
// return | ||
return map_set; | ||
|
@@ -545,7 +554,7 @@ private ArrayList<Integer> getShuffledListToLWithoutK(int L, int k) { | |
list.add(j); | ||
} | ||
} | ||
Collections.shuffle(list); | ||
Collections.shuffle(list, m_MetaRandom); | ||
return list; | ||
} | ||
|
||
|