Skip to content

Commit 3570c2b

Browse files
mkittigdallenalimilan
authored
Fix argmax, findmax, findXwithfirst, and expand testing (#99)
* Fix argmin and argmax * Revert bqd fix * Add two small tests * Moar tests to figure out what's wrong * Fix tests * Fix comparison order * Add more vectors to test and test vectors * last fix * One comp only * Add tests * Improve test set * Integrate test suggestions by @Seelengrab * Fix argmax tests with function * Fix findmax * Test collect * Run some tests only on 1.6 or greater * Fix static if * Add more tests * Update chainedvector.jl * Update src/chainedvector.jl * Update test/chainedvector.jl Co-authored-by: Milan Bouchet-Valat <[email protected]> * Revert back to approx in sum test --------- Co-authored-by: Guillaume Dalle <[email protected]> Co-authored-by: Milan Bouchet-Valat <[email protected]>
1 parent d13560a commit 3570c2b

File tree

2 files changed

+164
-1
lines changed

2 files changed

+164
-1
lines changed

src/chainedvector.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,10 @@ function Base.findmax(f::F, x::ChainedVector) where {F}
809809
cleanup!(x) # get rid of any empty arrays
810810
i = 1
811811
y = f(x.arrays[1][1])
812-
return findXwithfirst(!isless, f, x, y, i)
812+
# x > y iff y < x for a well ordered set
813+
# nb. isgreater = !isless is not correct. That is `>=`
814+
isgreater(x, y) = isless(y, x)
815+
return findXwithfirst(isgreater, f, x, y, i)
813816
end
814817

815818
function Base.findmin(f::F, x::ChainedVector) where {F}

test/chainedvector.jl

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@testset "ChainedVector" begin
22

3+
# identity checks
34
x = ChainedVector([[1,2,3], [4,5,6], [7,8,9,10]])
45
@test x == 1:10
56
@test length(x) == 10
@@ -55,6 +56,7 @@
5556
insert!(x, 1, 2)
5657
@test x[1] == 2
5758

59+
5860
x = ChainedVector([[1,2,3], [4,5,6], [7,8,9,10]])
5961
y = ChainedVector([[11,12,13], [14,15,16], [17,18,19,20]])
6062

@@ -434,6 +436,164 @@
434436
@test (rand(10,10) * v) isa ChainedVector
435437
end
436438

439+
@testset "ChainedVectors on Generated Vectors" begin
440+
#=
441+
# Use to generate text below
442+
function test_vector_generator(;
443+
lengths = rand(0:5, 5),
444+
possible_values = 1:100,
445+
)
446+
values = rand(possible_values, sum(lengths))
447+
remaining_values = copy(values)
448+
arrays = map(lengths) do length
449+
result = remaining_values[1:length]
450+
remaining_values = @view remaining_values[length+1:end]
451+
return result
452+
end
453+
return ChainedVector(arrays) => values
454+
end
455+
function Base.show(io::IO, cv::ChainedVector)
456+
print(io, "ChainedVector(")
457+
show(io, cv.arrays)
458+
print(io, ")")
459+
end
460+
for i in 1:10
461+
test_vector_generator() |>
462+
repr |>
463+
x->replace(x, "=>" => "=>\n ") |>
464+
x->println(x,",")
465+
end
466+
=#
467+
468+
# Pairs of test vectors
469+
# Some were inspired by https://github.com/JuliaData/SentinelArrays.jl/issues/97
470+
int_vectors = [
471+
ChainedVector([[100, 20], [10, 30, 70, 40], [50], Int[], [60, 90, 80]]) =>
472+
[100, 20, 10, 30, 70, 40, 50, 60, 90, 80],
473+
ChainedVector([[2,1,3], [4,5,6], [7,8,10,9]]) =>
474+
[2, 1, 3, 4, 5, 6, 7, 8, 10, 9],
475+
ChainedVector([[18, 70, 92, 15, 65], [25, 14, 95, 54, 57]]) =>
476+
[18, 70, 92, 15, 65, 25, 14, 95, 54, 57],
477+
ChainedVector([[2, 34], [61, 8, 71], [65, 81, 51], [48, 93, 48, 94], [59, 15, 16, 56, 83]]) =>
478+
[2, 34, 61, 8, 71, 65, 81, 51, 48, 93, 48, 94, 59, 15, 16, 56, 83],
479+
ChainedVector([[23, 97, 70, 70], [4, 4], [61, 17], [95, 84, 91]]) =>
480+
[23, 97, 70, 70, 4, 4, 61, 17, 95, 84, 91],
481+
ChainedVector([[61, 23, 67, 61], [27, 19, 100], [26, 95], [2, 27, 63], [51, 52, 25]]) =>
482+
[61, 23, 67, 61, 27, 19, 100, 26, 95, 2, 27, 63, 51, 52, 25],
483+
ChainedVector([[25, 6, 94], [50], [63, 1, 76], [96, 6]]) =>
484+
[25, 6, 94, 50, 63, 1, 76, 96, 6],
485+
ChainedVector([[98, 5, 94], [82, 60], [58, 46, 13, 62, 48]]) =>
486+
[98, 5, 94, 82, 60, 58, 46, 13, 62, 48],
487+
ChainedVector([[28, 26], [21, 18, 64, 15], [11, 81, 17, 90], [29], [16, 67, 34, 84]]) =>
488+
[28, 26, 21, 18, 64, 15, 11, 81, 17, 90, 29, 16, 67, 34, 84],
489+
ChainedVector([[95, 15, 49, 31, 63], [79, 88], [76], [87, 52], [86, 50, 68, 61]]) =>
490+
[95, 15, 49, 31, 63, 79, 88, 76, 87, 52, 86, 50, 68, 61],
491+
ChainedVector([[71], [96, 84], [88, 3], [76, 47]]) =>
492+
[71, 96, 84, 88, 3, 76, 47],
493+
ChainedVector([[7, 21, 31], [45], [53, 53]]) =>
494+
[7, 21, 31, 45, 53, 53],
495+
ChainedVector([[24, 28, 75, 42], [7, 38, 59, 10], [21, 30, 14], [8, 39], [13, 68, 42]]) =>
496+
[24, 28, 75, 42, 7, 38, 59, 10, 21, 30, 14, 8, 39, 13, 68, 42],
497+
]
498+
floating_point_vectors = [
499+
ChainedVector([[2.1, -4.6, -2.5], [-5.0, 6.4, 2.0, -0.5], [-6.1, -7.6, -3.2, -4.7, 4.3], [-1.7, 6.4, -8.9, -7.4], [-7.7, -1.4, 3.1, 4.5]]) =>
500+
[2.1, -4.6, -2.5, -5.0, 6.4, 2.0, -0.5, -6.1, -7.6, -3.2, -4.7, 4.3, -1.7, 6.4, -8.9, -7.4, -7.7, -1.4, 3.1, 4.5],
501+
ChainedVector([[-8.5, -1.2, -3.8, 7.5], [8.2, 7.5, -5.3], [-2.7, 0.6, -6.2, 6.1, 1.4]]) =>
502+
[-8.5, -1.2, -3.8, 7.5, 8.2, 7.5, -5.3, -2.7, 0.6, -6.2, 6.1, 1.4],
503+
ChainedVector([[-7.2], [8.1, 2.3, 7.5], [-8.4, -5.7]]) =>
504+
[-7.2, 8.1, 2.3, 7.5, -8.4, -5.7],
505+
ChainedVector([[-3.7, 7.8, -5.0], [0.1], [5.0, -4.1], [-1.6, -0.9, 8.7, -7.8]]) =>
506+
[-3.7, 7.8, -5.0, 0.1, 5.0, -4.1, -1.6, -0.9, 8.7, -7.8],
507+
ChainedVector([[8.6, -2.0], [8.0, 3.4, 3.3], [1.0], [5.4, -2.6, -4.7, 4.4, 4.4], [7.9]]) =>
508+
[8.6, -2.0, 8.0, 3.4, 3.3, 1.0, 5.4, -2.6, -4.7, 4.4, 4.4, 7.9],
509+
ChainedVector([[7.6, 5.9], [7.9, -8.8, -1.5, -0.4, 6.0], [-5.1, -0.4, 4.4, 7.3]]) =>
510+
[7.6, 5.9, 7.9, -8.8, -1.5, -0.4, 6.0, -5.1, -0.4, 4.4, 7.3],
511+
ChainedVector([[3.2, -3.2, 1.2, -1.2, -2.1], [0.5], [6.2], [2.9], [-8.1, 5.8, 4.8, -3.4, -3.1]]) =>
512+
[3.2, -3.2, 1.2, -1.2, -2.1, 0.5, 6.2, 2.9, -8.1, 5.8, 4.8, -3.4, -3.1],
513+
ChainedVector([[-8.0, -1.9, -5.1, -1.4, -8.3], [5.1, -3.7, 6.3, -4.8, -3.3], [-7.0], [-2.4, 4.0, -3.7], [-6.6, -6.9, 2.5, -1.3]]) =>
514+
[-8.0, -1.9, -5.1, -1.4, -8.3, 5.1, -3.7, 6.3, -4.8, -3.3, -7.0, -2.4, 4.0, -3.7, -6.6, -6.9, 2.5, -1.3],
515+
ChainedVector([[-7.5], [-1.5, -5.8, 8.4], [-8.4, -1.9, 2.3, -0.8, -8.5], [0.2, 0.5, -7.4, 2.1, -3.9]]) =>
516+
[-7.5, -1.5, -5.8, 8.4, -8.4, -1.9, 2.3, -0.8, -8.5, 0.2, 0.5, -7.4, 2.1, -3.9],
517+
ChainedVector([[3.9, -8.9], [-0.3, 0.0, 7.3], [-2.9, 8.6, 5.8, 0.5], [0.0, -4.5, 3.3, 0.4, -3.2]]) =>
518+
[3.9, -8.9, -0.3, 0.0, 7.3, -2.9, 8.6, 5.8, 0.5, 0.0, -4.5, 3.3, 0.4, -3.2],
519+
]
520+
rational_vectors = [
521+
ChainedVector(Vector{Rational{Int64}}[[1, 1//2, 1//2, 4//5], [7//10], [1, 1//5, 7//10, 3//10, 1], [3//5], [1]]) =>
522+
Rational{Int64}[1, 1//2, 1//2, 4//5, 7//10, 1, 1//5, 7//10, 3//10, 1, 3//5, 1],
523+
ChainedVector(Vector{Rational{Int64}}[[1//5], [1, 4//5, 1//5], [3//5, 7//10, 3//5], [9//10, 1//5, 7//10, 1//2], [1//2, 7//10, 9//10, 3//5, 7//10]]) =>
524+
Rational{Int64}[1//5, 1, 4//5, 1//5, 3//5, 7//10, 3//5, 9//10, 1//5, 7//10, 1//2, 1//2, 7//10, 9//10, 3//5, 7//10],
525+
ChainedVector(Vector{Rational{Int64}}[[7//10, 1, 1//5, 1//2, 2//5], [1//5, 4//5, 1//2, 1//5], [3//10, 3//10, 1//2], [3//10, 1//10, 4//5, 3//5], [2//5, 7//10, 1, 3//10, 3//10]]) =>
526+
Rational{Int64}[7//10, 1, 1//5, 1//2, 2//5, 1//5, 4//5, 1//2, 1//5, 3//10, 3//10, 1//2, 3//10, 1//10, 4//5, 3//5, 2//5, 7//10, 1, 3//10, 3//10],
527+
ChainedVector(Vector{Rational{Int64}}[[1//10, 4//5], [1//2], [1//10], [4//5, 1, 3//5, 9//10, 9//10]]) =>
528+
Rational{Int64}[1//10, 4//5, 1//2, 1//10, 4//5, 1, 3//5, 9//10, 9//10],
529+
ChainedVector(Vector{Rational{Int64}}[[3//10, 1, 9//10, 3//5], [1, 1], [1, 4//5, 3//5, 9//10]]) =>
530+
Rational{Int64}[3//10, 1, 9//10, 3//5, 1, 1, 1, 4//5, 3//5, 9//10],
531+
ChainedVector(Vector{Rational{Int64}}[[3//10, 7//10], [4//5], [4//5, 1, 1//10, 9//10], [1, 1, 4//5]]) =>
532+
Rational{Int64}[3//10, 7//10, 4//5, 4//5, 1, 1//10, 9//10, 1, 1, 4//5],
533+
ChainedVector(Vector{Rational{Int64}}[[2//5], [3//5, 9//10, 7//10, 9//10], [1//2, 1, 1//10, 1//5], [1//5, 4//5, 7//10, 2//5]]) =>
534+
Rational{Int64}[2//5, 3//5, 9//10, 7//10, 9//10, 1//2, 1, 1//10, 1//5, 1//5, 4//5, 7//10, 2//5],
535+
ChainedVector(Vector{Rational{Int64}}[[7//10], [3//5, 1//5, 2//5, 3//5, 4//5], [4//5], [7//10, 3//5, 7//10, 7//10, 1//10]]) =>
536+
Rational{Int64}[7//10, 3//5, 1//5, 2//5, 3//5, 4//5, 4//5, 7//10, 3//5, 7//10, 7//10, 1//10],
537+
ChainedVector(Vector{Rational{Int64}}[[1//2, 1, 1//2, 9//10, 2//5], [9//10, 1//2, 3//5], [4//5, 7//10], [3//10, 2//5], [9//10, 1]]) =>
538+
Rational{Int64}[1//2, 1, 1//2, 9//10, 2//5, 9//10, 1//2, 3//5, 4//5, 7//10, 3//10, 2//5, 9//10, 1],
539+
ChainedVector(Vector{Rational{Int64}}[[9//10, 3//10, 1//10, 2//5], [4//5], [9//10, 2//5]]) =>
540+
Rational{Int64}[9//10, 3//10, 1//10, 2//5, 4//5, 9//10, 2//5],
541+
542+
]
543+
@testset for (x,y) in Iterators.flatten([int_vectors, floating_point_vectors, rational_vectors])
544+
@test copy(x) == y
545+
@test collect(x) == y
546+
@test length(x) == length(y)
547+
# Floating point tests fail if this is not approx
548+
# See https://github.com/JuliaData/SentinelArrays.jl/pull/99#issuecomment-2171005657
549+
@test sum(x) sum(y)
550+
@test findmax(x) == findmax(y)
551+
@test findmin(x) == findmin(y)
552+
@test maximum(x) == maximum(y)
553+
@test minimum(x) == minimum(y)
554+
@test argmax(x) == argmax(y)
555+
@test argmin(x) == argmin(y)
556+
@test all(>(0),x) == all(>(0),y)
557+
@test any(>(0),x) == any(>(0),y)
558+
@test any(<(0),x) == any(<(0),y)
559+
@test count(>(0),x) == count(>(0),y)
560+
@test count(<(0),x) == count(<(0),y)
561+
@test extrema(inv, x) == extrema(inv, y)
562+
@static if VERSION v"1.6"
563+
@test findmax(x->x+1, x) == findmax(x->x+1, y)
564+
@test findmin(x->x-1, x) == findmin(x->x-1, y)
565+
@test findfirst(isodd, x) == findfirst(isodd, y)
566+
@test findfirst(iseven, x) == findfirst(iseven ,y)
567+
@test findlast(isodd, x) == findlast(isodd, y)
568+
@test findlast(iseven, x) == findlast(iseven ,y)
569+
@test findall(iseven, x) == findall(iseven ,y)
570+
@test findnext(isodd, x, 5) == findnext(isodd, y, 5)
571+
@test findprev(isodd, x, 5) == findprev(isodd, y, 5)
572+
end
573+
@test let (val, idx) = findmax(x)
574+
max_val = maximum(x)
575+
val == max_val == x[idx]
576+
end
577+
@test let (val, idx) = findmin(x)
578+
min_val = minimum(x)
579+
val == min_val == x[idx]
580+
end
581+
@test x[argmax(x)] == maximum(x)
582+
@test x[argmin(x)] == minimum(x)
583+
@test let (val, idx) = findmax(inv, x)
584+
max_val = maximum(inv, x)
585+
val == max_val == inv(x[idx])
586+
end
587+
@test let (val, idx) = findmin(inv, x)
588+
min_val = minimum(inv, x)
589+
val == min_val == inv(x[idx])
590+
end
591+
@test inv(argmax(inv, x)) == maximum(inv, x)
592+
@test inv(argmin(inv, x)) == minimum(inv, x)
593+
end
594+
end
595+
596+
437597
@testset "iteration protocol on ChainedVector" begin
438598
for len in 0:6
439599
cv = ChainedVector([1:len])

0 commit comments

Comments
 (0)