Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

amdgcn for llvm dialect #116

Merged
merged 11 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ concurrency:

env:
SYSTEM_VERSION_COMPAT: 0
PIP_FIND_LINKS: "https://github.com/llvm/eudsl/releases/expanded_assets/latest https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest"

jobs:

Expand All @@ -43,6 +44,8 @@ jobs:
- os: macos-14
py_version: "3.9"

name: "${{ matrix.os }}-${{ matrix.py_version }}"

steps:
- name: Checkout
uses: actions/checkout@v2
Expand All @@ -56,7 +59,7 @@ jobs:
- name: Install and configure
shell: bash
run: |
pip install .[test,mlir] -v -f https://makslevental.github.io/wheels
pip install .[test,mlir] -v

- name: Test
shell: bash
Expand Down Expand Up @@ -95,7 +98,6 @@ jobs:
- name: Install and configure
shell: bash
run: |
export PIP_FIND_LINKS=https://makslevental.github.io/wheels
pip install .[test,mlir] -v
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v

Expand Down Expand Up @@ -133,7 +135,7 @@ jobs:
run: |

pip install jupyter
pip install -q mlir-python-bindings -f https://makslevental.github.io/wheels
pip install -q mlir-python-bindings
pip install -q .

sed -i.bak 's/OUTPUT_TIMEOUT = 10/OUTPUT_TIMEOUT = 100/g' \
Expand Down
8 changes: 0 additions & 8 deletions mlir/extras/dialects/ext/llvm.py

This file was deleted.

73 changes: 73 additions & 0 deletions mlir/extras/dialects/ext/llvm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import warnings

# noinspection PyUnresolvedReferences
from .....dialects.llvm import *
from .....ir import Type, F16Type, F32Type, F64Type, BF16Type, IntegerType

try:
from llvm import intrinsic_is_overloaded, intrinsic_get_type, print_type_to_string
from llvm import types_
from llvm.context import context as llvm_context
except ImportError:
warnings.warn(
"llvm bindings not installed; call_intrinsic won't work without supplying return type explicitly"
)


def llvm_ptr_t():
return Type.parse("!llvm.ptr")


def mlir_type_to_llvm_type(mlir_type, llvm_ctx):
if F16Type.isinstance(mlir_type):
return types_.half_type_in_context(llvm_ctx)
if F32Type.isinstance(mlir_type):
return types_.float_type_in_context(llvm_ctx)
if F64Type.isinstance(mlir_type):
return types_.double_type_in_context(llvm_ctx)
if BF16Type.isinstance(mlir_type):
return types_.b_float_type_in_context(llvm_ctx)
if IntegerType.isinstance(mlir_type):
return types_.int_type_in_context(llvm_ctx, mlir_type.width)

raise NotImplementedError(f"{mlir_type} is not supported")


def llvm_type_str_to_mlir_type(llvm_type: str):
if llvm_type.startswith("<"):
return Type.parse(f"vector{llvm_type}")
if llvm_type == "float":
return F32Type.get()
raise NotImplementedError(f"{llvm_type} is not supported")


_call_intrinsic = call_intrinsic


def call_intrinsic(*args, **kwargs):
intr_id = kwargs.pop("intr_id")
intr_name = kwargs.pop("intr_name")
mlir_ret_type = kwargs.pop("return_type", None)
if mlir_ret_type:
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])

is_overloaded = kwargs.pop("is_overloaded", None)
if is_overloaded is None:
is_overloaded = intrinsic_is_overloaded(intr_id)
with llvm_context() as ctx:
types = []
if is_overloaded:
types = [mlir_type_to_llvm_type(a.type, ctx.context) for a in args]
intr_decl_fn_ty = intrinsic_get_type(ctx.context, intr_id, types)

ret_type_str = print_type_to_string(intr_decl_fn_ty).split(" (")[0].strip()
mlir_ret_type = None
if ret_type_str:
mlir_ret_type = llvm_type_str_to_mlir_type(ret_type_str)

return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])


call_intrinsic_ = call_intrinsic

from . import amdgcn
Loading