Skip to content

Commit

Permalink
fix import missing for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanfz98 committed Nov 14, 2023
1 parent 98a6e19 commit 8e77393
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import tempfile
from pathlib import Path

from triton.common.backend import BaseBackend, register_backend
from triton.common.backend import BaseBackend, compute_core_version_key, register_backend
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
from triton.runtime.jit import version_key


def _get_triton_shared_opt_path() -> str:
Expand Down Expand Up @@ -274,6 +273,7 @@ class TritonSharedRefCPUBackend(BaseBackend):

def __init__(self, device_type: str) -> None:
super(TritonSharedRefCPUBackend, self).__init__(device_type)
self.version_key

def add_stages(self, arch, extern_libs, stages):
filter_in_stages = ["ast", "ttir"]
Expand All @@ -300,7 +300,12 @@ def add_meta_info(self, ir, module, next_module, metadata, asm):

def get_driver(self):
return None


def get_version_key(self):
if self.version_key is None:
self.version_key = compute_core_version_key()
return self.version_key

def get_stream(self, idx=None) -> int:
# Returns int to make Triton happy.
return 0
Expand Down Expand Up @@ -336,7 +341,7 @@ def get_architecture_descriptor(self, **kwargs):

def make_launcher_stub(self, name, signature, constants, ids):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids)
so_cache_key = make_so_cache_key(self.version_key, signature, constants, ids)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.py"
# retrieve stub from cache if it exists
Expand Down

0 comments on commit 8e77393

Please sign in to comment.