Skip to content

Commit 3d8cad8

Browse files
A basic working version of the flux model (#663)
This version of the flux model should work, as it directly modifies the reference implementation, but could really use some refactoring, especially to reduce code duplication --------- Co-authored-by: Boian Petkantchin <[email protected]>
1 parent f7d2681 commit 3d8cad8

File tree

3 files changed

+538
-4
lines changed

3 files changed

+538
-4
lines changed
+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
# Copyright 2024 Black Forest Labs. Inc. and Flux Authors
3+
# Copyright 2024 Advanced Micro Devices, Inc.
4+
#
5+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
6+
# See https://llvm.org/LICENSE.txt for license information.
7+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
"""Model adapted from black-forest-labs' flux implementation
9+
https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py
10+
"""
11+
12+
import math
13+
from dataclasses import dataclass
14+
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
17+
18+
from ...layers import *
19+
from ...types import *
20+
from ...utils.create_cache import *
21+
from ... import ops
22+
23+
__all__ = [
24+
"FluxModelV1",
25+
]
26+
27+
################################################################################
28+
# Models
29+
################################################################################
30+
31+
32+
@dataclass
33+
class FluxParams:
34+
in_channels: int
35+
out_channels: int
36+
vec_in_dim: int
37+
context_in_dim: int
38+
hidden_size: int
39+
mlp_ratio: float
40+
num_heads: int
41+
depth: int
42+
depth_single_blocks: int
43+
axes_dim: list[int]
44+
theta: int
45+
qkv_bias: bool
46+
guidance_embed: bool
47+
48+
49+
class FluxModelV1(ThetaLayer):
50+
"""FluxModel adapted from Black Forest Lab's implementation."""
51+
52+
def __init__(self, theta: Theta, params: FluxParams):
53+
super().__init__(
54+
theta,
55+
)
56+
57+
self.in_channels = params.in_channels
58+
self.out_channels = self.in_channels
59+
if params.hidden_size % params.num_heads != 0:
60+
raise ValueError(
61+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
62+
)
63+
pe_dim = params.hidden_size // params.num_heads
64+
if sum(params.axes_dim) != pe_dim:
65+
raise ValueError(
66+
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
67+
)
68+
self.hidden_size = params.hidden_size
69+
self.num_heads = params.num_heads
70+
self.pe_embedder = EmbedND(
71+
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
72+
)
73+
self.add_module("img_in", LinearLayer(theta("img_in")))
74+
# TODO: Refactor this pattern to an MLPEmbedder like src implementatio
75+
self.add_module("time_in_0", LinearLayer(theta("time_in.0")))
76+
self.add_module("time_in_1", LinearLayer(theta("time_in.1")))
77+
self.add_module("vector_in_0", LinearLayer(theta("vector_in.0")))
78+
self.add_module("vector_in_1", LinearLayer(theta("vector_in.1")))
79+
self.guidance = False
80+
if params.guidance_embed:
81+
self.guidance = True
82+
self.add_module("guidance_in_0", LinearLayer(theta("guidance_in.0")))
83+
self.add_module("guidance_in_1", LinearLayer(theta("guidance_in.1")))
84+
self.add_module("txt_in", LinearLayer(theta("txt_in")))
85+
86+
self.double_blocks = nn.ModuleList(
87+
[
88+
MMDITDoubleBlock(
89+
theta("double_blocks", i),
90+
self.num_heads,
91+
)
92+
for i in range(params.depth)
93+
]
94+
)
95+
96+
self.single_blocks = nn.ModuleList(
97+
[
98+
MMDITSingleBlock(
99+
theta("single_blocks", i),
100+
self.num_heads,
101+
)
102+
for i in range(params.depth_single_blocks)
103+
]
104+
)
105+
106+
self.add_module(
107+
"last_layer",
108+
LastLayer(theta("last_layer")),
109+
)
110+
111+
def forward(
112+
self,
113+
img: AnyTensor,
114+
img_ids: AnyTensor,
115+
txt: AnyTensor,
116+
txt_ids: AnyTensor,
117+
timesteps: AnyTensor,
118+
y: AnyTensor,
119+
guidance: AnyTensor | None = None,
120+
) -> AnyTensor:
121+
if img.ndim != 3 or txt.ndim != 3:
122+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
123+
124+
# running on sequences img
125+
img = self.img_in(img)
126+
time_in_0 = self.time_in_0(timestep_embedding(timesteps, 256))
127+
time_in_silu = ops.elementwise(F.silu, time_in_0)
128+
vec = self.time_in_1(time_in_silu)
129+
if self.guidance:
130+
if guidance is None:
131+
raise ValueError(
132+
"Didn't get guidance strength for guidance distilled model."
133+
)
134+
guidance_inp = timestep_embedding(guidance, 256)
135+
guidance0 = self.guidance_in0(guidance_inp)
136+
guidance_silu = ops.elementwise(F.silu, guidance0)
137+
guidance_out = self.guidance_in1(guidance_silu)
138+
vec = vec + self.guidance_in(guidance_out)
139+
vector_in_0 = self.vector_in_0(y)
140+
vector_in_silu = ops.elementwise(F.silu, vector_in_0)
141+
vector_in_1 = self.vector_in_1(vector_in_silu)
142+
vec = vec + vector_in_1
143+
144+
txt = self.txt_in(txt)
145+
146+
ids = torch.cat((txt_ids, img_ids), dim=1)
147+
pe = self.pe_embedder(ids)
148+
149+
for block in self.double_blocks:
150+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
151+
152+
img = torch.cat((txt, img), 1)
153+
for block in self.single_blocks:
154+
img = block(img, vec=vec, pe=pe)
155+
img = img[:, txt.shape[1] :, ...]
156+
157+
img = self.last_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
158+
return img
159+
160+
161+
################################################################################
162+
# Layers
163+
################################################################################
164+
165+
166+
# TODO: Refactor these functions to other files. Rope can probably be merged with
167+
# our rotary embedding layer, some of these functions are shared with layers/mmdit.py
168+
def timestep_embedding(
169+
t: AnyTensor, dim, max_period=10000, time_factor: float = 1000.0
170+
):
171+
"""
172+
Create sinusoidal timestep embeddings.
173+
:param t: a 1-D Tensor of N indices, one per batch element.
174+
These may be fractional.
175+
:param dim: the dimension of the output.
176+
:param max_period: controls the minimum frequency of the embeddings.
177+
:return: an (N, D) Tensor of positional embeddings.
178+
"""
179+
t = time_factor * t
180+
half = dim // 2
181+
freqs = torch.exp(
182+
-math.log(max_period)
183+
* torch.arange(start=0, end=half, dtype=torch.float32)
184+
/ half
185+
).to(t.device)
186+
187+
args = t[:, None].float() * freqs[None]
188+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
189+
if dim % 2:
190+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
191+
if torch.is_floating_point(t):
192+
embedding = embedding.to(t)
193+
return embedding
194+
195+
196+
def layer_norm(inp):
197+
weight = torch.ones(inp.shape)
198+
bias = torch.zeros(inp.shape)
199+
return ops.layer_norm(inp, weight, bias, eps=1e-6)
200+
201+
202+
def qk_norm(q, k, v, rms_q, rms_k):
203+
return rms_q(q).to(v), rms_k(k).to(v)
204+
205+
206+
def rope(pos: AnyTensor, dim: int, theta: int) -> AnyTensor:
207+
assert dim % 2 == 0
208+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
209+
omega = 1.0 / (theta**scale)
210+
out = torch.einsum("...n,d->...nd", pos, omega)
211+
out = torch.stack(
212+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
213+
)
214+
# out = out.view(out.shape[0], out.shape[1], out.shape[2], out.shape[3], 2, 2)
215+
out = out.view(out.shape[0], out.shape[1], out.shape[2], 2, 2)
216+
return out.float()
217+
218+
219+
class EmbedND(torch.nn.Module):
220+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
221+
super().__init__()
222+
self.dim = dim
223+
self.theta = theta
224+
self.axes_dim = axes_dim
225+
226+
def forward(self, ids: AnyTensor) -> AnyTensor:
227+
n_axes = ids.shape[-1]
228+
emb = torch.cat(
229+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
230+
dim=-3,
231+
)
232+
233+
return emb.unsqueeze(1)
234+
235+
236+
class LastLayer(ThetaLayer):
237+
def __init__(
238+
self,
239+
theta: Theta,
240+
):
241+
super().__init__(theta)
242+
self.add_module("outlinear", LinearLayer(theta("outlinear")))
243+
self.add_module("ada_linear", LinearLayer(theta("ada_linear")))
244+
245+
def forward(self, x: AnyTensor, vec: AnyTensor) -> AnyTensor:
246+
silu = ops.elementwise(F.silu, vec)
247+
lin = self.ada_linear(silu)
248+
shift, scale = lin.chunk(2, dim=1)
249+
x = (1 + scale[:, None, :]) * layer_norm(x) + shift[:, None, :]
250+
x = self.outlinear(x)
251+
return x

sharktank/tests/layers/mmdit_test.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
MMDITDoubleBlock,
1818
MMDITSingleBlock,
1919
)
20-
import sharktank.ops as ops
2120
from sharktank.layers.testing import (
2221
make_mmdit_double_block_random_theta,
2322
make_mmdit_single_block_random_theta,
2423
)
25-
from sharktank.types.tensors import DefaultPrimitiveTensor
24+
from sharktank.utils.testing import TempDirTestBase
25+
from sharktank.types import Dataset, Theta
2626

2727

28-
class MMDITTest(unittest.TestCase):
28+
class MMDITTest(TempDirTestBase):
2929
def setUp(self):
30+
super().setUp()
3031
torch.manual_seed(12345)
3132
self.hidden_size = 3072
3233
self.num_heads = 24
@@ -35,6 +36,7 @@ def setUp(self):
3536
def testDoubleExport(self):
3637

3738
theta = make_mmdit_double_block_random_theta()
39+
theta = self.save_load_theta(theta)
3840
mmdit = MMDITDoubleBlock(
3941
theta=theta,
4042
num_heads=self.num_heads,
@@ -58,6 +60,7 @@ def _(model, img, txt, vec, rot) -> torch.Tensor:
5860
def testSingleExport(self):
5961

6062
theta = make_mmdit_single_block_random_theta()
63+
theta = self.save_load_theta(theta)
6164
mmdit = MMDITSingleBlock(
6265
theta=theta,
6366
num_heads=self.num_heads,
@@ -73,10 +76,19 @@ def testSingleExport(self):
7376
def _(model, inp, vec, rot) -> torch.Tensor:
7477
return model.forward(inp, vec, rot)
7578

76-
output = aot.export(fxb)
79+
output = aot.export(fxb, import_symbolic_shape_expressions=True)
7780
output.verify()
7881
asm = str(output.mlir_module)
7982

83+
def save_load_theta(self, theta: Theta):
84+
# Roundtrip to disk to avoid treating parameters as constants that would appear
85+
# in the MLIR.
86+
theta.rename_tensors_to_paths()
87+
dataset = Dataset(root_theta=theta, properties={})
88+
file_path = self._temp_dir / "parameters.irpa"
89+
dataset.save(file_path)
90+
return Dataset.load(file_path).root_theta
91+
8092

8193
if __name__ == "__main__":
8294
unittest.main()

0 commit comments

Comments
 (0)