Skip to content

Commit 9ece827

Browse files
committed
fix test GPU
1 parent 7040632 commit 9ece827

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/Ops.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -961,13 +961,13 @@ end
961961
@assert 0 < dimension <= ndims(x) "$x invalid dimension"
962962

963963
(a, b) = (Reactant.ConcreteRNumber(T(0)), Reactant.ConcreteRNumber(T(0)))
964-
func = Reactant.TracedUtils.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true, return_dialect=:stablehlo)[2]
964+
func = Reactant.TracedUtils.make_mlir_fn(comparator, (a, b), (), "comparator"; no_args_in_result=true, return_dialect=:stablehlo)[2]
965965
@assert MLIR.IR.nregions(func) == 1
966966
fn_name = String(
967967
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
968968
)
969969
#C5:
970-
@assert fn_name == "main" "$comparator: no function generated"
970+
@assert fn_name == "comparator" "$comparator: no function generated"
971971
ftype_attr = MLIR.IR.attr(func, "function_type")
972972
ftype = MLIR.IR.Type(ftype_attr)
973973
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error(

test/ops.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,10 +692,11 @@ end
692692
basic_sort(x, dimension) = Reactant.Ops.sort(x; comparator=(a, b) -> a < b, dimension)
693693
for i in 1:3
694694
t_size = tuple(fill(10, (i,))...)
695-
x = Reactant.to_rarray(randn(t_size))
695+
x = randn(t_size)
696+
xa = Reactant.to_rarray(x)
696697

697698
for j in 1:i
698-
@test (i == 1 ? sort(x) : sort(x; dims=j)) == @jit basic_sort(x, j)
699+
@test (i == 1 ? sort(x) : sort(x; dims=j)) == @jit basic_sort(xa, j)
699700
end
700701
end
701702
end

0 commit comments

Comments
 (0)