Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU memory explode after 3 steps #5

Open
InitialBug opened this issue Jun 7, 2018 · 31 comments
Open

GPU memory explode after 3 steps #5

InitialBug opened this issue Jun 7, 2018 · 31 comments

Comments

@InitialBug
Copy link

InitialBug commented Jun 7, 2018

I use the TiTan x GPU, but the GPU memory is growing rapidly, and after 3 batches, it went out of memory.
I have check your code line by line, and I still don't konw what's wrong with it

@hengruo
Copy link
Owner

hengruo commented Jun 8, 2018

@InitialBug I'm testing on 1080ti and got the same result. The program crashes every time when it run into the self attention layer. I'm sorry that I don't know why. The original model can run normally but now it fails.

@hengruo
Copy link
Owner

hengruo commented Jun 14, 2018

@InitialBug Hey I fixed the issue! The original model can't release computational graph after each epoch so took enormous memory. I modified CQAttention so that the graph can be released automatically. Now the model runs as fast as a rocket!

@deepakkumar1984
Copy link

I have 2 8GB M60 card, but still fails with CUDA memory exception. Does it support multi gpu for now?

@hengruo
Copy link
Owner

hengruo commented Jun 19, 2018

@deepakkumar1984 I just have one GPU so I can't test it on multi ones. I'll be very grateful if you commit relevant codes.

@deepakkumar1984
Copy link

I tried using GPU V100 with 16GB HBM VRAM. With that too it give out of memory exception:

THCudaCheck FAIL file=/pytorch/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory

Traceback (most recent call last):
File "main.py", line 294, in
app.run(main)
File "/home/ubuntu/.local/lib/python3.5/site-packages/absl/app.py", line 274, in run
_run_main(main, args)
File "/home/ubuntu/.local/lib/python3.5/site-packages/absl/app.py", line 238, in _run_main
sys.exit(main(argv))
File "main.py", line 275, in main
train_entry(config)
File "main.py", line 240, in train_entry
train(model, optimizer, scheduler, train_dataset, iter, L)
File "main.py", line 145, in train
p1, p2 = model(Cwid, Ccid, Qwid, Qcid)
File "/home/ubuntu/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/home/ubuntu/git/QANet-pytorch/models.py", line 254, in forward
M3 = self.model_enc_blks(M2)
File "/home/ubuntu/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/home/ubuntu/.local/lib/python3.5/site-packages/torch/nn/modules/container.py", line 91, in forward
input = module(input)
File "/home/ubuntu/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/home/ubuntu/git/QANet-pytorch/models.py", line 154, in forward
out = self.self_att(out)
File "/home/ubuntu/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/home/ubuntu/git/QANet-pytorch/models.py", line 98, in forward
out = torch.bmm(WQs[i], WKs[i].transpose(1, 2))
RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58

Any idea???

@hengruo
Copy link
Owner

hengruo commented Jun 21, 2018

@deepakkumar1984 Please set batch_size = 28. I tested it on 16g GPU. That setting can run smoothly. BTW, current F1/EM is very low. I'm fixing that.

@deepakkumar1984
Copy link

deepakkumar1984 commented Jun 21, 2018 via email

@InitialBug
Copy link
Author

@hengruo the reason for low score may be that your optim scheduler has some problem, I print the learning rate of the optimizer but it seems that it doesn't increase from 0 to 0.001 but got 1e-7 after 1000 steps. And the training loss doesn't converge because of the learning rate. After changing the optim function to Adam without increasing learning rate, the model converges fast and I'v got 53 F1 score after 10,000 traing steps but using different hypeparameters.

@hengruo
Copy link
Owner

hengruo commented Jun 21, 2018

@InitialBug Wow that's amazing! Could you please public your hyperparameters?
And have you updated your local codes? I indeed saw the 1e-7 learning rate but I have fixed that. Now the learning rate will increase from 0 to 1e-3 during the first 1000 steps and exponentially decrease later.

@InitialBug
Copy link
Author

InitialBug commented Jun 21, 2018

@hengruo I have tried many different hyperparameters, but I think the root problem is the optim function, you can simply try using Adam with fixed learning rate 0.001. I think you can get better results even after 1,000 steps. But I don't know why. This is my experience after 5 days debug.
Because of my personal naming style, I have changed some hyperparameters' name in the config, and here is the key parameters in the code.
`flags.DEFINE_integer("char_dim", 64, "Embedding dimension for char")

flags.DEFINE_integer("word_len", 16, "Limit length for one word")
flags.DEFINE_integer("emb_dim", 96, "Dimension of connectors of each layer")
flags.DEFINE_integer("head_num", 6, "Number of heads in multi-head attention")
flags.DEFINE_integer("attention_map_dim", 64, "MultiHeadAttentnion middle dimension")
flags.DEFINE_integer("kernel_size", 7, "kernel size of CNN")
flags.DEFINE_integer("block_num", 7, "Number of blocks in the model encoder layer")`

the attention_map_dim refers to dv,dk in the mutihead attention layer, and I find lower character dimension makes model converge faster

@hengruo
Copy link
Owner

hengruo commented Jun 21, 2018

@InitialBug THANK YOU VERY MUCH!!! I'm testing by your settings!

@InitialBug
Copy link
Author

@hengruo Any good news? I tested many different hyperparameters, but the best F1 score is only 64.3 as so far. I wonder if the model has some problem?

@Jimmy880
Copy link

I got the same question.
I used 12G GeForce GTX titan X and i got "out of memory" error when i run the main.py in train mode
Now i have to modify the "batch_size" in config.py to 16 to run it.

@BangLiu
Copy link

BangLiu commented Jul 5, 2018

@InitialBug May I ask what is the current best performance you can get?

  1. In the QANet paper, they use warmup and fix lr=0.001 after 1000 steps. I revised it according to the paper. (This repository use ExponentialLR scheduler with 0.9999 gamma value). I found that the training becomes quite unstable. Usually we get unstable after around 5 epoches. And the best performance is only around 30 (F1), then it may change dramatically to be 4%, and varies a lot and unpredictable.

  2. If I use the ExponentialLR scheduler, the performance grows. However, after around 10 epoches, the performance increases quite slow. When I debug, I manually load the best checkpoint and then set lr=0.001; after it, every time I found that the performance increase is slow, I decrease the lr = 0.1 * old_lr. But the performance is still around 61 (F1).

I am not sure why fix LR doesn't work well, as it should be good according to the original paper.
And I don't know why the performance is not good enough (only 61 F1).

@BangLiu
Copy link

BangLiu commented Jul 5, 2018

Another problem is that seems we don't have exponential moving average here.

@InitialBug
Copy link
Author

@BangLiu The best F1 score is aroud 66, but when I keep training, the model overfits. I have re-impemented the part of the model.py, but most of the modules are the same as this reposity. I also use the exponential moving average but it seems to also cause the training unstable. I think the only difference between my model and the paper is the hyperparameters and the stochastic
depth method (layer dropout). But I don't have enough GPU memory(12GB is not enough).

@BangLiu
Copy link

BangLiu commented Jul 5, 2018

@InitialBug I uploaded my implementation based on this repository to: https://github.com/BangLiu/QANet-PyTorch

You are welcomed to test my code. I get memory explode using this implementation using batch_size 32, but my implementation doesn't have this problem. However, currently my performance is not good. I implemented EMA in my code repository and tested it in QANet_trainer.py (You can check to see whether my implementation is correct). But the EMA makes the performance quite bad (F1 is less than 10% ....) and I don't know why.

@hengruo It will be great if you can also take a look and to see what is different with our implementations. I think we should get about 78 ~ 80 F1, otherwise it is not correct....

@haibarasfs
Copy link

@BangLiu i also meet the same problem when test your code, i use one 1080ti card

@BangLiu
Copy link

BangLiu commented Jul 13, 2018

@haibarasfs Currently the best performance I can get is F1 64 EM 50 (with long time training ......)
I haven't figured out the difference. Ideally, it should reach around 77 quite soon according to the paper.

@BangLiu
Copy link

BangLiu commented Jul 13, 2018

@haibarasfs You mean memory explode? If you use 1080ti card, then the batch_size maybe need to be smaller...

@hengruo
Copy link
Owner

hengruo commented Jul 14, 2018

@InitialBug @BangLiu Sorry for my late response. I was traveling so hardly made any progress. Thanks for your contributions! I start working on it again. If I have any progress I'll let you know. BTW, I added EMA.

@jayleicn
Copy link

Not sure whether you guys have seen the Tensorflow implementation, it gets even better results than the original paper, see https://github.com/NLPLearn/QANet#results. Hope it helps!

@andy840314
Copy link

I have implemented a repository QANet, mostly based on this repository and another Tensorflow implementation Tensorflow QANet. I can reach F1: 75.0 EM: 64.0 in 60000 steps. You could take a look!

@BangLiu
Copy link

BangLiu commented Jul 23, 2018

@andy840314 What do you think that may cause the 5.0 point gap between your implementation and that of Tensorflow QANet? I am also trying to reach the performance.

@andy840314
Copy link

@BangLiu i'm not sure, but i will try adding EMA first.

@hengruo
Copy link
Owner

hengruo commented Jul 24, 2018

@InitialBug @deepakkumar1984 @Jimmy880 @BangLiu @haibarasfs I contacted one of QANet's authors. They've published their model. Link: https://github.com/tensorflow/tpu/tree/master/models/experimental/qanet
I'm tuning by their hyper-parameters.

@BangLiu
Copy link

BangLiu commented Jul 24, 2018

@andy840314 @hengruo I tested andy's code, it can achieve F1 74.128853 EM 62.707499 in 14 epochs, and F1 70.157651 EM 58.185461 in 4 epochs.
I compared andy's version with hengruo's version. Currently, I think the differences includes:

  1. Initialized_Conv1d
    In this class, the weight is initialized according to specific method, and this convolution class is used many times in andy's implementation. Previously when I test hengruo's version, I mainly use DepthWiseSeparableConvolution.
  2. the embedding encoder in QANet is shared for both Context and Question in andy's version; in hengruo's version, we use two separate encoder for Question and Context.
  3. In my repository, I use Multi-head attention for self-encoding; besides, my pos encoding and highway network's implementation is also not the same with andy's version.
  4. Other potential differences that I haven't look in detail
    I am wondering what are the key differences that make the performance increase. Hope this helps.

@hackiey
Copy link

hackiey commented Jul 25, 2018

@BangLiu I have implemented QAnet with EMA, also mostly based the Tensorflow implementation, the performance is em: 67.317 and f1: 76.953 (without EMA), em: 70.155 and f1: 79.432 (with EMA) after 22 epochs (2730 batches every epoch, 60060steps).

@BangLiu
Copy link

BangLiu commented Jul 26, 2018

@hackiey That's great! I will test it

@jayleicn
Copy link

I tested @hackiey 's code, it can reach em 67.515 after 18 epochs, em 68.09 after 56 epochs, w/o EMA.

@BangLiu
Copy link

BangLiu commented Jul 27, 2018

@hengruo I think in the train() function of your code, clip_grad_norm_ shall be put before the optimizer.step() ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants