Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FilterbankFeatures may return NaNs on CUDA device - torch autocast problem #11541

Open
yoadsn opened this issue Dec 11, 2024 · 0 comments
Open
Labels
bug Something isn't working

Comments

@yoadsn
Copy link

yoadsn commented Dec 11, 2024

Describe the bug

When autocast region is enabled, the FilterbankFeatures featurizer, during the forward pass may generate NaN frames.

Steps/Code to reproduce bug

As probably is known, on CUDA devices, matmul operations are cast to half when feature compat allows for this data type.
So the following is expected:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Assume "cuda" is available
with torch.amp.autocast(device.type):
        f = torch.tensor([0.0], dtype=torch.float32).to(device)
        x1 = torch.tensor([65519], dtype=torch.float32).to(device)
        r1 = torch.matmul(f, x1)
        print(r1) # OK - returns 0 tensor

        x2 = torch.tensor([65520], dtype=torch.float32).to(device)
        r2 = torch.matmul(f, x2)
        print(r2) # Oops - returns NaN tensor

This is not a bug, just the behavior of autocast with an F16 capable GPU. The mamul operations is done in f16, 65520 is a NaN with this data type.

Here is a repro of the bug though:

import torch
from nemo.collections.asr.parts.utils.vad_utils import init_frame_vad_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Assume "cuda" is available
if __name__ == "__main__":
    vad_model = init_frame_vad_model("vad_multilingual_frame_marblenet")
    vad_model = vad_model.to(device)
    vad_model.eval()

    x = torch.tensor(
        [
            -0.9058,-0.5946,1.2434,-0.3962,-1.7921,2.4021,-1.4928,-0.0365,0.4774,0.2091,-2.2784,2.6600,
            -1.1311,-2.3960,4.6964,-4.2161,0.3519,2.6062,-3.5229,1.2587,1.6692,-4.6367,4.4893,-2.8256,
            -0.9259,3.5737,-4.0703,2.4223,-0.2424,-2.2224,1.6835,0.1224,-2.6710,3.6783,-3.3401,0.8552,
            1.9303,-4.3002,4.6046,-4.0270,0.7077,2.5835,-4.6308,3.4596,-0.4729,-2.3140,2.9415,-2.1423,
            -0.4900,2.3685,-3.8594,2.8201,-0.0154,-2.5574,4.3913,0.0000,
        ], # chose those numbers since they produce the wanted stft result which would trigger the casting problem
        dtype=torch.float32,
    )
    x = torch.concatenate([x, torch.zeros(257 - x.shape[0])], axis=0) # pad up to minimal input signal size with zeroes
    x = x.unsqueeze(0).to(device)
    seq_len = torch.tensor([x.shape[1]], dtype=x.dtype).to(device)

    with torch.amp.autocast(vad_model.device.type):
        r, s = vad_model.preprocessor.featurizer(x, seq_len) # this internally does the matmul
        print(torch.isnan(r).cpu().any().item()) # We get a True here - there is a NaN in the result

I know the numbers seem arbitrary - I pulled them from an audio sample that had the problem - but on my audio dataset we get 1-2 NaN situations per 10 min segment - so it's not that rare.
Audio comes from MP3 files, perhaps this produces artifacts in the STFT output - still, NaN should not be reasonable output feature in any case.

Expected behavior

Non NaN features produced if no NaN audio signal is presented.

Environment overview (please complete the following information)

  • Environment location: Cloud, AWS EC2 instance, (g4dn.xlarge, T4 GPU) , AMI: NVIDIA GPU Cloud VMI Base 2024.10.1 x86_64-676eed8d-dcf5-4784-87d7-0de463205c17
  • Method of NeMo install: pip install using nemo_toolkit[asr]

Environment details

OS Version: Linux ip-172-31-47-154 6.8.0-1017-aws _18~22.04.1-Ubuntu SMP Thu Oct 3 19:57:42 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux
Python Version: Python 3.10.12
PyTorch Version: 2.5.1+cu124
CUDA Version: 12.4

Additional context

I would like to explain the cause of the bug.

The FilterbankFeatures forward has the following code (removed parts for brevity):

def forward(self, x, seq_len, linear_spec=False):
        seq_len = self.get_seq_len(seq_len)

        # ... parts removed

        # disable autocast to get full range of stft values
#=>>
        with torch.amp.autocast(x.device.type, enabled=False):
            x = self.stft(x)

        # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
        # guard is needed for sqrt if grads are passed through
        guard = 0 if not self.use_grads else CONSTANT
        x = torch.view_as_real(x)
        x = torch.sqrt(x.pow(2).sum(-1) + guard)

        # ... parts removed

        # dot with filterbank energies
#=>>
        x = torch.matmul(self.fb.to(x.dtype), x)
        
        # ... parts removed

        return x, seq_len

Note the two highlighted sections.
The first, disabled the autocast - so the stftf can get "the full range" of values. This on my GPU produces a float64.
Later operations until the second marked tow operate in float32.
The second row does a matmul, this, in many cases operates within an active autocast region - and will, on my GPU downcast to f16.
When x contains values which are outside f16 range, the matmul produces unexpected results, among them are NaN results.

On such scenario, is the Frame VAD inference, where the vad util activates an autocast region around the inference, and the vad model will internally, before inference featurize the raw audio signal.
Here is the util, as an example that activates the autocast region from the NeMo repo (generate_vad_frame_pred) code.

One solution I suggest - and that I locally tested with a patch to NeMo code seems to work is to exclude that matmul itseld from the autocast.

# dot with filterbank energies
+ with torch.amp.autocast(x.device.type, enabled=False):
    x = torch.matmul(self.fb.to(x.dtype), x)

And right after that, clip the values or otherwise ensure no downstream downcasting impacts those values.

@yoadsn yoadsn added the bug Something isn't working label Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant