-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #171 from BerndDoser/rot-loss
Alternative rotational invariance and convolution neural networks
- Loading branch information
Showing
33 changed files
with
1,349 additions
and
704 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
seed_everything: 42 | ||
|
||
model: | ||
class_path: spherinator.models.RotationalVariationalAutoencoderPower | ||
init_args: | ||
encoder: | ||
class_path: spherinator.models.ConvolutionalEncoder | ||
decoder: | ||
class_path: spherinator.models.ConvolutionalDecoder | ||
h_dim: 256 | ||
z_dim: 3 | ||
image_size: 224 | ||
rotations: 1 | ||
beta: 1.0e-3 | ||
|
||
data: | ||
class_path: spherinator.data.ImagesDataModule | ||
init_args: | ||
data_directory: /local_data/doserbd/data/pokemon | ||
extensions: ['jpg'] | ||
image_size: 224 | ||
batch_size: 32 | ||
shuffle: True | ||
num_workers: 16 | ||
|
||
optimizer: | ||
class_path: torch.optim.Adam | ||
init_args: | ||
lr: 1.e-3 | ||
|
||
lr_scheduler: | ||
class_path: lightning.pytorch.cli.ReduceLROnPlateau | ||
init_args: | ||
mode: min | ||
factor: 0.1 | ||
patience: 5 | ||
cooldown: 5 | ||
min_lr: 1.e-6 | ||
monitor: train_loss | ||
verbose: True | ||
|
||
trainer: | ||
max_epochs: -1 | ||
accelerator: gpu | ||
devices: [3] | ||
precision: 32 | ||
callbacks: | ||
- class_path: spherinator.callbacks.LogReconstructionCallback | ||
init_args: | ||
num_samples: 6 | ||
# - class_path: lightning.pytorch.callbacks.ModelCheckpoint | ||
# init_args: | ||
# monitor: train_loss | ||
# filename: "{epoch}-{train_loss:.2f}" | ||
# save_top_k: 3 | ||
# mode: min | ||
# every_n_epochs: 1 | ||
logger: | ||
class_path: lightning.pytorch.loggers.WandbLogger | ||
init_args: | ||
project: spherinator | ||
log_model: True | ||
entity: ain-space | ||
tags: | ||
- rot-loss | ||
- pokemon |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import torch | ||
import torchvision.transforms.v2 as transforms | ||
from lightning.pytorch import LightningDataModule | ||
from torch.utils.data import DataLoader | ||
|
||
from spherinator.data.images_dataset import ImagesDataset | ||
|
||
|
||
class ImagesDataModule(LightningDataModule): | ||
"""Defines access to the ImagesDataset.""" | ||
|
||
def __init__( | ||
self, | ||
data_directory: str, | ||
extensions: list[str] = ["jpg"], | ||
shuffle: bool = True, | ||
image_size: int = 64, | ||
batch_size: int = 32, | ||
num_workers: int = 1, | ||
): | ||
"""Initializes the data loader | ||
Args: | ||
data_directory (str): The data directory | ||
shuffle (bool, optional): Wether or not to shuffle whe reading. Defaults to True. | ||
image_size (int, optional): The size of the images. Defaults to 64. | ||
batch_size (int, optional): The batch size for training. Defaults to 32. | ||
num_workers (int, optional): How many worker to use for loading. Defaults to 1. | ||
download (bool, optional): Wether or not to download the data. Defaults to False. | ||
""" | ||
super().__init__() | ||
|
||
self.data_directory = data_directory | ||
self.extensions = extensions | ||
self.shuffle = shuffle | ||
self.image_size = image_size | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
|
||
self.data_train = None | ||
self.dataloader_train = None | ||
|
||
self.transform_train = transforms.Compose( | ||
[ | ||
transforms.Resize((self.image_size, self.image_size), antialias=True), | ||
transforms.Lambda( # Normalize | ||
lambda x: (x - torch.min(x)) / (torch.max(x) - torch.min(x)) | ||
), | ||
] | ||
) | ||
self.transform_processing = self.transform_train | ||
self.transform_images = self.transform_train | ||
self.transform_thumbnail_images = transforms.Compose( | ||
[ | ||
self.transform_train, | ||
transforms.Resize((100, 100), antialias=True), | ||
] | ||
) | ||
|
||
def setup(self, stage: str): | ||
"""Sets up the data set and data loaders. | ||
Args: | ||
stage (str): Defines for which stage the data is needed. | ||
For the moment just fitting is supported. | ||
""" | ||
|
||
if stage == "fit" and self.data_train is None: | ||
self.data_train = ImagesDataset( | ||
data_directory=self.data_directory, | ||
extensions=self.extensions, | ||
transform=self.transform_train, | ||
) | ||
self.dataloader_train = DataLoader( | ||
self.data_train, | ||
batch_size=self.batch_size, | ||
shuffle=self.shuffle, | ||
num_workers=self.num_workers, | ||
) | ||
else: | ||
raise ValueError(f"Stage {stage} not supported.") | ||
|
||
def train_dataloader(self): | ||
"""Gets the data loader for training.""" | ||
return self.dataloader_train |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" Create dataset with all image files in a directory. | ||
""" | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
import skimage.io as io | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
def get_all_filenames(data_directory: str, extensions: list[str]): | ||
result = [] | ||
for dirpath, dirnames, filenames in os.walk(data_directory): | ||
for filename in filenames: | ||
if Path(filename).suffix[1:] in extensions: | ||
result.append(os.path.join(dirpath, filename)) | ||
for dirname in dirnames: | ||
result.extend(get_all_filenames(dirname, extensions)) | ||
return result | ||
|
||
|
||
class ImagesDataset(Dataset): | ||
"""Create dataset with all image files in a directory.""" | ||
|
||
def __init__( | ||
self, | ||
data_directory: str, | ||
extensions: list[str] = ["jpg"], | ||
transform=None, | ||
): | ||
"""Initializes the data set. | ||
Args: | ||
data_directory (str): The data directory. | ||
transform (torchvision.transforms, optional): A single or a set of | ||
transformations to modify the images. Defaults to None. | ||
""" | ||
|
||
self.transform = transform | ||
self.filenames = sorted(get_all_filenames(data_directory, extensions)) | ||
|
||
def __len__(self) -> int: | ||
"""Return the number of items in the dataset. | ||
Returns: | ||
int: Number of items in dataset. | ||
""" | ||
return len(self.filenames) | ||
|
||
def __getitem__(self, index: int) -> torch.Tensor: | ||
"""Retrieves the item/items with the given indices from the dataset. | ||
Args: | ||
index: The index of the item to retrieve. | ||
Returns: | ||
data: Data of the item/items with the given indices. | ||
""" | ||
# Swap axis 0 and 2 to bring the color channel to the front | ||
data = io.imread(self.filenames[index]) | ||
data = data.swapaxes(0, 2) | ||
data = torch.Tensor(data) | ||
if self.transform: | ||
data = self.transform(data) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class ConvolutionalDecoder2(nn.Module): | ||
def __init__(self, latent_dim: int): | ||
super().__init__() | ||
|
||
self.dec1 = nn.Sequential( | ||
nn.Linear(latent_dim, 1024 * 4 * 4), | ||
nn.Unflatten(1, (1024, 4, 4)), | ||
nn.BatchNorm2d(1024), | ||
nn.ReLU(), | ||
) # 512 x 8 x 8 | ||
self.dec2 = nn.Sequential( | ||
nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1), | ||
nn.BatchNorm2d(512), | ||
nn.ReLU(), | ||
) # 512 x 8 x 8 | ||
self.dec3 = nn.Sequential( | ||
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), | ||
nn.BatchNorm2d(512), | ||
nn.ReLU(), | ||
) # 512 x 16 x 16 | ||
self.dec4 = nn.Sequential( | ||
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), | ||
nn.BatchNorm2d(256), | ||
nn.ReLU(), | ||
) # 256 x 32 x 32 | ||
self.dec5 = nn.Sequential( | ||
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), | ||
nn.BatchNorm2d(128), | ||
nn.ReLU(), | ||
) # 128 x 64 x 64 | ||
self.dec6 = nn.Sequential( | ||
nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1), | ||
nn.BatchNorm2d(3), | ||
) # 3 x 128 x 128 | ||
|
||
def forward(self, x: torch.tensor) -> torch.tensor: | ||
x = self.dec1(x) | ||
x = self.dec2(x) | ||
x = self.dec3(x) | ||
x = self.dec4(x) | ||
x = self.dec5(x) | ||
x = self.dec6(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.