Skip to content

Commit

Permalink
Created config
Browse files Browse the repository at this point in the history
  • Loading branch information
isamu-isozaki committed Jul 25, 2023
1 parent 3a3881a commit b04369a
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
wandb:
entity: null
entity: null
mode: "offline"

experiment:
project: "muse"
name: "imagenet-movq-conv-maxvit"
output_dir: "imagenet-movq-conv-maxvit"
name: "imagenet-text2image"
output_dir: "imagenet-text2image"
max_train_examples: 1281167 # total number of imagenet examples
max_eval_examples: 12800
save_every: 1000
eval_every: 1000
generate_every: 1000
log_every: 30
log_every: 50
log_grad_norm_every: 500
resume_from_checkpoint: False
resume_lr_scheduler: True
num_nodes: null
num_gpus_per_node: null

model:
vq_model:
type: "movq"
pretrained: "openMUSE/movq-lion-high-res-f8-16384"

pretrained: "openMUSE/maskgit-vqgan-imagenet-f16-256"
type: "maskgit_vqgan"
text_encoder:
type: "t5"
pretrained: "openMUSE/t5-v1_1-large-enc"
pretrained: "google/t5-v1_1-large"

transformer:
vocab_size: 16400 # (16384 + + 1 = 16385 -> Vq + <mask>, use 16400 for even division by 8)
vocab_size: 1040
max_position_embeddings: 256
hidden_size: 1024
num_hidden_layers: 24
Expand All @@ -34,8 +36,8 @@ model:
add_cross_attention: True
encoder_hidden_size: 1024
project_encoder_hidden_states: False
codebook_size: 16384
num_vq_tokens: 1024
codebook_size: 1024
num_vq_tokens: 256
initializer_range: 0.02
norm_type: "rmsnorm"
layer_norm_eps: 1e-6
Expand All @@ -46,25 +48,23 @@ model:
use_bias: False
hidden_dropout: 0.0
attention_dropout: 0.0
use_codebook_size_for_output: True
use_conv_in_out: True
patch_size: 2
transformer_type: "maxvit"
transformer_type: 'maxvit'

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True
offline: True


dataset:
type: "classification"
params:
train_shards_path_or_url: "pipe:aws s3 cp s3://muse-datasets/imagenet-wds/imagenet-train-{000000..000320}.tar -"
eval_shards_path_or_url: "pipe:aws s3 cp s3://muse-datasets/imagenet-wds/imagenet-val-{000000..000012}.tar -"
imagenet_class_mapping_path: "/fsx/suraj/data/imagenet-class-mapping.json"
train_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-train-{000000..000320}.tar"
eval_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-val-{000000..000012}.tar"
imagenet_class_mapping_path: "/p/scratch/ccstdl/muse/imagenet-class-mapping.json"
dataset.params.validation_prompts_file: null
batch_size: ${training.batch_size}
shuffle_buffer_size: 1000
num_workers: 4
num_workers: 2
resolution: 256
pin_memory: True
persistent_workers: True
Expand All @@ -76,7 +76,7 @@ dataset:


optimizer:
name: fused_adamw
name: adamw
params: # default adamw params
learning_rate: 1e-4
scale_lr: False # scale learning rate by total batch size
Expand All @@ -90,12 +90,12 @@ lr_scheduler:
scheduler: "constant_with_warmup"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 1000
warmup_steps: 2000


training:
gradient_accumulation_steps: 2
batch_size: 64
gradient_accumulation_steps: 1
batch_size: 16
mixed_precision: "no"
enable_tf32: True
use_ema: False
Expand All @@ -107,9 +107,8 @@ training:
label_smoothing: 0.0
max_grad_norm: null
guidance_scale: 2.0
generation_timesteps: 8
generation_timesteps: 4
# related to vae code sampling
use_soft_code_target: False
use_stochastic_code: False
soft_code_temp: 1.0

soft_code_temp: 1.0
91 changes: 91 additions & 0 deletions slurm_scripts/imagenet_text2image_jewels.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/bin/bash
#SBATCH --job-name=t2i_testing
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=48
#SBATCH --gres=gpu:4
#SBATCH --exclusive
#SBATCH -A cstdl
#SBATCH --partition booster
#SBATCH --output=/p/home/jusers/isozaki1/juwels/%x-%j.out
#SBATCH --time=0:10:00

set -x -e

source /p/home/jusers/isozaki1/juwels/miniconda3/etc/profile.d/conda.sh
conda activate muse

echo "START TIME: $(date)"

MUSE_REPO=/p/home/jusers/isozaki1/juwels/open-muse
OUTPUT_DIR=/p/home/jusers/isozaki1/juwels/muse
LOG_PATH=$OUTPUT_DIR/main_log.txt

mkdir -p $OUTPUT_DIR
touch $LOG_PATH
pushd $MUSE_REPO

GPUS_PER_NODE=4
NNODES=$SLURM_NNODES

CMD=" \
training/train_muse.py config=configs/imagenet_text2image_jewels.yaml \
wandb.entity=isamu \
experiment.name=$(basename $OUTPUT_DIR) \
experiment.output_dir=$OUTPUT_DIR \
training.seed=9345104 \
experiment.num_nodes=$SLURM_NNODES
"

# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000

export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
"

echo $CMD

# hide duplicated errors using this hack - will be properly fixed in pt-1.12
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json

# force crashing on nccl issues like hanging broadcast
# export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1

# # AWS specific
# export NCCL_PROTO=simple
# export RDMAV_FORK_SAFE=1
# export FI_EFA_FORK_SAFE=1
# export FI_EFA_USE_DEVICE_RDMA=1
# export FI_PROVIDER=efa
# export FI_LOG_LEVEL=1
# export NCCL_IB_DISABLE=1
# # # export NCCL_SOCKET_IFNAME=ens
# export PYTHONWARNINGS="ignore"
# export CXX=g++


# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"

# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH

echo "END TIME: $(date)"

0 comments on commit b04369a

Please sign in to comment.