Skip to content

Commit

Permalink
Merge pull request #17 from BerndDoser/hipster
Browse files Browse the repository at this point in the history
Hipster refactoring and model unit tests
  • Loading branch information
BerndDoser authored Sep 26, 2023
2 parents 9e1fb00 + 229d965 commit 0fb663f
Show file tree
Hide file tree
Showing 27 changed files with 119 additions and 406 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python package

on:
push:
branches: [ "**" ]
pull_request:
branches: [ "main" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
with:
submodules: 'recursive'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,10 @@
"files.trimFinalNewlines": true,
"python.analysis.extraPaths": [
"./external/s-vae-pytorch"
]
],
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
16 changes: 0 additions & 16 deletions devel/cli-vae.py

This file was deleted.

16 changes: 0 additions & 16 deletions devel/rot-vae-cli.py

This file was deleted.

8 changes: 7 additions & 1 deletion devel/rot-vae-onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import os
import sys

script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(script_dir, '../'))

import models

model = models.rotational_spherical_variational_autoencoder()
model = models.RotationalVariationalAutoencoder()

# Test the model with a dummy input
model(model.example_input_array)
Expand Down
16 changes: 0 additions & 16 deletions devel/svae-cli.py

This file was deleted.

28 changes: 0 additions & 28 deletions devel/svae-fit.py

This file was deleted.

15 changes: 0 additions & 15 deletions devel/svae-predict.py

This file was deleted.

2 changes: 1 addition & 1 deletion devel/test-bce-predict.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"import data\n",
"import models\n",
"\n",
"model = models.RotationalSphericalVariationalAutoencoder.load_from_checkpoint(\n",
"model = models.RotationalVariationalAutoencoder.load_from_checkpoint(\n",
" # \"spherinator/euj7yhpr/checkpoints/epoch=34-step=16205.ckpt\")\n",
" \"spherinator/iktq1u6p/checkpoints/epoch=8-step=4167.ckpt\")\n",
"\n",
Expand Down
21 changes: 0 additions & 21 deletions devel/test-bce.py

This file was deleted.

2 changes: 1 addition & 1 deletion devel/test-predict.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"import data\n",
"import models\n",
"\n",
"model = models.RotationalSphericalAutoencoder.load_from_checkpoint(\n",
"model = models.RotationalAutoencoder.load_from_checkpoint(\n",
" \"spherinator/ksl5gx72/checkpoints/epoch=10-step=5093.ckpt\")\n",
"\n",
"data_module = data.IllustrisSdssDataModule(\n",
Expand Down
33 changes: 0 additions & 33 deletions devel/test-svae.py

This file was deleted.

2 changes: 1 addition & 1 deletion experiments/Illustris.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
seed_everything: 42

model:
class_path: models.RotationalSphericalAutoencoder
class_path: models.RotationalAutoencoder
init_args:
bottleneck: 3
rotations: 36
Expand Down
2 changes: 1 addition & 1 deletion experiments/gz-svae.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
seed_everything: 42

model:
class_path: models.RotationalSphericalVariationalAutoencoder
class_path: models.RotationalVariationalAutoencoder
init_args:
h_dim: 256
z_dim: 3
Expand Down
2 changes: 1 addition & 1 deletion experiments/illustris-svae-test.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
seed_everything: 42

model:
class_path: models.RotationalSphericalVariationalAutoencoder
class_path: models.RotationalVariationalAutoencoder
init_args:
h_dim: 256
z_dim: 3
Expand Down
2 changes: 1 addition & 1 deletion experiments/illustris-svae.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
seed_everything: 42

model:
class_path: models.RotationalSphericalVariationalAutoencoder
class_path: models.RotationalVariationalAutoencoder
init_args:
h_dim: 256
z_dim: 3
Expand Down
2 changes: 1 addition & 1 deletion experiments/illustris-test.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
seed_everything: 42

model:
class_path: models.RotationalSphericalAutoencoder
class_path: models.RotationalAutoencoder

data:
class_path: data.IllustrisSdssDataModule
Expand Down
7 changes: 3 additions & 4 deletions hipster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import healpy
import numpy
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as functional
import yaml
from astropy.io.votable import writeto
Expand Down Expand Up @@ -235,8 +234,8 @@ def project_dataset(self, model, dataloader, rotation_steps):
images = batch['image']
for r in range(rotation_steps):
rot_images = functional.rotate(images, 360/rotation_steps*r, expand=False) # rotate
crop_images = functional.center_crop(rot_images, [256,256]) # crop
scaled_images = functional.resize(crop_images, [128,128], antialias=False) # scale
crop_images = functional.center_crop(rot_images, [self.crop_size, self.crop_size]) # crop
scaled_images = functional.resize(crop_images, [self.output_size, self.output_size], antialias=False) # scale
with torch.no_grad():
coordinates = model.project(scaled_images)
reconstruction = model.reconstruct(coordinates)
Expand Down Expand Up @@ -273,7 +272,7 @@ def generate_catalog(self, model, dataloader, catalog_file):
if answer != "Yes":
return
print("projecting dataset:")
coordinates, rotations, losses = self.project_dataset(model, dataloader, 36)
coordinates, rotations, losses = self.project_dataset(model, dataloader, model.rotations)
coordinates = coordinates.cpu().detach().numpy()
rotations = rotations.cpu().detach().numpy()
losses = losses.cpu().detach().numpy()
Expand Down
2 changes: 0 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import torch
from lightning.pytorch.cli import LightningCLI
import models
import data

torch.set_float32_matmul_precision('high')

Expand Down
13 changes: 6 additions & 7 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
A plain convolutional autoencoder projecting on a sphere with naive rotation invariance.
2. `rotational_spherical_variational_autoencoder`:
A convolutional variational autoencoder projecting on a sphere with naive rotation invariance.
3. `SVAE`:
Straightforward implementation of hyperspherical variational autoencoder (SVAE).
"""

from .rotational_spherical_autoencoder import RotationalSphericalAutoencoder
from .rotational_spherical_variational_autoencoder import \
RotationalSphericalVariationalAutoencoder
from .rotational_autoencoder import RotationalAutoencoder
from .rotational_variational_autoencoder import RotationalVariationalAutoencoder
from .svae import SVAE
from .vae import VAE

__all__ = [
'RotationalSphericalAutoencoder',
'RotationalSphericalVariationalAutoencoder',
'RotationalAutoencoder',
'RotationalVariationalAutoencoder',
'SVAE',
'VAE',
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from .spherinator_module import SpherinatorModule


class RotationalSphericalAutoencoder(SpherinatorModule):
class RotationalAutoencoder(SpherinatorModule):

def __init__(self,
rotations: int = 36,
bottleneck: int = 3):
super(RotationalSphericalAutoencoder, self).__init__()
super(RotationalAutoencoder, self).__init__()
self.bottleneck = bottleneck
self.rotations = rotations
self.conv0 = nn.Conv2d(in_channels=3, out_channels=16,
Expand Down
Loading

0 comments on commit 0fb663f

Please sign in to comment.