Skip to content

Commit

Permalink
Merge pull request #14 from BerndDoser/mises-fisher
Browse files Browse the repository at this point in the history
BCEWithLogitsLoss for reconstruction loss
  • Loading branch information
BerndDoser authored Sep 18, 2023
2 parents fe4e526 + 7ee5671 commit 9e1fb00
Show file tree
Hide file tree
Showing 20 changed files with 457 additions and 69 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__pycache__
*.ckpt
HiPSter
lightning_logs
HiPSter/
lightning_logs/
data/MNIST/
wandb
wandb/
18 changes: 18 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: illustris-svae",
"type": "python",
"request": "launch",
"program": "main.py",
"args": "fit -c experiments/illustris-svae.yaml",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: gz-svae",
"type": "python",
"request": "launch",
"program": "main.py",
"args": "fit -c experiments/gz-svae.yaml",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: Current File",
"type": "python",
Expand Down
42 changes: 0 additions & 42 deletions data/MNISTDataModule.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
97 changes: 97 additions & 0 deletions devel/test-bce-predict.ipynb

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions devel/test-bce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import lightning.pytorch as pl
import yaml

import data
import models

if __name__ == "__main__":

with open("experiments/illustris-svae.yaml", "r", encoding="utf-8") as stream:
config = yaml.load(stream, Loader=yaml.Loader)

model = models.RotationalSphericalVariationalAutoencoder(**(config["model"]["init_args"]))

data_module = data.IllustrisSdssDataModule(
data_directories=["/home/doserbd/data/machine-learning/SKIRT_synthetic_images/TNG50/sdss/snapnum_095/data/"],
num_workers=8, batch_size=32)

trainer = pl.Trainer(accelerator='gpu', max_epochs=-1)
trainer.fit(model, data_module)

print("Done.")
101 changes: 101 additions & 0 deletions devel/test-predict.ipynb

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions devel/test-svae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse

import torch
import yaml

import data
import models

if __name__ == "__main__":

parser = argparse.ArgumentParser(description="Test SVAE")
parser.add_argument("--config", "-c", default="config.yaml",
help="config file (default = 'config.yaml').")

vars = vars(parser.parse_args())
if "config" in vars:
with open(vars["config"], "r", encoding="utf-8") as stream:
config = yaml.load(stream, Loader=yaml.Loader)

model = models.RotationalSphericalVariationalAutoencoder(**(config["model"]["init_args"]))

checkpoint = torch.load("lightning_logs/version_13/checkpoints/epoch=43-step=3344.ckpt")
model.load_state_dict(checkpoint["state_dict"])

data_module = data.IllustrisSdssDataModule(**(config["data"]["init_args"]))
data_module.setup("predict")
dataloader = data_module.predict_dataloader()

for batch in dataloader:
image = batch["image"]
z_mean, _ = model.encode(image)
print(z_mean)
break
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ dependencies:
- pandas
- pip
- pip:
- ipykernel
- jsonargparse[signatures]
- tensorboard
- tensorrt
- python=3.10
- pytorch-gpu
Expand Down
39 changes: 39 additions & 0 deletions experiments/gz-svae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
seed_everything: 42

model:
class_path: models.RotationalSphericalVariationalAutoencoder
init_args:
h_dim: 256
z_dim: 3
distribution: "vmf"

data:
class_path: data.GalaxyZooDataModule
init_args:
data_dir: /hits/basement/ain/Data/KaggleGalaxyZoo/images_training_rev1
extension: jpg
batch_size: 256
shuffle: True
num_workers: 12

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001

lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: min
factor: 0.1
patience: 500
cooldown: 500
min_lr: 1.e-5
monitor: train_loss
verbose: True

trainer:
max_epochs: -1
accelerator: gpu
devices: 1
precision: 32
46 changes: 46 additions & 0 deletions experiments/illustris-svae-test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
seed_everything: 42

model:
class_path: models.RotationalSphericalVariationalAutoencoder
init_args:
h_dim: 256
z_dim: 3
distribution: "vmf"

data:
class_path: data.IllustrisSdssDataModule
init_args:
data_directories: ["/home/doserbd/data/machine-learning/SKIRT_synthetic_images/TNG50/sdss/snapnum_095/data/"]
extension: fits
minsize: 100
batch_size: 32
shuffle: True
num_workers: 8

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001

lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: min
factor: 0.1
patience: 500
cooldown: 500
min_lr: 1.e-5
monitor: train_loss
verbose: True

trainer:
max_epochs: -1
accelerator: gpu
devices: 1
precision: 32
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
name: illustris-svae-bce
log_model: True
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ model:
h_dim: 256
z_dim: 3
distribution: "vmf"
spherical_loss_weight: 1.e-4

data:
class_path: data.IllustrisSdssDataModule
Expand All @@ -20,7 +21,7 @@ data:
minsize: 100
batch_size: 128
shuffle: True
num_workers: 32
num_workers: 16

optimizer:
class_path: torch.optim.Adam
Expand All @@ -42,4 +43,10 @@ trainer:
max_epochs: -1
accelerator: gpu
devices: 4
precision: 32
precision: 32
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
name: illustris-svae
log_model: True
42 changes: 42 additions & 0 deletions experiments/illustris-test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
seed_everything: 42

model:
class_path: models.RotationalSphericalAutoencoder

data:
class_path: data.IllustrisSdssDataModule
init_args:
data_directories: ["/home/doserbd/data/machine-learning/SKIRT_synthetic_images/TNG50/sdss/snapnum_095/data/"]
extension: fits
minsize: 100
batch_size: 32
shuffle: True
num_workers: 8

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001

lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: min
factor: 0.1
patience: 500
cooldown: 500
min_lr: 1.e-5
monitor: train_loss
verbose: True

trainer:
max_epochs: -1
accelerator: gpu
devices: 1
precision: 32
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
name: illustris-test
log_model: True
4 changes: 3 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
It contains the following modules:
1. `rotational_spherical_projecting_autoencoder`:
1. `rotational_spherical_autoencoder`:
A plain convolutional autoencoder projecting on a sphere with naive rotation invariance.
2. `rotational_spherical_variational_autoencoder`:
A convolutional variational autoencoder projecting on a sphere with naive rotation invariance.
"""

from .rotational_spherical_autoencoder import RotationalSphericalAutoencoder
Expand Down
Loading

0 comments on commit 9e1fb00

Please sign in to comment.