Skip to content

Commit

Permalink
add truncated normals, default truncation at thresholds -1.5 and 1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 25, 2020
1 parent 16342fa commit 46873cf
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
8 changes: 7 additions & 1 deletion lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from adabelief_pytorch import AdaBelief
from gsa_pytorch import GSA

from scipy.stats import truncnorm

# asserts

assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
Expand Down Expand Up @@ -110,6 +112,10 @@ def slerp(val, low, high):
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res

def truncated_normal(size, threshold = 1.5):
values = truncnorm.rvs(-threshold, threshold, size = size)
return torch.from_numpy(values)

# helper classes

class NanException(Exception):
Expand Down Expand Up @@ -1084,7 +1090,7 @@ def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0):

# latents and noise

latents = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
latents = truncated_normal((num_rows ** 2, latent_dim)).float().cuda(self.rank)

# regular

Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.12.1'
__version__ = '0.12.2'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'pillow',
'pytorch-fid',
'retry',
'scipy',
'torch>=1.6',
'torchvision',
'tqdm'
Expand Down

0 comments on commit 46873cf

Please sign in to comment.