Skip to content

Commit

Permalink
continued fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 7, 2025
1 parent 6dd132b commit f603206
Show file tree
Hide file tree
Showing 11 changed files with 11,367 additions and 13 deletions.
3 changes: 3 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
#include "xla/python/pjrt_ifrt/pjrt_topology.h"
#include "xla/python/pjrt_ifrt/pjrt_tuple.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

using namespace mlir;
using namespace llvm;
using namespace xla;
Expand Down Expand Up @@ -543,6 +545,7 @@ extern "C" void RegisterDialects(MlirContext cctx) {
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::enzyme::EnzymeDialect>();
context.loadDialect<mlir::enzymexla::EnzymeXLADialect>();
context.loadDialect<mlir::triton::TritonDialect>();
context.loadDialect<mlir::tensor::TensorDialect>();
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::mhlo::MhloDialect>();
Expand Down
18 changes: 17 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ cc_library(
"@com_google_absl//absl/log:globals",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPILLVMObjects",
"@jax//jaxlib/mosaic:tpu_dialect_capi",
"@jax//jaxlib/mosaic:tpu_dialect_capi_objects",
"@jax//jaxlib/triton:triton_dialect_capi_objects",
] + select({
"@xla//xla/tsl:is_cuda_enabled_and_oss":[
"@xla//xla/stream_executor/cuda:all_runtime",
Expand Down Expand Up @@ -697,6 +698,21 @@ gentbl_cc_library(
tblgen = "//:mlir-jl-tblgen",
)

gentbl_cc_library(
name = "TritonJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"Triton.jl"
)
],
td_file = "@jax//jaxlib/triton:triton.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@triton//:td_files",
],
tblgen = "//:mlir-jl-tblgen",
)

gentbl_cc_library(
name = "StableHLOJLIncGen",
tbl_outs = [(
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ LLVM_TARGETS = select({
"//conditions:default": ["AMDGPU", "NVPTX"],
}) + ["AArch64", "X86", "ARM"]

LLVM_COMMIT = "b5f21671ef04984bc00770263234dfb94833a274"
LLVM_COMMIT = "5d633fab8679e4f6ebbbed9751a3ee34337c6f92"
LLVM_SHA256 = ""
http_archive(
name = "llvm-raw",
Expand Down
3 changes: 2 additions & 1 deletion deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ for file in [
"Nvvm.jl",
"Gpu.jl",
"Affine.jl",
"MosaicTPU.jl"
"MosaicTPU.jl",
"Triton.jl"
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
Expand Down
16 changes: 8 additions & 8 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ function to_bytes(x)
sz = sizeof(x)
ref = Ref(x)
GC.@preserve ref begin
ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref))
vec = Vector{UInt8}(undef, sz)
ptr = Base.reinterpret(Ptr{Int8}, Base.unsafe_convert(Ptr{Cvoid}, ref))
vec = Vector{Int8}(undef, sz)
for i in 1:sz
@inbounds vec[i] = Base.unsafe_load(ptr, i)
end
Expand Down Expand Up @@ -423,7 +423,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))


c1 = MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1))
for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...]
if sizeof(a) == 0
continue
Expand Down Expand Up @@ -455,15 +454,16 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
end

# TODO check for only integer and explicitly non cutraced types
@show "Warning: using fallback for kernel argument type: $(Core.Typeof(a))"
@show "Warning: using fallback for kernel argument type conversion for argument of type $(Core.Typeof(a)), if this contains a CuTracedArray this will segfault"
MLIR.IR.block!(wrapbody) do
argty = MLIR.IR.input(gpu_function_type, argidx)
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, argidx-1))
argidx += 1
alloc = MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(elem_type))
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
alloc = MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0)))

sz = sizeof(a)
array_ty = MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(UInt8), sz)
cdata = MLIR.Dialects.llvm.mlir_constant(; res=array_type, value=MLIR.IR.Attribute(to_bytes(a)))
array_ty = MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)
cdata = MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.Attribute(to_bytes(a)))
MLIR.Dialects.llvm.store(cdata, alloc)
argres = MLIR.Dialects.llvm.load(alloc; res=argty)
push!(wrapargs, argres)
Expand Down
Loading

0 comments on commit f603206

Please sign in to comment.