Skip to content

Commit

Permalink
use lazy import
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Mar 17, 2023
1 parent 47e6ca8 commit 1a7c9c6
Show file tree
Hide file tree
Showing 9 changed files with 5,573 additions and 11,119 deletions.
115 changes: 74 additions & 41 deletions examples/unet.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,92 @@
import torch
import inspect
import re

import numpy as np
import torch

from pi.models.unet import UNet2DConditionModel
import torch_mlir

unet = UNet2DConditionModel(
**{
"block_out_channels": (32, 64),
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
)
unet.eval()

batch_size = 4
num_channels = 4
sizes = (32, 32)
from pi.lazy_importer.run_lazy_imports import do_package_imports, do_hand_imports
from pi.lazy_importer import lazy_imports


def floats_tensor(shape, scale=1.0, rng=None, name=None):
#


def floats_tensor(shape, scale=1.0, rng=None, name=None):
total_dims = 1
for dim in shape:
total_dims *= dim

values = []
for _ in range(total_dims):
values.append(np.random.random() * scale)

return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()


noise = floats_tensor((batch_size, num_channels) + sizes)
time_step = torch.tensor([10])
encoder_hidden_states = floats_tensor((batch_size, 4, 32))
def run(
CTor,
down_block_types=("CrossAttnDownBlock2D", "ResnetDownsampleBlock2D"),
up_block_types=("UpBlock2D", "ResnetUpsampleBlock2D"),
):
unet = CTor(
**{
"block_out_channels": (32, 64),
"down_block_types": down_block_types,
"up_block_types": up_block_types,
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
)
unet.eval()
batch_size = 4
num_channels = 4
sizes = (32, 32)

noise = floats_tensor((batch_size, num_channels) + sizes)
time_step = torch.tensor([10])
encoder_hidden_states = floats_tensor((batch_size, 4, 32))
output = unet(noise, time_step, encoder_hidden_states)


def make_linearized():
def filter(ret):
try:
MODULE_TARGET = lambda x: re.match(
r"(huggingface|torch|diffusers)", inspect.getmodule(x).__package__
)
return MODULE_TARGET(ret)
except:
return None

lazy_imports.MODULE_TARGET = filter

def _inner():

from diffusers import UNet2DConditionModel

run(
UNet2DConditionModel,
down_block_types=("CrossAttnDownBlock2D", "ResnetDownsampleBlock2D"),
up_block_types=("UpBlock2D", "ResnetUpsampleBlock2D"),
)
run(
UNet2DConditionModel,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
)

prefix = "from pi.models.unet.prologue import CONFIG_NAME, LORA_WEIGHT_NAME"
name = "unet_linearized"
do_package_imports(_inner, prefix, name)


output = unet(noise, time_step, encoder_hidden_states)
print(output)
def run_linearized():
from pi.models.unet import linearized

traced = torch.jit.trace(unet, (noise, time_step, encoder_hidden_states), strict=False)
frozen = torch.jit.freeze(traced)
print(frozen.graph)
run(linearized.UNet2DConditionModel)


module = torch_mlir.compile(
frozen,
(noise, time_step, encoder_hidden_states),
use_tracing=True,
output_type=torch_mlir.OutputType.RAW,
)
with open("unet.mlir", "w") as f:
f.write(str(module))
if __name__ == "__main__":
make_linearized()
Empty file added pi/lazy_importer/__init__.py
Empty file.
Loading

0 comments on commit 1a7c9c6

Please sign in to comment.