Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training error #11

Open
bycloudai opened this issue Jul 27, 2020 · 3 comments
Open

Training error #11

bycloudai opened this issue Jul 27, 2020 · 3 comments

Comments

@bycloudai
Copy link

bycloudai commented Jul 27, 2020

Hello again, I got this training error when running "train.py", how can I solve this?

(hiface) G:\HiFaceGAN\Face-Renovation-master>python train.py
train.py
dataset [TrainDataset] of size 7 was created
Network [HiFaceGANGenerator] was created. Total number of parameters: 128.0 million. To see the architecture, do print(network).
Network [MultiscaleDiscriminator] was created. Total number of parameters: 5.5 million. To see the architecture, do print(network).
create web directory ./checkpoints\exp1\web...
Traceback (most recent call last):
  File "train.py", line 93, in <module>
    main()
  File "train.py", line 52, in main
    trainer.run_generator_one_step(data_i)
  File "G:\HiFaceGAN\Face-Renovation-master\trainers\pix2pix_trainer.py", line 34, in run_generator_one_step
    g_losses, generated = self.pix2pix_model(data, mode='generator')
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\parallel\data_parallel.py", line 153, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "G:\HiFaceGAN\Face-Renovation-master\models\pix2pix_model.py", line 47, in forward
    g_loss, generated = self.compute_generator_loss(input_semantics, real_image)
  File "G:\HiFaceGAN\Face-Renovation-master\models\pix2pix_model.py", line 74, in compute_generator_loss
    fake_image = self.generate_fake(input_semantics)
  File "G:\HiFaceGAN\Face-Renovation-master\models\pix2pix_model.py", line 120, in generate_fake
    fake_image = self.netG(input_semantics)
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "G:\HiFaceGAN\Face-Renovation-master\models\networks\generator.py", line 238, in forward
    x = self.head_0(x, xs[0])
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "G:\HiFaceGAN\Face-Renovation-master\models\networks\architecture.py", line 55, in forward
    dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "G:\HiFaceGAN\Face-Renovation-master\models\networks\normalization.py", line 100, in forward
    actv = self.mlp_shared(segmap)
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\container.py", line 100, in forward
    input = module(input)
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\conv.py", line 353, in forward
    return self._conv_forward(input, self.weight)
  File "E:\Anaconda3\envs\hiface\lib\site-packages\torch\nn\modules\conv.py", line 350, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [128, 768, 3, 3], expected input[2, 1024, 4, 4] to have 768 channels, but got 1024 channels instead

I've done exactly what you described for degrade.py, input 512x512 image and produce a paired image which is 512x1024.

Here's the training config

class TrainOptions(object):
    dataroot = './training_t_full/'
    dataroot_assist = ''
    name = 'exp1'
    crop_size = 512

    gpu_ids = [0]  # set to [] for CPU-only training (not tested)
    gan_mode = 'ls'

    continue_train = False
    which_epoch = 'latest'

    D_steps_per_G = 1
    aspect_ratio = 1.0
    batchSize = 2
    beta1 = 0.0
    beta2 = 0.9
    cache_filelist_read = True
    cache_filelist_write = True
    checkpoints_dir = './checkpoints'
    choose_pair = [0, 1]
    coco_no_portraits = False
    contain_dontcare_label = False

    dataset_mode = 'train'
    debug = False
    display_freq = 100
    display_winsize = 256
    print_freq = 100
    save_epoch_freq = 1
    save_latest_freq = 5000

    init_type = 'xavier'
    init_variance = 0.02
    isTrain = True
    is_test = False

    semantic_nc = 3
    label_nc = 3
    output_nc = 3
    lambda_feat = 10.0
    lambda_kld = 0.05
    lambda_vgg = 10.0
    load_from_opt_file = False
    lr = 0.0002
    max_dataset_size = sys.maxsize
    model = 'pix2pix'
    nThreads = 2

    n_layers_D = 4
    num_D = 2
    ndf = 64
    nef = 16
    netD = 'multiscale'
    netD_subarch = 'n_layer'
    netG = 'hifacegan'  # spade, lipspade
    ngf = 64  # set to 48 for Titan X 12GB card
    niter = 30
    niter_decay = 20
    no_TTUR = False
    no_flip = False
    no_ganFeat_loss = False
    no_html = False
    no_instance = True
    no_pairing_check = False
    no_vgg_loss = False

    norm_D = 'spectralinstance'
    norm_E = 'spectralinstance'
    norm_G = 'spectralspadesyncbatch3x3'

    num_upsampling_layers = 'normal'
    optimizer = 'adam'
    phase = 'train'
    prd_resize = 512
    preprocess_mode = 'resize_and_crop'

    serial_batches = False
    tf_log = False
    train_phase = 3  # progressive training disabled (set initial phase to 0 to enable it)
    # 20200211
    #max_train_phase = 2 # default 3 (4x)
    max_train_phase = 3
    # training 1024*1024 is also possible, just turning this to 4 and add more layers in generator.
    upsample_phase_epoch_fq = 5
    use_vae = False
    z_dim = 256

thank you!

@darkcake
Copy link

Have same error. Please help!

@kwea123
Copy link

kwea123 commented Aug 2, 2020

Change this parameter to 48:

ngf = 64 # set to 48 for Titan X 12GB card

It controls the input channel, 48 means 48x16=768 channels.
if opt.use_vae:
# In case of VAE, we will sample from random z vector
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
else:
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 768)
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 768)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 768)

The training will work in this case, but I don't know if the result will be good... The code seems too messy...

@Lotayou
Copy link
Owner

Lotayou commented Aug 7, 2020

@kwea123 Thanks for mentioning. The code has now been reformatted and should work for both 48 and 64.

if opt.use_vae:
# In case of VAE, we will sample from random z vector
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
else:
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)

Most experiments reported in the paper are trained with ngf=48 that fits on a 12GB Titan X card, and the result is already good enough to beat SOTA. The final code release is tested on a company server with 16GB P100 card, allowing training with ngf=64. I haven't benchmarked the performance between 48 and 64 though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants