From fd334383c978c48c6f5904f109143cd70347afd1 Mon Sep 17 00:00:00 2001 From: YanniZhangYZ Date: Tue, 3 Oct 2023 17:33:32 +0200 Subject: [PATCH] Add scripts for visualizing eg3d generator architecture. --- eg3d/visualize_pkl.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 eg3d/visualize_pkl.py diff --git a/eg3d/visualize_pkl.py b/eg3d/visualize_pkl.py new file mode 100644 index 00000000..cc03a973 --- /dev/null +++ b/eg3d/visualize_pkl.py @@ -0,0 +1,37 @@ +import pickle +import torch +import io +import legacy +import dnnlib +import json + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + + +device = torch.device('cpu') +with dnnlib.util.open_url("ffhqrebalanced512-128.pkl") as f: + + data = legacy.load_network_pkl(f) + G = data['G'].to(device)# type: ignore + + G_em = data['G_ema'].to(device) # type: ignore + D = data['D'].to(device) + # train_set = data['training_set_kwargs'].to(device) # type: ignore + # augment_pipe = data['augment_pipe'].to(device) + + print(G) + print("--------------------------------------------") + print(G_em) + print("--------------------------------------------") + # print(D)