Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU ram memory increase until overflow when using PSNR and SSIM #2597

Closed
ouioui199 opened this issue Jun 14, 2024 · 3 comments
Closed

GPU ram memory increase until overflow when using PSNR and SSIM #2597

ouioui199 opened this issue Jun 14, 2024 · 3 comments
Assignees
Labels
bug / fix Something isn't working question Further information is requested v1.3.x

Comments

@ouioui199
Copy link

ouioui199 commented Jun 14, 2024

🐛 Bug

Hello all,

I'm implementing CycleGAN with Lightning. I use PSNR and SSIM from torchmetrics for evaluation.
During training, I see that my GPU ram memory increases non stop until overflow and the whole training shuts down.
This might similar to #2481

To Reproduce

Add this to init method of model class:

self.train_metrics = MetricCollection({"PSNR": PeakSignalNoiseRatio(), "SSIM": StructuralSimilarityIndexMeasure()})
self.valid_metrics = self.train_metrics.clone(prefix='val_')

In training_step method:
train_metrics = self.train_metrics(fake, real)

In validation_step method:
valid_metrics = self.valid_metrics(fake, real)

Environment

  • TorchMetrics version: 1.3.0 installed via pip
  • Python: 3.11.7
  • Pytorch: 2.1.2
  • Issue encountered when training on Window 10

Easy fix proposition

I try to debug the code.
When verifying train_metrics, I get this:

"{'PSNR': tensor(10.5713, device='cuda:0', grad_fn=<SqueezeBackward0>), 'SSIM': tensor(0.0373, device='cuda:0', grad_fn=<SqueezeBackward0>)}"

which is weird because metrics aren't supposed to be attached to computational graph.
When verifying valid_metrics, I don't see grad_fn.
Guessing that's the issue, I tried to call fake.detach() when computing train_metrics.
Now the training is stable, the GPU memory stops increasing non stop.

@ouioui199 ouioui199 added bug / fix Something isn't working help wanted Extra attention is needed labels Jun 14, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda
Copy link
Member

Borda commented Aug 21, 2024

@ouioui199 looking at your example (could you pls share the full sample code?) and wondering if you in the epoch end hook also call compute?

@Borda Borda added question Further information is requested and removed help wanted Extra attention is needed labels Aug 21, 2024
@SkafteNicki
Copy link
Member

Hi, sorry for the slow reply from my side.
let's decouple the issue into two sub issues: first being the gradient still having the computational graph and the second being the increasing memory.

Computational graph

I am not sure where you have got this from but the computational graph is not auto detached from the metric in general. It depends on the is_differentiable attribute of the metric. For both PeakSignalNoiseRatio() and StructuralSimilarityIndexMeasure() this attribute is True because users may want to use these metrics as loss functions for backpropagation. Thus,

import torch
from torchmetrics.image import PeakSignalNoiseRatio
psnr = PeakSignalNoiseRatio()
print(psnr.is_differentiable)  # returns True
preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True)
target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
print(psnr(preds, target))  # returns tensor(2.5527, grad_fn=<CloneBackward0>)

whereas a metric such as Classification will not have the computational graph equipped because it is not differentiable

from torchmetrics.classification import Accuracy
accuracy = Accuracy(task="multiclass", num_classes=4)
print(accuracy.is_differentiable)  # returns False
preds = torch.tensor([0, 1, 2, 3], requires_grad=True, dtype=torch.float)
target = torch.tensor([0, 2, 2, 3])
print(accuracy(preds, target))  # returns tensor(0.7500)

so call detach before evaluating the metric would be the correct approach here.

Memory keeps increasing

It is fairly well laid out in the documentation that there are two different kind of metrics in torchmetrics, metrics with constant memory and metrics with increasing memory:
https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metrics-and-memory-management
Sadly, PeakSignalNoiseRatio() and StructuralSimilarityIndexMeasure() out of the box falls into the second category e.g. it is simply not possible to correctly calculate the metrics over multiple batches without holding some internal state per batch in memory. As more batches are seen, this state increases. So what you are seeing is fully expected.
However, for these two metrics there are ways to have constant memory by setting the appropriate arguments. For both PeakSignalNoiseRatio() and StructuralSimilarityIndexMeasure() this argument is called reduction. If this is set, we reduce the batches into a constant state that prevents memory overflow, but it depends on the application what kind of reduction is appropriate. So initializing the metrics as

from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
psnr = PeakSignalNoiseRatio(reduction="elementwise_mean")
ssim = StructuralSimilarityIndexMeasure(reduction="elementwise_mean")

will solve this issue.

Closing issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working question Further information is requested v1.3.x
Projects
None yet
Development

No branches or pull requests

3 participants