Skip to content

layer6ai-labs/dgm_geometry

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DGM Geometry

Here, we study how the geometry of deep generative models (DGMs) can inform our understanding of phenomena like the likelihood out-of-distribution paradox. In tandem and as a supplement to these topics, we also study algorithms for local intrinsic dimension (LID) estimation of datapoints.

Installation

We use a conda environment for this project. To create the environment, run:

conda env create -f env.yaml
# this will create an environment named dgm_geometry
conda activate dgm_geometry

To download all the checkpoints, resources, and setup appropriate environment variables, run the following:

python scripts/download_resources.py

You may choose to skip this stage if you want to train your own models, but it is recommended as some of the notebooks.

Training a Deep Generative Model

Most of the capabilities in the codebase involve using the training script, we use Pytorch Lightning for training and lightning callbacks for monitoring the behaviour and properties of the manifold induced by the generative model. Even when no training is involved, we use the training script but load checkpoints and set the epoch count to zero.

Training involves running scripts/train.py alongside a dataset and an experiment configuraton. To get started, you can run the following examples for training flows or diffusions on image datasets:

# to train a greyscale diffusion, run the following! You can for example replace the dataset argument with mnist or fmnist
python scripts/train.py dataset=<grayscale-data> +experiment=train_diffusion_greyscale
# to train an RGB diffusion, run the following! You can for example replace the dataset argument with cifar10
python scripts/train.py dataset=<rgb-data> +experiment=train_diffusion_rgb
# to train a greyscale flow, run the following! You can for example replace the dataset argument with mnist or fmnist
python scripts/train.py dataset=<grayscale-data> +experiment=train_flow_greyscale
# to train an RGB flow, run the following! You can for example replace the dataset argument with cifar10
python scripts/train.py dataset=<rgb-data> +experiment=train_flow_rgb

For example:

python scripts/train.py dataset=mnist +experiment=train_diffusion_greyscale

Tracking

We use mlflow for tracking and logging; all the artifactors will be available in the outputs subdirectory. To set up mlflow, run the following:

cd outputs
mlflow ui

When ran, you can click on the provided link to view all the experiments. The logs are typically stored in the artifacts directory.

Test Runs

Use the following script to see the configuration that the script ends up running:

python scripts/train.py <training-options> --help --resolve

To perform a test run in development mode, you can run the following:

python scripts/train.py <training-options> dev_run=true train.trainer.callbacks=null train.trainer.fast_dev_run=True

For example:

python scripts/train.py dataset=cifar10 +experiment=train_diffusion_rgb --help --resolve # show configurations
python scripts/train.py dataset=cifar10 +experiment=train_diffusion_rgb dev_run=true train.trainer.callbacks=null train.trainer.fast_dev_run=True # run without trainig logic

Maintaining

Simple Tests

Please sort imports, format the code, and run the tests before pushing any changes:

isort data lid models tests
black -l 100 .
pytest tests

Hydra Tests

In addition, plase ensure that the hydra scripts are also working as expected. By default, the pytests command will only check the barebone YAML files for the Hydra scripts. Before merging major PRs, please run the following command which serves as an integration test. It will take some time but it will ensure that all the scripts are backwards compatible:

ENABLE_MLFLOW_LOGGING=True pytest tests/hydra

If, for example, the tests fail on specific settings, you can test them individually by setting the SCRIPT_LEVEL variable. Example include:

SCRIPT_LEVEL=ALL ENABLE_MLFLOW_LOGGING=True pytest tests/hydra/<hydra-test-file> # to run all the scripts
SCRIPT_LEVEL=0 ENABLE_MLFLOW_LOGGING=True pytest tests/hydra/<hydra-test-file> # to run a specific script
SCRIPT_LEVEL=0,2 ENABLE_MLFLOW_LOGGING=True pytest tests/hydra/<hydra-test-file> # to run multiple scripts
SCRIPT_LEVEL=0-2 ENABLE_MLFLOW_LOGGING=True pytest tests/hydra/<hydra-test-file> # to run a range of scripts
SCRIPT_LEVEL=0,2,3-5,7-10 ENABLE_MLFLOW_LOGGING=True pytest tests/hydra/<hydra-test-file> # to run multiple ranges of scripts

This mechanism is incorporated to allow for more granular control over the tests. As an example, when you encounter errors, pytest will show you an error on setting [True,setting{idx}], and in turn, you can run the script with SCRIPT_LEVEL=idx to debug the error associated with that script. For a full list of all the scripts that are being used for testing, please look at the corresponding script's test directory under tests/hydra.

Additional Notes on Hydra Tests

Note that even without setting this variables, the barebone resolved configurations are compared to ground truth configurations stored in tests/resources/hydra_config. If you want to add a new configuration, please add it to the tests/resources/hydra_config directory and then run the tests to ensure that the configurations are correct. When test configurations are being compared, the current version of the resolved configuration can be found under outputs/hydra_config, you may want to compare your configurations in that directory with the ones in tests/resources/hydra_config to ensure that the configurations are correct. This will be automatically done when you run the tests. In addition, you can also qualitatively monitor the runs by openning up your mlflow server and looking at the new runs in hydra_config. These runs are tagged using the current data and time and the setting (look for the tags setting and timestamp for these runs).

Maintaining the Website

This repository also hosts the content for the website related to these projects. The website consists of some html files and runnable jupyter notebooks. All of our notebooks are maintained in our website which works with a Quarto plugin. You may update the content of the website by changing the notebooks and .qmd files docs/ directory. To see the updates in real-time, run the following command that will start a local server:

# download and install Quarto from https://quarto.org/docs/get-started/
cd docs_quarto
quarto preview # opens up a local server on port 4200

To publish the website, move everything to the docs directory:

cp -r docs_quarto/_output/* docs/