Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: stablehlo + apple metal #434

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ jobs:
name: pjrt_plugin_xla_cpu-darwin-aarch64
path: pjrt_plugin_xla_cpu.dylib
if-no-files-found: error
pjrt-plugin-apple-metal:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2
- name: Build or fetch Apple Metal PJRT plugin
run: |
prefix=jax_metal-0.1.0-py3-none-macosx_11_0_arm64
curl -fsL "https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl" \
-o "$prefix.zip"
unzip "$prefix.zip"
mv "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib" pjrt_plugin_apple_metal.dylib
- name: Upload binary
uses: actions/upload-artifact@v4
with:
name: pjrt_plugin_apple_metal
path: pjrt_plugin_apple_metal.dylib
if-no-files-found: error
pjrt-plugin-xla-cuda-linux-x86_64:
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -181,6 +200,27 @@ jobs:
name: tests-xla-cpu-darwin-aarch64
path: test/xla-cpu/tests-xla-cpu.tar.gz
if-no-files-found: error
build-tests-apple-metal:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- name: Install build dependencies
run: |
brew install chezscheme
git clone https://github.com/stefan-hoeck/idris2-pack.git
(cd idris2-pack && make micropack SCHEME=chez)
~/.pack/bin/pack switch HEAD
- name: Build tests
working-directory: test/apple-metal
run: |
SPIDR_INSTALL_SUPPORT_LIBS=false ~/.pack/bin/pack --no-prompt build apple-metal.ipkg
tar cfz tests-apple-metal.tar.gz -C build/exec .
- name: Upload tests
uses: actions/upload-artifact@v4
with:
name: tests-apple-metal
path: test/apple-metal/tests-apple-metal.tar.gz
if-no-files-found: error
build-tests-xla-cuda-linux-x86_64:
runs-on: ubuntu-latest
container: ghcr.io/stefan-hoeck/idris2-pack
Expand Down Expand Up @@ -239,6 +279,25 @@ jobs:
run: |
tar xfz tests-xla-cpu.tar.gz && rm tests-xla-cpu.tar.gz
./test
test-apple-metal:
needs:
- pjrt-darwin-aarch64
- pjrt-plugin-apple-metal
- build-tests-apple-metal
runs-on: macos-latest
steps:
- name: Download artifacts
uses: actions/download-artifact@v4
with:
pattern: "{libc_xla-darwin-aarch64,pjrt_plugin_apple_metal,tests-apple-metal}"
merge-multiple: true
- name: Install runtime dependencies
run: |
brew install chezscheme
- name: Run tests
run: |
tar xfz tests-apple-metal.tar.gz && rm tests-apple-metal.tar.gz
./test
test-xla-cuda-linux-x86_64:
needs:
- pjrt-linux-x86_64
Expand Down
5 changes: 5 additions & 0 deletions pack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ type = "local"
path = ""
ipkg = "test/runner/runner.ipkg"

[custom.all.pjrt-plugin-apple-metal]
type = "local"
path = ""
ipkg = "pjrt-plugins/apple-metal/pjrt-plugin-apple-metal.ipkg"

[custom.all.pjrt-plugin-xla-cpu]
type = "local"
path = ""
Expand Down
2 changes: 1 addition & 1 deletion pjrt-plugins/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# PJRT Plugins

A PJRT plugin provides the compiler and hardware device support required to execute spidr graphs. We provide plugins for [CPU](xla-cpu/README.md) and [CUDA-enabled GPUs](xla-cuda/README.md). You can also use third-party plugins, or make your own.
A PJRT plugin provides the compiler and hardware device support required to execute spidr graphs. We provide plugins for [CPU](xla-cpu/README.md), [Apple Metal](apple-metal/README.md), and [CUDA-enabled GPUs](xla-cuda/README.md). You can also use third-party plugins, or make your own.

## How to integrate your own plugin

Expand Down
30 changes: 30 additions & 0 deletions pjrt-plugins/apple-metal/PjrtPluginAppleMetal.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{--
Copyright 2024 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module PjrtPluginAppleMetal

import System.FFI

import public Compiler.Xla.PJRT.C.PjrtCApi
import public Device

%foreign "C:GetPjrtApi,pjrt_plugin_apple_metal"
prim__getPjrtApi : PrimIO AnyPtr

export
device : Pjrt Device
device = do
api <- MkPjrtApi <$> primIO prim__getPjrtApi
MkDevice api <$> pjrtClientCreate api
10 changes: 10 additions & 0 deletions pjrt-plugins/apple-metal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# PJRT plugin for Apple Metal

This is the PJRT plugin for Apple Metal, which provides hardware acceleration with GPU on Apple silicon (AArch64, ARM64).

## Install

Run
```
pack install pjrt-plugin-apple-metal
```
33 changes: 33 additions & 0 deletions pjrt-plugins/apple-metal/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/sh -e

script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
cd "$script_dir/../.."
. ./dev.sh
rev=$(cat XLA_VERSION)

osu="$(uname)"
case $osu in
'Linux')
os=linux
arch=x86_64
ext=so
;;
'Darwin')
os=darwin
arch=aarch64
ext=dylib
;;
*)
echo "OS $osu not handled"
exit 1
;;
esac

xla_dir=$(mktemp -d)
install_xla "$rev" "$xla_dir"
(
cd "$xla_dir"
./configure.py --backend=CPU --os=$os
bazel build //xla/pjrt/c:pjrt_c_api_cpu_plugin.so
)
mv "$xla_dir/bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so" "pjrt_plugin_xla_cpu-$os-$arch.$ext"
11 changes: 11 additions & 0 deletions pjrt-plugins/apple-metal/pjrt-plugin-apple-metal.ipkg
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package pjrt-plugin-apple-metal
version = 0.0.1

depends = spidr
modules = PjrtPluginAppleMetal

brief = "XLA PJRT plugin for Apple Metal."
readme = "README.md"
license = "Apache License, Version 2.0"

postinstall = "./postinstall.sh"
25 changes: 25 additions & 0 deletions pjrt-plugins/apple-metal/postinstall.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/sh -e

if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi

script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
cd "$script_dir/../.."

os="$(uname)"
case $os in
'Darwin')
;;
*)
echo "WARNING: OS $os not supported, unable to fetch supporting libraries."
exit 0
;;
esac

prefix=jax_metal-0.1.0-py3-none-macosx_11_0_arm64
curl -fsL "https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl" \
-o "$prefix.zip"
unzip "$prefix.zip"
libdir="$(idris2 --libdir)/pjrt-plugin-xla-cpu-0.0.1/lib"
mkdir -p libdir
mv "$prefix/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib" "$libdir/pjrt_plugin_apple_metal.dylib"
rm -rf "$prefix.zip" $prefix
10 changes: 10 additions & 0 deletions spidr/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,31 @@ cc_binary(
linkshared = True,
linkstatic = True,
srcs = [
"//src/mlir/IR",
"//src/stablehlo/dialect",
"//src/xla",
"//src/xla/client",
"//src/xla/hlo/builder",
"//src/xla/hlo/builder/lib",
"//src/xla/hlo/translate",
"//src/xla/mlir_hlo/mhlo/IR",
"//src/xla/pjrt",
"//src/xla/pjrt/c",
"//src/xla/service",
"//src",
],
deps = [
"//src/mlir/IR",
"//src/stablehlo/dialect",
"//src/xla",
"//src/xla/client",
"//src/xla/hlo/builder",
"//src/xla/hlo/builder/lib",
"//src/xla/hlo/translate",
"//src/xla/mlir_hlo/mhlo/IR",
"//src/xla/pjrt",
"//src/xla/pjrt/c",
"//src/xla/service",
"//src",
],
)
2 changes: 1 addition & 1 deletion spidr/backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.14
0.0.15
4 changes: 4 additions & 0 deletions spidr/backend/src/ffi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ extern "C" {
return ptr == nullptr;
}

string* string_new() {
return reinterpret_cast<string*>(new std::string());
}

void string_delete(string* s) {
delete reinterpret_cast<std::string*>(s);
}
Expand Down
12 changes: 12 additions & 0 deletions spidr/backend/src/mlir/IR/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
cc_library(
name = "IR",
linkstatic = True,
alwayslink = True,
srcs = glob(["*.cpp"]),
hdrs = glob(["*.h"]),
deps = [
"@llvm-project//mlir:IR",
"//src",
],
visibility = ["//visibility:public"],
)
18 changes: 18 additions & 0 deletions spidr/backend/src/mlir/IR/BuiltinOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
Copyright 2024 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
extern "C" {
struct ModuleOp;
}
28 changes: 28 additions & 0 deletions spidr/backend/src/mlir/IR/DialectRegistry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
Copyright 2024 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "mlir/IR/DialectRegistry.h"

#include "DialectRegistry.h"

extern "C" {
DialectRegistry* DialectRegistry_new() {
return reinterpret_cast<DialectRegistry*>(new mlir::DialectRegistry());
}

void DialectRegistry_delete(DialectRegistry* s) {
delete reinterpret_cast<mlir::DialectRegistry*>(s);
}
}
18 changes: 18 additions & 0 deletions spidr/backend/src/mlir/IR/DialectRegistry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
Copyright 2024 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
extern "C" {
struct DialectRegistry;
}
34 changes: 34 additions & 0 deletions spidr/backend/src/mlir/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
Copyright 2024 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "mlir/IR/MLIRContext.h"

#include "DialectRegistry.h"
#include "MLIRContext.h"

extern "C" {
MLIRContext* MLIRContext_new() {
return reinterpret_cast<MLIRContext*>(new mlir::MLIRContext);
}

void MLIRContext_delete(MLIRContext* s) {
delete reinterpret_cast<mlir::MLIRContext*>(s);
}

void MLIRContext_appendDialectRegistry(MLIRContext& s, DialectRegistry& registry) {
auto& registry_ = reinterpret_cast<mlir::DialectRegistry&>(registry);
reinterpret_cast<mlir::MLIRContext&>(s).appendDialectRegistry(registry_);
}
}
18 changes: 18 additions & 0 deletions spidr/backend/src/mlir/IR/MLIRContext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
Copyright 2024 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
extern "C" {
struct MLIRContext;
}
Loading
Loading