From 0be564c6ce18e204468f9e356d4e7c96d6c0bc97 Mon Sep 17 00:00:00 2001 From: Bernd Doser Date: Fri, 17 Nov 2023 14:32:41 +0100 Subject: [PATCH] Fix #45: Memory leak --- callbacks/log_reconstruction_callback.py | 58 +++++++++++++---------- tests/test_log_reconstruction_callback.py | 2 +- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/callbacks/log_reconstruction_callback.py b/callbacks/log_reconstruction_callback.py index 11252d4..a8de913 100644 --- a/callbacks/log_reconstruction_callback.py +++ b/callbacks/log_reconstruction_callback.py @@ -1,7 +1,12 @@ -import matplotlib.pyplot as plt +import gc + +import matplotlib import torch import torchvision.transforms.functional as functional from lightning.pytorch.callbacks import Callback +from matplotlib import figure + +matplotlib.use('Agg') class LogReconstructionCallback(Callback): @@ -27,24 +32,20 @@ def on_train_epoch_end(self, trainer, pl_module): batch_size = samples.shape[0] losses = torch.zeros(batch_size, pl_module.rotations) images = torch.zeros( - ( - batch_size, - 3, - pl_module.input_size, - pl_module.input_size, - pl_module.rotations, - ) + batch_size, + 3, + pl_module.input_size, + pl_module.input_size, + pl_module.rotations, ) recons = torch.zeros( - ( - batch_size, - 3, - pl_module.input_size, - pl_module.input_size, - pl_module.rotations, - ) + batch_size, + 3, + pl_module.input_size, + pl_module.input_size, + pl_module.rotations, ) - coords = torch.zeros((batch_size, pl_module.z_dim, pl_module.rotations)) + coords = torch.zeros(batch_size, pl_module.z_dim, pl_module.rotations) for r in range(pl_module.rotations): rotate = functional.rotate( samples, 360.0 / pl_module.rotations * r, expand=False @@ -66,15 +67,24 @@ def on_train_epoch_end(self, trainer, pl_module): min_idx = torch.min(losses, dim=1)[1] # Plot the original samples and their reconstructions side by side - fig, axs = plt.subplots(self.num_samples, 2, figsize=(6, 2 * self.num_samples)) + fig = figure.Figure(figsize=(6, 2 * self.num_samples)) + ax = fig.subplots(self.num_samples, 2) for i in range(self.num_samples): - axs[i, 0].imshow(images[i, :, :, :, min_idx[i]].cpu().detach().numpy().T) - axs[i, 0].set_title("Original") - axs[i, 0].axis("off") - axs[i, 1].imshow(recons[i, :, :, :, min_idx[i]].cpu().detach().numpy().T) - axs[i, 1].set_title("Reconstruction") - axs[i, 1].axis("off") - plt.tight_layout() + ax[i, 0].imshow(images[i, :, :, :, min_idx[i]].cpu().detach().numpy().T) + ax[i, 0].set_title("Original") + ax[i, 0].axis("off") + ax[i, 1].imshow(recons[i, :, :, :, min_idx[i]].cpu().detach().numpy().T) + ax[i, 1].set_title("Reconstruction") + ax[i, 1].axis("off") + fig.tight_layout() # Log the figure at W&B trainer.logger.log_image(key="Reconstructions", images=[fig]) + + # Clear the figure and free memory + # Memory leak issue: https://github.com/matplotlib/matplotlib/issues/27138 + for i in range(self.num_samples): + ax[i, 0].clear() + ax[i, 1].clear() + fig.clear() + gc.collect() diff --git a/tests/test_log_reconstruction_callback.py b/tests/test_log_reconstruction_callback.py index 91130c1..17e5c9b 100644 --- a/tests/test_log_reconstruction_callback.py +++ b/tests/test_log_reconstruction_callback.py @@ -49,7 +49,7 @@ def test_on_train_epoch_end(z_dim): model = RotationalVariationalAutoencoderPower(z_dim=z_dim) datamodule = ShapesDataModule( - "tests/data/shapes", num_workers=1, batch_size=12, shuffle=False + "tests/data/shapes", batch_size=12, shuffle=False ) datamodule.setup("fit") # data_loader = data_module.train_dataloader()