diff --git a/configs/imagenet_text2image_jewels.yaml b/configs/imagenet_text2image_jewels.yaml new file mode 100644 index 00000000..03ae1d35 --- /dev/null +++ b/configs/imagenet_text2image_jewels.yaml @@ -0,0 +1,112 @@ +wandb: + entity: null + mode: "offline" + +experiment: + project: "muse" + name: "imagenet-text2image" + output_dir: "imagenet-text2image" + max_train_examples: 1281167 # total number of imagenet examples + max_eval_examples: 12800 + save_every: 1000 + eval_every: 1000 + generate_every: 1000 + log_every: 50 + log_grad_norm_every: 500 + resume_from_checkpoint: False + resume_lr_scheduler: True + num_nodes: null + num_gpus_per_node: null + +model: + vq_model: + pretrained: "openMUSE/maskgit-vqgan-imagenet-f16-256" + type: "maskgit_vqgan" + text_encoder: + type: "t5" + pretrained: "google/t5-v1_1-large" + + transformer: + vocab_size: 1040 + max_position_embeddings: 256 + hidden_size: 1024 + num_hidden_layers: 24 + num_attention_heads: 16 + intermediate_size: 4096 + add_cross_attention: True + encoder_hidden_size: 1024 + project_encoder_hidden_states: False + codebook_size: 1024 + num_vq_tokens: 256 + initializer_range: 0.02 + norm_type: "rmsnorm" + layer_norm_eps: 1e-6 + use_normformer: False + use_encoder_layernorm: True + use_mlm_layer: True + use_mlm_layernorm: True + use_bias: False + hidden_dropout: 0.0 + attention_dropout: 0.0 + gradient_checkpointing: True + enable_xformers_memory_efficient_attention: False + offline: True + + +dataset: + type: "classification" + params: + train_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-train-{000000..000320}.tar" + eval_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-val-{000000..000012}.tar" + imagenet_class_mapping_path: "/p/scratch/ccstdl/muse/imagenet-class-mapping.json" + dataset.params.validation_prompts_file: null + batch_size: ${training.batch_size} + shuffle_buffer_size: 1000 + num_workers: 2 + resolution: 256 + pin_memory: True + persistent_workers: True + preprocessing: + max_seq_length: 16 + resolution: 256 + center_crop: True + random_flip: False + + +optimizer: + name: lion + params: # default adamw params + learning_rate: 1e-4 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + + +lr_scheduler: + scheduler: "constant_with_warmup" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 2000 + + +training: + gradient_accumulation_steps: 1 + batch_size: 1 + mixed_precision: "no" + enable_tf32: True + use_ema: False + seed: 9345104 + max_train_steps: 200000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: null + guidance_scale: 2.0 + generation_timesteps: 4 + # related to vae code sampling + use_soft_code_target: False + use_stochastic_code: False + soft_code_temp: 1.0 diff --git a/configs/imagenet_text2image_max_vit_jewels.yaml b/configs/imagenet_text2image_max_vit_jewels.yaml new file mode 100644 index 00000000..2049d2ca --- /dev/null +++ b/configs/imagenet_text2image_max_vit_jewels.yaml @@ -0,0 +1,115 @@ +wandb: + entity: null + mode: "offline" + +experiment: + project: "muse" + name: "imagenet-text2image" + output_dir: "imagenet-text2image" + max_train_examples: 1281167 # total number of imagenet examples + max_eval_examples: 12800 + save_every: 1000 + eval_every: 1000 + generate_every: 1000 + log_every: 50 + log_grad_norm_every: 500 + resume_from_checkpoint: False + resume_lr_scheduler: True + num_nodes: null + num_gpus_per_node: null + +model: + vq_model: + pretrained: "openMUSE/maskgit-vqgan-imagenet-f16-256" + type: "maskgit_vqgan" + text_encoder: + type: "t5" + pretrained: "google/t5-v1_1-large" + + transformer: + vocab_size: 1040 + max_position_embeddings: 256 + hidden_size: 1024 + num_hidden_layers: 24 + num_attention_heads: 16 + intermediate_size: 4096 + add_cross_attention: True + encoder_hidden_size: 1024 + project_encoder_hidden_states: False + codebook_size: 1024 + num_vq_tokens: 256 + initializer_range: 0.02 + norm_type: "rmsnorm" + layer_norm_eps: 1e-6 + use_normformer: False + use_encoder_layernorm: True + use_mlm_layer: True + use_mlm_layernorm: True + use_bias: False + hidden_dropout: 0.0 + attention_dropout: 0.0 + transformer_type: 'maxvit' + window_size: 8 + + gradient_checkpointing: True + enable_xformers_memory_efficient_attention: False + offline: True + + +dataset: + type: "classification" + params: + train_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-train-{000000..000320}.tar" + eval_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-val-{000000..000012}.tar" + imagenet_class_mapping_path: "/p/scratch/ccstdl/muse/imagenet-class-mapping.json" + dataset.params.validation_prompts_file: null + batch_size: ${training.batch_size} + shuffle_buffer_size: 1000 + num_workers: 2 + resolution: 256 + pin_memory: True + persistent_workers: True + preprocessing: + max_seq_length: 16 + resolution: 256 + center_crop: True + random_flip: False + + +optimizer: + name: lion + params: # default adamw params + learning_rate: 1e-4 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + + +lr_scheduler: + scheduler: "constant_with_warmup" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 2000 + + +training: + gradient_accumulation_steps: 1 + batch_size: 1 + mixed_precision: "no" + enable_tf32: True + use_ema: False + seed: 9345104 + max_train_steps: 200000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: null + guidance_scale: 2.0 + generation_timesteps: 4 + # related to vae code sampling + use_soft_code_target: False + use_stochastic_code: False + soft_code_temp: 1.0 \ No newline at end of file diff --git a/muse/modeling_transformer.py b/muse/modeling_transformer.py index 4a09cdb6..e1d59ea6 100644 --- a/muse/modeling_transformer.py +++ b/muse/modeling_transformer.py @@ -22,7 +22,7 @@ import numpy as np import torch import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, reduce from torch import nn from torch.utils.checkpoint import checkpoint from tqdm import tqdm @@ -184,7 +184,7 @@ def set_use_memory_efficient_attention_xformers( self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self.xformers_attention_op = attention_op - def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None): + def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, bias=None): if encoder_attention_mask is not None and self.use_memory_efficient_attention_xformers: raise ValueError("Memory efficient attention does not yet support encoder attention mask") @@ -202,7 +202,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_m if self.use_memory_efficient_attention_xformers: attn_output = xops.memory_efficient_attention( - query, key, value, op=self.xformers_attention_op, p=self.attention_dropout if self.training else 0.0 + query, + key, + value, + op=self.xformers_attention_op, + p=self.attention_dropout if self.training else 0.0, + attn_bias=bias, ) attn_output = attn_output.view(batch, q_seq_len, self.hidden_size) else: @@ -210,12 +215,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_m if encoder_attention_mask is not None: src_attn_mask = torch.ones(batch, q_seq_len, dtype=torch.long, device=query.device) attention_mask = make_attention_mask(src_attn_mask, encoder_attention_mask, dtype=query.dtype) - attn_output = self.attention(query, key, value, attention_mask) + attn_output = self.attention(query, key, value, attention_mask, bias) attn_output = self.out(attn_output) return attn_output - def attention(self, query, key, value, attention_mask=None): + def attention(self, query, key, value, attention_mask=None, bias=None): batch, seq_len = query.shape[:2] kv_seq_len = key.shape[1] query, key, value = map(lambda t: t.transpose(1, 2).contiguous(), (query, key, value)) # (B, nh, T, hs) @@ -227,6 +232,8 @@ def attention(self, query, key, value, attention_mask=None): alpha=1 / self.scale_attn, ) attn_weights = attn_weights.view(batch, self.num_heads, seq_len, kv_seq_len) # -1 is kv_seq_len + if bias is not None: + attn_weights += bias # Apply the attention mask if attention_mask is not None: attn_weights = torch.masked_fill(attn_weights, attention_mask, torch.finfo(query.dtype).min) @@ -815,6 +822,7 @@ def __init__( cond_embed_dim=None, ffn_type="glu", use_bias=False, + **kwargs, ): super().__init__() @@ -901,6 +909,126 @@ def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_m return hidden_states +class MaxVitTransformerLayer(TransformerLayer): + def __init__( + self, + hidden_size, + intermediate_size, + num_attention_heads, + hidden_dropout=0.0, + attention_dropout=0.0, + norm_type="layernorm", + use_bias=False, + window_size=8, + mbconv_expansion_rate=4, + mbconv_shrinkage_rate=0.25, + embedding_size=256, + **kwargs, + ): + super().__init__( + hidden_size, + intermediate_size, + num_attention_heads, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + norm_type=norm_type, + use_bias=use_bias, + **kwargs, + ) + norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm + self.mb_conv = MBConv( + embedding_size, + embedding_size, + expansion_rate=mbconv_expansion_rate, + shrinkage_rate=mbconv_shrinkage_rate, + dropout=hidden_dropout, + ) + self.window_size = window_size + self.norm0 = norm_cls(hidden_size) + self.attn0 = MaxVitAttention( + hidden_size=hidden_size, + num_heads=num_attention_heads, + attention_dropout=attention_dropout, + window_size=window_size, + ) + self.norm1 = norm_cls(hidden_size) + # In lucidrian's code the implementation of feedforward is different + self.ff0 = FeedForward(hidden_size=hidden_size, intermediate_size=hidden_size, hidden_dropout=hidden_dropout) + self.norm2 = norm_cls(hidden_size) + self.attn1 = MaxVitAttention( + hidden_size=hidden_size, + num_heads=num_attention_heads, + attention_dropout=attention_dropout, + window_size=window_size, + ) + self.norm3 = norm_cls(hidden_size) + self.ff1 = FeedForward(hidden_size=hidden_size, intermediate_size=hidden_size, hidden_dropout=hidden_dropout) + + def attention(self, hidden_states): + # If you examine the rearranges before the first attention, we get self.window_size intervals to make a window_sizexwindow_size size grid which gives + # our local attention once positional embeddings are added to it + # However for the second one, we see that we pick one element, then take x // window_size steps then pick the next one + # This helps us make a "global" grid of window_size x window_size + hidden_states = self.mb_conv(hidden_states) + # block like attention(local attention) + hidden_states = rearrange( + hidden_states, "b d (x w1) (y w2) -> b x y w1 w2 d", w1=self.window_size, w2=self.window_size + ) + hidden_states = self.norm0(hidden_states) + hidden_states = self.attn0(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.ff0(hidden_states) + hidden_states = rearrange(hidden_states, "b x y w1 w2 d -> b d (x w1) (y w2)") + # grid-like attention(global attention) + hidden_states = rearrange( + hidden_states, "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=self.window_size, w2=self.window_size + ) + hidden_states = self.norm2(hidden_states) + hidden_states = self.attn1(hidden_states) + hidden_states = self.norm3(hidden_states) + hidden_states = self.ff1(hidden_states) + hidden_states = rearrange(hidden_states, "b x y w1 w2 d -> b d (w1 x) (w2 y)") + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, cond_embeds=None): + residual = hidden_states + + hidden_states = self.attn_layer_norm(hidden_states) + if cond_embeds is not None: + hidden_states = self.self_attn_adaLN_modulation(hidden_states, cond_embeds) + hidden_states = hidden_states.permute(0, 2, 1) + b, c, seq_length = hidden_states.shape + h, w = int(seq_length**0.5), int(seq_length**0.5) + hidden_states = hidden_states.view(b, c, h, w) + attention_output = self.attention(hidden_states) + attention_output = attention_output.view(b, c, seq_length) + attention_output = attention_output.permute(0, 2, 1) + if self.use_normformer: + attention_output = self.post_attn_layer_norm(attention_output) + + hidden_states = residual + attention_output + + if encoder_hidden_states is not None: + residual = hidden_states + # TODO: should norm be applied to encoder_hidden_states as well? + hidden_states = self.crossattn_layer_norm(hidden_states) + if cond_embeds is not None: + hidden_states = self.cross_attn_adaLN_modulation(hidden_states, cond_embeds) + attention_output = self.crossattention( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + if self.use_normformer: + attention_output = self.post_crossattn_layer_norm(attention_output) + hidden_states = residual + attention_output + + residual = hidden_states + hidden_states = self.ffn(hidden_states, cond_embeds=cond_embeds) + hidden_states = residual + hidden_states + return hidden_states + + class Embed(nn.Module): def __init__( self, @@ -912,7 +1040,7 @@ def __init__( norm_type="layernorm", layer_norm_eps=1e-5, use_bias=False, - layer_norm_embedddings=False, + layer_norm_embeddings=False, use_embeddings_project=False, ): super().__init__() @@ -922,14 +1050,14 @@ def __init__( self.hidden_size = hidden_size self.hidden_dropout = hidden_dropout self.max_position_embeddings = max_position_embeddings - self.layer_norm_embedddings = layer_norm_embedddings + self.layer_norm_embeddings = layer_norm_embeddings self.use_embeddings_project = use_embeddings_project self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_size) self.position_embeddings = nn.Embedding(self.max_position_embeddings, self.embedding_size) self.dropout = nn.Dropout(self.hidden_dropout) - if layer_norm_embedddings: + if layer_norm_embeddings: norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm self.embeddings_ln = norm_cls(self.embedding_size, eps=layer_norm_eps) @@ -944,7 +1072,7 @@ def forward(self, input_ids): position_embeddings = self.position_embeddings(position_ids) input_embeddings = word_embeddings + position_embeddings - if self.layer_norm_embedddings: + if self.layer_norm_embeddings: input_embeddings = self.embeddings_ln(input_embeddings) if self.use_embeddings_project: @@ -992,7 +1120,7 @@ def __init__( max_position_embeddings=256, norm_type="layernorm", ln_elementwise_affine=True, - layer_norm_embedddings=False, + layer_norm_embeddings=False, layer_norm_eps=1e-5, use_position_embeddings=True, use_bias=False, @@ -1002,7 +1130,7 @@ def __init__( self.patch_size = patch_size self.max_position_embeddings = max_position_embeddings self.use_position_embeddings = use_position_embeddings - self.layer_norm_embedddings = layer_norm_embedddings + self.layer_norm_embeddings = layer_norm_embeddings self.embeddings = nn.Embedding(vocab_size, embedding_size) norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm @@ -1012,7 +1140,7 @@ def __init__( self.conv = nn.Conv2d(embedding_size * (patch_size**2), hidden_size, kernel_size=1, bias=use_bias) if use_position_embeddings: self.position_embeddings = nn.Embedding(self.max_position_embeddings, hidden_size) - if self.layer_norm_embedddings: + if self.layer_norm_embeddings: self.embeddings_ln = Norm2D( hidden_size, eps=layer_norm_eps, norm_type=norm_type, elementwise_affine=ln_elementwise_affine ) @@ -1032,7 +1160,7 @@ def forward(self, input_ids): position_ids = torch.arange(embeddings.shape[1])[None, :].to(input_ids.device) position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings - if self.layer_norm_embedddings: + if self.layer_norm_embeddings: embeddings = self.embeddings_ln(embeddings) return embeddings @@ -1106,9 +1234,12 @@ def __init__( codebook_size=1024, num_vq_tokens=256, num_classes=None, # set for class-conditioned generation + use_position_embeddings=False, use_codebook_size_for_output=False, use_conv_in_out=False, patch_size=1, + transformer_type="default", + window_size=4, **kwargs, ): super().__init__() @@ -1125,16 +1256,18 @@ def __init__( self.register_to_config(mask_token_id=vocab_size - 1) norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm + transformer_cls = TransformerLayer if transformer_type == "default" else MaxVitTransformerLayer if use_conv_in_out: self.embed = ConvEmbed( vocab_size, - embedding_size, + self.embedding_size, hidden_size, patch_size=patch_size, norm_type=norm_type, layer_norm_eps=layer_norm_eps, use_bias=use_bias, + use_position_embeddings=use_position_embeddings, ) else: self.embed = Embed( @@ -1155,7 +1288,7 @@ def __init__( self.transformer_layers = nn.ModuleList( [ - TransformerLayer( + transformer_cls( hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, num_attention_heads=self.num_attention_heads, @@ -1167,6 +1300,8 @@ def __init__( layer_norm_eps=layer_norm_eps, use_normformer=use_normformer, use_bias=use_bias, + embedding_size=self.embedding_size, + window_size=window_size, ) for _ in range(self.num_hidden_layers) ] @@ -1179,7 +1314,7 @@ def __init__( if use_conv_in_out: self.mlm_layer = ConvMlmLayer( self.output_size, - embedding_size, + self.embedding_size, hidden_size, patch_size=patch_size, norm_type=norm_type, @@ -1230,13 +1365,11 @@ def forward( ): if self.config.add_cross_attention and encoder_hidden_states is None: raise ValueError("If `add_cross_attention` is True, `encoder_hidden_states` should be provided.") - hidden_states = self.embed(input_ids) if encoder_hidden_states is not None and self.config.project_encoder_hidden_states: encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) - # condition dropout for classifier free guidance if encoder_hidden_states is not None and self.training and cond_dropout_prob > 0.0: batch_size = encoder_hidden_states.shape[0] @@ -1485,11 +1618,10 @@ def __init__( codebook_size=1024, num_vq_tokens=256, num_classes=None, # set for class-conditioned generation - use_position_embeddings=False, use_codebook_size_for_output=False, patch_size=1, layer_norm_before_mlm=False, - layer_norm_embedddings=False, + layer_norm_embeddings=False, add_cond_embeds=False, cond_embed_dim=None, add_micro_cond_embeds=False, @@ -1499,6 +1631,7 @@ def __init__( use_empty_embeds_for_uncond=False, learn_uncond_embeds=False, use_vannilla_resblock=False, + transformer_type="default", ffn_type="glu", res_ffn_factor=4, **kwargs, @@ -1518,6 +1651,7 @@ def __init__( self.register_to_config(block_out_channels=tuple(block_out_channels)) norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm + transformer_cls = TransformerLayer if transformer_type == "default" else MaxVitTransformerLayer if block_has_attention is None: block_has_attention = [False] * len(block_out_channels) @@ -1553,10 +1687,9 @@ def __init__( block_out_channels[0], patch_size=patch_size, norm_type=norm_type, - layer_norm_embedddings=layer_norm_embedddings, + layer_norm_embeddings=layer_norm_embeddings, layer_norm_eps=layer_norm_eps, ln_elementwise_affine=ln_elementwise_affine, - use_position_embeddings=use_position_embeddings, use_bias=use_bias, ) @@ -1624,7 +1757,7 @@ def __init__( # Mid Transformer self.transformer_layers = nn.ModuleList( [ - TransformerLayer( + transformer_cls( hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, num_attention_heads=self.num_attention_heads, @@ -2225,3 +2358,155 @@ def generate2( if return_intermediate: return sampled_ids, intermediate return sampled_ids + + +# Taken and slightly adapted from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/max_vit.py +# Originally proposed https://arxiv.org/abs/1709.01507 +# The main idea is without changing the size of the input, choose to prioritize some channels over others +class SqueezeExcitation(nn.Module): + def __init__(self, dim, shrinkage_rate=0.25): + super().__init__() + hidden_dim = int(dim * shrinkage_rate) + + self.gate = nn.Sequential( + nn.Linear(dim, hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim, dim, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + hidden = reduce(x, "b c h w -> b c", "mean") + hidden = self.gate(hidden) + hidden = rearrange(hidden, "b c -> b c 1 1") + return x * hidden + + +class MBConvResidual(nn.Module): + def __init__(self, fn, dropout=0.0): + super().__init__() + self.fn = fn + self.dropsample = Dropsample(dropout) + + def forward(self, x): + out = self.fn(x) + out = self.dropsample(out) + return out + x + + +class Dropsample(nn.Module): + def __init__(self, prob=0): + super().__init__() + self.prob = prob + + def forward(self, x): + device = x.device + + if self.prob == 0.0 or (not self.training): + return x + + keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() > self.prob + return x * keep_mask / (1 - self.prob) + + +class MBConv(nn.Module): + def __init__( + self, + dim_in, + dim_out, + downsample=False, + expansion_rate=4, + shrinkage_rate=0.25, + dropout=0.0, + ): + super().__init__() + # One function of this mbconv layer argued in the paper is to provide conditional position encoding especially with the depthwise convolution + # so that we do not need explicit positional embeddings + hidden_dim = int(expansion_rate * dim_out) + stride = 2 if downsample else 1 + + self.net = nn.Sequential( + nn.Conv2d(dim_in, hidden_dim, 1), + nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim), + nn.BatchNorm2d(hidden_dim), + nn.GELU(), + SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), + nn.BatchNorm2d(dim_out), + ) + + if dim_in == dim_out and not downsample: + self.net = MBConvResidual(self.net, dropout=dropout) + + def forward(self, x): + return self.net(x) + + +class MaxVitAttention(Attention): + def __init__( + self, hidden_size, num_heads, window_size=8, encoder_hidden_size=None, attention_dropout=0.0, use_bias=False + ): + super().__init__( + hidden_size, + num_heads, + encoder_hidden_size=encoder_hidden_size, + attention_dropout=attention_dropout, + use_bias=use_bias, + ) + self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.num_heads) + + # TODO: Maybe make this more comprehensible. This is basically positional embeddings for our grid + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos, indexing="ij")) + grid = rearrange(grid, "c i j -> (i j) c") + """ + grid is + tensor([[ 0, 0], + [ 0, 1], + [ 0, 2], + ..., + [window_size-1, window_size-1]]) + with shape [window_size**2, 2] + This is essentially 2d coordinates for window_size x window_size grid + """ + rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(grid, "j ... -> 1 j ...") + rel_pos += window_size - 1 + """ + rel_pos has shape [window_size**2, window_wize**2, 2] + here rel_pos[i] = tensor([[24+(i // window_size), 24+(i % window_size)], + [24+(i // window_size), 23+(i % window_size)], + [24+(i // window_size), 22+(i % window_size)], + ..., + [ (i // window_size), 2+(i % window_size)], + [ (i // window_size), 1+(i % window_size)], + [ (i // window_size), (i % window_size)]]) + """ + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1) + """ + rel_pos_indices has shape (625, 625) + rel_pos_indices[i] = [i, i+1, i+2...i+window_size-1, i+2*window_size-1, i+2*window_size....] + """ + self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) + + def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None): + batch, height, width, window_height, window_width, _ = hidden_states.shape + # flatten + # Here, w1 and w2 are both window size so x will have size (b x y), window_size**2, d + hidden_states = rearrange(hidden_states, "b x y w1 w2 d -> (b x y) (w1 w2) d") + bias = self.rel_pos_bias(self.rel_pos_indices) + # shape is [window_size**2, window_size**2, self.num_heads] + bias = rearrange(bias, "i j h -> h i j") + # shape is [self.num_heads, window_size**2, window_size**2] + # the bias adds positional embeddings for each window size segment + out = super().forward( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + bias=bias, + ) + out = rearrange(out, "b (w1 w2) d -> b w1 w2 d", w1=window_height, w2=window_width) + + # combine heads out + return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) diff --git a/slurm_scripts/imagenet_text2image_jewels.slurm b/slurm_scripts/imagenet_text2image_jewels.slurm new file mode 100644 index 00000000..7ef7eeca --- /dev/null +++ b/slurm_scripts/imagenet_text2image_jewels.slurm @@ -0,0 +1,91 @@ +#!/bin/bash +#SBATCH --job-name=t2i_testing +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=48 +#SBATCH --gres=gpu:4 +#SBATCH --exclusive +#SBATCH -A cstdl +#SBATCH --partition booster +#SBATCH --output=/p/home/jusers/isozaki1/juwels/%x-%j.out +#SBATCH --time=0:10:00 + +set -x -e + +source /p/home/jusers/isozaki1/juwels/miniconda3/etc/profile.d/conda.sh +conda activate muse + +echo "START TIME: $(date)" + +MUSE_REPO=/p/home/jusers/isozaki1/juwels/open-muse +OUTPUT_DIR=/p/home/jusers/isozaki1/juwels/muse +LOG_PATH=$OUTPUT_DIR/main_log.txt + +mkdir -p $OUTPUT_DIR +touch $LOG_PATH +pushd $MUSE_REPO + +GPUS_PER_NODE=4 +NNODES=$SLURM_NNODES + +CMD=" \ + training/train_muse.py config=configs/imagenet_text2image_jewels.yaml \ + wandb.entity=isamu \ + experiment.name=$(basename $OUTPUT_DIR) \ + experiment.output_dir=$OUTPUT_DIR \ + training.seed=9345104 \ + experiment.num_nodes=$SLURM_NNODES + + " + +# so processes know who to talk to +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +# export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# # AWS specific +# export NCCL_PROTO=simple +# export RDMAV_FORK_SAFE=1 +# export FI_EFA_FORK_SAFE=1 +# export FI_EFA_USE_DEVICE_RDMA=1 +# export FI_PROVIDER=efa +# export FI_LOG_LEVEL=1 +# export NCCL_IB_DISABLE=1 +# # # export NCCL_SOCKET_IFNAME=ens +# export PYTHONWARNINGS="ignore" +# export CXX=g++ + + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" \ No newline at end of file diff --git a/training/data.py b/training/data.py index b56db6fb..cc82446f 100644 --- a/training/data.py +++ b/training/data.py @@ -179,7 +179,6 @@ def tokenize(imagenet_class_id): input_ids=tokenize, text_raw=lambda class_idx: self.class_mapping[str(class_idx)], ), - wds.to_tuple("image", "input_ids"), ] else: processing_pipeline = [