diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index dc4bb587b..08ed8f7f0 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -3,7 +3,7 @@ from bitblas import tvm from tvm.tir.function import TensorIntrin from tvm.script import tir as T -from typing import Dict, Literal +from typing import Dict, Literal, List from bitblas.quantization import ( _tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, @@ -769,6 +769,7 @@ def get_fast_decode_intrin( with_scale=False, with_zeros=False, zeros_mode="original", + storage_scope="local", ): """ loops extent is the number of elements to be decoded in one stage @@ -814,7 +815,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle) -> None: n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -822,7 +823,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle) -> None: loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) with T.block("root"): @@ -846,7 +847,8 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, + offset_factor=n_storage_elems, ) Decompressed = T.match_buffer( decompressed, @@ -854,7 +856,8 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, + offset_factor=loops_extent, ) with T.block("root"): @@ -863,8 +866,8 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: T.call_extern( "handle", func_name, - Compressed.data, - Decompressed.data, + Compressed.access_ptr("r"), + Decompressed.access_ptr("w"), loops_extent, ) @@ -878,7 +881,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle, scale: T.hand n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -886,7 +889,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle, scale: T.hand loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -920,7 +923,7 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle, scale: T.hand n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -928,7 +931,7 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle, scale: T.hand loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -988,7 +991,7 @@ def fast_decode_desc( n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -996,7 +999,7 @@ def fast_decode_desc( loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -1004,7 +1007,7 @@ def fast_decode_desc( 1, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Zeros = T.match_buffer( zeros, @@ -1012,7 +1015,7 @@ def fast_decode_desc( 1, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) with T.block("root"): T.reads(*get_dequantize_buffers_list( @@ -1053,7 +1056,7 @@ def fast_decode_impl( n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -1061,7 +1064,7 @@ def fast_decode_impl( loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -1071,7 +1074,7 @@ def fast_decode_impl( dtype=target_dtype, offset_factor=1, strides=[s0], - scope="local", + scope=storage_scope, ) Zeros = T.match_buffer( zeros, @@ -1081,7 +1084,7 @@ def fast_decode_impl( dtype=storage_dtype, offset_factor=1, strides=[s1], - scope="local", + scope=storage_scope, ) with T.block("root"): T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) @@ -1128,7 +1131,7 @@ def fast_decode_desc( n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -1136,7 +1139,7 @@ def fast_decode_desc( loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -1192,7 +1195,7 @@ def fast_decode_impl( n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -1200,7 +1203,7 @@ def fast_decode_impl( loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -1238,353 +1241,83 @@ def fast_decode_impl( return fast_decode_desc, fast_decode_impl -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u2_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u1_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=1, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int32_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int32", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int32_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int32", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_uint32_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="uint32", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_uint32_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="uint32", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_quantized_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="quantized", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_quantized_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="quantized", - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u2_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i2_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - source_format="int", - storage_dtype="int8", - target_dtype="int8", - loops_extent=16), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i1_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - source_format="int", - storage_dtype="int8", - target_dtype="int8", - loops_extent=16), -) - -LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i4_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i2_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i2_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i1_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) +# Define the intrin definitions +intrin_definitions = [ + # (source_bit, storage_dtype, target_dtype, loops_extent, storage_scope, source_format, with_scale, with_zeros, zeros_mode) + (4, "int8", "float16", 8, "local", "uint", False, False, "original"), + (4, "int8", "float16", 8, "warp", "uint", False, False, "original"), + (2, "int8", "float16", 8, "local", "uint", False, False, "original"), + (1, "int8", "float16", 8, "local", "uint", False, False, "original"), + (4, "int32", "float16", 8, "local", "uint", False, False, "original"), + (4, "int32", "float16", 8, "local", "uint", True, False, "original"), + (4, "uint32", "float16", 8, "local", "uint", False, False, "original"), + (4, "uint32", "float16", 8, "local", "uint", True, False, "original"), + (4, "int8", "float16", 8, "local", "uint", True, False, "original"), + (4, "int8", "float16", 8, "local", "uint", True, True, "original"), + (4, "int8", "float16", 8, "local", "uint", True, True, "rescale"), + (4, "int8", "float16", 8, "local", "uint", True, True, "quantized"), + (2, "int8", "float16", 8, "local", "uint", True, False, "original"), + (2, "int8", "float16", 8, "local", "uint", True, True, "original"), + (2, "int8", "float16", 8, "local", "uint", True, True, "rescale"), + (2, "int8", "float16", 8, "local", "uint", True, True, "quantized"), + (1, "int8", "float16", 8, "local", "uint", True, False, "original"), + (1, "int8", "float16", 8, "local", "uint", True, True, "original"), + (1, "int8", "float16", 8, "local", "uint", True, True, "rescale"), + (4, "int8", "int8", 8, "local", "uint", False, False, "original"), + (4, "int8", "int8", 16, "local", "uint", False, False, "original"), + (2, "int8", "int8", 16, "local", "uint", False, False, "original"), + (2, "int8", "int8", 16, "local", "int", False, False, "original"), + (1, "int8", "int8", 16, "local", "uint", False, False, "original"), + (1, "int8", "int8", 16, "local", "int", False, False, "original"), + (4, "int8", "float16", 8, "local", "int", False, False, "original"), + (4, "int8", "float16", 8, "local", "int", True, False, "original"), + (2, "int8", "float16", 8, "local", "int", False, False, "original"), + (2, "int8", "float16", 8, "local", "int", True, False, "original"), + (1, "int8", "float16", 8, "local", "int", False, False, "original"), +] + + +# Register the intrin +def initialize_tensor_intrin(): + registered_intrins: List[str] = [] + for params in intrin_definitions: + # Repack from the params + source_bit, storage_dtype, target_dtype, loops_extent, storage_scope, source_format, with_scale, with_zeros, zeros_mode = params + + # Construct the name + name_parts = [ + "lop3_fast_decode", f"{source_format[0]}{source_bit}", f"to_{storage_dtype}", + f"to_{target_dtype}", f"l{loops_extent}" + ] + if with_scale: + name_parts.append("scale") + if with_zeros: + name_parts.extend(["zeros", zeros_mode]) + if storage_scope == "warp": + name_parts.append("warp") + + name = "_".join(part for part in name_parts if part) + "_" + + # Get intrin desc and implementation + intrin = get_fast_decode_intrin( + source_bit=source_bit, + storage_dtype=storage_dtype, + source_format=source_format, + target_dtype=target_dtype, + loops_extent=loops_extent, + with_scale=with_scale, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + storage_scope=storage_scope) + + # Register the intrin + TensorIntrin.register(name, *intrin) + registered_intrins.append(name) + + return registered_intrins + + +registered_intrins = initialize_tensor_intrin() def get_lop3_intrin_group( @@ -1595,6 +1328,7 @@ def get_lop3_intrin_group( with_scaling: bool = False, with_zeros: bool = False, zeros_mode: Literal["original", "rescale", "quantized"] = "original", + storage_scope: str = "local", ) -> Dict[str, str]: """ This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. @@ -1615,6 +1349,15 @@ def get_lop3_intrin_group( with_scale : bool, optional A boolean parameter that indicates whether scaling should be applied. By default, it is False. + with_zeros : bool, optional + A boolean parameter that indicates whether zeros should be used. By default, it is False. + + zeros_mode : Literal["original", "rescale", "quantized"], optional + The mode of zeros. It can be either "original", "rescale", or "quantized". By default, it is "original". + + storage_scope : Literal["local", "warp"], optional + The scope of the storage. It can be either "local" or "warp". By default, it is "local". + Returns ------- Dict[str, str] @@ -1630,11 +1373,13 @@ def get_lop3_intrin_group( raise ValueError("Invalid source_format. Expected 'int' or 'uint'.") source_symbol = "i" if source_format == "int" else "u" - _intrin = f"lop3_fast_decode_{source_symbol}{source_bit}_to_{storage_dtype}_to_{target_dtype}_l{loop_extent}_" + _intrin = f"lop3_fast_decode_{source_symbol}{source_bit}_to_{storage_dtype}_to_{out_dtype}_l{loop_extent}_" if with_scaling: _intrin += "scale_" if with_zeros: _intrin += f"zeros_{zeros_mode}_" + if storage_scope == "warp": + _intrin += "warp_" import_c_map = { "i4_to_f16": decode_i4_to_f16,