diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 597b2a34f..0e7ecaa54 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -143,13 +143,16 @@ def _load_operator(self, config_path, target): with open(full_path) as f: config = json.load(f) elif file.endswith(".tar"): - rt_mod = tvm.runtime.load_module(full_path) + try: + rt_mod = tvm.runtime.load_module(full_path) + except Exception as e: + logger.error(f"Failed to load runtime module from {full_path}: {e}") elif file == BITBLAS_WRAPPED_COMPILED_NAME: libpath = full_path elif file == BITBLAS_WRAPPED_SOURCE_NAME: srcpath = full_path - if mapping and config and rt_mod: + if mapping and config: self._instantiate_and_add_operator(mapping, config, rt_mod, srcpath, libpath, target) def _instantiate_and_add_operator(self, mapping, config, rt_mod, srcpath, libpath, target): diff --git a/bitblas/gpu/matmul.py b/bitblas/gpu/matmul.py index 4d21cdb46..0cfd65dc4 100644 --- a/bitblas/gpu/matmul.py +++ b/bitblas/gpu/matmul.py @@ -464,8 +464,8 @@ def check_weight_decode_info(weight_decode_info): block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) thread_row_tiles = config.thread[0] // (config.step[0]) thread_col_tiles = config.thread[1] // (config.step[1]) - vthread_row_tiles = (config.step[0]) # expand vtrhead to avoid load band conflict - vthread_col_tiles = (config.step[1]) # expand vtrhead to avoid load band conflict + vthread_row_tiles = (config.step[0]) # expand vthread to avoid load band conflict + vthread_col_tiles = (config.step[1]) # expand vthread to avoid load band conflict chunk = config.rstep[0] shared_scope = config.shared_scope @@ -489,8 +489,10 @@ def find_valid_number(k, chunk, magic=16): return None # If no such number is found K = func.buffer_map[func.params[0]].shape[-1] + # This is hack to handle unaligned K and BK BK = find_valid_number(K, chunk) - + # Align Factor (Notes: This is also a hack.) + align_factor = 4 # used to expand the vectorization factor sch.pad_einsum( main_block, [1, BM, BN, BK], @@ -498,13 +500,8 @@ def find_valid_number(k, chunk, magic=16): batch, y, x, k = sch.get_loops(main_block) by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles]) bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles]) - ko, ki = sch.split(k, factors=[None, BK]) - sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) - by = sch.fuse(batch, by) - sch.bind(bx, "blockIdx.x") - sch.bind(by, "blockIdx.y") - sch.bind(vy, "vthread.y") - sch.bind(vx, "vthread.x") + ko, ki, kii = sch.split(k, factors=[None, (BK // align_factor), align_factor]) + sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, kii, yi, xi) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -514,7 +511,7 @@ def prod(iterable): l2g = sch.cache_write(main_block, 0, "local") sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) - def _cooperative_fetch(index, vec_len): + def _cooperative_fetch(index, vec_len, align_factor=2): block = sch.cache_read(main_block, index, "shared") num_loops = len(sch.get_loops(block)) block_local = sch.cache_read(main_block, index, "local") @@ -543,9 +540,10 @@ def is_trivial_load(block): sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") + fused = sch.fuse(*sch.get_loops(block_local)[-2:]) _, vec = sch.split( - sch.fuse(*sch.get_loops(block_local)[-2:]), - [None, vec_len // prod(config.step)], + fused, + [None, align_factor], ) sch.vectorize(vec) @@ -562,7 +560,7 @@ def is_trivial_load(block): # otherwise cooperative fetch in shared memory. vectorize = config.vectorize.get(_buffer_name, 1) - _cooperative_fetch(i, vec_len=vectorize) + _cooperative_fetch(i, vec_len=vectorize, align_factor=align_factor) def decode_fetch_to_shared(block, idx): # step1. create memory hierarchy @@ -652,10 +650,24 @@ def get_idx(): _ = decode_fetch_to_shared(main_block, 1) + def fetch_to_local(block, index, align_factor=2): + # read_b to load + block_local = sch.cache_read(block, index, "local") + sch.compute_at(block_local, ki, preserve_unit_loops=True) + fused = sch.fuse(*sch.get_loops(block_local)[-2:]) + _, vec = sch.split( + fused, + [None, align_factor], + ) + sch.vectorize(vec) + return block_local + + fetch_to_local(main_block, 1, align_factor=align_factor) + auto_inline_consumer_chain(sch, l2g) - _, vec = sch.split( - sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)]) + l2g_vec = get_coalesced_veclen(sch.get(l2g)) + _, vec = sch.split(sch.fuse(*sch.get_loops(l2g)[-2:]), [None, l2g_vec]) sch.vectorize(vec) sch.decompose_reduction(main_block, ko) diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index a94da9969..c8a9cb08a 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -373,10 +373,12 @@ def __call__(self, *args: Any) -> Any: def update_func(self, func: PrimFunc): self.prim_func_mod["main"] = func - def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): - self.rt_mod = rt_mod - self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) - self.torch_func = to_pytorch_func(rt_mod) + def update_runtime_module(self, rt_mod=None, srcpath=None, libpath=None): + if rt_mod is not None: + self.rt_mod = rt_mod + self.time_evaluator = rt_mod.time_evaluator( + rt_mod.entry_name, self.arch.device, number=10) + self.torch_func = to_pytorch_func(rt_mod) if srcpath is not None: assert self.lib_generator is not None, "lib_generator is not initialized" self.lib_generator.set_src_path(srcpath)