Skip to content

Commit

Permalink
build a "linearized" unet using https://github.com/jansel/pytorch-jit…
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Mar 16, 2023
1 parent ceb2af2 commit 41828ef
Show file tree
Hide file tree
Showing 12 changed files with 11,138 additions and 5,544 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Early days of a lightweight MLIR Python frontend with support for PyTorch (throu
Just

```shell
pip install - requirements.txt
pip install -r requirements.txt
pip install . --no-build-isolation
```

Expand Down
74 changes: 53 additions & 21 deletions examples/unet.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,59 @@
import pi
from pi import nn
from pi.mlir.utils import pipile
from pi.utils.annotations import annotate_args
import torch
import numpy as np

from pi.models.unet import UNet2DConditionModel
import torch_mlir

unet = UNet2DConditionModel(
**{
"block_out_channels": (32, 64),
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
)
unet.eval()

batch_size = 4
num_channels = 4
sizes = (32, 32)


def floats_tensor(shape, scale=1.0, rng=None, name=None):

total_dims = 1
for dim in shape:
total_dims *= dim

values = []
for _ in range(total_dims):
values.append(np.random.random() * scale)

return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()


noise = floats_tensor((batch_size, num_channels) + sizes)
time_step = torch.tensor([10])
encoder_hidden_states = floats_tensor((batch_size, 4, 32))

class MyUNet(nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel()
output = unet(noise, time_step, encoder_hidden_states)
print(output)

@annotate_args(
[
None,
([-1, -1, -1, -1], pi.float32, True),
]
)
def forward(self, x):
y = self.resnet(x)
return y
traced = torch.jit.trace(unet, (noise, time_step, encoder_hidden_states), strict=False)
frozen = torch.jit.freeze(traced)
print(frozen.graph)


test_module = MyUNet()
x = pi.randn((1, 3, 64, 64))
mlir_module = pipile(test_module, example_args=(x,))
print(mlir_module)
module = torch_mlir.compile(
frozen,
(noise, time_step, encoder_hidden_states),
use_tracing=True,
output_type=torch_mlir.OutputType.RAW,
)
with open("unet.mlir", "w") as f:
f.write(str(module))
Loading

0 comments on commit 41828ef

Please sign in to comment.