diff --git a/src/scalarstats.jl b/src/scalarstats.jl index 924d681cf..1e4406c8a 100644 --- a/src/scalarstats.jl +++ b/src/scalarstats.jl @@ -261,36 +261,101 @@ realXcY(x::Real, y::Real) = x*y realXcY(x::Complex, y::Complex) = real(x)*real(y) + imag(x)*imag(y) """ - sem(x) + sem(x; mean=nothing) + sem(x::AbstractArray[, weights::AbstractWeights]; mean=nothing) -Return the standard error of the mean of collection `x`, -i.e. `sqrt(var(x, corrected=true) / length(x))`. +Return the standard error of the mean for a collection `x`. +A pre-computed `mean` may be provided. + +When not using weights, this is the (sample) standard deviation +divided by the sample size. If weights are used, the +variance of the sample mean is calculated as follows: + +* `AnalyticWeights`: Not implemented. +* `FrequencyWeights`: ``\\frac{\\sum_{i=1}^n w_i (x_i - \\bar{x_i})^2}{(\\sum w_i) (\\sum w_i - 1)}`` +* `ProbabilityWeights`: ``\\frac{n}{n-1} \\frac{\\sum_{i=1}^n w_i^2 (x_i - \\bar{x_i})^2}{\\left( \\sum w_i \\right)^2}`` + +The standard error is then the square root of the above quantities. + +# References + +Carl-Erik Särndal, Bengt Swensson, Jan Wretman (1992). Model Assisted Survey Sampling. +New York: Springer. pp. 51-53. """ -function sem(x) - y = iterate(x) - if y === nothing +function sem(x; mean=nothing) + if isempty(x) + # Return the NaN of the type that we would get for a nonempty x T = eltype(x) - # Return the NaN of the type that we would get, had this collection - # contained any elements (this is consistent with std) - return oftype(sqrt((abs2(zero(T)) + abs2(zero(T)))/2), NaN) - end - count = 1 - value, state = y - y = iterate(x, state) - # Use Welford algorithm as seen in (among other places) - # Knuth's TAOCP, Vol 2, page 232, 3rd edition. - M = value / 1 - S = real(zero(M)) - while y !== nothing + _mean = mean === nothing ? zero(T) / 1 : mean + z = abs2(zero(T) - _mean) + return oftype((z + z) / 2, NaN) + elseif mean === nothing + n = 0 + y = iterate(x) + value, state = y + # Use Welford algorithm as seen in (among other places) + # Knuth's TAOCP, Vol 2, page 232, 3rd edition. + _mean = value / 1 + sse = real(zero(_mean)) + while y !== nothing + value, state = y + y = iterate(x, state) + n += 1 + new_mean = _mean + (value - _mean) / n + sse += realXcY(value - _mean, value - new_mean) + _mean = new_mean + end + else + n = 1 + y = iterate(x) value, state = y - y = iterate(x, state) - count += 1 - new_M = M + (value - M) / count - S = S + realXcY(value - M, value - new_M) - M = new_M + sse = abs2(value - mean) + while (y = iterate(x, state)) !== nothing + value, state = y + n += 1 + sse += abs2(value - mean) + end + end + variance = sse / (n - 1) + return sqrt(variance / n) +end + +function sem(x::AbstractArray; mean=nothing) + if isempty(x) + # Return the NaN of the type that we would get for a nonempty x + T = eltype(x) + _mean = mean === nothing ? zero(T) / 1 : mean + z = abs2(zero(T) - _mean) + return oftype((z + z) / 2, NaN) + end + return sqrt(var(x; mean=mean, corrected=true) / length(x)) +end + +function sem(x::AbstractArray, weights::UnitWeights; mean=nothing) + if length(x) ≠ length(weights) + throw(DimensionMismatch("array and weights do not have the same length")) + end + return sem(x; mean=mean) +end + + +# Weighted methods for the above +sem(x::AbstractArray, weights::FrequencyWeights; mean=nothing) = + sqrt(var(x, weights; mean=mean, corrected=true) / sum(weights)) + +function sem(x::AbstractArray, weights::ProbabilityWeights; mean=nothing) + if isempty(x) + # Return the NaN of the type that we would get for a nonempty x + return var(x, weights; mean=mean, corrected=true) / 0 + else + _mean = mean === nothing ? Statistics.mean(x, weights) : mean + # sum of squared errors = sse + sse = sum(Broadcast.instantiate(Broadcast.broadcasted(x, weights) do x_i, w + return abs2(w * (x_i - _mean)) + end)) + n = count(!iszero, weights) + return sqrt(sse * n / (n - 1)) / sum(weights) end - var = S / (count - 1) - return sqrt(var/count) end # Median absolute deviation diff --git a/test/scalarstats.jl b/test/scalarstats.jl index 6b561a130..4d5054918 100644 --- a/test/scalarstats.jl +++ b/test/scalarstats.jl @@ -107,10 +107,36 @@ z2 = [8. 2. 3. 1.; 24. 10. -1. -1.; 20. 12. 1. -2.] @test variation([1:5;]) ≈ 0.527046276694730 @test variation(skipmissing([missing; 1:5; missing])) ≈ 0.527046276694730 -@test sem([1:5;]) ≈ 0.707106781186548 -@test sem(skipmissing([missing; 1:5; missing])) ≈ 0.707106781186548 -@test sem(Int[]) === NaN -@test sem(skipmissing(Union{Int,Missing}[missing, missing])) === NaN +@test @inferred(sem([1:5;])) ≈ 0.707106781186548 +@test @inferred(sem(skipmissing([missing; 1:5; missing]))) ≈ 0.707106781186548 +@test @inferred(sem(skipmissing([missing; 1:5; missing]), mean=3.0)) ≈ 0.707106781186548 +@test @inferred(sem([1:5;], UnitWeights{Int}(5))) ≈ 0.707106781186548 +@test @inferred(sem([1:5;], UnitWeights{Int}(5); mean=mean(1:5))) ≈ 0.707106781186548 +@test_throws DimensionMismatch sem(1:5, UnitWeights{Int}(4)) +@test @inferred(sem([1:5;], ProbabilityWeights([1:5;]))) ≈ 0.6166 rtol=.001 +μ = mean(1:5, ProbabilityWeights([1:5;])) +@test @inferred(sem([1:5;], ProbabilityWeights([1:5;]); mean=μ)) ≈ 0.6166 rtol=.001 +@test @inferred(sem([10; 1:5;], ProbabilityWeights([0; 1:5;]); mean=μ)) ≈ 0.6166 rtol=.001 +x = sort!(vcat([5:-1:i for i in 1:5]...)) +μ = mean(x) +@test @inferred(sem([1:5;], FrequencyWeights([1:5;]))) ≈ sem(x) +@test @inferred(sem([1:5;], FrequencyWeights([1:5;]); mean=μ)) ≈ sem(x) + +@inferred sem([1:5f0;]; mean=μ) ≈ sem(x) +@inferred sem([1:5f0;], ProbabilityWeights([1:5;]); mean=μ) +@inferred sem([1:5f0;], FrequencyWeights([1:5;]); mean=μ) +# Broken: Bug to do with Statistics.jl's implementation of `var` +# @inferred sem([1:5f0;], UnitWeights{Int}(5); mean=μ) + +@test @inferred(isnan(sem(Int[]))) +@test @inferred(isnan(sem(Int[], FrequencyWeights(Int[])))) +@test @inferred(isnan(sem(Int[], ProbabilityWeights(Int[])))) + +@test @inferred(isnan(sem(Int[]; mean=0f0))) +@test @inferred(isnan(sem(Int[], FrequencyWeights(Int[]); mean=0f0))) +@test @inferred(isnan(sem(Int[], ProbabilityWeights(Int[]); mean=0f0))) + +@test @inferred(isnan(sem(skipmissing(Union{Int,Missing}[missing, missing])))) @test_throws MethodError sem(Any[]) @test_throws MethodError sem(skipmissing([missing]))