diff --git a/training/train_muse.py b/training/train_muse.py index 6f5a6b64..07e69178 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -363,6 +363,13 @@ def main(): text_encoder.to(device=accelerator.device, dtype=weight_dtype) vq_model.to(device=accelerator.device) + if config.training.get("torch_compile_vqgan", False): + vq_encode = torch.compile(vq_model.encode, mode="reduce-overhead") + vq_decode = torch.compile(vq_model.decode_code, mode="reduce-overhead") + else: + vq_encode = vq_model.encode + vq_decode = vq_model.decode_code + if config.training.overfit_one_batch: train_dataloader = [next(iter(train_dataloader))] @@ -423,7 +430,7 @@ def prepare_inputs_and_labels( pixel_values, temp=config.training.soft_code_temp, stochastic=config.training.use_stochastic_code ) else: - image_tokens = vq_model.encode(pixel_values)[1] + image_tokens = vq_encode(pixel_values)[1] soft_targets = None encoder_hidden_states = text_encoder(input_ids)[0] @@ -554,7 +561,7 @@ def prepare_inputs_and_labels( # Generate images if (global_step + 1) % config.experiment.generate_every == 0 and accelerator.is_main_process: - generate_images(model, vq_model, text_encoder, tokenizer, accelerator, config, global_step + 1) + generate_images(model, vq_decode, text_encoder, tokenizer, accelerator, config, global_step + 1) global_step += 1 # TODO: Add generation @@ -601,7 +608,7 @@ def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inp @torch.no_grad() -def generate_images(model, vq_model, text_encoder, tokenizer, accelerator, config, global_step): +def generate_images(model, vq_decode, text_encoder, tokenizer, accelerator, config, global_step): logger.info("Generating images...") model.eval() # fmt: off @@ -634,7 +641,7 @@ def generate_images(model, vq_model, text_encoder, tokenizer, accelerator, confi # In the beginning of training, the model is not fully trained and the generated token ids can be out of range # so we clamp them to the correct range. gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1) - images = vq_model.decode_code(gen_token_ids) + images = vq_decode(gen_token_ids) model.train() # Convert to PIL images