Skip to content

Commit 4bde28b

Browse files
authored
Merge pull request #1 from ghabault/patch-1
Correct typos
2 parents 0f1950b + ab50c9e commit 4bde28b

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

toy-density-estimation/vanilla_diffusion_model.ipynb

+8-12
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
"data = torch.vstack([y, x])\n",
8686
"\n",
8787
"def transform(data):\n",
88-
" min_, max_ = torch.min(data, axis=1), torch.max(data, axis=1)\n",
88+
" min_, max_ = torch.min(data, dim=1), torch.max(data, dim=1)\n",
8989
" data_transformed = 2 * (data.sub(min_.values[:, None])).div((max_.values - min_.values)[:, None]) - 1\n",
9090
" return data_transformed, min_, max_\n",
9191
"\n",
@@ -195,7 +195,7 @@
195195
"# betas = linear_beta_schedule(timesteps)\n",
196196
"betas = linear_beta_schedule(timesteps)\n",
197197
"alphas = 1 - betas\n",
198-
"alphas_ = torch.cumprod(alphas, axis=0)\n",
198+
"alphas_ = torch.cumprod(alphas, dim=0)\n",
199199
"variance = 1 - alphas_\n",
200200
"sd = torch.sqrt(variance)\n",
201201
"\n",
@@ -213,7 +213,7 @@
213213
"added_noise_at_t = get_noisy(data_transformed, timesteps-1)\n",
214214
"plt.scatter(added_noise_at_t[0], added_noise_at_t[1])\n",
215215
"\n",
216-
"posterior_variance = (1 - alphas) * (1 - alphas_prev_) / (1 - alphas)"
216+
"posterior_variance = (1 - alphas) * (1 - alphas_prev_) / (1 - alphas_)"
217217
]
218218
},
219219
{
@@ -857,7 +857,7 @@
857857
" posterior_data = posterior_variance[timestep]\n",
858858
" data_in_batch = torch.normal(mean_data, torch.sqrt(posterior_data)).T\n",
859859
" datas.append(data_in_batch.cpu().detach())\n",
860-
" return datas, data_in_batch\n",
860+
" return datas, data_in_batch.cpu().detach()\n",
861861
"\n",
862862
"datas, data_in_batch = generate_data(denoising_model)"
863863
]
@@ -890,8 +890,7 @@
890890
}
891891
],
892892
"source": [
893-
"data_in_batch = data_in_batch.cpu().detach()\n",
894-
"data_pred = reverse_transform(data_in_batch.cpu().detach(), min_, max_)\n",
893+
"data_pred = reverse_transform(data_in_batch, min_, max_)\n",
895894
"_, (ax1, ax2) = plt.subplots(1, 2)\n",
896895
"\n",
897896
"ax1.scatter(data[0], data[1])\n",
@@ -1266,8 +1265,7 @@
12661265
}
12671266
],
12681267
"source": [
1269-
"data_in_batch = data_in_batch.cpu().detach()\n",
1270-
"data_pred = reverse_transform(data_in_batch.cpu().detach(), min_circles, max_circles)\n",
1268+
"data_pred = reverse_transform(data_in_batch, min_circles, max_circles)\n",
12711269
"_, (ax1, ax2) = plt.subplots(1, 2)\n",
12721270
"\n",
12731271
"ax1.scatter(circles[0], circles[1])\n",
@@ -1517,8 +1515,7 @@
15171515
}
15181516
],
15191517
"source": [
1520-
"data_in_batch = data_in_batch.cpu().detach()\n",
1521-
"data_pred = reverse_transform(data_in_batch.cpu().detach(), min_moons, max_moons)\n",
1518+
"data_pred = reverse_transform(data_in_batch, min_moons, max_moons)\n",
15221519
"_, (ax1, ax2) = plt.subplots(1, 2)\n",
15231520
"\n",
15241521
"ax1.scatter(make_moons[0], make_moons[1])\n",
@@ -1780,8 +1777,7 @@
17801777
}
17811778
],
17821779
"source": [
1783-
"data_in_batch = data_in_batch.cpu().detach()\n",
1784-
"data_pred = reverse_transform(data_in_batch.cpu().detach(), min_complex, max_complex)\n",
1780+
"data_pred = reverse_transform(data_in_batch, min_complex, max_complex)\n",
17851781
"_, (ax1, ax2) = plt.subplots(1, 2)\n",
17861782
"\n",
17871783
"ax1.scatter(complex_data[0], complex_data[1])\n",

0 commit comments

Comments
 (0)