From 142ef67fa656e09c90732bb05e99ab6b4f17d5f7 Mon Sep 17 00:00:00 2001 From: Ishan Dutta Date: Sun, 2 Jul 2023 22:45:47 +0530 Subject: [PATCH] Add support for CIFAR10 Dataset in the DCGAN Module --- src/pl_bolts/models/gans/dcgan/dcgan_module.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/pl_bolts/models/gans/dcgan/dcgan_module.py b/src/pl_bolts/models/gans/dcgan/dcgan_module.py index 01202a25b..8abf043ae 100644 --- a/src/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/src/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -14,7 +14,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib - from torchvision.datasets import LSUN, MNIST + from torchvision.datasets import LSUN, MNIST, CIFAR10 else: # pragma: no cover warn_missing_pkg("torchvision") @@ -35,8 +35,7 @@ class DCGAN(LightningModule): python dcgan_module.py --gpus 1 # cifar10 - python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3 - + python dcgan_module.py --gpus 1 --dataset cifar10 """ def __init__( @@ -174,7 +173,7 @@ def cli_main(args=None): parser = ArgumentParser() parser.add_argument("--batch_size", default=64, type=int) - parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist"]) + parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist", "cifar10"]) parser.add_argument("--data_dir", default="./", type=str) parser.add_argument("--image_size", default=64, type=int) parser.add_argument("--num_workers", default=8, type=int) @@ -202,6 +201,16 @@ def cli_main(args=None): ) dataset = MNIST(root=script_args.data_dir, download=True, transform=transforms) image_channels = 1 + elif script_args.dataset == "cifar10": + transforms = transform_lib.Compose( + [ + transform_lib.Resize(script_args.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + dataset = CIFAR10(root=script_args.data_dir, download=True, transform=transforms) + image_channels = 3 dataloader = DataLoader( dataset, batch_size=script_args.batch_size, shuffle=True, num_workers=script_args.num_workers