Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU out of memory problem during detector.fit() #93

Open
dingw00 opened this issue Jan 10, 2025 · 6 comments
Open

GPU out of memory problem during detector.fit() #93

dingw00 opened this issue Jan 10, 2025 · 6 comments

Comments

@dingw00
Copy link

dingw00 commented Jan 10, 2025

Hello, first of all, thanks a lot for your dedication and contributions to developing this library of OoD detectors. It helps a lot with our research. It’s greatly appreciated.

I came across a small issue: When I tried to do detector.fit() on ImageNet-1k dataset and I intended to fit the detectors using GPU, the GPUs run out of memory. The reason is that the 50000 embeddings and labels are all stored on GPU memory, which is quite demanding even for running on a server.

I think that a more practical way considering the limitation of GPU memory is to store the embeddings and labels on CPU. And when processing them in the upcoming detector.fit_features() procedure, iterate through them and transfer each batch of (embeddings, labels) to the GPUs.

Related codes in pytorch_ood.utils:

def extract_features(
    data_loader: DataLoader, model: Callable[[Tensor], Tensor], device: Optional[str]
) -> Tuple[Tensor, Tensor]:
    """
    Helper to extract outputs from model. Ignores OOD inputs.

    :param data_loader: dataset to extract from
    :param model: neural network to pass inputs to
    :param device: device used for calculations
    :return: Tuple with outputs and labels
    """
    # TODO: add option to buffer to GPU
    buffer = TensorBuffer()

    with torch.no_grad():
        for batch in data_loader:
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            known = is_known(y)
            if known.any():
                z = model(x[known])
                z = z.view(known.sum(), -1)  # flatten
                buffer.append("embedding", z)
                buffer.append("label", y[known])

        if buffer.is_empty():
            raise ValueError("No IN instances in loader")

    z = buffer.get("embedding")
    y = buffer.get("label")
    return z, y
@kkirchheim
Copy link
Owner

Hello, thank you for reporting this problem.

I am currently not sure if I understand what is happening.

I think that a more practical way considering the limitation of GPU memory is to store the embeddings and labels on CPU. And when processing them in the upcoming detector.fit_features() procedure, iterate through them and transfer each batch of (embeddings, labels) to the GPUs.

I completely agree with this sentiment. To my understanding, this is how it is implemented at the moment.

In the extract_features method, the labels and embeddings are stored in a TensorBuffer. By default, this buffer will store everything on the buffer.device, which already defaults to "cpu" see here (that is why buffering to GPU is noted as a TODO). This means that GPU VRAM should remain constant during feature extraction. If the OOM really occurs in extract_features, I would suspect that the batch-size is too large.

Maybe, there is a bug somewhere that I can not see right now. Which detector are you using exactly? If your code still causes an OOM, even with a small batch-size, could you provide me with a minimal example where the VRAM usage of the extract_feature method increases with time?

Could it be possible that the OOM does not occur in extract_features but in the following call to detector.fit_features()?

@dingw00
Copy link
Author

dingw00 commented Jan 12, 2025

Hello,

Thank you for your attention and response. I examined the problem again and I realized that there's no OOM problem with extract_features(). Sorry for my bad problem localization. Yes, you are right, the TensorBuffer stores the data on CPU by default. The OOM problem is related to specific detectors in detector.fit_features().

Specifically, I have OOM problem when fitting the following detectors using GPU:

  1. KLMatching.fit_features(device='cuda')
    Caused by logits, labels = logits.to(device), labels.to(device) in line 75.
  2. SHE.fit(device='cuda)
    Caused by return self.fit_features(x.to(device), y.to(device)) in line 73.
  3. Mahalanobis/RMD.fit_features(device=’cuda’)
    Caused by
z, y = z.to(device), y.to(device)
...
self.mu = torch.zeros(size=(n_classes, z.shape[-1]), device=device)
self.cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device)

from line 92.
Since I have n_classes=1000, z.shape[-1]=1024, Mahalanobis detector's algorithm is using too much memory even if I fit with CPU, causing the program to crash.

Fitting the above-mentioned OOD detectors using GPU would probably cause an OOM problem. It depends on the VRAM size in practce. To circumvent this problem, I can fit with CPU instead. Here I got the error: model and x not on the same device in extracture_features(): model(x). Because the CNN model was set to GPU VRAM. Anyway, at present I can circumvent this issue again by doing

model.to(’cpu’)
detector.fit(data_loader, device=cpu’)
model.to(’cuda’)

To avoid such errors coming up, could you also consider ensuring the model's device, for example by adding model.to(device) along with x.to(device) when we reference or fit the detectors detector(x)/detector.fit(data_loader) ?

Another small issue is concerning MCD.fit(). This method doesn't take any other arguments like device, which caused error when I made a unified reference detector.fit(device='cuda') to a set of OOD detectors I was evaluating. I suggest to make it the same as what you did for MaxSoftmax detector:

def fit_features(self: Self, *args, **kwargs) -> Self:
    """
    Not required
    """
    raise self

def fit(self: Self, *args, **kwargs) -> Self:
    """
    Not required
    """
    return self

@dingw00 dingw00 closed this as completed Jan 12, 2025
@dingw00 dingw00 reopened this Jan 12, 2025
@kkirchheim
Copy link
Owner

Okay, so the problem for KLMatching and SHE seems to be that after the feature extraction, all of the data is moved to the device at once.

Since I have n_classes=1000, z.shape[-1]=1024, Mahalanobis detector's algorithm is using too much memory even if I fit with CPU, causing the program to crash.

If I am not mistaken, the centers self.mu would allocate 1000 $\times$ 1024 $\times$ 4 bytes = 4MB, similar for self.cov. This is unlikely to cause your OOM. So the overall problem, again, is probably moving all of the features at once to the device.

For some of the detectors, implementing something like batch-processing would probably be possible. For example, KLMatching could fit on a per-class basis, reducing VRAM usage by factor 1000 in your use-case. Similar for SHE. But: this would slow things down.

On the other hand, if you have like 1,000,000 instances (say, ImageNet training set) with 1024 features, that will only be $\approx$ 4GB of memory. So I am not sure why you run into OOM with only 50k instances. Still, I will likely implement this and we will see if it resolves the problem.

To avoid such errors coming up, could you also consider ensuring the model's device, for example by adding model.to(device)

This is probably a good idea. It might be confusing to some that calling the detector has the side-effect of moving the model to a different device, but for now it is probably to most easy-to-use variant.

Another small issue is concerning MCD.fit()

Good point. I will add catchall **kwargs to MCD.

@kkirchheim
Copy link
Owner

kkirchheim commented Jan 13, 2025

I addressed some of the issues. Could you install the branch containing the fixes with

pip install git+https://github.com/kkirchheim/pytorch-ood.git@93-gpu-out-of-memory-problem-during-detectorfit

and test if you still run OOM?

@dingw00
Copy link
Author

dingw00 commented Jan 14, 2025

On the other hand, if you have like 1,000,000 instances (say, ImageNet training set) with 1024 features, that will only be ≈4GB of memory. So I am not sure why you run into OOM with only 50k instances.

I got the reason. It was not the bug of the detectors. It was actually caused by the feature map of ViT and SWIN models that I used (I used the features before average pooling). The features are 50,000491024 in total, ≈10GB. I have switched to the 50,000*1024 features after average pooling. With this number of features, I won't have any OOM problems.

Now all the detectors fit well with GPU !


The problem is the device inconsistency error now.

This is probably a good idea. It might be confusing to some that calling the detector has the side-effect of moving the model to a different device, but for now it is probably to most easy-to-use variant.

Yes, it's true. To avoid this side-effect, why not move the model (backbone) back to its original device after the calculations inside the detectors, so that the users won't see any change in model.device from the outside ?

However, I come across this error running SHE.fit(). The detector.backbone is defined as Callable in your library, not nn.Module, thus it is not possible to switch model.device.
image
Do you have any other idea to deal with model.device switching?

There are still some errors with Mahalanobis detector when I try Mahalanobis.fit(device='cpu').
In extract_features(): z = model(x[known]), model and x are not on the same device.

Besides, in SHE.py, line 99: mask = y_hat_batch == y_batch, if I fit with GPU, it is possible to cause an error, because y_hat_batch is on GPU while y_batch is still on CPU.

@kkirchheim
Copy link
Owner

why not move the model (backbone) back to its original device

There might be some problems in cases where models are distributed across different devices. Anyway, I think it makes sense to just accept the fact that the underlying model can be moved to a different device for now. As long as this is documented, I think it should be fine.

The detector.backbone is defined as Callable in your library, not nn.Module, thus it is not possible to switch model.device.

That is true, I will add a typecheck before the to(device) call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

2 participants