Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support lowering custom fp types #596

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Conversation

avik-pal
Copy link
Collaborator

No description provided.

@@ -507,7 +509,7 @@ end
# we need to override the outer copy method to make sure we never fall back to scalar
# iteration (see, e.g., CUDA.jl#145)
function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
fn = if bc.f isa Type && bc.f <: ReactantPrimitive
fn = if bc.f isa Type && is_reactant_primitive(bc.f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should proably be more careful with this. For example, Ops.add won't necessarily do what one wants

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For custom interface, we probably need:

  1. is_reactant_primitive
  2. primitive_type
  3. DenseElementsAttribute
  4. a way to map mlir_type to julia type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core idea is that we can allow support for external packages like Floats8.jl , https://github.com/JuliaMath/DoubleFloats.jl, etc. without having to pull them into Reactant

src/TracedUtils.jl Show resolved Hide resolved
src/TracedUtils.jl Show resolved Hide resolved
@avik-pal
Copy link
Collaborator Author

using Float8s, Reactant

Reactant.to_reactant_primitive_type(::Type{Float8_4}) = Reactant.F8E4M3FNUZ

x = rand(Float32, 10, 3) .|> Float8_4

x_ra = Reactant.to_rarray(x)

@code_hlo .+(x_ra, x_ra)
module {
  func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> tensor<3x10xf8E4M3FN> {
    %0 = stablehlo.add %arg0, %arg0 : tensor<3x10xf8E4M3FN>
    return %0 : tensor<3x10xf8E4M3FN>
  }
}

@wsmoses
Copy link
Member

wsmoses commented Jan 23, 2025

noice

@avik-pal
Copy link
Collaborator Author

julia> @code_hlo optimize=false sum(x_ra)
module {
  func.func private @identity_broadcast_scalar(%arg0: tensor<f8E4M3FN>) -> tensor<f8E4M3FN> {
    return %arg0 : tensor<f8E4M3FN>
  }
  func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> (tensor<f8E4M3FN>, tensor<3x10xf8E4M3FN>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x10xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f16>
    %1 = stablehlo.convert %cst : (tensor<f16>) -> tensor<f8E4M3FN>
    %2 = enzyme.batch @identity_broadcast_scalar(%0) {batch_shape = array<i64: 10, 3>} : (tensor<10x3xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
    %3 = stablehlo.convert %2 : tensor<10x3xf8E4M3FN>
    %4 = stablehlo.reduce(%3 init: %1) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x3xf8E4M3FN>, tensor<f8E4M3FN>) -> tensor<f8E4M3FN>
    %5 = stablehlo.transpose %0, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
    return %4, %5 : tensor<f8E4M3FN>, tensor<3x10xf8E4M3FN>
  }
}

julia> @code_hlo optimize=false .+(x_ra, 1)
module {
  func.func private @"+_broadcast_scalar"(%arg0: tensor<f8E4M3FN>, %arg1: tensor<i64>) -> (tensor<f8E4M3FN>, tensor<f8E4M3FN>, tensor<i64>) {
    %0 = stablehlo.convert %arg1 : (tensor<i64>) -> tensor<f8E4M3FN>
    %1 = stablehlo.add %arg0, %0 : tensor<f8E4M3FN>
    return %1, %arg0, %arg1 : tensor<f8E4M3FN>, tensor<f8E4M3FN>, tensor<i64>
  }
  func.func @main(%arg0: tensor<3x10xf8E4M3FN>) -> (tensor<3x10xf8E4M3FN>, tensor<3x10xf8E4M3FN>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x10xf8E4M3FN>) -> tensor<10x3xf8E4M3FN>
    %c = stablehlo.constant dense<1> : tensor<10x3xi64>
    %1:3 = enzyme.batch @"+_broadcast_scalar"(%0, %c) {batch_shape = array<i64: 10, 3>} : (tensor<10x3xf8E4M3FN>, tensor<10x3xi64>) -> (tensor<10x3xf8E4M3FN>, tensor<10x3xf8E4M3FN>, tensor<10x3xi64>)
    %2 = stablehlo.convert %1#0 : tensor<10x3xf8E4M3FN>
    %3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
    %4 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<10x3xf8E4M3FN>) -> tensor<3x10xf8E4M3FN>
    return %3, %4 : tensor<3x10xf8E4M3FN>, tensor<3x10xf8E4M3FN>
  }
}

But post optimization stablehlo seems to crash

Unknown stablehlo type to parse data from
: f8E4M3FN

[115730] signal 11 (1): Segmentation fault
in expression starting at REPL[7]:1
_ZNK4mlir17DenseElementsAttr7getTypeEv at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir17DenseElementsAttr11resizeSplatENS_10ShapedTypeE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZNK12_GLOBAL__N_115ConvertSimplify15matchAndRewriteEN4mlir9stablehlo9ConvertOpERNS1_15PatternRewriterE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZZN4mlir17PatternApplicator15matchAndRewriteEPNS_9OperationERNS_15PatternRewriterEN4llvm12function_refIFbRKNS_7PatternEEEENS6_IFvS9_EEENS6_IFNS5_13LogicalResultES9_EEEENKUlvE_clEv at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir17PatternApplicator15matchAndRewriteEPNS_9OperationERNS_15PatternRewriterEN4llvm12function_refIFbRKNS_7PatternEEEENS6_IFvS9_EEENS6_IFNS5_13LogicalResultES9_EEE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_126GreedyPatternRewriteDriver15processWorklistEv at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir21applyPatternsGreedilyERNS_6RegionERKNS_23FrozenRewritePatternSetENS_19GreedyRewriteConfigEPb at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15ApplyPatternsOp10applyToOneERNS0_17TransformRewriterEPNS_9OperationERNS0_21ApplyToEachResultListERNS0_14TransformStateE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail20applyTransformToEachINS0_15ApplyPatternsOpERN4llvm14iterator_rangeINS4_20filter_iterator_implIPKPNS_9OperationEZNKS0_14TransformState13getPayloadOpsENS_5ValueEEUlS8_E_St26bidirectional_iterator_tagEEEEEENS_27DiagnosedSilenceableFailureET_RNS0_17TransformRewriterEOT0_RNS4_15SmallVectorImplINS0_21ApplyToEachResultListEEERSB_ at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformEachOpTraitINS0_15ApplyPatternsOpEE5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail35TransformOpInterfaceInterfaceTraits5ModelINS0_15ApplyPatternsOpEE5applyEPKNS2_7ConceptEPNS_9OperationERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformOpInterface5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform14TransformState14applyTransformENS0_20TransformOpInterfaceE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZL18applySequenceBlockRN4mlir5BlockENS_9transform22FailurePropagationModeERNS2_14TransformStateERNS2_16TransformResultsE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15NamedSequenceOp5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail35TransformOpInterfaceInterfaceTraits5ModelINS0_15NamedSequenceOpEE5applyEPKNS2_7ConceptEPNS_9OperationERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformOpInterface5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform14TransformState14applyTransformENS0_20TransformOpInterfaceE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15applyTransformsEPNS_9OperationENS0_20TransformOpInterfaceERKNS_11RaggedArrayIN4llvm12PointerUnionIJS2_NS_9AttributeENS_5ValueEEEEEERKNS0_16TransformOptionsEbNS5_12function_refIFvRNS0_14TransformStateEEEENSG_IFNS5_13LogicalResultESI_EEE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform27applyTransformNamedSequenceENS_11RaggedArrayIN4llvm12PointerUnionIJPNS_9OperationENS_9AttributeENS_5ValueEEEEEENS0_20TransformOpInterfaceENS_8ModuleOpERKNS0_16TransformOptionsE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_115InterpreterPass14runOnOperationEv at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor3runEPNS_4PassEPNS_9OperationENS_15AnalysisManagerEbj at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor11runPipelineERNS_13OpPassManagerEPNS_9OperationENS_15AnalysisManagerEbjPNS_16PassInstrumentorEPKNS_19PassInstrumentation18PipelineParentInfoE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir11PassManager9runPassesEPNS_9OperationENS_15AnalysisManagerE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
_ZN4mlir11PassManager3runEPNS_9OperationE at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
mlirPassManagerRunOnOp at /mnt/.julia/artifacts/cdb4f20067cc2d97beff8bb166b75a71ef300805/lib/libReactantExtra.so (unknown line)
mlirPassManagerRunOnOp at /mnt/software/lux/Reactant.jl/src/mlir/libMLIR_h.jl:5867 [inlined]
run! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:74 [inlined]
#run_pass_pipeline!#2 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:335
run_pass_pipeline! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:330 [inlined]
#compile_mlir!#9 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:468
compile_mlir! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:426 [inlined]
#7 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:362 [inlined]
context! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
unknown function (ip: 0x78314b5cdea6)
#compile_mlir#6 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:360
compile_mlir at /mnt/software/lux/Reactant.jl/src/Compiler.jl:356
unknown function (ip: 0x78314b5c7d4d)
jl_apply at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:886
eval_body at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:625
eval_body at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:539
jl_interpret_toplevel_thunk at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /home/avikpal/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /home/avikpal/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /home/avikpal/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /home/avikpal/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /home/avikpal/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /home/avikpal/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10119 at /mnt/.julia/compiled/v1.11/REPL/u0gqU_gCry3.so (unknown line)
#1150 at ./client.jl:446
jfptr_YY.1150_14727 at /mnt/.julia/compiled/v1.11/REPL/u0gqU_gCry3.so (unknown line)
jl_apply at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_73609.1 at /home/avikpal/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
true_main at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/src/jlapi.c:1059
main at /cache/build/builder-demeter6-3/julialang/julia-release-1-dot-11/cli/loader_exe.c:58
unknown function (ip: 0x7831d4b44e07)
__libc_start_main at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 34824781 (Pool: 34823443; Big: 1338); GC: 32

@avik-pal
Copy link
Collaborator Author

avik-pal commented Jan 23, 2025

I also removed some of the primitives from ReactantPrimitive since they are not part of the StableHLO spec. I can restore them if we need them. Specifically

  • Int128
  • UInt128
  • Complex types other than F32 and F64

@wsmoses
Copy link
Member

wsmoses commented Jan 23, 2025

Open an issue for that on Enzyme-JaX

@avik-pal
Copy link
Collaborator Author

xref EnzymeAD/Enzyme-JAX#264

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants