Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ops.hlo_call(::String, args...) #358

Merged
merged 11 commits into from
Dec 10, 2024
68 changes: 68 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1046,4 +1046,72 @@ function compare(
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
end

"""
Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...) -> NTuple{N, AnyTracedRArray}

Given a MLIR module given as a string and containing a single function,
calls the given function with the provided arguments and return a tuple for each result of the call.

```julia-repl
julia> Reactant.@jit(
Ops.hlo_call(
\"\"\"
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
\"\"\",
Reactant.to_rarray(Float32[1, 2, 3]),
Reactant.to_rarray(Float32[1, 2, 3]),
)
)
(ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
```
"""
function hlo_call(code, args...; location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__))
new_mod = parse(MLIR.IR.Module, code)
body = MLIR.IR.body(new_mod)
fn = MLIR.IR.first_op(body)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the first op is not main? like what if we the code was traced by us and we added some function barriers.

maybe we can add a kwarg for selecting the target function (and default it to main), so we just iterate over the ops doing first(Iterators.filter(op -> String(IR.attr(op, "sym_name")) == target_fn, OperationIterator(body))

Copy link
Collaborator Author

@Pangoraw Pangoraw Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inside the code that was given by the caller. Currently, the expectation is that there is only one function inside the given module. We can surely revisit that with a keyword indeed, or a tuple as for Core.llvmcall

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I would instead ideally have a kwargument fn=main, and we extract that fn as the top level one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a func_name::String kwarg for this

MLIR.IR.rmfromparent!(fn)
Pangoraw marked this conversation as resolved.
Show resolved Hide resolved

current_module = MLIR.IR.mmodule()
top_level_block = MLIR.IR.body(current_module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to mark all fn's as private, as well as make sure to move all fns in the module (e.g. the main function could call something)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, do you know if we can encounter ops other than func.func (maybe gpu.func in the future?) and what to do with them ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s fine to assume func for now but if desired we could generalize to function interface or whatnot


orig_name = String(MLIR.IR.attr(fn, "sym_name"))
name = orig_name * "_" * string(gensym())

push!(top_level_block, fn)

ftype_attr = MLIR.IR.attr(fn, "function_type")
ftype = MLIR.IR.Type(ftype_attr)

Pangoraw marked this conversation as resolved.
Show resolved Hide resolved
MLIR.IR.attr!(fn, "sym_name", MLIR.IR.Attribute(name))

@assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "all inputs to hlo_call should be reactant arrays"
@assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of arguments for function $orig_name"

operands = [a.mlir_data for a in args]
call = MLIR.Dialects.func.call(
operands;
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],
callee=MLIR.IR.FlatSymbolRefAttribute(name),
location,
)

return ntuple(MLIR.IR.nresults(call)) do i
out = MLIR.IR.result(call, i)
ty = MLIR.IR.type(out)
sz = MLIR.IR.size(ty)
T = MLIR.IR.julia_type(eltype(ty))
N = length(sz)
if N == 0
Reactant.TracedRNumber{T}((), out)
else
Reactant.TracedRArray{T,N}((), out, sz)
end
end
end

end # module Ops
4 changes: 4 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ function make_tracer(
@assert Base.isconcretetype(RT)
nf = fieldcount(RT)

if TT === Module || TT === String
return prev
end

if ismutabletype(TT)
y = ccall(:jl_new_struct_uninit, Any, (Any,), TT)
seen[prev] = y
Expand Down
22 changes: 22 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,25 @@ end
z = ConcreteRArray([1e-8, 0.001, 2.0])
@test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit Ops.zeta(s, z)
end

@testset "hlo_call" begin
x = Float32[1.0, 2.0, 50.0]
y = Float32[-4.0, 0.001, 2.0]
x_reactant = Reactant.to_rarray(x)
y_reactant = Reactant.to_rarray(y)

@test Reactant.@jit(
Ops.hlo_call(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with multiple functions in the module.

and can you also add a test with two (different) hlo calls that happen to contain functions of the same name (to make sure we do the symbol rename properly)

"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x_reactant,
y_reactant,
)
)[1] ≈ x .+ y
end
Loading