Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch2.0 compile support #54

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add compile support
patil-suraj committed Apr 21, 2023
commit b6ba51825b39cb5125a2b16f82b17a94f5251524
15 changes: 11 additions & 4 deletions training/train_muse.py
Original file line number Diff line number Diff line change
@@ -363,6 +363,13 @@ def main():
text_encoder.to(device=accelerator.device, dtype=weight_dtype)
vq_model.to(device=accelerator.device)

if config.training.get("torch_compile_vqgan", False):
vq_encode = torch.compile(vq_model.encode, mode="reduce-overhead")
vq_decode = torch.compile(vq_model.decode_code, mode="reduce-overhead")
else:
vq_encode = vq_model.encode
vq_decode = vq_model.decode_code

if config.training.overfit_one_batch:
train_dataloader = [next(iter(train_dataloader))]

@@ -423,7 +430,7 @@ def prepare_inputs_and_labels(
pixel_values, temp=config.training.soft_code_temp, stochastic=config.training.use_stochastic_code
)
else:
image_tokens = vq_model.encode(pixel_values)[1]
image_tokens = vq_encode(pixel_values)[1]
soft_targets = None

encoder_hidden_states = text_encoder(input_ids)[0]
@@ -554,7 +561,7 @@ def prepare_inputs_and_labels(

# Generate images
if (global_step + 1) % config.experiment.generate_every == 0 and accelerator.is_main_process:
generate_images(model, vq_model, text_encoder, tokenizer, accelerator, config, global_step + 1)
generate_images(model, vq_decode, text_encoder, tokenizer, accelerator, config, global_step + 1)

global_step += 1
# TODO: Add generation
@@ -601,7 +608,7 @@ def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inp


@torch.no_grad()
def generate_images(model, vq_model, text_encoder, tokenizer, accelerator, config, global_step):
def generate_images(model, vq_decode, text_encoder, tokenizer, accelerator, config, global_step):
logger.info("Generating images...")
model.eval()
# fmt: off
@@ -634,7 +641,7 @@ def generate_images(model, vq_model, text_encoder, tokenizer, accelerator, confi
# In the beginning of training, the model is not fully trained and the generated token ids can be out of range
# so we clamp them to the correct range.
gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1)
images = vq_model.decode_code(gen_token_ids)
images = vq_decode(gen_token_ids)
model.train()

# Convert to PIL images