Skip to content

Commit

Permalink
Undo model modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
shuklabhay committed Aug 6, 2024
1 parent 92f5bfc commit 92e6c44
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,5 @@ cython_debug/
# User directories
data/
.vscode/
.DS_Store
.DS_Store
DCGAN_generated_audio.wav
43 changes: 23 additions & 20 deletions src/dcgan_architecture.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,49 @@
import torch
import torch.nn as nn

from utils.helpers import save_model

# Constants Constants
BATCH_SIZE = 16
LATENT_DIM = 100
N_EPOCHS = 10

N_CHANNELS = 2 # Left, right
N_FRAMES = 345
N_FREQ_BINS = 257

VALIDATION_INTERVAL = int(N_EPOCHS / 2)
SAVE_INTERVAL = int(N_EPOCHS / 1)


# Model Components
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.deconv_blocks = nn.Sequential(
nn.ConvTranspose2d(LATENT_DIM, 512, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(LATENT_DIM, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 256, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(32, 16, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(16, 8, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(8),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(8, 4, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(8, 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(4, N_CHANNELS, kernel_size=6, stride=2, padding=2),
nn.ConvTranspose2d(4, N_CHANNELS, kernel_size=4, stride=2, padding=1),
nn.Upsample(
size=(N_FRAMES, N_FREQ_BINS), mode="bilinear", align_corners=False
),
Expand All @@ -57,30 +60,30 @@ def __init__(self):
super(Discriminator, self).__init__()
self.conv_blocks = nn.Sequential(
nn.Upsample(size=(512, 512), mode="bilinear", align_corners=False),
nn.Conv2d(N_CHANNELS, 4, kernel_size=6, stride=2, padding=2),
nn.Conv2d(N_CHANNELS, 4, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(4, 8, kernel_size=6, stride=2, padding=2),
nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(8),
nn.Conv2d(8, 16, kernel_size=6, stride=2, padding=2),
nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(16),
nn.Conv2d(16, 32, kernel_size=6, stride=2, padding=2),
nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=6, stride=2, padding=2),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=6, stride=2, padding=2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(128, 256, kernel_size=6, stride=2, padding=2),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(256, 512, kernel_size=6, stride=2, padding=2),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(512, 1, kernel_size=6, stride=2, padding=2),
nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
nn.Sigmoid(),
)

Expand Down

0 comments on commit 92e6c44

Please sign in to comment.