-
Notifications
You must be signed in to change notification settings - Fork 23
feat: support lowering custom fp types #596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
92cc29d
to
89ef696
Compare
using Float8s, Reactant
Reactant.to_reactant_primitive_type(::Type{Float8_4}) = Reactant.F8E4M3FNUZ
x = rand(Float32, 10, 3) .|> Float8_4
x_ra = Reactant.to_rarray(x)
@code_hlo .+(x_ra, x_ra) module {
func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> tensor<3x10xf8E4M3FN> {
%0 = stablehlo.add %arg0, %arg0 : tensor<3x10xf8E4M3FN>
return %0 : tensor<3x10xf8E4M3FN>
}
} |
noice |
d68cda9
to
e60e7cf
Compare
julia> @code_hlo optimize=false sum(x_ra)
module {
func.func private @identity_broadcast_scalar(%arg0: tensor<f8E4M3FN>) -> tensor<f8E4M3FN> {
return %arg0 : tensor<f8E4M3FN>
}
func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> (tensor<f8E4M3FN>, tensor<3x10xf8E4M3FN>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x10xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f16>
%1 = stablehlo.convert %cst : (tensor<f16>) -> tensor<f8E4M3FN>
%2 = enzyme.batch @identity_broadcast_scalar(%0) {batch_shape = array<i64: 10, 3>} : (tensor<10x3xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
%3 = stablehlo.convert %2 : tensor<10x3xf8E4M3FN>
%4 = stablehlo.reduce(%3 init: %1) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x3xf8E4M3FN>, tensor<f8E4M3FN>) -> tensor<f8E4M3FN>
%5 = stablehlo.transpose %0, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
return %4, %5 : tensor<f8E4M3FN>, tensor<3x10xf8E4M3FN>
}
}
julia> @code_hlo optimize=false .+(x_ra, 1)
module {
func.func private @"+_broadcast_scalar"(%arg0: tensor<f8E4M3FN>, %arg1: tensor<i64>) -> (tensor<f8E4M3FN>, tensor<f8E4M3FN>, tensor<i64>) {
%0 = stablehlo.convert %arg1 : (tensor<i64>) -> tensor<f8E4M3FN>
%1 = stablehlo.add %arg0, %0 : tensor<f8E4M3FN>
return %1, %arg0, %arg1 : tensor<f8E4M3FN>, tensor<f8E4M3FN>, tensor<i64>
}
func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> (tensor<3x10xf8E4M3FN>, tensor<3x10xf8E4M3FN>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x10xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
%c = stablehlo.constant dense<1> : tensor<10x3xi64>
%1:3 = enzyme.batch @"+_broadcast_scalar"(%0, %c) {batch_shape = array<i64: 10, 3>} : (tensor<10x3xf8E4M3FN>, tensor<10x3xi64>) -> (tensor<10x3xf8E4M3FN>, tensor<10x3xf8E4M3FN>, tensor<10x3xi64>)
%2 = stablehlo.convert %1#0 : tensor<10x3xf8E4M3FN>
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
%4 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
return %3, %4 : tensor<3x10xf8E4M3FN>, tensor<3x10xf8E4M3FN>
}
} But post optimization stablehlo seems to crash
|
I also removed some of the primitives from ReactantPrimitive since they are not part of the StableHLO spec. I can restore them if we need them. Specifically
|
Open an issue for that on Enzyme-JaX |
438185e
to
a489f5e
Compare
no_nan passes failure needs a patch in EnzymeJAX |
@avik-pal take a double check here, but this seem fine to merge? |
Mac doesn't seem to like the float8 tests. I will disable those for mac |
3d899d7
to
c09599d
Compare
No description provided.