Skip to content

Commit

Permalink
fix for legalize
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 7, 2024
1 parent 7a16e5a commit 8c666ac
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
11 changes: 10 additions & 1 deletion bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class MatmulConfig(OperatorConfig):
None # propagate_b is a flag to control the ladder permutation
)

optimize_stratety: OptimizeStrategy = OptimizeStrategy.ContigousBatching
optimize_stratety: Union[int, OptimizeStrategy] = OptimizeStrategy.ContigousBatching

def __legalize_dynamic_symbolic(self, M):
return tuple(self.M) if isinstance(self.M, list) else self.M
Expand All @@ -98,6 +98,11 @@ def __legalize_propagate(self, propagate):

return propagate

def __legalize_optimize_strategy(self, optimize_stratety):
if isinstance(optimize_stratety, int):
return OptimizeStrategy(optimize_stratety)
return optimize_stratety

def __initialize_propagate(self, propagate_a: Optional[TransformKind],
propagate_b: Optional[TransformKind]):
MICRO_KERNEL_SIZE = 16
Expand Down Expand Up @@ -181,6 +186,10 @@ def __post_init__(self):
object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a))
object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b))

# set optimize_stratety to legal value
object.__setattr__(self, "optimize_stratety",
self.__legalize_optimize_strategy(self.optimize_stratety))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
self.__initialize_propagate(self.propagate_a, self.propagate_b)
Expand Down
8 changes: 7 additions & 1 deletion testing/python/cache/test_operator_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
target = bitblas.utils.auto_detect_nvidia_target()
bitblas.set_log_level("DEBUG")


def get_codegen_result(ops, target):
code = ops.get_source(target=target)
return code
Expand Down Expand Up @@ -245,6 +246,9 @@ def test_global_cache_save_to_database(
assert success

database_path = "/tmp/.tmp_bitblas_cache.db"
# clean the database if exists
if os.path.exists(database_path):
os.remove(database_path)
global_operator_cache.save_into_database(database_path, target=target)
assert os.path.exists(database_path)
global_operator_cache.clear()
Expand Down Expand Up @@ -281,4 +285,6 @@ def test_global_cache_save_to_database(

# fmt: on
if __name__ == "__main__":
bitblas.testing.main()
# bitblas.testing.main()
test_global_cache_save_to_database(1, 1024, 1024, "float16", "float16", "float16", False, False,
False, "nt", False)

0 comments on commit 8c666ac

Please sign in to comment.