From 8c666acd9579c83caa06283c5fd54232ca85ebca Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 7 Aug 2024 07:24:02 +0000 Subject: [PATCH] fix for legalize --- bitblas/ops/general_matmul/__init__.py | 11 ++++++++++- testing/python/cache/test_operator_cache.py | 8 +++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index d384ee83d..9b8256ded 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -85,7 +85,7 @@ class MatmulConfig(OperatorConfig): None # propagate_b is a flag to control the ladder permutation ) - optimize_stratety: OptimizeStrategy = OptimizeStrategy.ContigousBatching + optimize_stratety: Union[int, OptimizeStrategy] = OptimizeStrategy.ContigousBatching def __legalize_dynamic_symbolic(self, M): return tuple(self.M) if isinstance(self.M, list) else self.M @@ -98,6 +98,11 @@ def __legalize_propagate(self, propagate): return propagate + def __legalize_optimize_strategy(self, optimize_stratety): + if isinstance(optimize_stratety, int): + return OptimizeStrategy(optimize_stratety) + return optimize_stratety + def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]): MICRO_KERNEL_SIZE = 16 @@ -181,6 +186,10 @@ def __post_init__(self): object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) + # set optimize_stratety to legal value + object.__setattr__(self, "optimize_stratety", + self.__legalize_optimize_strategy(self.optimize_stratety)) + # This is hack to legalize propagate_a and b # TODO(lei): should be removed in the future when tc+br template is ready. self.__initialize_propagate(self.propagate_a, self.propagate_b) diff --git a/testing/python/cache/test_operator_cache.py b/testing/python/cache/test_operator_cache.py index 51a155f7f..7d609c7c8 100644 --- a/testing/python/cache/test_operator_cache.py +++ b/testing/python/cache/test_operator_cache.py @@ -10,6 +10,7 @@ target = bitblas.utils.auto_detect_nvidia_target() bitblas.set_log_level("DEBUG") + def get_codegen_result(ops, target): code = ops.get_source(target=target) return code @@ -245,6 +246,9 @@ def test_global_cache_save_to_database( assert success database_path = "/tmp/.tmp_bitblas_cache.db" + # clean the database if exists + if os.path.exists(database_path): + os.remove(database_path) global_operator_cache.save_into_database(database_path, target=target) assert os.path.exists(database_path) global_operator_cache.clear() @@ -281,4 +285,6 @@ def test_global_cache_save_to_database( # fmt: on if __name__ == "__main__": - bitblas.testing.main() + # bitblas.testing.main() + test_global_cache_save_to_database(1, 1024, 1024, "float16", "float16", "float16", False, False, + False, "nt", False)