From 46873cf45c28453d31d6f9a9e829b2d4bdc7b805 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 25 Nov 2020 12:49:07 -0800 Subject: [PATCH] add truncated normals, default truncation at thresholds -1.5 and 1.5 --- lightweight_gan/lightweight_gan.py | 8 +++++++- lightweight_gan/version.py | 2 +- setup.py | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index 8ab6f0b..de4444a 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -34,6 +34,8 @@ from adabelief_pytorch import AdaBelief from gsa_pytorch import GSA +from scipy.stats import truncnorm + # asserts assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' @@ -110,6 +112,10 @@ def slerp(val, low, high): res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res +def truncated_normal(size, threshold = 1.5): + values = truncnorm.rvs(-threshold, threshold, size = size) + return torch.from_numpy(values) + # helper classes class NanException(Exception): @@ -1084,7 +1090,7 @@ def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0): # latents and noise - latents = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank) + latents = truncated_normal((num_rows ** 2, latent_dim)).float().cuda(self.rank) # regular diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index f8d9095..92a60bd 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.12.1' +__version__ = '0.12.2' diff --git a/setup.py b/setup.py index 823a127..394cb67 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ 'pillow', 'pytorch-fid', 'retry', + 'scipy', 'torch>=1.6', 'torchvision', 'tqdm'