Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Improve Dequant performance on CUDA Simt #189

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,16 @@ def _load_operator(self, config_path, target):
with open(full_path) as f:
config = json.load(f)
elif file.endswith(".tar"):
rt_mod = tvm.runtime.load_module(full_path)
try:
rt_mod = tvm.runtime.load_module(full_path)
except Exception as e:
logger.error(f"Failed to load runtime module from {full_path}: {e}")
elif file == BITBLAS_WRAPPED_COMPILED_NAME:
libpath = full_path
elif file == BITBLAS_WRAPPED_SOURCE_NAME:
srcpath = full_path

if mapping and config and rt_mod:
if mapping and config:
self._instantiate_and_add_operator(mapping, config, rt_mod, srcpath, libpath, target)

def _instantiate_and_add_operator(self, mapping, config, rt_mod, srcpath, libpath, target):
Expand Down
44 changes: 28 additions & 16 deletions bitblas/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def check_weight_decode_info(weight_decode_info):
block_col_warps = config.block[1] // (config.thread[1] * config.step[1])
thread_row_tiles = config.thread[0] // (config.step[0])
thread_col_tiles = config.thread[1] // (config.step[1])
vthread_row_tiles = (config.step[0]) # expand vtrhead to avoid load band conflict
vthread_col_tiles = (config.step[1]) # expand vtrhead to avoid load band conflict
vthread_row_tiles = (config.step[0]) # expand vthread to avoid load band conflict
vthread_col_tiles = (config.step[1]) # expand vthread to avoid load band conflict
chunk = config.rstep[0]
shared_scope = config.shared_scope

Expand All @@ -489,22 +489,19 @@ def find_valid_number(k, chunk, magic=16):
return None # If no such number is found

K = func.buffer_map[func.params[0]].shape[-1]
# This is hack to handle unaligned K and BK
BK = find_valid_number(K, chunk)

# Align Factor (Notes: This is also a hack.)
align_factor = 4 # used to expand the vectorization factor
sch.pad_einsum(
main_block,
[1, BM, BN, BK],
)
batch, y, x, k = sch.get_loops(main_block)
by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles])
bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles])
ko, ki = sch.split(k, factors=[None, BK])
sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi)
by = sch.fuse(batch, by)
sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")
sch.bind(vy, "vthread.y")
sch.bind(vx, "vthread.x")
ko, ki, kii = sch.split(k, factors=[None, (BK // align_factor), align_factor])
sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, kii, yi, xi)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

Expand All @@ -514,7 +511,7 @@ def prod(iterable):
l2g = sch.cache_write(main_block, 0, "local")
sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True)

def _cooperative_fetch(index, vec_len):
def _cooperative_fetch(index, vec_len, align_factor=2):
block = sch.cache_read(main_block, index, "shared")
num_loops = len(sch.get_loops(block))
block_local = sch.cache_read(main_block, index, "local")
Expand Down Expand Up @@ -543,9 +540,10 @@ def is_trivial_load(block):
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

fused = sch.fuse(*sch.get_loops(block_local)[-2:])
_, vec = sch.split(
sch.fuse(*sch.get_loops(block_local)[-2:]),
[None, vec_len // prod(config.step)],
fused,
[None, align_factor],
)
sch.vectorize(vec)

Expand All @@ -562,7 +560,7 @@ def is_trivial_load(block):
# otherwise cooperative fetch in shared memory.
vectorize = config.vectorize.get(_buffer_name, 1)

_cooperative_fetch(i, vec_len=vectorize)
_cooperative_fetch(i, vec_len=vectorize, align_factor=align_factor)

def decode_fetch_to_shared(block, idx):
# step1. create memory hierarchy
Expand Down Expand Up @@ -652,10 +650,24 @@ def get_idx():

_ = decode_fetch_to_shared(main_block, 1)

def fetch_to_local(block, index, align_factor=2):
# read_b to load
block_local = sch.cache_read(block, index, "local")
sch.compute_at(block_local, ki, preserve_unit_loops=True)
fused = sch.fuse(*sch.get_loops(block_local)[-2:])
_, vec = sch.split(
fused,
[None, align_factor],
)
sch.vectorize(vec)
return block_local

fetch_to_local(main_block, 1, align_factor=align_factor)

auto_inline_consumer_chain(sch, l2g)

_, vec = sch.split(
sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)])
l2g_vec = get_coalesced_veclen(sch.get(l2g))
_, vec = sch.split(sch.fuse(*sch.get_loops(l2g)[-2:]), [None, l2g_vec])
sch.vectorize(vec)

sch.decompose_reduction(main_block, ko)
Expand Down
10 changes: 6 additions & 4 deletions bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,12 @@ def __call__(self, *args: Any) -> Any:
def update_func(self, func: PrimFunc):
self.prim_func_mod["main"] = func

def update_runtime_module(self, rt_mod, srcpath=None, libpath=None):
self.rt_mod = rt_mod
self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10)
self.torch_func = to_pytorch_func(rt_mod)
def update_runtime_module(self, rt_mod=None, srcpath=None, libpath=None):
if rt_mod is not None:
self.rt_mod = rt_mod
self.time_evaluator = rt_mod.time_evaluator(
rt_mod.entry_name, self.arch.device, number=10)
self.torch_func = to_pytorch_func(rt_mod)
if srcpath is not None:
assert self.lib_generator is not None, "lib_generator is not initialized"
self.lib_generator.set_src_path(srcpath)
Expand Down
Loading