@@ -27,7 +27,6 @@ def testAllGather(self):
27
27
28
28
sharded = SplitPrimitiveTensor (shard_dim = shard_dim , ts = shards )
29
29
actual_result = ops .all_gather (sharded )
30
-
31
30
for shard in actual_result .shards :
32
31
torch .testing .assert_close (shard .as_torch (), expected_result )
33
32
@@ -770,6 +769,83 @@ def testSameSplitLhsAndRhsBatchDim(self):
770
769
actual_result = unbox_tensor (ops .unshard (sharded_result ))
771
770
torch .testing .assert_close (actual_result , expected_result )
772
771
772
+ def testTranposedQuantizedRHSSharded_BlockScaledOffsetI4 (self ):
773
+ ops ._registry ._test_enable_last_op_dispatch (True )
774
+ a_dtype = torch .float32
775
+ d_dtype = torch .float32
776
+ ref_dtype = torch .float32
777
+ a = torch .rand ([4 , 16 , 3200 ], dtype = a_dtype ) / 256.0
778
+ d = torch .rand ([3200 , 100 , 1 ], dtype = d_dtype ) / 256.0
779
+ qs = (torch .rand ([3200 , 100 , 16 ], dtype = ref_dtype ) * 255.0 ).to (torch .uint8 )
780
+ m = torch .rand ([3200 , 100 , 1 ], dtype = d_dtype ) + 16.0
781
+ rhs_pqt = PlanarQuantizedTensor (
782
+ shape = [3200 , 3200 ],
783
+ layout = BlockScaledI4Layout ([3200 , 3200 ], d , qs , m = m , signed = False ),
784
+ )
785
+ expected_result = ops .matmul (a , rhs_pqt , transpose_rhs = True )
786
+
787
+ shard_count = 2
788
+ rhs_pqt_sharded = SplitPrimitiveTensor (
789
+ shard_dim = 0 , ts = rhs_pqt , shard_count = shard_count
790
+ )
791
+
792
+ sharded_result = ops .matmul (a , rhs_pqt_sharded , transpose_rhs = True )
793
+ actual_result = ops .sharded_cat (sharded_result )
794
+
795
+ torch .testing .assert_close (actual_result , expected_result )
796
+
797
+ def testTorchImplTransposedQuantizedRHSSharded_BlockScaledLayout (self ):
798
+ ops ._registry ._test_enable_last_op_dispatch (True )
799
+ a_dtype = torch .float32
800
+ d_dtype = torch .float32
801
+ ref_dtype = torch .float32
802
+ a = torch .rand ([4 , 16 , 3200 ], dtype = a_dtype ) * 64
803
+ d = torch .rand ([3200 , 100 , 1 ], dtype = d_dtype ) * 64
804
+ qs = (torch .rand ([3200 , 100 , 32 ], dtype = ref_dtype ) * 32.0 ).to (torch .int8 )
805
+ rhs_pqt = PlanarQuantizedTensor (
806
+ shape = [3200 , 3200 ], layout = BlockScaledLayout ([3200 , 3200 ], d , qs )
807
+ )
808
+ expected_result = ops .matmul (a , rhs_pqt , transpose_rhs = True )
809
+
810
+ shard_count = 2
811
+ rhs_pqt_sharded = SplitPrimitiveTensor (
812
+ shard_dim = 0 , ts = rhs_pqt , shard_count = shard_count
813
+ )
814
+
815
+ sharded_result = ops .matmul (a , rhs_pqt_sharded , transpose_rhs = True )
816
+ actual_result = ops .sharded_cat (sharded_result )
817
+
818
+ torch .testing .assert_close (actual_result , expected_result )
819
+
820
+ def testTorchImplTransposedQuantizedRHSSharded_TensorScaledLayout (self ):
821
+ ops ._registry ._test_enable_last_op_dispatch (True )
822
+ a_dtype = torch .float32
823
+ d_dtype = torch .float32
824
+ ref_dtype = torch .float32
825
+ a = torch .rand ([4 , 16 , 3200 ], dtype = a_dtype ) * 64
826
+ d = torch .tensor (5.1 , dtype = d_dtype ) # torch.rand([3200], dtype=d_dtype)
827
+ qs = (torch .rand ([3200 , 3200 ], dtype = ref_dtype ) * 32.0 ).to (torch .int8 )
828
+ m = torch .tensor (
829
+ 16.0 , dtype = d_dtype
830
+ ) # torch.rand([3200], dtype=d_dtype) + 16.0
831
+ rhs_pqt = PlanarQuantizedTensor (
832
+ shape = [3200 , 3200 ],
833
+ layout = TensorScaledLayout (shape = [3200 , 3200 ], d = d , qs = qs , m = m ),
834
+ )
835
+ print ("a shape:, " , a .shape )
836
+ print ("rhs_pqt.shape: " , rhs_pqt .shape )
837
+ expected_result = ops .matmul (a , rhs_pqt , transpose_rhs = True )
838
+
839
+ shard_count = 2
840
+ rhs_pqt_sharded = SplitPrimitiveTensor (
841
+ shard_dim = 0 , ts = rhs_pqt , shard_count = shard_count
842
+ )
843
+
844
+ sharded_result = ops .matmul (a , rhs_pqt_sharded , transpose_rhs = True )
845
+ actual_result = ops .sharded_cat (sharded_result )
846
+
847
+ torch .testing .assert_close (actual_result , expected_result )
848
+
773
849
774
850
class ReplicateTest (unittest .TestCase ):
775
851
def testReplicateReplicated (self ):
0 commit comments