diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 3c0db1590..35a2ee570 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -189,6 +189,7 @@ class T5Config: is_encoder_decoder: bool = True is_decoder: bool = False vocab_size: int = 32128 + context_length: int = 512 d_model: int = 512 d_kv: int = 64 d_ff: int = 2048 @@ -206,6 +207,7 @@ class T5Config: pad_token_id: int = 0 eos_token_id: int = 1 decoder_start_token_id: int = 0 + context_length_padding_block_size: int = 16 def __post_init__(self): self.is_gated_act = self.feed_forward_proj.startswith("gated-") @@ -226,6 +228,7 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): ) gguf_to_config_names_map = { + "t5.context_length": ["context_length"], "t5.embedding_length": ["d_model"], "t5.feed_forward_length": ["d_ff"], "t5.block_count": ["num_layers", "num_decoder_layers"], @@ -245,6 +248,8 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): for config_name in config_names } ) + if "tokenizer.ggml.tokens" in properties: + all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"]) all_kwargs.update(kwargs) return T5Config(**all_kwargs) diff --git a/sharktank/sharktank/models/t5/__init__.py b/sharktank/sharktank/models/t5/__init__.py new file mode 100644 index 000000000..7c7e76704 --- /dev/null +++ b/sharktank/sharktank/models/t5/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .t5 import * +from .export import * diff --git a/sharktank/sharktank/models/t5/export.py b/sharktank/sharktank/models/t5/export.py new file mode 100644 index 000000000..7bd5eef3d --- /dev/null +++ b/sharktank/sharktank/models/t5/export.py @@ -0,0 +1,97 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Union +from pathlib import Path +import torch + +from .t5 import T5Config, T5Encoder +from ...types import Dataset +from iree.turbine.aot import FxProgramsBuilder, export + +__all__ = [ + "export_encoder_mlir", + "export_encoder_iree_parameters", + "prune_decoder_parameters", +] + + +def export_encoder_mlir( + model: Union[T5Encoder, Path, str], + batch_sizes: list[int], + mlir_output_path: str, +): + """ + Args: + model: either the torch module or path to GGUF/IRPA. + """ + if isinstance(model, (Path, str)): + dataset = Dataset.load(model) + config = T5Config.from_gguf_properties( + dataset.properties, + # TODO: add this property to our HuggingFace-to-GGUF conversion script. + # We currently use llama.cpp's converter and it can not make a distinction + # between T5 V1 and V1.1. + # V1 uses ReLU and V1.1 uses gated GeLU. + feed_forward_proj="gated-gelu", + ) + model = T5Encoder(theta=dataset.root_theta, config=config) + + fxb = FxProgramsBuilder(model) + + for batch_size in batch_sizes: + sample_inputs = model.sample_inputs(batch_size) + + context_length_dim_idx = 1 + assert ( + sample_inputs["input_ids"].shape[context_length_dim_idx] + % config.context_length_padding_block_size + == 0 + ) + context_length_block_dim_max = ( + sample_inputs["input_ids"].shape[context_length_dim_idx] + // config.context_length_padding_block_size + ) + context_length_block_dim = torch.export.Dim( + "block", max=context_length_block_dim_max + ) + context_length_dim = ( + config.context_length_padding_block_size * context_length_block_dim + ) + dynamic_shapes = {"input_ids": {context_length_dim_idx: context_length_dim}} + + @fxb.export_program( + name=f"forward_bs{batch_size}", + args=tuple(sample_inputs.values()), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + def _( + model, + input_ids, + ): + return model(input_ids) + + output = export(fxb, import_symbolic_shape_expressions=True) + output.save_mlir(mlir_output_path) + + +def prune_decoder_parameters(dataset: Dataset): + # Remove decoder tensors/parameters if present. + try: + del dataset.root_theta.tree["dec"] + except KeyError: + pass + try: + del dataset.properties["t5.decoder_start_token_id"] + except KeyError: + pass + + +def export_encoder_iree_parameters(model_path: str, output_path: str): + dataset = Dataset.load(model_path) + prune_decoder_parameters(dataset) + dataset.save(output_path) diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py index 3f2e09b79..4ae9108d5 100644 --- a/sharktank/sharktank/models/t5/t5.py +++ b/sharktank/sharktank/models/t5/t5.py @@ -26,8 +26,20 @@ ) from ... import ops from ...types.theta import Theta +from ...types.tensors import AnyTensor from ...layers import FFN, T5Config +__all__ = [ + "T5Config", + "T5LayerFF", + "T5Attention", + "T5SelfAttention", + "T5CrossAttention", + "T5Block", + "T5Stack", + "T5Encoder", +] + logger = logging.getLogger(__name__) @@ -1044,6 +1056,22 @@ def __init__(self, theta: Theta, config: T5Config): theta=theta, config=encoder_config, embed_tokens=self.token_embedding ) + @property + def config(self): + return self.encoder.config + + def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]: + return OrderedDict( + [ + ( + "input_ids", + torch.empty( + size=[batch_size, self.config.context_length], dtype=torch.long + ), + ) + ] + ) + def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py index 44674bc83..9a7dcf1ee 100644 --- a/sharktank/sharktank/types/gguf_interop/base.py +++ b/sharktank/sharktank/types/gguf_interop/base.py @@ -11,7 +11,7 @@ import numpy as np import torch -from gguf import GGUFReader, GGUFValueType +from gguf import GGUFReader, GGUFValueType, ReaderField from iree.turbine.aot import ( ExternalTensorTrait, @@ -44,12 +44,26 @@ def _sanitize_scalar(scalar): return scalar +def _load_array(field: ReaderField) -> list: + if len(field.types) != 2: + raise ValueError(f"Unsupported array type {field.types}") + element_type = field.types[1] + if element_type == GGUFValueType.STRING: + return [ + str(bytes(field.parts[parts_index]), encoding="utf8") + for parts_index in field.data + ] + elif element_type in GGUFReader.gguf_scalar_to_np: + return [ + _sanitize_scalar(field.parts[parts_index][0]) for parts_index in field.data + ] + else: + raise ValueError(f"Unsupported array element type f{element_type}") + + def _load_properties(reader: GGUFReader) -> dict[str, Any]: - # TODO: Figure out what to do with tables. - tables: dict[str, Any] = {} properties: dict[str, Any] = { "schema": "GGUF", - # "tables": tables, } # Extract hyper-parameters. Adapted from gguf-dump.py @@ -60,8 +74,10 @@ def _load_properties(reader: GGUFReader) -> dict[str, Any]: properties[field.name] = str(bytes(field.parts[-1]), encoding="utf8") elif field.types[0] in reader.gguf_scalar_to_np: properties[field.name] = _sanitize_scalar(field.parts[-1][0]) + elif field.types[0] == GGUFValueType.ARRAY: + properties[field.name] = _load_array(field) else: - tables[field.name] = field.parts + raise ValueError(f"Invalid field type.") return properties diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index 377a0b1ba..d5976ec48 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -6,7 +6,6 @@ import iree.runtime from typing import List, Tuple, Optional, Union -from copy import deepcopy from pathlib import Path import torch import numpy as np @@ -91,14 +90,7 @@ def run_iree_module_function( ) if trace_path_prefix is not None: for i, arg in enumerate(args): - # iree.runtime.DeviceArray.to_host() will cache the result and reuse it. - # In the meantime the "actual" device array may have changed. - # It kinda assumes immutable arrays. - # This should probably not be its behavior. - # See https://github.com/iree-org/iree/issues/18870. - # deepcopy also returns an numpy ndarray instead of DeviceArray. - arg_copy = deepcopy(arg) - np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg_copy) + np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host()) results = invoker(*args) if isinstance(results, iree.runtime.DeviceArray): results = (results,) @@ -107,10 +99,10 @@ def run_iree_module_function( for i, arg in enumerate(args): np.save( f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy", - deepcopy(arg), + arg.to_host(), ) for i, arg in enumerate(results): - np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", deepcopy(arg)) + np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host()) return results @@ -197,4 +189,4 @@ def call_torch_module_function( def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: - return [torch.tensor(deepcopy(tensor)) for tensor in tensors] + return [torch.tensor(tensor.to_host()) for tensor in tensors] diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py index 3da914556..076404e5d 100644 --- a/sharktank/tests/models/t5/t5_test.py +++ b/sharktank/tests/models/t5/t5_test.py @@ -14,19 +14,34 @@ T5EncoderModel as ReferenceT5EncoderModel, T5Config as ReferenceT5Config, ) +import os +from collections import OrderedDict import pytest import torch from unittest import TestCase +from parameterized import parameterized from sharktank.types import Theta, DefaultPrimitiveTensor, unbox_tensor, Dataset -from sharktank.models.t5.t5 import ( +from sharktank.models.t5 import ( T5Attention, T5SelfAttention, T5Config, T5Encoder, T5LayerFF, + export_encoder_mlir, + export_encoder_iree_parameters, ) -from sharktank.utils.testing import make_rand_torch +from sharktank.utils.testing import make_rand_torch, TempDirTestBase from sharktank.utils.hf_datasets import get_dataset +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +import iree.compiler with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')") @@ -37,8 +52,16 @@ def make_random_mask(shape: tuple[int], dtype: torch.dtype): return mask +test_prompts = [ + "Studies have been shown that owning a dog is good for you", + "The horse went into the river", + "We need at least one sentence long enough so that it spans more than one padding block which by default is of size 16.", + "Make the batch size 4", +] + + @pytest.mark.usefixtures("get_model_artifacts") -class T5EncoderTest(TestCase): +class T5EncoderEagerTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) @@ -55,10 +78,7 @@ def runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace( reference_model.eval() input_ids = tokenizer( - [ - "Studies have been shown that owning a dog is good for you", - "The horse went into the river", - ], + test_prompts, return_tensors="pt", padding=True, ).input_ids @@ -70,7 +90,6 @@ def runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace( dataset = Dataset.load(target_model_path) config = T5Config.from_gguf_properties( dataset.properties, - vocab_size=tokenizer.vocab_size, feed_forward_proj="gated-gelu", ) model = T5Encoder(theta=dataset.root_theta, config=config) @@ -89,6 +108,103 @@ def testV1_1XxlFp32CompareTorchEagerAgainstHuggingFace(self): self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-xxl") +@pytest.mark.usefixtures("caching", "get_model_artifacts", "path_prefix") +class T5EncoderIreeTest(TempDirTestBase): + def setUp(self): + super().setUp() + if self.path_prefix is None: + self.path_prefix = f"{self._temp_dir}/" + + @parameterized.expand( + [ + "google/t5-v1_1-small", + "google/t5-v1_1-xxl", + ] + ) + @with_t5_data + def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): + get_dataset( + huggingface_repo_id, + ).download() + tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) + + huggingface_repo_id_as_path = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + ) + source_model_name = f"{huggingface_repo_id_as_path}_fp32_model" + source_model_path = getattr(self, source_model_name) + + dataset = Dataset.load(source_model_path) + config = T5Config.from_gguf_properties( + dataset.properties, + feed_forward_proj="gated-gelu", + ) + + input_ids = tokenizer( + test_prompts, + return_tensors="pt", + padding=True, + pad_to_multiple_of=config.context_length_padding_block_size, + ).input_ids + input_args = OrderedDict([("input_ids", input_ids)]) + batch_size = input_ids.shape[0] + + reference_model = T5Encoder(theta=dataset.root_theta, config=config) + reference_result = flatten_for_iree_signature( + call_torch_module_function( + module=reference_model, + function_name="forward", + kwargs=input_args, + trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_torch_", + ) + ) + + mlir_path = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.mlir" + if not self.caching or not os.path.exists(mlir_path): + export_encoder_mlir( + source_model_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + ) + iree_module_path = ( + f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.vmfb" + ) + if not self.caching or not os.path.exists(iree_module_path): + iree.compiler.compile_file( + mlir_path, + output_file=iree_module_path, + extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], + ) + + parameters_path = ( + f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.irpa" + ) + if not self.caching or not os.path.exists(parameters_path): + export_encoder_iree_parameters(source_model_path, parameters_path) + + iree_devices = get_iree_devices(driver="hip", device_count=1) + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=parameters_path, + ) + iree_args = prepare_iree_module_function_args( + args=flatten_for_iree_signature(input_args), devices=iree_devices + ) + iree_result = iree_to_torch( + *run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name=f"forward_bs{batch_size}", + trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_iree_", + ) + ) + + torch.testing.assert_close( + reference_result, iree_result, atol=1e-4, rtol=2.0e-3 + ) + + class T5AttentionTest(TestCase): def setUp(self): super().setUp()