-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2c52af3
Showing
22 changed files
with
4,040 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
check_dirs := . | ||
|
||
quality: | ||
black --check --preview $(check_dirs) | ||
isort --check-only $(check_dirs) | ||
flake8 $(check_dirs) | ||
|
||
style: | ||
black --preview $(check_dirs) | ||
isort $(check_dirs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# muse-open-reproduction | ||
A repo to train the best and fastest text2image model! | ||
|
||
## Goal | ||
This repo is for reproduction of the [MUSE](https://arxiv.org/abs/2301.00704) model. The goal is to create a simple and scalable repo, to reproduce MUSE and build knowedge about VQ + transformers at scale. | ||
We will use deduped LAION-2B + COYO-700M dataset for training. | ||
|
||
Project stages: | ||
1. Setup the codebase and train a class-conditional model on imagenet. | ||
2. Conduct text2image experiments on CC12M. | ||
3. Train improved VQGANs models. | ||
4. Train the full (base-256) model on LAION + COYO. | ||
5. Train the full (base-512) model on LAION + COYO. | ||
|
||
|
||
## Steps | ||
### Setup the codebase and train a class-conditional model no imagenet. | ||
- [x] Setup repo-structure | ||
- [x] Add transformers and VQGAN model. | ||
- [x] Add a generation support for the model. | ||
- [x] Port the VQGAN from [maskgit](https://github.com/google-research/maskgit) repo for imagenet experiment. | ||
- [ ] Finish and verify masking utils. | ||
- [ ] Add the masking arccos scheduling function from MUSE. | ||
- [x] Add EMA. | ||
- [ ] Suport OmegaConf for training configuration. | ||
- [ ] Add W&B logging utils. | ||
- [ ] Add WebDataset support. Not really needed for imagenet experiment but can work on this parallelly. (LAION is already available in this format so will be easier to use it). | ||
- [ ] Add a training script for class conditional generation using imagenet. (WIP) | ||
- [ ] Make the codebase ready for the cluster training. | ||
|
||
### Conduct text2image experiments on CC12M. | ||
- [ ] Finish data loading, pre-processing utils. | ||
- [ ] Add CLIP and T5 support. | ||
- [ ] Add text2image training script. | ||
- [ ] Add eavluation scripts (FiD, CLIP score). | ||
- [ ] Train on CC12M. Here we could conduct different experiments: | ||
- [ ] Train on CC12M with T5 conditioning. | ||
- [ ] Train on CC12M with CLIP conditioning. | ||
- [ ] Train on CC12M with CLIP + T5 conditioning (probably costly during training and experiments). | ||
- [ ] Self conditioning from Bit Diffusion paper. | ||
- [ ] Collect different prompts for intermmediate evaluations (Can reuse the prompts for dalle-mini, parti-prompts). | ||
- [ ] Setup a space where people can play with the model and provide feedback, compare with other models etc. | ||
|
||
### Train improved VQGANs models. | ||
- [ ] Add training component models for VQGAN (EMA, discriminator, LPIPS etc). | ||
- [ ] VGQAN training script. | ||
|
||
|
||
### Misc tasks | ||
- [ ] Create a space for visualizing exploring dataset | ||
- [ ] Create a space where people can try to find their own images and can opt-out of the dataset. | ||
|
||
|
||
## Repo structure (WIP) | ||
``` | ||
├── README.md | ||
├── configs -> All training config files. | ||
│ └── dummy_config.yaml | ||
├── muse | ||
│ ├── __init__.py | ||
│ ├── data.py -> All data related utils. Can create a data folder if needed. | ||
│ ├── logging.py -> Misc logging utils. | ||
│ ├── maskgit_vqgan.py -> VQGAN model from maskgit repo. | ||
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc | ||
│ ├── sampling.py -> Sampling/Generation utils. | ||
│ ├── taming_vqgan.py -> VQGAN model from taming repo. | ||
│ ├── training_utils.py -> Common training utils. | ||
│ └── transformer.py -> The main transformer model. | ||
├── pyproject.toml | ||
├── setup.cfg | ||
├── setup.py | ||
├── test.py | ||
└── training -> All training scripts. | ||
├── train_muse.py | ||
└── train_vqgan.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# A dir to store training configurations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
experiment: | ||
name: "imagenet" | ||
project: "muse" | ||
output_dir: "imagenet" | ||
max_train_examples: 110000000 | ||
num_eval_images: 1000 | ||
save_every: 1000 | ||
log_every: 50 | ||
|
||
|
||
model: | ||
vq_model: | ||
pretrained: "path to vq model" | ||
|
||
transformer: | ||
vocab_size: 2025 # (1024 + 1000 + 1 -> Vq + Imagenet + <mask>) | ||
hidden_size: 64 | ||
num_hidden_layers: 2 | ||
num_attention_heads: 4 | ||
intermediate_size: 256 | ||
hidden_dropout: 0.1 | ||
attention_dropout: 0.1 | ||
max_position_embeddings: 256 | ||
initializer_range: 0.02 | ||
layer_norm_eps: 1e-6 | ||
use_bias: False | ||
|
||
gradient_checkpointing: True | ||
|
||
|
||
dataset: | ||
params: | ||
path: "imagenet-1k-" # path to imagenet dataset | ||
streaming: True | ||
shuffle_buffer_size: 5000 | ||
batch_size: ${training.batch_size} | ||
workers: 1 | ||
class_mapping: "scripts/metadata/imagenet_idx_to_prompt.json" | ||
resolution: 256 | ||
preprocessing: | ||
resolution: 256 | ||
center_crop: True | ||
random_flip: True | ||
|
||
|
||
optimizer: | ||
name: adamw | ||
params: | ||
learning_rate: 0.0001 | ||
beta1: 0.9 | ||
beta2: 0.98 | ||
weight_decay: 0.01 | ||
epsilon: 0.00000001 | ||
|
||
|
||
lr_scheduler: | ||
scheduler: "ConstantWithWarmup" | ||
params: | ||
learning_rate: ${optimizer.params.learning_rate} | ||
warmup_steps: 500 | ||
|
||
|
||
training: | ||
gradient_accumulation_steps: 1 | ||
batch_size: 16 | ||
mixed_precision: bf16 | ||
use_ema: False | ||
seed: 42 | ||
max_train_steps: 1000 | ||
num_epochs: 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
__version__ = "0.0.1" | ||
|
||
from .maskgit_vqgan import MaskGitVQGAN | ||
from .taming_vqgan import VQGANModel | ||
from .transformer import MaskGitTransformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""All data related utilities and loaders are defined here.""" |
Oops, something went wrong.