Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor dcgan-import fix #921

Merged
merged 9 commits into from
Oct 31, 2022
7 changes: 2 additions & 5 deletions pl_bolts/models/gans/dcgan/components.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py
from torch import Tensor, nn
import torch.nn as nn
from torch import Tensor

from pl_bolts.utils.stability import under_review


@under_review()
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
"""
Expand Down Expand Up @@ -50,7 +48,6 @@ def forward(self, noise: Tensor) -> Tensor:
return self.gen(noise)


@under_review()
class DCGANDiscriminator(nn.Module):
def __init__(self, feature_maps: int, image_channels: int) -> None:
"""
Expand Down
14 changes: 6 additions & 8 deletions pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from typing import Any

import torch
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch import Tensor, nn
from torch import Tensor
from torch.utils.data import DataLoader

from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -19,7 +19,6 @@
warn_missing_pkg("torchvision")


@under_review()
class DCGAN(LightningModule):
"""DCGAN implementation.

Expand All @@ -28,7 +27,7 @@ class DCGAN(LightningModule):
from pl_bolts.models.gans import DCGAN

m = DCGAN()
Trainer(gpus=2).fit(m)
Trainer(accelerator="gpu", devices=2).fit(m)

Example CLI::

Expand Down Expand Up @@ -80,10 +79,10 @@ def _get_discriminator(self) -> nn.Module:
def _weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)

def configure_optimizers(self):
lr = self.hparams.learning_rate
Expand Down Expand Up @@ -173,7 +172,6 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
return parser


@under_review()
def cli_main(args=None):
seed_everything(1234)

Expand Down