-
Notifications
You must be signed in to change notification settings - Fork 10
/
entropy_estimators.py
207 lines (164 loc) · 7.16 KB
/
entropy_estimators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Written by Greg Ver Steeg (http://www.isi.edu/~gregv/npeet.html)
import scipy.spatial as ss
from scipy.special import digamma
from math import log
import numpy.random as nr
import numpy as np
import random
# continuous estimators
def entropy(x, k=3, base=2):
"""
The classic K-L k-nearest neighbor continuous entropy estimator x should be a list of vectors,
e.g. x = [[1.3],[3.7],[5.1],[2.4]] if x is a one-dimensional scalar and we have four samples
"""
assert k <= len(x)-1, "Set k smaller than num. samples - 1"
d = len(x[0])
N = len(x)
intens = 1e-10 # small noise to break degeneracy, see doc.
x = [list(p + intens * nr.rand(len(x[0]))) for p in x]
tree = ss.cKDTree(x)
nn = [tree.query(point, k+1, p=float('inf'))[0][k] for point in x]
const = digamma(N)-digamma(k) + d*log(2)
return (const + d*np.mean(map(log, nn)))/log(base)
def mi(x, y, k=3, base=2):
"""
Mutual information of x and y; x, y should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
if x is a one-dimensional scalar and we have four samples
"""
assert len(x) == len(y), "Lists should have same length"
assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
intens = 1e-10 # small noise to break degeneracy, see doc.
x = [list(p + intens * nr.rand(len(x[0]))) for p in x]
y = [list(p + intens * nr.rand(len(y[0]))) for p in y]
points = zip2(x, y)
# Find nearest neighbors in joint space, p=inf means max-norm
tree = ss.cKDTree(points)
dvec = [tree.query(point, k+1, p=float('inf'))[0][k] for point in points]
a, b, c, d = avgdigamma(x, dvec), avgdigamma(y, dvec), digamma(k), digamma(len(x))
return (-a-b+c+d)/log(base)
def cmi(x, y, z, k=3, base=2):
"""
Mutual information of x and y, conditioned on z; x, y, z should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
if x is a one-dimensional scalar and we have four samples
"""
assert len(x) == len(y), "Lists should have same length"
assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
intens = 1e-10 # small noise to break degeneracy, see doc.
x = [list(p + intens * nr.rand(len(x[0]))) for p in x]
y = [list(p + intens * nr.rand(len(y[0]))) for p in y]
z = [list(p + intens * nr.rand(len(z[0]))) for p in z]
points = zip2(x, y, z)
# Find nearest neighbors in joint space, p=inf means max-norm
tree = ss.cKDTree(points)
dvec = [tree.query(point, k+1, p=float('inf'))[0][k] for point in points]
a, b, c, d = avgdigamma(zip2(x, z), dvec), avgdigamma(zip2(y, z), dvec), avgdigamma(z, dvec), digamma(k)
return (-a-b+c+d)/log(base)
def kldiv(x, xp, k=3, base=2):
"""
KL Divergence between p and q for x~p(x), xp~q(x); x, xp should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
if x is a one-dimensional scalar and we have four samples
"""
assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
assert k <= len(xp) - 1, "Set k smaller than num. samples - 1"
assert len(x[0]) == len(xp[0]), "Two distributions must have same dim."
d = len(x[0])
n = len(x)
m = len(xp)
const = log(m) - log(n-1)
tree = ss.cKDTree(x)
treep = ss.cKDTree(xp)
nn = [tree.query(point, k+1, p=float('inf'))[0][k] for point in x]
nnp = [treep.query(point, k, p=float('inf'))[0][k-1] for point in x]
return (const + d*np.mean(map(log, nnp))-d*np.mean(map(log, nn)))/log(base)
# Discrete estimators
def entropyd(sx, base=2):
"""
Discrete entropy estimator given a list of samples which can be any hashable object
"""
return entropyfromprobs(hist(sx), base=base)
def midd(x, y):
"""
Discrete mutual information estimator given a list of samples which can be any hashable object
"""
return -entropyd(list(zip(x, y)))+entropyd(x)+entropyd(y)
def cmidd(x, y, z):
"""
Discrete mutual information estimator given a list of samples which can be any hashable object
"""
return entropyd(list(zip(y, z)))+entropyd(list(zip(x, z)))-entropyd(list(zip(x, y, z)))-entropyd(z)
def hist(sx):
# Histogram from list of samples
d = dict()
for s in sx:
d[s] = d.get(s, 0) + 1
return map(lambda z: float(z)/len(sx), d.values())
def entropyfromprobs(probs, base=2):
# Turn a normalized list of probabilities of discrete outcomes into entropy (base 2)
return -sum(map(elog, probs))/log(base)
def elog(x):
# for entropy, 0 log 0 = 0. but we get an error for putting log 0
if x <= 0. or x >= 1.:
return 0
else:
return x*log(x)
# Mixed estimators
def micd(x, y, k=3, base=2, warning=True):
""" If x is continuous and y is discrete, compute mutual information
"""
overallentropy = entropy(x, k, base)
n = len(y)
word_dict = dict()
for sample in y:
word_dict[sample] = word_dict.get(sample, 0) + 1./n
yvals = list(set(word_dict.keys()))
mi = overallentropy
for yval in yvals:
xgiveny = [x[i] for i in range(n) if y[i] == yval]
if k <= len(xgiveny) - 1:
mi -= word_dict[yval]*entropy(xgiveny, k, base)
else:
if warning:
print("Warning, after conditioning, on y={0} insufficient data. Assuming maximal entropy in this case.".format(yval))
mi -= word_dict[yval]*overallentropy
return mi # units already applied
# Utility functions
def vectorize(scalarlist):
"""
Turn a list of scalars into a list of one-d vectors
"""
return [(x,) for x in scalarlist]
def shuffle_test(measure, x, y, z=False, ns=200, ci=0.95, **kwargs):
"""
Shuffle test
Repeatedly shuffle the x-values and then estimate measure(x,y,[z]).
Returns the mean and conf. interval ('ci=0.95' default) over 'ns' runs, 'measure' could me mi,cmi,
e.g. Keyword arguments can be passed. Mutual information and CMI should have a mean near zero.
"""
xp = x[:] # A copy that we can shuffle
outputs = []
for i in range(ns):
random.shuffle(xp)
if z:
outputs.append(measure(xp, y, z, **kwargs))
else:
outputs.append(measure(xp, y, **kwargs))
outputs.sort()
return np.mean(outputs), (outputs[int((1.-ci)/2*ns)], outputs[int((1.+ci)/2*ns)])
# Internal functions
def avgdigamma(points, dvec):
# This part finds number of neighbors in some radius in the marginal space
# returns expectation value of <psi(nx)>
N = len(points)
tree = ss.cKDTree(points)
avg = 0.
for i in range(N):
dist = dvec[i]
# subtlety, we don't include the boundary point,
# but we are implicitly adding 1 to kraskov def bc center point is included
num_points = len(tree.query_ball_point(points[i], dist-1e-15, p=float('inf')))
avg += digamma(num_points)/N
return avg
def zip2(*args):
# zip2(x,y) takes the lists of vectors and makes it a list of vectors in a joint space
# E.g. zip2([[1],[2],[3]],[[4],[5],[6]]) = [[1,4],[2,5],[3,6]]
return [sum(sublist, []) for sublist in zip(*args)]