-
Notifications
You must be signed in to change notification settings - Fork 44
/
ComputeObservedCoherence.py
157 lines (134 loc) · 5.14 KB
/
ComputeObservedCoherence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Author: Jey Han Lau
Date: May 2013
"""
import argparse
import sys
import operator
import math
import codecs
import numpy as np
from collections import defaultdict
#parser arguments
desc = "Computes the observed coherence for a given topic and word-count file."
parser = argparse.ArgumentParser(description=desc)
#####################
#positional argument#
#####################
parser.add_argument("topic_file", help="file that contains the topics")
parser.add_argument("metric", help="type of evaluation metric", choices=["pmi","npmi","lcp"])
parser.add_argument("wordcount_file", help="file that contains the word counts")
###################
#optional argument#
###################
parser.add_argument("-t", "--topns", nargs="+", type=int, default=[10], \
help="list of top-N topic words to consider for computing coherence; e.g. '-t 5 10' means it " + \
" will compute coherence over top-5 words and top-10 words and then take the mean of both values." + \
" Default = [10]")
args = parser.parse_args()
#parameters
colloc_sep = "_" #symbol for concatenating collocations
#input
topic_file = codecs.open(args.topic_file, "r", "utf-8")
wc_file = codecs.open(args.wordcount_file, "r", "utf-8")
#constants
WTOTALKEY = "!!<TOTAL_WINDOWS>!!" #key name for total number of windows (in word count file)
#global variables
window_total = 0 #total number of windows
wordcount = {} #a dictionary of word counts, for single and pair words
wordpos = {} #a dictionary of pos distribution
###########
#functions#
###########
#use utf-8 for stdout
sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
#compute the association between two words
def calc_assoc(word1, word2):
combined1 = word1 + "|" + word2
combined2 = word2 + "|" + word1
combined_count = 0
if combined1 in wordcount:
combined_count = wordcount[combined1]
elif combined2 in wordcount:
combined_count = wordcount[combined2]
w1_count = 0
if word1 in wordcount:
w1_count = wordcount[word1]
w2_count = 0
if word2 in wordcount:
w2_count = wordcount[word2]
if (args.metric == "pmi") or (args.metric == "npmi"):
if w1_count == 0 or w2_count == 0 or combined_count == 0:
result = 0.0
else:
result = math.log((float(combined_count)*float(window_total))/ \
float(w1_count*w2_count), 10)
if args.metric == "npmi":
result = result / (-1.0*math.log(float(combined_count)/(window_total),10))
elif args.metric == "lcp":
if combined_count == 0:
if w2_count != 0:
result = math.log(float(w2_count)/window_total, 10)
else:
result = math.log(float(1.0)/window_total, 10)
else:
result = math.log((float(combined_count))/(float(w1_count)), 10)
return result
#compute topic coherence given a list of topic words
def calc_topic_coherence(topic_words):
topic_assoc = []
for w1_id in range(0, len(topic_words)-1):
target_word = topic_words[w1_id]
#remove the underscore and sub it with space if it's a collocation/bigram
w1 = " ".join(target_word.split(colloc_sep))
for w2_id in range(w1_id+1, len(topic_words)):
topic_word = topic_words[w2_id]
#remove the underscore and sub it with space if it's a collocation/bigram
w2 = " ".join(topic_word.split(colloc_sep))
if target_word != topic_word:
topic_assoc.append(calc_assoc(w1, w2))
return float(sum(topic_assoc))/len(topic_assoc)
######
#main#
######
#process the word count file(s)
for line in wc_file:
line = line.strip()
data = line.split("|")
if len(data) == 2:
wordcount[data[0]] = int(data[1])
elif len(data) == 3:
if data[0] < data[1]:
key = data[0] + "|" + data[1]
else:
key = data[1] + "|" + data[0]
wordcount[key] = int(data[2])
else:
print "ERROR: wordcount format incorrect. Line =", line
raise SystemExit
#get the total number of windows
if WTOTALKEY in wordcount:
window_total = wordcount[WTOTALKEY]
#read the topic file and compute the observed coherence
topic_coherence = defaultdict(list) # {topicid: [tc]}
topic_tw = {} #{topicid: topN_topicwords}
for topic_id, line in enumerate(topic_file):
topic_list = line.split()[:max(args.topns)]
topic_tw[topic_id] = " ".join(topic_list)
for n in args.topns:
topic_coherence[topic_id].append(calc_topic_coherence(topic_list[:n]))
#sort the topic coherence scores in terms of topic id
tc_items = sorted(topic_coherence.items())
mean_coherence_list = []
for item in tc_items:
topic_words = topic_tw[item[0]].split()
mean_coherence = np.mean(item[1])
mean_coherence_list.append(mean_coherence)
print ("[%.2f] (" % mean_coherence),
for i in item[1]:
print ("%.2f;" % i),
print ")", topic_tw[item[0]]
#print the overall topic coherence for all topics
print "=========================================================================="
print "Average Topic Coherence = %.3f" % np.mean(mean_coherence_list)
print "Median Topic Coherence = %.3f" % np.median(mean_coherence_list)