From a1e632a280ae160f53ffff4ecb0c7ec87b83dddf Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin@amd.com>
Date: Thu, 2 Jan 2025 09:12:50 -0800
Subject: [PATCH] Add CLI script exporting CLIP Toy model IREE test data (#672)

This is required to have an easy way of exporting test data that will be
used in IREE to guard against regressions.

E.g.
```
python -m sharktank.models.clip.export_toy_text_model_iree_test_data \
  --output-dir=clip_toy_text_model
```

Refactor some of the existing tests to reuse the new export logic.
---
 sharktank/sharktank/models/clip/export.py     |  25 ++-
 .../export_toy_text_model_iree_test_data.py   |  29 ++++
 sharktank/sharktank/models/clip/testing.py    | 161 +++++++++++++++++-
 sharktank/sharktank/utils/io.py               |  25 ++-
 sharktank/tests/models/clip/clip_test.py      | 145 ++++++----------
 5 files changed, 280 insertions(+), 105 deletions(-)
 create mode 100644 sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py

diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py
index 3cae3f4c4..95dbdacad 100644
--- a/sharktank/sharktank/models/clip/export.py
+++ b/sharktank/sharktank/models/clip/export.py
@@ -11,8 +11,8 @@
     CLIPEncoderLayer as HfCLIPEncoderLayer,
     CLIPEncoder as HfCLIPEncoder,
 )
-from os import PathLike
 import torch
+from os import PathLike
 
 from ...types.theta import Theta, Dataset, torch_module_to_theta
 from ...layers.configs import ClipTextConfig
@@ -50,9 +50,14 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset:
     return Dataset(properties=model.config.to_properties(), root_theta=model.theta)
 
 
+def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike):
+    dataset = clip_text_model_to_dataset(model)
+    dataset.save(output_path)
+
+
 def export_clip_text_model_dataset_from_hugging_face(
-    model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel],
-    output_path: Union[str, PathLike],
+    model_or_name_or_path: Union[PathLike, transformers.CLIPTextModel],
+    output_path: PathLike,
     dtype: Optional[torch.dtype] = None,
 ):
     if isinstance(model_or_name_or_path, transformers.CLIPTextModel):
@@ -99,3 +104,17 @@ def _(
 
     output = export(fxb, import_symbolic_shape_expressions=True)
     output.save_mlir(mlir_output_path)
+
+
+def export_clip_text_model_to_iree(
+    model: ClipTextModel,
+    batch_sizes: list[int],
+    mlir_output_path: PathLike,
+    parameters_output_path: PathLike,
+):
+    export_clip_text_model_iree_parameters(model, parameters_output_path)
+    export_clip_text_model_mlir(
+        model=parameters_output_path,
+        batch_sizes=batch_sizes,
+        mlir_output_path=mlir_output_path,
+    )
diff --git a/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py b/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py
new file mode 100644
index 000000000..979bc3255
--- /dev/null
+++ b/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py
@@ -0,0 +1,29 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from argparse import ArgumentParser
+from typing import Optional
+from pathlib import Path
+
+from .testing import export_clip_toy_text_model_default_iree_test_data
+
+
+def main(args: Optional[list[str]] = None):
+    parser = ArgumentParser(
+        description=(
+            "Export test data for toy-sized CLIP text model."
+            " This program MLIR, parameters sample input and expected output."
+            " Exports float32 and bfloat16 model variants."
+            " The expected output is always in float32 precision."
+        )
+    )
+    parser.add_argument("--output-dir", type=str, default=f"clip_toy_text_model")
+    args = parser.parse_args(args=args)
+    export_clip_toy_text_model_default_iree_test_data(Path(args.output_dir))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/sharktank/sharktank/models/clip/testing.py b/sharktank/sharktank/models/clip/testing.py
index 87634c220..852da8a18 100644
--- a/sharktank/sharktank/models/clip/testing.py
+++ b/sharktank/sharktank/models/clip/testing.py
@@ -4,14 +4,167 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from ...layers.configs.llm_configs import ClipTextConfig
-from ...types.theta import Theta
-from .export import hugging_face_clip_text_model_to_theta
+import functools
 import torch
+from os import PathLike, makedirs
+from typing import Union, Optional
+from copy import copy
+from iree.turbine.aot.params import ParameterArchiveBuilder
+
+from ...layers.configs.llm_configs import ClipTextConfig
+from .clip import ClipTextModel
+from ...types.theta import Theta, Dataset
+from ...types.tensors import dtype_to_serialized_short_name
+from ...utils.io import save_tensor_as_irpa
+from .export import (
+    clip_text_model_to_dataset,
+    hugging_face_clip_text_model_to_theta,
+    export_clip_text_model_to_iree,
+)
+from ...transforms.dataset import set_float_dtype
+
+
+def clip_toy_text_model_config(dtype: Optional[torch.dtype] = None) -> ClipTextConfig:
+    num_attention_heads = 5
+    vocab_size = 11
+    return ClipTextConfig(
+        vocab_size=vocab_size,
+        hidden_size=13 * num_attention_heads,
+        intermediate_size=7,
+        projection_dim=3,
+        num_attention_heads=num_attention_heads,
+        max_position_embeddings=17,
+        layer_norm_eps=1e-4,
+        num_hidden_layers=2,
+        bos_token_id=vocab_size - 2,
+        eos_token_id=vocab_size - 1,
+        dtype=dtype,
+    )
+
+
+def export_clip_toy_text_model_default_iree_test_data(output_dir: PathLike):
+    makedirs(output_dir, exist_ok=True)
+
+    # We want to always export the same without interfering with RNG for the rest of
+    # the program.
+    rng_state = torch.get_rng_state()
+    torch.random.manual_seed(12345)
+
+    reference_dtype = torch.float32
+    target_dtypes = [torch.float32, torch.bfloat16]
+    target_iree_parameters_output_paths = []
+    target_mlir_output_paths = []
+    batch_size = 4
+    for dtype in target_dtypes:
+        prefix = output_dir / f"{dtype_to_serialized_short_name(dtype)}"
+        target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa")
+        target_mlir_output_paths.append(f"{prefix}.mlir")
+    call_prefix = output_dir / f"forward_bs{batch_size}"
+    input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa"
+    expected_last_hidden_state_output_path = (
+        f"{call_prefix}_expected_result0_last_hidden_state_"
+        f"{dtype_to_serialized_short_name(reference_dtype)}.irpa"
+    )
+    export_clip_toy_text_model_iree_test_data(
+        reference_dtype=reference_dtype,
+        target_dtypes=target_dtypes,
+        batch_size=batch_size,
+        input_ids_output_path=input_ids_output_path,
+        expected_last_hidden_state_output_path=expected_last_hidden_state_output_path,
+        target_iree_parameters_output_paths=target_iree_parameters_output_paths,
+        target_mlir_output_paths=target_mlir_output_paths,
+    )
+
+    torch.set_rng_state(rng_state)
+
+
+def export_clip_toy_text_model_iree_test_data(
+    reference_dtype: torch.dtype,
+    target_dtypes: list[torch.dtype],
+    batch_size: int,
+    target_iree_parameters_output_paths: list[PathLike],
+    target_mlir_output_paths: list[PathLike],
+    input_ids_output_path: PathLike,
+    expected_last_hidden_state_output_path: PathLike,
+):
+    reference_config = clip_toy_text_model_config(reference_dtype)
+    input_ids = make_random_input_token_sequences(
+        batch_size=batch_size, config=reference_config
+    )
+    reference_theta = make_clip_text_model_random_theta(reference_config)
+    reference_model = ClipTextModel(theta=reference_theta, config=reference_config)
+    for i, (
+        target_dtype,
+        target_iree_parameters_output_path,
+        target_mlir_output_path,
+    ) in enumerate(
+        zip(
+            target_dtypes,
+            target_iree_parameters_output_paths,
+            target_mlir_output_paths,
+            strict=True,
+        )
+    ):
+        current_input_ids_output_path = None
+        current_expected_last_hidden_state_output_path = None
+        if i == 0:
+            current_input_ids_output_path = input_ids_output_path
+            current_expected_last_hidden_state_output_path = (
+                expected_last_hidden_state_output_path
+            )
+        export_clip_text_model_iree_test_data(
+            reference_model=reference_model,
+            target_dtype=target_dtype,
+            input_ids=input_ids,
+            target_iree_parameters_output_path=target_iree_parameters_output_path,
+            target_mlir_output_path=target_mlir_output_path,
+            input_ids_output_path=current_input_ids_output_path,
+            expected_last_hidden_state_output_path=current_expected_last_hidden_state_output_path,
+        )
+
+
+def export_clip_text_model_iree_test_data(
+    reference_model: ClipTextModel,
+    target_dtype: torch.dtype,
+    input_ids: torch.LongTensor,
+    target_mlir_output_path: PathLike,
+    target_iree_parameters_output_path: PathLike,
+    input_ids_output_path: Optional[PathLike] = None,
+    expected_last_hidden_state_output_path: Optional[PathLike] = None,
+):
+    batch_size = input_ids.shape[0]
+    reference_dataset = clip_text_model_to_dataset(reference_model)
+    target_config = copy(reference_model.config)
+    target_config.dtype = target_dtype
+    target_dataset = Dataset(
+        root_theta=reference_dataset.root_theta.transform(
+            functools.partial(set_float_dtype, dtype=target_dtype)
+        ),
+        properties=target_config.to_properties(),
+    )
+    target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config)
+    export_clip_text_model_to_iree(
+        target_model,
+        batch_sizes=[batch_size],
+        mlir_output_path=target_mlir_output_path,
+        parameters_output_path=target_iree_parameters_output_path,
+    )
+
+    if input_ids_output_path is not None:
+        save_tensor_as_irpa(input_ids, input_ids_output_path)
+
+    if expected_last_hidden_state_output_path is None:
+        return
+
+    expected_last_hidden_state = reference_model(input_ids=input_ids)[
+        "last_hidden_state"
+    ]
+    save_tensor_as_irpa(
+        expected_last_hidden_state, expected_last_hidden_state_output_path
+    )
 
 
 def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta:
-    from transformers import CLIPTextConfig as HfCLIPTextConfig
     from transformers import CLIPTextModel as HfCLIPTextModel
 
     hf_config = config.to_hugging_face_clip_text_model_config()
diff --git a/sharktank/sharktank/utils/io.py b/sharktank/sharktank/utils/io.py
index ac2480846..12c9c002b 100644
--- a/sharktank/sharktank/utils/io.py
+++ b/sharktank/sharktank/utils/io.py
@@ -5,10 +5,10 @@
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from pathlib import Path
+import torch
+from os import PathLike
 
-from iree.turbine.aot import (
-    ParameterArchiveBuilder,
-)
+from iree.turbine.aot import ParameterArchiveBuilder, ParameterArchive
 
 
 class ShardedArchiveBuilder(ParameterArchiveBuilder):
@@ -49,3 +49,22 @@ def path_for_rank(path: Path, rank: int):
           /tmp/foobar.rank0.irpa
         """
         return path.with_suffix(f".rank{rank}{path.suffix}")
+
+
+def save_tensor_as_irpa(tensor: torch.Tensor, path: PathLike):
+    """Save a single tensor into an IRPA file."""
+    param_builder = ParameterArchiveBuilder()
+    param_builder.add_tensor("", tensor)
+    param_builder.save(path)
+
+
+def load_irpa_as_tensor(tensor: torch.Tensor, path: PathLike, **kwargs):
+    """Load a tensor form an IRPA file that holds only one tensor."""
+    params = ParameterArchive(path, **kwargs)
+    items = params.items()
+    if len(items) != 1:
+        raise ValueError(
+            f'Too many items {len(items)} in IRPA file "{path}".'
+            " Only a single tensor was expected."
+        )
+    return items[0][1].as_tensor()
diff --git a/sharktank/tests/models/clip/clip_test.py b/sharktank/tests/models/clip/clip_test.py
index 99af4ba6f..704333c90 100644
--- a/sharktank/tests/models/clip/clip_test.py
+++ b/sharktank/tests/models/clip/clip_test.py
@@ -8,6 +8,7 @@
 import functools
 import iree.compiler
 import os
+from pathlib import Path
 from parameterized import parameterized
 from copy import copy
 import pytest
@@ -47,18 +48,18 @@
     test_prompts,
 )
 from sharktank.models.clip.export import (
-    export_clip_text_model_mlir,
     export_clip_text_model_dataset_from_hugging_face,
     hugging_face_clip_attention_to_theta,
     hugging_face_clip_encoder_layer_to_theta,
     hugging_face_clip_encoder_to_theta,
     hugging_face_clip_text_model_to_dataset,
     hugging_face_clip_text_model_to_theta,
-    clip_text_model_to_dataset,
 )
 from sharktank.models.clip.testing import (
     make_random_input_token_sequences,
     make_clip_text_model_random_theta,
+    export_clip_text_model_iree_test_data,
+    clip_toy_text_model_config,
 )
 from sharktank.models.clip import (
     ClipAttention,
@@ -72,13 +73,15 @@
 with_clip_data = pytest.mark.skipif("not config.getoption('with_clip_data')")
 
 
-@pytest.mark.usefixtures("caching", "path_prefix")
+@pytest.mark.usefixtures("path_prefix")
 class ClipTextIreeTest(TempDirTestBase):
     def setUp(self):
         super().setUp()
         torch.random.manual_seed(12345)
         if self.path_prefix is None:
-            self.path_prefix = f"{self._temp_dir}/"
+            self.path_prefix = self._temp_dir
+        else:
+            self.path_prefix = Path(self.path_prefix)
 
     @with_clip_data
     def testSmokeExportLargeF32FromHuggingFace(self):
@@ -90,12 +93,20 @@ def testSmokeExportLargeF32FromHuggingFace(self):
             huggingface_repo_id,
         ).download()
         target_dtype_name = dtype_to_serialized_short_name(torch.float32)
-        target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_text_model_{target_dtype_name}"
+        target_model_path_prefix = (
+            self.path_prefix
+            / f"{huggingface_repo_id_as_path}_text_model_{target_dtype_name}"
+        )
         output_path = f"{target_model_path_prefix}.irpa"
         export_clip_text_model_dataset_from_hugging_face(
             huggingface_repo_id, output_path
         )
 
+    def testSmokeExportToyIreeTestData(self):
+        from sharktank.models.clip.export_toy_text_model_iree_test_data import main
+
+        main([f"--output-dir={self.path_prefix/'clip_toy_text_model'}"])
+
     @with_clip_data
     def testCompareLargeIreeF32AgainstTorchEagerF32(self):
         self.runTestCompareIreeAgainstPretrainedTorchEager(
@@ -141,43 +152,31 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
         )
         target_dtype_name = dtype_to_serialized_short_name(target_dtype)
         reference_model_path_prefix = (
-            f"{self.path_prefix}{file_artifact_prefix_name}_{reference_dtype_name}"
+            self.path_prefix / f"{file_artifact_prefix_name}_{reference_dtype_name}"
         )
         target_model_path_prefix = (
-            f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}"
-        )
-
-        target_config = copy(reference_model.config)
-        target_config.dtype = target_dtype
-        reference_dataset = clip_text_model_to_dataset(reference_model)
-        target_dataset = Dataset(
-            root_theta=reference_dataset.root_theta.transform(
-                functools.partial(set_float_dtype, dtype=target_config.dtype)
-            ),
-            properties=target_config.to_properties(),
+            self.path_prefix / f"{file_artifact_prefix_name}_{target_dtype_name}"
         )
 
         parameters_path = f"{target_model_path_prefix}.irpa"
-        if not self.caching or not os.path.exists(parameters_path):
-            target_dataset.save(parameters_path)
-
-        dataset = Dataset.load(parameters_path)
-        target_config = ClipTextConfig.from_properties(dataset.properties)
         input_args = OrderedDict([("input_ids", input_ids)])
         batch_size = input_ids.shape[0]
-
         mlir_path = f"{target_model_path_prefix}.mlir"
-        if not self.caching or not os.path.exists(mlir_path):
-            export_clip_text_model_mlir(
-                parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path
-            )
+
+        export_clip_text_model_iree_test_data(
+            reference_model=reference_model,
+            target_dtype=target_dtype,
+            input_ids=input_ids,
+            target_mlir_output_path=mlir_path,
+            target_iree_parameters_output_path=parameters_path,
+        )
+
         iree_module_path = f"{target_model_path_prefix}.vmfb"
-        if not self.caching or not os.path.exists(iree_module_path):
-            iree.compiler.compile_file(
-                mlir_path,
-                output_file=iree_module_path,
-                extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"],
-            )
+        iree.compiler.compile_file(
+            mlir_path,
+            output_file=iree_module_path,
+            extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"],
+        )
 
         reference_result_dict = call_torch_module_function(
             module=reference_model,
@@ -211,11 +210,11 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
             for i in range(len(expected_outputs))
         ]
 
-        actual_last_hidden_states = actual_outputs[0]
-        expected_last_hidden_states = expected_outputs[0]
+        actual_last_hidden_state = actual_outputs[0]
+        expected_last_hidden_state = expected_outputs[0]
 
         assert_text_encoder_state_close(
-            actual_last_hidden_states, expected_last_hidden_states, atol
+            actual_last_hidden_state, expected_last_hidden_state, atol
         )
 
     def runTestCompareRandomModelIreeAgainstTorch(
@@ -243,21 +242,7 @@ def runTestCompareToyModelIreeAgainstTorch(
         self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float
     ):
         batch_size = 4
-        num_attention_heads = 5
-        vocab_size = 11
-        reference_config = ClipTextConfig(
-            vocab_size=vocab_size,
-            hidden_size=13 * num_attention_heads,
-            intermediate_size=7,
-            projection_dim=3,
-            num_attention_heads=num_attention_heads,
-            max_position_embeddings=17,
-            layer_norm_eps=1e-4,
-            num_hidden_layers=2,
-            bos_token_id=vocab_size - 2,
-            eos_token_id=vocab_size - 1,
-            dtype=reference_dtype,
-        )
+        reference_config = clip_toy_text_model_config(reference_dtype)
         file_artifact_prefix_name = "clip_text_model_toy"
         self.runTestCompareRandomModelIreeAgainstTorch(
             reference_config=reference_config,
@@ -404,21 +389,9 @@ def testCompareEagerToySizedModelAgainstTransformers(
     ):
         torch.set_default_dtype(reference_dtype)
         batch_size = 19
-        tgt_len = 23
-        num_attention_heads = 5
         vocab_size = 11
-        reference_config = transformers.CLIPTextConfig(
-            vocab_size=vocab_size,
-            hidden_size=13 * num_attention_heads,
-            intermediate_size=7,
-            projection_dim=3,
-            num_attention_heads=num_attention_heads,
-            layer_norm_eps=1e-4,
-            num_hidden_layers=2,
-            final_layer_norm=1e-3,
-            bos_token_id=vocab_size - 2,
-            eos_token_id=vocab_size - 1,
-        )
+        config = clip_toy_text_model_config()
+        reference_config = config.to_hugging_face_clip_text_model_config()
         reference_model = HfCLIPTextModel(
             reference_config,
         )
@@ -432,7 +405,9 @@ def testCompareEagerToySizedModelAgainstTransformers(
         )
         model = ClipTextModel(theta, config)
 
-        input_ids = torch.randint(low=0, high=vocab_size, size=[batch_size, tgt_len])
+        input_ids = torch.randint(
+            low=0, high=vocab_size, size=[batch_size, config.max_position_embeddings]
+        )
 
         expected_outputs = reference_model(input_ids=input_ids)
 
@@ -471,16 +446,10 @@ def testCompareEagerToySizedModelAgainstTransformers(
     ):
         torch.set_default_dtype(reference_dtype)
         batch_size = 19
-        tgt_len = 23
+        config = clip_toy_text_model_config()
+        reference_config = config.to_hugging_face_clip_text_model_config()
+        tgt_len = config.max_position_embeddings
         src_len = tgt_len
-        num_attention_heads = 2
-        reference_config = transformers.CLIPTextConfig(
-            vocab_size=11,
-            hidden_size=13 * num_attention_heads,
-            intermediate_size=7,
-            projection_dim=3,
-            num_attention_heads=num_attention_heads,
-        )
         reference_model = HfCLIPAttention(
             reference_config,
         )
@@ -495,7 +464,7 @@ def testCompareEagerToySizedModelAgainstTransformers(
         model = ClipAttention(theta, config)
 
         reference_hidden_states = make_rand_torch(
-            shape=[batch_size, tgt_len, reference_config.hidden_size],
+            shape=[batch_size, tgt_len, config.hidden_size],
             dtype=reference_dtype,
         )
         reference_attention_mask = make_random_mask(
@@ -551,17 +520,10 @@ def testCompareEagerToySizedModelAgainstTransformers(
     ):
         torch.set_default_dtype(reference_dtype)
         batch_size = 19
-        tgt_len = 23
+        config = clip_toy_text_model_config()
+        reference_config = config.to_hugging_face_clip_text_model_config()
+        tgt_len = config.max_position_embeddings
         src_len = tgt_len
-        num_attention_heads = 2
-        reference_config = transformers.CLIPTextConfig(
-            vocab_size=11,
-            hidden_size=13 * num_attention_heads,
-            intermediate_size=7,
-            projection_dim=3,
-            num_attention_heads=num_attention_heads,
-            layer_norm_eps=1e-4,
-        )
         reference_model = HfCLIPEncoderLayer(
             reference_config,
         )
@@ -634,15 +596,8 @@ def testCompareEagerToySizedModelAgainstTransformers(
         batch_size = 19
         tgt_len = 23
         src_len = tgt_len
-        num_attention_heads = 5
-        reference_config = transformers.CLIPTextConfig(
-            vocab_size=11,
-            hidden_size=13 * num_attention_heads,
-            intermediate_size=7,
-            projection_dim=3,
-            num_attention_heads=num_attention_heads,
-            layer_norm_eps=1e-4,
-            num_hidden_layers=2,
+        reference_config = (
+            clip_toy_text_model_config().to_hugging_face_clip_text_model_config()
         )
         reference_model = HfCLIPEncoder(
             reference_config,