Skip to content

Commit

Permalink
[CI] Edit the notify setting in our CI (#76)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci
  • Loading branch information
LeiWang1999 committed Jul 5, 2024
1 parent 5391d12 commit 60caba6
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 72 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,13 @@ jobs:
run: |
source bitblas_ci/bin/activate
cd testing/python
python -m pytest
python -m pytest
# Control notifications
notify:
runs-on: self-hosted
needs: [format-check, build-test]
if: failure()
steps:
- name: Notification
run: echo "Jobs failed, but no email will be sent."
2 changes: 1 addition & 1 deletion bitblas/ops/impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .lop3_permutate_impl import tir_interleave_weight
from .lop3_permutate_impl import tir_interleave_weight # noqa: F401
20 changes: 20 additions & 0 deletions bitblas/ops/impl/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod


# TODO: Refactor all the tir script implementations to use this base class
# Abstract base class for TIR script emitters
class TIRScriptEmitter(ABC):

@abstractmethod
def emit(self):
raise NotImplementedError


# Abstract base class for TIR script selectors
class TIRScriptSelector(ABC):

@abstractmethod
def select(self):
raise NotImplementedError
191 changes: 121 additions & 70 deletions bitblas/ops/impl/batch_matmul_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,127 @@
from bitblas import tvm
from tvm import te
from bitblas.ops.operator import TransformKind
from .base import TIRScriptEmitter, TIRScriptSelector


def matmul_nt(
Batch,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
):
if not isinstance(M, int):
M = tvm.te.var("m")
A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype)
B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype)
Bias = te.placeholder((N,), name="Bias", dtype=in_dtype)

# Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name="k")
C = te.compute(
(Batch, M, N),
lambda b, i, j: te.sum(
A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k),
name="C",
)
last_output = C
if accum_dtype != out_dtype:
D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D")
last_output = D

if with_bias:
E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E")
last_output = E

args = [A, B, Bias, last_output] if with_bias else [A, B, last_output]

func = te.create_prim_func(args)

return tvm.IRModule.from_expr(func)


def matmul(
Batch,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
layout="nt",
):
if layout == "nn":
raise ValueError("Currently only support layout=nt")
return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias)
class BatchMatMulEmitter(TIRScriptEmitter):

def __init__(
self,
batch,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
layout="nt",
):
self.batch = batch
self.M = self._validate_dimension(M, "M")
self.N = self._validate_dimension(N, "N")
self.K = self._validate_dimension(K, "K")
self.in_dtype = in_dtype
self.out_dtype = out_dtype
self.accum_dtype = accum_dtype
self.with_bias = with_bias
self.layout = layout
self._validate_layout()

@staticmethod
def _validate_dimension(dim, name):
if not isinstance(dim, int):
return tvm.te.var(name.lower())
return dim

def _validate_layout(self):
if self.layout not in ["nn", "nt"]:
raise ValueError(f"Unsupported layout: {self.layout}")
if self.layout == "nn":
raise ValueError("Currently only support layout=nt")

def _create_placeholders(self):
A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype)
B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype)
Bias = te.placeholder(
(self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None
return A, B, Bias

def _compute_matmul(self, A, B):
k = te.reduce_axis((0, self.K), name="k")
C = te.compute(
(self.batch, self.M, self.N),
lambda b, i, j: te.sum(
A[b, i, k].astype(self.accum_dtype) * B[b, j, k].astype(self.accum_dtype), axis=k),
name="C",
)
return C

def _apply_bias(self, C, Bias):
if self.with_bias:
return te.compute((self.batch, self.M, self.N),
lambda b, i, j: C[b, i, j] + Bias[j],
name="E")
return C

def _convert_dtype(self, tensor):
if self.accum_dtype != self.out_dtype:
return te.compute((self.batch, self.M, self.N),
lambda b, i, j: tensor[b, i, j].astype(self.out_dtype),
name="D")
return tensor

def emit(self):
A, B, Bias = self._create_placeholders()
C = self._compute_matmul(A, B)
last_output = self._convert_dtype(C)
if self.with_bias:
last_output = self._apply_bias(last_output, Bias)

args = [A, B, Bias, last_output] if self.with_bias else [A, B, last_output]
func = te.create_prim_func(args)
return tvm.IRModule.from_expr(func)


class BatchMatMulSelector(TIRScriptSelector):

def __init__(self,
propagate_a: TransformKind = TransformKind.NonTransform,
propagate_b: TransformKind = TransformKind.NonTransform):
self.propagate_a = propagate_a
self.propagate_b = propagate_b

def select(
self,
batch=1,
M=None,
N=16384,
K=16384,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
layout="nt",
):
if layout == "nn":
if self.propagate_a or self.propagate_b:
raise ValueError(
"Currently only support propagate_a=False and propagate_b=False for layout=nn")
return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias,
layout).emit()
elif layout == "nt":
if self.propagate_a and self.propagate_b:
raise ValueError("Currently only support propagate_a or propagate_b for layout=nt")
elif self.propagate_a:
raise ValueError("Currently only support propagate_a=False for layout=nt")
elif self.propagate_b:
raise ValueError("Currently only support propagate_b=False for layout=nt")
else:
return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype,
with_bias, layout).emit()
else:
raise ValueError(f"Unsupported layout: {layout}")


def select_implementation(
Expand All @@ -75,19 +140,5 @@ def select_implementation(
propagate_a: TransformKind = TransformKind.NonTransform,
propagate_b: TransformKind = TransformKind.NonTransform,
):
if layout == "nn":
if propagate_a or propagate_b:
raise ValueError(
"Currently only support propagate_a=False and propagate_b=False for layout=nn")
return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout)
elif layout == "nt":
if propagate_a and propagate_b:
raise ValueError("Currently only support propagate_a or propagate_b for layout=nt")
elif propagate_a:
raise ValueError("Currently only support propagate_a=False for layout=nt")
elif propagate_b:
raise ValueError("Currently only support propagate_b=False for layout=nt")
else:
return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout)
else:
raise ValueError(f"Unsupported layout: {layout}")
selector = BatchMatMulSelector(propagate_a, propagate_b)
return selector.select(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout)

0 comments on commit 60caba6

Please sign in to comment.