Skip to content

Commit

Permalink
🔥 Avoid code copies in tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Jan 3, 2024
1 parent 8751d80 commit 57165a8
Showing 1 changed file with 25 additions and 124 deletions.
149 changes: 25 additions & 124 deletions auto_tutorials_source/tutorial_corruptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
torchvision and matplotlib.
"""
import torch
from torchvision.datasets import CIFAR10, MNIST
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize

from torchvision.utils import make_grid
Expand All @@ -25,189 +25,90 @@ def get_images(main_transform, severity):
ds = CIFAR10("./data", train=False, download=False, transform=ds_transforms)
return make_grid([ds[i][0] for i in range(6)]).permute(1, 2, 0)

def show_images(transform):
print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(transform, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(transform, severity))
plt.show()

# %%
# 1. Gaussian Noise
# ~~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import GaussianNoise

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(GaussianNoise, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(GaussianNoise, severity))
plt.show()
show_images(GaussianNoise)

# %%
# 2. Shot Noise
# ~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import ShotNoise

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(ShotNoise, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(ShotNoise, severity))
plt.show()
show_images(ShotNoise)

# %%
# 3. Impulse Noise
# ~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import ImpulseNoise

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(ImpulseNoise, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(ImpulseNoise, severity))
plt.show()
show_images(ImpulseNoise)

# %%
# 4. Speckle Noise
# ~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import SpeckleNoise

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(SpeckleNoise, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(SpeckleNoise, severity))
plt.show()
show_images(SpeckleNoise)

# %%
# 5. Gaussian Blur
# ~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import GaussianBlur

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(GaussianBlur, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(GaussianBlur, severity))
plt.show()

show_images(GaussianBlur)

# %%
# 6. Glass Blur
# ~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import GlassBlur

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(GlassBlur, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(GlassBlur, severity))
plt.show()

show_images(GlassBlur)

# %%
# 7. Defocus Blur
# ~~~~~~~~~~~~~~~

from torch_uncertainty.transforms.corruptions import DefocusBlur

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(DefocusBlur, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(DefocusBlur, severity))
plt.show()
show_images(DefocusBlur)

#%%
# 8. JPEG Compression
# ~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import JPEGCompression

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(JPEGCompression, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(JPEGCompression, severity))
plt.show()
show_images(JPEGCompression)

#%%
# 9. Pixelate
# ~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import Pixelate

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(Pixelate, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(Pixelate, severity))
plt.show()
show_images(Pixelate)

#%%
# 10. Frost
# ~~~~~~~~
from torch_uncertainty.transforms.corruptions import Frost

print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(Frost, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(Frost, severity))
plt.show()

show_images(Frost)

# %%
# Reference
Expand Down

0 comments on commit 57165a8

Please sign in to comment.