Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Edit the notify setting in our CI #76

Merged
merged 4 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,13 @@ jobs:
run: |
source bitblas_ci/bin/activate
cd testing/python
python -m pytest
python -m pytest

# Control notifications
notify:
runs-on: self-hosted
needs: [format-check, build-test]
if: failure()
steps:
- name: Notification
run: echo "Jobs failed, but no email will be sent."
2 changes: 1 addition & 1 deletion bitblas/ops/impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .lop3_permutate_impl import tir_interleave_weight
from .lop3_permutate_impl import tir_interleave_weight # noqa: F401
20 changes: 20 additions & 0 deletions bitblas/ops/impl/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod


# TODO: Refactor all the tir script implementations to use this base class
# Abstract base class for TIR script emitters
class TIRScriptEmitter(ABC):

@abstractmethod
def emit(self):
raise NotImplementedError


# Abstract base class for TIR script selectors
class TIRScriptSelector(ABC):

@abstractmethod
def select(self):
raise NotImplementedError
191 changes: 121 additions & 70 deletions bitblas/ops/impl/batch_matmul_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,127 @@
from bitblas import tvm
from tvm import te
from bitblas.ops.operator import TransformKind
from .base import TIRScriptEmitter, TIRScriptSelector


def matmul_nt(
Batch,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
):
if not isinstance(M, int):
M = tvm.te.var("m")
A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype)
B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype)
Bias = te.placeholder((N,), name="Bias", dtype=in_dtype)

# Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name="k")
C = te.compute(
(Batch, M, N),
lambda b, i, j: te.sum(
A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k),
name="C",
)
last_output = C
if accum_dtype != out_dtype:
D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D")
last_output = D

if with_bias:
E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E")
last_output = E

args = [A, B, Bias, last_output] if with_bias else [A, B, last_output]

func = te.create_prim_func(args)

return tvm.IRModule.from_expr(func)


def matmul(
Batch,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
layout="nt",
):
if layout == "nn":
raise ValueError("Currently only support layout=nt")
return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias)
class BatchMatMulEmitter(TIRScriptEmitter):

def __init__(
self,
batch,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
layout="nt",
):
self.batch = batch
self.M = self._validate_dimension(M, "M")
self.N = self._validate_dimension(N, "N")
self.K = self._validate_dimension(K, "K")
self.in_dtype = in_dtype
self.out_dtype = out_dtype
self.accum_dtype = accum_dtype
self.with_bias = with_bias
self.layout = layout
self._validate_layout()

@staticmethod
def _validate_dimension(dim, name):
if not isinstance(dim, int):
return tvm.te.var(name.lower())
return dim

def _validate_layout(self):
if self.layout not in ["nn", "nt"]:
raise ValueError(f"Unsupported layout: {self.layout}")
if self.layout == "nn":
raise ValueError("Currently only support layout=nt")

def _create_placeholders(self):
A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype)
B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype)
Bias = te.placeholder(
(self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None
return A, B, Bias

def _compute_matmul(self, A, B):
k = te.reduce_axis((0, self.K), name="k")
C = te.compute(
(self.batch, self.M, self.N),
lambda b, i, j: te.sum(
A[b, i, k].astype(self.accum_dtype) * B[b, j, k].astype(self.accum_dtype), axis=k),
name="C",
)
return C

def _apply_bias(self, C, Bias):
if self.with_bias:
return te.compute((self.batch, self.M, self.N),
lambda b, i, j: C[b, i, j] + Bias[j],
name="E")
return C

def _convert_dtype(self, tensor):
if self.accum_dtype != self.out_dtype:
return te.compute((self.batch, self.M, self.N),
lambda b, i, j: tensor[b, i, j].astype(self.out_dtype),
name="D")
return tensor

def emit(self):
A, B, Bias = self._create_placeholders()
C = self._compute_matmul(A, B)
last_output = self._convert_dtype(C)
if self.with_bias:
last_output = self._apply_bias(last_output, Bias)

args = [A, B, Bias, last_output] if self.with_bias else [A, B, last_output]
func = te.create_prim_func(args)
return tvm.IRModule.from_expr(func)


class BatchMatMulSelector(TIRScriptSelector):

def __init__(self,
propagate_a: TransformKind = TransformKind.NonTransform,
propagate_b: TransformKind = TransformKind.NonTransform):
self.propagate_a = propagate_a
self.propagate_b = propagate_b

def select(
self,
batch=1,
M=None,
N=16384,
K=16384,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
layout="nt",
):
if layout == "nn":
if self.propagate_a or self.propagate_b:
raise ValueError(
"Currently only support propagate_a=False and propagate_b=False for layout=nn")
return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias,
layout).emit()
elif layout == "nt":
if self.propagate_a and self.propagate_b:
raise ValueError("Currently only support propagate_a or propagate_b for layout=nt")
elif self.propagate_a:
raise ValueError("Currently only support propagate_a=False for layout=nt")
elif self.propagate_b:
raise ValueError("Currently only support propagate_b=False for layout=nt")
else:
return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype,
with_bias, layout).emit()
else:
raise ValueError(f"Unsupported layout: {layout}")


def select_implementation(
Expand All @@ -75,19 +140,5 @@ def select_implementation(
propagate_a: TransformKind = TransformKind.NonTransform,
propagate_b: TransformKind = TransformKind.NonTransform,
):
if layout == "nn":
if propagate_a or propagate_b:
raise ValueError(
"Currently only support propagate_a=False and propagate_b=False for layout=nn")
return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout)
elif layout == "nt":
if propagate_a and propagate_b:
raise ValueError("Currently only support propagate_a or propagate_b for layout=nt")
elif propagate_a:
raise ValueError("Currently only support propagate_a=False for layout=nt")
elif propagate_b:
raise ValueError("Currently only support propagate_b=False for layout=nt")
else:
return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout)
else:
raise ValueError(f"Unsupported layout: {layout}")
selector = BatchMatMulSelector(propagate_a, propagate_b)
return selector.select(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout)
Loading