Skip to content

Commit

Permalink
Merge branch 'nod-ai:main' into flux_vae
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod authored Jan 3, 2025
2 parents 3e7673c + d42cc29 commit 41a6fc7
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 225 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
ref: iree-3.1.0rc20241204
ref: iree-3.1.0rc20241220

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64_asan-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_SOURCE_DIR }}
submodules: false
ref: iree-3.1.0rc20241204
ref: iree-3.1.0rc20241220

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_SOURCE_DIR }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64_nogil-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
ref: iree-3.1.0rc20241204
ref: iree-3.1.0rc20241220

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
Expand Down
55 changes: 48 additions & 7 deletions sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ...layers import *
from ...types import *
from ...utils.create_cache import *
from ...utils.testing import make_rand_torch
from ... import ops

__all__ = [
Expand Down Expand Up @@ -196,13 +197,37 @@ def sample_inputs(
if not (function is None or function == "forward"):
raise ValueError(f'Only function "forward" is supported. Got "{function}"')

# TODO: do not hardcode these but derive the required shapes from the config.
img = torch.rand([batch_size, 1024, 64], dtype=self.dtype)
img_ids = torch.rand([batch_size, 1024, 3], dtype=torch.float32)
txt = torch.rand([batch_size, 512, 4096], dtype=self.dtype)
txt_ids = torch.rand([batch_size, 512, 3], dtype=torch.float32)
# The allowed range of these values is dependent on the model size.
# They will not work for all variants, specifically toy-sized models.
output_img_height = 1024
output_img_width = 1024
output_img_channels = 3

img = self._get_noise(
batch_size, output_img_height, output_img_width, self.dtype
)

_, c, h, w = img.shape
img = img.reshape(batch_size, h * w // 4, c * 4)

img_ids = torch.zeros(h // 2, w // 2, output_img_channels)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = img_ids.reshape(1, h * w // 4, output_img_channels)
img_ids = img_ids.repeat(batch_size, 1, 1)

# T5 encoder output
txt_context_length = 512
txt_dims_per_token = 4096
txt = torch.rand([1, txt_context_length, txt_dims_per_token], dtype=self.dtype)
txt = txt.repeat(batch_size, 1, 1)
txt_ids = torch.zeros(batch_size, txt.shape[1], output_img_channels)

timesteps = torch.rand([batch_size], dtype=self.dtype)
y = torch.rand([batch_size, 768], dtype=self.dtype)

# CLIP text model output
y = make_rand_torch([1, 768], dtype=self.dtype)
y = y.repeat(batch_size, 1)

args = tuple()
kwargs = OrderedDict(
Expand All @@ -217,10 +242,26 @@ def sample_inputs(
)

if self.guidance:
kwargs["guidance"] = torch.rand([batch_size], dtype=self.dtype)
kwargs["guidance"] = torch.full([batch_size], 3.5, dtype=self.dtype)

return args, kwargs

def _get_noise(
self,
batch_size: int,
height: int,
width: int,
dtype: torch.dtype,
):
return torch.randn(
batch_size,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
dtype=dtype,
)

def _deduce_dtype(self) -> torch.dtype:
dtype = self.theta("img_in.weight").dtype
assert (
Expand Down
26 changes: 15 additions & 11 deletions sharktank/sharktank/models/flux/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,8 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype):
return Theta(tensor_dict)


def export_dev_random_single_layer(
dtype: torch.dtype,
mlir_output_path: PathLike,
parameters_output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
):
rng_state = torch.get_rng_state()
torch.random.manual_seed(12345)

dtype = torch.bfloat16
params = FluxParams(
def make_dev_single_layer_config():
return FluxParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
Expand All @@ -241,6 +232,19 @@ def export_dev_random_single_layer(
qkv_bias=True,
guidance_embed=True,
)


def export_dev_random_single_layer(
dtype: torch.dtype,
mlir_output_path: PathLike,
parameters_output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
):
rng_state = torch.get_rng_state()
torch.random.manual_seed(12345)

dtype = torch.bfloat16
params = make_dev_single_layer_config()
theta = make_random_theta(params, dtype)
flux = FluxModelV1(
theta=theta,
Expand Down
35 changes: 26 additions & 9 deletions sharktank/sharktank/utils/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ def bfloat16_device_array_to_torch(
return torch.tensor(device_array.to_host())


def torch_tensor_to_device_array(
tensor: torch.Tensor, device: iree.runtime.HalDevice
) -> iree.runtime.DeviceArray:
if tensor.dtype == torch.bfloat16:
tensor_as_int16 = tensor.view(dtype=torch.int16)
device_array_as_int16 = iree.runtime.asdevicearray(
device, unbox_tensor(tensor_as_int16).to("cpu").numpy()
)
buffer_view = iree.runtime.HalBufferView(
buffer=device_array_as_int16._buffer_view.get_buffer(),
shape=device_array_as_int16._buffer_view.shape,
element_type=iree.runtime.HalElementType.BFLOAT_16,
)
return iree.runtime.DeviceArray(device, buffer_view)

return iree.runtime.asdevicearray(device, unbox_tensor(tensor).to("cpu").numpy())


def run_iree_module_function(
module: iree.runtime.VmModule,
vm_context: iree.runtime.VmContext,
Expand Down Expand Up @@ -180,11 +198,7 @@ def prepare_iree_module_function_args(
]
)
elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)):
res.append(
iree.runtime.asdevicearray(
devices[0], unbox_tensor(arg).to("cpu").numpy()
)
)
res.append(torch_tensor_to_device_array(arg, devices[0]))
else:
assert isinstance(arg, collections.abc.Sequence)
res.extend(prepare_iree_module_function_args(arg, devices))
Expand All @@ -200,24 +214,27 @@ def flatten_for_iree_signature(tree: Tree) -> List[torch.Tensor]:
def call_torch_module_function(
module: torch.nn.Module,
function_name: str,
kwargs: OrderedDict,
args: Optional[tuple[AnyTensor]] = None,
kwargs: Optional[OrderedDict] = None,
trace_path_prefix: Optional[str] = None,
):
"""Call a torch module function with optional tracing.
For tracing the arguments/results are flattened to match IREE's signature."""
args = args if args is not None else tuple()
kwargs = kwargs if kwargs is not None else OrderedDict()
assert isinstance(
kwargs, OrderedDict
), "Make sure when flattening the order is preserved"
if trace_path_prefix is not None:
flat_args = flatten_for_iree_signature(kwargs)
flat_args = flatten_for_iree_signature([args, kwargs])
for i, arg in enumerate(flat_args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}.npy",
promote_bfloat16_to_float32(arg.to("cpu")).numpy(),
)
res = getattr(module, function_name)(**kwargs)
res = getattr(module, function_name)(*args, **kwargs)
if trace_path_prefix is not None:
flat_args = flatten_for_iree_signature(kwargs)
flat_args = flatten_for_iree_signature([args, kwargs])
for i, arg in enumerate(flat_args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}.npy",
Expand Down
Loading

0 comments on commit 41a6fc7

Please sign in to comment.