Skip to content

Commit a86cb62

Browse files
authored
fix #39203, 2-arg findmax should return index instead of value (#41076)
1 parent 355b66a commit a86cb62

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

base/reduce.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -771,11 +771,11 @@ minimum(a; kw...) = mapreduce(identity, min, a; kw...)
771771
## findmax, findmin, argmax & argmin
772772

773773
"""
774-
findmax(f, domain) -> (f(x), x)
774+
findmax(f, domain) -> (f(x), index)
775775
776-
Returns a pair of a value in the codomain (outputs of `f`) and the corresponding
777-
value in the `domain` (inputs to `f`) such that `f(x)` is maximised. If there
778-
are multiple maximal points, then the first one will be returned.
776+
Returns a pair of a value in the codomain (outputs of `f`) and the index of
777+
the corresponding value in the `domain` (inputs to `f`) such that `f(x)` is maximised.
778+
If there are multiple maximal points, then the first one will be returned.
779779
780780
`domain` must be a non-empty iterable.
781781
@@ -788,20 +788,20 @@ Values are compared with `isless`.
788788
789789
```jldoctest
790790
julia> findmax(identity, 5:9)
791-
(9, 9)
791+
(9, 5)
792792
793793
julia> findmax(-, 1:10)
794794
(-1, 1)
795795
796-
julia> findmax(first, [(1, :a), (2, :b), (2, :c)])
797-
(2, (2, :b))
796+
julia> findmax(first, [(1, :a), (3, :b), (3, :c)])
797+
(3, 2)
798798
799799
julia> findmax(cos, 0:π/2:2π)
800-
(1.0, 0.0)
800+
(1.0, 1)
801801
```
802802
"""
803-
findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)
804-
_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m)
803+
findmax(f, domain) = mapfoldl( ((k, v),) -> (f(v), k), _rf_findmax, pairs(domain) )
804+
_rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im)
805805

806806
"""
807807
findmax(itr) -> (x, index)
@@ -826,14 +826,14 @@ julia> findmax([1, 7, 7, NaN])
826826
```
827827
"""
828828
findmax(itr) = _findmax(itr, :)
829-
_findmax(a, ::Colon) = mapfoldl( ((k, v),) -> (v, k), _rf_findmax, pairs(a) )
829+
_findmax(a, ::Colon) = findmax(identity, a)
830830

831831
"""
832-
findmin(f, domain) -> (f(x), x)
832+
findmin(f, domain) -> (f(x), index)
833833
834-
Returns a pair of a value in the codomain (outputs of `f`) and the corresponding
835-
value in the `domain` (inputs to `f`) such that `f(x)` is minimised. If there
836-
are multiple minimal points, then the first one will be returned.
834+
Returns a pair of a value in the codomain (outputs of `f`) and the index of
835+
the corresponding value in the `domain` (inputs to `f`) such that `f(x)` is minimised.
836+
If there are multiple minimal points, then the first one will be returned.
837837
838838
`domain` must be a non-empty iterable.
839839
@@ -846,21 +846,21 @@ are multiple minimal points, then the first one will be returned.
846846
847847
```jldoctest
848848
julia> findmin(identity, 5:9)
849-
(5, 5)
849+
(5, 1)
850850
851851
julia> findmin(-, 1:10)
852852
(-10, 10)
853853
854-
julia> findmin(first, [(1, :a), (1, :b), (2, :c)])
855-
(1, (1, :a))
854+
julia> findmin(first, [(2, :a), (2, :b), (3, :c)])
855+
(2, 1)
856856
857857
julia> findmin(cos, 0:π/2:2π)
858-
(-1.0, 3.141592653589793)
858+
(-1.0, 3)
859859
```
860860
861861
"""
862-
findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)
863-
_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m)
862+
findmin(f, domain) = mapfoldl( ((k, v),) -> (f(v), k), _rf_findmin, pairs(domain) )
863+
_rf_findmin((fm, im), (fx, ix)) = isgreater(fm, fx) ? (fx, ix) : (fm, im)
864864

865865
"""
866866
findmin(itr) -> (x, index)
@@ -885,7 +885,7 @@ julia> findmin([1, 7, 7, NaN])
885885
```
886886
"""
887887
findmin(itr) = _findmin(itr, :)
888-
_findmin(a, ::Colon) = mapfoldl( ((k, v),) -> (v, k), _rf_findmin, pairs(a) )
888+
_findmin(a, ::Colon) = findmin(identity, a)
889889

890890
"""
891891
argmax(f, domain)
@@ -909,7 +909,7 @@ julia> argmax(cos, 0:π/2:2π)
909909
0.0
910910
```
911911
"""
912-
argmax(f, domain) = findmax(f, domain)[2]
912+
argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2]
913913

914914
"""
915915
argmax(itr)
@@ -962,7 +962,7 @@ julia> argmin(acos, 0:0.1:1)
962962
1.0
963963
```
964964
"""
965-
argmin(f, domain) = findmin(f, domain)[2]
965+
argmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)[2]
966966

967967
"""
968968
argmin(itr)

test/reduce.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,22 +391,22 @@ end
391391

392392
@testset "findmin(f, domain)" begin
393393
@test findmin(-, 1:10) == (-10, 10)
394-
@test findmin(identity, [1, 2, 3, missing]) === (missing, missing)
395-
@test findmin(identity, [1, NaN, 3, missing]) === (missing, missing)
396-
@test findmin(identity, [1, missing, NaN, 3]) === (missing, missing)
397-
@test findmin(identity, [1, NaN, 3]) === (NaN, NaN)
398-
@test findmin(identity, [1, 3, NaN]) === (NaN, NaN)
399-
@test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π))
394+
@test findmin(identity, [1, 2, 3, missing]) === (missing, 4)
395+
@test findmin(identity, [1, NaN, 3, missing]) === (missing, 4)
396+
@test findmin(identity, [1, missing, NaN, 3]) === (missing, 2)
397+
@test findmin(identity, [1, NaN, 3]) === (NaN, 2)
398+
@test findmin(identity, [1, 3, NaN]) === (NaN, 3)
399+
@test findmin(cos, 0:π/2:2π) == (-1.0, 3)
400400
end
401401

402402
@testset "findmax(f, domain)" begin
403403
@test findmax(-, 1:10) == (-1, 1)
404-
@test findmax(identity, [1, 2, 3, missing]) === (missing, missing)
405-
@test findmax(identity, [1, NaN, 3, missing]) === (missing, missing)
406-
@test findmax(identity, [1, missing, NaN, 3]) === (missing, missing)
407-
@test findmax(identity, [1, NaN, 3]) === (NaN, NaN)
408-
@test findmax(identity, [1, 3, NaN]) === (NaN, NaN)
409-
@test findmax(cos, 0:π/2:2π) == (1.0, 0.0)
404+
@test findmax(identity, [1, 2, 3, missing]) === (missing, 4)
405+
@test findmax(identity, [1, NaN, 3, missing]) === (missing, 4)
406+
@test findmax(identity, [1, missing, NaN, 3]) === (missing, 2)
407+
@test findmax(identity, [1, NaN, 3]) === (NaN, 2)
408+
@test findmax(identity, [1, 3, NaN]) === (NaN, 3)
409+
@test findmax(cos, 0:π/2:2π) == (1.0, 1)
410410
end
411411

412412
@testset "argmin(f, domain)" begin

0 commit comments

Comments
 (0)