Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 LM v1.1 encoder #550

Merged
merged 12 commits into from
Nov 20, 2024
59 changes: 59 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,62 @@ jobs:
if: ${{ !cancelled() }}
run: |
pytest -n 4 sharktank/


test_with_data:
name: "Data-dependent Tests"
strategy:
matrix:
version: [3.11]
runs-on: [llama-mi300x-3]
fail-fast: false
runs-on: ${{matrix.runs-on}}
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
HF_HOME: "/data/huggingface"
SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
steps:
- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}

- name: Install sharktank deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/

# Install latest iree-tubrine.
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"

# Try with the latest IREE nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler \
iree-base-runtime

- name: Run tests
run: |
pytest \
--with-t5-data \
sharktank/tests/models/t5/t5_test.py
42 changes: 42 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def pytest_addoption(parser):
help="Enable all llama benchmarking tests",
)

parser.addoption(
"--with-t5-data",
action="store_true",
default=False,
help=(
"Enable tests that use T5 data like models that is not a part of the source "
"code. The user is expected to provide the data"
),
)

# TODO: Remove all hardcoded paths in CI tests
parser.addoption(
"--llama3-8b-tokenizer-path",
Expand Down Expand Up @@ -133,6 +143,28 @@ def pytest_addoption(parser):
help="Llama3.1 405b fp8 model path",
)

# To obtain a T5 GGUF file you can use llama.cpp's convert_hf_to_gguf.py.
# https://github.com/ggerganov/llama.cpp/blob/9abe9eeae98b11fa93b82632b264126a010225ff/convert_hf_to_gguf.py
# E.g.
# git lfs install
# git clone https://huggingface.co/google/t5-v1_1-small
# convert_hf_to_gguf.py \
# --outfile t5-v1_1-small.gguf \
# --outtype=f32 \
# t5-v1_1-small
parser.addoption(
"--google-t5-v1-1-small-fp32-model-path",
type=Path,
default="/data/t5/small/google__t5-v1_1-small_fp32.gguf",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't hardcode any llama model/tokenizer paths anymore here. You can pass it as an arg directly to pytest.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have a default that allows you to pytest sharktank/tests if you have data at default paths?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For someone running these tests on a machine that does not have the data in the default paths we should have at least a comment with a link or something for how/where to get this data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment.

help="Google T5 v1.1 small fp32 model path",
)
parser.addoption(
"--google-t5-v1-1-xxl-fp32-model-path",
type=Path,
default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf",
help="Google T5 v1.1 XXL fp32 model path",
)

parser.addoption(
"--baseline-perplexity-scores",
type=Path,
Expand Down Expand Up @@ -256,6 +288,16 @@ def get_model_artifacts(request: FixtureRequest):
model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
)
model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need 2 separate model_path flags for small vs xxl models instead of just calling this twice from our tests?

Copy link
Contributor Author

@sogartar sogartar Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows the already accepted nomenclature for the Llama variants. I think we will get a lot more variants like fp16 and other quantizations.
We probably want sane defaults for all files so that you can do pytest sharktank/tests if you got the files at their expected places already. It is important to have a simple command to run all tests.

request,
"--google-t5-v1-1-small-fp32-model-path",
"google__t5_v1_1_small_fp32_model",
)
model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-xxl-fp32-model-path",
"google__t5_v1_1_xxl_fp32_model",
)
return model_path


Expand Down
73 changes: 71 additions & 2 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
(and indeed, can bootstrap these off of GGUF files).
"""

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Optional
import torch

__all__ = ["LlamaHParams", "LlamaModelConfig"]
__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"]


@dataclass
Expand Down Expand Up @@ -179,3 +179,72 @@ class LlamaModelConfig:
# be the difference of many gigabytes of static data being embedded in
# the program and not.
static_tables: bool = True


@dataclass
class T5Config:
return_dict: bool = True
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
output_hidden_states: bool = False
output_attentions: bool = False
is_encoder_decoder: bool = True
is_decoder: bool = False
vocab_size: int = 32128
d_model: int = 512
d_kv: int = 64
d_ff: int = 2048
num_layers: int = 6
num_decoder_layers: int = 6
num_heads: int = 8
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
layer_norm_epsilon: float = 1e-6
feed_forward_proj: str = "relu"
is_gated_act: bool = field(init=False)
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
activation_dtype: torch.dtype = torch.float32
dense_act_fn: str = field(init=False)
use_cache: bool = True
pad_token_id: int = 0
eos_token_id: int = 1
decoder_start_token_id: int = 0

def __post_init__(self):
self.is_gated_act = self.feed_forward_proj.startswith("gated-")
self.dense_act_fn = (
self.feed_forward_proj.split("-")[1]
if "-" in self.feed_forward_proj
else self.feed_forward_proj
)
if self.dense_act_fn == "gelu":
self.dense_act_fn = "gelu_new"

@staticmethod
def from_gguf_properties(properties: dict[str, Any], **kwargs):
assert properties["general.architecture"] == "t5"
assert (
properties["t5.attention.layer_norm_epsilon"]
== properties["t5.attention.layer_norm_rms_epsilon"]
)

gguf_to_config_names_map = {
"t5.embedding_length": ["d_model"],
"t5.feed_forward_length": ["d_ff"],
"t5.block_count": ["num_layers", "num_decoder_layers"],
"t5.attention.head_count": ["num_heads"],
"t5.attention.key_length": ["d_kv"],
"t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"],
"t5.attention.relative_buckets_count": ["relative_attention_num_buckets"],
"t5.decoder_start_token_id": ["decoder_start_token_id"],
"tokenizer.ggml.eos_token_id": ["eos_token_id"],
"tokenizer.ggml.padding_token_id": ["pad_token_id"],
}
all_kwargs = {"vocab_size": None, "feed_forward_proj": None}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_config_names_map.items()
for config_name in config_names
}
)
all_kwargs.update(kwargs)

return T5Config(**all_kwargs)
28 changes: 20 additions & 8 deletions sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional
from typing import Optional, Callable

import torch
import torch.nn.functional as F
from .. import ops
from ..types import AnyTensor

from .base import Theta, ThetaLayer
from .linear import LinearLayer
Expand All @@ -22,18 +23,29 @@ class FFN(ThetaLayer):
def __init__(
self,
theta: Theta,
is_gated: bool = True,
activation_fn: Callable[[AnyTensor], AnyTensor] = F.silu,
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(theta)

self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.is_gated = is_gated
self.activation_fn = activation_fn
if self.is_gated:
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))

def forward(
self,
h: torch.Tensor,
):
ffn_gate = ops.elementwise(F.silu, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
h: AnyTensor,
) -> AnyTensor:
if self.is_gated:
ffn_gate = ops.elementwise(self.activation_fn, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
else:
h = self.ffn_up(h)
h = ops.elementwise(self.activation_fn, h)
h = self.ffn_down(h)
return h
Loading
Loading