Skip to content

Commit 2a79eda

Browse files
Merge branch 'main' into perplexity-vmfb
2 parents b220688 + f925a5b commit 2a79eda

19 files changed

+102
-93
lines changed

.github/workflows/ci-sharktank.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ on:
66
paths:
77
- '.github/workflows/ci-sharktank.yml'
88
- 'sharktank/**'
9-
- '*requirements.txt'
9+
- '*requirements*.txt'
1010
push:
1111
branches:
1212
- main
1313
paths:
1414
- '.github/workflows/ci-sharktank.yml'
1515
- 'sharktank/**'
16-
- '*requirements.txt'
16+
- '*requirements*.txt'
1717

1818
concurrency:
1919
# A PR number if a pull request and otherwise the commit hash. This cancels
@@ -52,7 +52,7 @@ jobs:
5252
id: cache-pip
5353
with:
5454
path: ${{ env.PIP_CACHE_DIR }}
55-
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
55+
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}
5656

5757
- name: Install pip deps
5858
run: |

.github/workflows/ci_eval.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
id: cache-pip
4747
with:
4848
path: ${{ env.PIP_CACHE_DIR }}
49-
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
49+
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}
5050

5151
- name: Install sharktank deps
5252
run: |

requirements-dev.txt

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Used for managing pre-commit flows.
2+
pre-commit
3+
4+
# Type checking
5+
mypy==1.8.0
6+
types-requests==2.31.0.20240125
7+
8+
# Testing
9+
pytest==8.0.0
10+
pytest-xdist==3.5.0

requirements.txt

+4-30
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,4 @@
1-
# Runtime deps.
2-
gguf==0.6.0
3-
numpy==1.26.3
4-
onnx==1.15.0
5-
6-
# Model deps.
7-
huggingface-hub==0.22.2
8-
transformers==4.40.0
9-
sentencepiece==0.2.0
10-
11-
# It is expected that you have installed a PyTorch version/variant specific
12-
# to your needs, so we only include a minimum version spec.
13-
# TODO: Use a versioned release once 2.3.0 drops.
14-
torch>=2.3.0.dev1
15-
16-
# Used for managing pre-commit flows.
17-
pre-commit
18-
19-
# Type checking
20-
mypy==1.8.0
21-
types-requests==2.31.0.20240125
22-
23-
# Testing
24-
parameterized
25-
pytest==8.0.0
26-
pytest-xdist==3.5.0
27-
28-
# Serving deps.
29-
fastapi==0.112.2
30-
uvicorn==0.30.6
1+
-r sharktank/requirements.txt
2+
-r sharktank/requirements-tests.txt
3+
-r shortfin/requirements-tests.txt
4+
-r requirements-dev.txt

sharktank/requirements-tests.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
datasets==3.0.0
2+
parameterized
3+
pytest==8.0.0

sharktank/requirements.txt

+16-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,16 @@
1-
gguf
1+
# Runtime deps.
2+
gguf==0.6.0
3+
numpy==1.26.3
4+
5+
# Model deps.
6+
huggingface-hub==0.22.2
7+
transformers==4.40.0
8+
datasets
9+
10+
# It is expected that you have installed a PyTorch version/variant specific
11+
# to your needs, so we only include a minimum version spec.
12+
torch>=2.3.0
13+
14+
# Serving deps.
15+
fastapi==0.112.2
16+
uvicorn==0.30.6

sharktank/setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def initialize_options(self):
9999
extras_require={
100100
"testing": [
101101
f"pytest{get_version_spec('pytest')}",
102-
f"pytest-xdist{get_version_spec('pytest-xdist')}",
103102
],
104103
},
105104
cmdclass={"build": BuildCommand},

sharktank/sharktank/examples/export_paged_llm_v1.py

+29-20
Original file line numberDiff line numberDiff line change
@@ -116,35 +116,38 @@ def setup_cache(model, shard_count):
116116
page_count=hp.context_length // llama_config.block_seq_stride
117117
)
118118
page_dim = torch.export.Dim("page")
119+
119120
dynamic_shapes = [{0: page_dim}]
121+
unpacked = cache_state
122+
arg_affinities = {}
123+
shard_dim = None
124+
125+
# Need to unpacke that state when sharded
126+
if llama_config.tensor_parallelism_size > 1:
127+
shard_dim = cache_state[0].shard_dim
128+
129+
unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
130+
dynamic_shapes = [
131+
[ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
132+
]
133+
134+
for i in range(llama_config.tensor_parallelism_size):
135+
arg_affinities[i] = DeviceAffinity(str(i))
136+
137+
return unpacked, shard_dim, dynamic_shapes, arg_affinities
138+
120139
elif model.config.kv_cache_type == "direct":
121140
cache_state = model.cache.allocate(bs=1)
122141
# Direct cache dimensions:
123142
# 2 * transformer_block_count of...
124143
# [bs, seq_length, attn_head_count, attn_head_dim]
125144
dynamic_shapes = [None]
145+
arg_affinities = {}
146+
shard_dim = None
147+
return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities
126148
else:
127149
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")
128150

129-
unpacked = cache_state
130-
dynamic_shapes = dynamic_shapes
131-
arg_affinities = {}
132-
shard_dim = None
133-
134-
# Need to unpacke that state when sharded
135-
if llama_config.tensor_parallelism_size > 1:
136-
shard_dim = cache_state[0].shard_dim
137-
138-
unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
139-
dynamic_shapes = [
140-
[ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
141-
]
142-
143-
for i in range(llama_config.tensor_parallelism_size):
144-
arg_affinities[i] = DeviceAffinity(str(i))
145-
146-
return torch.stack(unpacked), shard_dim, dynamic_shapes, arg_affinities
147-
148151
def repack_cache(cache, shard_dim):
149152
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]
150153

@@ -184,7 +187,13 @@ def generate_batch_prefill(bs: int):
184187
arg_device=arg_affinities,
185188
)
186189
def _(model, tokens, seq_lens, seq_block_ids, cs):
187-
cache_tensors = torch.unbind(cs)
190+
if (
191+
model.config.tensor_parallelism_size == 1
192+
and model.config.kv_cache_type == "direct"
193+
):
194+
cache_tensors = torch.unbind(cs)
195+
else:
196+
cache_tensors = cs
188197

189198
sl = tokens.shape[1]
190199
input_mask = model.input_mask(seq_lens, sl)

sharktank/sharktank/layers/causal_llm.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,9 @@ def input_mask(
9595

9696
def decode_attention_mask(self, boolean_input_mask: torch.Tensor):
9797
dtype = self.attention_dtype
98-
numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype)
99-
numeric_mask.masked_fill_(
100-
boolean_input_mask, self._maximally_negative_value(dtype)
101-
)
98+
numeric_mask = torch.where(
99+
boolean_input_mask, self._maximally_negative_value(dtype), 0
100+
).to(dtype)
102101
return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device)
103102

104103
def attention_mask(
@@ -127,9 +126,10 @@ def attention_mask(
127126
dtype = self.attention_dtype
128127
_, batch_seq_len = input_mask.shape
129128
causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len]
130-
boolean_mask = causal_mask + input_mask[:, None, None, :]
131-
numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype)
132-
numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype))
129+
boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :])
130+
numeric_mask = torch.where(
131+
boolean_mask, self._maximally_negative_value(dtype), 0
132+
).to(dtype)
133133
return numeric_mask.to(self.device)
134134

135135
def extract_tokens_from_logits(

shortfin/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
cmake_minimum_required(VERSION 3.29)
88

shortfin/build_tools/cmake/shortfin_library.cmake

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
set(SHORTFIN_DEFAULT_COPTS
88
# General clang and GCC options application to C and C++.

shortfin/dev_me.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#!/usr/bin/env python3
22
# Copyright 2024 Advanced Micro Devices, Inc.
33
#
4-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
5-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
6-
# Apache-2.0 WITH LLVM-exception
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
5+
# See https://llvm.org/LICENSE.txt for license information.
6+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
77

88
# dev_me.py
99
#

shortfin/python/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
# shortfin publishes multiple python packages: - _shortfin: Trampoline
88
# __init__.py which looks at environment variables to load an appropriate native

shortfin/src/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
add_subdirectory(shortfin)
88

shortfin/src/shortfin/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
add_subdirectory(array)
88
add_subdirectory(local)

shortfin/src/shortfin/array/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
shortfin_cc_component(
88
NAME

shortfin/src/shortfin/local/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
add_subdirectory(systems)
88

shortfin/src/shortfin/local/systems/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
set(_SYSTEM_COMPONENTS)
88

shortfin/src/shortfin/support/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2024 Advanced Micro Devices, Inc.
22
#
3-
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
4-
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
5-
# Apache-2.0 WITH LLVM-exception
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
shortfin_cc_component(
88
NAME

0 commit comments

Comments
 (0)