Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training stability: Continue training even if a data batch was hit which causes OOM #113

Closed
wants to merge 4 commits into from

Conversation

RuntimeRacer
Copy link
Contributor

Following my Investigations mentioned in #110, I implemented some code which does the following in case a training batch causes an OOM error:

  1. Print Metadata and text contents of the malicious batch to console
  2. Deletes all references to the batch object from current training cycle
  3. Deletes all references to loss output from current training cycle if it exists
  4. Empty CUDA cache to free up stale memory
  5. Continue training with the next batch

This should improve the training process in various aspects, such as:

  • Easier identification of dataset content which is causing issues
  • Training can still resume and eventually finish training the epoch in case there are just a few data elements casing issues.

Additionally, I discovered there is a Grad Scaling step missing in the Pre-Training OOM check, which explicitly was the code where the OOM Exception happened during training in my case. So I added thos as well

@lifeiteng
Copy link
Owner

you can load the batch data display_and_save_batch(batch, params=params) by torch.load(xxxx), the analysis the data(mainly the duration)

@lifeiteng
Copy link
Owner

@RuntimeRacer should we close this?

@chenjiasheng
Copy link
Collaborator

chenjiasheng commented May 12, 2023

I've applied similar approach in my private training code some days ago.

But I have to point out that this kind of fix can't restore from OOM errors that happen during backward.
I use a while loop to deliberately exhaust every 100M of rank0's GPU memory just before scaler.scale(loss).backward() so as to trigger OOM of backward stage. The result shows that rank0 can enter the exception block, but other ranks hang for ever.
I further tried another approach, I tried using the timeout argument of the torch.dist.initialize_process_group() method, hoping that the ranks stuck in backward() would raise a TimeoutError. However, they don't.

According to my experience, this PR can avoid OOM hangs for AR model training, but not for NAR model.
I further confirmed the reason, because the backward of the NAR model needs to allocate a larger chunk of memory, up to 480M. In case of serious memory fragmentation, OOM will appear even if only 28G/32G of the GPU memory is currently used.

I really want to solve this problem, too. But I don't know how.
For my workaround right now, since the OOM part of the forward part can be recovered in this way, I create a big temporary tensor which took up 1G of contiguous memory before the forward stage, and released it after forward is done, making this 1G memory exclusive used by backword.

And one more thing, explicit cleanup of anything is no need, most of the memory will be released when leaving the exception handling block.

@RuntimeRacer
Copy link
Contributor Author

@chenjiasheng thank you for your detailed comment. Yes this issue probably needs a bit more in-depth analysis.

In my case it turned out that the error was thrown because of some token generation problem with non-latin characters; causing VRAM to be flooded. As pointed out by @lifeiteng here: #111 (comment), the model needs to be improved in regards to supporting symbols from various languages to overcome this.

However, the main purpose of this PR, to recover the training process in case of an unexpected OOM error - I still don't 100% understand why these errors were not detected by OOM check on start in the first place - is not 100% solved by this yet. So maybe this PR should be converted to draft state in the meantime.

@chenjiasheng
Copy link
Collaborator

As I pointed out, we should also continue other ranks that didn't encounter an OOM error, to avoid inconsistency among ranks. The inconsistency may be caused by following processes like update_averaged_model, scaler.update.
I think inconsistency is more harmful than just a failure.

@RuntimeRacer
Could you please:

  1. modify the code as I suggested
    is_oom = 0
    try:
        ...
    except:  # noqa
        is_oom = 1
    is_any_rank_oom = torch.tensor(is_oom).cuda(local_rank)
    torch.distributed.all_reduce(is_any_rank_oom)
    if is_any_rank_oom.item() > 0:
        continue
  1. carry out some tests on:
    1. different OOM time point (during forwarding or backwarding)
    2. single-rank and multi-ranks training
    3. different numbers of ranks that encounter OOM error simultaneously.

@RuntimeRacer
Copy link
Contributor Author

Closing this since the issue was related to input data format / combo itself, not the way errors are being handled.
Also I currently don't have the time to do a detailed evaluation as proposed by chenjiasheng.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants