-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@trace
function_call() to introduce function barrier
#346
Comments
i've to do sth similar manually in #344 but i have some problems in the end:
aside of that, we must check how well does using |
Autodiff is already supported, it should not be too hard to support batch if it isn't yet supported (basically apply enzyme.batch to the call).
This looks like the pattern that we emit for a reshape, are you sure that it is related to the added calls?
The inline pass has parameters than we can tweak. Also IIRC this pass was run before autodiff since autodiff of func.call was not yet supported. |
mmm you're right, using my PR #344 on this code using Reactant
using YaoBlocks
θ = ConcreteRNumber(rand())
f(x) = mat(ComplexF64, Rz(x))
@code_hlo optimize = false f(θ) generates the following MLIR module {
func.func @rz_Float64_ComplexF64(%arg0: tensor<f64>) -> tensor<2x2xcomplex<f64>> {
%cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<2x2xcomplex<f64>>
%cst_0 = stablehlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor<complex<f64>>
%0 = stablehlo.convert %arg0 : (tensor<f64>) -> tensor<complex<f64>>
%1 = stablehlo.multiply %cst_0, %0 : tensor<complex<f64>>
%cst_1 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor<complex<f64>>
%2 = stablehlo.divide %1, %cst_1 : tensor<complex<f64>>
%3 = stablehlo.exponential %2 : tensor<complex<f64>>
%4 = chlo.conj %3 : tensor<complex<f64>> -> tensor<complex<f64>>
%5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<complex<f64>>) -> tensor<1x1xcomplex<f64>>
%c = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<1> : tensor<i64>
%6 = stablehlo.subtract %c, %c_2 : tensor<i64>
%c_3 = stablehlo.constant dense<1> : tensor<i64>
%c_4 = stablehlo.constant dense<1> : tensor<i64>
%7 = stablehlo.subtract %c_3, %c_4 : tensor<i64>
%8 = stablehlo.dynamic_update_slice %cst, %5, %6, %7 : (tensor<2x2xcomplex<f64>>, tensor<1x1xcomplex<f64>>, tensor<i64>, tensor<i64>) -> tensor<2x2xcomplex<f64>>
%9 = stablehlo.broadcast_in_dim %3, dims = [] : (tensor<complex<f64>>) -> tensor<1x1xcomplex<f64>>
%c_5 = stablehlo.constant dense<2> : tensor<i64>
%c_6 = stablehlo.constant dense<1> : tensor<i64>
%10 = stablehlo.subtract %c_5, %c_6 : tensor<i64>
%c_7 = stablehlo.constant dense<2> : tensor<i64>
%c_8 = stablehlo.constant dense<1> : tensor<i64>
%11 = stablehlo.subtract %c_7, %c_8 : tensor<i64>
%12 = stablehlo.dynamic_update_slice %8, %9, %10, %11 : (tensor<2x2xcomplex<f64>>, tensor<1x1xcomplex<f64>>, tensor<i64>, tensor<i64>) -> tensor<2x2xcomplex<f64>>
return %12 : tensor<2x2xcomplex<f64>>
}
func.func @main(%arg0: tensor<f64>) -> (tensor<2x2xcomplex<f64>>, tensor<f64>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f64>) -> tensor<f64>
%1 = call @rz_Float64_ComplexF64(%0) : (tensor<f64>) -> tensor<2x2xcomplex<f64>>
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xcomplex<f64>>) -> tensor<2x2xcomplex<f64>>
%3 = stablehlo.transpose %0, dims = [] : (tensor<f64>) -> tensor<f64>
return %2, %3 : tensor<2x2xcomplex<f64>>, tensor<f64>
}
} which doesn't have the It's probably me (Tenet) or Yao again, I'll have to revise that.
Yeah, but I was thinking if there's a way to mark a
I don't understand, the forward-/reverse-rules for |
I think @wsmoses has been working on something like this (llvm/llvm-project#117392).
Yes they are implemented. Not sure if this is up to date on Reactant_jll. |
Cool!
Yes, it's inside the latest Reactant_jll. The current Enzyme commit in |
yeah I need to finish up that LLVM PR but got distracted kernel'ing |
Currently, calling the same function multiple times will generate the IR multiple times due to the tracing.
One idea to circumvent this problem is to update the
@trace
macro to introduce a MLIR function call as the result op instead of tracing through the function. The cache for this must be specific:Check the values in Julia land, and the MLIR Types to cache the functions among a single trace.
The text was updated successfully, but these errors were encountered: