Skip to content

Commit e48006a

Browse files
committed
feat: Ops.lu
1 parent d27baa4 commit e48006a

File tree

4 files changed

+50
-30
lines changed

4 files changed

+50
-30
lines changed

src/Compiler.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1273,7 +1273,7 @@ function compile_mlir!(
12731273
end
12741274

12751275
blas_int_width = sizeof(BLAS.BlasInt) * 8
1276-
lower_factorization_pass = "lower-factorization{backend=$backend \
1276+
lower_factorization_pass = "lower-enzymexla-linalg{backend=$backend \
12771277
blas_int_width=$blas_int_width}"
12781278

12791279
if optimize === :all

src/Ops.jl

+42
Original file line numberDiff line numberDiff line change
@@ -2846,4 +2846,46 @@ end
28462846
]
28472847
end
28482848

2849+
@noinline function lu(
2850+
x::TracedRArray{T},
2851+
::Type{pT}=Int32;
2852+
location=mlir_stacktrace("lu", @__FILE__, @__LINE__),
2853+
) where {T,pT}
2854+
@assert ndims(x) >= 2
2855+
2856+
output_shape = collect(Int64, size(x))
2857+
batch_shape = output_shape[1:(end - 2)]
2858+
pivots_shape = vcat(batch_shape, min(size(x, ndims(x) - 1), size(x, ndims(x))))
2859+
permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1))
2860+
info_shape = batch_shape
2861+
2862+
op = MLIR.Dialects.enzymexla.linalg_lu(
2863+
x.mlir_data;
2864+
output=MLIR.IR.TensorType(output_shape, MLIR.IR.Type(unwrapped_eltype(T))),
2865+
pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(unwrapped_eltype(pT))),
2866+
permutation=MLIR.IR.TensorType(
2867+
permutation_shape, MLIR.IR.Type(unwrapped_eltype(pT))
2868+
),
2869+
info=MLIR.IR.TensorType(info_shape, MLIR.IR.Type(unwrapped_eltype(pT))),
2870+
location,
2871+
)
2872+
2873+
res = TracedRArray{unwrapped_eltype(T),ndims(x)}((), MLIR.IR.result(op, 1), size(x))
2874+
ipiv = TracedRArray{unwrapped_eltype(pT),ndims(x) - 1}(
2875+
(), MLIR.IR.result(op, 2), pivots_shape
2876+
)
2877+
perm = TracedRArray{unwrapped_eltype(pT),ndims(x) - 1}(
2878+
(), MLIR.IR.result(op, 3), permutation_shape
2879+
)
2880+
2881+
if ndims(x) == 2
2882+
info = TracedRNumber{unwrapped_eltype(pT)}((), MLIR.IR.result(op, 4))
2883+
else
2884+
info = TracedRArray{unwrapped_eltype(pT),ndims(x) - 2}(
2885+
(), MLIR.IR.result(op, 4), info_shape
2886+
)
2887+
end
2888+
return (res, ipiv, perm, info)
2889+
end
2890+
28492891
end # module Ops

src/mlir/Dialects/EnzymeXLA.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,14 @@ function kernel_call(
189189
end
190190

191191
function linalg_lu(
192-
input::Value; output::IR.Type, pivots::IR.Type, info::IR.Type, location=Location()
192+
input::Value;
193+
output::IR.Type,
194+
pivots::IR.Type,
195+
permutation::IR.Type,
196+
info::IR.Type,
197+
location=Location(),
193198
)
194-
op_ty_results = IR.Type[output, pivots, info]
199+
op_ty_results = IR.Type[output, pivots, permutation, info]
195200
operands = Value[input,]
196201
owned_regions = Region[]
197202
successors = Block[]

src/stdlibs/LinearAlgebra.jl

-27
Original file line numberDiff line numberDiff line change
@@ -488,31 +488,4 @@ function LinearAlgebra.dot(x::AnyTracedRVector, y::AnyTracedRVector)
488488
return TracedRNumber{unwrapped_eltype(res)}((), res.mlir_data)
489489
end
490490

491-
# Factorizations
492-
function LinearAlgebra.lu!(
493-
A::TracedRArray{T,2},
494-
::LinearAlgebra.RowMaximum;
495-
check::Bool=true,
496-
allowsingular::Bool=false,
497-
) where {T}
498-
# TODO: stop ignoring `check` and `allowsingular`
499-
m, n = size(A)
500-
501-
lu_op = MLIR.Dialects.enzymexla.linalg_lu(
502-
materialize_traced_array(A).mlir_data;
503-
output=MLIR.IR.TensorType(
504-
collect(Int64, size(A)), MLIR.IR.Type(unwrapped_eltype(A))
505-
),
506-
pivots=MLIR.IR.TensorType(Int64[min(m, n)], MLIR.IR.Type(Int32)),
507-
info=MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Int32)),
508-
)
509-
set_mlir_data!(A, MLIR.IR.result(lu_op, 1))
510-
pivots = TracedRArray{Int32,1}((), MLIR.IR.result(lu_op, 2), (min(m, n),))
511-
info = TracedRNumber{Int32}((), MLIR.IR.result(lu_op, 3))
512-
513-
# XXX: `info` needs to be a BLASInt for LU
514-
# return LinearAlgebra.LU{T, typeof(A), typeof(pivots)}(A, pivots, -1)
515-
return (UnitLowerTriangular(A), UpperTriangular(A), pivots, info)
516-
end
517-
518491
end

0 commit comments

Comments
 (0)