From b4cc54cf1722c4cf9b951a2d83423e72bd4f9aa4 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Fri, 6 Dec 2024 07:16:42 -0800 Subject: [PATCH] Implement MMDIT Single block for Flux (#648) --- sharktank/sharktank/layers/__init__.py | 2 +- sharktank/sharktank/layers/mmdit.py | 42 ++++++++++ sharktank/sharktank/layers/testing.py | 109 +++++++++++++++++++------ sharktank/tests/layers/mmdit_test.py | 28 ++++++- 4 files changed, 153 insertions(+), 28 deletions(-) diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index 620c15672..5828d2dd3 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -17,6 +17,6 @@ from .ffn_block import FFN from .ffn_moe_block import FFNMOE from .mixture_of_experts_block import MoeBlock -from .mmdit import MMDITDoubleBlock +from .mmdit import MMDITDoubleBlock, MMDITSingleBlock from .configs import * diff --git a/sharktank/sharktank/layers/mmdit.py b/sharktank/sharktank/layers/mmdit.py index 0b0750549..0c970ab35 100644 --- a/sharktank/sharktank/layers/mmdit.py +++ b/sharktank/sharktank/layers/mmdit.py @@ -144,3 +144,45 @@ def forward( txt = txt + txt_mod2.gate * txt_mlp_out3 return img, txt + + +class MMDITSingleBlock(ThetaLayer): + def __init__(self, theta, num_heads: int): + super().__init__(theta) + + self.num_heads = num_heads + self.add_module("mod", ModulationLayer(theta("mod"), double=False)) + self.add_module( + "attn_norm_q", RMSNormLayer(theta("attn.norm.query_norm"), epsilon=1e-6) + ) + self.add_module( + "attn_norm_k", RMSNormLayer(theta("attn.norm.key_norm"), epsilon=1e-6) + ) + self.add_module("attn_proj", LinearLayer(theta("attn.proj"))) + + self.add_module("linear1", LinearLayer(theta("linear1"))) + self.add_module("linear2", LinearLayer(theta("linear2"))) + # TODO: There should be a way to refactor out the following two constants and just reference model shapes + self.hidden_size = 3072 + self.mlp_hidden_dim = 3072 + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + mod, _ = self.mod(vec) + x_norm = ops.layer_norm(x, None, None, eps=1e-6) + x_mod = (1 + mod.scale) * x_norm + mod.shift + x_lin = self.linear1(x_mod) + qkv, mlp = torch.split( + x_lin, [3 * self.hidden_size, 4 * self.mlp_hidden_dim], dim=-1 + ) + + qkv_2 = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1) # + qkv_3 = ops.permute(qkv_2, (2, 0, 3, 1, 4)) + q, k, v = qkv_3 + q, k = qk_norm(q, k, v, self.attn_norm_q, self.attn_norm_k) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + gelu = ops.elementwise(F.gelu, mlp) + output = self.linear2(torch.cat((attn, gelu), 2)) + return x + mod.gate * output diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py index a21d5bf85..74ba49624 100644 --- a/sharktank/sharktank/layers/testing.py +++ b/sharktank/sharktank/layers/testing.py @@ -51,80 +51,139 @@ def make_llama_attention_block_theta( ) -def make_mmdit_double_block_theta(dtype: torch.dtype | None = None) -> Theta: +def make_mmdit_double_block_random_theta( + in_channels: int = 128, + hidden_size: int = 3072, + mlp_ratio: float = 4.0, + dtype: torch.dtype | None = None, +) -> Theta: + in_channels = 128 + hidden_size = 3072 + mlp_ratio = 4.0 + mlp_hidden_size = int((mlp_ratio - 1) * hidden_size) + mlp_hidden_size2 = int(mlp_ratio * hidden_size) + mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size) return Theta( { "img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((128,), dtype=dtype) + data=make_rand_torch((in_channels,), dtype=dtype) ), "img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((128,), dtype=dtype) + data=make_rand_torch((in_channels,), dtype=dtype) ), "img_attn.proj.bias": DefaultPrimitiveTensor( - data=make_rand_torch((3072,), dtype=dtype) + data=make_rand_torch((hidden_size,), dtype=dtype) ), "img_attn.proj.weight": DefaultPrimitiveTensor( - data=make_rand_torch((3072, 3072), dtype=dtype) + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) ), "img_attn.qkv.bias": DefaultPrimitiveTensor( - data=make_rand_torch((9216,), dtype=dtype) + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) ), "img_attn.qkv.weight": DefaultPrimitiveTensor( - data=make_rand_torch((9216, 3072), dtype=dtype) + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) ), "img_mlp.0.bias": DefaultPrimitiveTensor( - data=make_rand_torch((12288), dtype=dtype) + data=make_rand_torch((mlp_hidden_size2), dtype=dtype) ), "img_mlp.0.weight": DefaultPrimitiveTensor( - data=make_rand_torch((12288, 3072), dtype=dtype) + data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) ), "img_mlp.2.bias": DefaultPrimitiveTensor( - data=make_rand_torch((3072), dtype=dtype) + data=make_rand_torch((hidden_size), dtype=dtype) ), "img_mlp.2.weight": DefaultPrimitiveTensor( - data=make_rand_torch((3072, 12288), dtype=dtype) + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) ), "img_mod.lin.bias": DefaultPrimitiveTensor( - data=make_rand_torch((18432,), dtype=dtype) + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) ), "img_mod.lin.weight": DefaultPrimitiveTensor( - data=make_rand_torch((18432, 3072), dtype=dtype) + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) ), "txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((128,), dtype=dtype) + data=make_rand_torch((in_channels,), dtype=dtype) ), "txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((128,), dtype=dtype) + data=make_rand_torch((in_channels,), dtype=dtype) ), "txt_attn.proj.bias": DefaultPrimitiveTensor( - data=make_rand_torch((3072,), dtype=dtype) + data=make_rand_torch((hidden_size,), dtype=dtype) ), "txt_attn.proj.weight": DefaultPrimitiveTensor( - data=make_rand_torch((3072, 3072), dtype=dtype) + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) ), "txt_attn.qkv.bias": DefaultPrimitiveTensor( - data=make_rand_torch((9216,), dtype=dtype) + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) ), "txt_attn.qkv.weight": DefaultPrimitiveTensor( - data=make_rand_torch((9216, 3072), dtype=dtype) + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) ), "txt_mlp.0.bias": DefaultPrimitiveTensor( - data=make_rand_torch((12288), dtype=dtype) + data=make_rand_torch((mlp_hidden_size2), dtype=dtype) ), "txt_mlp.0.weight": DefaultPrimitiveTensor( - data=make_rand_torch((12288, 3072), dtype=dtype) + data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) ), "txt_mlp.2.bias": DefaultPrimitiveTensor( - data=make_rand_torch((3072), dtype=dtype) + data=make_rand_torch((hidden_size), dtype=dtype) ), "txt_mlp.2.weight": DefaultPrimitiveTensor( - data=make_rand_torch((3072, 12288), dtype=dtype) + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) ), "txt_mod.lin.bias": DefaultPrimitiveTensor( - data=make_rand_torch((18432,), dtype=dtype) + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) ), "txt_mod.lin.weight": DefaultPrimitiveTensor( - data=make_rand_torch((18432, 3072), dtype=dtype) + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) + ), + } + ) + + +def make_mmdit_single_block_random_theta( + in_channels: int = 128, + hidden_size: int = 3072, + mlp_ratio: float = 4.0, + dtype: torch.dtype | None = None, +) -> Theta: + in_channels = 128 + hidden_size = 3072 + mlp_ratio = 4.0 + mlp_hidden_size = int((mlp_ratio - 1) * hidden_size) + mlp_hidden_size2 = int((mlp_ratio + 1) * hidden_size) + mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size) + return Theta( + { + "attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels,), dtype=dtype) + ), + "attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels,), dtype=dtype) + ), + "attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "linear1.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) + ), + "linear1.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) + ), + "linear2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "linear2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) + ), + "mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) ), } ) diff --git a/sharktank/tests/layers/mmdit_test.py b/sharktank/tests/layers/mmdit_test.py index 5bd5ce39a..d265b33d8 100644 --- a/sharktank/tests/layers/mmdit_test.py +++ b/sharktank/tests/layers/mmdit_test.py @@ -15,10 +15,12 @@ from iree.turbine import aot from sharktank.layers import ( MMDITDoubleBlock, + MMDITSingleBlock, ) import sharktank.ops as ops from sharktank.layers.testing import ( - make_mmdit_double_block_theta, + make_mmdit_double_block_random_theta, + make_mmdit_single_block_random_theta, ) from sharktank.types.tensors import DefaultPrimitiveTensor @@ -32,7 +34,7 @@ def setUp(self): def testDoubleExport(self): - theta = make_mmdit_double_block_theta() + theta = make_mmdit_double_block_random_theta() mmdit = MMDITDoubleBlock( theta=theta, num_heads=self.num_heads, @@ -53,6 +55,28 @@ def _(model, img, txt, vec, rot) -> torch.Tensor: output.verify() asm = str(output.mlir_module) + def testSingleExport(self): + + theta = make_mmdit_single_block_random_theta() + mmdit = MMDITSingleBlock( + theta=theta, + num_heads=self.num_heads, + ) + + inp = torch.rand([self.batch_size, 1024, self.hidden_size]) + vec = torch.rand([self.batch_size, self.hidden_size]) + rot = torch.rand([self.batch_size, 1, 1024, 64, 2, 2]) + mmdit.forward(inp, vec, rot) + fxb = aot.FxProgramsBuilder(mmdit) + + @fxb.export_program(name="mmdit", args=(inp, vec, rot), strict=False) + def _(model, inp, vec, rot) -> torch.Tensor: + return model.forward(inp, vec, rot) + + output = aot.export(fxb) + output.verify() + asm = str(output.mlir_module) + if __name__ == "__main__": unittest.main()