Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix llvm dialect #120

Merged
merged 7 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/extras/dialects/ext/llvm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
amdgcn.py
72 changes: 6 additions & 66 deletions mlir/extras/dialects/ext/llvm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,13 @@
import warnings

# noinspection PyUnresolvedReferences
from .....dialects.llvm import *
from .....ir import Type, F16Type, F32Type, F64Type, BF16Type, IntegerType

try:
from llvm import intrinsic_is_overloaded, intrinsic_get_type, print_type_to_string
from llvm import types_
from llvm.context import context as llvm_context
except ImportError:
warnings.warn(
"llvm bindings not installed; call_intrinsic won't work without supplying return type explicitly"
)
from .....ir import Type, Value

ValueRef = Value

def llvm_ptr_t():
return Type.parse("!llvm.ptr")


def mlir_type_to_llvm_type(mlir_type, llvm_ctx):
if F16Type.isinstance(mlir_type):
return types_.half_type_in_context(llvm_ctx)
if F32Type.isinstance(mlir_type):
return types_.float_type_in_context(llvm_ctx)
if F64Type.isinstance(mlir_type):
return types_.double_type_in_context(llvm_ctx)
if BF16Type.isinstance(mlir_type):
return types_.b_float_type_in_context(llvm_ctx)
if IntegerType.isinstance(mlir_type):
return types_.int_type_in_context(llvm_ctx, mlir_type.width)

raise NotImplementedError(f"{mlir_type} is not supported")


def llvm_type_str_to_mlir_type(llvm_type: str):
if llvm_type.startswith("<"):
return Type.parse(f"vector{llvm_type}")
if llvm_type == "float":
return F32Type.get()
raise NotImplementedError(f"{llvm_type} is not supported")


_call_intrinsic = call_intrinsic


def call_intrinsic(*args, **kwargs):
intr_id = kwargs.pop("intr_id")
intr_name = kwargs.pop("intr_name")
mlir_ret_type = kwargs.pop("return_type", None)
if mlir_ret_type:
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])

is_overloaded = kwargs.pop("is_overloaded", None)
if is_overloaded is None:
is_overloaded = intrinsic_is_overloaded(intr_id)
with llvm_context() as ctx:
types = []
if is_overloaded:
types = [mlir_type_to_llvm_type(a.type, ctx.context) for a in args]
intr_decl_fn_ty = intrinsic_get_type(ctx.context, intr_id, types)

ret_type_str = print_type_to_string(intr_decl_fn_ty).split(" (")[0].strip()
mlir_ret_type = None
if ret_type_str:
mlir_ret_type = llvm_type_str_to_mlir_type(ret_type_str)

return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])


call_intrinsic_ = call_intrinsic

from . import amdgcn
try:
from . import amdgcn
except ImportError:
pass
Loading