@@ -65,11 +65,10 @@ def __init__(
65
65
n_components ,
66
66
n_features ,
67
67
covariance_type = "full" ,
68
- eps = 1.0e-6 ,
68
+ eps = 1.0e-3 ,
69
69
init_means = "random" ,
70
70
mu_init = None ,
71
71
var_init = None ,
72
- device = "cpu" ,
73
72
verbose = True ,
74
73
):
75
74
"""
@@ -113,13 +112,9 @@ def __init__(
113
112
assert self .covariance_type in ["full" , "diag" ]
114
113
assert self .init_means in ["kmeans" , "random" ]
115
114
116
- self .device = device
117
115
self .verbose = verbose
118
116
self ._init_params ()
119
117
120
- def _to_device (self , x ):
121
- return torch .tensor (x , device = self .device )
122
-
123
118
def _init_params (self ):
124
119
if self .mu_init is not None :
125
120
assert self .mu_init .size () == (
@@ -132,10 +127,9 @@ def _init_params(self):
132
127
)
133
128
# (1, k, d)
134
129
self .mu = torch .nn .Parameter (self .mu_init , requires_grad = True )
135
- self .mu = self ._to_device (self .mu )
136
130
else :
137
131
self .mu = torch .nn .Parameter (
138
- torch .randn (1 , self .n_components , self .n_features , device = self . device ),
132
+ torch .randn (1 , self .n_components , self .n_features ),
139
133
requires_grad = True ,
140
134
)
141
135
@@ -151,14 +145,16 @@ def _init_params(self):
151
145
% (self .n_components , self .n_features )
152
146
)
153
147
self .var = torch .nn .Parameter (self .var_init , requires_grad = True )
154
- self .var = self ._to_device (self .var )
155
148
else :
156
149
self .var = torch .nn .Parameter (
157
150
torch .ones (
158
- 1 , self .n_components , self .n_features , device = self .device
151
+ 1 ,
152
+ self .n_components ,
153
+ self .n_features ,
159
154
),
160
155
requires_grad = True ,
161
156
)
157
+
162
158
elif self .covariance_type == "full" :
163
159
if self .var_init is not None :
164
160
# (1, k, d, d)
@@ -172,7 +168,6 @@ def _init_params(self):
172
168
% (self .n_components , self .n_features , self .n_features )
173
169
)
174
170
self .var = torch .nn .Parameter (self .var_init , requires_grad = False )
175
- self .var = self ._to_device (self .var )
176
171
else :
177
172
self .var = torch .nn .Parameter (
178
173
torch .eye (self .n_features )
@@ -183,9 +178,7 @@ def _init_params(self):
183
178
184
179
# (1, k, 1)
185
180
self .pi = torch .nn .Parameter (
186
- torch .Tensor (1 , self .n_components , 1 , device = self .device ).fill_ (
187
- 1.0 / self .n_components
188
- ),
181
+ torch .Tensor (1 , self .n_components , 1 ).fill_ (1.0 / self .n_components ),
189
182
requires_grad = True ,
190
183
)
191
184
@@ -206,25 +199,30 @@ def _set_marginal(self, indices=[]):
206
199
self .var .data = self .var_chached .data
207
200
208
201
else :
202
+ device = self .mu .data .device
209
203
max_dimension = self .mu_chached .shape [- 1 ]
210
204
assert any (
211
205
[~ (idx <= max_dimension ) for idx in indices ]
212
206
), f"One of provided indices { indices } is higher than a number of dimensions the model was fitted on { max_dimension } ."
213
207
214
- self .mu .data = torch .zeros (1 , self .n_components , len (indices ))
208
+ self .mu .data = torch .zeros (
209
+ 1 , self .n_components , len (indices ), device = device
210
+ )
215
211
for i , ii in enumerate (indices ):
216
212
self .mu .data [:, :, i ] = self .mu_chached [:, :, ii ]
217
213
218
- if self .covariance_type is "full" :
214
+ if self .covariance_type == "full" :
219
215
self .var .data = torch .zeros (
220
- 1 , self .n_components , len (indices ), len (indices )
216
+ 1 , self .n_components , len (indices ), len (indices ), device = device
221
217
)
222
218
223
219
for i , ii in enumerate (indices ):
224
220
for j , jj in enumerate (indices ):
225
221
self .var .data [:, :, i , j ] = self .var_chached [:, :, ii , jj ]
226
222
else :
227
- self .var .data = torch .zeros (1 , self .n_components , len (indices ))
223
+ self .var .data = torch .zeros (
224
+ 1 , self .n_components , len (indices ), device = device
225
+ )
228
226
for i , ii in enumerate (indices ):
229
227
self .mu_chached .data [:, :, i ] = self .var [:, :, ii ]
230
228
@@ -304,8 +302,7 @@ def fit_em(self, x, delta=1e-5, n_iter=300, warm_start=False):
304
302
var_init = self .var_init ,
305
303
eps = self .eps ,
306
304
)
307
- for p in self .parameters ():
308
- p .data = self ._to_device (p .data )
305
+
309
306
if self .init_means == "kmeans" :
310
307
(self .mu .data ,) = self .get_kmeans_mu (x , n_centers = self .n_components )
311
308
@@ -331,7 +328,7 @@ def fit_grad(self, x, n_iter=1000, learning_rate=1e-1):
331
328
optimizer = torch .optim .Adam ([self .pi , self .mu , self .var ], lr = learning_rate )
332
329
333
330
# Initialise the minimum loss at infinity.
334
- x = self ._to_device (x )
331
+ # x = self._to_device(x)
335
332
# Iterate over the number of iterations.
336
333
for i in range (n_iter ):
337
334
optimizer .zero_grad ()
@@ -451,8 +448,8 @@ def _estimate_log_prob(self, x):
451
448
x = self .check_size (x )
452
449
453
450
if self .covariance_type == "full" :
454
- mu = self .mu
455
- var = self .var
451
+ mu = self .mu . detach ()
452
+ var = self .var . detach ()
456
453
457
454
if var .shape [2 ] == 1 :
458
455
precision = 1 / var
@@ -585,7 +582,7 @@ def __score(self, x, as_average=True):
585
582
(or)
586
583
per_sample_score: torch.Tensor (n)
587
584
"""
588
- weighted_log_prob = self ._estimate_log_prob (x ) + torch .log (self .pi )
585
+ weighted_log_prob = self ._estimate_log_prob (x ) + torch .log (self .pi ). detach ()
589
586
per_sample_score = torch .logsumexp (weighted_log_prob , dim = 1 )
590
587
591
588
if as_average :
@@ -722,7 +719,7 @@ def get_kmeans_mu(self, x, n_centers, init_times=50, min_delta=1e-3):
722
719
723
720
delta = torch .norm ((center_old - center ), dim = 1 ).max ()
724
721
725
- return self . _to_device ( center .unsqueeze (0 ) * (x_max - x_min ) + x_min )
722
+ return center .unsqueeze (0 ) * (x_max - x_min ) + x_min
726
723
727
724
def print_verbose (self , string ):
728
725
if self .verbose :
0 commit comments