Skip to content

Commit

Permalink
Implement MMDIT Single block for Flux (nod-ai#648)
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon authored Dec 6, 2024
1 parent 1763a82 commit b4cc54c
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 28 deletions.
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import MoeBlock
from .mmdit import MMDITDoubleBlock
from .mmdit import MMDITDoubleBlock, MMDITSingleBlock

from .configs import *
42 changes: 42 additions & 0 deletions sharktank/sharktank/layers/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,45 @@ def forward(
txt = txt + txt_mod2.gate * txt_mlp_out3

return img, txt


class MMDITSingleBlock(ThetaLayer):
def __init__(self, theta, num_heads: int):
super().__init__(theta)

self.num_heads = num_heads
self.add_module("mod", ModulationLayer(theta("mod"), double=False))
self.add_module(
"attn_norm_q", RMSNormLayer(theta("attn.norm.query_norm"), epsilon=1e-6)
)
self.add_module(
"attn_norm_k", RMSNormLayer(theta("attn.norm.key_norm"), epsilon=1e-6)
)
self.add_module("attn_proj", LinearLayer(theta("attn.proj")))

self.add_module("linear1", LinearLayer(theta("linear1")))
self.add_module("linear2", LinearLayer(theta("linear2")))
# TODO: There should be a way to refactor out the following two constants and just reference model shapes
self.hidden_size = 3072
self.mlp_hidden_dim = 3072

def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
mod, _ = self.mod(vec)
x_norm = ops.layer_norm(x, None, None, eps=1e-6)
x_mod = (1 + mod.scale) * x_norm + mod.shift
x_lin = self.linear1(x_mod)
qkv, mlp = torch.split(
x_lin, [3 * self.hidden_size, 4 * self.mlp_hidden_dim], dim=-1
)

qkv_2 = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1) #
qkv_3 = ops.permute(qkv_2, (2, 0, 3, 1, 4))
q, k, v = qkv_3
q, k = qk_norm(q, k, v, self.attn_norm_q, self.attn_norm_k)

# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
gelu = ops.elementwise(F.gelu, mlp)
output = self.linear2(torch.cat((attn, gelu), 2))
return x + mod.gate * output
109 changes: 84 additions & 25 deletions sharktank/sharktank/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,80 +51,139 @@ def make_llama_attention_block_theta(
)


def make_mmdit_double_block_theta(dtype: torch.dtype | None = None) -> Theta:
def make_mmdit_double_block_random_theta(
in_channels: int = 128,
hidden_size: int = 3072,
mlp_ratio: float = 4.0,
dtype: torch.dtype | None = None,
) -> Theta:
in_channels = 128
hidden_size = 3072
mlp_ratio = 4.0
mlp_hidden_size = int((mlp_ratio - 1) * hidden_size)
mlp_hidden_size2 = int(mlp_ratio * hidden_size)
mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size)
return Theta(
{
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
data=make_rand_torch((in_channels,), dtype=dtype)
),
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
data=make_rand_torch((in_channels,), dtype=dtype)
),
"img_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072,), dtype=dtype)
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"img_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 3072), dtype=dtype)
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
),
"img_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((9216,), dtype=dtype)
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"img_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((9216, 3072), dtype=dtype)
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
),
"img_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((12288), dtype=dtype)
data=make_rand_torch((mlp_hidden_size2), dtype=dtype)
),
"img_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((12288, 3072), dtype=dtype)
data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype)
),
"img_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072), dtype=dtype)
data=make_rand_torch((hidden_size), dtype=dtype)
),
"img_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 12288), dtype=dtype)
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
),
"img_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((18432,), dtype=dtype)
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
),
"img_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((18432, 3072), dtype=dtype)
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
),
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
data=make_rand_torch((in_channels,), dtype=dtype)
),
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
data=make_rand_torch((in_channels,), dtype=dtype)
),
"txt_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072,), dtype=dtype)
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"txt_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 3072), dtype=dtype)
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
),
"txt_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((9216,), dtype=dtype)
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"txt_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((9216, 3072), dtype=dtype)
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
),
"txt_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((12288), dtype=dtype)
data=make_rand_torch((mlp_hidden_size2), dtype=dtype)
),
"txt_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((12288, 3072), dtype=dtype)
data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype)
),
"txt_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072), dtype=dtype)
data=make_rand_torch((hidden_size), dtype=dtype)
),
"txt_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 12288), dtype=dtype)
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
),
"txt_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((18432,), dtype=dtype)
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
),
"txt_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((18432, 3072), dtype=dtype)
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
),
}
)


def make_mmdit_single_block_random_theta(
in_channels: int = 128,
hidden_size: int = 3072,
mlp_ratio: float = 4.0,
dtype: torch.dtype | None = None,
) -> Theta:
in_channels = 128
hidden_size = 3072
mlp_ratio = 4.0
mlp_hidden_size = int((mlp_ratio - 1) * hidden_size)
mlp_hidden_size2 = int((mlp_ratio + 1) * hidden_size)
mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size)
return Theta(
{
"attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
),
"linear1.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
),
"linear1.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
),
"linear2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size), dtype=dtype)
),
"linear2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
),
"mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
),
}
)
28 changes: 26 additions & 2 deletions sharktank/tests/layers/mmdit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from iree.turbine import aot
from sharktank.layers import (
MMDITDoubleBlock,
MMDITSingleBlock,
)
import sharktank.ops as ops
from sharktank.layers.testing import (
make_mmdit_double_block_theta,
make_mmdit_double_block_random_theta,
make_mmdit_single_block_random_theta,
)
from sharktank.types.tensors import DefaultPrimitiveTensor

Expand All @@ -32,7 +34,7 @@ def setUp(self):

def testDoubleExport(self):

theta = make_mmdit_double_block_theta()
theta = make_mmdit_double_block_random_theta()
mmdit = MMDITDoubleBlock(
theta=theta,
num_heads=self.num_heads,
Expand All @@ -53,6 +55,28 @@ def _(model, img, txt, vec, rot) -> torch.Tensor:
output.verify()
asm = str(output.mlir_module)

def testSingleExport(self):

theta = make_mmdit_single_block_random_theta()
mmdit = MMDITSingleBlock(
theta=theta,
num_heads=self.num_heads,
)

inp = torch.rand([self.batch_size, 1024, self.hidden_size])
vec = torch.rand([self.batch_size, self.hidden_size])
rot = torch.rand([self.batch_size, 1, 1024, 64, 2, 2])
mmdit.forward(inp, vec, rot)
fxb = aot.FxProgramsBuilder(mmdit)

@fxb.export_program(name="mmdit", args=(inp, vec, rot), strict=False)
def _(model, inp, vec, rot) -> torch.Tensor:
return model.forward(inp, vec, rot)

output = aot.export(fxb)
output.verify()
asm = str(output.mlir_module)


if __name__ == "__main__":
unittest.main()

0 comments on commit b4cc54c

Please sign in to comment.