Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 encoder exporting to MLIR and numerics verification with IREE #573

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class T5Config:
is_encoder_decoder: bool = True
is_decoder: bool = False
vocab_size: int = 32128
context_length: int = 512
d_model: int = 512
d_kv: int = 64
d_ff: int = 2048
Expand All @@ -206,6 +207,7 @@ class T5Config:
pad_token_id: int = 0
eos_token_id: int = 1
decoder_start_token_id: int = 0
context_length_padding_block_size: int = 16

def __post_init__(self):
self.is_gated_act = self.feed_forward_proj.startswith("gated-")
Expand All @@ -226,6 +228,7 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
)

gguf_to_config_names_map = {
"t5.context_length": ["context_length"],
"t5.embedding_length": ["d_model"],
"t5.feed_forward_length": ["d_ff"],
"t5.block_count": ["num_layers", "num_decoder_layers"],
Expand All @@ -245,6 +248,8 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
for config_name in config_names
}
)
if "tokenizer.ggml.tokens" in properties:
all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"])
all_kwargs.update(kwargs)

return T5Config(**all_kwargs)
8 changes: 8 additions & 0 deletions sharktank/sharktank/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .t5 import *
from .export import *
97 changes: 97 additions & 0 deletions sharktank/sharktank/models/t5/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Union
from pathlib import Path
import torch

from .t5 import T5Config, T5Encoder
from ...types import Dataset
from iree.turbine.aot import FxProgramsBuilder, export

__all__ = [
"export_encoder_mlir",
"export_encoder_iree_parameters",
"prune_decoder_parameters",
]


def export_encoder_mlir(
model: Union[T5Encoder, Path, str],
batch_sizes: list[int],
mlir_output_path: str,
):
"""
Args:
model: either the torch module or path to GGUF/IRPA.
"""
if isinstance(model, (Path, str)):
dataset = Dataset.load(model)
config = T5Config.from_gguf_properties(
dataset.properties,
# TODO: add this property to our HuggingFace-to-GGUF conversion script.
# We currently use llama.cpp's converter and it can not make a distinction
# between T5 V1 and V1.1.
# V1 uses ReLU and V1.1 uses gated GeLU.
feed_forward_proj="gated-gelu",
)
model = T5Encoder(theta=dataset.root_theta, config=config)

fxb = FxProgramsBuilder(model)

for batch_size in batch_sizes:
sample_inputs = model.sample_inputs(batch_size)

context_length_dim_idx = 1
assert (
sample_inputs["input_ids"].shape[context_length_dim_idx]
% config.context_length_padding_block_size
== 0
)
context_length_block_dim_max = (
sample_inputs["input_ids"].shape[context_length_dim_idx]
// config.context_length_padding_block_size
)
context_length_block_dim = torch.export.Dim(
"block", max=context_length_block_dim_max
)
context_length_dim = (
config.context_length_padding_block_size * context_length_block_dim
)
dynamic_shapes = {"input_ids": {context_length_dim_idx: context_length_dim}}

@fxb.export_program(
name=f"forward_bs{batch_size}",
args=tuple(sample_inputs.values()),
dynamic_shapes=dynamic_shapes,
strict=False,
)
def _(
model,
input_ids,
):
return model(input_ids)

output = export(fxb, import_symbolic_shape_expressions=True)
output.save_mlir(mlir_output_path)


def prune_decoder_parameters(dataset: Dataset):
# Remove decoder tensors/parameters if present.
try:
del dataset.root_theta.tree["dec"]
except KeyError:
pass
try:
del dataset.properties["t5.decoder_start_token_id"]
except KeyError:
pass


def export_encoder_iree_parameters(model_path: str, output_path: str):
dataset = Dataset.load(model_path)
prune_decoder_parameters(dataset)
dataset.save(output_path)
28 changes: 28 additions & 0 deletions sharktank/sharktank/models/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,20 @@
)
from ... import ops
from ...types.theta import Theta
from ...types.tensors import AnyTensor
from ...layers import FFN, T5Config

__all__ = [
"T5Config",
"T5LayerFF",
"T5Attention",
"T5SelfAttention",
"T5CrossAttention",
"T5Block",
"T5Stack",
"T5Encoder",
]

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -1044,6 +1056,22 @@ def __init__(self, theta: Theta, config: T5Config):
theta=theta, config=encoder_config, embed_tokens=self.token_embedding
)

@property
def config(self):
return self.encoder.config

def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]:
return OrderedDict(
[
(
"input_ids",
torch.empty(
size=[batch_size, self.config.context_length], dtype=torch.long
),
)
]
)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
26 changes: 21 additions & 5 deletions sharktank/sharktank/types/gguf_interop/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import torch

from gguf import GGUFReader, GGUFValueType
from gguf import GGUFReader, GGUFValueType, ReaderField

from iree.turbine.aot import (
ExternalTensorTrait,
Expand Down Expand Up @@ -44,12 +44,26 @@ def _sanitize_scalar(scalar):
return scalar


def _load_array(field: ReaderField) -> list:
if len(field.types) != 2:
raise ValueError(f"Unsupported array type {field.types}")
element_type = field.types[1]
if element_type == GGUFValueType.STRING:
return [
str(bytes(field.parts[parts_index]), encoding="utf8")
for parts_index in field.data
]
elif element_type in GGUFReader.gguf_scalar_to_np:
return [
_sanitize_scalar(field.parts[parts_index][0]) for parts_index in field.data
]
else:
raise ValueError(f"Unsupported array element type f{element_type}")


def _load_properties(reader: GGUFReader) -> dict[str, Any]:
# TODO: Figure out what to do with tables.
tables: dict[str, Any] = {}
properties: dict[str, Any] = {
"schema": "GGUF",
# "tables": tables,
}

# Extract hyper-parameters. Adapted from gguf-dump.py
Expand All @@ -60,8 +74,10 @@ def _load_properties(reader: GGUFReader) -> dict[str, Any]:
properties[field.name] = str(bytes(field.parts[-1]), encoding="utf8")
elif field.types[0] in reader.gguf_scalar_to_np:
properties[field.name] = _sanitize_scalar(field.parts[-1][0])
elif field.types[0] == GGUFValueType.ARRAY:
properties[field.name] = _load_array(field)
else:
tables[field.name] = field.parts
raise ValueError(f"Invalid field type.")
return properties


Expand Down
16 changes: 4 additions & 12 deletions sharktank/sharktank/utils/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import iree.runtime
from typing import List, Tuple, Optional, Union
from copy import deepcopy
from pathlib import Path
import torch
import numpy as np
Expand Down Expand Up @@ -91,14 +90,7 @@ def run_iree_module_function(
)
if trace_path_prefix is not None:
for i, arg in enumerate(args):
# iree.runtime.DeviceArray.to_host() will cache the result and reuse it.
# In the meantime the "actual" device array may have changed.
# It kinda assumes immutable arrays.
# This should probably not be its behavior.
# See https://github.com/iree-org/iree/issues/18870.
# deepcopy also returns an numpy ndarray instead of DeviceArray.
arg_copy = deepcopy(arg)
np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg_copy)
np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host())
results = invoker(*args)
if isinstance(results, iree.runtime.DeviceArray):
results = (results,)
Expand All @@ -107,10 +99,10 @@ def run_iree_module_function(
for i, arg in enumerate(args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy",
deepcopy(arg),
arg.to_host(),
)
for i, arg in enumerate(results):
np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", deepcopy(arg))
np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host())
return results


Expand Down Expand Up @@ -197,4 +189,4 @@ def call_torch_module_function(


def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
return [torch.tensor(deepcopy(tensor)) for tensor in tensors]
return [torch.tensor(tensor.to_host()) for tensor in tensors]
Loading
Loading