diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 30cc6c6be..4a018d196 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -112,6 +112,25 @@ jobs: name: pjrt_plugin_xla_cpu-darwin-aarch64 path: pjrt_plugin_xla_cpu.dylib if-no-files-found: error + pjrt-plugin-apple-metal: + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 2 + - name: Build or fetch Apple Metal PJRT plugin + run: | + prefix=jax_metal-0.1.0-py3-none-macosx_11_0_arm64 + curl -fsL "https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl" \ + -o "$prefix.zip" + unzip "$prefix.zip" + mv "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib" pjrt_plugin_apple_metal.dylib + - name: Upload binary + uses: actions/upload-artifact@v4 + with: + name: pjrt_plugin_apple_metal + path: pjrt_plugin_apple_metal.dylib + if-no-files-found: error pjrt-plugin-xla-cuda-linux-x86_64: runs-on: ubuntu-latest steps: @@ -181,6 +200,27 @@ jobs: name: tests-xla-cpu-darwin-aarch64 path: test/xla-cpu/tests-xla-cpu.tar.gz if-no-files-found: error + build-tests-apple-metal: + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + - name: Install build dependencies + run: | + brew install chezscheme + git clone https://github.com/stefan-hoeck/idris2-pack.git + (cd idris2-pack && make micropack SCHEME=chez) + ~/.pack/bin/pack switch HEAD + - name: Build tests + working-directory: test/apple-metal + run: | + SPIDR_INSTALL_SUPPORT_LIBS=false ~/.pack/bin/pack --no-prompt build apple-metal.ipkg + tar cfz tests-apple-metal.tar.gz -C build/exec . + - name: Upload tests + uses: actions/upload-artifact@v4 + with: + name: tests-apple-metal + path: test/apple-metal/tests-apple-metal.tar.gz + if-no-files-found: error build-tests-xla-cuda-linux-x86_64: runs-on: ubuntu-latest container: ghcr.io/stefan-hoeck/idris2-pack @@ -239,6 +279,25 @@ jobs: run: | tar xfz tests-xla-cpu.tar.gz && rm tests-xla-cpu.tar.gz ./test + test-apple-metal: + needs: + - pjrt-darwin-aarch64 + - pjrt-plugin-apple-metal + - build-tests-apple-metal + runs-on: macos-latest + steps: + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + pattern: "{libc_xla-darwin-aarch64,pjrt_plugin_apple_metal,tests-apple-metal}" + merge-multiple: true + - name: Install runtime dependencies + run: | + brew install chezscheme + - name: Run tests + run: | + tar xfz tests-apple-metal.tar.gz && rm tests-apple-metal.tar.gz + ./test test-xla-cuda-linux-x86_64: needs: - pjrt-linux-x86_64 diff --git a/pack.toml b/pack.toml index 7e1218ff7..33edbc328 100644 --- a/pack.toml +++ b/pack.toml @@ -8,6 +8,11 @@ type = "local" path = "" ipkg = "test/runner/runner.ipkg" +[custom.all.pjrt-plugin-apple-metal] +type = "local" +path = "" +ipkg = "pjrt-plugins/apple-metal/pjrt-plugin-apple-metal.ipkg" + [custom.all.pjrt-plugin-xla-cpu] type = "local" path = "" diff --git a/pjrt-plugins/README.md b/pjrt-plugins/README.md index bc0d70f5f..8d977e52f 100644 --- a/pjrt-plugins/README.md +++ b/pjrt-plugins/README.md @@ -1,6 +1,6 @@ # PJRT Plugins -A PJRT plugin provides the compiler and hardware device support required to execute spidr graphs. We provide plugins for [CPU](xla-cpu/README.md) and [CUDA-enabled GPUs](xla-cuda/README.md). You can also use third-party plugins, or make your own. +A PJRT plugin provides the compiler and hardware device support required to execute spidr graphs. We provide plugins for [CPU](xla-cpu/README.md), [Apple Metal](apple-metal/README.md), and [CUDA-enabled GPUs](xla-cuda/README.md). You can also use third-party plugins, or make your own. ## How to integrate your own plugin diff --git a/pjrt-plugins/apple-metal/PjrtPluginAppleMetal.idr b/pjrt-plugins/apple-metal/PjrtPluginAppleMetal.idr new file mode 100644 index 000000000..f9e1c33e6 --- /dev/null +++ b/pjrt-plugins/apple-metal/PjrtPluginAppleMetal.idr @@ -0,0 +1,30 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module PjrtPluginAppleMetal + +import System.FFI + +import public Compiler.Xla.PJRT.C.PjrtCApi +import public Device + +%foreign "C:GetPjrtApi,pjrt_plugin_apple_metal" +prim__getPjrtApi : PrimIO AnyPtr + +export +device : Pjrt Device +device = do + api <- MkPjrtApi <$> primIO prim__getPjrtApi + MkDevice api <$> pjrtClientCreate api diff --git a/pjrt-plugins/apple-metal/README.md b/pjrt-plugins/apple-metal/README.md new file mode 100644 index 000000000..fda57042c --- /dev/null +++ b/pjrt-plugins/apple-metal/README.md @@ -0,0 +1,10 @@ +# PJRT plugin for Apple Metal + +This is the PJRT plugin for Apple Metal, which provides hardware acceleration with GPU on Apple silicon (AArch64, ARM64). + +## Install + +Run +``` +pack install pjrt-plugin-apple-metal +``` diff --git a/pjrt-plugins/apple-metal/build.sh b/pjrt-plugins/apple-metal/build.sh new file mode 100644 index 000000000..094d8c41d --- /dev/null +++ b/pjrt-plugins/apple-metal/build.sh @@ -0,0 +1,33 @@ +#!/bin/sh -e + +script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd) +cd "$script_dir/../.." +. ./dev.sh +rev=$(cat XLA_VERSION) + +osu="$(uname)" +case $osu in + 'Linux') + os=linux + arch=x86_64 + ext=so + ;; + 'Darwin') + os=darwin + arch=aarch64 + ext=dylib + ;; + *) + echo "OS $osu not handled" + exit 1 + ;; +esac + +xla_dir=$(mktemp -d) +install_xla "$rev" "$xla_dir" +( + cd "$xla_dir" + ./configure.py --backend=CPU --os=$os + bazel build //xla/pjrt/c:pjrt_c_api_cpu_plugin.so +) +mv "$xla_dir/bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so" "pjrt_plugin_xla_cpu-$os-$arch.$ext" diff --git a/pjrt-plugins/apple-metal/pjrt-plugin-apple-metal.ipkg b/pjrt-plugins/apple-metal/pjrt-plugin-apple-metal.ipkg new file mode 100644 index 000000000..79183d236 --- /dev/null +++ b/pjrt-plugins/apple-metal/pjrt-plugin-apple-metal.ipkg @@ -0,0 +1,11 @@ +package pjrt-plugin-apple-metal +version = 0.0.1 + +depends = spidr +modules = PjrtPluginAppleMetal + +brief = "XLA PJRT plugin for Apple Metal." +readme = "README.md" +license = "Apache License, Version 2.0" + +postinstall = "./postinstall.sh" diff --git a/pjrt-plugins/apple-metal/postinstall.sh b/pjrt-plugins/apple-metal/postinstall.sh new file mode 100755 index 000000000..7df8c4669 --- /dev/null +++ b/pjrt-plugins/apple-metal/postinstall.sh @@ -0,0 +1,25 @@ +#!/bin/sh -e + +if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi + +script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd) +cd "$script_dir/../.." + +os="$(uname)" +case $os in + 'Darwin') + ;; + *) + echo "WARNING: OS $os not supported, unable to fetch supporting libraries." + exit 0 + ;; +esac + +prefix=jax_metal-0.1.0-py3-none-macosx_11_0_arm64 +curl -fsL "https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl" \ + -o "$prefix.zip" +unzip "$prefix.zip" +libdir="$(idris2 --libdir)/pjrt-plugin-xla-cpu-0.0.1/lib" +mkdir -p libdir +mv "$prefix/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib" "$libdir/pjrt_plugin_apple_metal.dylib" +rm -rf "$prefix.zip" $prefix diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index ba1a39bfa..cff2e8398 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -12,21 +12,31 @@ cc_binary( linkshared = True, linkstatic = True, srcs = [ + "//src/mlir/IR", + "//src/stablehlo/dialect", "//src/xla", "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", + "//src/xla/hlo/translate", + "//src/xla/mlir_hlo/mhlo/IR", "//src/xla/pjrt", "//src/xla/pjrt/c", + "//src/xla/service", "//src", ], deps = [ + "//src/mlir/IR", + "//src/stablehlo/dialect", "//src/xla", "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", + "//src/xla/hlo/translate", + "//src/xla/mlir_hlo/mhlo/IR", "//src/xla/pjrt", "//src/xla/pjrt/c", + "//src/xla/service", "//src", ], ) diff --git a/spidr/backend/VERSION b/spidr/backend/VERSION index 9789c4ccb..ceddfb28f 100644 --- a/spidr/backend/VERSION +++ b/spidr/backend/VERSION @@ -1 +1 @@ -0.0.14 +0.0.15 diff --git a/spidr/backend/src/ffi.cpp b/spidr/backend/src/ffi.cpp index 2a77d13ac..f940b423e 100644 --- a/spidr/backend/src/ffi.cpp +++ b/spidr/backend/src/ffi.cpp @@ -29,6 +29,10 @@ extern "C" { return ptr == nullptr; } + string* string_new() { + return reinterpret_cast(new std::string()); + } + void string_delete(string* s) { delete reinterpret_cast(s); } diff --git a/spidr/backend/src/mlir/IR/BUILD b/spidr/backend/src/mlir/IR/BUILD new file mode 100644 index 000000000..f034b361b --- /dev/null +++ b/spidr/backend/src/mlir/IR/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "IR", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@llvm-project//mlir:IR", + "//src", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/mlir/IR/BuiltinOps.h b/spidr/backend/src/mlir/IR/BuiltinOps.h new file mode 100644 index 000000000..0fb5ccbec --- /dev/null +++ b/spidr/backend/src/mlir/IR/BuiltinOps.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct ModuleOp; +} diff --git a/spidr/backend/src/mlir/IR/DialectRegistry.cpp b/spidr/backend/src/mlir/IR/DialectRegistry.cpp new file mode 100644 index 000000000..dfc543d57 --- /dev/null +++ b/spidr/backend/src/mlir/IR/DialectRegistry.cpp @@ -0,0 +1,28 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/DialectRegistry.h" + +#include "DialectRegistry.h" + +extern "C" { + DialectRegistry* DialectRegistry_new() { + return reinterpret_cast(new mlir::DialectRegistry()); + } + + void DialectRegistry_delete(DialectRegistry* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/IR/DialectRegistry.h b/spidr/backend/src/mlir/IR/DialectRegistry.h new file mode 100644 index 000000000..58c7ab272 --- /dev/null +++ b/spidr/backend/src/mlir/IR/DialectRegistry.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct DialectRegistry; +} diff --git a/spidr/backend/src/mlir/IR/MLIRContext.cpp b/spidr/backend/src/mlir/IR/MLIRContext.cpp new file mode 100644 index 000000000..9083b03e7 --- /dev/null +++ b/spidr/backend/src/mlir/IR/MLIRContext.cpp @@ -0,0 +1,34 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/MLIRContext.h" + +#include "DialectRegistry.h" +#include "MLIRContext.h" + +extern "C" { + MLIRContext* MLIRContext_new() { + return reinterpret_cast(new mlir::MLIRContext); + } + + void MLIRContext_delete(MLIRContext* s) { + delete reinterpret_cast(s); + } + + void MLIRContext_appendDialectRegistry(MLIRContext& s, DialectRegistry& registry) { + auto& registry_ = reinterpret_cast(registry); + reinterpret_cast(s).appendDialectRegistry(registry_); + } +} diff --git a/spidr/backend/src/mlir/IR/MLIRContext.h b/spidr/backend/src/mlir/IR/MLIRContext.h new file mode 100644 index 000000000..efa58bc0c --- /dev/null +++ b/spidr/backend/src/mlir/IR/MLIRContext.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct MLIRContext; +} diff --git a/spidr/backend/src/stablehlo/dialect/BUILD b/spidr/backend/src/stablehlo/dialect/BUILD new file mode 100644 index 000000000..8cfce51e4 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/BUILD @@ -0,0 +1,14 @@ +cc_library( + name = "dialect", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@llvm-project//mlir:IR", + "@stablehlo//:register", + "@stablehlo//:stablehlo_serialization", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/stablehlo/dialect/Register.cpp b/spidr/backend/src/stablehlo/dialect/Register.cpp new file mode 100644 index 000000000..505668a34 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/Register.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "stablehlo/dialect/Register.h" + +#include "../../mlir/IR/DialectRegistry.h" + +extern "C" { + void registerAllDialects(DialectRegistry& registry) { + mlir::stablehlo::registerAllDialects(reinterpret_cast(registry)); + } +} diff --git a/spidr/backend/src/stablehlo/dialect/Serialization.cpp b/spidr/backend/src/stablehlo/dialect/Serialization.cpp new file mode 100644 index 000000000..87c066038 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/Serialization.cpp @@ -0,0 +1,57 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/Bytecode/BytecodeWriter.h" +#include "stablehlo/dialect/Serialization.h" +#include "stablehlo/dialect/Version.h" + +#include "../../mlir/IR/BuiltinOps.h" +#include "../../ffi.h" + +extern "C" { + int serializePortableArtifact(ModuleOp& module, string& str) { + auto& module_ = reinterpret_cast(module); + auto& str_ = reinterpret_cast(str); + +// std::string s; +// llvm::raw_string_ostream os0(s); +// module_.print(os0); +// printf("serializePortableArtifact ...\n"); +// printf("... debug print:\n"); +// printf("%s\n", s.c_str()); + + llvm::raw_string_ostream os(str_); +// if (mlir::writeBytecodeToFile(module_, os).failed()) { +// return (int) false; +// } + +// printf("... serialization:\n"); +// printf("%s\n", str_.c_str()); + auto version = mlir::vhlo::Version::getMinimumVersion().toString(); + auto result = mlir::stablehlo::serializePortableArtifact(module_, version, os); + return (int) result.succeeded(); + } + + string* printModule(ModuleOp& module) { + auto& module_ = reinterpret_cast(module); + auto str = new std::string(); + llvm::raw_string_ostream os(*str); + module_.print(os); + + printf("... debug print:\n"); + printf("%s\n", str->c_str()); + return reinterpret_cast(str); + } +} diff --git a/spidr/backend/src/xla/hlo/builder/BUILD b/spidr/backend/src/xla/hlo/builder/BUILD index e729f1eef..48be5352b 100644 --- a/spidr/backend/src/xla/hlo/builder/BUILD +++ b/spidr/backend/src/xla/hlo/builder/BUILD @@ -8,6 +8,7 @@ cc_library( "@xla//xla/hlo/builder:xla_builder", "//src", "//src/xla", + "//src/xla/service", ], visibility = ["//visibility:public"], ) diff --git a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp index 1cba3a527..8695e74ee 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp @@ -14,8 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "xla/hlo/builder/xla_computation.h" +#include "xla/shape.h" #include "../../../ffi.h" +#include "../../service/hlo.proto.h" +#include "../../shape.h" #include "xla_computation.h" extern "C" { @@ -23,9 +26,8 @@ extern "C" { delete reinterpret_cast(s); } - string* XlaComputation_SerializeAsString(XlaComputation* s) { + HloModuleProto* XlaComputation_proto(XlaComputation* s) { auto s_ = reinterpret_cast(s); - auto serialized = s_->proto().SerializeAsString(); - return reinterpret_cast(new std::string(serialized)); + return reinterpret_cast(new xla::HloModuleProto(s_->proto())); } } diff --git a/spidr/backend/src/xla/hlo/translate/BUILD b/spidr/backend/src/xla/hlo/translate/BUILD new file mode 100644 index 000000000..75212dc84 --- /dev/null +++ b/spidr/backend/src/xla/hlo/translate/BUILD @@ -0,0 +1,13 @@ +cc_library( + name = "translate", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/hlo/translate:stablehlo", + "//src/mlir/IR", + "//src/xla/service", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp new file mode 100644 index 000000000..e358aeca7 --- /dev/null +++ b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp @@ -0,0 +1,31 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/BuiltinOps.h" +#include "xla/service/hlo.pb.h" +#include "xla/hlo/translate/stablehlo.h" + +#include "../../service/hlo.proto.h" +#include "../../../mlir/IR/BuiltinOps.h" +#include "../../../mlir/IR/MLIRContext.h" + +extern "C" { + ModuleOp* ConvertHloToStablehlo(MLIRContext& ctx, HloModuleProto* hlo_module) { + auto& ctx_ = reinterpret_cast(ctx); + auto hlo_module_ = reinterpret_cast(hlo_module); + auto module_op = xla::ConvertHloToStablehlo(ctx_, hlo_module_); + return reinterpret_cast(new mlir::ModuleOp(module_op.value().release())); + } +} diff --git a/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD new file mode 100644 index 000000000..e7f37a3c4 --- /dev/null +++ b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "IR", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/mlir_hlo:hlo_dialect_registration", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp new file mode 100644 index 000000000..eb9319d4d --- /dev/null +++ b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/mlir_hlo/mhlo/IR/register.h" + +#include "../../../../mlir/IR/DialectRegistry.h" + +extern "C" { + void registerAllMhloDialects(DialectRegistry& registry) { + mlir::mhlo::registerAllMhloDialects(reinterpret_cast(registry)); + } +} diff --git a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp index 18290860c..eeb4f8950 100644 --- a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp +++ b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp @@ -127,7 +127,7 @@ extern "C" { } PJRT_Program* PJRT_Program_new(char* code, size_t code_size) { - auto format = pjrt::kHloFormat; + auto format = pjrt::kMlirFormat; return new PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, .extension_start = nullptr, @@ -159,7 +159,11 @@ extern "C" { } PJRT_Error* pjrt_client_compile(PJRT_Api* api, PJRT_Client_Compile_Args* args) { - return api->PJRT_Client_Compile(args); +// printf("pjrt_client_compile ...\n"); +// printf("... args->program->code\n"); +// printf("%.*s\n", args->program->code_size, args->program->code); + auto res = api->PJRT_Client_Compile(args); + return res; } PJRT_LoadedExecutable_Destroy_Args* PJRT_LoadedExecutable_Destroy_Args_new( diff --git a/spidr/backend/src/xla/service/BUILD b/spidr/backend/src/xla/service/BUILD new file mode 100644 index 000000000..38c7f2d3e --- /dev/null +++ b/spidr/backend/src/xla/service/BUILD @@ -0,0 +1,13 @@ +cc_library( + name = "service", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/service", + "//src/xla", + "//src", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/service/hlo.proto.cpp b/spidr/backend/src/xla/service/hlo.proto.cpp new file mode 100644 index 000000000..195d17026 --- /dev/null +++ b/spidr/backend/src/xla/service/hlo.proto.cpp @@ -0,0 +1,31 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/service/hlo.pb.h" +// #include "xla/service/..." // try to import from some random place + +#include "../../ffi.h" +#include "hlo.proto.h" + +extern "C" { + string* HloModuleProto_SerializeAsString(HloModuleProto& s) { + auto s_ = reinterpret_cast(s); + return reinterpret_cast(new std::string(s_.SerializeAsString())); + } + + void HloModuleProto_delete(HloModuleProto* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/xla/service/hlo.proto.h b/spidr/backend/src/xla/service/hlo.proto.h new file mode 100644 index 000000000..336bbeaf3 --- /dev/null +++ b/spidr/backend/src/xla/service/hlo.proto.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct HloModuleProto; +} diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index fa7b670c0..a0dfb6bce 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,6 +8,11 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, + Compiler.MLIR.IR.BuiltinOps, + Compiler.MLIR.IR.DialectRegistry, + Compiler.MLIR.IR.MLIRContext, + Compiler.StableHLO.Dialect.Register, + Compiler.StableHLO.Dialect.Serialization, Compiler.Xla.Client.ExecutableBuildOptions, Compiler.Xla.HLO.Builder.Lib.Arithmetic, Compiler.Xla.HLO.Builder.Lib.Constants, @@ -16,8 +21,11 @@ modules = Compiler.Xla.HLO.Builder.Lib.PRNG, Compiler.Xla.HLO.Builder.XlaBuilder, Compiler.Xla.HLO.Builder.XlaComputation, + Compiler.Xla.HLO.Translate.StableHLO, + Compiler.Xla.MLIRHLO.MHLO.IR.Register, Compiler.Xla.PJRT.C.PjrtCApi, Compiler.Xla.PJRT.PjrtExecutable, + Compiler.Xla.Service.HloProto, Compiler.Xla.Literal, Compiler.Xla.Shape, Compiler.Xla.ShapeUtil, diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 6c26186f5..e719cd255 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -26,6 +26,11 @@ import Data.List.Elem import Compiler.Expr import Compiler.FFI import Compiler.LiteralRW +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.DialectRegistry +import Compiler.MLIR.IR.MLIRContext +import Compiler.StableHLO.Dialect.Register +import Compiler.StableHLO.Dialect.Serialization import Compiler.Xla.Client.ExecutableBuildOptions import Compiler.Xla.HLO.Builder.Lib.Arithmetic import Compiler.Xla.HLO.Builder.Lib.Constants @@ -34,8 +39,11 @@ import Compiler.Xla.HLO.Builder.Lib.Matrix import Compiler.Xla.HLO.Builder.Lib.PRNG import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.HLO.Builder.XlaComputation +import Compiler.Xla.HLO.Translate.StableHLO +import Compiler.Xla.MLIRHLO.MHLO.IR.Register import Compiler.Xla.PJRT.C.PjrtCApi import Compiler.Xla.PJRT.PjrtExecutable +import Compiler.Xla.Service.HloProto import Compiler.Xla.Literal import Compiler.Xla.Shape import Compiler.Xla.ShapeUtil @@ -46,17 +54,21 @@ import Types import Util import Device +import System + export data Err = OutOfBounds Nat Nat | ValueNotFound Nat | PjrtErr PjrtError + | SerializationError String export Show Err where show (OutOfBounds idx size) = "Index \{show idx} is out of bounds for array of size \{show size}" show (ValueNotFound idx) = "Value not found at index \{show idx}" - show (PjrtErr err)= show err + show (PjrtErr err) = show err + show (SerializationError err) = "SerializationError: \{err}" public export 0 ErrIO : Type -> Type @@ -223,11 +235,19 @@ execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ V execute (MkDevice api client) f@(MkFn _ _ env) shapes = do xlaBuilder <- mkXlaBuilder "root" computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f + dialectRegistry <- mkDialectRegistry + registerAllMhloDialects dialectRegistry + registerAllDialects dialectRegistry + mlirCtx <- mkMLIRContext + stablehlo <- convertHloToStablehlo mlirCtx !(proto computation) + appendDialectRegistry mlirCtx dialectRegistry + Just code <- serializePortableArtifact stablehlo | Nothing => throwE (SerializationError "Failed to serialize StableHLO") + -- code <- printModule stablehlo + executableBuildOptions <- mkExecutableBuildOptions + compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) + program <- mkPjrtProgram code bimapEitherT PjrtErr id $ do - code <- serializeAsString computation - executableBuildOptions <- mkExecutableBuildOptions - compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) - loadedExec <- pjrtClientCompile api client !(mkPjrtProgram code) compileOptions + loadedExec <- pjrtClientCompile api client program compileOptions free code free compileOptions delete executableBuildOptions diff --git a/spidr/src/Compiler/FFI.idr b/spidr/src/Compiler/FFI.idr index aec92c193..307b9574b 100644 --- a/spidr/src/Compiler/FFI.idr +++ b/spidr/src/Compiler/FFI.idr @@ -31,6 +31,10 @@ namespace CharArray free : HasIO io => CharArray -> io () free (MkCharArray arr _) = free $ prim__forgetPtr arr +export +%foreign (libxla "string_new") +prim__stringNew : PrimIO AnyPtr + export %foreign (libxla "string_delete") prim__stringDelete : AnyPtr -> PrimIO () @@ -47,6 +51,15 @@ export %foreign (libxla "idx") prim__index : Int -> AnyPtr -> AnyPtr +||| Deletes the `string`. It is up to the caller to `free` the `CharArray`. +export +stringToCharArray : HasIO io => AnyPtr -> io CharArray +stringToCharArray str = do + data' <- primIO $ prim__stringData str + let size = prim__stringSize str + primIO $ prim__stringDelete str + pure (MkCharArray data' size) + export cIntToBool : Int -> Bool cIntToBool 0 = False diff --git a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr new file mode 100644 index 000000000..9f06c7327 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr @@ -0,0 +1,20 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.BuiltinOps + +public export +data ModuleOp = MkModuleOp AnyPtr -- need to GC diff --git a/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr b/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr new file mode 100644 index 000000000..329b35308 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.DialectRegistry + +import Compiler.FFI + +public export +data DialectRegistry = MkDialectRegistry GCAnyPtr + +%foreign (libxla "DialectRegistry_new") +prim__mkDialectRegistry : PrimIO AnyPtr + +%foreign (libxla "DialectRegistry_delete") +prim__deleteDialectRegistry : AnyPtr -> PrimIO () + +export +mkDialectRegistry : HasIO io => io DialectRegistry +mkDialectRegistry = do + registry <- primIO prim__mkDialectRegistry + registry <- onCollectAny registry (primIO . prim__deleteDialectRegistry) + pure (MkDialectRegistry registry) diff --git a/spidr/src/Compiler/MLIR/IR/MLIRContext.idr b/spidr/src/Compiler/MLIR/IR/MLIRContext.idr new file mode 100644 index 000000000..645a21e56 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/MLIRContext.idr @@ -0,0 +1,44 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.MLIRContext + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +public export +data MLIRContext = MkMLIRContext GCAnyPtr + +%foreign (libxla "MLIRContext_new") +prim__mkMLIRContext : PrimIO AnyPtr + +%foreign (libxla "MLIRContext_delete") +prim__deleteMLIRContext : AnyPtr -> PrimIO () + +export +mkMLIRContext : HasIO io => io MLIRContext +mkMLIRContext = do + ctx <- primIO prim__mkMLIRContext + ctx <- onCollectAny ctx (primIO . prim__deleteMLIRContext) + pure (MkMLIRContext ctx) + +%foreign (libxla "MLIRContext_appendDialectRegistry") +prim__appendDialectRegistry : GCAnyPtr -> GCAnyPtr -> PrimIO () + +export +appendDialectRegistry : HasIO io => MLIRContext -> DialectRegistry -> io () +appendDialectRegistry (MkMLIRContext ctx) (MkDialectRegistry registry) = + primIO $ prim__appendDialectRegistry ctx registry diff --git a/spidr/src/Compiler/StableHLO/Dialect/Register.idr b/spidr/src/Compiler/StableHLO/Dialect/Register.idr new file mode 100644 index 000000000..e51220ac2 --- /dev/null +++ b/spidr/src/Compiler/StableHLO/Dialect/Register.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.StableHLO.Dialect.Register + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "registerAllDialects") +prim__registerAllDialects : GCAnyPtr -> PrimIO () + +export +registerAllDialects : HasIO io => DialectRegistry -> io () +registerAllDialects (MkDialectRegistry reg) = primIO $ prim__registerAllDialects reg diff --git a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr new file mode 100644 index 000000000..4b8d55e02 --- /dev/null +++ b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr @@ -0,0 +1,39 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.StableHLO.Dialect.Serialization + +import Compiler.MLIR.IR.BuiltinOps +import Compiler.FFI + +%foreign (libxla "serializePortableArtifact") +prim__serializePortableArtifact : AnyPtr -> AnyPtr -> PrimIO Int + +export +serializePortableArtifact : HasIO io => ModuleOp -> io (Maybe CharArray) +serializePortableArtifact (MkModuleOp moduleOp) = do + str <- primIO prim__stringNew + ok <- primIO $ prim__serializePortableArtifact moduleOp str + case cIntToBool ok of + True => Just <$> stringToCharArray str + False => free str >> pure Nothing + +%foreign (libxla "printModule") +prim__printModule : AnyPtr -> PrimIO AnyPtr + +export +printModule : HasIO io => ModuleOp -> io CharArray +printModule (MkModuleOp moduleOp) = primIO (prim__printModule moduleOp) >>= stringToCharArray diff --git a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr index 1e35ba4dc..a9a4c455f 100644 --- a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr @@ -17,6 +17,8 @@ limitations under the License. module Compiler.Xla.HLO.Builder.XlaComputation import Compiler.FFI +import Compiler.Xla.Shape +import Compiler.Xla.Service.HloProto public export data XlaComputation : Type where @@ -27,18 +29,14 @@ prim__delete : AnyPtr -> PrimIO () export delete : AnyPtr -> IO () -delete = primIO . prim__delete +delete = primIO . XlaComputation.prim__delete -export -%foreign (libxla "XlaComputation_SerializeAsString") -prim__xlaComputationSerializeAsString : GCAnyPtr -> PrimIO AnyPtr +%foreign (libxla "XlaComputation_proto") +prim__xlaComputationProto : GCAnyPtr -> PrimIO AnyPtr -||| It is up to the caller to deallocate the CharArray. export -serializeAsString : HasIO io => XlaComputation -> io CharArray -serializeAsString (MkXlaComputation computation) = do - str <- primIO $ prim__xlaComputationSerializeAsString computation - data' <- primIO $ prim__stringData str - let size = prim__stringSize str - primIO $ prim__stringDelete str - pure (MkCharArray data' size) +proto : HasIO io => XlaComputation -> io HloModuleProto +proto (MkXlaComputation comp) = do + proto <- primIO $ prim__xlaComputationProto comp + proto <- onCollectAny proto (primIO . HloProto.prim__delete) + pure (MkHloModuleProto proto) diff --git a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr new file mode 100644 index 000000000..b58bbe055 --- /dev/null +++ b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr @@ -0,0 +1,31 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.HLO.Translate.StableHLO + +import Compiler.FFI +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.MLIRContext +import Compiler.Xla.Service.HloProto + +%foreign (libxla "ConvertHloToStablehlo") +prim__convertHloToStablehlo : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr + +export +convertHloToStablehlo : HasIO io => MLIRContext -> HloModuleProto -> io ModuleOp +convertHloToStablehlo (MkMLIRContext ctx) (MkHloModuleProto proto) = do + moduleOp <- primIO $ prim__convertHloToStablehlo ctx proto + pure (MkModuleOp moduleOp) diff --git a/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr b/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr new file mode 100644 index 000000000..77f82fd41 --- /dev/null +++ b/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.MLIRHLO.MHLO.IR.Register + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "registerAllMhloDialects") +prim__registerAllMhloDialects : GCAnyPtr -> PrimIO () + +export +registerAllMhloDialects : HasIO io => DialectRegistry -> io () +registerAllMhloDialects (MkDialectRegistry reg) = primIO $ prim__registerAllMhloDialects reg diff --git a/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr b/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr index 734f9fa16..24f0b7f9f 100644 --- a/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr +++ b/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr @@ -73,9 +73,9 @@ export Show PjrtError where show e = let code = case e.code of - Nothing => "not found" + Nothing => "unknown" Just c => show c - in "PjrtError \{show e.message} (code \{code})" + in "PjrtError (error code \{code})\n\{e.message}" %foreign (libxla "PJRT_Error_Destroy_Args_new") prim__mkPjrtErrorDestroyArgs : AnyPtr -> PrimIO AnyPtr diff --git a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr index 987cb1fdd..acdb83406 100644 --- a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr +++ b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr @@ -36,12 +36,8 @@ mkCompileOptions (MkExecutableBuildOptions executableBuildOptions) = do %foreign (libxla "CompileOptions_SerializeAsString") prim__compileOptionsSerializeAsString : GCAnyPtr -> PrimIO AnyPtr -||| It is up to the caller to deallocate the CharArray. +||| It is up to the caller to `free` the `CharArray`. export serializeAsString : HasIO io => CompileOptions -> io CharArray -serializeAsString (MkCompileOptions options) = do - str <- primIO $ prim__compileOptionsSerializeAsString options - data' <- primIO $ prim__stringData str - let size = prim__stringSize str - primIO $ prim__stringDelete str - pure (MkCharArray data' size) +serializeAsString (MkCompileOptions options) = + primIO (prim__compileOptionsSerializeAsString options) >>= stringToCharArray diff --git a/spidr/src/Compiler/Xla/Service/HloProto.idr b/spidr/src/Compiler/Xla/Service/HloProto.idr new file mode 100644 index 000000000..9e7ce2b2c --- /dev/null +++ b/spidr/src/Compiler/Xla/Service/HloProto.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.Service.HloProto + +import Compiler.FFI + +public export +data HloModuleProto = MkHloModuleProto GCAnyPtr + +%foreign (libxla "HloModuleProto_SerializeAsString") +prim__hloModuleProtoSerializeAsString : GCAnyPtr -> PrimIO AnyPtr + +export +%foreign (libxla "HloModuleProto_delete") +prim__delete : AnyPtr -> PrimIO () + +||| It is up to the caller to `free` the `CharArray`. +export +serializeAsString : HasIO io => HloModuleProto -> io CharArray +serializeAsString (MkHloModuleProto proto) = + primIO (prim__hloModuleProtoSerializeAsString proto) >>= stringToCharArray diff --git a/test/apple-metal/Main.idr b/test/apple-metal/Main.idr new file mode 100644 index 000000000..db388c772 --- /dev/null +++ b/test/apple-metal/Main.idr @@ -0,0 +1,25 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Main + +import System + +import TestRunner +import PjrtPluginAppleMetal + +partial +main : IO () +main = eitherT (die . show) run device diff --git a/test/apple-metal/apple-metal.ipkg b/test/apple-metal/apple-metal.ipkg new file mode 100644 index 000000000..dd8efc992 --- /dev/null +++ b/test/apple-metal/apple-metal.ipkg @@ -0,0 +1,8 @@ +package apple-metal + +depends = + pjrt-plugin-apple-metal, + runner + +executable = test +main = Main diff --git a/test/xla-cpu/XlaCpu.idr b/test/xla-cpu/Main.idr similarity index 97% rename from test/xla-cpu/XlaCpu.idr rename to test/xla-cpu/Main.idr index 2e5d7b972..854d1eae4 100644 --- a/test/xla-cpu/XlaCpu.idr +++ b/test/xla-cpu/Main.idr @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module XlaCpu +module Main import System diff --git a/test/xla-cpu/xla-cpu.ipkg b/test/xla-cpu/xla-cpu.ipkg index 24255b025..39fd35065 100644 --- a/test/xla-cpu/xla-cpu.ipkg +++ b/test/xla-cpu/xla-cpu.ipkg @@ -5,4 +5,4 @@ depends = runner executable = test -main = XlaCpu +main = Main diff --git a/test/xla-cuda/XlaCuda.idr b/test/xla-cuda/Main.idr similarity index 97% rename from test/xla-cuda/XlaCuda.idr rename to test/xla-cuda/Main.idr index 422589049..4a727f497 100644 --- a/test/xla-cuda/XlaCuda.idr +++ b/test/xla-cuda/Main.idr @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module XlaCuda +module Main import System diff --git a/test/xla-cuda/xla-cuda.ipkg b/test/xla-cuda/xla-cuda.ipkg index 66c3f269b..9d76e1994 100644 --- a/test/xla-cuda/xla-cuda.ipkg +++ b/test/xla-cuda/xla-cuda.ipkg @@ -5,4 +5,4 @@ depends = runner executable = test -main = XlaCuda +main = Main