@@ -2846,6 +2846,16 @@ end
2846
2846
]
2847
2847
end
2848
2848
2849
+ """
2850
+ lu(
2851
+ x::TracedRArray{T},
2852
+ ::Type{pT}=Int32;
2853
+ location=mlir_stacktrace("lu", @__FILE__, @__LINE__)
2854
+ ) where {T,pT}
2855
+
2856
+ Compute the row maximum pivoted LU factorization of `x` and return the factors `LU`,
2857
+ `ipiv`, `permutation` tensor, and `info`.
2858
+ """
2849
2859
@noinline function lu (
2850
2860
x:: TracedRArray{T} ,
2851
2861
:: Type{pT} = Int32;
@@ -2862,30 +2872,76 @@ end
2862
2872
op = MLIR. Dialects. enzymexla. linalg_lu (
2863
2873
x. mlir_data;
2864
2874
output= MLIR. IR. TensorType (output_shape, MLIR. IR. Type (unwrapped_eltype (T))),
2865
- pivots= MLIR. IR. TensorType (pivots_shape, MLIR. IR. Type (unwrapped_eltype (pT))),
2866
- permutation= MLIR. IR. TensorType (
2867
- permutation_shape, MLIR. IR. Type (unwrapped_eltype (pT))
2868
- ),
2869
- info= MLIR. IR. TensorType (info_shape, MLIR. IR. Type (unwrapped_eltype (pT))),
2875
+ pivots= MLIR. IR. TensorType (pivots_shape, MLIR. IR. Type (pT)),
2876
+ permutation= MLIR. IR. TensorType (permutation_shape, MLIR. IR. Type (pT)),
2877
+ info= MLIR. IR. TensorType (info_shape, MLIR. IR. Type (pT)),
2870
2878
location,
2871
2879
)
2872
2880
2873
- res = TracedRArray {unwrapped_eltype(T),ndims(x)} ((), MLIR. IR. result (op, 1 ), size (x))
2874
- ipiv = TracedRArray {unwrapped_eltype(pT),ndims(x) - 1} (
2875
- (), MLIR. IR. result (op, 2 ), pivots_shape
2876
- )
2877
- perm = TracedRArray {unwrapped_eltype(pT),ndims(x) - 1} (
2878
- (), MLIR. IR. result (op, 3 ), permutation_shape
2879
- )
2881
+ res = TracedRArray {T,ndims(x)} ((), MLIR. IR. result (op, 1 ), size (x))
2882
+ ipiv = TracedRArray {pT,ndims(x) - 1} ((), MLIR. IR. result (op, 2 ), pivots_shape)
2883
+ perm = TracedRArray {pT,ndims(x) - 1} ((), MLIR. IR. result (op, 3 ), permutation_shape)
2880
2884
2881
2885
if ndims (x) == 2
2882
- info = TracedRNumber {unwrapped_eltype(pT) } ((), MLIR. IR. result (op, 4 ))
2886
+ info = TracedRNumber {pT } ((), MLIR. IR. result (op, 4 ))
2883
2887
else
2884
- info = TracedRArray {unwrapped_eltype(pT),ndims(x) - 2} (
2885
- (), MLIR. IR. result (op, 4 ), info_shape
2886
- )
2888
+ info = TracedRArray {pT,ndims(x) - 2} ((), MLIR. IR. result (op, 4 ), info_shape)
2887
2889
end
2888
2890
return (res, ipiv, perm, info)
2889
2891
end
2890
2892
2893
+ @noinline function triangular_solve (
2894
+ a:: TracedRArray{T,N} ,
2895
+ b:: TracedRArray{T,M} ;
2896
+ left_side:: Bool = true ,
2897
+ lower:: Bool ,
2898
+ unit_diagonal:: Bool = false ,
2899
+ transpose_a:: Symbol = :N ,
2900
+ location= mlir_stacktrace (" triangular_solve" , @__FILE__ , @__LINE__ ),
2901
+ ) where {T,N,M}
2902
+ @assert N >= 2
2903
+
2904
+ if M == N - 1
2905
+ if left_side
2906
+ b = Ops. reshape (b, size (b)... , 1 ; location)
2907
+ else
2908
+ b = Ops. reshape (b, size (b)[1 : (M - 1 )]. .. , 1 , size (b, M); location)
2909
+ end
2910
+ end
2911
+
2912
+ @assert size (a, N - 1 ) == size (a, N) == size (b, N - left_side)
2913
+ @assert N == ndims (b)
2914
+ @assert transpose_a ∈ (:N , :T , :C )
2915
+ @assert size (a)[1 : (N - 2 )] == size (b)[1 : (N - 2 )]
2916
+
2917
+ transpose_attr = if transpose_a === :N
2918
+ MLIR. API. stablehloTransposeAttrGet (MLIR. IR. context (), " NO_TRANSPOSE" )
2919
+ elseif transpose_a === :T
2920
+ MLIR. API. stablehloTransposeAttrGet (MLIR. IR. context (), " TRANSPOSE" )
2921
+ else
2922
+ MLIR. API. stablehloTransposeAttrGet (MLIR. IR. context (), " ADJOINT" )
2923
+ end
2924
+
2925
+ op = stablehlo. triangular_solve (
2926
+ a. mlir_data,
2927
+ b. mlir_data;
2928
+ left_side= MLIR. IR. Attribute (left_side),
2929
+ lower= MLIR. IR. Attribute (lower),
2930
+ unit_diagonal= MLIR. IR. Attribute (unit_diagonal),
2931
+ transpose_a= transpose_attr,
2932
+ location,
2933
+ )
2934
+
2935
+ result_shape = size (a)[1 : (N - 2 )]
2936
+ if left_side
2937
+ result_shape = (result_shape... , size (a, N), size (b, N))
2938
+ else
2939
+ result_shape = (result_shape... , size (b, N - 1 ), size (a, N - 1 ))
2940
+ end
2941
+
2942
+ sol = TracedRArray {T,N} ((), MLIR. IR. result (op, 1 ), result_shape)
2943
+ N == M && return sol
2944
+ return Ops. reshape (sol, Int64[size (a)[1 : (N - 2 )]. .. , size (b, N - left_side)]; location)
2945
+ end
2946
+
2891
2947
end # module Ops
0 commit comments