Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Kerfoot <[email protected]>
  • Loading branch information
ericspod committed Sep 6, 2024
1 parent f384232 commit 6ec6692
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
6 changes: 4 additions & 2 deletions generation/2d_vqgan/2d_vqgan_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@
}
],
"source": [
"batch_size = 16\n",
"\n",
"train_transforms = mt.Compose(\n",
" [\n",
" mt.LoadImaged(keys=[\"image\"]),\n",
Expand All @@ -239,7 +241,7 @@
" ]\n",
")\n",
"train_ds = CacheDataset(data=train_datalist, transform=train_transforms)\n",
"train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4, persistent_workers=True)"
"train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)"
]
},
{
Expand Down Expand Up @@ -276,7 +278,7 @@
" ]\n",
")\n",
"val_ds = CacheDataset(data=val_datalist, transform=val_transforms)\n",
"val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4, persistent_workers=True)"
"val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)"
]
},
{
Expand Down
6 changes: 4 additions & 2 deletions generation/2d_vqvae/2d_vqvae_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@
"train_data = MedNISTDataset(root_dir=root_dir, section=\"training\", download=True, seed=0)\n",
"train_datalist = [{\"image\": item[\"image\"]} for item in train_data.data if item[\"class_name\"] == \"HeadCT\"]\n",
"image_size = 64\n",
"batch_size = 16\n",
"\n",
"train_transforms = mt.Compose(\n",
" [\n",
" mt.LoadImaged(keys=[\"image\"]),\n",
Expand All @@ -238,7 +240,7 @@
" ]\n",
")\n",
"train_ds = Dataset(data=train_datalist, transform=train_transforms)\n",
"train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True)"
"train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)"
]
},
{
Expand Down Expand Up @@ -317,7 +319,7 @@
" ]\n",
")\n",
"val_ds = Dataset(data=val_datalist, transform=val_transforms)\n",
"val_loader = DataLoader(val_ds, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True)"
"val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@
"train_data = MedNISTDataset(root_dir=root_dir, section=\"training\", download=True, seed=0)\n",
"train_datalist = [{\"image\": item[\"image\"]} for item in train_data.data if item[\"class_name\"] == \"HeadCT\"]\n",
"image_size = 64\n",
"batch_size = 16\n",
"train_transforms = transforms.Compose(\n",
" [\n",
" transforms.LoadImaged(keys=[\"image\"]),\n",
Expand All @@ -243,7 +244,7 @@
" ]\n",
")\n",
"train_ds = Dataset(data=train_datalist, transform=train_transforms)\n",
"train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4, persistent_workers=True)"
"train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)"
]
},
{
Expand Down Expand Up @@ -322,7 +323,7 @@
" ]\n",
")\n",
"val_ds = Dataset(data=val_datalist, transform=val_transforms)\n",
"val_loader = DataLoader(val_ds, batch_size=128, shuffle=True, num_workers=4, persistent_workers=True)"
"val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)"
]
},
{
Expand Down

0 comments on commit 6ec6692

Please sign in to comment.