Skip to content

Commit 2c980e0

Browse files
author
egor
committed
gpu compatability
1 parent 935f08f commit 2c980e0

File tree

6 files changed

+314
-33
lines changed

6 files changed

+314
-33
lines changed

GM.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ def __init__(
6565
n_components,
6666
n_features,
6767
covariance_type="full",
68-
eps=1.0e-6,
68+
eps=1.0e-3,
6969
init_means="random",
7070
mu_init=None,
7171
var_init=None,
72-
device="cpu",
7372
verbose=True,
7473
):
7574
"""
@@ -113,13 +112,9 @@ def __init__(
113112
assert self.covariance_type in ["full", "diag"]
114113
assert self.init_means in ["kmeans", "random"]
115114

116-
self.device = device
117115
self.verbose = verbose
118116
self._init_params()
119117

120-
def _to_device(self, x):
121-
return torch.tensor(x, device=self.device)
122-
123118
def _init_params(self):
124119
if self.mu_init is not None:
125120
assert self.mu_init.size() == (
@@ -132,10 +127,9 @@ def _init_params(self):
132127
)
133128
# (1, k, d)
134129
self.mu = torch.nn.Parameter(self.mu_init, requires_grad=True)
135-
self.mu = self._to_device(self.mu)
136130
else:
137131
self.mu = torch.nn.Parameter(
138-
torch.randn(1, self.n_components, self.n_features, device=self.device),
132+
torch.randn(1, self.n_components, self.n_features),
139133
requires_grad=True,
140134
)
141135

@@ -151,14 +145,16 @@ def _init_params(self):
151145
% (self.n_components, self.n_features)
152146
)
153147
self.var = torch.nn.Parameter(self.var_init, requires_grad=True)
154-
self.var = self._to_device(self.var)
155148
else:
156149
self.var = torch.nn.Parameter(
157150
torch.ones(
158-
1, self.n_components, self.n_features, device=self.device
151+
1,
152+
self.n_components,
153+
self.n_features,
159154
),
160155
requires_grad=True,
161156
)
157+
162158
elif self.covariance_type == "full":
163159
if self.var_init is not None:
164160
# (1, k, d, d)
@@ -172,7 +168,6 @@ def _init_params(self):
172168
% (self.n_components, self.n_features, self.n_features)
173169
)
174170
self.var = torch.nn.Parameter(self.var_init, requires_grad=False)
175-
self.var = self._to_device(self.var)
176171
else:
177172
self.var = torch.nn.Parameter(
178173
torch.eye(self.n_features)
@@ -183,9 +178,7 @@ def _init_params(self):
183178

184179
# (1, k, 1)
185180
self.pi = torch.nn.Parameter(
186-
torch.Tensor(1, self.n_components, 1, device=self.device).fill_(
187-
1.0 / self.n_components
188-
),
181+
torch.Tensor(1, self.n_components, 1).fill_(1.0 / self.n_components),
189182
requires_grad=True,
190183
)
191184

@@ -206,25 +199,30 @@ def _set_marginal(self, indices=[]):
206199
self.var.data = self.var_chached.data
207200

208201
else:
202+
device = self.mu.data.device
209203
max_dimension = self.mu_chached.shape[-1]
210204
assert any(
211205
[~(idx <= max_dimension) for idx in indices]
212206
), f"One of provided indices {indices} is higher than a number of dimensions the model was fitted on {max_dimension}."
213207

214-
self.mu.data = torch.zeros(1, self.n_components, len(indices))
208+
self.mu.data = torch.zeros(
209+
1, self.n_components, len(indices), device=device
210+
)
215211
for i, ii in enumerate(indices):
216212
self.mu.data[:, :, i] = self.mu_chached[:, :, ii]
217213

218-
if self.covariance_type is "full":
214+
if self.covariance_type == "full":
219215
self.var.data = torch.zeros(
220-
1, self.n_components, len(indices), len(indices)
216+
1, self.n_components, len(indices), len(indices), device=device
221217
)
222218

223219
for i, ii in enumerate(indices):
224220
for j, jj in enumerate(indices):
225221
self.var.data[:, :, i, j] = self.var_chached[:, :, ii, jj]
226222
else:
227-
self.var.data = torch.zeros(1, self.n_components, len(indices))
223+
self.var.data = torch.zeros(
224+
1, self.n_components, len(indices), device=device
225+
)
228226
for i, ii in enumerate(indices):
229227
self.mu_chached.data[:, :, i] = self.var[:, :, ii]
230228

@@ -304,8 +302,7 @@ def fit_em(self, x, delta=1e-5, n_iter=300, warm_start=False):
304302
var_init=self.var_init,
305303
eps=self.eps,
306304
)
307-
for p in self.parameters():
308-
p.data = self._to_device(p.data)
305+
309306
if self.init_means == "kmeans":
310307
(self.mu.data,) = self.get_kmeans_mu(x, n_centers=self.n_components)
311308

@@ -331,7 +328,7 @@ def fit_grad(self, x, n_iter=1000, learning_rate=1e-1):
331328
optimizer = torch.optim.Adam([self.pi, self.mu, self.var], lr=learning_rate)
332329

333330
# Initialise the minimum loss at infinity.
334-
x = self._to_device(x)
331+
# x = self._to_device(x)
335332
# Iterate over the number of iterations.
336333
for i in range(n_iter):
337334
optimizer.zero_grad()
@@ -451,8 +448,8 @@ def _estimate_log_prob(self, x):
451448
x = self.check_size(x)
452449

453450
if self.covariance_type == "full":
454-
mu = self.mu
455-
var = self.var
451+
mu = self.mu.detach()
452+
var = self.var.detach()
456453

457454
if var.shape[2] == 1:
458455
precision = 1 / var
@@ -585,7 +582,7 @@ def __score(self, x, as_average=True):
585582
(or)
586583
per_sample_score: torch.Tensor (n)
587584
"""
588-
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
585+
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi).detach()
589586
per_sample_score = torch.logsumexp(weighted_log_prob, dim=1)
590587

591588
if as_average:
@@ -722,7 +719,7 @@ def get_kmeans_mu(self, x, n_centers, init_times=50, min_delta=1e-3):
722719

723720
delta = torch.norm((center_old - center), dim=1).max()
724721

725-
return self._to_device(center.unsqueeze(0) * (x_max - x_min) + x_min)
722+
return center.unsqueeze(0) * (x_max - x_min) + x_min
726723

727724
def print_verbose(self, string):
728725
if self.verbose:

MI.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@ def __init__(
88
n_components,
99
n_features,
1010
covariance_type="full",
11-
eps=1.0e-6,
11+
eps=1.0e-3,
1212
init_means="kmeans",
1313
mu_init=None,
1414
var_init=None,
15-
device="cpu",
1615
verbose=True,
1716
fit_mode="em",
18-
n_iter=3e2,
19-
delta=1e-5,
17+
n_iter=1e2,
18+
delta=1e-3,
2019
learning_rate=1e-2,
2120
warm_start=False,
2221
):
@@ -29,7 +28,6 @@ def __init__(
2928
init_means,
3029
mu_init,
3130
var_init,
32-
device,
3331
verbose,
3432
)
3533

@@ -61,11 +59,11 @@ def compute_mi(self, data_joint, indices_a, indices_b):
6159
"""
6260

6361
self.joint = self.logscore_samples(data_joint)
64-
65-
sample_a = torch.index_select(data_joint, 1, torch.tensor(indices_a))
62+
device = data_joint.device
63+
sample_a = torch.index_select(data_joint, 1, torch.tensor(indices_a).to(device))
6664
self.a = self.logscore_samples(sample_a, indices_a)
6765

68-
sample_b = torch.index_select(data_joint, 1, torch.tensor(indices_b))
66+
sample_b = torch.index_select(data_joint, 1, torch.tensor(indices_b).to(device))
6967
self.b = self.logscore_samples(sample_b, indices_b)
7068

7169
mi = (self.joint - self.a - self.b).mean()

__pycache__/GM.cpython-310.pyc

-290 Bytes
Binary file not shown.

__pycache__/MI.cpython-310.pyc

2 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)