Skip to content

Commit

Permalink
Ka2 (#498)
Browse files Browse the repository at this point in the history
* Fix kernel abstractions with Reactant GPU

* fixup

* fix

* wip

* continued fixes

* fixup

* fix

* fix

* continue

* fix

* more bump

* Attempt bump

* fix build

* more fix

* Update WORKSPACE

* Update Project.toml

* Update cuda.jl

* Update cuda.jl

* Update cuda.jl

* Delete src/mlir/Dialects/MosaicTPU.jl
  • Loading branch information
wsmoses authored Jan 8, 2025
1 parent 47f363b commit ff229a2
Show file tree
Hide file tree
Showing 17 changed files with 11,302 additions and 2,008 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.36"
Reactant_jll = "0.0.37"
Scratch = "1.2"
SpecialFunctions = "2"
Statistics = "1.10"
Expand Down
103 changes: 64 additions & 39 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
#include "xla/python/pjrt_ifrt/pjrt_topology.h"
#include "xla/python/pjrt_ifrt/pjrt_tuple.h"

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"

using namespace mlir;
using namespace llvm;
using namespace xla;
Expand Down Expand Up @@ -432,11 +435,26 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
SMDiagnostic Err;
auto llvmModule =
llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context);
if (!llvmModule) {
std::string err_str;
llvm::raw_string_ostream err_stream(err_str);
Err.print(/*ProgName=*/"LLVMToMLIR", err_stream);
err_stream.flush();
if (ReactantThrowError) {
llvm::errs() << lmod << "\n";
ReactantThrowError(err_str.c_str());
return wrap((mlir::ModuleOp)nullptr);
}
}
mlir::MLIRContext &context = *unwrap(cctx);
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context,
/*emitExpensiveWarnings*/ false,
/*dropDICompositeElements*/ false)
.release();
if (!res) {
llvm::errs() << lmod << "\n";
ReactantThrowError("Could not translate LLVM IR to MLIR Module");
}
return wrap(res);
}

Expand Down Expand Up @@ -530,6 +548,8 @@ extern "C" void RegisterDialects(MlirContext cctx) {
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::enzyme::EnzymeDialect>();
context.loadDialect<mlir::enzymexla::EnzymeXLADialect>();
context.loadDialect<mlir::triton::TritonDialect>();
context.loadDialect<mlir::tpu::TPUDialect>();
context.loadDialect<mlir::tensor::TensorDialect>();
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::mhlo::MhloDialect>();
Expand Down Expand Up @@ -1130,9 +1150,10 @@ extern "C" const ifrt::Sharding *ifrt_array_sharding(ifrt::Array *array) {
return &(array->sharding());
}

extern "C" PjRtLayout *ifrt_array_layout(ifrt::Array *array) {
return MyValueOrThrow(array->layout()).release();
}
// @mofeng this is now a shared ptr, will let you fix
// extern "C" PjRtLayout *ifrt_array_layout(ifrt::Array *array) {
// return MyValueOrThrow(array->layout()).release();
// }

// TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays
// TODO xla::ifrt::Array::FullyReplicatedShard
Expand Down Expand Up @@ -1380,25 +1401,27 @@ ifrt_executable_output_shardings(ifrt::Executable *executable) {
return std::make_tuple(shardings.value().size(), shardings.value().data());
}

extern "C" std::tuple<size_t, xla::PjRtLayout **>
ifrt_executable_parameter_layouts(ifrt::Executable *executable) {
auto layouts = MyValueOrThrow(executable->GetParameterLayouts());
auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
for (int i = 0; i < layouts.size(); i++) {
layouts_ptr[i] = layouts[i].release();
}
return std::make_tuple(layouts.size(), layouts_ptr);
}
// @mofeng this is now a shared ptr, will let you fix
// extern "C" std::tuple<size_t, xla::PjRtLayout **>
// ifrt_executable_parameter_layouts(ifrt::Executable *executable) {
// auto layouts = MyValueOrThrow(executable->GetParameterLayouts());
// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
// for (int i = 0; i < layouts.size(); i++) {
// layouts_ptr[i] = layouts[i].release();
// }
// return std::make_tuple(layouts.size(), layouts_ptr);
// }

extern "C" std::tuple<size_t, xla::PjRtLayout **>
ifrt_executable_output_layouts(ifrt::Executable *executable) {
auto layouts = MyValueOrThrow(executable->GetOutputLayouts());
auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
for (int i = 0; i < layouts.size(); i++) {
layouts_ptr[i] = layouts[i].release();
}
return std::make_tuple(layouts.size(), layouts_ptr);
}
// @mofeng this is now a shared ptr, will let you fix
// extern "C" std::tuple<size_t, xla::PjRtLayout **>
// ifrt_executable_output_layouts(ifrt::Executable *executable) {
// auto layouts = MyValueOrThrow(executable->GetOutputLayouts());
// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
// for (int i = 0; i < layouts.size(); i++) {
// layouts_ptr[i] = layouts[i].release();
// }
// return std::make_tuple(layouts.size(), layouts_ptr);
// }

extern "C" std::tuple<size_t, xla::HloModule **>
ifrt_executable_hlo_modules(ifrt::Executable *executable) {
Expand Down Expand Up @@ -1491,25 +1514,27 @@ ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable *executable) {
return std::make_tuple(shardings.value().size(), shardings.value().data());
}

extern "C" std::tuple<size_t, xla::PjRtLayout **>
ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable *executable) {
auto layouts = MyValueOrThrow(executable->GetParameterLayouts());
auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
for (int i = 0; i < layouts.size(); i++) {
layouts_ptr[i] = layouts[i].release();
}
return std::make_tuple(layouts.size(), layouts_ptr);
}
// @mofeng this is now a shared ptr, will let you fix
// extern "C" std::tuple<size_t, xla::PjRtLayout **>
// ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable *executable) {
// auto layouts = MyValueOrThrow(executable->GetParameterLayouts());
// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
// for (int i = 0; i < layouts.size(); i++) {
// layouts_ptr[i] = layouts[i].release();
// }
// return std::make_tuple(layouts.size(), layouts_ptr);
// }

extern "C" std::tuple<size_t, xla::PjRtLayout **>
ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable *executable) {
auto layouts = MyValueOrThrow(executable->GetOutputLayouts());
auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
for (int i = 0; i < layouts.size(); i++) {
layouts_ptr[i] = layouts[i].release();
}
return std::make_tuple(layouts.size(), layouts_ptr);
}
// @mofeng this is now a shared ptr, will let you fix
// extern "C" std::tuple<size_t, xla::PjRtLayout **>
// ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable *executable) {
// auto layouts = MyValueOrThrow(executable->GetOutputLayouts());
// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
// for (int i = 0; i < layouts.size(); i++) {
// layouts_ptr[i] = layouts[i].release();
// }
// return std::make_tuple(layouts.size(), layouts_ptr);
// }

extern "C" std::tuple<size_t, xla::HloModule **>
ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable *executable) {
Expand Down
32 changes: 32 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,9 @@ cc_library(
"@com_google_absl//absl/log:initialize",
"@com_google_absl//absl/log:globals",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPILLVMObjects",
"@jax//jaxlib/mosaic:tpu_dialect_capi_objects",
"@jax//jaxlib/triton:triton_dialect_capi_objects",
] + select({
"@xla//xla/tsl:is_cuda_enabled_and_oss":[
"@xla//xla/stream_executor/cuda:all_runtime",
Expand Down Expand Up @@ -681,6 +684,35 @@ gentbl_cc_library(
tblgen = "//:mlir-jl-tblgen",
)

gentbl_cc_library(
name = "TPUJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"TPU.jl"
)
],
td_file = "@jax//jaxlib/mosaic:dialect/tpu/tpu.td",
deps = [
"@jax//jaxlib/mosaic:tpu_td_files",
],
tblgen = "//:mlir-jl-tblgen",
)

gentbl_cc_library(
name = "TritonJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"Triton.jl"
)
],
td_file = "@jax//jaxlib/triton:triton.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@triton//:td_files",
],
tblgen = "//:mlir-jl-tblgen",
)

gentbl_cc_library(
name = "StableHLOJLIncGen",
tbl_outs = [(
Expand Down
37 changes: 36 additions & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "a8451f97231186cf13fed2d5d801dc05156d76ed"
ENZYMEXLA_COMMIT = "85612ea74731f02aa4e30800038e065912d37ae2"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down Expand Up @@ -94,6 +94,41 @@ LLVM_TARGETS = select({
"//conditions:default": ["AMDGPU", "NVPTX"],
}) + ["AArch64", "X86", "ARM"]

# Uncomment these lines to use a custom LLVM commit
# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908"
# LLVM_SHA256 = ""
# http_archive(
# name = "llvm-raw",
# build_file_content = "# empty",
# sha256 = LLVM_SHA256,
# strip_prefix = "llvm-project-" + LLVM_COMMIT,
# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
# )
#
#
# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
# maybe(
# http_archive,
# name = "llvm_zlib",
# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
# strip_prefix = "zlib-ng-2.0.7",
# urls = [
# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
# ],
# )
#
# maybe(
# http_archive,
# name = "llvm_zstd",
# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
# strip_prefix = "zstd-1.5.2",
# urls = [
# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
# ],
# )

http_archive(
name = "jax",
sha256 = JAX_SHA256,
Expand Down
2 changes: 2 additions & 0 deletions deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ for file in [
"Nvvm.jl",
"Gpu.jl",
"Affine.jl",
"TPU.jl",
"Triton.jl"
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
Expand Down
Loading

0 comments on commit ff229a2

Please sign in to comment.