diff --git a/src/Ops.jl b/src/Ops.jl index 18ab2d7d4b..e7ffe033ae 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -950,28 +950,49 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -# sorting ops -# TODO need to trace over `comparator` -# function sort( -# x::TracedRArray{T,N}; -# comparator, -# dimension=-1, -# is_stable=false, -# location=mlir_stacktrace("sort", @__FILE__, @__LINE__), -# ) where {T,N} -# dimension = MLIR.IR.Attribute(dimension) -# is_stable = MLIR.IR.Attribute(is_stable) -# res = MLIR.IR.result( -# stablehlo.sort( -# x.mlir_data; -# result=mlir_type(TracedRArray{T,N}, size(x)), -# dimension, -# is_stable, -# location, -# ), -# ) -# return TracedRArray{T,N}((), res, size(x)) -# end +@noinline function sort( + x::TracedRArray{T,N}; + comparator, + dimension=1, + is_stable=false, + location=mlir_stacktrace("sort", @__FILE__, @__LINE__), +) where {T,N} + #C4: + @assert 0 < dimension <= ndims(x) "$x invalid dimension" + + (a, b) = (Reactant.ConcreteRNumber(T(0)), Reactant.ConcreteRNumber(T(0))) + func = Reactant.TracedUtils.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true, return_dialect=:stablehlo)[2] + @assert MLIR.IR.nregions(func) == 1 + fn_name = String( + MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) + ) + #C5: + @assert fn_name == "main" "$comparator: no function generated" + ftype_attr = MLIR.IR.attr(func, "function_type") + ftype = MLIR.IR.Type(ftype_attr) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error( + "$comparator return type is not tensor" + ) + + comparator = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1)) + MLIR.IR.rmfromparent!(func) + + dimension = MLIR.IR.Attribute(dimension - 1) + is_stable = MLIR.IR.Attribute(is_stable) + + res = MLIR.IR.result( + stablehlo.sort( + [x.mlir_data]; + result_0=[mlir_type(TracedRArray{T,N}, size(x))], + dimension, + is_stable, + comparator, + location, + ), + ) + return TracedRArray{T,N}((), res, size(x)) +end @noinline function top_k( x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) diff --git a/test/Project.toml b/test/Project.toml index d8861a1aae..aa1291edfe 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,6 +43,7 @@ NNlib = "0.9.26" OneHotArrays = "0.2.6" Optimisers = "0.4" Random = "1.10" +Random123 = "1" SafeTestsets = "0.1" SpecialFunctions = "2.4" Statistics = "1.10" diff --git a/test/ops.jl b/test/ops.jl index 82ec4cc8b8..561e2911ba 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -688,6 +688,18 @@ end end end +@testset "sort" begin + basic_sort(x, dimension) = Reactant.Ops.sort(x; comparator=(a, b) -> a < b, dimension) + for i in 1:3 + t_size = tuple(fill(10, (i,))...) + x = Reactant.to_rarray(randn(t_size)) + + for j in 1:i + @test (i == 1 ? sort(x) : sort(x; dims=j)) == @jit basic_sort(x, j) + end + end +end + @testset "slice" begin x = ConcreteRArray([1, 2, 3, 4]) @test [2, 3] == @jit Ops.slice(x, [2], [3])