diff --git a/src/pl_bolts/models/gans/dcgan/dcgan_module.py b/src/pl_bolts/models/gans/dcgan/dcgan_module.py index 01202a25b..d26dff0db 100644 --- a/src/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/src/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -14,7 +14,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib - from torchvision.datasets import LSUN, MNIST + from torchvision.datasets import CIFAR10, LSUN, MNIST else: # pragma: no cover warn_missing_pkg("torchvision") @@ -35,7 +35,7 @@ class DCGAN(LightningModule): python dcgan_module.py --gpus 1 # cifar10 - python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3 + python dcgan_module.py --gpus 1 --dataset cifar10 """ @@ -174,7 +174,7 @@ def cli_main(args=None): parser = ArgumentParser() parser.add_argument("--batch_size", default=64, type=int) - parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist"]) + parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist", "cifar10"]) parser.add_argument("--data_dir", default="./", type=str) parser.add_argument("--image_size", default=64, type=int) parser.add_argument("--num_workers", default=8, type=int) @@ -202,6 +202,16 @@ def cli_main(args=None): ) dataset = MNIST(root=script_args.data_dir, download=True, transform=transforms) image_channels = 1 + elif script_args.dataset == "cifar10": + transforms = transform_lib.Compose( + [ + transform_lib.Resize(script_args.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ] + ) + dataset = CIFAR10(root=script_args.data_dir, download=True, transform=transforms) + image_channels = 3 dataloader = DataLoader( dataset, batch_size=script_args.batch_size, shuffle=True, num_workers=script_args.num_workers