|
85 | 85 | "data = torch.vstack([y, x])\n",
|
86 | 86 | "\n",
|
87 | 87 | "def transform(data):\n",
|
88 |
| - " min_, max_ = torch.min(data, axis=1), torch.max(data, axis=1)\n", |
| 88 | + " min_, max_ = torch.min(data, dim=1), torch.max(data, dim=1)\n", |
89 | 89 | " data_transformed = 2 * (data.sub(min_.values[:, None])).div((max_.values - min_.values)[:, None]) - 1\n",
|
90 | 90 | " return data_transformed, min_, max_\n",
|
91 | 91 | "\n",
|
|
195 | 195 | "# betas = linear_beta_schedule(timesteps)\n",
|
196 | 196 | "betas = linear_beta_schedule(timesteps)\n",
|
197 | 197 | "alphas = 1 - betas\n",
|
198 |
| - "alphas_ = torch.cumprod(alphas, axis=0)\n", |
| 198 | + "alphas_ = torch.cumprod(alphas, dim=0)\n", |
199 | 199 | "variance = 1 - alphas_\n",
|
200 | 200 | "sd = torch.sqrt(variance)\n",
|
201 | 201 | "\n",
|
|
213 | 213 | "added_noise_at_t = get_noisy(data_transformed, timesteps-1)\n",
|
214 | 214 | "plt.scatter(added_noise_at_t[0], added_noise_at_t[1])\n",
|
215 | 215 | "\n",
|
216 |
| - "posterior_variance = (1 - alphas) * (1 - alphas_prev_) / (1 - alphas)" |
| 216 | + "posterior_variance = (1 - alphas) * (1 - alphas_prev_) / (1 - alphas_)" |
217 | 217 | ]
|
218 | 218 | },
|
219 | 219 | {
|
|
857 | 857 | " posterior_data = posterior_variance[timestep]\n",
|
858 | 858 | " data_in_batch = torch.normal(mean_data, torch.sqrt(posterior_data)).T\n",
|
859 | 859 | " datas.append(data_in_batch.cpu().detach())\n",
|
860 |
| - " return datas, data_in_batch\n", |
| 860 | + " return datas, data_in_batch.cpu().detach()\n", |
861 | 861 | "\n",
|
862 | 862 | "datas, data_in_batch = generate_data(denoising_model)"
|
863 | 863 | ]
|
|
890 | 890 | }
|
891 | 891 | ],
|
892 | 892 | "source": [
|
893 |
| - "data_in_batch = data_in_batch.cpu().detach()\n", |
894 |
| - "data_pred = reverse_transform(data_in_batch.cpu().detach(), min_, max_)\n", |
| 893 | + "data_pred = reverse_transform(data_in_batch, min_, max_)\n", |
895 | 894 | "_, (ax1, ax2) = plt.subplots(1, 2)\n",
|
896 | 895 | "\n",
|
897 | 896 | "ax1.scatter(data[0], data[1])\n",
|
|
1266 | 1265 | }
|
1267 | 1266 | ],
|
1268 | 1267 | "source": [
|
1269 |
| - "data_in_batch = data_in_batch.cpu().detach()\n", |
1270 |
| - "data_pred = reverse_transform(data_in_batch.cpu().detach(), min_circles, max_circles)\n", |
| 1268 | + "data_pred = reverse_transform(data_in_batch, min_circles, max_circles)\n", |
1271 | 1269 | "_, (ax1, ax2) = plt.subplots(1, 2)\n",
|
1272 | 1270 | "\n",
|
1273 | 1271 | "ax1.scatter(circles[0], circles[1])\n",
|
|
1517 | 1515 | }
|
1518 | 1516 | ],
|
1519 | 1517 | "source": [
|
1520 |
| - "data_in_batch = data_in_batch.cpu().detach()\n", |
1521 |
| - "data_pred = reverse_transform(data_in_batch.cpu().detach(), min_moons, max_moons)\n", |
| 1518 | + "data_pred = reverse_transform(data_in_batch, min_moons, max_moons)\n", |
1522 | 1519 | "_, (ax1, ax2) = plt.subplots(1, 2)\n",
|
1523 | 1520 | "\n",
|
1524 | 1521 | "ax1.scatter(make_moons[0], make_moons[1])\n",
|
|
1780 | 1777 | }
|
1781 | 1778 | ],
|
1782 | 1779 | "source": [
|
1783 |
| - "data_in_batch = data_in_batch.cpu().detach()\n", |
1784 |
| - "data_pred = reverse_transform(data_in_batch.cpu().detach(), min_complex, max_complex)\n", |
| 1780 | + "data_pred = reverse_transform(data_in_batch, min_complex, max_complex)\n", |
1785 | 1781 | "_, (ax1, ax2) = plt.subplots(1, 2)\n",
|
1786 | 1782 | "\n",
|
1787 | 1783 | "ax1.scatter(complex_data[0], complex_data[1])\n",
|
|
0 commit comments