diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index 54fc17e..fb4f7a8 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -348,14 +348,17 @@ def __init__( width ): super().__init__() + freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]) self.register_buffer('dct_weights', dct_weights) + chan_intermediate = max(3, chan_out // reduction) + self.net = nn.Sequential( - nn.Conv2d(chan_in, chan_out // reduction, 1), + nn.Conv2d(chan_in, chan_intermediate, 1), nn.LeakyReLU(0.1), - nn.Conv2d(chan_out // reduction, chan_out, 1), + nn.Conv2d(chan_intermediate, chan_out, 1), nn.Sigmoid() ) diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index cbe8c74..84d3488 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.17.1' +__version__ = '0.17.2'