-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutil.py
134 lines (112 loc) · 4.48 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Contains helper functions
"""
import csv
import json
from tqdm import tqdm
def zero_shot_predict(zero_shot_classifier, zero_shot_labels, data):
"""
:param zero_shot_classifier: a pipleine for a zero-shot classification created using Hugging Face's transformers library
:param zero_shot_labels: a list of strings for the zero-shot labels
:param data: a dataset creating using Hugging Face's datasets library
:return: a list of dictionaries with the following keys: sentence, prediction and answer
"""
labeled_cases = []
for case in tqdm(data):
prediction = zero_shot_classifier(case["sentence"], zero_shot_labels)
prediction_string = prediction["labels"][0]
# get the index number of the predicted label
prediction_int = 0
if prediction_string == zero_shot_labels[1]:
prediction_int = 1
result = {'sentence': case["sentence"], 'prediction': prediction_int, "answer": case["label"]}
labeled_cases.append(result)
return labeled_cases
def display_result(predictions):
"""
Displays the performance of a set of predictions.
:param predictions: a list of dictionaries with the following keys: sentence, prediction and answer
:return: None
"""
true_positive = 0
true_negative = 0
false_positive = 0
false_negative = 0
for case in predictions:
if case['prediction'] == case['answer']:
# Prediction is correct
if case['prediction'] == 1:
true_positive += 1
else:
true_negative += 1
else:
# Prediction is wrong
if case['prediction'] == 1:
false_positive += 1
else:
false_negative += 1
total = true_positive + true_negative + false_positive + false_negative
percentage = (true_positive + true_negative) / total
print("Percentage correct: ", str(percentage*100) + "%")
print("true_positive ", true_positive)
print("true_negative ", true_negative)
print("false_positive ", false_positive)
print("false_negative ", false_negative)
def generate_training_csv(predictions, path):
"""
Creates a CSV file for in a format that's understandable by HappyTextClassification objects for training
:param path: a string that contains the path to a CSV file.
:param predictions: a list of dictionaries with the two keys: sentence and prediction
:return:
"""
with open(path, 'w', newline='') as csvfile:
writter = csv.writer(csvfile)
writter.writerow(["text", "label"])
for case in predictions:
writter.writerow([case["sentence"], case["prediction"]])
def generate_training_json(predictions, path):
"""
Creates a CSV file for in a format that's understandable by HappyTextClassification objects for training
:param path: a string that contains the path to a JSON file.
:param predictions: a list of dictionaries with the keys sentence and prediction
:return:
"""
textblob_data = []
for case in predictions:
textblob_data.append({
'text': case["sentence"],
'label': "positive" if case["prediction"] else "negative"
})
with open(path, 'w') as f_out:
json.dump(textblob_data, f_out)
def happy_tc_predict(happy_tc, data):
"""
:param happy_tc: a HappyTextClassification object
:param data: a Hugging Face dataset
:return: a list of dictionaries with the following keys: sentence, prediction and answer
"""
predictions = []
for case in tqdm(data):
result = happy_tc.classify_text(case["sentence"])
prediction = result.label
prediction_int = 0
if prediction == "LABEL_1":
prediction_int = 1
result = {'sentence': case["sentence"], 'prediction': prediction_int, "answer": case["label"]}
predictions.append(result)
return predictions
def textblob_predict(textblob, data):
"""
:param textblob: A TextBlob NaiveBayesClassifier
:param data: a Hugging Face dataset
:return: a list of dictionaries with the following keys: sentence, prediction and answer
"""
predictions = []
for case in tqdm(data):
prediction = textblob.classify(case["sentence"])
prediction_int = 0
if prediction == "positive":
prediction_int = 1
result = {'sentence': case["sentence"], 'prediction': prediction_int, "answer": case["label"]}
predictions.append(result)
return predictions