diff --git a/BayesClassifier.java b/BayesClassifier.java index 40f4a88..cbdf520 100644 --- a/BayesClassifier.java +++ b/BayesClassifier.java @@ -1,9 +1,7 @@ package de.daslaboratorium.machinelearning.classifier; -import java.util.AbstractMap.SimpleEntry; import java.util.Collection; import java.util.Comparator; -import java.util.Map.Entry; import java.util.SortedSet; import java.util.TreeSet; @@ -97,22 +95,25 @@ public int compare(Classification o1, * @return The category the set of features is classified as. */ @Override - public K classify(Collection features) { + public Classification classify(Collection features) { SortedSet> probabilites = this.categoryProbabilities(features); - System.out.println("Results:\t"); - for (Classification prob : probabilites) - System.out.println(prob); - if (probabilites.size() > 0) { - System.out.println("Classified as " + - probabilites.last().getCategory()); - return probabilites.last().getCategory(); - } else { - System.out.println("No results"); + return probabilites.last(); } return null; } + /** + * Classifies the given set of features. and return the full details of the + * classification. + * + * @return The set of categories the set of features is classified as. + */ + public Collection> classifyDetailed( + Collection features) { + return this.categoryProbabilities(features); + } + } diff --git a/Classifier.java b/Classifier.java index 54b6fb1..d1b74cb 100644 --- a/Classifier.java +++ b/Classifier.java @@ -35,7 +35,7 @@ public abstract class Classifier implements IFeatureProbability { /** * The initial memory capacity or how many classifications are memorized. */ - private static final int MEMORY_CAPACITY = 5; + private int memoryCapacity = 200; /** * A dictionary mapping features to their number of occurrences in each @@ -114,6 +114,28 @@ public int getCategoriesTotal() { return toReturn; } + /** + * Retrieves the memory's capacity. + * + * @return The memory's capacity. + */ + public int getMemoryCapacity() { + return memoryCapacity; + } + + /** + * Sets the memory's capacity. If the new value is less than the old + * value, the memory will be truncated accordingly. + * + * @param memoryCapacity The new memory capacity. + */ + public void setMemoryCapacity(int memoryCapacity) { + for (int i = this.memoryCapacity; i > memoryCapacity; i--) { + this.memoryQueue.poll(); + } + this.memoryCapacity = memoryCapacity; + } + /** * Increments the count of a given feature in the given category. This is * equal to telling the classifier, that this feature has occurred in this @@ -140,7 +162,7 @@ public void incrementFeature(T feature, K category) { Integer totalCount = this.totalFeatureCount.get(feature); if (totalCount == null) { - this.totalFeatureCount.put(feature, 1); + this.totalFeatureCount.put(feature, 0); totalCount = this.totalFeatureCount.get(feature); } this.totalFeatureCount.put(feature, ++totalCount); @@ -155,7 +177,7 @@ public void incrementFeature(T feature, K category) { public void incrementCategory(K category) { Integer count = this.totalCategoryCount.get(category); if (count == null) { - this.totalCategoryCount.put(category, 1); + this.totalCategoryCount.put(category, 0); count = this.totalCategoryCount.get(category); } this.totalCategoryCount.put(category, ++count); @@ -179,20 +201,23 @@ public void decrementFeature(T feature, K category) { if (count == null) { return; } - if (count == 1) { + if (count.intValue() == 1) { features.remove(feature); + if (features.size() == 0) { + this.featureCountPerCategory.remove(category); + } } else { - count--; + features.put(feature, --count); } Integer totalCount = this.totalFeatureCount.get(feature); if (totalCount == null) { return; } - if (totalCount == 1) { + if (totalCount.intValue() == 1) { this.totalFeatureCount.remove(feature); } else { - totalCount--; + this.totalFeatureCount.put(feature, --totalCount); } } @@ -207,10 +232,10 @@ public void decrementCategory(K category) { if (count == null) { return; } - if (count == 1) { + if (count.intValue() == 1) { this.totalCategoryCount.remove(category); } else { - count--; + this.totalCategoryCount.put(category, --count); } } @@ -360,20 +385,14 @@ public void learn(K category, Collection features) { */ public void learn(Classification classification) { - System.out.println("Learning new classification:\t" - + classification); - for (T feature : classification.getFeatureset()) this.incrementFeature(feature, classification.getCategory()); this.incrementCategory(classification.getCategory()); this.memoryQueue.offer(classification); - if (this.memoryQueue.size() > Classifier.MEMORY_CAPACITY) { + if (this.memoryQueue.size() > this.memoryCapacity) { Classification toForget = this.memoryQueue.remove(); - System.out.println("Memory over capacity. Forgetting about\t" - + toForget); - for (T feature : toForget.getFeatureset()) this.decrementFeature(feature, toForget.getCategory()); this.decrementCategory(toForget.getCategory()); @@ -387,6 +406,6 @@ public void learn(Classification classification) { * @param features The features to classify. * @return The category most likely. */ - public abstract K classify(Collection features); + public abstract Classification classify(Collection features); } diff --git a/ClassifierTester.java b/ClassifierTester.java deleted file mode 100644 index 4bb2fb1..0000000 --- a/ClassifierTester.java +++ /dev/null @@ -1,43 +0,0 @@ -package de.daslaboratorium.machinelearning.classifier; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.util.Arrays; -import java.util.Collection; - -public class ClassifierTester { - - /** - * @param args - * @throws IOException - */ - public static void main(String[] args) throws IOException { - BayesClassifier classifier = - new BayesClassifier(); - - BufferedReader io =new BufferedReader(new InputStreamReader(System.in)); - String line = ""; - System.out.print("> "); - while ((line = io.readLine()) != null) { - String[] tokens = line.split("\\s"); - if (tokens.length < 3) { - System.out.println("not enough params"); - continue; - } - if (tokens[0].startsWith("t")) { - Collection context = - Arrays.asList( - Arrays.copyOfRange(tokens, 2, tokens.length)); - classifier.learn(tokens[1], context); - } else if (tokens[0].startsWith("c")) { - Collection context = - Arrays.asList( - Arrays.copyOfRange(tokens, 1, tokens.length)); - classifier.classify(context); - } - System.out.print("> "); - } - } - -} diff --git a/forgetful-learning.pdf b/forgetful-learning.pdf new file mode 100644 index 0000000..ad681c7 Binary files /dev/null and b/forgetful-learning.pdf differ