forked from djstrouse/information-bottleneck
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathIB.py
1374 lines (1232 loc) · 67.3 KB
/
IB.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.stats import multivariate_normal as mvn
import time
import math
import pickle
import copy
vlog = np.vectorize(math.log)
vexp = np.vectorize(math.exp)
# A word on notation: for probability variables, an underscore here means a
# conditioning, so read _ as |.
# todos:
# allow refine_beta parameters to be set by model.fit()
# allow global setting of data types, including sparse arrays
# questions:
# run higher s for uniform smoothing on circlepluscigar
def entropy_term(x):
"""Helper function for entropy_single: calculates one term in the sum."""
if x==0: return 0.0
else: return -x*math.log2(x)
def entropy_single(p):
"""Returns entropy of p: H(p)=-sum(p*log(p)). (in bits)"""
ventropy_term = np.vectorize(entropy_term)
return np.sum(ventropy_term(p))
def entropy(P):
"""Returns entropy of a distribution, or series of distributions.
For the input array P [=] M x N, treats each col as a prob distribution
(over M elements), and thus returns N entropies. If P is a vector, treats
P as a single distribution and returns its entropy."""
if P.ndim==1: return entropy_single(P)
else:
M,N = P.shape
H = np.zeros(N)
for n in range(N):
H[n] = entropy_single(P[:,n])
return H
def kl_term(x,y):
"""Helper function for kl: calculates one term in the sum."""
if x>0 and y>0: return x*math.log2(x/y)
elif x==0: return 0.0
else: return math.inf
def kl_single(p,q):
"""Returns KL divergence of p and q: KL(p,q)=sum(p*log(p/q)). (in bits)"""
vkl_term = np.vectorize(kl_term)
return np.sum(vkl_term(p,q))
def kl(P,Q):
"""Returns KL divergence of one or more pairs of distributions.
For the input arrays P [=] M x N and Q [=] M x L, calculates KL of each col
of P with each col of Q, yielding the KL matrix DKL [=] N x L. If P=Q=1,
returns a single KL divergence."""
if P.ndim==1 and Q.ndim==1: return kl_single(P,Q)
elif P.ndim==1 and Q.ndim!=1: # handle vector P case
M = len(P)
N = 1
M2,L = Q.shape
if M!=M2: raise ValueError("P and Q must have same number of columns")
DKL = np.zeros((1,L))
for l in range(L):
DKL[0,l] = kl_single(P,Q[:,l])
elif P.ndim!=1 and Q.ndim==1: # handle vector Q case
M,N = P.shape
M2 = len(Q)
L = 1
if M!=M2: raise ValueError("P and Q must have same number of columns")
DKL = np.zeros((N,1))
for n in range(N):
DKL[n,0] = kl_single(P[:,n],Q)
else:
M,N = P.shape
M2,L = Q.shape
if M!=M2: raise ValueError("P and Q must have same number of columns")
DKL = np.zeros((N,L))
for n in range(N):
for l in range(L):
DKL[n,l] = kl_single(P[:,n],Q[:,l])
return DKL
class dataset:
"""A representation / interface for an IB dataset, primarily consisting of
the joint distribution p(x,y) and its marginals, conditionals, and stats.
Also includes functionality for acceping data point coordinates and applying
smoothing to yield an IB-appropriate p(x,y)."""
def __init__(self,
pxy=None,
coord=None, # 2d coordinates for x
labels=None, # labels y
gen_param=None, # generative parameters
name=None,
smoothing_type='uniform',
smoothing_center='data_point',
s=None, # smoothing scale
d=None, # neighborhood size
dt=None): # data type
if dt is None: self.dt = np.float32
else: self.dt = dt
if pxy is not None:
if not(isinstance(pxy,np.ndarray)):
raise ValueError('pxy must be a numpy array')
if np.any(pxy<0) or np.any(pxy>1):
raise ValueError('entries of pxy must be between 0 and 1')
if abs(np.sum(pxy)-1)>10**-8:
raise ValueError('pxy must be normalized; sum = %f' % np.sum(pxy))
pxy = pxy.astype(self.dt)
self.pxy = pxy # the distribution that (D)IB acts upon
if coord is not None:
if not(isinstance(coord,np.ndarray)): raise ValueError('coord must be a numpy array')
else: coord = coord.astype(self.dt)
self.coord = coord # locations of data points if geometric, assumed 2D
if labels is not None and len(labels)!=coord.shape[0]:
raise ValueError('number of labels must match number of rows in coord')
self.labels = labels # class labels of data (if synthetic)
if smoothing_type in ['u','uniform']: self.smoothing_type = 'uniform'
elif smoothing_type in ['t','topological']: self.smoothing_type = 'topological'
elif smoothing_type in ['m','metric']: self.smoothing_type = 'metric'
elif smoothing_type is not None: raise ValueError('invalid smoothing_type')
else: self.smoothing_type = None
if smoothing_center in ['d','data_point']: self.smoothing_center = 'data_point'
elif smoothing_center in ['m','neighborhood_mean']: self.smoothing_center = 'neighborhood_mean'
elif smoothing_center in ['b','blended']: self.smoothing_center = 'blended'
elif smoothing_center is not None: raise ValueError('invalid smoothing_center')
else: self.smoothing_center = None
if s is not None and s<0:
raise ValueError('s must be a positive scalar')
self.s = s # determines width of gaussian smoothing for coord->pxy
if d is not None and d<0:
raise ValueError('d must be a positive scalar')
self.d = d # determines neighborhood of a data point in gaussian smoothing for coord->pxy
if gen_param is not None:
if not(isinstance(gen_param,dict)):
raise ValueError('gen_param must be a dictionary')
self.gen_param = gen_param # generative parameters of data (if synthetic)
self.name = name # name of dataset, used for saving
if self.pxy is not None:
self.process_pxy()
elif self.coord is not None:
self.X = self.coord.shape[0]
if self.pxy is None and self.coord is not None and self.s is not None:
self.coord_to_pxy()
def __str__(self):
return(self.name)
def process_pxy(self,drop_zeros=True):
"""Drops unused x and y, and calculates info-theoretic stats of pxy."""
Xorig, Yorig = self.pxy.shape
px = self.pxy.sum(axis=1)
py = self.pxy.sum(axis=0)
if drop_zeros:
nzx = px>0 # find nonzero-prob entries
nzy = py>0
zx = np.where(px<=0)[0]
zy = np.where(py<=0)[0]
self.px = px[nzx] # drop zero-prob entries
self.py = py[nzy]
self.Ygrid = self.Ygrid[nzy]
pxy_orig = self.pxy
tmp = pxy_orig[nzx,:]
self.pxy = tmp[:,nzy] # pxy_orig with zero-prob x,y removed
else:
self.px = px
self.py = py
self.X = len(self.px)
self.Y = len(self.py)
if (Xorig-self.X)>0:
print('%i of %i Xs dropped due to zero prob; size now %i. Dropped IDs:' % (Xorig-self.X,Xorig,self.X))
print(zx)
if (Yorig-self.Y)>0:
print('%i of %i Ys dropped due to zero prob; size now %i. Dropped IDs:' % (Yorig-self.Y,Yorig,self.Y))
print(zy)
self.py_x = np.multiply(self.pxy.T,np.tile(1./self.px,(self.Y,1)))
self.hx = entropy(self.px)
self.hy = entropy(self.py)
self.hy_x = np.dot(self.px,entropy(self.py_x))
self.ixy = self.hy-self.hy_x
def normalize_coord(self):
desired_r = 20
min_x1 = np.min(self.coord[:,0])
min_x2 = np.min(self.coord[:,1])
max_x1 = np.max(self.coord[:,0])
max_x2 = np.max(self.coord[:,1])
range_x1 = max_x1-min_x1
range_x2 = max_x2-min_x2
r = (range_x1+range_x2)/2
# zero-mean
self.coord = self.coord - np.mean(self.coord,axis=0)
# scale
self.coord = desired_r*self.coord/r
def make_bins(self,total_bins=2500,pad=None):
"""Compute appropriate spatial bins."""
if pad is None: pad = 2*self.s # bins further than this from all data points are dropped
self.pad = pad
# dimensional preprocessing
min_x1 = np.min(self.coord[:,0])
min_x2 = np.min(self.coord[:,1])
max_x1 = np.max(self.coord[:,0])
max_x2 = np.max(self.coord[:,1])
range_x1 = max_x1-min_x1
range_x2 = max_x2-min_x2
bins1 = int(math.sqrt(total_bins*range_x1/range_x2)) # divy up bins according to spread of data
bins2 = int(math.sqrt(total_bins*range_x2/range_x1))
Y = int(bins1*bins2)
# generate bins
min_y1 = min_x1-pad
max_y1 = max_x1+pad
min_y2 = min_x2-pad
max_y2 = max_x2+pad
y1 = np.linspace(min_y1,max_y1,bins1,dtype=self.dt)
y2 = np.linspace(min_y2,max_y2,bins2,dtype=self.dt)
y1v,y2v = np.meshgrid(y2,y1)
Ygrid = np.array([np.reshape(y1v,Y),np.reshape(y2v,Y)]).T
return Y,bins1,bins2,y1v,y2v,Ygrid
def coord_to_pxy(self,total_bins=2500,pad=None,drop_distant=True,
drop_zeros=True,make_smoothed_coord_density=True):
"""Uses smoothing paramters to transform coord into pxy."""
# assumes 2D coord, total_bins is approximate
if self.smoothing_type is None: raise ValueError('smoothing_type not yet set')
if self.s is None: raise ValueError('smoothing scale, s, not yet set')
if self.smoothing_type=='uniform':
print('smoothing coordinates: smoothing_type = uniform, scale s = %.2f' % self.s)
else:
if self.smoothing_center is None: raise ValueError('smoothing_center not yet set')
if self.d is None: raise ValueError('neighborhood size, d, not yet set')
if self.smoothing_type=='topological':
print('Smoothing coordinates: smoothing_type = topological, smoothing_center = %s, scale s = %.2f, neighborhood size d = %i' % (self.smoothing_center,self.s,self.d))
elif self.smoothing_type=='metric':
print('Smoothing coordinates: smoothing_type = metric, smoothing_center = %s, scale s = %.2f, neighborhood size d = %.1f' % (self.smoothing_center,self.s,self.d))
else: raise ValueError('invalid smoothing_type')
# compute appropriate spatial bins
Y,bins1,bins2,y1v,y2v,Ygrid = self.make_bins(total_bins=total_bins,pad=pad)
# construct gaussian-smoothed p(y|x), based on smoothing parameters
py_x = np.zeros((Y,self.X),dtype=self.dt)
smoothed_coord_density = np.zeros(y1v.shape,dtype=self.dt)
ycountv = np.zeros(Y) # counts data points within pad of each bin
ycount = np.zeros(y1v.shape)
for x in range(self.X):
# construct gaussian covariance
if self.smoothing_type in ["u","uniform"]:
S = (self.s**2)*np.eye(2)
mu = self.coord[x,:]
else: # nearest-neighbor, locally estimate covariance cases
distances = np.array([np.linalg.norm(self.coord[x,:]-self.coord[x2,:]) for x2 in range(self.X)])
# find neighbors
if self.smoothing_type in ["t","topological"]:
neighbors = list(np.argpartition(distances, self.d)[0:self.d])
elif self.smoothing_type in ["m","metric"]:
neighbors = list(np.where(distances<self.d)[0])
# if sufficiently many neighbors...
N = len(neighbors)
if N>=3:
# calculate S centered on data point...
if self.smoothing_center in ['d','data_point']: mu = self.coord[x,:]
# ...or on neighborhood mean (where hood includes data point)
else:
neighbors.append(x)
mu = np.mean(self.coord[neighbors,:],axis=0)
S = np.zeros((2,2))
for n in neighbors: S += np.outer(self.coord[n,:]-mu,self.coord[n,:]-mu)
S *= 1/N
# scale so that s determines volume, and S only determines shape
det = np.linalg.det(S)
S *= (self.s**2)/math.sqrt(det)
# if not enough neighbors, just use uniform
else: S = (self.s**2)*np.eye(2)
# for blended approach, switch definition of mu
if self.smoothing_center in ['b','blended']: mu = self.coord[x,:]
# in other two cases, mu remains as above
# use covariance to smooth data
rv = mvn(mu,S)
y = 0
py_x[:,x] = rv.pdf(Ygrid).astype(self.dt)
if make_smoothed_coord_density:
for y1 in range(y1v.shape[0]):
for y2 in range(y1v.shape[1]):
#py_x[y,x] = rv.pdf(Ygrid[y,:]).astype(self.dt)
smoothed_coord_density[y1,y2] += rv.pdf([y1v[y1,y2],y2v[y1,y2]]).astype(self.dt)
if drop_distant and np.linalg.norm(self.coord[x,:]-Ygrid[y,:])<self.pad:
ycountv[y] += 1
ycount[y1,y2] += 1
y += 1
if drop_distant:
# drop ybins that are too far away from data
ymask = ycountv>0
py_x = py_x[ymask,:]
print("Dropped %i ybins. Y reduced from %i to %i." % (Y-np.sum(ymask),Y,np.sum(ymask)))
Y = np.sum(ymask)
Ygrid = Ygrid[ymask,:]
self.bins_dropped = ycount==0
else: self.bins_dropped = None
self.Y = Y
self.Ygrid = Ygrid
# normalize p(y|x), since gaussian binned/truncated and bins dropped
for x in range(self.X): py_x[:,x] = py_x[:,x]/np.sum(py_x[:,x])
self.py_x = py_x
# package stuff for plotting smoothed density in coord space
if make_smoothed_coord_density: self.smoothed_coord_density = smoothed_coord_density/np.sum(smoothed_coord_density[:])
self.y1v = y1v
self.y2v = y2v
# construct p(x) and p(x,y)
self.px = (1/self.X)*np.ones(self.X,dtype=self.dt)
self.pxy = np.multiply(np.tile(self.px,(self.Y,1)),self.py_x).T
# calc and display I(x,y)
self.process_pxy(drop_zeros=drop_zeros)
print("I(X;Y) = %.3f" % self.ixy)
def plot_coord(self,save=False,path=None):
if self.coord is not None:
fig = plt.figure()
plt.scatter(self.coord[:,0],self.coord[:,1])
plt.axis('scaled')
plt.show()
if save:
if path is None: raise ValueError('must specify path to save figure')
else: fig.savefig(path+self.name+'_coord.pdf',bbox_inches='tight')
else:
print("coord not yet defined")
def plot_smoothed_coord(self,save=False,path=None):
fig = plt.figure()
plt.title('s = %i' % self.s,fontsize=18,fontweight='bold')
plt.contour(self.y1v,self.y2v,self.smoothed_coord_density)
plt.scatter(self.coord[:,0],self.coord[:,1])
#plt.axis('scaled')
plt.axis([-22,22,-15,15])
plt.show()
if save:
if path is None: raise ValueError('must specify path to save figure')
else: fig.savefig(path+self.name+'_smoothed_coord_s%i'%self.s+'.pdf',bbox_inches='tight')
def plot_pxy(self,save=False,path=None):
fig = plt.figure()
if self.pxy is not None:
if self.s==2:
plt.xlabel('Y',fontsize=14,fontweight='bold')
plt.ylabel('X',fontsize=14,fontweight='bold')
plt.contourf(self.pxy)
plt.show()
if save:
if path is None: raise ValueError('must specify path to save figure')
else: fig.savefig(path+self.name+'_pxy_s%i'%self.s+'.pdf',bbox_inches='tight')
else:
print("pxy not yet defined")
def save(self,directory,filename=None):
"""Pickles dataset in directory with filename."""
if filename is None: filename = self.name+'_dataset'
with open(directory+filename+'.pkl', 'wb') as output:
pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
def load(self,directory,filename=None):
"""Replaces current content with pickled data in directory with filename."""
if filename is None: filename = self.name+'_dataset'
with open(directory+filename+'.pkl', 'rb') as input:
obj = pickle.load(input)
self.__init__(pxy = obj.pxy, coord = obj.coord, labels = obj.labels,
gen_param = obj.gen_param, name = obj.name, s = obj.s)
class model:
"""A representation / interface for an IB model, primarily consisting of
the encoder / clustering q(t|x) and its associated distributions.
Functions of main interest to users are (in order) __init__, fit,
report_metrics, report_param, and possibly clamp. The rest are primarily
helper functions that won't be called directly by the user."""
def __init__(self,ds,alpha,beta,Tmax=None,qt_x=None,p0=None,waviness=None,
ctol_abs=10**-4,ctol_rel=0.,cthresh=1,ptol=10**-8,zeroLtol=0,
geoapprox=False,step=None,dt=None,quiet=False):
"""ds is a dataset object (see dataset class above). alpha and beta are
IB parameters that appear in the generalized cost functional (see
Strouse & Schwab 2016). Tmax is the maximum number of clusters allowed,
i.e. the maximum cardinality of T. qt_x is the initialization of the
encoder. If not provided, qt_x will be initialized randomly based on
p0 and waviness (see init_qt_x below for details). ctol_abs, ctol_rel,
and cthresh are convergence tolerances; see check_converged below for
details. ptol is the threshold for considering a probability to be zero;
clusters with probability mass below ptol are pruned. zeroLtol governs
how aggressively converged solutions are replaced with the single-cluster
solution; if converged L>zeroLtol, it gets replaced (see
check_single_better below for details). geoapprox determines whether a
particular approximation to the IB algorithm is used; applicable to
geometric datasets where coord is available only. quiet is a flag that
suppresses some output."""
if not(isinstance(ds,dataset)):
raise ValueError('ds must be a dataset')
self.ds = ds # dataset
if dt is None: self.dt = np.float32
else: self.dt = dt
if alpha<0: raise ValueError('alpha must be a non-negative scalar')
self.alpha = alpha
if not(beta>0): raise ValueError('beta must be a positive scalar')
self.beta = beta
if Tmax is None:
Tmax = ds.X
if not(quiet): print('Tmax set to %i based on X' % Tmax)
elif Tmax<1 or Tmax!=int(Tmax):
raise ValueError('Tmax must be a positive integer')
elif Tmax>ds.X:
print('Reduced Tmax from %i to %i based on X' % (Tmax,ds.X))
Tmax = ds.X
else: Tmax = int(Tmax)
self.Tmax = Tmax
self.T = Tmax
if ctol_rel==0 and ctol_abs==0:
raise ValueError('One of ctol_rel and ctol_abs must be postive')
if ctol_abs<0 or not(isinstance(ctol_abs,float)):
raise ValueError('ctol_abs must be a non-negative float')
self.ctol_abs = ctol_abs
if ctol_rel<0 or not(isinstance(ctol_rel,float)):
raise ValueError('ctol_rel must be a non-negative float')
self.ctol_rel = ctol_rel
if cthresh<1 or cthresh!=int(cthresh):
raise ValueError('cthresh must be a positive integer')
self.cthresh = cthresh
if not(ptol>0) or not(isinstance(ptol,float)):
raise ValueError('ptol must be a positive float')
self.ptol = ptol
if zeroLtol<0:
raise ValueError('zeroLtol must be a non-negative float or integer')
self.zeroLtol = zeroLtol
self.geoapprox = geoapprox
self.quiet = quiet
self.using_labels = False
self.clamped = False
self.conv_time = None
self.conv_condition = None
self.merged = False
if step is None: self.step = 0
if p0 is None:
if alpha==0: p0 = 1. # DIB default: deterministic init that spreads points evenly across clusters
else: p0 = .75 # non-DIB default: DIB-like init but with only 75% prob mass on "assigned" cluster
elif p0<-1 or p0>1 or not(isinstance(p0,(int,float))):
raise ValueError('p0 must be a float/int between -1 and 1')
else: p0 = float(p0)
self.p0 = p0
if waviness is not None and (waviness<0 or waviness>1 or not(isinstance(waviness,float))):
raise ValueError('waviness must be a float between 0 and 1')
self.waviness = waviness
start_time = time.time()
if qt_x is not None: # use initialization if provided
if not(isinstance(qt_x,np.ndarray)):
raise ValueError('qt_x must be a numpy array')
if isinstance(qt_x,np.ndarray):
if np.any(qt_x<0) or np.any(qt_x>1):
raise ValueError('entries of qt_x must be between 0 and 1')
if qt_x.shape[0]==1: # if single cluster
if np.any(qt_x!=1): raise ValueError('columns of qt_x must be normalized')
elif np.any(abs(np.sum(qt_x,axis=0)-1)>ptol): # if multi-cluster
raise ValueError('columns of qt_x must be normalized')
self.qt_x = qt_x.astype(self.dt)
self.T = qt_x.shape[0]
else: # initialize randomly if not
self.init_qt_x()
self.make_step(init=True)
self.step_time = time.time()-start_time
if not(self.quiet): print('step %i: ' % self.step + self.report_metrics())
def use_labels(self):
"""Uses labels to give model 'true' encoder q(t|x)."""
if self.ds.labels is None: raise ValueError('dataset doesnt have labels')
start_time = time.time()
self.using_labels = True
self.clamped = False
self.conv_time = None
self.conv_condition = None
self.merged = False
self.step = 0
label_alphabet = np.unique(self.ds.labels)
T = len(label_alphabet)
self.Tmax = T
self.T = T
qt_x = np.zeros((T,self.ds.X),dtype=self.dt)
for x in range(self.ds.X):
tstar = np.where(label_alphabet==self.ds.labels[x])[0][0]
qt_x[tstar,x] = 1
self.qt_x = qt_x.astype(self.dt)
self.make_step(init=True)
self.step_time = time.time()-start_time
if not(self.quiet): print('true init: ' + self.report_metrics())
def init_qt_x(self):
"""Initializes q(t|x) for generalized Information Bottleneck.
For p0 = 0: init is random noise. If waviness = None, normalized uniform
random vector. Otherwise, uniform over clusters +- uniform noise of
magnitude waviness.
For p0 positive: attempt to spread points as evenly across clusters as
possible. Prob mass p0 is given to the "assigned" clusters, and the
remaining 1-p0 prob mass is randomly assigned. If waviness = None, again
use a normalized random vector to assign the remaining mass. Otherwise,
uniform +- waviness again.
For p0 negative: just as above, except that all data points are "assigned"
to the same cluster (well, at least |p0| of their prob mass)."""
if self.p0==0: # don't insert any peaks; init is all "noise"
if self.waviness: # flat + wavy style noise
self.qt_x = np.ones((self.T,self.ds.X))+2*(np.random.rand(self.T,self.ds.X)-.5)*self.waviness # 1+-waviness%
for i in range(self.ds.X):
self.qt_x[:,i] = self.qt_x[:,i]/np.sum(self.qt_x[:,i]) # normalize
else: # uniform random vector
self.qt_x = np.random.rand(self.T,self.ds.X)
self.qt_x = np.multiply(self.qt_x,np.tile(1./np.sum(self.qt_x,axis=0),(self.T,1))) # renormalize
elif self.p0>0: # spread points evenly across clusters; "assigned" clusters for each data point get prob mass p0
if self.waviness:
# insert wavy noise part
self.qt_x = np.ones((self.T,self.ds.X))+2*(np.random.rand(self.T,self.ds.X)-.5)*self.waviness # 1+-waviness%
# choose clusters for each x to get spikes
n = math.ceil(float(self.ds.X)/float(self.T)) # approx number points per cluster
I = np.repeat(np.arange(0,self.T),n).astype("int") # data-to-cluster assignment vector
np.random.shuffle(I)
for i in range(self.ds.X):
self.qt_x[I[i],i] = 0 # zero out that cluster
self.qt_x[:,i] = (1-self.p0)*self.qt_x[:,i]/np.sum(self.qt_x[:,i]) # normalize others to 1-p0
self.qt_x[I[i],i] = self.p0 # insert p0 spike
else: # uniform random vector instead of wavy
self.qt_x = np.zeros((self.T,self.ds.X))
# choose clusters for each x to get spikes
n = math.ceil(float(self.ds.X)/float(self.T)) # approx number points per cluster
I = np.repeat(np.arange(0,self.T),n).astype("int") # data-to-cluster assignment vector
np.random.shuffle(I)
for i in range(self.ds.X):
u = np.random.rand(self.T)
u[I[i]] = 0
u = (1-self.p0)*u/np.sum(u)
u[I[i]] = self.p0
self.qt_x[:,i] = u
else: # put all points in the same cluster; primary cluster gets prob mass |p0|
p0 = -self.p0
if self.waviness:
self.qt_x = np.ones((self.T,self.ds.X))+2*(np.random.rand(self.T,self.ds.X)-.5)*self.waviness # 1+-waviness%
t = np.random.randint(self.T) # pick cluster to get delta spike
self.qt_x[t,:] = np.zeros((1,self.ds.X)) # zero out that cluster
self.qt_x = np.multiply(self.qt_x,np.tile(1./np.sum(self.qt_x,axis=0),(self.T,1))) # normalize the rest...
self.qt_x = (1-p0)*self.qt_x # ...to 1-p0
self.qt_x[t,:] = p0*np.ones((1,self.ds.X)) # put in delta spike
else: # uniform random vector instead of wavy
self.qt_x = np.zeros((self.T,self.ds.X))
# choose clusters for each x to get spikes
t = np.random.randint(self.T) # pick cluster to get delta spike
for i in range(self.ds.X):
u = np.random.rand(self.T)
u[t] = 0
u = (1-p0)*u/np.sum(u)
u[t] = p0
self.qt_x[:,i] = u
if self.qt_x.dtype != self.dt: self.qt_x = self.qt_x.astype(self.dt)
def qt_step(self):
"""Peforms q(t) update step for generalized Information Bottleneck."""
self.qt = np.dot(self.qt_x,self.ds.px).astype(self.dt)
dropped = self.qt<=self.ptol # clusters to drop due to near-zero prob
if any(dropped):
self.qt = self.qt[~dropped] # drop ununsed clusters
self.qt_x = self.qt_x[~dropped,:]
self.T = len(self.qt) # update number of clusters
self.qt_x = np.multiply(self.qt_x,np.tile(1./np.sum(self.qt_x,axis=0),(self.T,1))) # renormalize
self.qt = np.dot(self.qt_x,self.ds.px).astype(self.dt)
if not(self.quiet): print('%i cluster(s) dropped. Down to %i cluster(s).' % (np.sum(dropped),self.T))
def qy_t_step(self):
"""Peforms q(y|t) update step for generalized Information Bottleneck."""
self.qy_t = np.dot(self.ds.py_x,np.multiply(self.qt_x,np.outer(1./self.qt,self.ds.px)).T)
if self.qy_t.dtype != self.dt: self.qy_t = self.qy_x.astype(self.dt)
def query_coord(self,x,ptol=0):
"""Returns cluster assignment for new data point not in training set."""
# currently assumes uniform smoothing; needs extended to nearest-neighbor
if self.alpha!=0: raise ValueError('only implemented for DIB (alpha=0)')
if self.T==1: return 0
else:
# which y correspond to which spatial locations? Ygrid tells us!
py_x = mvn.pdf(self.ds.Ygrid,mean=x,cov=(self.ds.s**2)*np.eye(2))
ymask = py_x>ptol
perc_dropped = 100*(1-np.mean(ymask))
l = vlog(self.qt)-self.beta*kl(py_x[ymask],self.qy_t[ymask,:])
return np.argmax(l),perc_dropped
def qt_x_step(self):
"""Peforms q(t|x) update step for generalized Information Bottleneck."""
if self.T==1: self.qt_x = np.ones((1,self.X),dtype=self.dt)
else:
self.qt_x = np.zeros((self.T,self.ds.X),dtype=self.dt)
for x in range(self.ds.X):
l = vlog(self.qt)-self.beta*kl(self.ds.py_x[:,x],self.qy_t) # [=] T x 1 # scales like X*Y*T
if self.alpha==0: self.qt_x[np.argmax(l),x] = 1
else: self.qt_x[:,x] = vexp(l/self.alpha)/np.sum(vexp(l/self.alpha)) # note: l/alpha<-745 is where underflow creeps in
if self.qt_x.dtype != self.dt: self.qt_x = self.qt_x.astype(self.dt)
def build_dist_mat(self):
"""Replaces the qy_t_step whens using geoapprox."""
self.Dxt = np.zeros((self.ds.X,self.T))
for x in range(self.ds.X):
for t in range(self.T):
for otherx in np.nditer(np.nonzero(self.qt_x[t,:])): # only iterate over x with nonzero involvement
self.Dxt[x,t] += self.qt_x[t,otherx]*np.linalg.norm(self.ds.coord[x,:]-self.ds.coord[otherx,:])**2
self.Dxt[x,t] *= 1/(self.ds.X*self.qt[t])
def qt_x_step_geoapprox(self):
"""Peforms q(t|x) update step for approximate generalized Information
Bottleneck, an algorithm for geometric clustering."""
if self.T==1: self.qt_x = np.ones((1,self.ds.X),dtype=self.dt)
else:
self.qt_x = np.zeros((self.T,self.ds.X),dtype=self.dt)
for x in range(self.ds.X):
l = vlog(self.qt)-(self.beta/(2*self.ds.s**2))*self.Dxt[x,:] # only substantive difference from qt_x_step
if self.alpha==0: self.qt_x[np.argmax(l),x] = 1
else: self.qt_x[:,x] = vexp(l/self.alpha)/np.sum(vexp(l/self.alpha)) # note: l/alpha<-745 is where underflow creeps in
if self.qt_x.dtype != self.dt: self.qt_x = self.qt_x.astype(self.dt)
def calc_metrics(self):
"""Calculates IB performance metrics.."""
self.ht = entropy(self.qt)
self.hy_t = np.dot(self.qt,entropy(self.qy_t))
self.iyt = self.ds.hy-self.hy_t
self.ht_x = np.dot(self.ds.px,entropy(self.qt_x))
self.ixt = self.ht-self.ht_x
self.L = self.ht-self.alpha*self.ht_x-self.beta*self.iyt
def report_metrics(self):
"""Returns string of model metrics."""
self.calc_metrics()
return 'I(X,T) = %.3f, H(T) = %.3f, T = %i, H(X) = %.3f, I(Y,T) = %.3f, I(X,Y) = %.3f, L = %.3f' % (self.ixt,self.ht,self.T,self.ds.hx,self.iyt,self.ds.ixy,self.L)
def report_param(self):
"""Returns string of model parameters."""
if self.p0 is None or self.qt_x is not None: p0_str = 'None'
else: p0_str = '%.3f' % self.p0
if self.waviness is None or self.qt_x is not None: waviness_str = 'None'
else: waviness_str = '%.2f' % self.waviness
if self.ds.smoothing_type is None: smoothing_type_str = 'None'
else: smoothing_type_str = self.ds.smoothing_type
if self.ds.smoothing_center is None: smoothing_center_str = 'None'
else: smoothing_center_str = self.ds.smoothing_center
if self.ds.s is None: s_str = 'None'
else: s_str = '%.2f' % self.ds.s
if self.ds.d is None: d_str = 'None'
elif self.ds.d==int(self.ds.d): d_str = '%i' % self.ds.d
else: d_str = '%.1f' % self.ds.d
return 'alpha = %.2f, beta = %.1f, Tmax = %i, p0 = %s, wav = %s, geo = %s,\nctol_abs = %.0e, ctol_rel = %.0e, cthresh = %i, ptol = %.0e, zeroLtol = %.0e\nsmoothing_type = %s, smoothing_center = %s, s = %s, d = %s' %\
(self.alpha, self.beta, self.Tmax, p0_str, waviness_str, self.geoapprox,
self.ctol_abs, self.ctol_rel, self.cthresh, self.ptol, self.zeroLtol,
smoothing_type_str, smoothing_center_str, s_str, d_str)
def make_step(self,init=False):
"""Performs one IB step."""
if not(init):
start_time = time.time()
if self.geoapprox: self.qt_x_step_geoapprox()
else: self.qt_x_step()
self.qt_step()
self.qy_t_step()
if self.geoapprox: self.build_dist_mat()
else: self.Dxt = None
self.calc_metrics()
self.step += 1
self.merged = False
if not(init):
self.step_time = time.time()-start_time
def clamp(self):
"""Clamps solution to argmax_t of q(t|x) for each x, i.e. hard clustering."""
print('before clamp: ' + self.report_metrics())
if self.alpha==0: print('WARNING: clamping with alpha=0; solution is likely already deterministic.')
for x in range(self.ds.X):
tstar = np.argmax(self.qt_x[:,x])
self.qt_x[tstar,x] = 1
self.qt_step()
self.qy_t_step()
if self.geoapprox: self.build_dist_mat()
self.clamped = True
print('after clamp: ' + self.report_metrics())
def panda(self,dist_to_keep=set()):
""""Return dataframe of model. If dist, include distributions.
If conv, include converged variables; otherwise include stepwise."""
df = pd.DataFrame(data={
'alpha': self.alpha, 'beta': self.beta, 'step': self.step,
'L': self.L, 'ixt': self.ixt, 'iyt': self.iyt, 'ht': self.ht,
'T': self.T, 'ht_x': self.ht_x, 'hy_t': self.hy_t,
'hx': self.ds.hx, 'ixy': self.ds.ixy, 'Tmax': self.Tmax,
'p0': self.p0, 'waviness': self.waviness, 'ptol': self.ptol,
'ctol_abs': self.ctol_abs, 'ctol_rel': self.ctol_rel,
'cthresh': self.cthresh, 'zeroLtol': self.zeroLtol,
'clamped': self.clamped, 'geoapprox': self.geoapprox,
'using_labels': self.using_labels, 'merged': self.merged,
'smoothing_type': self.ds.smoothing_type,
'smoothing_center': self.ds.smoothing_center,
's': self.ds.s, 'd': self.ds.d,
'step_time': self.step_time, 'conv_time': self.conv_time,
'conv_condition': self.conv_condition}, index = [0])
if 'qt_x' in dist_to_keep:
df['qt_x'] = [self.qt_x]
if 'qt' in dist_to_keep:
df['qt'] = [self.qt]
if 'qy_t' in dist_to_keep:
df['qy_t'] = [self.qy_t]
if 'Dxt' in dist_to_keep:
df['Dxt'] = [self.Dxt]
return df
def depanda(self,df):
"""Replaces current model with one in df."""
self.alpha = df['alpha'][0]
self.beta = df['beta'][0]
self.step = df['step'][0]
self.L = df['L'][0]
self.ixt = df['ixt'][0]
self.iyt = df['iyt'][0]
self.ht = df['ht'][0]
self.T = df['T'][0]
self.ht_x = df['ht_x'][0]
self.hy_t = df['hy_t'][0]
self.hx = df['hx'][0]
self.ixy = df['ixy'][0]
self.Tmax = df['Tmax'][0]
self.p0 = df['p0'][0]
self.waviness = df['waviness'][0]
self.ptol = df['ptol'][0]
self.ctol_abs = df['ctol_abs'][0]
self.ctol_rel = df['ctol_rel'][0]
self.cthresh = df['cthresh'][0]
self.zeroLtol = df['zeroLtol'][0]
self.clamped = df['clamped'][0]
self.geoapprox = df['geoapprox'][0]
self.using_labels = df['using_labels'][0]
self.merged = df['merged'][0]
self.ds.smoothing_type = df['smoothing_type'][0]
self.ds.smoothing_center = df['smoothing_center'][0]
self.ds.s = df['s'][0]
self.ds.d = df['d'][0]
self.step_time = df['step_time'][0]
self.conv_time = df['conv_time'][0]
self.conv_condition = df['conv_condition'][0]
self.qt_x = df['qt_x'][0]
self.qt = df['qt'][0]
self.qy_t = df['qy_t'][0]
self.Dxt = df['Dxt'][0]
def append_conv_condition(self,cond):
if self.conv_condition is None: self.conv_condition = cond
else: self.conv_condition += '_AND_' + cond
def update_sw(self):
"""Appends current model / stats to the internal stepwise dataframe."""
if self.keep_steps:
# store stepwise data
self.metrics_sw = self.metrics_sw.append(self.panda(), ignore_index = True)
if bool(self.dist_to_keep): self.dist_sw = self.dist_sw.append(self.panda(self.dist_to_keep), ignore_index = True)
def check_converged(self):
"""Checks if most recent step triggered convergence, and stores step /
reverts model to last step if necessary."""
Lold = self.prev['L'][0]
# check for small changes
small_abs_changes = abs(Lold-self.L)<self.ctol_abs
small_rel_changes = (abs(Lold-self.L)/abs(Lold))<self.ctol_rel
if small_abs_changes or small_rel_changes: self.cstep += 1
else: self.cstep = 0 # reset counter of small changes in a row
if small_abs_changes and self.cstep>=self.cthresh:
self.conv_condition = 'small_abs_changes'
print('converged due to small absolute changes in objective')
if small_rel_changes and self.cstep>=self.cthresh:
self.append_conv_condition('small_rel_changes')
print('converged due to small relative changes in objective')
# check for objective becoming NaN
if np.isnan(self.L):
self.cstep = self.cthresh
self.append_conv_condition('cost_func_NaN')
print('stopped because objective = NaN')
L_abs_inc_flag = self.L>(Lold+self.ctol_abs)
L_rel_inc_flag = self.L>(Lold+(abs(Lold)*self.ctol_rel))
# check for reduction to single cluster
if self.T==1 and not(L_abs_inc_flag) and not(L_rel_inc_flag):
self.cstep = self.cthresh
self.append_conv_condition('single_cluster')
print('converged due to reduction to single cluster')
# check if obj went up by amount above threshold (after 1st step)
if (L_abs_inc_flag or L_rel_inc_flag) and self.step>1: # if so, don't store or count this step!
self.cstep = self.cthresh
if L_abs_inc_flag:
self.append_conv_condition('cost_func_abs_inc')
print('converged due to absolute increase in objective value')
if L_rel_inc_flag:
self.append_conv_condition('cost_func_rel_inc')
print('converged due to relative increase in objective value')
# revert to metrics/distributions from last step
self.prev.conv_condition = self.conv_condition
self.depanda(self.prev)
# otherwise, store step
else: self.update_sw()
# if converged, check if single cluster solution better
if self.cstep>=self.cthresh and self.T>1: self.check_single_better()
def check_single_better(self):
""" Replace converged step with single-cluster map if better."""
sqt_x = np.zeros((self.T,self.ds.X),dtype=self.dt)
sqt_x[0,:] = 1.
smodel = model(ds=self.ds,alpha=self.alpha,beta=self.beta,Tmax=self.Tmax,
qt_x=sqt_x,p0=self.p0,waviness=self.waviness,
ctol_abs=self.ctol_abs,ctol_rel=self.ctol_rel,cthresh=self.cthresh,
ptol=self.ptol,zeroLtol=self.zeroLtol,geoapprox=self.geoapprox,
quiet=True)
smodel.step = self.step
smodel.conv_condition = self.conv_condition + '_AND_force_single'
if smodel.L<(self.L-self.zeroLtol): # if better fit...
print("single-cluster mapping reduces L from %.4f to %.4f (zeroLtol = %.1e); replacing solution." % (self.L,smodel.L,self.zeroLtol))
# replace everything
self.depanda(smodel.panda(dist_to_keep={'qt_x','qt','qy_t','Dxt'}))
self.update_sw()
print('single-cluster solution: ' + self.report_metrics())
else: print("single-cluster mapping not better; changes L from %.4f to %.4f (zeroLtol = %.1e)." % (self.L,smodel.L,self.zeroLtol))
def check_merged_better(self,findbest=True):
"""Checks if merging any two clusters improves cost function.
If findbest = True, review all merges and choose best, if any improve L.
If findbest = False, just accept the first merge that improves L.
Latter option good if too many clusters to compare all."""
start_time = time.time()
anybetter = False
best = copy.deepcopy(self)
# iterate over cluster pairs
for t1 in range(self.T-1):
for t2 in range(t1+1,self.T):
if not(anybetter) or findbest:
# copy model
alt = copy.deepcopy(self)
alt.quiet = True
# t2 -> t1
alt.qt_x[t1,alt.qt_x[t2,:]==1] = 1
alt.qt_x[t2,:] = 0
# update other dist
alt.make_step(init=True)
# check if cost function L reduced relative to best so far
if alt.L<best.L:
best = copy.deepcopy(alt)
mergedt1 = t1
mergedt2 = t2
anybetter = True
if anybetter:
print('merged clusters %i and %i, reducing L from %.3f to %.3f' % (mergedt1,mergedt2,self.L,best.L))
self.__init__(ds=self.ds,alpha=self.alpha,beta=self.beta,
Tmax=self.Tmax,qt_x=best.qt_x,p0=self.p0,waviness=self.waviness,
ctol_abs=self.ctol_abs,ctol_rel=self.ctol_rel,cthresh=self.cthresh,
ptol=self.ptol,zeroLtol=self.zeroLtol,geoapprox=self.geoapprox,
step=self.step,quiet=self.quiet)
self.merged = True
self.step_time = time.time()-start_time
self.cstep = 0
self.conv_condition = None
self.update_sw()
return True
else:
self.merged = False
print('no merges reduce L')
return False
def fit(self,keep_steps=False,dist_to_keep={'qt_x','qt','qy_t','Dxt'}):
"""Runs generalized IB algorithm to convergence for current model.
keep_steps determines whether pre-convergence models / statistics about
them are kept. dist_to_keep is a set with the model distributions to be
kept for each step."""
fit_start_time = time.time()
self.keep_steps = keep_steps
self.dist_to_keep = dist_to_keep
print(20*'*'+' Beginning IB fit with the following parameters '+20*'*')
print(self.report_param())
print(88*'*')
# initialize stepwise dataframes, if tracking them
if self.keep_steps:
self.metrics_sw = self.panda()
if bool(self.dist_to_keep): self.dist_sw = self.panda(self.dist_to_keep)
# check if single cluster init
if self.T==1:
self.cstep = self.cthresh
print('converged due to initialization with single cluster')
self.conv_condition = 'single_cluster_init'
else: # init iterative parameters
self.cstep = 0
self.conv_condition = None
# save encoder init
self.qt_x0 = self.qt_x
# iterate to convergence
while self.cstep<self.cthresh:
self.prev = self.panda(dist_to_keep={'qt_x','qt','qy_t','Dxt'})
self.make_step()
print('step %i: ' % self.step + self.report_metrics())
self.check_converged()
if self.cstep>=self.cthresh and self.T>1 and self.alpha==0: self.check_merged_better()
# report
print('converged in %i step(s) to: ' % self.step + self.report_metrics())
# clean up
self.cstep = None
self.prev = None
self.step_time = None
# record total time to convergence
self.conv_time = time.time() - fit_start_time
def plot_qt_x(self):
"""Visualizes clustering induced by q(t|x), if coord available."""
if self.ds.coord is None:
raise ValueError('coordinates not available; cannot plot')
if not(self.alpha==0 or self.clamped):
raise ValueError('qt_x not determinstic; cannot plot')
# build cluster assignment vector
cluster = np.zeros(self.ds.X)
for x in range(self.ds.X): cluster[x] = np.nonzero(self.qt_x[:,x])[0][0]
# plot with ggplot
plt.figure()
plt.scatter(self.ds.coord[:,0],self.ds.coord[:,1],c=cluster)
plt.show()
def refine_beta(metrics_conv):
"""Helper function for IB to automate search over parameter beta."""
# parameters governing insertion of betas, or when there is a transition to NaNs (due to under/overflow)
l = 1 # number of betas to insert into gaps
del_R = .05 # if fractional change in I(Y;T) exceeds this between adjacent betas, insert more betas
del_C = .05 # if fractional change in H(T) or I(X;T) exceeds this between adjacent betas, insert more betas
del_T = 0 # if difference in number of clusters used exceeds this between adjacent betas, insert more betas
min_abs_res = 1e-2 # if beta diff smaller than this absolute threshold, don't insert; consider as phase transition
min_rel_res = 1e-2 # if beta diff smaller than this fractional threshold, don't insert
# parameters governing insertion of betas when I(X;T) doesn't reach zero
eps0 = 1e-2 # tolerance for considering I(X;T) to be zero
l0 = 1 # number of betas to insert at low beta end
f0 = .5 # new betas will be minbeta*f0.^1:l0
# parameters governing insertion of betas when I(T;Y) doesn't reach I(X;Y)
eps1 = .95 # tolerance for considering I(T;Y) to be I(X;Y)
l1 = 1 # number of betas to insert at high beta end
f1 = 2 # new betas will be maxbeta*f0.^1:l0
max_beta_allowed = 80 # any proposed betas above this will be filtered out and replaced it max_beta_allowed
# sort fits by beta
metrics_conv = metrics_conv.sort_values(by='beta')
# init
new_betas = []
NaNtran = False
ixy = metrics_conv['ixy'].iloc[0]
logT = math.log2(metrics_conv['Tmax'].iloc[0])
print('-----------------------------------')
# check that smallest beta was small enough
if metrics_conv['ixt'].min()>eps0: