Skip to content

Commit

Permalink
Add T5 encoder exporting to MLIR and numerics verification with IREE (#…
Browse files Browse the repository at this point in the history
…573)

This change adds the export functionality and a test for the T5 v1.1
encoder. Small and XXL variants.
The numerical tolerance in the test is `alto=1e-4 rtol=2.0e-3`.
I observed relative error of `1.47e-3` on the XXL variant. This is
probably OK. Definitely not outrageous.
    
Added parsing of GGUF arrays when loading properties. This is used to
deduce the vocabulary size instead of having to provide this manually.
This can potentially be used also to load the tokenizer from GGUF or
IRPA.
  • Loading branch information
sogartar authored Nov 21, 2024
1 parent c8af5a1 commit dbff2e5
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 25 deletions.
5 changes: 5 additions & 0 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class T5Config:
is_encoder_decoder: bool = True
is_decoder: bool = False
vocab_size: int = 32128
context_length: int = 512
d_model: int = 512
d_kv: int = 64
d_ff: int = 2048
Expand All @@ -206,6 +207,7 @@ class T5Config:
pad_token_id: int = 0
eos_token_id: int = 1
decoder_start_token_id: int = 0
context_length_padding_block_size: int = 16

def __post_init__(self):
self.is_gated_act = self.feed_forward_proj.startswith("gated-")
Expand All @@ -226,6 +228,7 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
)

gguf_to_config_names_map = {
"t5.context_length": ["context_length"],
"t5.embedding_length": ["d_model"],
"t5.feed_forward_length": ["d_ff"],
"t5.block_count": ["num_layers", "num_decoder_layers"],
Expand All @@ -245,6 +248,8 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
for config_name in config_names
}
)
if "tokenizer.ggml.tokens" in properties:
all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"])
all_kwargs.update(kwargs)

return T5Config(**all_kwargs)
8 changes: 8 additions & 0 deletions sharktank/sharktank/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .t5 import *
from .export import *
97 changes: 97 additions & 0 deletions sharktank/sharktank/models/t5/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Union
from pathlib import Path
import torch

from .t5 import T5Config, T5Encoder
from ...types import Dataset
from iree.turbine.aot import FxProgramsBuilder, export

__all__ = [
"export_encoder_mlir",
"export_encoder_iree_parameters",
"prune_decoder_parameters",
]


def export_encoder_mlir(
model: Union[T5Encoder, Path, str],
batch_sizes: list[int],
mlir_output_path: str,
):
"""
Args:
model: either the torch module or path to GGUF/IRPA.
"""
if isinstance(model, (Path, str)):
dataset = Dataset.load(model)
config = T5Config.from_gguf_properties(
dataset.properties,
# TODO: add this property to our HuggingFace-to-GGUF conversion script.
# We currently use llama.cpp's converter and it can not make a distinction
# between T5 V1 and V1.1.
# V1 uses ReLU and V1.1 uses gated GeLU.
feed_forward_proj="gated-gelu",
)
model = T5Encoder(theta=dataset.root_theta, config=config)

fxb = FxProgramsBuilder(model)

for batch_size in batch_sizes:
sample_inputs = model.sample_inputs(batch_size)

context_length_dim_idx = 1
assert (
sample_inputs["input_ids"].shape[context_length_dim_idx]
% config.context_length_padding_block_size
== 0
)
context_length_block_dim_max = (
sample_inputs["input_ids"].shape[context_length_dim_idx]
// config.context_length_padding_block_size
)
context_length_block_dim = torch.export.Dim(
"block", max=context_length_block_dim_max
)
context_length_dim = (
config.context_length_padding_block_size * context_length_block_dim
)
dynamic_shapes = {"input_ids": {context_length_dim_idx: context_length_dim}}

@fxb.export_program(
name=f"forward_bs{batch_size}",
args=tuple(sample_inputs.values()),
dynamic_shapes=dynamic_shapes,
strict=False,
)
def _(
model,
input_ids,
):
return model(input_ids)

output = export(fxb, import_symbolic_shape_expressions=True)
output.save_mlir(mlir_output_path)


def prune_decoder_parameters(dataset: Dataset):
# Remove decoder tensors/parameters if present.
try:
del dataset.root_theta.tree["dec"]
except KeyError:
pass
try:
del dataset.properties["t5.decoder_start_token_id"]
except KeyError:
pass


def export_encoder_iree_parameters(model_path: str, output_path: str):
dataset = Dataset.load(model_path)
prune_decoder_parameters(dataset)
dataset.save(output_path)
28 changes: 28 additions & 0 deletions sharktank/sharktank/models/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,20 @@
)
from ... import ops
from ...types.theta import Theta
from ...types.tensors import AnyTensor
from ...layers import FFN, T5Config

__all__ = [
"T5Config",
"T5LayerFF",
"T5Attention",
"T5SelfAttention",
"T5CrossAttention",
"T5Block",
"T5Stack",
"T5Encoder",
]

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -1044,6 +1056,22 @@ def __init__(self, theta: Theta, config: T5Config):
theta=theta, config=encoder_config, embed_tokens=self.token_embedding
)

@property
def config(self):
return self.encoder.config

def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]:
return OrderedDict(
[
(
"input_ids",
torch.empty(
size=[batch_size, self.config.context_length], dtype=torch.long
),
)
]
)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
26 changes: 21 additions & 5 deletions sharktank/sharktank/types/gguf_interop/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import torch

from gguf import GGUFReader, GGUFValueType
from gguf import GGUFReader, GGUFValueType, ReaderField

from iree.turbine.aot import (
ExternalTensorTrait,
Expand Down Expand Up @@ -44,12 +44,26 @@ def _sanitize_scalar(scalar):
return scalar


def _load_array(field: ReaderField) -> list:
if len(field.types) != 2:
raise ValueError(f"Unsupported array type {field.types}")
element_type = field.types[1]
if element_type == GGUFValueType.STRING:
return [
str(bytes(field.parts[parts_index]), encoding="utf8")
for parts_index in field.data
]
elif element_type in GGUFReader.gguf_scalar_to_np:
return [
_sanitize_scalar(field.parts[parts_index][0]) for parts_index in field.data
]
else:
raise ValueError(f"Unsupported array element type f{element_type}")


def _load_properties(reader: GGUFReader) -> dict[str, Any]:
# TODO: Figure out what to do with tables.
tables: dict[str, Any] = {}
properties: dict[str, Any] = {
"schema": "GGUF",
# "tables": tables,
}

# Extract hyper-parameters. Adapted from gguf-dump.py
Expand All @@ -60,8 +74,10 @@ def _load_properties(reader: GGUFReader) -> dict[str, Any]:
properties[field.name] = str(bytes(field.parts[-1]), encoding="utf8")
elif field.types[0] in reader.gguf_scalar_to_np:
properties[field.name] = _sanitize_scalar(field.parts[-1][0])
elif field.types[0] == GGUFValueType.ARRAY:
properties[field.name] = _load_array(field)
else:
tables[field.name] = field.parts
raise ValueError(f"Invalid field type.")
return properties


Expand Down
16 changes: 4 additions & 12 deletions sharktank/sharktank/utils/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import iree.runtime
from typing import List, Tuple, Optional, Union
from copy import deepcopy
from pathlib import Path
import torch
import numpy as np
Expand Down Expand Up @@ -91,14 +90,7 @@ def run_iree_module_function(
)
if trace_path_prefix is not None:
for i, arg in enumerate(args):
# iree.runtime.DeviceArray.to_host() will cache the result and reuse it.
# In the meantime the "actual" device array may have changed.
# It kinda assumes immutable arrays.
# This should probably not be its behavior.
# See https://github.com/iree-org/iree/issues/18870.
# deepcopy also returns an numpy ndarray instead of DeviceArray.
arg_copy = deepcopy(arg)
np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg_copy)
np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host())
results = invoker(*args)
if isinstance(results, iree.runtime.DeviceArray):
results = (results,)
Expand All @@ -107,10 +99,10 @@ def run_iree_module_function(
for i, arg in enumerate(args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy",
deepcopy(arg),
arg.to_host(),
)
for i, arg in enumerate(results):
np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", deepcopy(arg))
np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host())
return results


Expand Down Expand Up @@ -197,4 +189,4 @@ def call_torch_module_function(


def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
return [torch.tensor(deepcopy(tensor)) for tensor in tensors]
return [torch.tensor(tensor.to_host()) for tensor in tensors]
Loading

0 comments on commit dbff2e5

Please sign in to comment.