Skip to content

Commit 1aeb3a8

Browse files
authored
Split sharded Llama dataset exporting and loading in export scripts (#327)
Separate the 2 steps. We need exported irpa files for the IREE module anyway.
1 parent d3778ed commit 1aeb3a8

File tree

3 files changed

+49
-37
lines changed

3 files changed

+49
-37
lines changed

sharktank/sharktank/examples/export_paged_llm_v1.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -59,31 +59,26 @@ def main():
5959
default="decomposed",
6060
choices=["decomposed", "torch_sdpa"],
6161
)
62-
parser.add_argument(
63-
"--tensor-parallelism-size",
64-
type=int,
65-
default=1,
66-
help="How many devices are involved for tensor parallel sharding.",
67-
)
6862

6963
args = cli.parse(parser)
7064
dataset_type = cli.get_input_data_files(args)
7165
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
7266
dataset = cli.get_input_dataset(args)
7367

7468
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
75-
llama_config = LlamaModelConfig(hp)
76-
if args.tensor_parallelism_size > 1:
77-
dataset.root_theta = shard_theta(dataset.root_theta, llama_config)
78-
llama_config.use_hf = False
79-
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
80-
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
81-
llama_config.attention_kernel = args.attention_kernel
82-
83-
# This is a bit gross and should be changed in the future. Best Idea I had so far.
84-
attn_q_weight = dataset.root_theta.tensor("blk")["0"]["attn_q"]["weight"]
85-
if isinstance(attn_q_weight, SplitPrimitiveTensor):
86-
llama_config.tensor_parallelism_size = attn_q_weight.shard_count
69+
tensor_parallelism_size = (
70+
dataset.properties["tensor_parallelism_size"]
71+
if "tensor_parallelism_size" in dataset.properties
72+
else 1
73+
)
74+
llama_config = LlamaModelConfig(
75+
hp,
76+
tensor_parallelism_size=tensor_parallelism_size,
77+
use_hf=False,
78+
static_tables=False, # Rely on the compiler for hoisting tables.
79+
kv_cache_type="direct" if args.bs == [1] else "paged",
80+
attention_kernel=args.attention_kernel,
81+
)
8782

8883
if llama_config.hp.expert_count:
8984
if llama_config.hp.model_arch == "grok":

sharktank/sharktank/examples/sharding/shard_llm_dataset.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
weights of an LLM by converting the RHS of all eligible layers to a sharded
1111
form.
1212
"""
13-
from ...transforms.dataset import MmtRHSShardingTransform
13+
from ...models.llama.sharding import shard_theta
14+
from ...layers import LlamaHParams, LlamaModelConfig
1415
from ...types import *
1516

1617

@@ -21,16 +22,22 @@ def main(raw_args=None):
2122
cli.add_input_dataset_options(parser)
2223
cli.add_output_dataset_options(parser)
2324
parser.add_argument(
24-
"--num-shards", type=int, required=True, help="Number of shards to split"
25+
"--tensor-parallelism-size",
26+
type=int,
27+
required=True,
28+
help="Number of shards to split",
2529
)
2630
args = cli.parse(parser, args=raw_args)
2731
dataset = cli.get_input_dataset(args)
2832

29-
tr = MmtRHSShardingTransform(
30-
r"^blk\.[0-9]+\.(attn_k|attn_q|attn_v|ffn_gate|ffn_up|ffn_down)\.weight$",
31-
num_shards=8,
33+
hp = LlamaHParams.from_gguf_props(dataset.properties)
34+
llama_config = LlamaModelConfig(
35+
hp, tensor_parallelism_size=args.tensor_parallelism_size
3236
)
33-
dataset.transform(tr)
37+
sharded_theta = shard_theta(dataset.root_theta, llama_config)
38+
sharded_theta.rename_tensors_to_paths()
39+
dataset.root_theta = sharded_theta
40+
dataset.properties["tensor_parallelism_size"] = args.tensor_parallelism_size
3441
dataset.save(args.output_irpa_file, io_report_callback=print)
3542

3643

sharktank/tests/transforms/dataset_transforms_test.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,28 @@
1919
from sharktank.utils.testing import MainRunnerTestBase
2020

2121

22-
class MmtRHSShardingTransformTest(MainRunnerTestBase):
23-
def testPrimitive(self):
22+
class DatasetShardingTransformTest(MainRunnerTestBase):
23+
def testShardLlmDataset(self):
2424
orig_pts = [
2525
DefaultPrimitiveTensor(
2626
name="blk.1.attn_k.weight", data=torch.randn([32, 128])
2727
),
2828
DefaultPrimitiveTensor(
2929
name="blk.2.attn_q.weight", data=torch.randn([48, 64])
3030
),
31-
DefaultPrimitiveTensor(name="other", data=torch.randn([2, 2])),
3231
]
33-
ds_orig = Dataset({}, Theta(orig_pts))
32+
ds_orig = Dataset(
33+
{
34+
"general.architecture": "llm",
35+
"llm.attention.head_count": 1,
36+
"llm.context_length": 2,
37+
"llm.embedding_length": 3,
38+
"llm.block_count": 4,
39+
"llm.feed_forward_length": 5,
40+
"llm.attention.layer_norm_rms_epsilon": 0.1,
41+
},
42+
Theta(orig_pts),
43+
)
3444
input_path = self.save_dataset(ds_orig, "input")
3545
output_path = self.get_irpa_path("output")
3646
from sharktank.examples.sharding import shard_llm_dataset
@@ -41,38 +51,38 @@ def testPrimitive(self):
4151
input_path,
4252
"--output-irpa-file",
4353
output_path,
44-
"--num-shards",
54+
"--tensor-parallelism-size",
4555
8,
4656
)
4757
ds_tran = Dataset.load(output_path, mmap=False)
4858

59+
ds_tran.properties["tensor_parallelism_size"] = 8
60+
4961
# Verify.
5062
flat_sts = ds_tran.root_theta.flatten()
51-
self.assertEqual(3, len(flat_sts))
63+
self.assertEqual(2, len(flat_sts))
5264
st_1 = flat_sts["blk.1.attn_k.weight"]
5365
st_2 = flat_sts["blk.2.attn_q.weight"]
54-
pt_3 = flat_sts["other"]
5566
self.assertIsInstance(st_1, SplitPrimitiveTensor)
5667
self.assertIsInstance(st_2, SplitPrimitiveTensor)
57-
self.assertIsInstance(pt_3, DefaultPrimitiveTensor)
5868
self.assertListEqual(st_1.shape, [32, 128])
5969
self.assertListEqual(st_2.shape, [48, 64])
6070

6171
# Verify component shapes for st_1.
6272
self.assertEqual(8, len(st_1.shards))
63-
self.assertTrue(all(pt.shape == [32, 16] for pt in st_1.shards))
73+
self.assertTrue(all(pt.shape == [4, 128] for pt in st_1.shards))
6474
self.assertTrue(
65-
all(list(pt.as_torch().shape) == [32, 16] for pt in st_1.shards)
75+
all(list(pt.as_torch().shape) == [4, 128] for pt in st_1.shards)
6676
)
6777

6878
# Verify component shapes for st_2.
6979
self.assertEqual(8, len(st_2.shards))
70-
self.assertTrue(all(pt.shape == [48, 8] for pt in st_2.shards))
71-
self.assertTrue(all(list(pt.as_torch().shape) == [48, 8] for pt in st_2.shards))
80+
self.assertTrue(all(pt.shape == [6, 64] for pt in st_2.shards))
81+
self.assertTrue(all(list(pt.as_torch().shape) == [6, 64] for pt in st_2.shards))
7282

7383
# Verify contents for one shard for sanity.
7484
new_t = st_1.shards[0].as_torch()
75-
torch.testing.assert_close(new_t, orig_pts[0].as_torch().split(16, dim=1)[0])
85+
torch.testing.assert_close(new_t, orig_pts[0].as_torch().split(4, dim=0)[0])
7686

7787

7888
if __name__ == "__main__":

0 commit comments

Comments
 (0)