Skip to content

Commit

Permalink
Fix #45: Memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
BerndDoser committed Nov 17, 2023
1 parent baa1f2f commit 0be564c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
58 changes: 34 additions & 24 deletions callbacks/log_reconstruction_callback.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import matplotlib.pyplot as plt
import gc

import matplotlib
import torch
import torchvision.transforms.functional as functional
from lightning.pytorch.callbacks import Callback
from matplotlib import figure

matplotlib.use('Agg')


class LogReconstructionCallback(Callback):
Expand All @@ -27,24 +32,20 @@ def on_train_epoch_end(self, trainer, pl_module):
batch_size = samples.shape[0]
losses = torch.zeros(batch_size, pl_module.rotations)
images = torch.zeros(
(
batch_size,
3,
pl_module.input_size,
pl_module.input_size,
pl_module.rotations,
)
batch_size,
3,
pl_module.input_size,
pl_module.input_size,
pl_module.rotations,
)
recons = torch.zeros(
(
batch_size,
3,
pl_module.input_size,
pl_module.input_size,
pl_module.rotations,
)
batch_size,
3,
pl_module.input_size,
pl_module.input_size,
pl_module.rotations,
)
coords = torch.zeros((batch_size, pl_module.z_dim, pl_module.rotations))
coords = torch.zeros(batch_size, pl_module.z_dim, pl_module.rotations)
for r in range(pl_module.rotations):
rotate = functional.rotate(
samples, 360.0 / pl_module.rotations * r, expand=False
Expand All @@ -66,15 +67,24 @@ def on_train_epoch_end(self, trainer, pl_module):
min_idx = torch.min(losses, dim=1)[1]

# Plot the original samples and their reconstructions side by side
fig, axs = plt.subplots(self.num_samples, 2, figsize=(6, 2 * self.num_samples))
fig = figure.Figure(figsize=(6, 2 * self.num_samples))
ax = fig.subplots(self.num_samples, 2)
for i in range(self.num_samples):
axs[i, 0].imshow(images[i, :, :, :, min_idx[i]].cpu().detach().numpy().T)
axs[i, 0].set_title("Original")
axs[i, 0].axis("off")
axs[i, 1].imshow(recons[i, :, :, :, min_idx[i]].cpu().detach().numpy().T)
axs[i, 1].set_title("Reconstruction")
axs[i, 1].axis("off")
plt.tight_layout()
ax[i, 0].imshow(images[i, :, :, :, min_idx[i]].cpu().detach().numpy().T)
ax[i, 0].set_title("Original")
ax[i, 0].axis("off")
ax[i, 1].imshow(recons[i, :, :, :, min_idx[i]].cpu().detach().numpy().T)
ax[i, 1].set_title("Reconstruction")
ax[i, 1].axis("off")
fig.tight_layout()

# Log the figure at W&B
trainer.logger.log_image(key="Reconstructions", images=[fig])

# Clear the figure and free memory
# Memory leak issue: https://github.com/matplotlib/matplotlib/issues/27138
for i in range(self.num_samples):
ax[i, 0].clear()
ax[i, 1].clear()
fig.clear()
gc.collect()
2 changes: 1 addition & 1 deletion tests/test_log_reconstruction_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_on_train_epoch_end(z_dim):
model = RotationalVariationalAutoencoderPower(z_dim=z_dim)

datamodule = ShapesDataModule(
"tests/data/shapes", num_workers=1, batch_size=12, shuffle=False
"tests/data/shapes", batch_size=12, shuffle=False
)
datamodule.setup("fit")
# data_loader = data_module.train_dataloader()
Expand Down

0 comments on commit 0be564c

Please sign in to comment.