-
Notifications
You must be signed in to change notification settings - Fork 194
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This generic method takes iterators of vectors and supports skipping missing values. It is a more general version of `pairwise` in Distances.jl. Since methods are compatible, both packages can override a common empty function defined in another API package.
- Loading branch information
Showing
4 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,5 @@ levelsmap | |
indexmap | ||
indicatormat | ||
StatsBase.midpoints | ||
pairwise | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
function _pairwise!(::Val{:none}, res::AbstractMatrix, f, x, y, symmetric::Bool) | ||
m, n = size(res) | ||
for j in 1:n, i in 1:m | ||
symmetric && i > j && continue | ||
|
||
# For performance, diagonal is special-cased | ||
if f === cor && i == j && x[i] === y[j] | ||
# If the type isn't concrete, 1 may not be converted to the right type | ||
# and the final matrix will have an abstract eltype | ||
# (missings are propagated via the second branch, but NaNs are ignored) | ||
res[i, j] = isconcretetype(eltype(res)) ? 1 : one(f(x[i], y[j])) | ||
else | ||
res[i, j] = f(x[i], y[j]) | ||
end | ||
end | ||
if symmetric | ||
for j in 1:n, i in (j+1):m | ||
res[i, j] = res[j, i] | ||
end | ||
end | ||
return res | ||
end | ||
|
||
function _pairwise!(::Val{:pairwise}, res::AbstractMatrix, f, x, y, symmetric::Bool) | ||
m, n = size(res) | ||
for j in 1:n | ||
ynminds = .!ismissing.(y[j]) | ||
for i in 1:m | ||
symmetric && i > j && continue | ||
|
||
if x[i] === y[j] | ||
ynm = view(y[j], ynminds) | ||
# For performance, diagonal is special-cased | ||
if f === cor && i == j | ||
# If the type isn't concrete, 1 may not be converted to the right type | ||
# and the final matrix will have an abstract eltype | ||
# (missings and NaNs are ignored) | ||
res[i, j] = isconcretetype(eltype(res)) ? 1 : one(f(ynm, ynm)) | ||
else | ||
res[i, j] = f(ynm, ynm) | ||
end | ||
else | ||
nminds = .!ismissing.(x[i]) .& ynminds | ||
xnm = view(x[i], nminds) | ||
ynm = view(y[j], nminds) | ||
res[i, j] = f(xnm, ynm) | ||
end | ||
end | ||
end | ||
if symmetric | ||
for j in 1:n, i in (j+1):m | ||
res[i, j] = res[j, i] | ||
end | ||
end | ||
return res | ||
end | ||
|
||
function _pairwise!(::Val{:listwise}, res::AbstractMatrix, f, x, y, symmetric::Bool) | ||
m, n = size(res) | ||
nminds = .!ismissing.(x[1]) | ||
for i in 2:m | ||
nminds .&= .!ismissing.(x[i]) | ||
end | ||
if x !== y | ||
for j in 1:n | ||
nminds .&= .!ismissing.(y[j]) | ||
end | ||
end | ||
|
||
# Computing integer indices once for all vectors is faster | ||
nminds′ = findall(nminds) | ||
# TODO: check whether wrapping views in a custom array type which asserts | ||
# that entries cannot be `missing` (similar to `skipmissing`) | ||
# could offer better performance | ||
return _pairwise!(Val(:none), res, f, | ||
[view(xi, nminds′) for xi in x], | ||
[view(yi, nminds′) for yi in y], | ||
symmetric) | ||
end | ||
|
||
function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmissing} | ||
inds = keys(first(x)) | ||
if symmetric && x !== y | ||
throw(ArgumentError("symmetric=true only makes sense passing " * | ||
"a single set of variables (x === y)")) | ||
end | ||
for xi in x | ||
keys(xi) == inds || | ||
throw(ArgumentError("All input vectors must have the same indices")) | ||
end | ||
for yi in y | ||
keys(yi) == inds || | ||
throw(ArgumentError("All input vectors must have the same indices")) | ||
end | ||
x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x) | ||
y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y) | ||
m = length(x) | ||
n = length(y) | ||
|
||
T = Core.Compiler.return_type(f, Tuple{eltype(x′), eltype(y′)}) | ||
Tsm = Core.Compiler.return_type((x, y) -> f(disallowmissing(x), disallowmissing(y)), | ||
Tuple{eltype(x′), eltype(y′)}) | ||
|
||
if skipmissing === :none | ||
res = Matrix{T}(undef, m, n) | ||
_pairwise!(Val(:none), res, f, x′, y′, symmetric) | ||
elseif skipmissing === :pairwise | ||
res = Matrix{Tsm}(undef, m, n) | ||
_pairwise!(Val(:pairwise), res, f, x′, y′, symmetric) | ||
elseif skipmissing === :listwise | ||
res = Matrix{Tsm}(undef, m, n) | ||
_pairwise!(Val(:listwise), res, f, x′, y′, symmetric) | ||
else | ||
throw(ArgumentError("skipmissing must be one of :none, :pairwise or :listwise")) | ||
end | ||
|
||
# identity.(res) lets broadcasting compute a concrete element type | ||
# TODO: using promote_type rather than typejoin (which broadcast uses) would make sense | ||
# Once identity.(res) is inferred automatically (JuliaLang/julia#30485), | ||
# the assertion can be removed | ||
@static if VERSION >= v"1.6.0-DEV" | ||
U = Base.Broadcast.promote_typejoin_union(Union{T, Tsm}) | ||
return (isconcretetype(eltype(res)) ? res : identity.(res))::Matrix{<:U} | ||
else | ||
return (isconcretetype(eltype(res)) ? res : identity.(res)) | ||
end | ||
end | ||
|
||
""" | ||
pairwise(f, x[, y], symmetric::Bool=false, skipmissing::Symbol=:none) | ||
Return a matrix holding the result of applying `f` to all possible pairs | ||
of vectors in iterators `x` and `y`. Rows correspond to | ||
vectors in `x` and columns to vectors in `y`. If `y` is omitted then a | ||
square matrix crossing `x` with itself is returned. | ||
As a special case, if `f` is `cor`, diagonal cells are set to 1 even in | ||
the presence `NaN` or `Inf` entries (but `missing` is propagated unless | ||
`skipmissing` is different from `:none`). | ||
# Keyword arguments | ||
- `symmetric::Bool=false`: If `true`, `f` is only called to compute | ||
for the lower triangle of the matrix, and these values are copied | ||
to fill the upper triangle. Only possible when `y` is omitted. | ||
This is automatically set to `true` when `f` is `cor` or `cov`. | ||
- `skipmissing::Symbol=:none`: If `:none` (the default), missing values | ||
in input vectors are passed to `f` without any modification. | ||
Use `:pairwise` to skip entries with a `missing` value in either | ||
of the two vectors passed to `f` for a given pair of vectors in `x` and `y`. | ||
Use `:listwise` to skip entries with a `missing` value in any of the | ||
vectors in `x` or `y`; note that this is likely to drop a large part of | ||
entries. | ||
""" | ||
pairwise(f, x, y=x; symmetric::Bool=false, skipmissing::Symbol=:none) = | ||
_pairwise(Val(skipmissing), f, x, y, symmetric) | ||
|
||
# cov(x) is faster than cov(x, x) | ||
pairwise(::typeof(cov), x, y; symmetric::Bool=false, skipmissing::Symbol=:none) = | ||
pairwise((x, y) -> x === y ? cov(x) : cov(x, y), x, y, | ||
symmetric=symmetric, skipmissing=skipmissing) | ||
|
||
pairwise(::typeof(cor), x; symmetric::Bool=true, skipmissing::Symbol=:none) = | ||
pairwise(cor, x, x, symmetric=symmetric, skipmissing=skipmissing) | ||
|
||
pairwise(::typeof(cov), x; symmetric::Bool=true, skipmissing::Symbol=:none) = | ||
pairwise(cov, x, x, symmetric=symmetric, skipmissing=skipmissing) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
using StatsBase | ||
using Test, Random, Statistics, LinearAlgebra | ||
using Missings | ||
|
||
const ≅ = isequal | ||
|
||
Random.seed!(1) | ||
|
||
# to avoid using specialized method | ||
arbitrary_fun(x, y) = cor(x, y) | ||
|
||
@testset "pairwise with $f" for f in (arbitrary_fun, cor, cov) | ||
@testset "basic interface" begin | ||
x = [rand(10) for _ in 1:4] | ||
y = [rand(Float32, 10) for _ in 1:5] | ||
# to test case where inference of returned eltype fails | ||
z = [Vector{Any}(rand(Float32, 10)) for _ in 1:5] | ||
|
||
res = @inferred pairwise(f, x, y) | ||
@test res isa Matrix{Float64} | ||
@test res == [f(xi, yi) for xi in x, yi in y] | ||
|
||
res = pairwise(f, y, z) | ||
@test res isa Matrix{Float32} | ||
@test res == [f(yi, zi) for yi in y, zi in z] | ||
|
||
res = pairwise(f, Any[[1.0, 2.0, 3.0], [1.0f0, 3.0f0, 10.5f0]]) | ||
@test res isa Matrix{AbstractFloat} | ||
@test res == [f(xi, yi) for xi in ([1.0, 2.0, 3.0], [1.0f0, 3.0f0, 10.5f0]), | ||
yi in ([1.0, 2.0, 3.0], [1.0f0, 3.0f0, 10.5f0])] | ||
@test typeof.(res) == [Float64 Float64 | ||
Float64 Float32] | ||
|
||
@inferred pairwise(f, x, y) | ||
|
||
@test_throws ArgumentError pairwise(f, [Int[]], [Int[]]) | ||
end | ||
|
||
@testset "missing values handling interface" begin | ||
xm = [ifelse.(rand(100) .> 0.9, missing, rand(100)) for _ in 1:4] | ||
ym = [ifelse.(rand(100) .> 0.9, missing, rand(Float32, 100)) for _ in 1:4] | ||
zm = [ifelse.(rand(100) .> 0.9, missing, rand(Float32, 100)) for _ in 1:4] | ||
|
||
res = pairwise(f, xm, ym) | ||
@test res isa Matrix{Missing} | ||
@test res ≅ [missing for xi in xm, yi in ym] | ||
|
||
res = pairwise(f, xm, ym, skipmissing=:pairwise) | ||
@test res isa Matrix{Float64} | ||
@test isapprox(res, [f(collect.(skipmissings(xi, yi))...) for xi in xm, yi in ym], | ||
rtol=1e-6) | ||
|
||
res = pairwise(f, ym, zm, skipmissing=:pairwise) | ||
@test res isa Matrix{Float32} | ||
@test isapprox(res, [f(collect.(skipmissings(yi, zi))...) for yi in ym, zi in zm], | ||
rtol=1e-6) | ||
|
||
nminds = mapreduce(x -> .!ismissing.(x), | ||
(x, y) -> x .& y, | ||
[xm; ym]) | ||
res = pairwise(f, xm, ym, skipmissing=:listwise) | ||
@test res isa Matrix{Float64} | ||
@test isapprox(res, [f(view(xi, nminds), view(yi, nminds)) for xi in xm, yi in ym], | ||
rtol=1e-6) | ||
|
||
if VERSION >= v"1.6.0-DEV" | ||
# inference of cor fails so use an inferrable function | ||
# to check that pairwise itself is inferrable | ||
for skipmissing in (:none, :pairwise, :listwise) | ||
g(x, y=x) = pairwise((x, y) -> x[1] * y[1], x, y, skipmissing=skipmissing) | ||
@test Core.Compiler.return_type(g, Tuple{Vector{Vector{Union{Float64, Missing}}}}) == | ||
Core.Compiler.return_type(g, Tuple{Vector{Vector{Union{Float64, Missing}}}, | ||
Vector{Vector{Union{Float64, Missing}}}}) == | ||
Matrix{<: Union{Float64, Missing}} | ||
if skipmissing in (:pairwise, :listwise) | ||
@test_broken Core.Compiler.return_type(g, Tuple{Vector{Vector{Union{Float64, Missing}}}}) == | ||
Core.Compiler.return_type(g, Tuple{Vector{Vector{Union{Float64, Missing}}}, | ||
Vector{Vector{Union{Float64, Missing}}}}) == | ||
Matrix{Float64} | ||
end | ||
end | ||
end | ||
|
||
@test_throws ArgumentError pairwise(f, xm, ym, skipmissing=:something) | ||
|
||
# variable with only missings | ||
xm = [fill(missing, 10), rand(10)] | ||
ym = [rand(10), rand(10)] | ||
|
||
res = pairwise(f, xm, ym) | ||
@test res isa Matrix{Union{Float64, Missing}} | ||
@test res ≅ [f(xi, yi) for xi in xm, yi in ym] | ||
|
||
if VERSION >= v"1.5" # Fails with UndefVarError on Julia 1.0 | ||
@test_throws ArgumentError pairwise(f, xm, ym, skipmissing=:pairwise) | ||
@test_throws ArgumentError pairwise(f, xm, ym, skipmissing=:listwise) | ||
end | ||
end | ||
|
||
@testset "iterators" begin | ||
x = (v for v in [rand(10) for _ in 1:4]) | ||
y = (v for v in [rand(10) for _ in 1:4]) | ||
|
||
@test pairwise(f, x, y) == pairwise(f, collect(x), collect(y)) | ||
@test pairwise(f, x) == pairwise(f, collect(x)) | ||
end | ||
|
||
@testset "two-argument method" begin | ||
x = [rand(10) for _ in 1:4] | ||
@test pairwise(f, x) == pairwise(f, x, x) | ||
end | ||
|
||
@testset "symmetric" begin | ||
x = [rand(10) for _ in 1:4] | ||
y = [rand(10) for _ in 1:4] | ||
@test pairwise(f, x, x, symmetric=true) == | ||
pairwise(f, x, symmetric=true) == | ||
Symmetric(pairwise(f, x, x), :U) | ||
@test_throws ArgumentError pairwise(f, x, y, symmetric=true) | ||
end | ||
|
||
@testset "cor corner cases" begin | ||
# Integer inputs must give a Float64 output | ||
res = pairwise(cor, [[1, 2, 3], [1, 5, 2]]) | ||
@test res isa Matrix{Float64} | ||
@test res == [cor(xi, yi) for xi in ([1, 2, 3], [1, 5, 2]), | ||
yi in ([1, 2, 3], [1, 5, 2])] | ||
|
||
# NaNs are ignored for the diagonal | ||
res = pairwise(cor, [[1, 2, NaN], [1, 5, 2]]) | ||
@test res isa Matrix{Float64} | ||
@test res ≅ [1.0 NaN | ||
NaN 1.0] | ||
|
||
# missings are propagated even for the diagonal | ||
res = pairwise(cor, [[1, 2, 7], [1, 5, missing]]) | ||
@test res isa Matrix{Union{Float64, Missing}} | ||
@test res ≅ [1.0 missing | ||
missing missing] | ||
|
||
for sm in (:pairwise, :listwise) | ||
res = pairwise(cor, [[1, 2, NaN, 4], [1, 5, 5, missing]], skipmissing=sm) | ||
@test res isa Matrix{Float64} | ||
@test res ≅ [1.0 NaN | ||
NaN 1.0] | ||
end | ||
end | ||
end |