From 33dc99897f4d196e0cc0c3ee6d53a91a006124ba Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 3 Jan 2025 10:50:34 -0800 Subject: [PATCH] Add xfail tests for single layer Flux transformer that verify IREE results against Torch (#741) The test accuracy is not of sufficient quality and needs further investigation. The tests compare IREE bf16 and f32 against Torch f32. Refactor the sample input generation and make it produce noise images for final size of 1024x1024 instead of 512x512. Remove unused duplicated function for random Theta generation. --- sharktank/sharktank/models/flux/flux.py | 55 +++- sharktank/sharktank/models/flux/testing.py | 26 +- sharktank/sharktank/utils/iree.py | 35 ++- sharktank/tests/models/flux/flux_test.py | 346 +++++++++------------ 4 files changed, 243 insertions(+), 219 deletions(-) diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py index 531083ae1..ea84f8f04 100644 --- a/sharktank/sharktank/models/flux/flux.py +++ b/sharktank/sharktank/models/flux/flux.py @@ -21,6 +21,7 @@ from ...layers import * from ...types import * from ...utils.create_cache import * +from ...utils.testing import make_rand_torch from ... import ops __all__ = [ @@ -196,13 +197,37 @@ def sample_inputs( if not (function is None or function == "forward"): raise ValueError(f'Only function "forward" is supported. Got "{function}"') - # TODO: do not hardcode these but derive the required shapes from the config. - img = torch.rand([batch_size, 1024, 64], dtype=self.dtype) - img_ids = torch.rand([batch_size, 1024, 3], dtype=torch.float32) - txt = torch.rand([batch_size, 512, 4096], dtype=self.dtype) - txt_ids = torch.rand([batch_size, 512, 3], dtype=torch.float32) + # The allowed range of these values is dependent on the model size. + # They will not work for all variants, specifically toy-sized models. + output_img_height = 1024 + output_img_width = 1024 + output_img_channels = 3 + + img = self._get_noise( + batch_size, output_img_height, output_img_width, self.dtype + ) + + _, c, h, w = img.shape + img = img.reshape(batch_size, h * w // 4, c * 4) + + img_ids = torch.zeros(h // 2, w // 2, output_img_channels) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = img_ids.reshape(1, h * w // 4, output_img_channels) + img_ids = img_ids.repeat(batch_size, 1, 1) + + # T5 encoder output + txt_context_length = 512 + txt_dims_per_token = 4096 + txt = torch.rand([1, txt_context_length, txt_dims_per_token], dtype=self.dtype) + txt = txt.repeat(batch_size, 1, 1) + txt_ids = torch.zeros(batch_size, txt.shape[1], output_img_channels) + timesteps = torch.rand([batch_size], dtype=self.dtype) - y = torch.rand([batch_size, 768], dtype=self.dtype) + + # CLIP text model output + y = make_rand_torch([1, 768], dtype=self.dtype) + y = y.repeat(batch_size, 1) args = tuple() kwargs = OrderedDict( @@ -217,10 +242,26 @@ def sample_inputs( ) if self.guidance: - kwargs["guidance"] = torch.rand([batch_size], dtype=self.dtype) + kwargs["guidance"] = torch.full([batch_size], 3.5, dtype=self.dtype) return args, kwargs + def _get_noise( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + ): + return torch.randn( + batch_size, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + dtype=dtype, + ) + def _deduce_dtype(self) -> torch.dtype: dtype = self.theta("img_in.weight").dtype assert ( diff --git a/sharktank/sharktank/models/flux/testing.py b/sharktank/sharktank/models/flux/testing.py index e0354ff7b..6445366b2 100644 --- a/sharktank/sharktank/models/flux/testing.py +++ b/sharktank/sharktank/models/flux/testing.py @@ -216,17 +216,8 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype): return Theta(tensor_dict) -def export_dev_random_single_layer( - dtype: torch.dtype, - mlir_output_path: PathLike, - parameters_output_path: PathLike, - batch_sizes: list[int] = flux_transformer_default_batch_sizes, -): - rng_state = torch.get_rng_state() - torch.random.manual_seed(12345) - - dtype = torch.bfloat16 - params = FluxParams( +def make_dev_single_layer_config(): + return FluxParams( in_channels=64, out_channels=64, vec_in_dim=768, @@ -241,6 +232,19 @@ def export_dev_random_single_layer( qkv_bias=True, guidance_embed=True, ) + + +def export_dev_random_single_layer( + dtype: torch.dtype, + mlir_output_path: PathLike, + parameters_output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + rng_state = torch.get_rng_state() + torch.random.manual_seed(12345) + + dtype = torch.bfloat16 + params = make_dev_single_layer_config() theta = make_random_theta(params, dtype) flux = FluxModelV1( theta=theta, diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index a9097cf06..1ed47f45d 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -119,6 +119,24 @@ def bfloat16_device_array_to_torch( return torch.tensor(device_array.to_host()) +def torch_tensor_to_device_array( + tensor: torch.Tensor, device: iree.runtime.HalDevice +) -> iree.runtime.DeviceArray: + if tensor.dtype == torch.bfloat16: + tensor_as_int16 = tensor.view(dtype=torch.int16) + device_array_as_int16 = iree.runtime.asdevicearray( + device, unbox_tensor(tensor_as_int16).to("cpu").numpy() + ) + buffer_view = iree.runtime.HalBufferView( + buffer=device_array_as_int16._buffer_view.get_buffer(), + shape=device_array_as_int16._buffer_view.shape, + element_type=iree.runtime.HalElementType.BFLOAT_16, + ) + return iree.runtime.DeviceArray(device, buffer_view) + + return iree.runtime.asdevicearray(device, unbox_tensor(tensor).to("cpu").numpy()) + + def run_iree_module_function( module: iree.runtime.VmModule, vm_context: iree.runtime.VmContext, @@ -180,11 +198,7 @@ def prepare_iree_module_function_args( ] ) elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)): - res.append( - iree.runtime.asdevicearray( - devices[0], unbox_tensor(arg).to("cpu").numpy() - ) - ) + res.append(torch_tensor_to_device_array(arg, devices[0])) else: assert isinstance(arg, collections.abc.Sequence) res.extend(prepare_iree_module_function_args(arg, devices)) @@ -200,24 +214,27 @@ def flatten_for_iree_signature(tree: Tree) -> List[torch.Tensor]: def call_torch_module_function( module: torch.nn.Module, function_name: str, - kwargs: OrderedDict, + args: Optional[tuple[AnyTensor]] = None, + kwargs: Optional[OrderedDict] = None, trace_path_prefix: Optional[str] = None, ): """Call a torch module function with optional tracing. For tracing the arguments/results are flattened to match IREE's signature.""" + args = args if args is not None else tuple() + kwargs = kwargs if kwargs is not None else OrderedDict() assert isinstance( kwargs, OrderedDict ), "Make sure when flattening the order is preserved" if trace_path_prefix is not None: - flat_args = flatten_for_iree_signature(kwargs) + flat_args = flatten_for_iree_signature([args, kwargs]) for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", promote_bfloat16_to_float32(arg.to("cpu")).numpy(), ) - res = getattr(module, function_name)(**kwargs) + res = getattr(module, function_name)(*args, **kwargs) if trace_path_prefix is not None: - flat_args = flatten_for_iree_signature(kwargs) + flat_args = flatten_for_iree_signature([args, kwargs]) for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", diff --git a/sharktank/tests/models/flux/flux_test.py b/sharktank/tests/models/flux/flux_test.py index ee8e6d82e..34dedd4f1 100644 --- a/sharktank/tests/models/flux/flux_test.py +++ b/sharktank/tests/models/flux/flux_test.py @@ -5,216 +5,63 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import logging +import functools import unittest import torch import pytest -from sharktank.models.flux.flux import ( - FluxModelV1, - FluxParams, -) +import iree.compiler +from collections import OrderedDict from sharktank.models.flux.export import ( export_flux_transformer_from_hugging_face, + export_flux_transformer, ) -from sharktank.models.flux.testing import export_dev_random_single_layer -import sharktank.ops as ops -from sharktank.layers.testing import ( - make_rand_torch, +from sharktank.models.flux.testing import ( + export_dev_random_single_layer, + make_dev_single_layer_config, + make_random_theta, ) -from sharktank.types.tensors import DefaultPrimitiveTensor -from sharktank.types.theta import Dataset, Theta +from sharktank.models.flux.flux import FluxModelV1 from sharktank.utils.testing import TempDirTestBase -from sharktank.utils.hf_datasets import get_dataset +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +from sharktank import ops +from sharktank.transforms.dataset import set_float_dtype logging.basicConfig(level=logging.DEBUG) with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')") - -# TODO: Refactor this to a function that generates random toy weights, possibly -# to another file -in_channels = 64 -in_channels2 = 128 -hidden_size = 3072 -mlp_ratio = 4.0 -mlp_hidden_size = int((mlp_ratio - 1) * hidden_size) -mlp_hidden_size2 = int(mlp_ratio * hidden_size) -mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size) -mlp_hidden_size4 = int((mlp_ratio + 1) * hidden_size) -mlp_hidden_size5 = int((2 * mlp_ratio - 1) * hidden_size) -context_in_dim = 4096 -time_dim = 256 -vec_dim = 768 -patch_size = 1 -out_channels = 64 - - -def make_random_theta(dtype: torch.dtype): - return Theta( - { - "img_in.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size, in_channels), dtype=dtype) - ), - "img_in.bias": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "txt_in.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size, context_in_dim), dtype=dtype) - ), - "txt_in.bias": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "time_in.in_layer.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size, time_dim), dtype=dtype) - ), - "time_in.in_layer.bias": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "time_in.out_layer.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) - ), - "time_in.out_layer.bias": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "vector_in.in_layer.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size, vec_dim), dtype=dtype) - ), - "vector_in.in_layer.bias": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "vector_in.out_layer.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) - ), - "vector_in.out_layer.bias": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "double_blocks.0.img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # - data=make_rand_torch((in_channels2,), dtype=dtype) - ), - "double_blocks.0.img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # - data=make_rand_torch((in_channels2,), dtype=dtype) - ), - "double_blocks.0.img_attn.proj.bias": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "double_blocks.0.img_attn.proj.weight": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) - ), - "double_blocks.0.img_attn.qkv.bias": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size,), dtype=dtype) - ), - "double_blocks.0.img_attn.qkv.weight": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) - ), - "double_blocks.0.img_mlp.0.bias": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size2), dtype=dtype) - ), - "double_blocks.0.img_mlp.0.weight": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) - ), - "double_blocks.0.img_mlp.2.bias": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size), dtype=dtype) - ), - "double_blocks.0.img_mlp.2.weight": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) - ), - "double_blocks.0.img_mod.lin.bias": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) - ), - "double_blocks.0.img_mod.lin.weight": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) - ), - "double_blocks.0.txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # - data=make_rand_torch((in_channels2,), dtype=dtype) - ), - "double_blocks.0.txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # - data=make_rand_torch((in_channels2,), dtype=dtype) - ), - "double_blocks.0.txt_attn.proj.bias": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "double_blocks.0.txt_attn.proj.weight": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) - ), - "double_blocks.0.txt_attn.qkv.bias": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size,), dtype=dtype) - ), - "double_blocks.0.txt_attn.qkv.weight": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) - ), - "double_blocks.0.txt_mlp.0.bias": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size2), dtype=dtype) - ), - "double_blocks.0.txt_mlp.0.weight": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) - ), - "double_blocks.0.txt_mlp.2.bias": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size), dtype=dtype) - ), - "double_blocks.0.txt_mlp.2.weight": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) - ), - "double_blocks.0.txt_mod.lin.bias": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) - ), - "double_blocks.0.txt_mod.lin.weight": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) - ), - "single_blocks.0.norm.key_norm.scale": DefaultPrimitiveTensor( # - data=make_rand_torch((in_channels2,), dtype=dtype) - ), - "single_blocks.0.norm.query_norm.scale": DefaultPrimitiveTensor( # - data=make_rand_torch((in_channels2,), dtype=dtype) - ), - "single_blocks.0.attn.proj.bias": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size,), dtype=dtype) - ), - "single_blocks.0.attn.proj.weight": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) - ), - "single_blocks.0.linear1.bias": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size5,), dtype=dtype) - ), - "single_blocks.0.linear1.weight": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size5, hidden_size), dtype=dtype) - ), - "single_blocks.0.linear2.bias": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size), dtype=dtype) - ), - "single_blocks.0.linear2.weight": DefaultPrimitiveTensor( - data=make_rand_torch((hidden_size, mlp_hidden_size4), dtype=dtype) - ), - "single_blocks.0.modulation.lin.bias": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size,), dtype=dtype) - ), - "single_blocks.0.modulation.lin.weight": DefaultPrimitiveTensor( - data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) - ), - "final_layer.linear.weight": DefaultPrimitiveTensor( # - data=make_rand_torch( - (patch_size * patch_size * out_channels, hidden_size), dtype=dtype - ) - ), - "final_layer.linear.bias": DefaultPrimitiveTensor( # - data=make_rand_torch( - (patch_size * patch_size * out_channels,), dtype=dtype - ) - ), - "final_layer.adaLN_modulation.1.weight": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size * 2, hidden_size), dtype=dtype) - ), - "final_layer.adaLN_modulation.1.bias": DefaultPrimitiveTensor( # - data=make_rand_torch((hidden_size * 2,), dtype=dtype) - ), - } - ) +iree_compile_flags = [ + "--iree-hal-target-device=hip", + "--iree-hip-target=gfx942", + "--iree-opt-const-eval=false", + "--iree-opt-strip-assertions=true", + "--iree-global-opt-propagate-transposes=true", + "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", + "--iree-dispatch-creation-enable-aggressive-fusion=true", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-data-tiling=false", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-codegen-llvmgpu-use-vector-distribution", + "--iree-hip-waves-per-eu=2", + "--iree-execution-model=async-external", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", +] class FluxTest(TempDirTestBase): def setUp(self): super().setUp() torch.manual_seed(12345) - self.hidden_size = 3072 - self.num_heads = 24 - self.batch_size = 5 def testExportDevRandomSingleLayerBf16(self): export_dev_random_single_layer( @@ -224,6 +71,121 @@ def testExportDevRandomSingleLayerBf16(self): parameters_output_path=self._temp_dir / "parameters.irpa", ) + def runCompareIreeAgainstTorchEager( + self, reference_model: FluxModelV1, target_dtype: torch.dtype + ): + target_theta = reference_model.theta.transform( + functools.partial(set_float_dtype, dtype=target_dtype) + ) + target_torch_model = FluxModelV1( + theta=target_theta, + params=reference_model.params, + ) + + mlir_path = self._temp_dir / "model.mlir" + parameters_path = self._temp_dir / "parameters.irpa" + batch_size = 1 + batch_sizes = [batch_size] + export_flux_transformer( + target_torch_model, + mlir_output_path=mlir_path, + parameters_output_path=parameters_path, + batch_sizes=batch_sizes, + ) + + iree_module_path = self._temp_dir / "model.vmfb" + iree.compiler.compile_file( + mlir_path, + output_file=iree_module_path, + extra_args=iree_compile_flags, + ) + + target_input_args, target_input_kwargs = target_torch_model.sample_inputs( + batch_size + ) + + def covert_target_to_reference_dtype(t: torch.Tensor) -> torch.Tensor: + if t.dtype == target_dtype: + return t.to(dtype=reference_model.dtype) + return t + + reference_input_args = [ + covert_target_to_reference_dtype(t) for t in target_input_args + ] + reference_input_kwargs = OrderedDict( + (k, covert_target_to_reference_dtype(t)) + for k, t in target_input_kwargs.items() + ) + + reference_result_dict = call_torch_module_function( + module=reference_model, + function_name="forward", + args=reference_input_args, + kwargs=reference_input_kwargs, + ) + expected_outputs = flatten_for_iree_signature(reference_result_dict) + + iree_devices = get_iree_devices(driver="hip", device_count=1) + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=parameters_path, + ) + iree_args = prepare_iree_module_function_args( + args=flatten_for_iree_signature([target_input_args, target_input_kwargs]), + devices=iree_devices, + ) + + iree_result = iree_to_torch( + *run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name=f"forward_bs{batch_size}", + ) + ) + actual_outputs = [ + ops.to(iree_result[i], dtype=expected_outputs[i].dtype) + for i in range(len(expected_outputs)) + ] + # TODO: figure out a good metric. Probably per pixel comparison would be good + # enough. + torch.testing.assert_close(actual_outputs, expected_outputs) + + def runCompareDevRandomSingleLayerIreeAgainstTorchEager( + self, reference_dtype: torch.dtype, target_dtype: torch.dtype + ): + config = make_dev_single_layer_config() + + reference_theta = make_random_theta(config, reference_dtype) + reference_theta.rename_tensors_to_paths() + reference_model = FluxModelV1( + theta=reference_theta, + params=config, + ) + self.runCompareIreeAgainstTorchEager(reference_model, target_dtype) + + @pytest.mark.xfail( + raises=AssertionError, + reason="Accuracy is not good enough. The observed absolute error is 8976.53.", + ) + @with_flux_data + def testCompareDevRandomSingleLayerIreeBf16AgainstTorchEagerF32(self): + self.runCompareDevRandomSingleLayerIreeAgainstTorchEager( + reference_dtype=torch.float32, target_dtype=torch.bfloat16 + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason="Accuracy is probably not good enough. The observed absolute error is 73.25.", + ) + @with_flux_data + def testCompareDevRandomSingleLayerIreeF32AgainstTorchEagerF32(self): + self.runCompareDevRandomSingleLayerIreeAgainstTorchEager( + reference_dtype=torch.float32, target_dtype=torch.float32 + ) + @with_flux_data def testExportSchnellTransformerFromHuggingFace(self): export_flux_transformer_from_hugging_face(