@@ -65,11 +65,13 @@ def __init__(
65
65
n_components ,
66
66
n_features ,
67
67
covariance_type = "full" ,
68
- eps = 1.0e-3 ,
69
- init_means = "random" ,
68
+ eps = 1.0e-8 ,
69
+ cov_reg = 1e-6 ,
70
+ init_means = "kmeans" ,
70
71
mu_init = None ,
71
72
var_init = None ,
72
73
verbose = True ,
74
+ device = "cpu" ,
73
75
):
74
76
"""
75
77
Initializes the model and brings all tensors into their required shape.
@@ -108,11 +110,13 @@ def __init__(
108
110
109
111
self .covariance_type = covariance_type
110
112
self .init_means = init_means
113
+ self .cov_reg = cov_reg
111
114
112
115
assert self .covariance_type in ["full" , "diag" ]
113
116
assert self .init_means in ["kmeans" , "random" ]
114
117
115
118
self .verbose = verbose
119
+ self .device = device
116
120
self ._init_params ()
117
121
118
122
def _init_params (self ):
@@ -182,7 +186,10 @@ def _init_params(self):
182
186
requires_grad = True ,
183
187
)
184
188
185
- self .params = [self .pi , self .mu , self .var ]
189
+ self .mu .to (self .device )
190
+ self .var .to (self .device )
191
+ self .pi .to (self .device )
192
+
186
193
self .fitted = False
187
194
188
195
def _finish_optimization (self ):
@@ -208,6 +215,7 @@ def _set_marginal(self, indices=[]):
208
215
self .mu .data = torch .zeros (
209
216
1 , self .n_components , len (indices ), device = device
210
217
)
218
+
211
219
for i , ii in enumerate (indices ):
212
220
self .mu .data [:, :, i ] = self .mu_chached [:, :, ii ]
213
221
@@ -268,6 +276,7 @@ def fit_em(self, x, delta=1e-5, n_iter=300, warm_start=False):
268
276
n_iter: int
269
277
warm_start: bool
270
278
"""
279
+
271
280
if not warm_start and self .fitted :
272
281
self ._init_params ()
273
282
@@ -289,22 +298,12 @@ def fit_em(self, x, delta=1e-5, n_iter=300, warm_start=False):
289
298
self .__em (x )
290
299
self .log_likelihood = self .__score (x )
291
300
self .print_verbose (f"score { self .log_likelihood .item ()} " )
301
+
292
302
if torch .isinf (self .log_likelihood .abs ()) or torch .isnan (
293
303
self .log_likelihood
294
304
):
295
-
296
305
# When the log-likelihood assumes unbound values, reinitialize model
297
- self .__init__ (
298
- self .n_components ,
299
- self .n_features ,
300
- covariance_type = self .covariance_type ,
301
- mu_init = self .mu_init ,
302
- var_init = self .var_init ,
303
- eps = self .eps ,
304
- )
305
-
306
- if self .init_means == "kmeans" :
307
- (self .mu .data ,) = self .get_kmeans_mu (x , n_centers = self .n_components )
306
+ self .__reset (x )
308
307
309
308
i += 1
310
309
j = self .log_likelihood - log_likelihood_old
@@ -316,6 +315,12 @@ def fit_em(self, x, delta=1e-5, n_iter=300, warm_start=False):
316
315
317
316
self ._finish_optimization ()
318
317
318
+ def __reset (self , x ):
319
+ print ("RESET" )
320
+ self ._init_params ()
321
+ if self .init_means == "kmeans" :
322
+ self .mu .data = self .get_kmeans_mu (x , n_centers = self .n_components )
323
+
319
324
def fit_grad (self , x , n_iter = 1000 , learning_rate = 1e-1 ):
320
325
321
326
# TODO make sure constrains for self.var & self.pi are satisfied
@@ -448,8 +453,8 @@ def _estimate_log_prob(self, x):
448
453
x = self .check_size (x )
449
454
450
455
if self .covariance_type == "full" :
451
- mu = self .mu .detach ()
452
- var = self .var .detach ()
456
+ mu = self .mu .detach (). to ( x . device )
457
+ var = self .var .detach (). to ( x . device )
453
458
454
459
if var .shape [2 ] == 1 :
455
460
precision = 1 / var
@@ -490,12 +495,17 @@ def _calculate_log_det(self, var):
490
495
var: torch.Tensor (1, k, d, d)
491
496
"""
492
497
log_det = torch .empty (size = (self .n_components ,)).to (var .device )
493
-
494
498
for k in range (self .n_components ):
495
- log_det [k ] = (
496
- 2 * torch .log (torch .diagonal (torch .linalg .cholesky (var [0 , k ]))).sum ()
497
- )
498
-
499
+ try :
500
+ dI = self .cov_reg * torch .eye (var [0 , k ].shape [0 ]).to (var .device )
501
+ log_det [k ] = (
502
+ 2
503
+ * torch .log (
504
+ torch .diagonal (torch .linalg .cholesky (var [0 , k ] + dI ))
505
+ ).sum ()
506
+ )
507
+ except :
508
+ log_det [k ] = torch .logdet (var [0 , k ])
499
509
return log_det .unsqueeze (- 1 )
500
510
501
511
def _e_step (self , x ):
@@ -555,7 +565,6 @@ def _m_step(self, x, log_resp):
555
565
var = x2 - 2 * xmu + mu2 + self .eps
556
566
557
567
pi = pi / x .shape [0 ]
558
-
559
568
return pi , mu , var
560
569
561
570
def __em (self , x ):
@@ -582,7 +591,10 @@ def __score(self, x, as_average=True):
582
591
(or)
583
592
per_sample_score: torch.Tensor (n)
584
593
"""
585
- weighted_log_prob = self ._estimate_log_prob (x ) + torch .log (self .pi ).detach ()
594
+
595
+ weighted_log_prob = self ._estimate_log_prob (x ) + torch .log (self .pi ).detach ().to (
596
+ x .device
597
+ )
586
598
per_sample_score = torch .logsumexp (weighted_log_prob , dim = 1 )
587
599
588
600
if as_average :
@@ -668,10 +680,9 @@ def __update_pi(self, pi):
668
680
self .n_components ,
669
681
1 ,
670
682
)
671
-
672
683
self .pi .data = pi
673
684
674
- def get_kmeans_mu (self , x , n_centers , init_times = 50 , min_delta = 1e-3 ):
685
+ def get_kmeans_mu (self , x , n_centers , init_times = 2 , min_delta = 1e-2 ):
675
686
"""
676
687
Find an initial value for the mean. Requires a threshold min_delta for the k-means algorithm to stop iterating.
677
688
The algorithm is repeated init_times often, after which the best centerpoint is returned.
@@ -687,6 +698,10 @@ def get_kmeans_mu(self, x, n_centers, init_times=50, min_delta=1e-3):
687
698
688
699
min_cost = np .inf
689
700
701
+ center = x [
702
+ np .random .choice (np .arange (x .shape [0 ]), size = n_centers , replace = False ),
703
+ ...,
704
+ ]
690
705
for i in range (init_times ):
691
706
tmp_center = x [
692
707
np .random .choice (np .arange (x .shape [0 ]), size = n_centers , replace = False ),
0 commit comments