Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential bug risks and fix anti-patterns #32

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions dnnlib/submission/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,9 @@ def get_path_from_template(path_template: str, path_type: PathType = PathType.AU
# return correctly formatted path
if path_type == PathType.WINDOWS:
return str(pathlib.PureWindowsPath(path_template))
elif path_type == PathType.LINUX:
if path_type == PathType.LINUX:
return str(pathlib.PurePosixPath(path_template))
else:
raise RuntimeError("Unknown platform")
raise RuntimeError("Unknown platform")


def get_template_from_path(path: str) -> str:
Expand All @@ -158,9 +157,9 @@ def get_user_name():
"""Get the current user name."""
if _user_name_override is not None:
return _user_name_override
elif platform.system() == "Windows":
if platform.system() == "Windows":
return os.getlogin()
elif platform.system() == "Linux":
if platform.system() == "Linux":
try:
import pwd
return pwd.getpwuid(os.geteuid()).pw_name
Expand Down Expand Up @@ -283,15 +282,14 @@ def run_wrapper(submit_config: SubmitConfig) -> None:
except:
if is_local:
raise
else:
traceback.print_exc()
traceback.print_exc()

log_src = os.path.join(submit_config.run_dir, "log.txt")
log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
shutil.copyfile(log_src, log_dst)
log_src = os.path.join(submit_config.run_dir, "log.txt")
log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
shutil.copyfile(log_src, log_dst)

# Defer sys.exit(1) to happen after we close the logs and create a _finished.txt
exit_with_errcode = True
# Defer sys.exit(1) to happen after we close the logs and create a _finished.txt
exit_with_errcode = True
finally:
open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()

Expand Down
2 changes: 1 addition & 1 deletion dnnlib/tflib/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _prepare_nvcc_cli(opts):
#----------------------------------------------------------------------------
# Main entry point.

_plugin_cache = dict()
_plugin_cache = {}

def get_plugin(cuda_file):
cuda_file_base = os.path.basename(cuda_file)
Expand Down
6 changes: 3 additions & 3 deletions dnnlib/tflib/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .tfutil import TfExpression, TfExpressionEx

_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
_import_module_src = dict() # Source code for temporary modules created during pickle import.
_import_module_src = {} # Source code for temporary modules created during pickle import.


def import_handler(handler_func):
Expand Down Expand Up @@ -120,7 +120,7 @@ def _init_fields(self) -> None:
self._build_func = None # User-supplied build function that constructs the network.
self._build_func_name = None # Name of the build function.
self._build_module_src = None # Full source code of the module containing the build function.
self._run_cache = dict() # Cached graph data for Network.run().
self._run_cache = {} # Cached graph data for Network.run().

def _init_graph(self) -> None:
# Collect inputs.
Expand Down Expand Up @@ -254,7 +254,7 @@ def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[

def __getstate__(self) -> dict:
"""Pickle export."""
state = dict()
state = {}
state["version"] = 4
state["name"] = self.name
state["static_kwargs"] = dict(self.static_kwargs)
Expand Down
10 changes: 7 additions & 3 deletions dnnlib/tflib/tfutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:

def _sanitize_tf_config(config_dict: dict = None) -> dict:
# Defaults.
cfg = dict()
cfg = {}
cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
Expand Down Expand Up @@ -227,20 +227,24 @@ def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwar
return var


def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
def convert_images_from_uint8(images, drange=None, nhwc_to_nchw=False):
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
Can be used as an input transformation for Network.run().
"""
if drange is None:
drange = [-1,1]
images = tf.cast(images, tf.float32)
if nhwc_to_nchw:
images = tf.transpose(images, [0, 3, 1, 2])
return images * ((drange[1] - drange[0]) / 255) + drange[0]


def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
def convert_images_to_uint8(images, drange=None, nchw_to_nhwc=False, shrink=1):
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
Can be used as an output transformation for Network.run().
"""
if drange is None:
drange = [-1,1]
images = tf.cast(images, tf.float32)
if shrink > 1:
ksize = [1, 1, shrink, shrink]
Expand Down
10 changes: 4 additions & 6 deletions module/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_
# print(x.shape)
x = self.chain[i](x, context, logpx, integration_times, reverse)
return x
else:
for i in inds:
x, logpx = self.chain[i](x, context, logpx, integration_times, reverse)
return x, logpx
for i in inds:
x, logpx = self.chain[i](x, context, logpx, integration_times, reverse)
return x, logpx


class CNF(nn.Module):
Expand Down Expand Up @@ -115,8 +114,7 @@ def forward(self, x, context=None, logpx=None, integration_times=None, reverse=F

if logpx is not None:
return z_t, logpz_t
else:
return z_t
return z_t

def num_evals(self):
return self.odefunc._num_evals.item()
Expand Down
9 changes: 3 additions & 6 deletions module/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def reset_parameters(self):
def forward(self, x, c=None, logpx=None, reverse=False):
if reverse:
return self._reverse(x, logpx)
else:
return self._forward(x, logpx)
return self._forward(x, logpx)

def _forward(self, x, logpx=None):
num_channels = x.size(-1)
Expand Down Expand Up @@ -87,8 +86,7 @@ def _forward(self, x, logpx=None):

if logpx is None:
return y
else:
return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)

def _reverse(self, y, logpy=None):
used_mean = self.running_mean
Expand All @@ -105,8 +103,7 @@ def _reverse(self, y, logpy=None):

if logpy is None:
return x
else:
return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)
return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)

def _logdetgrad(self, x, used_var):
logdetgrad = -0.5 * torch.log(used_var + self.eps)
Expand Down
5 changes: 2 additions & 3 deletions module/odefunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,8 @@ def forward(self, t, states):
divergence = self.divergence_fn(dy, y, e=self._e).unsqueeze(-1)

return dy, -divergence, torch.zeros_like(c).requires_grad_(True)
elif len(states) == 2: # unconditional CNF
if len(states) == 2: # unconditional CNF
dy = self.diffeq(t, y)
divergence = self.divergence_fn(dy, y, e=self._e).view(-1, 1)
return dy, -divergence
else:
assert 0, "`len(states)` should be 2 or 3"
assert 0, "`len(states)` should be 2 or 3"
4 changes: 3 additions & 1 deletion module/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def set_random_seed(seed):


# Visualization
def visualize_point_clouds(pts, gtr, idx, pert_order=[0, 1, 2]):
def visualize_point_clouds(pts, gtr, idx, pert_order=None):
if pert_order is None:
pert_order = [0, 1, 2]
pts = pts.cpu().detach().numpy()[:, pert_order]
gtr = gtr.cpu().detach().numpy()[:, pert_order]

Expand Down
2 changes: 1 addition & 1 deletion pretrained_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_path_or_url(path_or_gdrive_path):

#----------------------------------------------------------------------------

_cached_networks = dict()
_cached_networks = {}

def load_networks(path_or_gdrive_path):
path_or_url = get_path_or_url(path_or_gdrive_path)
Expand Down
2 changes: 1 addition & 1 deletion run_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_
all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]
all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component]
all_w = w_avg + (all_w - w_avg) * truncation_psi # [minibatch, layer, component]
w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} # [layer, component]
w_dict = dict(zip(all_seeds, list(all_w))) # [layer, component]

print('Generating images...')
all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs) # [minibatch, height, width, channel]
Expand Down
5 changes: 3 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import os
import re
Expand Down Expand Up @@ -142,7 +141,9 @@ def get_style_loss(base_style, gram_target):

#----------------------------------------------------------------------------

def generate_im_official(network_pkl='gdrive:networks/stylegan2-ffhq-config-f.pkl', seeds=[22], truncation_psi=0.5):
def generate_im_official(network_pkl='gdrive:networks/stylegan2-ffhq-config-f.pkl', seeds=None, truncation_psi=0.5):
if seeds is None:
seeds = [22]
print('Loading networks from "%s"...' % network_pkl)
_G, _D, Gs = pretrained_networks.load_networks(network_pkl)
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
Expand Down