From 6ec669237022420ded725b2d54f812fcc61cf071 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 6 Sep 2024 11:50:58 +0100 Subject: [PATCH] Fix Signed-off-by: Eric Kerfoot --- generation/2d_vqgan/2d_vqgan_tutorial.ipynb | 6 ++++-- generation/2d_vqvae/2d_vqvae_tutorial.ipynb | 6 ++++-- .../2d_vqvae_transformer_tutorial.ipynb | 5 +++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/generation/2d_vqgan/2d_vqgan_tutorial.ipynb b/generation/2d_vqgan/2d_vqgan_tutorial.ipynb index 156d547ec..f95b0df61 100644 --- a/generation/2d_vqgan/2d_vqgan_tutorial.ipynb +++ b/generation/2d_vqgan/2d_vqgan_tutorial.ipynb @@ -222,6 +222,8 @@ } ], "source": [ + "batch_size = 16\n", + "\n", "train_transforms = mt.Compose(\n", " [\n", " mt.LoadImaged(keys=[\"image\"]),\n", @@ -239,7 +241,7 @@ " ]\n", ")\n", "train_ds = CacheDataset(data=train_datalist, transform=train_transforms)\n", - "train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4, persistent_workers=True)" + "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)" ] }, { @@ -276,7 +278,7 @@ " ]\n", ")\n", "val_ds = CacheDataset(data=val_datalist, transform=val_transforms)\n", - "val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4, persistent_workers=True)" + "val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)" ] }, { diff --git a/generation/2d_vqvae/2d_vqvae_tutorial.ipynb b/generation/2d_vqvae/2d_vqvae_tutorial.ipynb index 100a490ac..b61dca0e7 100644 --- a/generation/2d_vqvae/2d_vqvae_tutorial.ipynb +++ b/generation/2d_vqvae/2d_vqvae_tutorial.ipynb @@ -221,6 +221,8 @@ "train_data = MedNISTDataset(root_dir=root_dir, section=\"training\", download=True, seed=0)\n", "train_datalist = [{\"image\": item[\"image\"]} for item in train_data.data if item[\"class_name\"] == \"HeadCT\"]\n", "image_size = 64\n", + "batch_size = 16\n", + "\n", "train_transforms = mt.Compose(\n", " [\n", " mt.LoadImaged(keys=[\"image\"]),\n", @@ -238,7 +240,7 @@ " ]\n", ")\n", "train_ds = Dataset(data=train_datalist, transform=train_transforms)\n", - "train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True)" + "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)" ] }, { @@ -317,7 +319,7 @@ " ]\n", ")\n", "val_ds = Dataset(data=val_datalist, transform=val_transforms)\n", - "val_loader = DataLoader(val_ds, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True)" + "val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)" ] }, { diff --git a/generation/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.ipynb b/generation/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.ipynb index fee060561..4e43c4a13 100644 --- a/generation/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.ipynb +++ b/generation/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.ipynb @@ -226,6 +226,7 @@ "train_data = MedNISTDataset(root_dir=root_dir, section=\"training\", download=True, seed=0)\n", "train_datalist = [{\"image\": item[\"image\"]} for item in train_data.data if item[\"class_name\"] == \"HeadCT\"]\n", "image_size = 64\n", + "batch_size = 16\n", "train_transforms = transforms.Compose(\n", " [\n", " transforms.LoadImaged(keys=[\"image\"]),\n", @@ -243,7 +244,7 @@ " ]\n", ")\n", "train_ds = Dataset(data=train_datalist, transform=train_transforms)\n", - "train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4, persistent_workers=True)" + "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)" ] }, { @@ -322,7 +323,7 @@ " ]\n", ")\n", "val_ds = Dataset(data=val_datalist, transform=val_transforms)\n", - "val_loader = DataLoader(val_ds, batch_size=128, shuffle=True, num_workers=4, persistent_workers=True)" + "val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)" ] }, {