-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Ops.hlo_call(::String, args...) #358
Changes from 4 commits
0c4c708
910d141
134f92b
53ca1dd
e37852e
3559cfb
5031e2b
e601f8e
dd22990
68138f6
8eb71cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1046,4 +1046,72 @@ function compare( | |
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs)) | ||
end | ||
|
||
""" | ||
Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...) -> NTuple{N, AnyTracedRArray} | ||
|
||
Given a MLIR module given as a string and containing a single function, | ||
calls the given function with the provided arguments and return a tuple for each result of the call. | ||
|
||
```julia-repl | ||
julia> Reactant.@jit( | ||
Ops.hlo_call( | ||
\"\"\" | ||
module { | ||
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { | ||
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> | ||
return %0 : tensor<3xf32> | ||
} | ||
} | ||
\"\"\", | ||
Reactant.to_rarray(Float32[1, 2, 3]), | ||
Reactant.to_rarray(Float32[1, 2, 3]), | ||
) | ||
) | ||
(ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) | ||
``` | ||
""" | ||
function hlo_call(code, args...; location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__)) | ||
new_mod = parse(MLIR.IR.Module, code) | ||
body = MLIR.IR.body(new_mod) | ||
fn = MLIR.IR.first_op(body) | ||
MLIR.IR.rmfromparent!(fn) | ||
Pangoraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
current_module = MLIR.IR.mmodule() | ||
top_level_block = MLIR.IR.body(current_module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we also need to mark all fn's as private, as well as make sure to move all fns in the module (e.g. the main function could call something) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have some utilities here: https://github.com/EnzymeAD/Enzyme-JAX/blob/f6587e37ff7298f2a1a273b08c24d69fca7ff30f/src/enzyme_ad/jax/compile_with_xla.cc#L190 and https://github.com/EnzymeAD/Enzyme-JAX/blob/f6587e37ff7298f2a1a273b08c24d69fca7ff30f/src/enzyme_ad/jax/primitives.py#L811 in Enzyme-JaX for explicitly making we can do all the things There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, do you know if we can encounter ops other than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it’s fine to assume func for now but if desired we could generalize to function interface or whatnot |
||
|
||
orig_name = String(MLIR.IR.attr(fn, "sym_name")) | ||
name = orig_name * "_" * string(gensym()) | ||
|
||
push!(top_level_block, fn) | ||
|
||
ftype_attr = MLIR.IR.attr(fn, "function_type") | ||
ftype = MLIR.IR.Type(ftype_attr) | ||
|
||
Pangoraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
MLIR.IR.attr!(fn, "sym_name", MLIR.IR.Attribute(name)) | ||
|
||
@assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "all inputs to hlo_call should be reactant arrays" | ||
@assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of arguments for function $orig_name" | ||
|
||
operands = [a.mlir_data for a in args] | ||
call = MLIR.Dialects.func.call( | ||
operands; | ||
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], | ||
callee=MLIR.IR.FlatSymbolRefAttribute(name), | ||
location, | ||
) | ||
|
||
return ntuple(MLIR.IR.nresults(call)) do i | ||
out = MLIR.IR.result(call, i) | ||
ty = MLIR.IR.type(out) | ||
sz = MLIR.IR.size(ty) | ||
T = MLIR.IR.julia_type(eltype(ty)) | ||
N = length(sz) | ||
if N == 0 | ||
Reactant.TracedRNumber{T}((), out) | ||
else | ||
Reactant.TracedRArray{T,N}((), out, sz) | ||
end | ||
end | ||
end | ||
|
||
end # module Ops |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -866,3 +866,25 @@ end | |
z = ConcreteRArray([1e-8, 0.001, 2.0]) | ||
@test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit Ops.zeta(s, z) | ||
end | ||
|
||
@testset "hlo_call" begin | ||
x = Float32[1.0, 2.0, 50.0] | ||
y = Float32[-4.0, 0.001, 2.0] | ||
x_reactant = Reactant.to_rarray(x) | ||
y_reactant = Reactant.to_rarray(y) | ||
|
||
@test Reactant.@jit( | ||
Ops.hlo_call( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test with multiple functions in the module. and can you also add a test with two (different) hlo calls that happen to contain functions of the same name (to make sure we do the symbol rename properly) |
||
""" | ||
module { | ||
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { | ||
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> | ||
return %0 : tensor<3xf32> | ||
} | ||
} | ||
""", | ||
x_reactant, | ||
y_reactant, | ||
) | ||
)[1] ≈ x .+ y | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the first op is not
main
? like what if we the code was traced by us and we added some function barriers.maybe we can add a kwarg for selecting the target function (and default it to
main
), so we just iterate over the ops doingfirst(Iterators.filter(op -> String(IR.attr(op, "sym_name")) == target_fn, OperationIterator(body))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is inside the code that was given by the caller. Currently, the expectation is that there is only one function inside the given module. We can surely revisit that with a keyword indeed, or a tuple as for Core.llvmcall
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I would instead ideally have a kwargument fn=main, and we extract that fn as the top level one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a
func_name::String
kwarg for this