diff --git a/.gitignore b/.gitignore index e8bc26f..eb3e7b7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /venv -docs/.DS_Store \ No newline at end of file +docs/.DS_Store +diffusion-models/notebooks/mnist_data/ \ No newline at end of file diff --git a/diffusion-models/notebooks/hf_diffusers.ipynb b/diffusion-models/notebooks/hf_diffusers.ipynb index 66089a7..bd18dfc 100644 --- a/diffusion-models/notebooks/hf_diffusers.ipynb +++ b/diffusion-models/notebooks/hf_diffusers.ipynb @@ -9,9 +9,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "/usr/local/lib/python3.12/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.\n", + "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.\n", " deprecate(\"Transformer2DModelOutput\", \"1.0.0\", deprecation_message)\n" ] } @@ -41,15 +41,19 @@ " in_channels=1,\n", " out_channels=1,\n", " layers_per_block=1,\n", - " block_out_channels=(32, 64),\n", + " block_out_channels=(8, 16, 32),\n", " down_block_types=(\n", " \"DownBlock2D\",\n", " \"DownBlock2D\",\n", + " \"DownBlock2D\",\n", " ),\n", " up_block_types=(\n", " \"UpBlock2D\",\n", " \"UpBlock2D\",\n", + " \"UpBlock2D\",\n", " ),\n", + " num_class_embeds=10,\n", + " norm_num_groups=1,\n", ")" ] }, @@ -62,43 +66,63 @@ "data": { "text/plain": [ "UNet2DModel(\n", - " (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv_in): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_proj): Timesteps()\n", " (time_embedding): TimestepEmbedding(\n", - " (linear_1): Linear(in_features=32, out_features=128, bias=True)\n", + " (linear_1): Linear(in_features=8, out_features=32, bias=True)\n", " (act): SiLU()\n", - " (linear_2): Linear(in_features=128, out_features=128, bias=True)\n", + " (linear_2): Linear(in_features=32, out_features=32, bias=True)\n", " )\n", + " (class_embedding): Embedding(10, 32)\n", " (down_blocks): ModuleList(\n", " (0): DownBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)\n", - " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)\n", - " (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)\n", + " (norm1): GroupNorm(1, 8, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=8, bias=True)\n", + " (norm2): GroupNorm(1, 8, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", - " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (1): DownBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)\n", - " (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)\n", - " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", + " (norm1): GroupNorm(1, 8, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=16, bias=True)\n", + " (norm2): GroupNorm(1, 16, eps=1e-05, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (nonlinearity): SiLU()\n", + " (conv_shortcut): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (downsamplers): ModuleList(\n", + " (0): Downsample2D(\n", + " (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (2): DownBlock2D(\n", + " (resnets): ModuleList(\n", + " (0): ResnetBlock2D(\n", + " (norm1): GroupNorm(1, 16, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (norm2): GroupNorm(1, 32, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_shortcut): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " )\n", @@ -107,53 +131,82 @@ " (0): UpBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)\n", - " (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)\n", - " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", + " (norm1): GroupNorm(1, 64, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (norm2): GroupNorm(1, 32, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 96, eps=1e-05, affine=True)\n", - " (conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)\n", - " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", + " (norm1): GroupNorm(1, 48, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (norm2): GroupNorm(1, 32, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_shortcut): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", - " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (1): UpBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 96, eps=1e-05, affine=True)\n", - " (conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)\n", - " (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)\n", + " (norm1): GroupNorm(1, 48, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(48, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=16, bias=True)\n", + " (norm2): GroupNorm(1, 16, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_shortcut): Conv2d(48, 16, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", - " (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)\n", - " (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)\n", + " (norm1): GroupNorm(1, 24, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(24, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=16, bias=True)\n", + " (norm2): GroupNorm(1, 16, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_shortcut): Conv2d(24, 16, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (upsamplers): ModuleList(\n", + " (0): Upsample2D(\n", + " (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (2): UpBlock2D(\n", + " (resnets): ModuleList(\n", + " (0): ResnetBlock2D(\n", + " (norm1): GroupNorm(1, 24, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(24, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=8, bias=True)\n", + " (norm2): GroupNorm(1, 8, eps=1e-05, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (nonlinearity): SiLU()\n", + " (conv_shortcut): Conv2d(24, 8, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (1): ResnetBlock2D(\n", + " (norm1): GroupNorm(1, 16, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=8, bias=True)\n", + " (norm2): GroupNorm(1, 8, eps=1e-05, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (nonlinearity): SiLU()\n", + " (conv_shortcut): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " )\n", @@ -161,31 +214,31 @@ " (mid_block): UNetMidBlock2D(\n", " (attentions): ModuleList(\n", " (0): Attention(\n", - " (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)\n", - " (to_q): Linear(in_features=64, out_features=64, bias=True)\n", - " (to_k): Linear(in_features=64, out_features=64, bias=True)\n", - " (to_v): Linear(in_features=64, out_features=64, bias=True)\n", + " (group_norm): GroupNorm(1, 32, eps=1e-05, affine=True)\n", + " (to_q): Linear(in_features=32, out_features=32, bias=True)\n", + " (to_k): Linear(in_features=32, out_features=32, bias=True)\n", + " (to_v): Linear(in_features=32, out_features=32, bias=True)\n", " (to_out): ModuleList(\n", - " (0): Linear(in_features=64, out_features=64, bias=True)\n", + " (0): Linear(in_features=32, out_features=32, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", - " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)\n", - " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", + " (norm1): GroupNorm(1, 32, eps=1e-05, affine=True)\n", + " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (time_emb_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (norm2): GroupNorm(1, 32, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", - " (conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True)\n", + " (conv_norm_out): GroupNorm(1, 8, eps=1e-05, affine=True)\n", " (conv_act): SiLU()\n", - " (conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (conv_out): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", ")" ] }, @@ -200,7 +253,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -212,152 +265,27 @@ "train_dataset = datasets.MNIST(root='mnist_data', train=True, download=True, transform=transform)\n", "test_dataset = datasets.MNIST(root='mnist_data', train=False, download=True, transform=transform)\n", "\n", - "# train_noisy_dataset = MNISTDataset(train_dataset)\n", - "# test_noisy_dataset = MNISTDataset(test_dataset)\n", - "\n", "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n", - "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 1, 28, 28])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sample_image = train_dataset[0][0].unsqueeze(0)\n", - "sample_image.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "output = model(sample_image, timestep=0).sample" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", + "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n", "\n", - "plt.imshow(output.squeeze().detach().numpy(), cmap='gray')" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "from diffusers import DDPMScheduler\n", "\n", - "noise_scheduler = DDPMScheduler(num_train_timesteps=150)" + "# " ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "noise = torch.randn(sample_image.shape)\n", - "timesteps = torch.LongTensor([20])\n", - "noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)\n", - "\n", - "output = model(noisy_image, timestep=50).sample" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))\n", - "\n", - "ax1.imshow(sample_image.squeeze().detach().numpy(), cmap='gray')\n", - "ax1.set_title('Original Image')\n", - "\n", - "ax2.imshow(noisy_image.squeeze().detach().numpy(), cmap='gray')\n", - "ax2.set_title('Noisy Image')\n", - "\n", - "ax3.imshow(output.squeeze().detach().numpy(), cmap='gray')\n", - "ax3.set_title('Denoised Image')\n", + "from diffusers import DDPMScheduler\n", "\n", - "plt.show()" + "noise_scheduler = DDPMScheduler(num_train_timesteps=150) " ] }, { "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(1.0710, grad_fn=)\n" - ] - } - ], - "source": [ - "noise_pred = model(noisy_image, timesteps).sample\n", - "loss = F.mse_loss(noise_pred, noise)\n", - "print(loss)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -366,7 +294,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -392,30 +320,72 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "0it [00:00, ?it/s]/usr/local/lib/python3.12/site-packages/diffusers/configuration_utils.py:140: FutureWarning: Accessing config attribute `num_train_timesteps` directly via 'DDPMScheduler' object attribute is deprecated. Please access 'num_train_timesteps' over 'DDPMScheduler's config object instead, e.g. 'scheduler.config.num_train_timesteps'.\n", + "0it [00:00, ?it/s]/usr/local/lib/python3.10/dist-packages/diffusers/configuration_utils.py:140: FutureWarning: Accessing config attribute `num_train_timesteps` directly via 'DDPMScheduler' object attribute is deprecated. Please access 'num_train_timesteps' over 'DDPMScheduler's config object instead, e.g. 'scheduler.config.num_train_timesteps'.\n", " deprecate(\"direct config name access\", \"1.0.0\", deprecation_message, standard_warn=False)\n", - "1875it [26:23, 1.18it/s]\n" + "1875it [04:51, 6.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 Loss: 0.059813931584358215\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1875it [04:47, 6.52it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 Loss: 0.047683872282505035\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1875it [04:52, 6.40it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 Loss: 0.034084219485521317\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1875it [05:02, 6.20it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0 Loss: 0.042306363582611084\n" + "Epoch 3 Loss: 0.039008501917123795\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "569it [08:01, 1.18it/s]\n" + "57it [00:09, 6.02it/s]\n" ] }, { @@ -425,10 +395,20 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[24], line 14\u001b[0m\n\u001b[1;32m 12\u001b[0m noise_pred \u001b[38;5;241m=\u001b[39m model(noisy_images, timesteps, return_dict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmse_loss(noise_pred, noise)\n\u001b[0;32m---> 14\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 16\u001b[0m lr_scheduler\u001b[38;5;241m.\u001b[39mstep()\n", - "File \u001b[0;32m/usr/local/lib/python3.12/site-packages/torch/_tensor.py:525\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 517\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 518\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 523\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 524\u001b[0m )\n\u001b[0;32m--> 525\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 526\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.12/site-packages/torch/autograd/__init__.py:267\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 262\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 267\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.12/site-packages/torch/autograd/graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 742\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 744\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 745\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 746\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n", + "Cell \u001b[0;32mIn[8], line 13\u001b[0m\n\u001b[1;32m 10\u001b[0m timesteps \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m0\u001b[39m, noise_scheduler\u001b[38;5;241m.\u001b[39mnum_train_timesteps, (bs,), device\u001b[38;5;241m=\u001b[39mclean_images\u001b[38;5;241m.\u001b[39mdevice)\u001b[38;5;241m.\u001b[39mlong()\n\u001b[1;32m 11\u001b[0m noisy_images \u001b[38;5;241m=\u001b[39m noise_scheduler\u001b[38;5;241m.\u001b[39madd_noise(clean_images, noise, timesteps)\n\u001b[0;32m---> 13\u001b[0m noise_pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoisy_images\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimesteps\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 14\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmse_loss(noise_pred, noise)\n\u001b[1;32m 15\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/diffusers/models/unets/unet_2d.py:329\u001b[0m, in \u001b[0;36mUNet2DModel.forward\u001b[0;34m(self, sample, timestep, class_labels, return_dict)\u001b[0m\n\u001b[1;32m 327\u001b[0m sample, skip_sample \u001b[38;5;241m=\u001b[39m upsample_block(sample, res_samples, emb, skip_sample)\n\u001b[1;32m 328\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 329\u001b[0m sample \u001b[38;5;241m=\u001b[39m \u001b[43mupsample_block\u001b[49m\u001b[43m(\u001b[49m\u001b[43msample\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mres_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43memb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 331\u001b[0m \u001b[38;5;66;03m# 6. post-process\u001b[39;00m\n\u001b[1;32m 332\u001b[0m sample \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv_norm_out(sample)\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/diffusers/models/unets/unet_2d_blocks.py:2673\u001b[0m, in \u001b[0;36mUpBlock2D.forward\u001b[0;34m(self, hidden_states, res_hidden_states_tuple, temb, upsample_size, *args, **kwargs)\u001b[0m\n\u001b[1;32m 2669\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mcheckpoint\u001b[38;5;241m.\u001b[39mcheckpoint(\n\u001b[1;32m 2670\u001b[0m create_custom_forward(resnet), hidden_states, temb\n\u001b[1;32m 2671\u001b[0m )\n\u001b[1;32m 2672\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2673\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[43mresnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2675\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupsamplers \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 2676\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m upsampler \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupsamplers:\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/diffusers/models/resnet.py:345\u001b[0m, in \u001b[0;36mResnetBlock2D.forward\u001b[0;34m(self, input_tensor, temb, *args, **kwargs)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_emb_proj \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 344\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mskip_time_act:\n\u001b[0;32m--> 345\u001b[0m temb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnonlinearity\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtemb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 346\u001b[0m temb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_emb_proj(temb)[:, :, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m]\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_embedding_norm \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdefault\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/activation.py:396\u001b[0m, in \u001b[0;36mSiLU.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 396\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msilu\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minplace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:2102\u001b[0m, in \u001b[0;36msilu\u001b[0;34m(input, inplace)\u001b[0m\n\u001b[1;32m 2100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inplace:\n\u001b[1;32m 2101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_nn\u001b[38;5;241m.\u001b[39msilu_(\u001b[38;5;28minput\u001b[39m)\n\u001b[0;32m-> 2102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msilu\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } @@ -441,11 +421,12 @@ " for i, (clean_images, labels) in tqdm(enumerate(train_loader)):\n", " noise = torch.randn(clean_images.shape)\n", " bs = clean_images.shape[0]\n", + " labels = labels\n", "\n", " timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()\n", " noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n", "\n", - " noise_pred = model(noisy_images, timesteps, return_dict=False)[0]\n", + " noise_pred = model(noisy_images, timesteps, labels, return_dict=False)[0]\n", " loss = F.mse_loss(noise_pred, noise)\n", " loss.backward()\n", " optimizer.step()\n", @@ -458,60 +439,89 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ - "batch = next(iter(train_loader))" + "import numpy as np" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch._C.Generator" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from diffusers import DDPMPipeline" + "type(torch.manual_seed(42))" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ - "pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler)\n" + "def inference(model: UNet2DModel,\n", + " scheduler: DDPMScheduler,\n", + " batch_size: int,\n", + " generator: torch._C.Generator,\n", + " num_inference_steps: int,\n", + " label: int) -> np.ndarray:\n", + " \n", + " image_shape = (batch_size, 1, 28, 28)\n", + " labels = torch.full((batch_size,), label)\n", + "\n", + " image = torch.randn(image_shape)\n", + "\n", + " # set step values\n", + " scheduler.set_timesteps(num_inference_steps)\n", + "\n", + " for t in scheduler.timesteps:\n", + " # 1. predict noise model_output\n", + " model_output = model(image, t, labels).sample\n", + "\n", + " # 2. compute previous image: x_t -> x_t-1\n", + " image = scheduler.step(model_output, t, image, generator=generator).prev_sample\n", + "\n", + " image = (image / 2 + 0.5).clamp(0, 1)\n", + " image = image.permute(0, 2, 3, 1)\n", + "\n", + " return image.detach().numpy()" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 14, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 150/150 [00:05<00:00, 28.16it/s]\n" - ] - } - ], + "outputs": [], "source": [ - "images = pipeline(\n", - " batch_size=10,\n", - " generator=torch.manual_seed(42),\n", - " num_inference_steps=150,\n", - ").images" + "images = inference(model=model,\n", + " scheduler=noise_scheduler,\n", + " batch_size=10,\n", + " generator=torch.manual_seed(42),\n", + " num_inference_steps=150,\n", + " label=3)" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -522,6 +532,7 @@ ], "source": [ "# show images\n", + "import matplotlib.pyplot as plt\n", "fig, ax = plt.subplots(2, 5, figsize=(20, 8))\n", "for i in range(10):\n", " ax[i // 5, i % 5].imshow(images[i], cmap='gray')\n", @@ -552,7 +563,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/diffusion-models/requirements.txt b/diffusion-models/requirements.txt index cdc2b5a..918624e 100644 --- a/diffusion-models/requirements.txt +++ b/diffusion-models/requirements.txt @@ -1,4 +1,6 @@ torch==2.3.0 torchvision==0.18.0 matplotlib==3.9.0 -seaborn==0.13.2 \ No newline at end of file +seaborn==0.13.2 +diffusers==0.28.2 +transformers==4.41.2 \ No newline at end of file diff --git a/docs/multi_2/subpage_2.md b/docs/hf_diffusers/hf_ecg_gen.md similarity index 100% rename from docs/multi_2/subpage_2.md rename to docs/hf_diffusers/hf_ecg_gen.md diff --git a/docs/hf_diffusers/hf_mnist_gen.md b/docs/hf_diffusers/hf_mnist_gen.md new file mode 100644 index 0000000..e8f08a3 --- /dev/null +++ b/docs/hf_diffusers/hf_mnist_gen.md @@ -0,0 +1,49 @@ +In this section we will introduce a core component of the Diffusers library `UNet2DModel`. This page will closely follow the notebook `hf_diffusers_mnist.ipynb`, so if you're interested in a quick run through and want to play around yourself, then check that out. Here, we will go through the code in more detail and explain the different components of the model, but feel free to refer back to this page. + + + +## `UNet2DModel` +The `UNet2DModel` is a 2D U-Net model that is used for image generation. It abstracts out all of the annoying details of building the model, and allows you to focus on the important parts of your project. + + +``` python +model = UNet2DModel( + sample_size=28, # (1) + in_channels=1, # (2) + out_channels=1, # (3) + layers_per_block=1, # (4) + block_out_channels=(8, 16, 32), # (5) + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + num_class_embeds=10, # (6) + norm_num_groups=1, # (7) +) +``` + +1. The size of the input image. In this case, it is 28x28 pixels +2. The number of channels in the input image. In this case, it is 1, as the images are grayscale, but you would use 3 for RGB images +3. The number of channels in the output image. In this case, it is 1, as the images are grayscale, but you would use 3 for RGB images +4. The number of layers in each block +5. The number of output channels in each block - in a convolutional neural network, the number of channels is the number of filters that are applied to the input image. +6. The number of classes your dataset has if you are doing conditional generation. +7. The number of groups to use for the normalization layer. This is a hyperparameter that you can tune to improve the performance of your model. + +There are a number of additional parameters, but these are the most important ones to understand when you are getting started with the `UNet2DModel`. Let's look at the block types in more detail, and some of the additional paramters + +## Further reading +
+ +- :fontawesome-solid-book-open:{ .lg .middle } [__CI/CD - Pre-commit resources__](../resources/references.md#pre-commit) + + --- + Information on GitHub Actions, Black, Flake8, Mypy, Isort, and Git Hooks + +
\ No newline at end of file diff --git a/docs/hf_diffusers/index.md b/docs/hf_diffusers/index.md new file mode 100644 index 0000000..8acff67 --- /dev/null +++ b/docs/hf_diffusers/index.md @@ -0,0 +1,4 @@ +# Hugging Face Diffusers +In this section, we make life easy for ourselves and use the Hugging Face Diffusers Library. We will look at two main tasks: image generation and ECG generation. We will start with image generation, and in doing so, we'll walk through most of the building blocks required to understand the second task. + +Let's get started! \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index f8da25f..4cb81ab 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,9 +2,9 @@ Logo -# COURSE TITLE +# Introduction to diffusion models in generative AI -Welcome to the material for COURSE NAME. +Welcome to the material for Introduction to diffusion models in generative AI. Please check the [official Mkdocs Material documentation](https://squidfunk.github.io/mkdocs-material/) for more information on how to use this template. diff --git a/docs/multi_2/index.md b/docs/multi_2/index.md deleted file mode 100644 index 65f7a9a..0000000 --- a/docs/multi_2/index.md +++ /dev/null @@ -1,2 +0,0 @@ -# Main landing page -You can see the links to subpages on the left. \ No newline at end of file diff --git a/docs/multi_2/subpage_1.md b/docs/multi_2/subpage_1.md deleted file mode 100644 index 606dd7d..0000000 --- a/docs/multi_2/subpage_1.md +++ /dev/null @@ -1,16 +0,0 @@ -This is another way of doing subpages. Reference the yml file. - -## Information -Here is some course content. - - - -## Further reading -
- -- :fontawesome-solid-book-open:{ .lg .middle } [__CI/CD - Pre-commit resources__](../resources/references.md#pre-commit) - - --- - Information on GitHub Actions, Black, Flake8, Mypy, Isort, and Git Hooks - -
\ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 0abbd81..3ff190f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,6 +1,6 @@ -site_name: SITE NAME -repo_url: https://github.com/acceleratescience/REPO-TITLE -site_url: https://acceleratescience.github.io/REPO-TITLE/ +site_name: Introduction to diffusion models in generative AI +repo_url: https://github.com/acceleratescience/diffusion-models +site_url: https://acceleratescience.github.io/diffusion-models/ nav: - Home: - Welcome!: index.md @@ -10,10 +10,10 @@ nav: - Multi Page 1: - Subpage 1: multi_1/subpage_1.md - Subpage 2: multi_1/subpage_2.md - - Multi Page 2: - - multi_2/index.md - - Page 1: multi_2/subpage_1.md - - Page 2: multi_2/subpage_2.md + - Hugging Face Diffusers: + - hf_diffusers/index.md + - Image Generation: hf_diffusers/hf_mnist_gen.md + - ECG Generation: hf_diffusers/hf_ecg_gen.md - Resources: - resources/index.md - Slides: resources/slides.md