-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathClassifier.java
76 lines (67 loc) · 2.14 KB
/
Classifier.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import java.io.File;
import java.util.Scanner;
import java.util.Set;
import java.util.HashSet;
public class Classifier {
public static void main(String[] args) {
// Parse the smoothing parameter to use.
int smooth = 1;
try {
smooth = Integer.parseInt(args[0]);
} catch (Exception e) {
System.out.println("Please be sure to include a valid smoothing parameter.");
System.exit(0);
}
// Parse training and test data into sets.
Set<Email> trainEmails = parseEmails("./DataSet/train");
Set<Email> testEmails = parseEmails("./DataSet/test");
// Train the data and then predict the classifier
NaiveBayes nb = new NaiveBayes(smooth);
nb.train(trainEmails);
int correctPred = 0;
for (Email e : testEmails) {
if (e.getLabel().equals(nb.predict(e))) { correctPred++; }
}
// Print accuracy statistics
computeAccuracy(correctPred, testEmails.size());
}
/**
* Parses a file containing emails and corresponding word counts.
*
* @param fileName File path to the file containing the emails to parse.
* @param return Returns a Set of Email objects
*/
public static Set<Email> parseEmails(String fileName) {
Set<Email> emails = new HashSet<Email>();
try {
Scanner sc = new Scanner(new File(fileName));
while (sc.hasNextLine()) {
Scanner line = new Scanner(sc.nextLine());
String eid = line.next();
String label = line.next();
Email email = new Email(eid, label);
while (line.hasNext()) {
String word = line.next();
int count = line.nextInt();
email.addWord(word, count);
}
emails.add(email);
}
} catch (Exception e) {
e.printStackTrace();
}
return emails;
}
/**
* Prints out the accuracy of the predictions to the console.
*
* @param matches Integer representing number of correct predictions made.
* @param total Integer representing total number of predictions made.
*/
public static void computeAccuracy(int matches, int total) {
System.out.println("Test-Data Prediction Statistics");
System.out.println("-------------------------------");
System.out.printf("Matches: %d\n", matches);
System.out.printf("Accuracy: %.2f%%\n\n", 100.0 * matches / total);
}
}