Skip to content

Commit

Permalink
[Bug] Replace mem leaking torch gaussian_blur in augmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Dec 18, 2024
1 parent 48873e0 commit f9cadb7
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
rev: v0.8.3
hooks:
- id: ruff
args: [ --fix ]
Expand Down
21 changes: 11 additions & 10 deletions doctr/transforms/functional/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
from scipy.ndimage import gaussian_filter
from torchvision.transforms import functional as F

from doctr.utils.geometry import rotate_abs_geoms
Expand Down Expand Up @@ -113,24 +114,24 @@ def crop_detection(


def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwargs) -> torch.Tensor:
"""Crop and image and associated bboxes
"""Apply a random shadow effect to an image using NumPy for blurring.
Args:
img: image to modify
opacity_range: the minimum and maximum desired opacity of the shadow
**kwargs: additional arguments to pass to `create_shadow_mask`
img: Image to modify (C, H, W) as a PyTorch tensor.
opacity_range: The minimum and maximum desired opacity of the shadow.
**kwargs: Additional arguments to pass to `create_shadow_mask`.
Returns:
shaded image
Shadowed image as a PyTorch tensor (same shape as input).
"""
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)

opacity = np.random.uniform(*opacity_range)
shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])

# Add some blur to make it believable
k = 7 + 2 * int(4 * np.random.rand(1))
# Apply Gaussian blur to the shadow mask
sigma = np.random.uniform(0.5, 5.0)
shadow_tensor = F.gaussian_blur(shadow_tensor, k, sigma=[sigma, sigma])
blurred_mask = gaussian_filter(shadow_mask, sigma=sigma)

shadow_tensor = 1 - torch.from_numpy(blurred_mask).float()
shadow_tensor = shadow_tensor.to(img.device).unsqueeze(0) # Add channel dimension

return opacity * shadow_tensor * img + (1 - opacity) * img
44 changes: 43 additions & 1 deletion doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@
import numpy as np
import torch
from PIL.Image import Image
from scipy.ndimage import gaussian_filter
from torch.nn.functional import pad
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T

from ..functional.pytorch import random_shadow

__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"]
__all__ = [
"Resize",
"GaussianNoise",
"ChannelShuffle",
"RandomHorizontalFlip",
"RandomShadow",
"RandomResize",
"GaussianBlur",
]


class Resize(T.Resize):
Expand Down Expand Up @@ -142,6 +151,39 @@ def extra_repr(self) -> str:
return f"mean={self.mean}, std={self.std}"


class GaussianBlur(torch.nn.Module):
"""Apply Gaussian Blur to the input tensor
>>> import torch
>>> from doctr.transforms import GaussianBlur
>>> transfo = GaussianBlur(sigma=(0.0, 1.0))
Args:
sigma : standard deviation range for the gaussian kernel
"""

def __init__(self, sigma: tuple[float, float]) -> None:
super().__init__()
self.sigma_range = sigma

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Sample a random sigma value within the specified range
sigma = torch.empty(1).uniform_(*self.sigma_range).item()

# Apply Gaussian blur along spatial dimensions only
blurred = torch.tensor(
gaussian_filter(
x.numpy(),
sigma=sigma,
mode="reflect",
truncate=4.0,
),
dtype=x.dtype,
device=x.device,
)
return blurred


class ChannelShuffle(torch.nn.Module):
"""Randomly shuffle channel order of a given image"""

Expand Down
6 changes: 3 additions & 3 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort
from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort
from tqdm.auto import tqdm

from doctr import transforms as T
Expand Down Expand Up @@ -261,12 +261,12 @@ def main(args):
img_transforms = T.OneOf([
Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.2),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2),
]),
Compose([
T.RandomApply(T.RandomShadow(), 0.3),
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.3),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3),
RandomGrayscale(p=0.15),
]),
RandomPhotometricDistort(p=0.3),
Expand Down
4 changes: 2 additions & 2 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ def main(args):
img_transforms = T.OneOf([
T.Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.2),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.5, 1.5)), 0.2),
]),
T.Compose([
T.RandomApply(T.RandomJpegQuality(60), 0.15),
# T.RandomApply(T.RandomShadow(), 0.2), # Broken atm on GPU
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.3),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.5, 1.5)), 0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.15),
]),
T.Compose([
Expand Down
33 changes: 33 additions & 0 deletions tests/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from doctr.transforms import (
ChannelShuffle,
ColorInversion,
GaussianBlur,
GaussianNoise,
RandomCrop,
RandomHorizontalFlip,
Expand Down Expand Up @@ -278,6 +279,38 @@ def test_gaussian_noise(input_dtype, input_shape):
assert torch.all(transformed <= 1.0)


@pytest.mark.parametrize(
"input_dtype, input_shape",
[
[torch.float32, (3, 32, 32)],
[torch.uint8, (3, 32, 32)],
],
)
def test_gaussian_blur(input_dtype, input_shape):
sigma_range = (0.0, 1.0)
transform = GaussianBlur(sigma=sigma_range)

input_t = torch.rand(input_shape, dtype=torch.float32)

if input_dtype == torch.uint8:
input_t = (255 * input_t).round().to(dtype=torch.uint8)

blurred = transform(input_t)

assert isinstance(blurred, torch.Tensor)
assert blurred.shape == input_shape
assert blurred.dtype == input_dtype

if input_dtype == torch.uint8:
assert torch.any(blurred != input_t)
assert torch.all(blurred <= 255)
assert torch.all(blurred >= 0)
else:
assert torch.any(blurred != input_t)
assert torch.all(blurred <= 1.0)
assert torch.all(blurred >= 0.0)


@pytest.mark.parametrize(
"p,target",
[
Expand Down

0 comments on commit f9cadb7

Please sign in to comment.