Skip to content

Commit

Permalink
Setting up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
isamu-isozaki committed Aug 6, 2023
1 parent b04369a commit 8024ec2
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 1 deletion.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ Project stages:
All the artifacts of this project will be uploaded to the [openMUSE](https://huggingface.co/openMUSE) organization on the huggingface hub.

## Usage
from muse.modeling_transformer import MaskGitTransformer
from omegaconf import DictConfig, ListConfig, OmegaConf
conf = OmegaConf.load("configs/imagenet_text2image_max_vit_jewels.yaml")
model = MaskGitTransformer(**conf.model.transformer)
model = model.to("cuda")

### Installation

Expand Down
113 changes: 113 additions & 0 deletions configs/imagenet_text2image_jewels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
wandb:
entity: null
mode: "offline"

experiment:
project: "muse"
name: "imagenet-text2image"
output_dir: "imagenet-text2image"
max_train_examples: 1281167 # total number of imagenet examples
max_eval_examples: 12800
save_every: 1000
eval_every: 1000
generate_every: 1000
log_every: 50
log_grad_norm_every: 500
resume_from_checkpoint: False
resume_lr_scheduler: True
num_nodes: null
num_gpus_per_node: null

model:
vq_model:
pretrained: "openMUSE/maskgit-vqgan-imagenet-f16-256"
type: "maskgit_vqgan"
text_encoder:
type: "t5"
pretrained: "google/t5-v1_1-large"

transformer:
vocab_size: 1040
max_position_embeddings: 256
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
intermediate_size: 4096
add_cross_attention: True
encoder_hidden_size: 1024
project_encoder_hidden_states: False
codebook_size: 1024
num_vq_tokens: 256
initializer_range: 0.02
norm_type: "rmsnorm"
layer_norm_eps: 1e-6
use_normformer: False
use_encoder_layernorm: True
use_mlm_layer: True
use_mlm_layernorm: True
use_bias: False
hidden_dropout: 0.0
attention_dropout: 0.0

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True
offline: True


dataset:
type: "classification"
params:
train_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-train-{000000..000320}.tar"
eval_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-val-{000000..000012}.tar"
imagenet_class_mapping_path: "/p/scratch/ccstdl/muse/imagenet-class-mapping.json"
dataset.params.validation_prompts_file: null
batch_size: ${training.batch_size}
shuffle_buffer_size: 1000
num_workers: 2
resolution: 256
pin_memory: True
persistent_workers: True
preprocessing:
max_seq_length: 16
resolution: 256
center_crop: True
random_flip: False


optimizer:
name: adamw
params: # default adamw params
learning_rate: 1e-4
scale_lr: False # scale learning rate by total batch size
beta1: 0.9
beta2: 0.999
weight_decay: 0.01
epsilon: 1e-8


lr_scheduler:
scheduler: "constant_with_warmup"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 2000


training:
gradient_accumulation_steps: 1
batch_size: 16
mixed_precision: "no"
enable_tf32: True
use_ema: False
seed: 9345104
max_train_steps: 200000
overfit_one_batch: False
cond_dropout_prob: 0.1
min_masking_rate: 0.0
label_smoothing: 0.0
max_grad_norm: null
guidance_scale: 2.0
generation_timesteps: 4
# related to vae code sampling
use_soft_code_target: False
use_stochastic_code: False
soft_code_temp: 1.0
1 change: 1 addition & 0 deletions muse/modeling_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,7 @@ def __init__(
shrinkage_rate = 0.25,
dropout = 0.,
):
super().__init__()
# One function of this mbconv layer argued in the paper is to provide conditional position encoding especially with the depthwise convolution
# so that we do not need explicit positional embeddings
hidden_dim = int(expansion_rate * dim_out)
Expand Down
3 changes: 2 additions & 1 deletion training/train_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ def prepare_inputs_and_labels(

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
accelerator.print(torch.cuda.max_memory_allocated()/(1024 ** 3), " allocated")

if config.training.get("use_ema", False):
ema.step(model.parameters())

Expand Down Expand Up @@ -676,7 +678,6 @@ def prepare_inputs_and_labels(
f"Batch (t): {batch_time_m.val:0.4f} "
f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
)

# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
Expand Down

0 comments on commit 8024ec2

Please sign in to comment.