From 0c4c708c27ecf391a558d5650f09fe790d956e76 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 13:23:24 +0100 Subject: [PATCH 01/11] special case String and Module in make_tracer --- src/Tracing.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Tracing.jl b/src/Tracing.jl index 58db7352c..4ea8172aa 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -284,6 +284,10 @@ function make_tracer( @assert Base.isconcretetype(RT) nf = fieldcount(RT) + if TT === Module || TT === String + return prev + end + if ismutabletype(TT) y = ccall(:jl_new_struct_uninit, Any, (Any,), TT) seen[prev] = y From 910d141a565b48ebefb10fd181a7efdf8a214f82 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 13:23:37 +0100 Subject: [PATCH 02/11] implement Ops.hlo_call --- src/Ops.jl | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++++ test/ops.jl | 22 +++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/src/Ops.jl b/src/Ops.jl index 2148cb5eb..bf93dbdb2 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1046,4 +1046,72 @@ function compare( return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs)) end +""" + Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...) -> NTuple{N, AnyTracedRArray} + +Given a MLIR module given as a string and containing a single function, +calls the given function with the provided arguments and return a tuple for each result of the call. + +```julia-repl +julia> Reactant.@jit( + Ops.hlo_call( + \"\"\" + module { + func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } + } + \"\"\", + Reactant.to_rarray(Float32[1, 2, 3]), + Reactant.to_rarray(Float32[1, 2, 3]), + ) + ) +(ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) +``` +""" +function hlo_call(code, args...; location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__)) + new_mod = parse(MLIR.IR.Module, code) + body = MLIR.IR.body(new_mod) + fn = MLIR.IR.first_op(body) + MLIR.IR.rmfromparent!(fn) + + current_module = MLIR.IR.mmodule() + top_level_block = MLIR.IR.body(current_module) + + orig_name = String(MLIR.IR.attr(fn, "sym_name")) + name = orig_name * "_" * string(gensym()) + + push!(top_level_block, fn) + + ftype_attr = MLIR.IR.attr(fn, "function_type") + ftype = MLIR.IR.Type(ftype_attr) + + MLIR.IR.attr!(fn, "sym_name", MLIR.IR.Attribute(name)) + + @assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "all inputs to hlo_call should be reactant arrays" + @assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of argument for function $orig_name" + + operands = [a.mlir_data for a in args] + call = MLIR.Dialects.func.call( + operands; + result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], + callee=MLIR.IR.FlatSymbolRefAttribute(name), + location, + ) + + return ntuple(MLIR.IR.nresults(call)) do i + out = MLIR.IR.result(call, i) + ty = MLIR.IR.type(out) + sz = MLIR.IR.size(ty) + T = MLIR.IR.julia_type(eltype(ty)) + N = length(sz) + if N == 0 + Reactant.TracedRNumber{T}((), out) + else + Reactant.TracedRArray{T,N}((), out, sz) + end + end end + +end # module Ops diff --git a/test/ops.jl b/test/ops.jl index 0437b2723..18efd06af 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -866,3 +866,25 @@ end z = ConcreteRArray([1e-8, 0.001, 2.0]) @test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit Ops.zeta(s, z) end + +@testset "hlo_call" begin + x = Float32[1.0, 2.0, 50.0] + y = Float32[-4.0, 0.001, 2.0] + x_reactant = Reactant.to_rarray(x) + y_reactant = Reactant.to_rarray(y) + + @test Reactant.@jit( + Ops.hlo_call( + """ +module { + func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } +} +""", + x_reactant, + y_reactant, + ) + )[1] ≈ x .+ y +end From 134f92b72bd53f0316f3f0fe5e84b74ea9cba82a Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 16:13:44 +0100 Subject: [PATCH 03/11] formatting --- test/ops.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/ops.jl b/test/ops.jl index 18efd06af..ddca11db4 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -876,13 +876,13 @@ end @test Reactant.@jit( Ops.hlo_call( """ -module { - func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { - %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> - return %0 : tensor<3xf32> - } -} -""", + module { + func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } + } + """, x_reactant, y_reactant, ) From 53ca1ddef3d85a7e71d75f9ee00d1413f61531b3 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 16:15:52 +0100 Subject: [PATCH 04/11] Update src/Ops.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/Ops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index bf93dbdb2..6c60be2a1 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1090,7 +1090,7 @@ function hlo_call(code, args...; location=mlir_stacktrace("hlo_call", @__FILE__, MLIR.IR.attr!(fn, "sym_name", MLIR.IR.Attribute(name)) @assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "all inputs to hlo_call should be reactant arrays" - @assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of argument for function $orig_name" + @assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of arguments for function $orig_name" operands = [a.mlir_data for a in args] call = MLIR.Dialects.func.call( From e37852eddc7dfa5c23cbda394ec2f52ac9eb6363 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 19:16:44 +0100 Subject: [PATCH 05/11] SymbolTable: fix lookup --- src/mlir/IR/SymbolTable.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/mlir/IR/SymbolTable.jl b/src/mlir/IR/SymbolTable.jl index 3b5d2073b..74a11e753 100644 --- a/src/mlir/IR/SymbolTable.jl +++ b/src/mlir/IR/SymbolTable.jl @@ -25,9 +25,17 @@ Base.convert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = st.st Looks up a symbol with the given name in the given symbol table and returns the operation that corresponds to the symbol. If the symbol cannot be found, returns a null operation. """ -lookup(st::SymbolTable, name::AbstractString) = - Operation(API.mlirSymbolTableLookup(st, name)) -Base.getindex(st::SymbolTable, name::AbstractString) = lookup(st, name) +function lookup(st::SymbolTable, name::AbstractString) + raw_op = API.mlirSymbolTableLookup(st, name) + if raw_op.ptr == C_NULL + nothing + else + Operation(raw_op, false) + end +end +function Base.getindex(st::SymbolTable, name::AbstractString) + @something(lookup(st, name), throw(KeyError(name))) +end """ push!(symboltable, operation) From 3559cfb9e8f20c839a116b552482bde8f13cd4a9 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 19:18:42 +0100 Subject: [PATCH 06/11] cache and more validation, also specify name to call --- src/Ops.jl | 76 ++++++++++++++++++++++++++++++++++++++++++----------- test/ops.jl | 32 ++++++++++++++++++++++ 2 files changed, 93 insertions(+), 15 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 6c60be2a1..95ee388fd 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1046,11 +1046,16 @@ function compare( return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs)) end +# Generate a unique name given a module hash and a function name. +function _hlo_call_name(orig_name, module_suffix) + return orig_name * "_hlo_call_" * module_suffix +end + """ - Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...) -> NTuple{N, AnyTracedRArray} + Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray} -Given a MLIR module given as a string and containing a single function, -calls the given function with the provided arguments and return a tuple for each result of the call. +Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main") +with the provided arguments and return a tuple for each result of the call. ```julia-repl julia> Reactant.@jit( @@ -1070,33 +1075,74 @@ julia> Reactant.@jit( (ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) ``` """ -function hlo_call(code, args...; location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__)) - new_mod = parse(MLIR.IR.Module, code) - body = MLIR.IR.body(new_mod) - fn = MLIR.IR.first_op(body) - MLIR.IR.rmfromparent!(fn) +function hlo_call( + code, + args...; + func_name="main", + location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), +) + module_suffix = string(hash(code); base=16) + name_to_call = _hlo_call_name(func_name, module_suffix) current_module = MLIR.IR.mmodule() top_level_block = MLIR.IR.body(current_module) - orig_name = String(MLIR.IR.attr(fn, "sym_name")) - name = orig_name * "_" * string(gensym()) + symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - push!(top_level_block, fn) + fn = MLIR.IR.lookup( + MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call + ) + if isnothing(fn) + new_mod = parse(MLIR.IR.Module, code) + body = MLIR.IR.body(new_mod) + + for op in MLIR.IR.OperationIterator(body) + MLIR.IR.rmfromparent!(op) + + # Set function private + MLIR.IR.attr!( + op, + MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), + MLIR.IR.Attribute("private"), + ) + + fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) + if fn_name == func_name + fn = op + end + + # Change function name + MLIR.IR.attr!( + op, + symbol_attr_name, + MLIR.IR.Attribute(_hlo_call_name(fn_name, module_suffix)), + ) + + push!(top_level_block, op) + end + end + + if isnothing(fn) + error("hlo_call: could not find function $func_name in the provided module") + end ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) - MLIR.IR.attr!(fn, "sym_name", MLIR.IR.Attribute(name)) - @assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "all inputs to hlo_call should be reactant arrays" - @assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of arguments for function $orig_name" + @assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of arguments for function $func_name" + + for (i, arg) in enumerate(args) + expected_type = MLIR.IR.input(ftype, i) + arg_type = MLIR.IR.type(arg.mlir_data) + @assert expected_type == arg_type "argument #$i has the wrong type (expected $expected_type, got $arg_type)" + end operands = [a.mlir_data for a in args] call = MLIR.Dialects.func.call( operands; result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], - callee=MLIR.IR.FlatSymbolRefAttribute(name), + callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call), location, ) diff --git a/test/ops.jl b/test/ops.jl index ddca11db4..fab7f3e5e 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -888,3 +888,35 @@ end ) )[1] ≈ x .+ y end + +function f_repeat(x, y) + for _ in 1:3 + x, = Ops.hlo_call( + """ + module { + func.func @my_add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } + } + """, + x, + y; + func_name="my_add", + ) + end + return x +end + +@testset "hlo_call: repeat" begin + x = Reactant.to_rarray(randn(Float32, 3)) + y = Reactant.to_rarray(randn(Float32, 3)) + mod = Reactant.@code_hlo optimize = false f_repeat(x, y) + hlo_ir = repr(mod) + + add_pos = findfirst("stablehlo.add", hlo_ir) + @test !isnothing(add_pos) + + add_pos = findfirst("stablehlo.add", hlo_ir[last(add_pos):end]) + @test isnothing(add_pos) +end From 5031e2b6adf80b49884122dba3260bfbb6e0af49 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 19:21:53 +0100 Subject: [PATCH 07/11] error if not func.func --- src/Ops.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Ops.jl b/src/Ops.jl index 95ee388fd..601cfad1e 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1097,6 +1097,8 @@ function hlo_call( body = MLIR.IR.body(new_mod) for op in MLIR.IR.OperationIterator(body) + @assert MLIR.IR.name(op) == "func.func" "hlo_call: the given module should only contain `func.func` operations" + MLIR.IR.rmfromparent!(op) # Set function private From e601f8e1327914f8b62e94028cd5f81a51e2ee88 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 19:30:45 +0100 Subject: [PATCH 08/11] only do special things for func.func --- src/Ops.jl | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 601cfad1e..551cf4434 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1097,29 +1097,29 @@ function hlo_call( body = MLIR.IR.body(new_mod) for op in MLIR.IR.OperationIterator(body) - @assert MLIR.IR.name(op) == "func.func" "hlo_call: the given module should only contain `func.func` operations" - MLIR.IR.rmfromparent!(op) - # Set function private - MLIR.IR.attr!( - op, - MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), - MLIR.IR.Attribute("private"), - ) - - fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) - if fn_name == func_name - fn = op + if MLIR.IR.name(op) == "func.func" + fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) + if fn_name == func_name + fn = op + end + + # Set function private + MLIR.IR.attr!( + op, + MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), + MLIR.IR.Attribute("private"), + ) + + # Change function name + MLIR.IR.attr!( + op, + symbol_attr_name, + MLIR.IR.Attribute(_hlo_call_name(fn_name, module_suffix)), + ) end - # Change function name - MLIR.IR.attr!( - op, - symbol_attr_name, - MLIR.IR.Attribute(_hlo_call_name(fn_name, module_suffix)), - ) - push!(top_level_block, op) end end From dd229903615e0bdb110b9991d976c8bb2d03e591 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 19:58:45 +0100 Subject: [PATCH 09/11] symbol_rename --- src/Ops.jl | 27 ++++++++++++++++----------- test/ops.jl | 21 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 551cf4434..b78e0e239 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1094,17 +1094,25 @@ function hlo_call( ) if isnothing(fn) new_mod = parse(MLIR.IR.Module, code) + new_mod_op = MLIR.IR.Operation(new_mod) body = MLIR.IR.body(new_mod) - for op in MLIR.IR.OperationIterator(body) - MLIR.IR.rmfromparent!(op) - + operations = collect(MLIR.IR.OperationIterator(body)) + for op in operations if MLIR.IR.name(op) == "func.func" fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) if fn_name == func_name fn = op end + new_name = _hlo_call_name(fn_name, module_suffix) + res = MLIR.IR.LogicalResult( + MLIR.API.mlirSymbolTableReplaceAllSymbolUses( + fn_name, new_name, new_mod_op + ), + ) + @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" + # Set function private MLIR.IR.attr!( op, @@ -1113,13 +1121,10 @@ function hlo_call( ) # Change function name - MLIR.IR.attr!( - op, - symbol_attr_name, - MLIR.IR.Attribute(_hlo_call_name(fn_name, module_suffix)), - ) + MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name)) end + MLIR.IR.rmfromparent!(op) push!(top_level_block, op) end end @@ -1131,13 +1136,13 @@ function hlo_call( ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) - @assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "all inputs to hlo_call should be reactant arrays" - @assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of arguments for function $func_name" + @assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "hlo_call: all inputs to hlo_call should be reactant arrays" + @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name" for (i, arg) in enumerate(args) expected_type = MLIR.IR.input(ftype, i) arg_type = MLIR.IR.type(arg.mlir_data) - @assert expected_type == arg_type "argument #$i has the wrong type (expected $expected_type, got $arg_type)" + @assert expected_type == arg_type "hlo_call: argument #$i has the wrong type (expected $expected_type, got $arg_type)" end operands = [a.mlir_data for a in args] diff --git a/test/ops.jl b/test/ops.jl index fab7f3e5e..8bcb759d4 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -920,3 +920,24 @@ end add_pos = findfirst("stablehlo.add", hlo_ir[last(add_pos):end]) @test isnothing(add_pos) end + +@testset "hlo_call: multiple functions" begin + @test Reactant.@jit( + Ops.hlo_call( + """ + module { + func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } + func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = func.call @add(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> + } + } + """, + Reactant.to_rarray(Float32[1, 2, 3]), + Reactant.to_rarray(Float32[1, 2, 3]), + ) + )[1] ≈ Float32[2, 4, 6] +end From 68138f625e8a3afd09856ef666cc573818607fac Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 20:03:07 +0100 Subject: [PATCH 10/11] add multiple call test --- test/ops.jl | 44 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/test/ops.jl b/test/ops.jl index 8bcb759d4..0b7c33566 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -130,7 +130,7 @@ end ] x = ConcreteRArray([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0]) @test [3.0, 3.0, 3.3, 4.4, 5.5, 6.6, 7.0, 7.0, 7.0, 7.0] == - @jit Ops.clamp(_min, x, _max) + @jit Ops.clamp(_min, x, _max) end end @@ -166,7 +166,7 @@ end 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im ]) @test [1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im, 1.0 + 0.0im] ≈ - @jit Ops.cosine(x) + @jit Ops.cosine(x) end end @@ -216,7 +216,7 @@ end # NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation @test sum(a .* b) ≈ @jit f1(a, b) @test kron(reshape(Array(a), length(a), 1), reshape(Array(b), 1, length(b))) ≈ - @jit fouter(a, b) + @jit fouter(a, b) @test a .* b ≈ @jit fouter_batch1(a, b) end @@ -415,7 +415,7 @@ end # on unsigned integers: (1) bitcast, (2) change sign and (3) bitcast x = ConcreteRArray(UInt[0, 1, 10]) @test reinterpret(UInt, Base.checked_neg.(reinterpret.(Int, Array(x)))) == - @jit Ops.negate(x) + @jit Ops.negate(x) x = ConcreteRArray([-1.0, 0.0, 1.0, 10.0]) @test [1.0, 0.0, -1.0, -10.0] ≈ @jit Ops.negate(x) @@ -639,7 +639,7 @@ end 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im ]) @test [0.0 + 0.0im, 1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im] ≈ - @jit Ops.sine(x) + @jit Ops.sine(x) end end @@ -847,7 +847,7 @@ end x = ConcreteRArray([-1.0, 0.0, 1.0, 1.0, 2.5]) m = ConcreteRArray([3.0, 3.0, 2.0, 3.0, 4.0]) @test SpecialFunctions.polygamma.(Int.(Array(m)), Array(x)) ≈ - @jit Ops.polygamma(m, x) + @jit Ops.polygamma(m, x) end end @@ -941,3 +941,35 @@ end ) )[1] ≈ Float32[2, 4, 6] end + +function f_multiple_hlo_calls(x, y) + x, = Ops.hlo_call( + """ + module { + func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } + } + """, x, y, + ) + return Ops.hlo_call( + """ + module { + func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.multiply %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } + } + """, x, y + ) +end + +@testset "hlo_call: multiple hlo_calls" begin + x = Float32[1.0, 2.0, 50.0] + y = Float32[-4.0, 0.001, 2.0] + x_reactant = Reactant.to_rarray(x) + y_reactant = Reactant.to_rarray(y) + + @test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] ≈ (x .+ y) .* y +end From 8eb71cc9f64ecad760765ee30a4c3bc27eba011e Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 10 Dec 2024 20:28:12 +0100 Subject: [PATCH 11/11] rename then remove from parsed module --- src/Ops.jl | 2 ++ test/ops.jl | 28 ++++++++++++++++------------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index b78e0e239..748d06f66 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1123,7 +1123,9 @@ function hlo_call( # Change function name MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name)) end + end + for op in operations MLIR.IR.rmfromparent!(op) push!(top_level_block, op) end diff --git a/test/ops.jl b/test/ops.jl index 0b7c33566..0600d0b86 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -130,7 +130,7 @@ end ] x = ConcreteRArray([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0]) @test [3.0, 3.0, 3.3, 4.4, 5.5, 6.6, 7.0, 7.0, 7.0, 7.0] == - @jit Ops.clamp(_min, x, _max) + @jit Ops.clamp(_min, x, _max) end end @@ -166,7 +166,7 @@ end 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im ]) @test [1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im, 1.0 + 0.0im] ≈ - @jit Ops.cosine(x) + @jit Ops.cosine(x) end end @@ -216,7 +216,7 @@ end # NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation @test sum(a .* b) ≈ @jit f1(a, b) @test kron(reshape(Array(a), length(a), 1), reshape(Array(b), 1, length(b))) ≈ - @jit fouter(a, b) + @jit fouter(a, b) @test a .* b ≈ @jit fouter_batch1(a, b) end @@ -415,7 +415,7 @@ end # on unsigned integers: (1) bitcast, (2) change sign and (3) bitcast x = ConcreteRArray(UInt[0, 1, 10]) @test reinterpret(UInt, Base.checked_neg.(reinterpret.(Int, Array(x)))) == - @jit Ops.negate(x) + @jit Ops.negate(x) x = ConcreteRArray([-1.0, 0.0, 1.0, 10.0]) @test [1.0, 0.0, -1.0, -10.0] ≈ @jit Ops.negate(x) @@ -639,7 +639,7 @@ end 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im ]) @test [0.0 + 0.0im, 1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im] ≈ - @jit Ops.sine(x) + @jit Ops.sine(x) end end @@ -847,7 +847,7 @@ end x = ConcreteRArray([-1.0, 0.0, 1.0, 1.0, 2.5]) m = ConcreteRArray([3.0, 3.0, 2.0, 3.0, 4.0]) @test SpecialFunctions.polygamma.(Int.(Array(m)), Array(x)) ≈ - @jit Ops.polygamma(m, x) + @jit Ops.polygamma(m, x) end end @@ -926,14 +926,14 @@ end Ops.hlo_call( """ module { - func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { - %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> - return %0 : tensor<3xf32> - } func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { %0 = func.call @add(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> } + func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } } """, Reactant.to_rarray(Float32[1, 2, 3]), @@ -951,7 +951,9 @@ function f_multiple_hlo_calls(x, y) return %0 : tensor<3xf32> } } - """, x, y, + """, + x, + y, ) return Ops.hlo_call( """ @@ -961,7 +963,9 @@ function f_multiple_hlo_calls(x, y) return %0 : tensor<3xf32> } } - """, x, y + """, + x, + y, ) end