diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 0000000..993256a --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,42 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python package + +on: + push: + branches: [ "**" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v2 + with: + submodules: 'recursive' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest diff --git a/.vscode/settings.json b/.vscode/settings.json index d7fe692..c34b980 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,5 +5,10 @@ "files.trimFinalNewlines": true, "python.analysis.extraPaths": [ "./external/s-vae-pytorch" - ] + ], + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } \ No newline at end of file diff --git a/devel/cli-vae.py b/devel/cli-vae.py deleted file mode 100644 index 2d96906..0000000 --- a/devel/cli-vae.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -from lightning.pytorch.cli import LightningCLI -from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule - -import models - -torch.set_float32_matmul_precision('high') - -def cli_main(): - cli = LightningCLI(models.VAE(latent_dim=32, input_height=28, input_width=28, input_channels=1, lr=0.0001, batch_size=32), - MNISTDataModule, - save_config_kwargs={"overwrite": True}) - - -if __name__ == "__main__": - cli_main() diff --git a/devel/rot-vae-cli.py b/devel/rot-vae-cli.py deleted file mode 100644 index 5da6c44..0000000 --- a/devel/rot-vae-cli.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -from lightning.pytorch.cli import LightningCLI - -import data -import models - -torch.set_float32_matmul_precision('high') - -def cli_main(): - cli = LightningCLI(models.RotationalSphericalVariationalAutoencoder, - data.GalaxyZooDataModule, - save_config_kwargs={"overwrite": True}) - - -if __name__ == "__main__": - cli_main() diff --git a/devel/rot-vae-onnx.py b/devel/rot-vae-onnx.py index af903a8..adc9245 100644 --- a/devel/rot-vae-onnx.py +++ b/devel/rot-vae-onnx.py @@ -1,6 +1,12 @@ +import os +import sys + +script_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(script_dir, '../')) + import models -model = models.rotational_spherical_variational_autoencoder() +model = models.RotationalVariationalAutoencoder() # Test the model with a dummy input model(model.example_input_array) diff --git a/devel/svae-cli.py b/devel/svae-cli.py deleted file mode 100644 index 2161fe0..0000000 --- a/devel/svae-cli.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -from lightning.pytorch.cli import LightningCLI - -import data -import models - -torch.set_float32_matmul_precision('high') - -def cli_main(): - LightningCLI(models.SVAE, - data.MNISTDataModule, - save_config_kwargs={"overwrite": True}) - - -if __name__ == "__main__": - cli_main() diff --git a/devel/svae-fit.py b/devel/svae-fit.py deleted file mode 100644 index f5ad694..0000000 --- a/devel/svae-fit.py +++ /dev/null @@ -1,28 +0,0 @@ -import lightning.pytorch as pl -import torch -import torch.nn as nn -import torch.utils.data -from torchvision import transforms -from torchvision.datasets import MNIST - -import data -import models - -# hidden dimension and dimension of latent space -H_DIM = 128 -Z_DIM = 2 -DIST = 'vmf' # 'normal' or 'vmf' - -# model = models.VAE(latent_dim=1024, input_height=32, input_width=32, input_channels=1, lr=0.0001, batch_size=32) -model = models.SVAE(h_dim=H_DIM, z_dim=Z_DIM, distribution=DIST) - -# dataset = MNIST('./data', download=True, transform=transforms.ToTensor()) -# train_loader = torch.utils.data.DataLoader(dataset, num_workers=12, batch_size=32) - -# trainer = pl.Trainer(accelerator='gpu', devices=1) -# trainer.fit(model, train_dataloaders=train_loader) - -data = data.MNISTDataModule(batch_size=32, num_workers=4) - -trainer = pl.Trainer(accelerator='gpu', devices=1) -trainer.fit(model, datamodule=data) diff --git a/devel/svae-predict.py b/devel/svae-predict.py deleted file mode 100644 index 8953b0c..0000000 --- a/devel/svae-predict.py +++ /dev/null @@ -1,15 +0,0 @@ -import lightning.pytorch as pl -import torch -import torch.nn as nn -import torch.utils.data -from torchvision import transforms -from torchvision.datasets import MNIST - -import data -import models - -data = data.MNISTDataModule(batch_size=32, num_workers=4) -model = models.SVAE.load_from_checkpoint("./lightning_logs/version_3/checkpoints/epoch=340-step=586179.ckpt") - -trainer = pl.Trainer(accelerator='gpu', devices=1) -predictions = trainer.predict(model, datamodule=data) diff --git a/devel/test-bce-predict.ipynb b/devel/test-bce-predict.ipynb index ca4ca61..3678e43 100644 --- a/devel/test-bce-predict.ipynb +++ b/devel/test-bce-predict.ipynb @@ -13,7 +13,7 @@ "import data\n", "import models\n", "\n", - "model = models.RotationalSphericalVariationalAutoencoder.load_from_checkpoint(\n", + "model = models.RotationalVariationalAutoencoder.load_from_checkpoint(\n", " # \"spherinator/euj7yhpr/checkpoints/epoch=34-step=16205.ckpt\")\n", " \"spherinator/iktq1u6p/checkpoints/epoch=8-step=4167.ckpt\")\n", "\n", diff --git a/devel/test-bce.py b/devel/test-bce.py deleted file mode 100644 index 459ad1b..0000000 --- a/devel/test-bce.py +++ /dev/null @@ -1,21 +0,0 @@ -import lightning.pytorch as pl -import yaml - -import data -import models - -if __name__ == "__main__": - - with open("experiments/illustris-svae.yaml", "r", encoding="utf-8") as stream: - config = yaml.load(stream, Loader=yaml.Loader) - - model = models.RotationalSphericalVariationalAutoencoder(**(config["model"]["init_args"])) - - data_module = data.IllustrisSdssDataModule( - data_directories=["/home/doserbd/data/machine-learning/SKIRT_synthetic_images/TNG50/sdss/snapnum_095/data/"], - num_workers=8, batch_size=32) - - trainer = pl.Trainer(accelerator='gpu', max_epochs=-1) - trainer.fit(model, data_module) - - print("Done.") diff --git a/devel/test-predict.ipynb b/devel/test-predict.ipynb index b8ec507..178c51e 100644 --- a/devel/test-predict.ipynb +++ b/devel/test-predict.ipynb @@ -13,7 +13,7 @@ "import data\n", "import models\n", "\n", - "model = models.RotationalSphericalAutoencoder.load_from_checkpoint(\n", + "model = models.RotationalAutoencoder.load_from_checkpoint(\n", " \"spherinator/ksl5gx72/checkpoints/epoch=10-step=5093.ckpt\")\n", "\n", "data_module = data.IllustrisSdssDataModule(\n", diff --git a/devel/test-svae.py b/devel/test-svae.py deleted file mode 100644 index b47b278..0000000 --- a/devel/test-svae.py +++ /dev/null @@ -1,33 +0,0 @@ -import argparse - -import torch -import yaml - -import data -import models - -if __name__ == "__main__": - - parser = argparse.ArgumentParser(description="Test SVAE") - parser.add_argument("--config", "-c", default="config.yaml", - help="config file (default = 'config.yaml').") - - vars = vars(parser.parse_args()) - if "config" in vars: - with open(vars["config"], "r", encoding="utf-8") as stream: - config = yaml.load(stream, Loader=yaml.Loader) - - model = models.RotationalSphericalVariationalAutoencoder(**(config["model"]["init_args"])) - - checkpoint = torch.load("lightning_logs/version_13/checkpoints/epoch=43-step=3344.ckpt") - model.load_state_dict(checkpoint["state_dict"]) - - data_module = data.IllustrisSdssDataModule(**(config["data"]["init_args"])) - data_module.setup("predict") - dataloader = data_module.predict_dataloader() - - for batch in dataloader: - image = batch["image"] - z_mean, _ = model.encode(image) - print(z_mean) - break diff --git a/experiments/Illustris.yaml b/experiments/Illustris.yaml index 1ccd2a9..ba4d9e4 100644 --- a/experiments/Illustris.yaml +++ b/experiments/Illustris.yaml @@ -1,7 +1,7 @@ seed_everything: 42 model: - class_path: models.RotationalSphericalAutoencoder + class_path: models.RotationalAutoencoder init_args: bottleneck: 3 rotations: 36 diff --git a/experiments/gz-svae.yaml b/experiments/gz-svae.yaml index da3ab4c..f8123d5 100644 --- a/experiments/gz-svae.yaml +++ b/experiments/gz-svae.yaml @@ -1,7 +1,7 @@ seed_everything: 42 model: - class_path: models.RotationalSphericalVariationalAutoencoder + class_path: models.RotationalVariationalAutoencoder init_args: h_dim: 256 z_dim: 3 diff --git a/experiments/illustris-svae-test.yaml b/experiments/illustris-svae-test.yaml index 1008c94..526b97c 100644 --- a/experiments/illustris-svae-test.yaml +++ b/experiments/illustris-svae-test.yaml @@ -1,7 +1,7 @@ seed_everything: 42 model: - class_path: models.RotationalSphericalVariationalAutoencoder + class_path: models.RotationalVariationalAutoencoder init_args: h_dim: 256 z_dim: 3 diff --git a/experiments/illustris-svae.yaml b/experiments/illustris-svae.yaml index 7c4bd5d..8f49d35 100644 --- a/experiments/illustris-svae.yaml +++ b/experiments/illustris-svae.yaml @@ -1,7 +1,7 @@ seed_everything: 42 model: - class_path: models.RotationalSphericalVariationalAutoencoder + class_path: models.RotationalVariationalAutoencoder init_args: h_dim: 256 z_dim: 3 diff --git a/experiments/illustris-test.yaml b/experiments/illustris-test.yaml index 56f155b..d01288e 100644 --- a/experiments/illustris-test.yaml +++ b/experiments/illustris-test.yaml @@ -1,7 +1,7 @@ seed_everything: 42 model: - class_path: models.RotationalSphericalAutoencoder + class_path: models.RotationalAutoencoder data: class_path: data.IllustrisSdssDataModule diff --git a/hipster.py b/hipster.py index 1fbf9c8..11a5667 100755 --- a/hipster.py +++ b/hipster.py @@ -13,7 +13,6 @@ import healpy import numpy import torch -import torchvision.transforms as transforms import torchvision.transforms.functional as functional import yaml from astropy.io.votable import writeto @@ -235,8 +234,8 @@ def project_dataset(self, model, dataloader, rotation_steps): images = batch['image'] for r in range(rotation_steps): rot_images = functional.rotate(images, 360/rotation_steps*r, expand=False) # rotate - crop_images = functional.center_crop(rot_images, [256,256]) # crop - scaled_images = functional.resize(crop_images, [128,128], antialias=False) # scale + crop_images = functional.center_crop(rot_images, [self.crop_size, self.crop_size]) # crop + scaled_images = functional.resize(crop_images, [self.output_size, self.output_size], antialias=False) # scale with torch.no_grad(): coordinates = model.project(scaled_images) reconstruction = model.reconstruct(coordinates) @@ -273,7 +272,7 @@ def generate_catalog(self, model, dataloader, catalog_file): if answer != "Yes": return print("projecting dataset:") - coordinates, rotations, losses = self.project_dataset(model, dataloader, 36) + coordinates, rotations, losses = self.project_dataset(model, dataloader, model.rotations) coordinates = coordinates.cpu().detach().numpy() rotations = rotations.cpu().detach().numpy() losses = losses.cpu().detach().numpy() diff --git a/main.py b/main.py index 03d602c..a7bd204 100644 --- a/main.py +++ b/main.py @@ -3,8 +3,6 @@ import torch from lightning.pytorch.cli import LightningCLI -import models -import data torch.set_float32_matmul_precision('high') diff --git a/models/__init__.py b/models/__init__.py index 3c92718..0e2c939 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -8,17 +8,16 @@ A plain convolutional autoencoder projecting on a sphere with naive rotation invariance. 2. `rotational_spherical_variational_autoencoder`: A convolutional variational autoencoder projecting on a sphere with naive rotation invariance. +3. `SVAE`: + Straightforward implementation of hyperspherical variational autoencoder (SVAE). """ -from .rotational_spherical_autoencoder import RotationalSphericalAutoencoder -from .rotational_spherical_variational_autoencoder import \ - RotationalSphericalVariationalAutoencoder +from .rotational_autoencoder import RotationalAutoencoder +from .rotational_variational_autoencoder import RotationalVariationalAutoencoder from .svae import SVAE -from .vae import VAE __all__ = [ - 'RotationalSphericalAutoencoder', - 'RotationalSphericalVariationalAutoencoder', + 'RotationalAutoencoder', + 'RotationalVariationalAutoencoder', 'SVAE', - 'VAE', ] diff --git a/models/rotational_spherical_autoencoder.py b/models/rotational_autoencoder.py similarity index 97% rename from models/rotational_spherical_autoencoder.py rename to models/rotational_autoencoder.py index e50684c..a3243fc 100644 --- a/models/rotational_spherical_autoencoder.py +++ b/models/rotational_autoencoder.py @@ -7,12 +7,12 @@ from .spherinator_module import SpherinatorModule -class RotationalSphericalAutoencoder(SpherinatorModule): +class RotationalAutoencoder(SpherinatorModule): def __init__(self, rotations: int = 36, bottleneck: int = 3): - super(RotationalSphericalAutoencoder, self).__init__() + super(RotationalAutoencoder, self).__init__() self.bottleneck = bottleneck self.rotations = rotations self.conv0 = nn.Conv2d(in_channels=3, out_channels=16, diff --git a/models/rotational_spherical_variational_autoencoder.py b/models/rotational_variational_autoencoder.py similarity index 85% rename from models/rotational_spherical_variational_autoencoder.py rename to models/rotational_variational_autoencoder.py index 083a899..3e2204d 100644 --- a/models/rotational_spherical_variational_autoencoder.py +++ b/models/rotational_variational_autoencoder.py @@ -16,15 +16,23 @@ VonMisesFisher) -class RotationalSphericalVariationalAutoencoder(SpherinatorModule): - - def __init__(self, h_dim=256, z_dim=2, distribution='normal', spherical_loss_weight=1e-4): +class RotationalVariationalAutoencoder(SpherinatorModule): + + def __init__(self, + h_dim: int = 256, + z_dim: int = 2, + distribution: str = 'normal', + rotations: int = 36, + beta: float = 1.0, + spherical_loss_weight: float = 1e-4): """ - RotationalSphericalVariationalAutoencoder initializer + RotationalVariationalAutoencoder initializer :param h_dim: dimension of the hidden layers :param z_dim: dimension of the latent representation :param distribution: string either `normal` or `vmf`, indicates which distribution to use + :param rotations: number of rotations + :param beta: factor for beta-VAE :param spherical_loss_weight: weight of the spherical loss """ super().__init__() @@ -32,7 +40,7 @@ def __init__(self, h_dim=256, z_dim=2, distribution='normal', spherical_loss_wei self.example_input_array = torch.randn(1, 3, 64, 64) self.h_dim, self.z_dim, self.distribution = h_dim, z_dim, distribution - self.spherical_loss_weight = spherical_loss_weight + self.rotations, self.beta, self.spherical_loss_weight = rotations, beta, spherical_loss_weight self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(5,5), stride=2, padding=2) self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5,5), stride=2, padding=2) @@ -115,13 +123,12 @@ def spherical_loss(self, coordinates): def training_step(self, batch, batch_idx): images = batch["image"] - rotations = 36 - losses = torch.zeros(images.shape[0], rotations) - losses_recon = torch.zeros(images.shape[0], rotations) - losses_KL = torch.zeros(images.shape[0], rotations) - losses_spher = torch.zeros(images.shape[0], rotations) - for i in range(rotations): - x = functional.rotate(images, 360.0 / rotations * i, expand=False) + losses = torch.zeros(images.shape[0], self.rotations) + losses_recon = torch.zeros(images.shape[0], self.rotations) + losses_KL = torch.zeros(images.shape[0], self.rotations) + losses_spher = torch.zeros(images.shape[0], self.rotations) + for i in range(self.rotations): + x = functional.rotate(images, 360.0 / self.rotations * i, expand=False) x = functional.center_crop(x, [256,256]) input = functional.resize(x, [64,64], antialias=False) @@ -138,7 +145,7 @@ def training_step(self, batch, batch_idx): loss_spher = self.spherical_loss(z_mean) - losses[:,i] = loss_recon + loss_KL + self.spherical_loss_weight * loss_spher + losses[:,i] = loss_recon + self.beta * loss_KL #+ self.spherical_loss_weight * loss_spher losses_recon[:,i] = loss_recon losses_KL[:,i] = loss_KL losses_spher[:,i] = loss_spher diff --git a/models/vae.py b/models/vae.py deleted file mode 100644 index 1d665b4..0000000 --- a/models/vae.py +++ /dev/null @@ -1,224 +0,0 @@ -# Source: https://github.com/alpercanberk/pytorch-lightning-vae - -from collections import namedtuple -from typing import Optional - -import lightning.pytorch as pl -import torch -import torch.nn as nn -from torch.optim import Adam -from torch.optim.lr_scheduler import ReduceLROnPlateau -from torchvision.utils import make_grid - - -class Stack(nn.Module): - def __init__(self, channels, height, width): - super(Stack, self).__init__() - self.channels = channels - self.height = height - self.width = width - - def forward(self, x): - return x.view(x.size(0), self.channels, self.height, self.width) - - -class VAE(pl.LightningModule): - def __init__(self, - latent_dim:int, - input_height:int, - input_width:int, - input_channels:int, - lr: float, - batch_size: int, - save_path: Optional[str] = None, **kwargs): - """Init function for the VAE - Args: - latent_dim (int): Latent Hidden Size - reconstruction loss vs KL-Divergence Loss - lr (float): Learning Rate, will not be used if auto_lr_find is used. - dataset (Optional[str]): Dataset to used - save_path (Optional[str]): Path to save images - """ - - super().__init__() - self.latent_dim = latent_dim - - self.save_hyperparameters() - - self.lr = lr - - self.batch_size = batch_size - - ENC_OUT_DIM = 128 - - self.encoder = nn.Sequential( - nn.Flatten(), - nn.Linear(input_channels*input_height*input_width, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1), - nn.Linear(392, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1), - nn.Linear(196, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1), - nn.Linear(128, ENC_OUT_DIM) - ) - self.decoder = nn.Sequential( - nn.Linear(latent_dim, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1), - nn.Linear(128, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1), - nn.Linear(196, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1), - nn.Linear(392, input_channels*input_height*input_width), - nn.Sigmoid(), - Stack(input_channels, input_height, input_width), - ) - - self.hidden2mu = nn.Linear(ENC_OUT_DIM, latent_dim) - self.hidden2log_var = nn.Linear(ENC_OUT_DIM, latent_dim) - - self.log_scale = nn.Parameter(torch.Tensor([0.0])) - - self.example_input_array = torch.randn(input_channels, input_height, input_width) - - - def encode(self, x): - hidden = self.encoder(x) - mu = self.hidden2mu(hidden) - log_var = self.hidden2log_var(hidden) - return mu, log_var - - def decode(self, x): - x = self.decoder(x) - return x - - def reparametrize(self, mu, log_var): - # Reparametrization Trick to allow gradients to backpropagate from the - # stochastic part of the model - sigma = torch.exp(0.5*log_var) - z = torch.randn_like(sigma) - return mu + sigma*z - - def gaussian_likelihood(self, x_hat, logscale, x): - scale = torch.exp(logscale) - mean = x_hat - dist = torch.distributions.Normal(mean, scale) - - # measure prob of seeing image under p(x|z) - log_pxz = dist.log_prob(x) - - return log_pxz.sum(dim=(1, 2, 3)) - - def kl_divergence(self, z, mu, std): - # -------------------------- - # Monte carlo KL divergence - # -------------------------- - # 1. define the first two probabilities (in this case Normal for both) - p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) - q = torch.distributions.Normal(mu, std) - - # 2. get the probabilities from the equation - log_qzx = q.log_prob(z) - log_pz = p.log_prob(z) - - # kl - kl = (log_qzx - log_pz) - kl = kl.sum(-1) - return kl - - def forward(self, x): - - mu, log_var = self.encode(x) - std = torch.exp(log_var / 2) - - #Sample from distribution - q = torch.distributions.Normal(mu, std) - z = q.rsample() - - #Push sample through decoder - x_hat = self.decode(z) - - return mu, std, z, x_hat - - - def training_step(self, batch, batch_idx): - - x, _ = batch - - mu, std, z, x_hat = self.forward(x) - - # reconstruction loss - recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x) - - #expectation under z of the kl divergence between q(z|x) and - #a standard normal distribution of the same shape - kl = self.kl_divergence(z, mu, std) - - # elbo - elbo = (kl - recon_loss) - elbo = elbo.mean() - - self.log('train_kl_loss', kl.mean(), on_step=True, - on_epoch=True, prog_bar=False) - self.log('train_recon_loss', recon_loss.mean(), on_step=True, - on_epoch=True, prog_bar=False) - self.log('train_loss', elbo, on_step=True, - on_epoch=True, prog_bar=True) - - # train_images = make_grid(x[:16]).cpu().numpy() - return elbo - - def validation_step(self, batch, batch_idx): - - x, _ = batch - - mu, std, z, x_hat = self.forward(x) - - # reconstruction loss - recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x) - - #expectation under z of the kl divergence between q(z|x) and - #a standard normal distribution of the same shape - kl = self.kl_divergence(z, mu, std) - - # elbo - elbo = kl - recon_loss - elbo = elbo.mean() - - self.log('val_kl_loss', kl.mean(), on_step=False, on_epoch=True) - self.log('val_recon_loss', recon_loss.mean(), on_step=False, on_epoch=True) - self.log('val_loss', elbo, on_step=False, on_epoch=True) - - self.logger.experiment.add_image('Normalized Inputs', make_grid(x[:8])) - - - return x_hat, elbo - - def configure_optimizers(self): - optimizer = Adam(self.parameters(), lr=(self.lr or self.learning_rate)) - lr_scheduler = ReduceLROnPlateau(optimizer,) - return { - "optimizer": optimizer, "lr_scheduler": lr_scheduler, - "monitor": "val_loss" - } - - def interpolate(self, x1, x2): - assert x1.shape == x2.shape, "Inputs must be of the same shape" - if x1.dim() == 3: - x1 = x1.unsqueeze(0) - if x2.dim() == 3: - x2 = x2.unsqueeze(0) - if self.training: - raise Exception( - "This function should not be called when model is still " - "in training mode. Use model.eval() before calling the " - "function") - mu1, lv1 = self.encode(x1) - mu2, lv2 = self.encode(x2) - z1 = self.reparametrize(mu1, lv1) - z2 = self.reparametrize(mu2, lv2) - weights = torch.arange(0.1, 0.9, 0.1) - intermediate = [self.decode(z1)] - for wt in weights: - inter = (1.-wt)*z1 + wt*z2 - intermediate.append(self.decode(inter)) - intermediate.append(self.decode(z2)) - out = torch.stack(intermediate, dim=0).squeeze(1) - return out, (mu1, lv1), (mu2, lv2) - - @staticmethod - def custom_transform(normalization): - return None, None diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..952c648 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[tool.pytest] +timeout = 300 +testpaths = 'tests' + +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::DeprecationWarning" +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a18653c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +lightning +scipy +torch +torchvision diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_rotational_variational_autoencoder.py b/tests/test_rotational_variational_autoencoder.py new file mode 100644 index 0000000..1cbe9fd --- /dev/null +++ b/tests/test_rotational_variational_autoencoder.py @@ -0,0 +1,14 @@ +from models import RotationalVariationalAutoencoder + +def test_rotational_variational_autoencoder(): + + z_dim = 2 + model = RotationalVariationalAutoencoder(z_dim=z_dim) + input = model.example_input_array + batch_size = input.shape[0] + + (z_mean, z_var), (_, _), _, recon = model(input) + + assert z_mean.shape == (batch_size, z_dim) + assert z_var.shape == (batch_size, z_dim) + assert recon.shape == input.shape