Skip to content

Commit 2057fae

Browse files
committed
basic sharding support for quant tensors
1 parent 6f3f8c7 commit 2057fae

File tree

7 files changed

+233
-46
lines changed

7 files changed

+233
-46
lines changed

sharktank/sharktank/ops/default_impls.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,8 @@ def to_default(tensor: Tensor, *args, **kwargs):
438438
return unbox_tensor(tensor).to(*args, **kwargs)
439439

440440

441-
@transfer_to_logical_device.override(Tensor)
442-
def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
441+
@transfer_to_logical_device.override(AllOfType(AnyTensor, QuantizedTensor))
442+
def transfer_to_logical_device_default(tensor, ordinal: int):
443443
return iree.turbine.ops.iree.transfer_to_logical_device(
444444
f"{ordinal}", unbox_tensor(tensor)
445445
)

sharktank/sharktank/ops/sharded_impls.py

+83
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
AnyTensor,
1616
DefaultPrimitiveTensor,
1717
InferenceTensor,
18+
QuantizedTensor,
19+
PlanarQuantizedTensor,
1820
PrimitiveTensor,
1921
ReplicatedTensor,
2022
ShardedTensor,
@@ -28,6 +30,8 @@
2830
from .signatures import *
2931
from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim
3032
from ..utils import longest_equal_range
33+
from ..utils.math import ceildiv
34+
from sharktank.types.tensors import REGISTERED_LAYOUT_CLASSES
3135

3236

3337
@all_gather.override(SplitPrimitiveTensor)
@@ -1264,3 +1268,82 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive
12641268
res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
12651269
assert math.prod(res.shape) == math.prod(tensor.shape)
12661270
return res
1271+
1272+
1273+
@split.override(QuantizedTensor)
1274+
def split_QuantizedTensor(tensor: QuantizedTensor, split_size_or_sections, dim):
1275+
tensors = []
1276+
unpacked = tensor.unpack()
1277+
num_outputs = ceildiv(unpacked._qs.shape[dim], split_size_or_sections)
1278+
new_shape = unpacked._shape
1279+
new_shape[dim] = split_size_or_sections
1280+
new_qs = torch.split(unpacked._qs, split_size_or_sections, dim)
1281+
if unpacked._d.ndim > 0:
1282+
new_d = torch.split(unpacked._d, split_size_or_sections, dim)
1283+
if unpacked.serialized_name() == "SuperBlockOffsetScaled_4_6_Layout":
1284+
new_dmin = torch.split(unpacked._dmin, split_size_or_sections, dim)
1285+
new_sb_scales_high = torch.split(
1286+
unpacked._sb_scales_high, split_size_or_sections, dim
1287+
)
1288+
new_sb_scales_low = torch.split(
1289+
unpacked._sb_scales_low, split_size_or_sections, dim
1290+
)
1291+
new_sb_mins_high = torch.split(
1292+
unpacked._sb_mins_high, split_size_or_sections, dim
1293+
)
1294+
new_sb_mins_low = torch.split(
1295+
unpacked._sb_mins_low, split_size_or_sections, dim
1296+
)
1297+
for i in range(num_outputs):
1298+
layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()]
1299+
new_layout = layout_clazz(
1300+
shape=new_shape,
1301+
d=new_d[i],
1302+
dmin=new_dmin[i],
1303+
sb_scales_high=new_sb_scales_high[i],
1304+
sb_scales_low=new_sb_scales_low[i],
1305+
sb_mins_high=new_sb_mins_high[i],
1306+
sb_mins_low=new_sb_mins_low[i],
1307+
qs=new_qs[i],
1308+
)
1309+
new_tensor = tensor.__class__
1310+
new_tensor_layout = new_layout.create(
1311+
new_layout.shape, new_layout.metadata, new_layout.planes
1312+
)
1313+
new_tensor = tensor.__class__(
1314+
shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i)
1315+
)
1316+
tensors.append(new_tensor)
1317+
else:
1318+
if split_size_or_sections > unpacked._qs.shape[dim]:
1319+
raise ValueError("split size greater than tensor dim")
1320+
1321+
if unpacked._m is not None:
1322+
if unpacked._m.ndim > 0:
1323+
new_m = torch.split(unpacked._m, split_size_or_sections, dim)
1324+
for i in range(num_outputs):
1325+
layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()]
1326+
if unpacked._m is not None:
1327+
if unpacked._d.ndim > 0:
1328+
new_layout = layout_clazz(
1329+
shape=new_shape, d=new_d[i], qs=new_qs[i], m=new_m[i]
1330+
)
1331+
else:
1332+
new_layout = layout_clazz(
1333+
shape=new_shape, d=unpacked._d, qs=new_qs[i], m=unpacked._m
1334+
)
1335+
else:
1336+
if unpacked._d.ndim > 0:
1337+
new_layout = layout_clazz(shape=new_shape, d=new_d[i], qs=new_qs[i])
1338+
else:
1339+
new_layout = layout_clazz(
1340+
shape=new_shape, d=unpacked._d, qs=new_qs[i]
1341+
)
1342+
new_tensor_layout = new_layout.create(
1343+
new_layout.shape, new_layout.metadata, new_layout.planes
1344+
)
1345+
new_tensor = tensor.__class__(
1346+
shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i)
1347+
)
1348+
tensors.append(new_tensor)
1349+
return tensors

sharktank/sharktank/ops/signatures.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@
1111
import torch
1212
import numbers
1313
from torch import Tensor, dtype
14-
from ..types import AnyTensor, ShardedTensor, Theta, sharding, InferenceTensor
14+
from ..types import (
15+
AnyTensor,
16+
ShardedTensor,
17+
Theta,
18+
sharding,
19+
InferenceTensor,
20+
QuantizedTensor,
21+
PlanarQuantizedTensor,
22+
)
1523
from numbers import Number
1624

1725
from ._registry import *
@@ -59,6 +67,7 @@
5967
"unshard",
6068
"unsqueeze",
6169
"view",
70+
"split",
6271
]
6372

6473
IntOrSequenceInt = Union[int, Sequence[int]]
@@ -976,14 +985,18 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):
976985

977986

978987
@overridable
979-
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
988+
def transfer_to_logical_device(
989+
tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor], ordinal: int
990+
) -> Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor]:
980991
"""Transfer the tensor to a device with ordinal `ordinal`."""
981992
...
982993

983994

984995
@transfer_to_logical_device.trampoline
985996
def _transfer_to_logical_device_trampoline(
986-
d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
997+
d: SignatureDispatcher,
998+
tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor],
999+
ordinal: int,
9871000
):
9881001
tensors = (tensor,)
9891002
for override in d.find_overrides(tensors):
@@ -1085,3 +1098,27 @@ def _view_trampoline(
10851098
return override, result
10861099
else:
10871100
d.fail(tensors)
1101+
1102+
1103+
@overridable
1104+
def split(
1105+
tensor: QuantizedTensor, split_size_or_sections: List[int], dim: int
1106+
) -> [QuantizedTensor]:
1107+
"""See torch.Tensor.split"""
1108+
...
1109+
1110+
1111+
@split.trampoline
1112+
def _split_trampoline(
1113+
d: SignatureDispatcher,
1114+
tensor: QuantizedTensor,
1115+
split_size_or_sections: List[int],
1116+
dim: int,
1117+
) -> [QuantizedTensor]:
1118+
tensors = (tensor,)
1119+
for override in d.find_overrides(tensors):
1120+
result = override(tensor, split_size_or_sections, dim)
1121+
if result is not NotImplemented:
1122+
return override, result
1123+
else:
1124+
d.fail(tensors)

sharktank/sharktank/types/sharding.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,17 @@ def __init__(self, *args, **kwargs):
6060
for k, v in d.items():
6161
d[k] = tree.map_nodes(
6262
tree=v,
63-
f=lambda x: x
64-
if isinstance(
65-
x,
66-
(
67-
TensorSharding,
68-
ThetaSharding,
69-
),
70-
)
71-
else ThetaSharding(x),
63+
f=lambda x: (
64+
x
65+
if isinstance(
66+
x,
67+
(
68+
TensorSharding,
69+
ThetaSharding,
70+
),
71+
)
72+
else ThetaSharding(x)
73+
),
7274
)
7375
super().__init__(d)
7476

sharktank/sharktank/types/tensors.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,11 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad
581581
"""
582582
return self.to_planar().add_to_archive(builder)
583583

584+
def split(self, split_size_or_sections: [int], dim: int) -> "[QuantizedTensor]":
585+
from ..ops import split
586+
587+
return split(self, split_size_or_sections, dim)
588+
584589

585590
@register_inference_tensor
586591
class PlanarQuantizedTensor(QuantizedTensor):
@@ -764,12 +769,14 @@ def __init__(
764769
assert shard_dim is None or (shard_dim >= 0 and len(ts[0].shape) > shard_dim)
765770
super().__init__(name=name, shape=shape, shard_dim=shard_dim)
766771
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
767-
DefaultPrimitiveTensor(
768-
name=f"{name}.shard.{i}",
769-
data=t,
772+
(
773+
DefaultPrimitiveTensor(
774+
name=f"{name}.shard.{i}",
775+
data=t,
776+
)
777+
if isinstance(t, torch.Tensor)
778+
else t
770779
)
771-
if isinstance(t, torch.Tensor)
772-
else t
773780
for i, t in enumerate(ts)
774781
)
775782

@@ -941,7 +948,7 @@ def __init__(
941948
will be split along dimension `shard_dim` into `shard_count`
942949
number of pieces.
943950
"""
944-
if isinstance(ts, torch.Tensor):
951+
if isinstance(ts, torch.Tensor) or isinstance(ts, InferenceTensor):
945952
from ..ops import transfer_to_logical_device
946953

947954
assert shard_count is not None
@@ -1082,12 +1089,14 @@ def __init__(
10821089
assert shape == list(shard.shape)
10831090

10841091
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
1085-
DefaultPrimitiveTensor(
1086-
name=f"{name}.shard.{i}",
1087-
data=t,
1092+
(
1093+
DefaultPrimitiveTensor(
1094+
name=f"{name}.shard.{i}",
1095+
data=t,
1096+
)
1097+
if isinstance(t, torch.Tensor)
1098+
else t
10881099
)
1089-
if isinstance(t, torch.Tensor)
1090-
else t
10911100
for i, t in enumerate(ts)
10921101
)
10931102

sharktank/tests/ops/ops_test.py

-20
Original file line numberDiff line numberDiff line change
@@ -194,26 +194,6 @@ def testTorchImplTransposedQuantizedRHS_BlockScaledLayout(self):
194194
ops.custom_impls.matmul_generic_tensor_block_scaled,
195195
)
196196

197-
def testTorchImplTransposedQuantizedRHS_BlockScaledOffsetI4(self):
198-
ops._registry._test_enable_last_op_dispatch(True)
199-
a_dtype = torch.float32
200-
d_dtype = torch.float32
201-
ref_dtype = torch.float32
202-
a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0
203-
d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0
204-
qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
205-
m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0
206-
rhs_pqt = PlanarQuantizedTensor(
207-
shape=[3200, 3200],
208-
layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False),
209-
)
210-
result = ops.matmul(a, rhs_pqt, transpose_rhs=True)
211-
# Just verifying dispatch. Numerics are tested at the kernel level.
212-
self.assertIs(
213-
ops._registry._test_get_last_op_dispatch(),
214-
ops.custom_impls.matmul_generic_tensor_block_scaled_i4,
215-
)
216-
217197
# TODO: mmt_super_block_scaled_offset_q4_unsigned
218198

219199

sharktank/tests/ops/sharded_test.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def testAllGather(self):
2727

2828
sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
2929
actual_result = ops.all_gather(sharded)
30-
3130
for shard in actual_result.shards:
3231
torch.testing.assert_close(shard.as_torch(), expected_result)
3332

@@ -770,6 +769,83 @@ def testSameSplitLhsAndRhsBatchDim(self):
770769
actual_result = unbox_tensor(ops.unshard(sharded_result))
771770
torch.testing.assert_close(actual_result, expected_result)
772771

772+
def testTranposedQuantizedRHSSharded_BlockScaledOffsetI4(self):
773+
ops._registry._test_enable_last_op_dispatch(True)
774+
a_dtype = torch.float32
775+
d_dtype = torch.float32
776+
ref_dtype = torch.float32
777+
a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0
778+
d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0
779+
qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
780+
m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0
781+
rhs_pqt = PlanarQuantizedTensor(
782+
shape=[3200, 3200],
783+
layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False),
784+
)
785+
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)
786+
787+
shard_count = 2
788+
rhs_pqt_sharded = SplitPrimitiveTensor(
789+
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
790+
)
791+
792+
sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
793+
actual_result = ops.sharded_cat(sharded_result)
794+
795+
torch.testing.assert_close(actual_result, expected_result)
796+
797+
def testTorchImplTransposedQuantizedRHSSharded_BlockScaledLayout(self):
798+
ops._registry._test_enable_last_op_dispatch(True)
799+
a_dtype = torch.float32
800+
d_dtype = torch.float32
801+
ref_dtype = torch.float32
802+
a = torch.rand([4, 16, 3200], dtype=a_dtype) * 64
803+
d = torch.rand([3200, 100, 1], dtype=d_dtype) * 64
804+
qs = (torch.rand([3200, 100, 32], dtype=ref_dtype) * 32.0).to(torch.int8)
805+
rhs_pqt = PlanarQuantizedTensor(
806+
shape=[3200, 3200], layout=BlockScaledLayout([3200, 3200], d, qs)
807+
)
808+
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)
809+
810+
shard_count = 2
811+
rhs_pqt_sharded = SplitPrimitiveTensor(
812+
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
813+
)
814+
815+
sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
816+
actual_result = ops.sharded_cat(sharded_result)
817+
818+
torch.testing.assert_close(actual_result, expected_result)
819+
820+
def testTorchImplTransposedQuantizedRHSSharded_TensorScaledLayout(self):
821+
ops._registry._test_enable_last_op_dispatch(True)
822+
a_dtype = torch.float32
823+
d_dtype = torch.float32
824+
ref_dtype = torch.float32
825+
a = torch.rand([4, 16, 3200], dtype=a_dtype) * 64
826+
d = torch.tensor(5.1, dtype=d_dtype) # torch.rand([3200], dtype=d_dtype)
827+
qs = (torch.rand([3200, 3200], dtype=ref_dtype) * 32.0).to(torch.int8)
828+
m = torch.tensor(
829+
16.0, dtype=d_dtype
830+
) # torch.rand([3200], dtype=d_dtype) + 16.0
831+
rhs_pqt = PlanarQuantizedTensor(
832+
shape=[3200, 3200],
833+
layout=TensorScaledLayout(shape=[3200, 3200], d=d, qs=qs, m=m),
834+
)
835+
print("a shape:, ", a.shape)
836+
print("rhs_pqt.shape: ", rhs_pqt.shape)
837+
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)
838+
839+
shard_count = 2
840+
rhs_pqt_sharded = SplitPrimitiveTensor(
841+
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
842+
)
843+
844+
sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
845+
actual_result = ops.sharded_cat(sharded_result)
846+
847+
torch.testing.assert_close(actual_result, expected_result)
848+
773849

774850
class ReplicateTest(unittest.TestCase):
775851
def testReplicateReplicated(self):

0 commit comments

Comments
 (0)