Skip to content

Commit

Permalink
✅ Add tests for the corruptions
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Jan 3, 2024
1 parent ad06558 commit 57dab0d
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 1 deletion.
89 changes: 89 additions & 0 deletions tests/transforms/test_corruptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import torch

from torch_uncertainty.transforms.corruptions import (
DefocusBlur,
Frost,
GaussianNoise,
GlassBlur,
ImpulseNoise,
JPEGCompression,
Pixelate,
ShotNoise,
)


class TestCorruptions:
"""Testing the Corruptions transform."""

def test_gaussian_noise(self):
with pytest.raises(ValueError):
_ = GaussianNoise(-1)
with pytest.raises(TypeError):
_ = GaussianNoise(0.1)
inputs = torch.rand(3, 32, 32)
transform = GaussianNoise(1)
transform(inputs)

def test_shot_noise(self):
with pytest.raises(ValueError):
_ = ShotNoise(-1)
with pytest.raises(TypeError):
_ = ShotNoise(0.1)
inputs = torch.rand(3, 32, 32)
transform = ShotNoise(1)
transform(inputs)

def test_impulse_noise(self):
with pytest.raises(ValueError):
_ = ImpulseNoise(-1)
with pytest.raises(TypeError):
_ = ImpulseNoise(0.1)
inputs = torch.rand(3, 32, 32)
transform = ImpulseNoise(1)
transform(inputs)

def test_glass_blur(self):
with pytest.raises(ValueError):
_ = GlassBlur(-1)
with pytest.raises(TypeError):
_ = GlassBlur(0.1)
inputs = torch.rand(3, 32, 32)
transform = GlassBlur(1)
transform(inputs)

def test_defocus_blur(self):
with pytest.raises(ValueError):
_ = DefocusBlur(-1)
with pytest.raises(TypeError):
_ = DefocusBlur(0.1)
inputs = torch.rand(3, 32, 32)
transform = DefocusBlur(1)
transform(inputs)

def test_jpeg_compression(self):
with pytest.raises(ValueError):
_ = JPEGCompression(-1)
with pytest.raises(TypeError):
_ = JPEGCompression(0.1)
inputs = torch.rand(3, 32, 32)
transform = JPEGCompression(1)
transform(inputs)

def test_pixelate(self):
with pytest.raises(ValueError):
_ = Pixelate(-1)
with pytest.raises(TypeError):
_ = Pixelate(0.1)
inputs = torch.rand(3, 32, 32)
transform = Pixelate(1)
transform(inputs)

def test_frost(self):
with pytest.raises(ValueError):
_ = Frost(-1)
with pytest.raises(TypeError):
_ = Frost(0.1)
inputs = torch.rand(3, 32, 32)
transform = Frost(1)
transform(inputs)
38 changes: 37 additions & 1 deletion torch_uncertainty/transforms/corruptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Adapted from https://github.com/hendrycks/robustness."""

from importlib import util
from io import BytesIO

import cv2
if util.find_spec("cv2"):
import cv2

import numpy as np
import torch
from PIL import Image
Expand All @@ -19,12 +22,27 @@

from torch_uncertainty.datasets import FrostImages

__all__ = [
"GaussianNoise",
"ShotNoise",
"ImpulseNoise",
"SpeckleNoise",
"GaussianBlur",
"GlassBlur",
"DefocusBlur",
"JPEGCompression",
"Pixelate",
"Frost",
]


class GaussianNoise(nn.Module):
def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.scale = [0, 0.04, 0.06, 0.08, 0.09, 0.10][severity]

Expand All @@ -43,6 +61,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.scale = [500, 250, 100, 75, 50][severity - 1]

Expand All @@ -61,6 +81,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.scale = [0, 0.01, 0.02, 0.03, 0.05, 0.07][severity]

Expand All @@ -83,6 +105,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.scale = [0.06, 0.1, 0.12, 0.16, 0.2][severity - 1]

Expand All @@ -105,6 +129,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.sigma = [0.4, 0.6, 0.7, 0.8, 1.0][severity - 1]

Expand All @@ -127,6 +153,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.sigma = [0.05, 0.25, 0.4, 0.25, 0.4][severity - 1]
self.max_delta = 1
Expand Down Expand Up @@ -176,6 +204,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.radius = [0.3, 0.4, 0.5, 1, 1.5][severity - 1]
self.alias_blur = [0.4, 0.5, 0.6, 0.2, 0.1][severity - 1]
Expand Down Expand Up @@ -206,6 +236,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.quality = [80, 65, 58, 50, 40][severity - 1]

Expand All @@ -226,6 +258,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.quality = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]

Expand All @@ -250,6 +284,8 @@ def __init__(self, severity: int) -> None:
super().__init__()
if not (0 <= severity <= 5):
raise ValueError("Severity must be between 0 and 5.")
if not isinstance(severity, int):
raise TypeError("Severity must be an integer.")
self.severity = severity
self.mix = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][
severity - 1
Expand Down

0 comments on commit 57dab0d

Please sign in to comment.