Skip to content

Commit

Permalink
Weighted sem (#754)
Browse files Browse the repository at this point in the history
  • Loading branch information
ParadaCarleton authored Feb 6, 2022
1 parent 7fca6e8 commit afd0bb7
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 29 deletions.
115 changes: 90 additions & 25 deletions src/scalarstats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,36 +421,101 @@ realXcY(x::Real, y::Real) = x*y
realXcY(x::Complex, y::Complex) = real(x)*real(y) + imag(x)*imag(y)

"""
sem(x)
sem(x; mean=nothing)
sem(x::AbstractArray[, weights::AbstractWeights]; mean=nothing)
Return the standard error of the mean of collection `x`,
i.e. `sqrt(var(x, corrected=true) / length(x))`.
Return the standard error of the mean for a collection `x`.
A pre-computed `mean` may be provided.
When not using weights, this is the (sample) standard deviation
divided by the sample size. If weights are used, the
variance of the sample mean is calculated as follows:
* `AnalyticWeights`: Not implemented.
* `FrequencyWeights`: ``\\frac{\\sum_{i=1}^n w_i (x_i - \\bar{x_i})^2}{(\\sum w_i) (\\sum w_i - 1)}``
* `ProbabilityWeights`: ``\\frac{n}{n-1} \\frac{\\sum_{i=1}^n w_i^2 (x_i - \\bar{x_i})^2}{\\left( \\sum w_i \\right)^2}``
The standard error is then the square root of the above quantities.
# References
Carl-Erik Särndal, Bengt Swensson, Jan Wretman (1992). Model Assisted Survey Sampling.
New York: Springer. pp. 51-53.
"""
function sem(x)
y = iterate(x)
if y === nothing
function sem(x; mean=nothing)
if isempty(x)
# Return the NaN of the type that we would get for a nonempty x
T = eltype(x)
# Return the NaN of the type that we would get, had this collection
# contained any elements (this is consistent with std)
return oftype(sqrt((abs2(zero(T)) + abs2(zero(T)))/2), NaN)
end
count = 1
value, state = y
y = iterate(x, state)
# Use Welford algorithm as seen in (among other places)
# Knuth's TAOCP, Vol 2, page 232, 3rd edition.
M = value / 1
S = real(zero(M))
while y !== nothing
_mean = mean === nothing ? zero(T) / 1 : mean
z = abs2(zero(T) - _mean)
return oftype((z + z) / 2, NaN)
elseif mean === nothing
n = 0
y = iterate(x)
value, state = y
# Use Welford algorithm as seen in (among other places)
# Knuth's TAOCP, Vol 2, page 232, 3rd edition.
_mean = value / 1
sse = real(zero(_mean))
while y !== nothing
value, state = y
y = iterate(x, state)
n += 1
new_mean = _mean + (value - _mean) / n
sse += realXcY(value - _mean, value - new_mean)
_mean = new_mean
end
else
n = 1
y = iterate(x)
value, state = y
y = iterate(x, state)
count += 1
new_M = M + (value - M) / count
S = S + realXcY(value - M, value - new_M)
M = new_M
sse = abs2(value - mean)
while (y = iterate(x, state)) !== nothing
value, state = y
n += 1
sse += abs2(value - mean)
end
end
variance = sse / (n - 1)
return sqrt(variance / n)
end

function sem(x::AbstractArray; mean=nothing)
if isempty(x)
# Return the NaN of the type that we would get for a nonempty x
T = eltype(x)
_mean = mean === nothing ? zero(T) / 1 : mean
z = abs2(zero(T) - _mean)
return oftype((z + z) / 2, NaN)
end
return sqrt(var(x; mean=mean, corrected=true) / length(x))
end

function sem(x::AbstractArray, weights::UnitWeights; mean=nothing)
if length(x) length(weights)
throw(DimensionMismatch("array and weights do not have the same length"))
end
return sem(x; mean=mean)
end


# Weighted methods for the above
sem(x::AbstractArray, weights::FrequencyWeights; mean=nothing) =
sqrt(var(x, weights; mean=mean, corrected=true) / sum(weights))

function sem(x::AbstractArray, weights::ProbabilityWeights; mean=nothing)
if isempty(x)
# Return the NaN of the type that we would get for a nonempty x
return var(x, weights; mean=mean, corrected=true) / 0
else
_mean = mean === nothing ? Statistics.mean(x, weights) : mean
# sum of squared errors = sse
sse = sum(Broadcast.instantiate(Broadcast.broadcasted(x, weights) do x_i, w
return abs2(w * (x_i - _mean))
end))
n = count(!iszero, weights)
return sqrt(sse * n / (n - 1)) / sum(weights)
end
var = S / (count - 1)
return sqrt(var/count)
end

# Median absolute deviation
Expand Down
34 changes: 30 additions & 4 deletions test/scalarstats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,36 @@ z2 = [8. 2. 3. 1.; 24. 10. -1. -1.; 20. 12. 1. -2.]
@test variation([1:5;]) 0.527046276694730
@test variation(skipmissing([missing; 1:5; missing])) 0.527046276694730

@test sem([1:5;]) 0.707106781186548
@test sem(skipmissing([missing; 1:5; missing])) 0.707106781186548
@test sem(Int[]) === NaN
@test sem(skipmissing(Union{Int,Missing}[missing, missing])) === NaN
@test @inferred(sem([1:5;])) 0.707106781186548
@test @inferred(sem(skipmissing([missing; 1:5; missing]))) 0.707106781186548
@test @inferred(sem(skipmissing([missing; 1:5; missing]), mean=3.0)) 0.707106781186548
@test @inferred(sem([1:5;], UnitWeights{Int}(5))) 0.707106781186548
@test @inferred(sem([1:5;], UnitWeights{Int}(5); mean=mean(1:5))) 0.707106781186548
@test_throws DimensionMismatch sem(1:5, UnitWeights{Int}(4))
@test @inferred(sem([1:5;], ProbabilityWeights([1:5;]))) 0.6166 rtol=.001
μ = mean(1:5, ProbabilityWeights([1:5;]))
@test @inferred(sem([1:5;], ProbabilityWeights([1:5;]); mean=μ)) 0.6166 rtol=.001
@test @inferred(sem([10; 1:5;], ProbabilityWeights([0; 1:5;]); mean=μ)) 0.6166 rtol=.001
x = sort!(vcat([5:-1:i for i in 1:5]...))
μ = mean(x)
@test @inferred(sem([1:5;], FrequencyWeights([1:5;]))) sem(x)
@test @inferred(sem([1:5;], FrequencyWeights([1:5;]); mean=μ)) sem(x)

@inferred sem([1:5f0;]; mean=μ) sem(x)
@inferred sem([1:5f0;], ProbabilityWeights([1:5;]); mean=μ)
@inferred sem([1:5f0;], FrequencyWeights([1:5;]); mean=μ)
# Broken: Bug to do with Statistics.jl's implementation of `var`
# @inferred sem([1:5f0;], UnitWeights{Int}(5); mean=μ)

@test @inferred(isnan(sem(Int[])))
@test @inferred(isnan(sem(Int[], FrequencyWeights(Int[]))))
@test @inferred(isnan(sem(Int[], ProbabilityWeights(Int[]))))

@test @inferred(isnan(sem(Int[]; mean=0f0)))
@test @inferred(isnan(sem(Int[], FrequencyWeights(Int[]); mean=0f0)))
@test @inferred(isnan(sem(Int[], ProbabilityWeights(Int[]); mean=0f0)))

@test @inferred(isnan(sem(skipmissing(Union{Int,Missing}[missing, missing]))))
@test_throws MethodError sem(Any[])
@test_throws MethodError sem(skipmissing([missing]))

Expand Down

0 comments on commit afd0bb7

Please sign in to comment.