Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Jun 29, 2023
1 parent 2f9dfd3 commit 3d9adf0
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 65 deletions.
9 changes: 3 additions & 6 deletions benchmark/muse_chart.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import pandas as pd
from argparse import ArgumentParser

df = pd.read_csv("artifacts/all.csv")

Expand All @@ -20,11 +21,7 @@


def chart(device, component, compiled, plot_on, legend, y_axis_key, y_label, timesteps):
filter = (
(df["Device"] == device)
& (df["Component"] == component)
& (df["Compilation Type"] == compiled)
)
filter = (df["Device"] == device) & (df["Component"] == component) & (df["Compilation Type"] == compiled)

if timesteps is not None:
filter = filter & (df["Timesteps"] == timesteps)
Expand Down
71 changes: 23 additions & 48 deletions benchmark/muse_perf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch
from torch.utils.benchmark import Timer, Compare
from muse.modeling_taming_vqgan import VQGANModel
from muse.modeling_transformer import MaskGiTUViT
from muse import PipelineMuse, PaellaVQModel
import csv
import multiprocessing
import traceback
from argparse import ArgumentParser
import csv
from diffusers import UNet2DConditionModel, AutoencoderKL, StableDiffusionPipeline

from transformers import CLIPTextModel, AutoTokenizer, CLIPTokenizer
import torch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from torch.utils.benchmark import Compare, Timer
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer

from muse import PaellaVQModel, PipelineMuse
from muse.modeling_taming_vqgan import VQGANModel
from muse.modeling_transformer import MaskGiTUViT

torch.manual_seed(0)
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -149,7 +150,7 @@ def main():
component,
timesteps,
mem_bytes,
iqr
iqr,
]
)

Expand All @@ -162,17 +163,11 @@ def main():


def muse_benchmark_transformer_backbone(in_queue, out_queue, timeout):
wrap_subprocess_fn(
in_queue, out_queue, timeout, _muse_benchmark_transformer_backbone
)
wrap_subprocess_fn(in_queue, out_queue, timeout, _muse_benchmark_transformer_backbone)


def _muse_benchmark_transformer_backbone(
device, dtype, compiled, batch_size, model, label, description, timesteps
):
text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder").to(
device
)
def _muse_benchmark_transformer_backbone(device, dtype, compiled, batch_size, model, label, description, timesteps):
text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder").to(device)

tokenizer = AutoTokenizer.from_pretrained(model, subfolder="text_encoder")

Expand All @@ -198,9 +193,7 @@ def _muse_benchmark_transformer_backbone(
if compiled is not None:
transformer = torch.compile(transformer, mode=compiled)

image_tokens = torch.full(
(batch_size, 256), fill_value=5, dtype=torch.long, device=device
)
image_tokens = torch.full((batch_size, 256), fill_value=5, dtype=torch.long, device=device)

def benchmark_fn():
transformer(image_tokens, encoder_hidden_states=encoder_hidden_states)
Expand All @@ -226,19 +219,15 @@ def sd_benchmark_unet_backbone(in_queue, out_queue, timeout):
wrap_subprocess_fn(in_queue, out_queue, timeout, _sd_benchmark_unet_backbone)


def _sd_benchmark_unet_backbone(
device, dtype, compiled, batch_size, model, label, description, timesteps
):
def _sd_benchmark_unet_backbone(device, dtype, compiled, batch_size, model, label, description, timesteps):
unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet")

unet = unet.to(device=device, dtype=dtype)

if compiled is not None:
unet = torch.compile(unet, mode=compiled)

text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder").to(
device
)
text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder").to(device)

tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer")

Expand Down Expand Up @@ -285,9 +274,7 @@ def muse_benchmark_vae(in_queue, out_queue, timeout):
wrap_subprocess_fn(in_queue, out_queue, timeout, _muse_benchmark_vae)


def _muse_benchmark_vae(
device, dtype, compiled, batch_size, model, label, description, timesteps
):
def _muse_benchmark_vae(device, dtype, compiled, batch_size, model, label, description, timesteps):
vae_cls = model_config[model]["vae"]["cls"]
vae = vae_cls.from_pretrained(model, subfolder="vae")

Expand All @@ -296,9 +283,7 @@ def _muse_benchmark_vae(
if compiled is not None:
vae = torch.compile(vae, mode=compiled)

image_tokens = torch.full(
(batch_size, 256), fill_value=5, dtype=torch.long, device=device
)
image_tokens = torch.full((batch_size, 256), fill_value=5, dtype=torch.long, device=device)

def benchmark_fn():
vae.decode_code(image_tokens)
Expand All @@ -324,9 +309,7 @@ def sd_benchmark_vae(in_queue, out_queue, timeout):
wrap_subprocess_fn(in_queue, out_queue, timeout, _sd_benchmark_vae)


def _sd_benchmark_vae(
device, dtype, compiled, batch_size, model, label, description, timesteps
):
def _sd_benchmark_vae(device, dtype, compiled, batch_size, model, label, description, timesteps):
vae = AutoencoderKL.from_pretrained(model, subfolder="vae")

vae = vae.to(device=device, dtype=dtype)
Expand Down Expand Up @@ -360,14 +343,10 @@ def muse_benchmark_full(in_queue, out_queue, timeout):
wrap_subprocess_fn(in_queue, out_queue, timeout, _muse_benchmark_full)


def _muse_benchmark_full(
device, dtype, compiled, batch_size, model, label, description, timesteps
):
def _muse_benchmark_full(device, dtype, compiled, batch_size, model, label, description, timesteps):
tokenizer = AutoTokenizer.from_pretrained(model, subfolder="text_encoder")

text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder").to(
device=device, dtype=dtype
)
text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder").to(device=device, dtype=dtype)

vae_cls = model_config[model]["vae"]["cls"]
vae = vae_cls.from_pretrained(model, subfolder="vae")
Expand Down Expand Up @@ -415,14 +394,10 @@ def sd_benchmark_full(in_queue, out_queue, timeout):
wrap_subprocess_fn(in_queue, out_queue, timeout, _sd_benchmark_full)


def _sd_benchmark_full(
device, dtype, compiled, batch_size, model, label, description, timesteps
):
def _sd_benchmark_full(device, dtype, compiled, batch_size, model, label, description, timesteps):
tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer")

text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder").to(
device=device, dtype=dtype
)
text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder").to(device=device, dtype=dtype)

vae = AutoencoderKL.from_pretrained(model, subfolder="vae")

Expand Down
2 changes: 1 addition & 1 deletion muse/modeling_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ def generate2(

if return_intermediate:
intermediate = []

if guidance_schedule == "linear":
guidance_scales = torch.linspace(0, guidance_scale, timesteps)
else:
Expand Down
21 changes: 14 additions & 7 deletions scripts/log_inpainting_images.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import copy
import json
import os
from argparse import ArgumentParser
from itertools import islice

import numpy as np
import torch
import wandb
import numpy as np
from muse import PipelineMuseInpainting
from PIL import Image
import copy
import os

from muse import PipelineMuseInpainting


def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
Expand All @@ -31,14 +34,17 @@ def generate_and_log(args):
inputs = {"text": args.text}

mask = np.zeros((args.image_size // vae_scaling_factor, args.image_size // vae_scaling_factor))
mask[args.mask_start_x:args.mask_end_x, args.mask_start_y:args.mask_end_y] = 1
mask[args.mask_start_x : args.mask_end_x, args.mask_start_y : args.mask_end_y] = 1
mask = mask.reshape(-1)
mask = torch.tensor(mask).to(args.device, dtype=torch.bool)

image = Image.open(args.input_image).resize((args.image_size, args.image_size))

masked_image = copy.deepcopy(np.array(image))
masked_image[args.mask_start_x*vae_scaling_factor:args.mask_end_x*vae_scaling_factor, args.mask_start_y*vae_scaling_factor:args.mask_end_y*vae_scaling_factor] = 0
masked_image[
args.mask_start_x * vae_scaling_factor : args.mask_end_x * vae_scaling_factor,
args.mask_start_y * vae_scaling_factor : args.mask_end_y * vae_scaling_factor,
] = 0
masked_image = Image.fromarray(masked_image)
masked_image.save(os.path.join(args.output_dir, "segmented.jpg"))
images = pipe(
Expand All @@ -50,7 +56,7 @@ def generate_and_log(args):
temperature=args.temperature,
use_maskgit_generate=not args.not_maskgit_generate,
num_images_per_prompt=args.num_generations,
image_size=args.image_size
image_size=args.image_size,
)

if args.is_class_conditioned:
Expand All @@ -62,6 +68,7 @@ def generate_and_log(args):
for i, image in enumerate(images):
image.save(os.path.join(args.output_dir, f"output_{i}.jpg"))


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--is_class_conditioned", action="store_true")
Expand Down
6 changes: 3 additions & 3 deletions training/train_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def main():
)

if accelerator.distributed_type == DistributedType.DEEPSPEED:
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = (
config.training.batch_size
)
accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = config.training.batch_size

#####################################
# SETUP LOGGING, SEED and CONFIG #
Expand Down

0 comments on commit 3d9adf0

Please sign in to comment.