Skip to content

Commit

Permalink
minor dcgan-import fix (#921)
Browse files Browse the repository at this point in the history
* minor dcgan-import fix
* fix under_review import
* update docstring

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Atharva-Phatak and pre-commit-ci[bot] authored Oct 31, 2022
1 parent e9c70c6 commit c91726f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
7 changes: 2 additions & 5 deletions pl_bolts/models/gans/dcgan/components.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py
from torch import Tensor, nn
import torch.nn as nn
from torch import Tensor

from pl_bolts.utils.stability import under_review


@under_review()
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
"""
Expand Down Expand Up @@ -50,7 +48,6 @@ def forward(self, noise: Tensor) -> Tensor:
return self.gen(noise)


@under_review()
class DCGANDiscriminator(nn.Module):
def __init__(self, feature_maps: int, image_channels: int) -> None:
"""
Expand Down
14 changes: 6 additions & 8 deletions pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from typing import Any

import torch
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch import Tensor, nn
from torch import Tensor
from torch.utils.data import DataLoader

from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -19,7 +19,6 @@
warn_missing_pkg("torchvision")


@under_review()
class DCGAN(LightningModule):
"""DCGAN implementation.
Expand All @@ -28,7 +27,7 @@ class DCGAN(LightningModule):
from pl_bolts.models.gans import DCGAN
m = DCGAN()
Trainer(gpus=2).fit(m)
Trainer(accelerator="gpu", devices=2).fit(m)
Example CLI::
Expand Down Expand Up @@ -80,10 +79,10 @@ def _get_discriminator(self) -> nn.Module:
def _weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)

def configure_optimizers(self):
lr = self.hparams.learning_rate
Expand Down Expand Up @@ -173,7 +172,6 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
return parser


@under_review()
def cli_main(args=None):
seed_everything(1234)

Expand Down

0 comments on commit c91726f

Please sign in to comment.