-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support lowering custom fp types #596
base: main
Are you sure you want to change the base?
Conversation
src/TracedRArray.jl
Outdated
@@ -507,7 +509,7 @@ end | |||
# we need to override the outer copy method to make sure we never fall back to scalar | |||
# iteration (see, e.g., CUDA.jl#145) | |||
function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) | |||
fn = if bc.f isa Type && bc.f <: ReactantPrimitive | |||
fn = if bc.f isa Type && is_reactant_primitive(bc.f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should proably be more careful with this. For example, Ops.add won't necessarily do what one wants
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For custom interface, we probably need:
is_reactant_primitive
primitive_type
DenseElementsAttribute
- a way to map mlir_type to julia type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The core idea is that we can allow support for external packages like Floats8.jl , https://github.com/JuliaMath/DoubleFloats.jl, etc. without having to pull them into Reactant
92cc29d
to
89ef696
Compare
using Float8s, Reactant
Reactant.to_reactant_primitive_type(::Type{Float8_4}) = Reactant.F8E4M3FNUZ
x = rand(Float32, 10, 3) .|> Float8_4
x_ra = Reactant.to_rarray(x)
@code_hlo .+(x_ra, x_ra) module {
func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> tensor<3x10xf8E4M3FN> {
%0 = stablehlo.add %arg0, %arg0 : tensor<3x10xf8E4M3FN>
return %0 : tensor<3x10xf8E4M3FN>
}
} |
noice |
d68cda9
to
e60e7cf
Compare
julia> @code_hlo optimize=false sum(x_ra)
module {
func.func private @identity_broadcast_scalar(%arg0: tensor<f8E4M3FN>) -> tensor<f8E4M3FN> {
return %arg0 : tensor<f8E4M3FN>
}
func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> (tensor<f8E4M3FN>, tensor<3x10xf8E4M3FN>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x10xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f16>
%1 = stablehlo.convert %cst : (tensor<f16>) -> tensor<f8E4M3FN>
%2 = enzyme.batch @identity_broadcast_scalar(%0) {batch_shape = array<i64: 10, 3>} : (tensor<10x3xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
%3 = stablehlo.convert %2 : tensor<10x3xf8E4M3FN>
%4 = stablehlo.reduce(%3 init: %1) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x3xf8E4M3FN>, tensor<f8E4M3FN>) -> tensor<f8E4M3FN>
%5 = stablehlo.transpose %0, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
return %4, %5 : tensor<f8E4M3FN>, tensor<3x10xf8E4M3FN>
}
}
julia> @code_hlo optimize=false .+(x_ra, 1)
module {
func.func private @"+_broadcast_scalar"(%arg0: tensor<f8E4M3FN>, %arg1: tensor<i64>) -> (tensor<f8E4M3FN>, tensor<f8E4M3FN>, tensor<i64>) {
%0 = stablehlo.convert %arg1 : (tensor<i64>) -> tensor<f8E4M3FN>
%1 = stablehlo.add %arg0, %0 : tensor<f8E4M3FN>
return %1, %arg0, %arg1 : tensor<f8E4M3FN>, tensor<f8E4M3FN>, tensor<i64>
}
func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> (tensor<3x10xf8E4M3FN>, tensor<3x10xf8E4M3FN>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x10xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
%c = stablehlo.constant dense<1> : tensor<10x3xi64>
%1:3 = enzyme.batch @"+_broadcast_scalar"(%0, %c) {batch_shape = array<i64: 10, 3>} : (tensor<10x3xf8E4M3FN>, tensor<10x3xi64>) -> (tensor<10x3xf8E4M3FN>, tensor<10x3xf8E4M3FN>, tensor<10x3xi64>)
%2 = stablehlo.convert %1#0 : tensor<10x3xf8E4M3FN>
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
%4 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
return %3, %4 : tensor<3x10xf8E4M3FN>, tensor<3x10xf8E4M3FN>
}
} But post optimization stablehlo seems to crash
|
I also removed some of the primitives from ReactantPrimitive since they are not part of the StableHLO spec. I can restore them if we need them. Specifically
|
Open an issue for that on Enzyme-JaX |
No description provided.