-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Ops.hlo_call(::String, args...) #358
Changes from 9 commits
0c4c708
910d141
134f92b
53ca1dd
e37852e
3559cfb
5031e2b
e601f8e
dd22990
68138f6
8eb71cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -866,3 +866,78 @@ end | |||||||||||||||||||||||||||||||||||||||||||||||||||||
z = ConcreteRArray([1e-8, 0.001, 2.0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
@test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit Ops.zeta(s, z) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
@testset "hlo_call" begin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = Float32[1.0, 2.0, 50.0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
y = Float32[-4.0, 0.001, 2.0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_reactant = Reactant.to_rarray(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
y_reactant = Reactant.to_rarray(y) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
@test Reactant.@jit( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ops.hlo_call( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test with multiple functions in the module. and can you also add a test with two (different) hlo calls that happen to contain functions of the same name (to make sure we do the symbol rename properly) |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
module { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return %0 : tensor<3xf32> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_reactant, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
y_reactant, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
)[1] ≈ x .+ y | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
function f_repeat(x, y) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for _ in 1:3 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x, = Ops.hlo_call( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
module { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
func.func @my_add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also add a version of this where the two definitions are different. just because if we fix caching then we might not actually not emit it twice (and thus not check things) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a test with the same name but different definitions: Lines 945 to 970 in 8eb71cc
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return %0 : tensor<3xf32> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
y; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
func_name="my_add", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return x | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
@testset "hlo_call: repeat" begin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = Reactant.to_rarray(randn(Float32, 3)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
y = Reactant.to_rarray(randn(Float32, 3)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
mod = Reactant.@code_hlo optimize = false f_repeat(x, y) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
hlo_ir = repr(mod) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
add_pos = findfirst("stablehlo.add", hlo_ir) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
@test !isnothing(add_pos) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
add_pos = findfirst("stablehlo.add", hlo_ir[last(add_pos):end]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
@test isnothing(add_pos) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
@testset "hlo_call: multiple functions" begin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
@test Reactant.@jit( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ops.hlo_call( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
module { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return %0 : tensor<3xf32> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
%0 = func.call @add(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return %0 : tensor<3xf32> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Reactant.to_rarray(Float32[1, 2, 3]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Reactant.to_rarray(Float32[1, 2, 3]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
)[1] ≈ Float32[2, 4, 6] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need to mark all fn's as private, as well as make sure to move all fns in the module (e.g. the main function could call something)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have some utilities here: https://github.com/EnzymeAD/Enzyme-JAX/blob/f6587e37ff7298f2a1a273b08c24d69fca7ff30f/src/enzyme_ad/jax/compile_with_xla.cc#L190 and https://github.com/EnzymeAD/Enzyme-JAX/blob/f6587e37ff7298f2a1a273b08c24d69fca7ff30f/src/enzyme_ad/jax/primitives.py#L811 in Enzyme-JaX for explicitly making we can do all the things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, do you know if we can encounter ops other than
func.func
(maybegpu.func
in the future?) and what to do with them ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it’s fine to assume func for now but if desired we could generalize to function interface or whatnot