Skip to content

Commit 5d505f6

Browse files
committed
feat: add triangular_solve op
1 parent c1c582f commit 5d505f6

File tree

1 file changed

+72
-16
lines changed

1 file changed

+72
-16
lines changed

src/Ops.jl

+72-16
Original file line numberDiff line numberDiff line change
@@ -2846,6 +2846,16 @@ end
28462846
]
28472847
end
28482848

2849+
"""
2850+
lu(
2851+
x::TracedRArray{T},
2852+
::Type{pT}=Int32;
2853+
location=mlir_stacktrace("lu", @__FILE__, @__LINE__)
2854+
) where {T,pT}
2855+
2856+
Compute the row maximum pivoted LU factorization of `x` and return the factors `LU`,
2857+
`ipiv`, `permutation` tensor, and `info`.
2858+
"""
28492859
@noinline function lu(
28502860
x::TracedRArray{T},
28512861
::Type{pT}=Int32;
@@ -2862,30 +2872,76 @@ end
28622872
op = MLIR.Dialects.enzymexla.linalg_lu(
28632873
x.mlir_data;
28642874
output=MLIR.IR.TensorType(output_shape, MLIR.IR.Type(unwrapped_eltype(T))),
2865-
pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(unwrapped_eltype(pT))),
2866-
permutation=MLIR.IR.TensorType(
2867-
permutation_shape, MLIR.IR.Type(unwrapped_eltype(pT))
2868-
),
2869-
info=MLIR.IR.TensorType(info_shape, MLIR.IR.Type(unwrapped_eltype(pT))),
2875+
pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(pT)),
2876+
permutation=MLIR.IR.TensorType(permutation_shape, MLIR.IR.Type(pT)),
2877+
info=MLIR.IR.TensorType(info_shape, MLIR.IR.Type(pT)),
28702878
location,
28712879
)
28722880

2873-
res = TracedRArray{unwrapped_eltype(T),ndims(x)}((), MLIR.IR.result(op, 1), size(x))
2874-
ipiv = TracedRArray{unwrapped_eltype(pT),ndims(x) - 1}(
2875-
(), MLIR.IR.result(op, 2), pivots_shape
2876-
)
2877-
perm = TracedRArray{unwrapped_eltype(pT),ndims(x) - 1}(
2878-
(), MLIR.IR.result(op, 3), permutation_shape
2879-
)
2881+
res = TracedRArray{T,ndims(x)}((), MLIR.IR.result(op, 1), size(x))
2882+
ipiv = TracedRArray{pT,ndims(x) - 1}((), MLIR.IR.result(op, 2), pivots_shape)
2883+
perm = TracedRArray{pT,ndims(x) - 1}((), MLIR.IR.result(op, 3), permutation_shape)
28802884

28812885
if ndims(x) == 2
2882-
info = TracedRNumber{unwrapped_eltype(pT)}((), MLIR.IR.result(op, 4))
2886+
info = TracedRNumber{pT}((), MLIR.IR.result(op, 4))
28832887
else
2884-
info = TracedRArray{unwrapped_eltype(pT),ndims(x) - 2}(
2885-
(), MLIR.IR.result(op, 4), info_shape
2886-
)
2888+
info = TracedRArray{pT,ndims(x) - 2}((), MLIR.IR.result(op, 4), info_shape)
28872889
end
28882890
return (res, ipiv, perm, info)
28892891
end
28902892

2893+
@noinline function triangular_solve(
2894+
a::TracedRArray{T,N},
2895+
b::TracedRArray{T,M};
2896+
left_side::Bool=true,
2897+
lower::Bool,
2898+
unit_diagonal::Bool=false,
2899+
transpose_a::Symbol=:N,
2900+
location=mlir_stacktrace("triangular_solve", @__FILE__, @__LINE__),
2901+
) where {T,N,M}
2902+
@assert N >= 2
2903+
2904+
if M == N - 1
2905+
if left_side
2906+
b = Ops.reshape(b, size(b)..., 1; location)
2907+
else
2908+
b = Ops.reshape(b, size(b)[1:(M - 1)]..., 1, size(b, M); location)
2909+
end
2910+
end
2911+
2912+
@assert size(a, N - 1) == size(a, N) == size(b, N - left_side)
2913+
@assert N == ndims(b)
2914+
@assert transpose_a (:N, :T, :C)
2915+
@assert size(a)[1:(N - 2)] == size(b)[1:(N - 2)]
2916+
2917+
transpose_attr = if transpose_a === :N
2918+
MLIR.API.stablehloTransposeAttrGet(MLIR.IR.context(), "NO_TRANSPOSE")
2919+
elseif transpose_a === :T
2920+
MLIR.API.stablehloTransposeAttrGet(MLIR.IR.context(), "TRANSPOSE")
2921+
else
2922+
MLIR.API.stablehloTransposeAttrGet(MLIR.IR.context(), "ADJOINT")
2923+
end
2924+
2925+
op = stablehlo.triangular_solve(
2926+
a.mlir_data,
2927+
b.mlir_data;
2928+
left_side=MLIR.IR.Attribute(left_side),
2929+
lower=MLIR.IR.Attribute(lower),
2930+
unit_diagonal=MLIR.IR.Attribute(unit_diagonal),
2931+
transpose_a=transpose_attr,
2932+
location,
2933+
)
2934+
2935+
result_shape = size(a)[1:(N - 2)]
2936+
if left_side
2937+
result_shape = (result_shape..., size(a, N), size(b, N))
2938+
else
2939+
result_shape = (result_shape..., size(b, N - 1), size(a, N - 1))
2940+
end
2941+
2942+
sol = TracedRArray{T,N}((), MLIR.IR.result(op, 1), result_shape)
2943+
N == M && return sol
2944+
return Ops.reshape(sol, Int64[size(a)[1:(N - 2)]..., size(b, N - left_side)]; location)
2945+
end
2946+
28912947
end # module Ops

0 commit comments

Comments
 (0)