19
19
from sharktank .utils .testing import MainRunnerTestBase
20
20
21
21
22
- class MmtRHSShardingTransformTest (MainRunnerTestBase ):
23
- def testPrimitive (self ):
22
+ class DatasetShardingTransformTest (MainRunnerTestBase ):
23
+ def testShardLlmDataset (self ):
24
24
orig_pts = [
25
25
DefaultPrimitiveTensor (
26
26
name = "blk.1.attn_k.weight" , data = torch .randn ([32 , 128 ])
27
27
),
28
28
DefaultPrimitiveTensor (
29
29
name = "blk.2.attn_q.weight" , data = torch .randn ([48 , 64 ])
30
30
),
31
- DefaultPrimitiveTensor (name = "other" , data = torch .randn ([2 , 2 ])),
32
31
]
33
- ds_orig = Dataset ({}, Theta (orig_pts ))
32
+ ds_orig = Dataset (
33
+ {
34
+ "general.architecture" : "llm" ,
35
+ "llm.attention.head_count" : 1 ,
36
+ "llm.context_length" : 2 ,
37
+ "llm.embedding_length" : 3 ,
38
+ "llm.block_count" : 4 ,
39
+ "llm.feed_forward_length" : 5 ,
40
+ "llm.attention.layer_norm_rms_epsilon" : 0.1 ,
41
+ },
42
+ Theta (orig_pts ),
43
+ )
34
44
input_path = self .save_dataset (ds_orig , "input" )
35
45
output_path = self .get_irpa_path ("output" )
36
46
from sharktank .examples .sharding import shard_llm_dataset
@@ -41,38 +51,38 @@ def testPrimitive(self):
41
51
input_path ,
42
52
"--output-irpa-file" ,
43
53
output_path ,
44
- "--num-shards " ,
54
+ "--tensor-parallelism-size " ,
45
55
8 ,
46
56
)
47
57
ds_tran = Dataset .load (output_path , mmap = False )
48
58
59
+ ds_tran .properties ["tensor_parallelism_size" ] = 8
60
+
49
61
# Verify.
50
62
flat_sts = ds_tran .root_theta .flatten ()
51
- self .assertEqual (3 , len (flat_sts ))
63
+ self .assertEqual (2 , len (flat_sts ))
52
64
st_1 = flat_sts ["blk.1.attn_k.weight" ]
53
65
st_2 = flat_sts ["blk.2.attn_q.weight" ]
54
- pt_3 = flat_sts ["other" ]
55
66
self .assertIsInstance (st_1 , SplitPrimitiveTensor )
56
67
self .assertIsInstance (st_2 , SplitPrimitiveTensor )
57
- self .assertIsInstance (pt_3 , DefaultPrimitiveTensor )
58
68
self .assertListEqual (st_1 .shape , [32 , 128 ])
59
69
self .assertListEqual (st_2 .shape , [48 , 64 ])
60
70
61
71
# Verify component shapes for st_1.
62
72
self .assertEqual (8 , len (st_1 .shards ))
63
- self .assertTrue (all (pt .shape == [32 , 16 ] for pt in st_1 .shards ))
73
+ self .assertTrue (all (pt .shape == [4 , 128 ] for pt in st_1 .shards ))
64
74
self .assertTrue (
65
- all (list (pt .as_torch ().shape ) == [32 , 16 ] for pt in st_1 .shards )
75
+ all (list (pt .as_torch ().shape ) == [4 , 128 ] for pt in st_1 .shards )
66
76
)
67
77
68
78
# Verify component shapes for st_2.
69
79
self .assertEqual (8 , len (st_2 .shards ))
70
- self .assertTrue (all (pt .shape == [48 , 8 ] for pt in st_2 .shards ))
71
- self .assertTrue (all (list (pt .as_torch ().shape ) == [48 , 8 ] for pt in st_2 .shards ))
80
+ self .assertTrue (all (pt .shape == [6 , 64 ] for pt in st_2 .shards ))
81
+ self .assertTrue (all (list (pt .as_torch ().shape ) == [6 , 64 ] for pt in st_2 .shards ))
72
82
73
83
# Verify contents for one shard for sanity.
74
84
new_t = st_1 .shards [0 ].as_torch ()
75
- torch .testing .assert_close (new_t , orig_pts [0 ].as_torch ().split (16 , dim = 1 )[0 ])
85
+ torch .testing .assert_close (new_t , orig_pts [0 ].as_torch ().split (4 , dim = 0 )[0 ])
76
86
77
87
78
88
if __name__ == "__main__" :
0 commit comments