Skip to content

Commit

Permalink
Add pairwise (#627)
Browse files Browse the repository at this point in the history
This generic method takes iterators of vectors and supports skipping
missing values.
It is a more general version of `pairwise` in Distances.jl.
Since methods are compatible, both packages can override a common empty
function defined in StatsAPI.
  • Loading branch information
nalimilan authored May 2, 2021
1 parent 45d65ec commit d18762c
Show file tree
Hide file tree
Showing 6 changed files with 583 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SortingAlgorithms = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[compat]
DataAPI = "1"
DataStructures = "0.10, 0.11, 0.12, 0.13, 0.14, 0.17, 0.18"
Missings = "0.3, 0.4, 1.0"
SortingAlgorithms = "0.3, 1.0"
StatsAPI = "1"
julia = "1"

[extras]
Expand Down
2 changes: 2 additions & 0 deletions docs/src/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ levelsmap
indexmap
indicatormat
StatsBase.midpoints
pairwise
pairwise!
```
4 changes: 4 additions & 0 deletions src/StatsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import LinearAlgebra: BlasReal, BlasFloat
import Statistics: mean, mean!, var, varm, varm!, std, stdm, cov, covm,
cor, corm, cov2cor!, unscaled_covzm, quantile, sqrt!,
median, middle
import StatsAPI: pairwise, pairwise!

## tackle compatibility issues

Expand Down Expand Up @@ -157,6 +158,8 @@ export
indexmap, # construct a map from element to index
levelsmap, # construct a map from n unique elements to [1, ..., n]
indicatormat, # construct indicator matrix
pairwise, # pairwise application of functions
pairwise!, # pairwise! application of functions

# statistical models
CoefTable,
Expand Down Expand Up @@ -228,6 +231,7 @@ include("signalcorr.jl")
include("partialcor.jl")
include("empirical.jl")
include("hist.jl")
include("pairwise.jl")
include("misc.jl")

include("sampling.jl")
Expand Down
313 changes: 313 additions & 0 deletions src/pairwise.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
function _pairwise!(::Val{:none}, f, dest::AbstractMatrix, x, y, symmetric::Bool)
@inbounds for (i, xi) in enumerate(x), (j, yj) in enumerate(y)
symmetric && i > j && continue

# For performance, diagonal is special-cased
if f === cor && eltype(dest) !== Union{} && i == j && xi === yj
# TODO: float() will not be needed after JuliaLang/Statistics.jl#61
dest[i, j] = float(cor(xi))
else
dest[i, j] = f(xi, yj)
end
end
if symmetric
m, n = size(dest)
@inbounds for j in 1:n, i in (j+1):m
dest[i, j] = dest[j, i]
end
end
return dest
end

function check_vectors(x, y, skipmissing::Symbol)
m = length(x)
n = length(y)
if !(all(xi -> xi isa AbstractVector, x) && all(yi -> yi isa AbstractVector, y))
throw(ArgumentError("All entries in x and y must be vectors " *
"when skipmissing=:$skipmissing"))
end
if m > 1
indsx = keys(first(x))
for i in 2:m
keys(x[i]) == indsx ||
throw(ArgumentError("All input vectors must have the same indices"))
end
end
if n > 1
indsy = keys(first(y))
for j in 2:n
keys(y[j]) == indsy ||
throw(ArgumentError("All input vectors must have the same indices"))
end
end
if m > 1 && n > 1
indsx == indsy ||
throw(ArgumentError("All input vectors must have the same indices"))
end
end

function _pairwise!(::Val{:pairwise}, f, dest::AbstractMatrix, x, y, symmetric::Bool)
check_vectors(x, y, :pairwise)
@inbounds for (j, yj) in enumerate(y)
ynminds = .!ismissing.(yj)
@inbounds for (i, xi) in enumerate(x)
symmetric && i > j && continue

if xi === yj
ynm = view(yj, ynminds)
# For performance, diagonal is special-cased
if f === cor && eltype(dest) !== Union{} && i == j
# TODO: float() will not be needed after JuliaLang/Statistics.jl#61
dest[i, j] = float(cor(xi))
else
dest[i, j] = f(ynm, ynm)
end
else
nminds = .!ismissing.(xi) .& ynminds
xnm = view(xi, nminds)
ynm = view(yj, nminds)
dest[i, j] = f(xnm, ynm)
end
end
end
if symmetric
m, n = size(dest)
@inbounds for j in 1:n, i in (j+1):m
dest[i, j] = dest[j, i]
end
end
return dest
end

function _pairwise!(::Val{:listwise}, f, dest::AbstractMatrix, x, y, symmetric::Bool)
check_vectors(x, y, :listwise)
m, n = size(dest)
nminds = .!ismissing.(first(x))
@inbounds for xi in Iterators.drop(x, 1)
nminds .&= .!ismissing.(xi)
end
if x !== y
@inbounds for yj in y
nminds .&= .!ismissing.(yj)
end
end

# Computing integer indices once for all vectors is faster
nminds′ = findall(nminds)
# TODO: check whether wrapping views in a custom array type which asserts
# that entries cannot be `missing` (similar to `skipmissing`)
# could offer better performance
return _pairwise!(Val(:none), f, dest,
[view(xi, nminds′) for xi in x],
[view(yi, nminds′) for yi in y],
symmetric)
end

function _pairwise!(f, dest::AbstractMatrix, x, y;
symmetric::Bool=false, skipmissing::Symbol=:none)
if !(skipmissing in (:none, :pairwise, :listwise))
throw(ArgumentError("skipmissing must be one of :none, :pairwise or :listwise"))
end

x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x)
y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y)
m = length(x′)
n = length(y′)

size(dest) != (m, n) &&
throw(DimensionMismatch("dest has dimensions $(size(dest)) but expected ($m, $n)"))

Base.has_offset_axes(dest) && throw("dest indices must start at 1")

return _pairwise!(Val(skipmissing), f, dest, x′, y′, symmetric)
end

function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmissing}
x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x)
y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y)
m = length(x′)
n = length(y′)

T = Core.Compiler.return_type(f, Tuple{eltype(x′), eltype(y′)})
Tsm = Core.Compiler.return_type((x, y) -> f(disallowmissing(x), disallowmissing(y)),
Tuple{eltype(x′), eltype(y′)})

if skipmissing === :none
dest = Matrix{T}(undef, m, n)
elseif skipmissing in (:pairwise, :listwise)
dest = Matrix{Tsm}(undef, m, n)
else
throw(ArgumentError("skipmissing must be one of :none, :pairwise or :listwise"))
end

# Preserve inferred element type
isempty(dest) && return dest

_pairwise!(f, dest, x′, y′, symmetric=symmetric, skipmissing=skipmissing)

if isconcretetype(eltype(dest))
return dest
else
# Final eltype depends on actual contents (consistent with map and broadcast)
U = mapreduce(typeof, promote_type, dest)
# V is inferred (contrary to U), but it only gives an upper bound for U
V = promote_type(T, Tsm)
return convert(Matrix{U}, dest)::Matrix{<:V}
end
end

"""
pairwise!(f, dest::AbstractMatrix, x[, y];
symmetric::Bool=false, skipmissing::Symbol=:none)
Store in matrix `dest` the result of applying `f` to all possible pairs
of entries in iterators `x` and `y`, and return it. Rows correspond to
entries in `x` and columns to entries in `y`, and `dest` must therefore
be of size `length(x) × length(y)`.
If `y` is omitted then `x` is crossed with itself.
As a special case, if `f` is `cor`, diagonal cells for which entries
from `x` and `y` are identical (according to `===`) are set to one even
in the presence `missing`, `NaN` or `Inf` entries.
# Keyword arguments
- `symmetric::Bool=false`: If `true`, `f` is only called to compute
for the lower triangle of the matrix, and these values are copied
to fill the upper triangle. Only allowed when `y` is omitted.
Defaults to `true` when `f` is `cor` or `cov`.
- `skipmissing::Symbol=:none`: If `:none` (the default), missing values
in inputs are passed to `f` without any modification.
Use `:pairwise` to skip entries with a `missing` value in either
of the two vectors passed to `f` for a given pair of vectors in `x` and `y`.
Use `:listwise` to skip entries with a `missing` value in any of the
vectors in `x` or `y`; note that this might drop a large part of entries.
Only allowed when entries in `x` and `y` are vectors.
# Examples
```jldoctest
julia> using StatsBase, Statistics
julia> dest = zeros(3, 3);
julia> x = [1 3 7
2 5 6
3 8 4
4 6 2];
julia> pairwise!(cor, dest, eachcol(x));
julia> dest
3×3 Matrix{Float64}:
1.0 0.744208 -0.989778
0.744208 1.0 -0.68605
-0.989778 -0.68605 1.0
julia> y = [1 3 missing
2 5 6
3 missing 2
4 6 2];
julia> pairwise!(cor, dest, eachcol(y), skipmissing=:pairwise);
julia> dest
3×3 Matrix{Float64}:
1.0 0.928571 -0.866025
0.928571 1.0 -1.0
-0.866025 -1.0 1.0
```
"""
function pairwise!(f, dest::AbstractMatrix, x, y=x;
symmetric::Bool=false, skipmissing::Symbol=:none)
if symmetric && x !== y
throw(ArgumentError("symmetric=true only makes sense passing " *
"a single set of variables (x === y)"))
end

return _pairwise!(f, dest, x, y, symmetric=symmetric, skipmissing=skipmissing)
end

"""
pairwise(f, x[, y];
symmetric::Bool=false, skipmissing::Symbol=:none)
Return a matrix holding the result of applying `f` to all possible pairs
of entries in iterators `x` and `y`. Rows correspond to
entries in `x` and columns to entries in `y`. If `y` is omitted then a
square matrix crossing `x` with itself is returned.
As a special case, if `f` is `cor`, diagonal cells for which entries
from `x` and `y` are identical (according to `===`) are set to one even
in the presence `missing`, `NaN` or `Inf` entries.
# Keyword arguments
- `symmetric::Bool=false`: If `true`, `f` is only called to compute
for the lower triangle of the matrix, and these values are copied
to fill the upper triangle. Only allowed when `y` is omitted.
Defaults to `true` when `f` is `cor` or `cov`.
- `skipmissing::Symbol=:none`: If `:none` (the default), missing values
in inputs are passed to `f` without any modification.
Use `:pairwise` to skip entries with a `missing` value in either
of the two vectors passed to `f` for a given pair of vectors in `x` and `y`.
Use `:listwise` to skip entries with a `missing` value in any of the
vectors in `x` or `y`; note that this might drop a large part of entries.
Only allowed when entries in `x` and `y` are vectors.
# Examples
```jldoctest
julia> using StatsBase, Statistics
julia> x = [1 3 7
2 5 6
3 8 4
4 6 2];
julia> pairwise(cor, eachcol(x))
3×3 Matrix{Float64}:
1.0 0.744208 -0.989778
0.744208 1.0 -0.68605
-0.989778 -0.68605 1.0
julia> y = [1 3 missing
2 5 6
3 missing 2
4 6 2];
julia> pairwise(cor, eachcol(y), skipmissing=:pairwise)
3×3 Matrix{Float64}:
1.0 0.928571 -0.866025
0.928571 1.0 -1.0
-0.866025 -1.0 1.0
```
"""
function pairwise(f, x, y=x; symmetric::Bool=false, skipmissing::Symbol=:none)
if symmetric && x !== y
throw(ArgumentError("symmetric=true only makes sense passing " *
"a single set of variables (x === y)"))
end

return _pairwise(Val(skipmissing), f, x, y, symmetric)
end

# cov(x) is faster than cov(x, x)
_cov(x, y) = x === y ? cov(x) : cov(x, y)

pairwise!(::typeof(cov), dest::AbstractMatrix, x, y;
symmetric::Bool=false, skipmissing::Symbol=:none) =
pairwise!(_cov, dest, x, y, symmetric=symmetric, skipmissing=skipmissing)

pairwise(::typeof(cov), x, y; symmetric::Bool=false, skipmissing::Symbol=:none) =
pairwise(_cov, x, y, symmetric=symmetric, skipmissing=skipmissing)

pairwise!(::typeof(cov), dest::AbstractMatrix, x;
symmetric::Bool=true, skipmissing::Symbol=:none) =
pairwise!(_cov, dest, x, x, symmetric=symmetric, skipmissing=skipmissing)

pairwise(::typeof(cov), x; symmetric::Bool=true, skipmissing::Symbol=:none) =
pairwise(_cov, x, x, symmetric=symmetric, skipmissing=skipmissing)

pairwise!(::typeof(cor), dest::AbstractMatrix, x;
symmetric::Bool=true, skipmissing::Symbol=:none) =
pairwise!(cor, dest, x, x, symmetric=symmetric, skipmissing=skipmissing)

pairwise(::typeof(cor), x; symmetric::Bool=true, skipmissing::Symbol=:none) =
pairwise(cor, x, x, symmetric=symmetric, skipmissing=skipmissing)
Loading

0 comments on commit d18762c

Please sign in to comment.