Skip to content

Commit 1b7705a

Browse files
author
egor
committed
fix diff
1 parent 2c980e0 commit 1b7705a

8 files changed

+64
-41
lines changed

GM.py

+41-26
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ def __init__(
6565
n_components,
6666
n_features,
6767
covariance_type="full",
68-
eps=1.0e-3,
69-
init_means="random",
68+
eps=1.0e-8,
69+
cov_reg=1e-6,
70+
init_means="kmeans",
7071
mu_init=None,
7172
var_init=None,
7273
verbose=True,
74+
device="cpu",
7375
):
7476
"""
7577
Initializes the model and brings all tensors into their required shape.
@@ -108,11 +110,13 @@ def __init__(
108110

109111
self.covariance_type = covariance_type
110112
self.init_means = init_means
113+
self.cov_reg = cov_reg
111114

112115
assert self.covariance_type in ["full", "diag"]
113116
assert self.init_means in ["kmeans", "random"]
114117

115118
self.verbose = verbose
119+
self.device = device
116120
self._init_params()
117121

118122
def _init_params(self):
@@ -182,7 +186,10 @@ def _init_params(self):
182186
requires_grad=True,
183187
)
184188

185-
self.params = [self.pi, self.mu, self.var]
189+
self.mu.to(self.device)
190+
self.var.to(self.device)
191+
self.pi.to(self.device)
192+
186193
self.fitted = False
187194

188195
def _finish_optimization(self):
@@ -208,6 +215,7 @@ def _set_marginal(self, indices=[]):
208215
self.mu.data = torch.zeros(
209216
1, self.n_components, len(indices), device=device
210217
)
218+
211219
for i, ii in enumerate(indices):
212220
self.mu.data[:, :, i] = self.mu_chached[:, :, ii]
213221

@@ -268,6 +276,7 @@ def fit_em(self, x, delta=1e-5, n_iter=300, warm_start=False):
268276
n_iter: int
269277
warm_start: bool
270278
"""
279+
271280
if not warm_start and self.fitted:
272281
self._init_params()
273282

@@ -289,22 +298,12 @@ def fit_em(self, x, delta=1e-5, n_iter=300, warm_start=False):
289298
self.__em(x)
290299
self.log_likelihood = self.__score(x)
291300
self.print_verbose(f"score {self.log_likelihood.item()}")
301+
292302
if torch.isinf(self.log_likelihood.abs()) or torch.isnan(
293303
self.log_likelihood
294304
):
295-
296305
# When the log-likelihood assumes unbound values, reinitialize model
297-
self.__init__(
298-
self.n_components,
299-
self.n_features,
300-
covariance_type=self.covariance_type,
301-
mu_init=self.mu_init,
302-
var_init=self.var_init,
303-
eps=self.eps,
304-
)
305-
306-
if self.init_means == "kmeans":
307-
(self.mu.data,) = self.get_kmeans_mu(x, n_centers=self.n_components)
306+
self.__reset(x)
308307

309308
i += 1
310309
j = self.log_likelihood - log_likelihood_old
@@ -316,6 +315,12 @@ def fit_em(self, x, delta=1e-5, n_iter=300, warm_start=False):
316315

317316
self._finish_optimization()
318317

318+
def __reset(self, x):
319+
print("RESET")
320+
self._init_params()
321+
if self.init_means == "kmeans":
322+
self.mu.data = self.get_kmeans_mu(x, n_centers=self.n_components)
323+
319324
def fit_grad(self, x, n_iter=1000, learning_rate=1e-1):
320325

321326
# TODO make sure constrains for self.var & self.pi are satisfied
@@ -448,8 +453,8 @@ def _estimate_log_prob(self, x):
448453
x = self.check_size(x)
449454

450455
if self.covariance_type == "full":
451-
mu = self.mu.detach()
452-
var = self.var.detach()
456+
mu = self.mu.detach().to(x.device)
457+
var = self.var.detach().to(x.device)
453458

454459
if var.shape[2] == 1:
455460
precision = 1 / var
@@ -490,12 +495,17 @@ def _calculate_log_det(self, var):
490495
var: torch.Tensor (1, k, d, d)
491496
"""
492497
log_det = torch.empty(size=(self.n_components,)).to(var.device)
493-
494498
for k in range(self.n_components):
495-
log_det[k] = (
496-
2 * torch.log(torch.diagonal(torch.linalg.cholesky(var[0, k]))).sum()
497-
)
498-
499+
try:
500+
dI = self.cov_reg * torch.eye(var[0, k].shape[0]).to(var.device)
501+
log_det[k] = (
502+
2
503+
* torch.log(
504+
torch.diagonal(torch.linalg.cholesky(var[0, k] + dI))
505+
).sum()
506+
)
507+
except:
508+
log_det[k] = torch.logdet(var[0, k])
499509
return log_det.unsqueeze(-1)
500510

501511
def _e_step(self, x):
@@ -555,7 +565,6 @@ def _m_step(self, x, log_resp):
555565
var = x2 - 2 * xmu + mu2 + self.eps
556566

557567
pi = pi / x.shape[0]
558-
559568
return pi, mu, var
560569

561570
def __em(self, x):
@@ -582,7 +591,10 @@ def __score(self, x, as_average=True):
582591
(or)
583592
per_sample_score: torch.Tensor (n)
584593
"""
585-
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi).detach()
594+
595+
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi).detach().to(
596+
x.device
597+
)
586598
per_sample_score = torch.logsumexp(weighted_log_prob, dim=1)
587599

588600
if as_average:
@@ -668,10 +680,9 @@ def __update_pi(self, pi):
668680
self.n_components,
669681
1,
670682
)
671-
672683
self.pi.data = pi
673684

674-
def get_kmeans_mu(self, x, n_centers, init_times=50, min_delta=1e-3):
685+
def get_kmeans_mu(self, x, n_centers, init_times=2, min_delta=1e-2):
675686
"""
676687
Find an initial value for the mean. Requires a threshold min_delta for the k-means algorithm to stop iterating.
677688
The algorithm is repeated init_times often, after which the best centerpoint is returned.
@@ -687,6 +698,10 @@ def get_kmeans_mu(self, x, n_centers, init_times=50, min_delta=1e-3):
687698

688699
min_cost = np.inf
689700

701+
center = x[
702+
np.random.choice(np.arange(x.shape[0]), size=n_centers, replace=False),
703+
...,
704+
]
690705
for i in range(init_times):
691706
tmp_center = x[
692707
np.random.choice(np.arange(x.shape[0]), size=n_centers, replace=False),

MI.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,31 @@ def __init__(
88
n_components,
99
n_features,
1010
covariance_type="full",
11-
eps=1.0e-3,
11+
eps=1.0e-6,
12+
cov_reg=1e-6,
1213
init_means="kmeans",
1314
mu_init=None,
1415
var_init=None,
1516
verbose=True,
1617
fit_mode="em",
1718
n_iter=1e2,
18-
delta=1e-3,
19+
delta=1e-6,
1920
learning_rate=1e-2,
2021
warm_start=False,
22+
device="cpu",
2123
):
2224

2325
super().__init__(
2426
n_components,
2527
n_features,
2628
covariance_type,
2729
eps,
30+
cov_reg,
2831
init_means,
2932
mu_init,
3033
var_init,
3134
verbose,
35+
device,
3236
)
3337

3438
assert fit_mode in [

__pycache__/GM.cpython-310.pyc

463 Bytes
Binary file not shown.

__pycache__/MI.cpython-310.pyc

45 Bytes
Binary file not shown.

example_em_.pdf

-31.5 KB
Binary file not shown.

example_em_sample.pdf

-31.2 KB
Binary file not shown.

test_diff.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def test(model, device, test_loader, MILoss):
6767
data, target = data.to(device), target.to(device)
6868
output = model(data)
6969
# loss = -
70-
test_loss += MILoss(output, target).item()
7170
# F.nll_loss(
7271
# output, target, reduction="sum"
7372
# ).item() # sum up batch loss
@@ -94,14 +93,14 @@ def main():
9493
parser.add_argument(
9594
"--batch-size",
9695
type=int,
97-
default=64,
96+
default=256,
9897
metavar="N",
9998
help="input batch size for training (default: 64)",
10099
)
101100
parser.add_argument(
102101
"--test-batch-size",
103102
type=int,
104-
default=1000,
103+
default=500,
105104
metavar="N",
106105
help="input batch size for testing (default: 1000)",
107106
)
@@ -188,7 +187,13 @@ def main():
188187
def MILoss(predict, yhat):
189188
yohe = torch.nn.functional.one_hot(yhat, num_classes=10)
190189
sample = torch.cat([predict, yohe], dim=1)
191-
model = MIGM(10, 20, init_means="kmeans", verbose=False)
190+
model = MIGM(
191+
n_components=4,
192+
n_features=20,
193+
init_means="kmeans",
194+
verbose=False,
195+
device=device,
196+
)
192197
model.to(device)
193198
model.fit(sample)
194199
indices = [i for i in range(0, 20)]

test_pdf.py test_gm.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515

1616
def main():
17-
n_components = 1
18-
n, d = 100, 2
17+
n_components = 10
18+
n, d = 1500, 2
19+
use_plots = False
1920

2021
data = []
2122
for i in range(n_components):
@@ -30,26 +31,24 @@ def main():
3031
data.append(torch.cat([x_, y_], 1))
3132

3233
data = torch.cat(data, 0)
33-
print(data.shape)
34-
# Next, the Gaussian mixture is instantiated and ..
3534

3635
model = GaussianMixture(n_components, d)
3736
model.fit_em(data)
3837

3938
# .. used to predict the data points as they where shifted
4039
y = model.predict(data)
41-
# model.set_marginal(indices=[0])
4240
x1 = model.predict(data[:, 0], marginals=[0])
43-
# model.set_marginal(indices=[1])
4441
x2 = model.predict(data[:, 1], marginals=[1])
4542

46-
plot(data, y, x1, x2, n)
43+
if use_plots:
44+
plot(data, y, x1, x2, n)
4745

48-
# model.set_marginal(indices=[])
4946
data, y = model.sample(n * n_components)
5047
x1, _ = model.sample(n * n_components, marginals=[0])
5148
x2, _ = model.sample(n * n_components, marginals=[1])
52-
plot(data, y, x1, x2, n, sample=True)
49+
50+
if use_plots:
51+
plot(data, y, x1, x2, n, sample=True)
5352

5453

5554
def plot(data, y, x1, x2, n, sample=False):

0 commit comments

Comments
 (0)