Skip to content

Commit

Permalink
Add xfail tests for single layer Flux transformer that verify IREE re…
Browse files Browse the repository at this point in the history
…sults against Torch (#741)

The test accuracy is not of sufficient quality and needs further
investigation.
The tests compare IREE bf16 and f32 against Torch f32.

Refactor the sample input generation and make it produce noise images
for final size of 1024x1024 instead of 512x512.

Remove unused duplicated function for random Theta generation.
  • Loading branch information
sogartar authored and eagarvey-amd committed Jan 8, 2025
1 parent f2c0016 commit 33dc998
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 219 deletions.
55 changes: 48 additions & 7 deletions sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ...layers import *
from ...types import *
from ...utils.create_cache import *
from ...utils.testing import make_rand_torch
from ... import ops

__all__ = [
Expand Down Expand Up @@ -196,13 +197,37 @@ def sample_inputs(
if not (function is None or function == "forward"):
raise ValueError(f'Only function "forward" is supported. Got "{function}"')

# TODO: do not hardcode these but derive the required shapes from the config.
img = torch.rand([batch_size, 1024, 64], dtype=self.dtype)
img_ids = torch.rand([batch_size, 1024, 3], dtype=torch.float32)
txt = torch.rand([batch_size, 512, 4096], dtype=self.dtype)
txt_ids = torch.rand([batch_size, 512, 3], dtype=torch.float32)
# The allowed range of these values is dependent on the model size.
# They will not work for all variants, specifically toy-sized models.
output_img_height = 1024
output_img_width = 1024
output_img_channels = 3

img = self._get_noise(
batch_size, output_img_height, output_img_width, self.dtype
)

_, c, h, w = img.shape
img = img.reshape(batch_size, h * w // 4, c * 4)

img_ids = torch.zeros(h // 2, w // 2, output_img_channels)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = img_ids.reshape(1, h * w // 4, output_img_channels)
img_ids = img_ids.repeat(batch_size, 1, 1)

# T5 encoder output
txt_context_length = 512
txt_dims_per_token = 4096
txt = torch.rand([1, txt_context_length, txt_dims_per_token], dtype=self.dtype)
txt = txt.repeat(batch_size, 1, 1)
txt_ids = torch.zeros(batch_size, txt.shape[1], output_img_channels)

timesteps = torch.rand([batch_size], dtype=self.dtype)
y = torch.rand([batch_size, 768], dtype=self.dtype)

# CLIP text model output
y = make_rand_torch([1, 768], dtype=self.dtype)
y = y.repeat(batch_size, 1)

args = tuple()
kwargs = OrderedDict(
Expand All @@ -217,10 +242,26 @@ def sample_inputs(
)

if self.guidance:
kwargs["guidance"] = torch.rand([batch_size], dtype=self.dtype)
kwargs["guidance"] = torch.full([batch_size], 3.5, dtype=self.dtype)

return args, kwargs

def _get_noise(
self,
batch_size: int,
height: int,
width: int,
dtype: torch.dtype,
):
return torch.randn(
batch_size,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
dtype=dtype,
)

def _deduce_dtype(self) -> torch.dtype:
dtype = self.theta("img_in.weight").dtype
assert (
Expand Down
26 changes: 15 additions & 11 deletions sharktank/sharktank/models/flux/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,8 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype):
return Theta(tensor_dict)


def export_dev_random_single_layer(
dtype: torch.dtype,
mlir_output_path: PathLike,
parameters_output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
):
rng_state = torch.get_rng_state()
torch.random.manual_seed(12345)

dtype = torch.bfloat16
params = FluxParams(
def make_dev_single_layer_config():
return FluxParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
Expand All @@ -241,6 +232,19 @@ def export_dev_random_single_layer(
qkv_bias=True,
guidance_embed=True,
)


def export_dev_random_single_layer(
dtype: torch.dtype,
mlir_output_path: PathLike,
parameters_output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
):
rng_state = torch.get_rng_state()
torch.random.manual_seed(12345)

dtype = torch.bfloat16
params = make_dev_single_layer_config()
theta = make_random_theta(params, dtype)
flux = FluxModelV1(
theta=theta,
Expand Down
35 changes: 26 additions & 9 deletions sharktank/sharktank/utils/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ def bfloat16_device_array_to_torch(
return torch.tensor(device_array.to_host())


def torch_tensor_to_device_array(
tensor: torch.Tensor, device: iree.runtime.HalDevice
) -> iree.runtime.DeviceArray:
if tensor.dtype == torch.bfloat16:
tensor_as_int16 = tensor.view(dtype=torch.int16)
device_array_as_int16 = iree.runtime.asdevicearray(
device, unbox_tensor(tensor_as_int16).to("cpu").numpy()
)
buffer_view = iree.runtime.HalBufferView(
buffer=device_array_as_int16._buffer_view.get_buffer(),
shape=device_array_as_int16._buffer_view.shape,
element_type=iree.runtime.HalElementType.BFLOAT_16,
)
return iree.runtime.DeviceArray(device, buffer_view)

return iree.runtime.asdevicearray(device, unbox_tensor(tensor).to("cpu").numpy())


def run_iree_module_function(
module: iree.runtime.VmModule,
vm_context: iree.runtime.VmContext,
Expand Down Expand Up @@ -180,11 +198,7 @@ def prepare_iree_module_function_args(
]
)
elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)):
res.append(
iree.runtime.asdevicearray(
devices[0], unbox_tensor(arg).to("cpu").numpy()
)
)
res.append(torch_tensor_to_device_array(arg, devices[0]))
else:
assert isinstance(arg, collections.abc.Sequence)
res.extend(prepare_iree_module_function_args(arg, devices))
Expand All @@ -200,24 +214,27 @@ def flatten_for_iree_signature(tree: Tree) -> List[torch.Tensor]:
def call_torch_module_function(
module: torch.nn.Module,
function_name: str,
kwargs: OrderedDict,
args: Optional[tuple[AnyTensor]] = None,
kwargs: Optional[OrderedDict] = None,
trace_path_prefix: Optional[str] = None,
):
"""Call a torch module function with optional tracing.
For tracing the arguments/results are flattened to match IREE's signature."""
args = args if args is not None else tuple()
kwargs = kwargs if kwargs is not None else OrderedDict()
assert isinstance(
kwargs, OrderedDict
), "Make sure when flattening the order is preserved"
if trace_path_prefix is not None:
flat_args = flatten_for_iree_signature(kwargs)
flat_args = flatten_for_iree_signature([args, kwargs])
for i, arg in enumerate(flat_args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}.npy",
promote_bfloat16_to_float32(arg.to("cpu")).numpy(),
)
res = getattr(module, function_name)(**kwargs)
res = getattr(module, function_name)(*args, **kwargs)
if trace_path_prefix is not None:
flat_args = flatten_for_iree_signature(kwargs)
flat_args = flatten_for_iree_signature([args, kwargs])
for i, arg in enumerate(flat_args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}.npy",
Expand Down
Loading

0 comments on commit 33dc998

Please sign in to comment.