Skip to content

FilterbankFeatures may return NaNs on CUDA device - torch autocast problem #11541

Closed
@yoadsn

Description

@yoadsn

Describe the bug

When autocast region is enabled, the FilterbankFeatures featurizer, during the forward pass may generate NaN frames.

Steps/Code to reproduce bug

As probably is known, on CUDA devices, matmul operations are cast to half when feature compat allows for this data type.
So the following is expected:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Assume "cuda" is available
with torch.amp.autocast(device.type):
        f = torch.tensor([0.0], dtype=torch.float32).to(device)
        x1 = torch.tensor([65519], dtype=torch.float32).to(device)
        r1 = torch.matmul(f, x1)
        print(r1) # OK - returns 0 tensor

        x2 = torch.tensor([65520], dtype=torch.float32).to(device)
        r2 = torch.matmul(f, x2)
        print(r2) # Oops - returns NaN tensor

This is not a bug, just the behavior of autocast with an F16 capable GPU. The mamul operations is done in f16, 65520 is a NaN with this data type.

Here is a repro of the bug though:

import torch
from nemo.collections.asr.parts.utils.vad_utils import init_frame_vad_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Assume "cuda" is available
if __name__ == "__main__":
    vad_model = init_frame_vad_model("vad_multilingual_frame_marblenet")
    vad_model = vad_model.to(device)
    vad_model.eval()

    x = torch.tensor(
        [
            -0.9058,-0.5946,1.2434,-0.3962,-1.7921,2.4021,-1.4928,-0.0365,0.4774,0.2091,-2.2784,2.6600,
            -1.1311,-2.3960,4.6964,-4.2161,0.3519,2.6062,-3.5229,1.2587,1.6692,-4.6367,4.4893,-2.8256,
            -0.9259,3.5737,-4.0703,2.4223,-0.2424,-2.2224,1.6835,0.1224,-2.6710,3.6783,-3.3401,0.8552,
            1.9303,-4.3002,4.6046,-4.0270,0.7077,2.5835,-4.6308,3.4596,-0.4729,-2.3140,2.9415,-2.1423,
            -0.4900,2.3685,-3.8594,2.8201,-0.0154,-2.5574,4.3913,0.0000,
        ], # chose those numbers since they produce the wanted stft result which would trigger the casting problem
        dtype=torch.float32,
    )
    x = torch.concatenate([x, torch.zeros(257 - x.shape[0])], axis=0) # pad up to minimal input signal size with zeroes
    x = x.unsqueeze(0).to(device)
    seq_len = torch.tensor([x.shape[1]], dtype=x.dtype).to(device)

    with torch.amp.autocast(vad_model.device.type):
        r, s = vad_model.preprocessor.featurizer(x, seq_len) # this internally does the matmul
        print(torch.isnan(r).cpu().any().item()) # We get a True here - there is a NaN in the result

I know the numbers seem arbitrary - I pulled them from an audio sample that had the problem - but on my audio dataset we get 1-2 NaN situations per 10 min segment - so it's not that rare.
Audio comes from MP3 files, perhaps this produces artifacts in the STFT output - still, NaN should not be reasonable output feature in any case.

Expected behavior

Non NaN features produced if no NaN audio signal is presented.

Environment overview (please complete the following information)

  • Environment location: Cloud, AWS EC2 instance, (g4dn.xlarge, T4 GPU) , AMI: NVIDIA GPU Cloud VMI Base 2024.10.1 x86_64-676eed8d-dcf5-4784-87d7-0de463205c17
  • Method of NeMo install: pip install using nemo_toolkit[asr]

Environment details

OS Version: Linux ip-172-31-47-154 6.8.0-1017-aws _18~22.04.1-Ubuntu SMP Thu Oct 3 19:57:42 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux
Python Version: Python 3.10.12
PyTorch Version: 2.5.1+cu124
CUDA Version: 12.4

Additional context

I would like to explain the cause of the bug.

The FilterbankFeatures forward has the following code (removed parts for brevity):

def forward(self, x, seq_len, linear_spec=False):
        seq_len = self.get_seq_len(seq_len)

        # ... parts removed

        # disable autocast to get full range of stft values
#=>>
        with torch.amp.autocast(x.device.type, enabled=False):
            x = self.stft(x)

        # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
        # guard is needed for sqrt if grads are passed through
        guard = 0 if not self.use_grads else CONSTANT
        x = torch.view_as_real(x)
        x = torch.sqrt(x.pow(2).sum(-1) + guard)

        # ... parts removed

        # dot with filterbank energies
#=>>
        x = torch.matmul(self.fb.to(x.dtype), x)
        
        # ... parts removed

        return x, seq_len

Note the two highlighted sections.
The first, disabled the autocast - so the stftf can get "the full range" of values. This on my GPU produces a float64.
Later operations until the second marked tow operate in float32.
The second row does a matmul, this, in many cases operates within an active autocast region - and will, on my GPU downcast to f16.
When x contains values which are outside f16 range, the matmul produces unexpected results, among them are NaN results.

On such scenario, is the Frame VAD inference, where the vad util activates an autocast region around the inference, and the vad model will internally, before inference featurize the raw audio signal.
Here is the util, as an example that activates the autocast region from the NeMo repo (generate_vad_frame_pred) code.

One solution I suggest - and that I locally tested with a patch to NeMo code seems to work is to exclude that matmul itseld from the autocast.

# dot with filterbank energies
+ with torch.amp.autocast(x.device.type, enabled=False):
    x = torch.matmul(self.fb.to(x.dtype), x)

And right after that, clip the values or otherwise ensure no downstream downcasting impacts those values.

Metadata

Metadata

Assignees

Labels

ASRbugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions