diff --git a/python/__init__.py b/python/__init__.py index ee62a437..eab3adb4 100644 --- a/python/__init__.py +++ b/python/__init__.py @@ -5,10 +5,9 @@ import tempfile from pathlib import Path -from triton.common.backend import BaseBackend, register_backend +from triton.common.backend import BaseBackend, compute_core_version_key, register_backend from triton.compiler.make_launcher import make_so_cache_key from triton.runtime.cache import get_cache_manager -from triton.runtime.jit import version_key def _get_triton_shared_opt_path() -> str: @@ -274,6 +273,7 @@ class TritonSharedRefCPUBackend(BaseBackend): def __init__(self, device_type: str) -> None: super(TritonSharedRefCPUBackend, self).__init__(device_type) + self.version_key def add_stages(self, arch, extern_libs, stages): filter_in_stages = ["ast", "ttir"] @@ -300,7 +300,12 @@ def add_meta_info(self, ir, module, next_module, metadata, asm): def get_driver(self): return None - + + def get_version_key(self): + if self.version_key is None: + self.version_key = compute_core_version_key() + return self.version_key + def get_stream(self, idx=None) -> int: # Returns int to make Triton happy. return 0 @@ -336,7 +341,7 @@ def get_architecture_descriptor(self, **kwargs): def make_launcher_stub(self, name, signature, constants, ids): # name of files that are cached - so_cache_key = make_so_cache_key(version_key(), signature, constants, ids) + so_cache_key = make_so_cache_key(self.version_key, signature, constants, ids) so_cache_manager = get_cache_manager(so_cache_key) so_name = f"{name}.py" # retrieve stub from cache if it exists