diff --git a/Project.toml b/Project.toml index 3cfde3b..e51fa0c 100644 --- a/Project.toml +++ b/Project.toml @@ -8,8 +8,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" [compat] -julia = "1" StatsAPI = "1" +julia = "1" [extras] OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" diff --git a/src/generic.jl b/src/generic.jl index 66f01c8..a74c8ef 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -43,6 +43,58 @@ _eltype(::Type{Union{Missing, T}}) where {T} = Union{Missing, T} __eltype(::Base.HasEltype, a) = _eltype(eltype(a)) __eltype(::Base.EltypeUnknown, a) = _eltype(typeof(first(a))) + +abstract type AbstractEvaluateStrategy end +struct Broadcasting <: AbstractEvaluateStrategy end +struct MapReduce1 <: AbstractEvaluateStrategy end + +""" + evaluate_strategy(d::PreMetric, a, b) + +Infer the optimal strategy to evaluate `d(a, b)`. + +Currently, two strategies are provided in Distances: + +- `Broadcasting`: evaluate each pair by broadcasting, this is usually performant for arrays + that has slow scalar indexing, e.g., `CUDA.CuArray`. But it introduces an extra memory + allocation due to large intermediate result. For `Euclidean`, this is almost equivalent + to `sqrt(sum(abs2, a - b))`. +- `MapReduce1`: use a single-thread version of mapreduce. This has minimal memory allocation. + For `Euclidean`, this is almost equivalent to `sqrt(mapreduce((x,y)->abs2(x-y), +, a, b))`. + +# Example + +This function is non-exported. Packages that provides custom array types can provide +specializations for this function and could implement their own evaluation strategy +for specific (pre)metric types. + +For example, these delegates distance evaluation for `CuArray` to the `Broadcasting` strategy. + +```julia +evaluate_strategy(d::Distances.UnionMetrics, a::CuArray, b::CuArray) = Distances.Broadcasting() +evaluate_strategy(d::Distances.UnionMetrics, a, b::CuArray) = Distances.Broadcasting() +evaluate_strategy(d::Distances.UnionMetrics, a::CuArray, b) = Distances.Broadcasting() +``` + +!!! note + Currently, only `Distances.UnionMetrics` respect the result of this function. + +Adding a new implementation strategy for `UnionMetrics` types can be done by adding new methods +to `Distances._evaluate`. For example, + +```julia +struct AnotherFancyStrategy <: Distances.AbstractEvaluateStrategy +function Distances._evaluate(::AnotherFancyStrategy, d::UnionMetrics, a, b, p) + # implementation details +end +``` + +The `_evaluate` function belongs to implementation detail that normal users shouldn't +call directly. But it is considered as a stable API so package developer can add new +strategy implementation to it. +""" +evaluate_strategy(::PreMetric, ::Any, ::Any) = MapReduce1() + # Generic column-wise evaluation """ diff --git a/src/metrics.jl b/src/metrics.jl index 693c76f..067711e 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -220,13 +220,27 @@ result_type(dist::UnionMetrics, ::Type{Ta}, ::Type{Tb}, ::Nothing) where {Ta,Tb} result_type(dist::UnionMetrics, ::Type{Ta}, ::Type{Tb}, p) where {Ta,Tb} = typeof(_evaluate(dist, oneunit(Ta), oneunit(Tb), oneunit(_eltype(p)))) -Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b) - _evaluate(d, a, b, parameters(d)) +Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p=parameters(d)) + _evaluate(evaluate_strategy(d, a, b), d, a, b, p) +end +for M in (metrics..., weightedmetrics...) + @eval @inline (dist::$M)(a, b) = _evaluate(dist, a, b) end # breaks the implementation into eval_start, eval_op, eval_reduce and eval_end +function _evaluate(::Broadcasting, d::UnionMetrics, a, b, ::Nothing) + map_op(x,y) = eval_op(d, x, y) + reduce_op(x, y) = eval_reduce(d, x, y) + eval_end(d, reduce(reduce_op, map_op.(a, b); init=eval_start(d, a, b))) +end +function _evaluate(::Broadcasting, d::UnionMetrics, a, b, p) + map_op(x,y,p) = eval_op(d, x, y, p) + reduce_op(x, y) = eval_reduce(d, x, y) + eval_end(d, reduce(reduce_op, map_op.(a, b, p); init=eval_start(d, a, b))) +end +_evaluate(::AbstractEvaluateStrategy, d::UnionMetrics, a, b, p) = error("Not implemented.") -Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, ::Nothing) +Base.@propagate_inbounds function _evaluate(::MapReduce1, d::UnionMetrics, a, b, ::Nothing) @boundscheck if length(a) != length(b) throw(DimensionMismatch("first collection has length $(length(a)) which does not match the length of the second, $(length(b)).")) end @@ -239,7 +253,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, ::Nothing) end return eval_end(d, s) end -Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray, ::Nothing) +Base.@propagate_inbounds function _evaluate(::MapReduce1, d::UnionMetrics, a::AbstractArray, b::AbstractArray, ::Nothing) @boundscheck if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end @@ -263,7 +277,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b end end -Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p) +Base.@propagate_inbounds function _evaluate(::MapReduce1, d::UnionMetrics, a, b, p) @boundscheck if length(a) != length(b) throw(DimensionMismatch("first collection has length $(length(a)) which does not match the length of the second, $(length(b)).")) end @@ -279,7 +293,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p) end return eval_end(d, s) end -Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray, p::AbstractArray) +Base.@propagate_inbounds function _evaluate(::MapReduce1, d::UnionMetrics, a::AbstractArray, b::AbstractArray, p::AbstractArray) @boundscheck if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end @@ -308,8 +322,8 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b end end -_evaluate(dist::UnionMetrics, a::Number, b::Number, ::Nothing) = eval_end(dist, eval_op(dist, a, b)) -function _evaluate(dist::UnionMetrics, a::Number, b::Number, p) +_evaluate(::MapReduce1, dist::UnionMetrics, a::Number, b::Number, ::Nothing) = eval_end(dist, eval_op(dist, a, b)) +function _evaluate(::MapReduce1, dist::UnionMetrics, a::Number, b::Number, p) length(p) != 1 && throw(DimensionMismatch("inputs are scalars but parameters have length $(length(p)).")) eval_end(dist, eval_op(dist, a, b, first(p))) end @@ -324,10 +338,6 @@ _eval_start(d::UnionMetrics, ::Type{Ta}, ::Type{Tb}, p) where {Ta,Tb} = eval_reduce(::UnionMetrics, s1, s2) = s1 + s2 eval_end(::UnionMetrics, s) = s -for M in (metrics..., weightedmetrics...) - @eval @inline (dist::$M)(a, b) = _evaluate(dist, a, b, parameters(dist)) -end - # Euclidean @inline eval_op(::Euclidean, ai, bi) = abs2(ai - bi) eval_end(::Euclidean, s) = sqrt(s) @@ -373,7 +383,14 @@ totalvariation(a, b) = TotalVariation()(a, b) @inline eval_op(::Chebyshev, ai, bi) = abs(ai - bi) @inline eval_reduce(::Chebyshev, s1, s2) = max(s1, s2) # if only NaN, will output NaN -Base.@propagate_inbounds eval_start(::Chebyshev, a, b) = abs(first(a) - first(b)) +Base.@propagate_inbounds function eval_start(d::Chebyshev, a, b) + T = result_type(d, a, b) + if any(isnan, a) || any(isnan, b) + return convert(T, NaN) + else + zero(T) # lower bound of chebyshev distance + end +end chebyshev(a, b) = Chebyshev()(a, b) # Minkowski diff --git a/test/test_dists.jl b/test/test_dists.jl index fedb397..c1c6eda 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -23,7 +23,14 @@ end function test_metricity(dist, x, y, z) @testset "Test metricity of $(typeof(dist))" begin - @test dist(x, y) == evaluate(dist, x, y) + d = dist(x, y) + @test d == evaluate(dist, x, y) + if d isa Distances.UnionMetrics + # currently only UnionMetrics supports this strategy trait + d_vec = Distances._evaluate(Distances.Broadcasting(), dist, x, y, Distances.parameters(dist)) + d_scalar = Distances._evaluate(Distances.MapReduce1(), dist, x, y, Distances.parameters(dist)) + @test d_vec ≈ d_scalar + end dxy = dist(x, y) dxz = dist(x, z)