diff --git a/pl_bolts/models/gans/dcgan/components.py b/pl_bolts/models/gans/dcgan/components.py index fe6ccfd8a8..d1d52779bc 100644 --- a/pl_bolts/models/gans/dcgan/components.py +++ b/pl_bolts/models/gans/dcgan/components.py @@ -1,10 +1,8 @@ # Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py -from torch import Tensor, nn +import torch.nn as nn +from torch import Tensor -from pl_bolts.utils.stability import under_review - -@under_review() class DCGANGenerator(nn.Module): def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None: """ @@ -50,7 +48,6 @@ def forward(self, noise: Tensor) -> Tensor: return self.gen(noise) -@under_review() class DCGANDiscriminator(nn.Module): def __init__(self, feature_maps: int, image_channels: int) -> None: """ diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py index 6b80eb95c8..92389118cd 100644 --- a/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -2,14 +2,14 @@ from typing import Any import torch +import torch.nn as nn from pytorch_lightning import LightningModule, Trainer, seed_everything -from torch import Tensor, nn +from torch import Tensor from torch.utils.data import DataLoader from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -19,7 +19,6 @@ warn_missing_pkg("torchvision") -@under_review() class DCGAN(LightningModule): """DCGAN implementation. @@ -28,7 +27,7 @@ class DCGAN(LightningModule): from pl_bolts.models.gans import DCGAN m = DCGAN() - Trainer(gpus=2).fit(m) + Trainer(accelerator="gpu", devices=2).fit(m) Example CLI:: @@ -80,10 +79,10 @@ def _get_discriminator(self) -> nn.Module: def _weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: - torch.nn.init.normal_(m.weight, 0.0, 0.02) + nn.init.normal_(m.weight, 0.0, 0.02) elif classname.find("BatchNorm") != -1: - torch.nn.init.normal_(m.weight, 1.0, 0.02) - torch.nn.init.zeros_(m.bias) + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.zeros_(m.bias) def configure_optimizers(self): lr = self.hparams.learning_rate @@ -173,7 +172,6 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser -@under_review() def cli_main(args=None): seed_everything(1234)