Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend VaeDecoderModel for flux compatibility #750

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sharktank/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ gguf>=0.11.0
numpy

# Model deps.
huggingface-hub==0.22.2
huggingface-hub
transformers==4.40.0
datasets
einops

# Serving deps.
fastapi>=0.112.2
uvicorn>=0.30.6
2 changes: 2 additions & 0 deletions sharktank/sharktank/models/vae/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class HParams:
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.13025
use_post_quant_conv: bool = True
shift_factor: float = 0.0

def assert_default_values(self, attr_names: Sequence[str]):
for name in attr_names:
Expand Down
3 changes: 1 addition & 2 deletions sharktank/sharktank/models/vae/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
ResnetBlock2D,
Upsample2D,
GroupNormLayer,
AttentionLayer,
)
from .config import *

Expand Down Expand Up @@ -84,7 +83,6 @@ def forward(
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)

query = self.to_q(hidden_states)

if encoder_hidden_states is None:
Expand All @@ -110,6 +108,7 @@ def forward(
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, self.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = self.to_out(hidden_states)
Expand Down
28 changes: 23 additions & 5 deletions sharktank/sharktank/models/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from .layers import *
from sharktank.models.punet.layers import UpDownBlock2D, GroupNormLayer
from typing import Optional
from einops import rearrange
import math


class VaeDecoderModel(ThetaLayer):
Expand All @@ -23,12 +25,13 @@ def from_dataset(cls, ds: Dataset) -> "VaeDecoderModel":
hp = HParams.from_dict(ds.properties["hparams"])
return cls(hp, ds.root_theta)

def __init__(self, hp: HParams, theta: Theta):
def __init__(self, hp, theta: Theta):
super().__init__(theta)
self.hp = hp

# input conv
self.post_quant_conv = Conv2DLayer(theta("post_quant_conv"), padding=(0, 0))
if hp.use_post_quant_conv:
self.post_quant_conv = Conv2DLayer(theta("post_quant_conv"), padding=(0, 0))
self.conv_in = Conv2DLayer(theta("decoder")("conv_in"), padding=(1, 1))
# Mid
self.mid_block = self._create_mid_block(theta("decoder")("mid_block"))
Expand Down Expand Up @@ -71,9 +74,20 @@ def forward(
"latent_embeds": latent_embeds,
},
)
sample = 1 / self.hp.scaling_factor * sample
if not self.hp.use_post_quant_conv:
sample = rearrange(
sample,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(1024 / 16),
w=math.ceil(1024 / 16),
ph=2,
pw=2,
)
sample = sample / self.hp.scaling_factor + self.hp.shift_factor

if self.hp.use_post_quant_conv:
sample = self.post_quant_conv(sample)

sample = self.post_quant_conv(sample)
sample = self.conv_in(sample)
self.trace_golden("conv_in", sample)
# TODO add training and gradient checkpointing support
Expand All @@ -90,7 +104,11 @@ def forward(

sample = self.conv_act(sample)
sample = self.conv_out(sample)
sample = (sample / 2 + 0.5).clamp(0, 1)

if not self.hp.use_post_quant_conv:
sample = sample.clamp(-1, 1)
else:
sample = (sample / 2 + 0.5).clamp(0, 1)
return sample

def _create_mid_block(self, mid_block_theta: Theta) -> nn.Module:
Expand Down
32 changes: 31 additions & 1 deletion sharktank/sharktank/models/vae/tools/diffuser_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import argparse
import torch
from diffusers import AutoencoderKL
from einops import rearrange
import math


class VaeModel(torch.nn.Module):
Expand Down Expand Up @@ -51,6 +53,34 @@ def decode(self, inp):


def run_torch_vae(hf_model_name, example_input):

vae_model = VaeModel(hf_model_name)
return vae_model.decode(example_input)


# TODO Remove and integrate with VaeModel
class FluxAEWrapper(torch.nn.Module):
def __init__(self, height=1024, width=1024):
super().__init__()
self.ae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
)
self.height = height
self.width = width

def forward(self, z):
d_in = rearrange(
z,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(self.height / 16),
w=math.ceil(self.width / 16),
ph=2,
pw=2,
)
d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor
return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1)


def run_flux_vae(example_input, dtype):
# TODO add support for other height/width sizes
vae_model = FluxAEWrapper(1024, 1024).to(dtype)
return vae_model.forward(example_input)
28 changes: 22 additions & 6 deletions sharktank/sharktank/models/vae/tools/run_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
import numpy as np


def export_vae(model, sample_inputs, decomp_attn):
Expand Down Expand Up @@ -81,6 +82,19 @@ def main(argv):
action="store_true",
help="Compares results vs HF diffusers reference model",
)

parser.add_argument(
"--torch_model",
default="stabilityai/stable-diffusion-xl-base-1.0",
help="HF reference model id",
)

parser.add_argument(
"--sharktank_config",
default="sdxl",
help="Sharktank config providing hyperparamters [sdxl or flux]",
)

parser.add_argument(
"--decomp_attn",
action="store_true",
Expand All @@ -95,12 +109,13 @@ def main(argv):
ds.to(device=device)

mdl = VaeDecoderModel.from_dataset(ds)

# Run a step for debugging.
if args.inputs:
inputs = load_inputs(args.inputs, dtype=dtype, device=device, bs=args.bs)
else:
inputs = get_random_inputs(dtype=dtype, device=device, bs=args.bs)
inputs = get_random_inputs(
dtype=dtype, device=device, bs=args.bs, config=args.sharktank_config
)

if args.export:
# TODO move export from a run_vae file
Expand All @@ -126,11 +141,12 @@ def main(argv):
intermediates_saver.save_file(args.save_intermediates_path)

if args.compare_vs_torch:
from .diffuser_ref import run_torch_vae
from .diffuser_ref import run_torch_vae, run_flux_vae

diffusers_results = run_torch_vae(
"stabilityai/stable-diffusion-xl-base-1.0", inputs
)
if args.sharktank_config == "flux":
diffusers_results = run_flux_vae(inputs, torch.bfloat16)
elif args.sharktank_config == "sdxl":
run_torch_vae(args.torch_model, inputs)
print("diffusers results:", diffusers_results)
torch.testing.assert_close(diffusers_results, results)

Expand Down
12 changes: 10 additions & 2 deletions sharktank/sharktank/models/vae/tools/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import torch


def get_random_inputs(dtype, device, bs: int = 2):
def get_random_inputs(dtype, device, bs: int = 2, config: str = "sdxl"):
height = 1024
width = 1024
return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device)
if config == "sdxl":
print("sdxl returning inputs")
return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device)
elif config == "flux":
print("flux returning inputs")
return torch.rand(bs, int(width * height / 256), 64, dtype=dtype).to(device)
else:
print("config: ", config)
raise AssertionError(f"{config} config not implmented [sdxl, flux] implemented")
5 changes: 4 additions & 1 deletion sharktank/sharktank/tools/import_hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def import_hf_dataset(
config_json_path: PathLike,
param_paths: list[PathLike],
output_irpa_file: Optional[PathLike] = None,
target_dtype=None,
) -> Optional[Dataset]:
import safetensors

Expand All @@ -50,7 +51,9 @@ def import_hf_dataset(
for params_path in param_paths:
with safetensors.safe_open(params_path, framework="pt", device="cpu") as st:
tensors = [
DefaultPrimitiveTensor(name=name, data=st.get_tensor(name))
DefaultPrimitiveTensor(
name=name, data=st.get_tensor(name).to(target_dtype)
)
for name in st.keys()
]

Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":

return squeeze(self, dim)

def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":
from ..ops import squeeze

return squeeze(self, dim)

def transpose(self, dim0: int, dim1: int) -> "AnyTensor":
from ..ops import transpose

Expand Down
Loading
Loading