Skip to content

Commit

Permalink
Limit changes to more general argument types
Browse files Browse the repository at this point in the history
  • Loading branch information
kagalenko-m-b committed Dec 4, 2021
1 parent 6b9d2cc commit 2efdf68
Showing 1 changed file with 56 additions and 35 deletions.
91 changes: 56 additions & 35 deletions src/signalcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ end

default_autolags(lx::Int) = 0 : default_laglen(lx)

_autodot(x::AbstractVector{<:RealFP}, lx::Int, l::Int) = dot(x, 1:(lx-l), x, (1+l):lx)
_autodot(x::AbstractVector, lx::Int, l::Int) = dot(view(x, 1:(lx-l)), view(x, (1+l):lx))


## autocov
"""
autocov!(r, x, lags; demean=true)
Expand All @@ -59,31 +61,34 @@ where each column in the result will correspond to a column in `x`.
The output is not normalized. See [`autocor!`](@ref) for a method with normalization.
"""
function autocov!(
r::AbstractVector,
x::AbstractVector,
lags::AbstractVector{<:Integer},
z::Vector=Vector{typeof(zero(eltype(x)) / 1)}(undef, length(x));
demean::Bool=true
)
function autocov!(r::AbstractVector, x::AbstractVector, lags::AbstractVector{<:Integer}; demean::Bool=true)
lx = length(x)
m = length(lags)
length(r) == length(z) == m || throw(DimensionMismatch())
length(r) == m || throw(DimensionMismatch())
check_lags(lx, lags)
demean ? z .= x .- mean(x) : copyto!(z, x)
@inbounds for (k, lags_k) in zip(eachindex(r), lags) # foreach lag value
r[k] = _autodot(z, lx, lags_k) / lx

T = typeof(zero(eltype(x)) / 1)
z::Vector{T} = demean ? x .- mean(x) : x
for k = 1 : m # foreach lag value
r[k] = _autodot(z, lx, lags[k]) / lx
end
return r
end

function autocov!(
r::AbstractMatrix, x::AbstractMatrix, lags::AbstractVector{<:Integer}; demean::Bool=true
)
function autocov!(r::AbstractMatrix, x::AbstractMatrix, lags::AbstractVector{<:Integer}; demean::Bool=true)
lx = size(x, 1)
ns = size(x, 2)
m = length(lags)
size(r) == (m, ns) || throw(DimensionMismatch())
check_lags(lx, lags)

T = typeof(zero(eltype(x)) / 1)
z = Vector{T}(undef, size(x, 1))
for n in 1:size(x, 2)
autocov!(view(r, :, n), view(x, :, n), lags, z; demean)
z = Vector{T}(undef, lx)
for j = 1 : ns
demean_col!(z, x, j, demean)
for k = 1 : m
r[k,j] = _autodot(z, lx, lags[k]) / lx
end
end
return r
end
Expand Down Expand Up @@ -134,27 +139,36 @@ where each column in the result will correspond to a column in `x`.
The output is normalized by the variance of `x`, i.e. so that the lag 0
autocorrelation is 1. See [`autocov!`](@ref) for the unnormalized form.
"""
function autocor!(
r::AbstractVector,
x::AbstractVector,
lags::AbstractVector{<:Integer},
z=zeros(typeof(zero(eltype(x)) / 1), length(x));
demean::Bool=true
)
autocov!(view(r, 1:1), x, 0:0, z; demean)
zz = r[1]
autocov!(r, x, lags, z; demean)
ldiv!(zz, r)
function autocor!(r::AbstractVector, x::AbstractVector, lags::AbstractVector{<:Integer}; demean::Bool=true)
lx = length(x)
m = length(lags)
length(r) == m || throw(DimensionMismatch())
check_lags(lx, lags)

T = typeof(zero(eltype(x)) / 1)
z::Vector{T} = demean ? x .- mean(x) : x
zz = dot(z, z)
for k = 1 : m # foreach lag value
r[k] = _autodot(z, lx, lags[k]) / zz
end
return r
end

function autocor!(
r::AbstractMatrix, x::AbstractMatrix, lags::AbstractVector{<:Integer}; demean::Bool=true
)
T = typeof(zero(eltype(x))/1)
z = Vector{T}(undef, size(x, 1))
for n in 1:size(x, 2)
autocor!(view(r, :, n), view(x, :, n), lags, z; demean=demean)
function autocor!(r::AbstractMatrix, x::AbstractMatrix, lags::AbstractVector{<:Integer}; demean::Bool=true)
lx = size(x, 1)
ns = size(x, 2)
m = length(lags)
size(r) == (m, ns) || throw(DimensionMismatch())
check_lags(lx, lags)

T = typeof(zero(eltype(x)) / 1)
z = Vector{T}(undef, lx)
for j = 1 : ns
demean_col!(z, x, j, demean)
zz = dot(z, z)
for k = 1 : m
r[k,j] = _autodot(z, lx, lags[k]) / zz
end
end
return r
end
Expand Down Expand Up @@ -199,6 +213,13 @@ autocor(x::AbstractVecOrMat; demean::Bool=true) =

default_crosslags(lx::Int) = (l=default_laglen(lx); -l:l)

function _crossdot(x::AbstractVector{T}, y::AbstractVector{T}, lx::Int, l::Int) where {T<:RealFP}
if l >= 0
dot(x, 1:(lx-l), y, (1+l):lx)
else
dot(x, (1-l):lx, y, 1:(lx+l))
end
end
function _crossdot(x::AbstractVector, y::AbstractVector, lx::Int, l::Int)
if l >= 0
dot(view(x, 1:(lx-l)), view(y, (1+l):lx))
Expand Down

0 comments on commit 2efdf68

Please sign in to comment.