diff --git a/Core/src/main/java/org/tribuo/impl/ArrayExample.java b/Core/src/main/java/org/tribuo/impl/ArrayExample.java index d5e2a2fb0..f296b9524 100644 --- a/Core/src/main/java/org/tribuo/impl/ArrayExample.java +++ b/Core/src/main/java/org/tribuo/impl/ArrayExample.java @@ -26,6 +26,7 @@ import org.tribuo.transform.TransformerMap; import org.tribuo.util.Merger; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; @@ -293,18 +294,19 @@ public int size() { @Override public void removeFeatures(List featureList) { - Map map = new HashMap<>(); + Map> map = new HashMap<>(); for (int i = 0; i < featureNames.length; i++) { - map.put(featureNames[i],i); + List list = map.computeIfAbsent(featureNames[i],(k) -> new ArrayList<>()); + list.add(i); } PriorityQueue removeQueue = new PriorityQueue<>(); for (Feature f : featureList) { - Integer i = map.get(f.getName()); + List i = map.get(f.getName()); if (i != null) { - // If we've found this feature ID remove it from the map to prevent double counting + // If we've found this feature remove it from the map to prevent double counting map.remove(f.getName()); - removeQueue.add(i); + removeQueue.addAll(i); } } diff --git a/Core/src/main/java/org/tribuo/impl/BinaryFeaturesExample.java b/Core/src/main/java/org/tribuo/impl/BinaryFeaturesExample.java index e4c9da46a..6e6440eae 100644 --- a/Core/src/main/java/org/tribuo/impl/BinaryFeaturesExample.java +++ b/Core/src/main/java/org/tribuo/impl/BinaryFeaturesExample.java @@ -16,6 +16,7 @@ package org.tribuo.impl; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -312,18 +313,19 @@ public int size() { @Override public void removeFeatures(List featureList) { - Map map = new HashMap<>(); + Map> map = new HashMap<>(); for (int i = 0; i < featureNames.length; i++) { - map.put(featureNames[i],i); + List list = map.computeIfAbsent(featureNames[i],(k) -> new ArrayList<>()); + list.add(i); } PriorityQueue removeQueue = new PriorityQueue<>(); for (Feature f : featureList) { - Integer i = map.get(f.getName()); + List i = map.get(f.getName()); if (i != null) { - // If we've found this feature ID remove it from the map to prevent double counting + // If we've found this feature remove it from the map to prevent double counting map.remove(f.getName()); - removeQueue.add(i); + removeQueue.addAll(i); } } diff --git a/Core/src/main/java/org/tribuo/impl/IndexedArrayExample.java b/Core/src/main/java/org/tribuo/impl/IndexedArrayExample.java index a85042f5e..80d3a9b90 100644 --- a/Core/src/main/java/org/tribuo/impl/IndexedArrayExample.java +++ b/Core/src/main/java/org/tribuo/impl/IndexedArrayExample.java @@ -24,6 +24,7 @@ import org.tribuo.Output; import org.tribuo.util.Merger; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; @@ -200,18 +201,19 @@ public void reduceByName(Merger merger) { @Override public void removeFeatures(List featureList) { - Map map = new HashMap<>(); + Map> map = new HashMap<>(); for (int i = 0; i < featureNames.length; i++) { - map.put(featureNames[i],i); + List list = map.computeIfAbsent(featureNames[i],(k) -> new ArrayList<>()); + list.add(i); } PriorityQueue removeQueue = new PriorityQueue<>(); for (Feature f : featureList) { - Integer i = map.get(f.getName()); + List i = map.get(f.getName()); if (i != null) { - // If we've found this feature ID remove it from the map to prevent double counting + // If we've found this feature remove it from the map to prevent double counting map.remove(f.getName()); - removeQueue.add(i); + removeQueue.addAll(i); } } diff --git a/Core/src/test/java/org/tribuo/ExampleTest.java b/Core/src/test/java/org/tribuo/ExampleTest.java index dd53191d2..8ab88235f 100644 --- a/Core/src/test/java/org/tribuo/ExampleTest.java +++ b/Core/src/test/java/org/tribuo/ExampleTest.java @@ -137,8 +137,21 @@ public void testArrayExampleRemove() { assertEquals(2,example.size()); assertEquals("A",example.lookup("A").name); assertEquals("C",example.lookup("C").name); + + example = new ArrayExample<>(output,new String[]{"A","B","C","D","E","A","C","E"},new double[]{1,1,1,1,1,1,1,1}); + featureList = new ArrayList<>(); + featureList.add(new Feature("D",1.0)); + featureList.add(new Feature("D",1.0)); + featureList.add(new Feature("B",1.0)); + featureList.add(new Feature("E",1.0)); + example.removeFeatures(featureList); + assertEquals(4,example.size()); + assertEquals("A",example.lookup("A").name); + assertEquals("C",example.lookup("C").name); } + + public static void checkDenseExample(Example expected, Example actual) { assertEquals(expected.size(),actual.size()); Iterator expectedItr = expected.iterator();