Skip to content

Commit

Permalink
Lint Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Jul 5, 2024
1 parent ef7e515 commit 4dc2c10
Showing 1 changed file with 73 additions and 52 deletions.
125 changes: 73 additions & 52 deletions testing/python/operators/test_general_matmul_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ def get_codegen_result(ops):
code = ops.get_source()
return code


# fmt: off
def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout,
with_bias, group_size, with_scaling, with_zeros, zeros_mode):
def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias,
group_size, with_scaling, with_zeros, zeros_mode):

matmul_config = MatmulConfig(
M=M,
Expand All @@ -34,23 +35,36 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la
matmul = Matmul(config=matmul_config, enable_tuning=False)
assert get_codegen_result(matmul)


def test_matmul_codegen_default():
matmul_codegen_default(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None),
matmul_codegen_default(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None),
matmul_codegen_default(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None),
matmul_codegen_default(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None),
matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, None),
matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None),
matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None),
matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, "original"),
matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, None),
matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None),
matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None),
matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, "original"),
matmul_codegen_default(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1,
False, False, None),
matmul_codegen_default(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False,
-1, False, False, None),
matmul_codegen_default(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False,
False, None),
matmul_codegen_default(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False,
False, None),
matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
False, False, None),
matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1,
False, False, None),
matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
True, False, None),
matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
True, True, "original"),
matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
False, False, None),
matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1,
False, False, None),
matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
True, False, None),
matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
True, True, "original"),


def matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias,
group_size, with_scaling, with_zeros, zeros_mode):
group_size, with_scaling, with_zeros, zeros_mode):

matmul_config = MatmulConfig(
M=M,
Expand All @@ -73,30 +87,34 @@ def matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, w


def test_matmul_finetune():
matmul_finetune(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False,
None),
matmul_finetune(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False,
None),
matmul_finetune(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None),
matmul_finetune(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None),
matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False,
None),
matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None),
matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None),
matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True,
"original"),
matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False,
None),
matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False,
None),
matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False,
None),
matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True,
"original"),
matmul_finetune(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False,
False, None),
matmul_finetune(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1,
False, False, None),
matmul_finetune(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False,
None),
matmul_finetune(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False,
None),
matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False,
False, None),
matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False,
False, None),
matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True,
False, None),
matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True,
True, "original"),
matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False,
False, None),
matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False,
False, None),
matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True,
False, None),
matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True,
True, "original"),


def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias,
group_size, with_scaling, with_zeros, zeros_mode):
group_size, with_scaling, with_zeros, zeros_mode):
import torch
torch.random.manual_seed(0)
import numpy as np
Expand Down Expand Up @@ -185,22 +203,24 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo


def test_matmul_torch_forward():
matmul_torch_forward(1, 1024, 1024, "float16", "int4", "float16", "float16", "nt", None, None, None, None,
None)
matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False,
None)
matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None),
matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None),
matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True,
"original")
matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False,
None)
matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False,
None)
matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False,
None)
matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True,
"original")
matmul_torch_forward(1, 1024, 1024, "float16", "int4", "float16", "float16", "nt", None, None,
None, None, None)
matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
False, False, None)
matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1,
False, False, None),
matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
True, False, None),
matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
True, True, "original")
matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
False, False, None)
matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1,
False, False, None)
matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
True, False, None)
matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1,
True, True, "original")


def matmul_transform_weight(
Expand Down Expand Up @@ -259,6 +279,7 @@ def test_matmul_transform_weight():
matmul_transform_weight(768, 768, 768, "float16", "uint4", "float16", "float16", False)
matmul_transform_weight(768, 768, 768, "float16", "int4", "float16", "float16", False)


# fmt: on
if __name__ == "__main__":
bitblas.testing.main()

0 comments on commit 4dc2c10

Please sign in to comment.