Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ops.hlo_call(::String, args...) #358

Merged
merged 11 commits into from
Dec 10, 2024
121 changes: 121 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1046,4 +1046,125 @@ function compare(
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
end

# Generate a unique name given a module hash and a function name.
function _hlo_call_name(orig_name, module_suffix)
return orig_name * "_hlo_call_" * module_suffix
end

"""
Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}

Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
with the provided arguments and return a tuple for each result of the call.

```julia-repl
julia> Reactant.@jit(
Ops.hlo_call(
\"\"\"
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
\"\"\",
Reactant.to_rarray(Float32[1, 2, 3]),
Reactant.to_rarray(Float32[1, 2, 3]),
)
)
(ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
```
"""
function hlo_call(
code,
args...;
func_name="main",
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
)
module_suffix = string(hash(code); base=16)
name_to_call = _hlo_call_name(func_name, module_suffix)

current_module = MLIR.IR.mmodule()
top_level_block = MLIR.IR.body(current_module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to mark all fn's as private, as well as make sure to move all fns in the module (e.g. the main function could call something)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, do you know if we can encounter ops other than func.func (maybe gpu.func in the future?) and what to do with them ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s fine to assume func for now but if desired we could generalize to function interface or whatnot


symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())

fn = MLIR.IR.lookup(
MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call
)
if isnothing(fn)
new_mod = parse(MLIR.IR.Module, code)
new_mod_op = MLIR.IR.Operation(new_mod)
body = MLIR.IR.body(new_mod)

operations = collect(MLIR.IR.OperationIterator(body))
for op in operations
if MLIR.IR.name(op) == "func.func"
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
if fn_name == func_name
fn = op
end

new_name = _hlo_call_name(fn_name, module_suffix)
res = MLIR.IR.LogicalResult(
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(
fn_name, new_name, new_mod_op
),
)
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"

# Set function private
MLIR.IR.attr!(
op,
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
MLIR.IR.Attribute("private"),
)

# Change function name
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name))
end

MLIR.IR.rmfromparent!(op)
push!(top_level_block, op)
end
end

if isnothing(fn)
error("hlo_call: could not find function $func_name in the provided module")
end

ftype_attr = MLIR.IR.attr(fn, "function_type")
ftype = MLIR.IR.Type(ftype_attr)

Pangoraw marked this conversation as resolved.
Show resolved Hide resolved
@assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "hlo_call: all inputs to hlo_call should be reactant arrays"
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name"

for (i, arg) in enumerate(args)
expected_type = MLIR.IR.input(ftype, i)
arg_type = MLIR.IR.type(arg.mlir_data)
@assert expected_type == arg_type "hlo_call: argument #$i has the wrong type (expected $expected_type, got $arg_type)"
end

operands = [a.mlir_data for a in args]
call = MLIR.Dialects.func.call(
operands;
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],
callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
location,
)

return ntuple(MLIR.IR.nresults(call)) do i
out = MLIR.IR.result(call, i)
ty = MLIR.IR.type(out)
sz = MLIR.IR.size(ty)
T = MLIR.IR.julia_type(eltype(ty))
N = length(sz)
if N == 0
Reactant.TracedRNumber{T}((), out)
else
Reactant.TracedRArray{T,N}((), out, sz)
end
end
end

end # module Ops
4 changes: 4 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ function make_tracer(
@assert Base.isconcretetype(RT)
nf = fieldcount(RT)

if TT === Module || TT === String
return prev
end

if ismutabletype(TT)
y = ccall(:jl_new_struct_uninit, Any, (Any,), TT)
seen[prev] = y
Expand Down
14 changes: 11 additions & 3 deletions src/mlir/IR/SymbolTable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@ Base.convert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = st.st
Looks up a symbol with the given name in the given symbol table and returns the operation that corresponds to the symbol.
If the symbol cannot be found, returns a null operation.
"""
lookup(st::SymbolTable, name::AbstractString) =
Operation(API.mlirSymbolTableLookup(st, name))
Base.getindex(st::SymbolTable, name::AbstractString) = lookup(st, name)
function lookup(st::SymbolTable, name::AbstractString)
raw_op = API.mlirSymbolTableLookup(st, name)
if raw_op.ptr == C_NULL
nothing
else
Operation(raw_op, false)
end
end
function Base.getindex(st::SymbolTable, name::AbstractString)
@something(lookup(st, name), throw(KeyError(name)))
end

"""
push!(symboltable, operation)
Expand Down
75 changes: 75 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,78 @@ end
z = ConcreteRArray([1e-8, 0.001, 2.0])
@test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit Ops.zeta(s, z)
end

@testset "hlo_call" begin
x = Float32[1.0, 2.0, 50.0]
y = Float32[-4.0, 0.001, 2.0]
x_reactant = Reactant.to_rarray(x)
y_reactant = Reactant.to_rarray(y)

@test Reactant.@jit(
Ops.hlo_call(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with multiple functions in the module.

and can you also add a test with two (different) hlo calls that happen to contain functions of the same name (to make sure we do the symbol rename properly)

"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x_reactant,
y_reactant,
)
)[1] ≈ x .+ y
end

function f_repeat(x, y)
for _ in 1:3
x, = Ops.hlo_call(
"""
module {
func.func @my_add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a version of this where the two definitions are different.

just because if we fix caching then we might not actually not emit it twice (and thus not check things)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test with the same name but different definitions:

Reactant.jl/test/ops.jl

Lines 945 to 970 in 8eb71cc

function f_multiple_hlo_calls(x, y)
x, = Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x,
y,
)
return Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.multiply %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x,
y,
)
end

%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x,
y;
func_name="my_add",
)
end
return x
end

@testset "hlo_call: repeat" begin
x = Reactant.to_rarray(randn(Float32, 3))
y = Reactant.to_rarray(randn(Float32, 3))
mod = Reactant.@code_hlo optimize = false f_repeat(x, y)
hlo_ir = repr(mod)

add_pos = findfirst("stablehlo.add", hlo_ir)
@test !isnothing(add_pos)

add_pos = findfirst("stablehlo.add", hlo_ir[last(add_pos):end])
@test isnothing(add_pos)
end

@testset "hlo_call: multiple functions" begin
@test Reactant.@jit(
Ops.hlo_call(
"""
module {
func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = func.call @add(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
Reactant.to_rarray(Float32[1, 2, 3]),
Reactant.to_rarray(Float32[1, 2, 3]),
)
)[1] ≈ Float32[2, 4, 6]
end
Loading