Skip to content

Commit

Permalink
make indicators nothing instead of tuple of nothing
Browse files Browse the repository at this point in the history
  • Loading branch information
Datseris committed Dec 17, 2023
1 parent c71f523 commit 7d1da2d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
16 changes: 9 additions & 7 deletions src/analysis/sliding_window.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ function SlidingWindowConfig(
)
indicators, change_metrics = sanitycheck_metrics(indicators, change_metrics)
# Last step: precomputable functions, if any
indicators = map(f -> precompute(f, 1:T(width_ind)), indicators)
if !isnothing(indicators)
indicators = map(f -> precompute(f, 1:T(width_ind)), indicators)
end
change_metrics = map(f -> precompute(f, 1:T(width_cha)), change_metrics)

return SlidingWindowConfig(
Expand All @@ -77,13 +79,13 @@ function SlidingWindowConfig(
end

function sanitycheck_metrics(indicators, change_metrics)
if !(indicators isa Tuple)
indicators = (indicators,)
end
if !(change_metrics isa Tuple)
change_metrics = (change_metrics,)
end
if length(change_metrics) length(indicators) && indicators !== (nothing, )
if indicators isa Function
indicators = (indicators, )
end
if !isnothing(indicators) && (length(change_metrics) length(indicators))
throw(ArgumentError("The amount of change metrics and indicators must match."))
end
return indicators, change_metrics
Expand All @@ -97,7 +99,7 @@ end
function estimate_indicator_changes(config::SlidingWindowConfig, x, t = eachindex(x))
(; indicators, change_metrics) = config
# initialize time vectors
if indicators === (nothing, )
if isnothing(indicators)
# Skip indicators if they are nothing
t_indicator = t
else
Expand All @@ -118,7 +120,7 @@ function estimate_indicator_changes(config::SlidingWindowConfig, x, t = eachinde

for i in 1:n_metrics
# estimate indicator timeseries
if indicators !== (nothing, )
if !isnothing(indicators)
z = view(x_indicator, :, i)
windowmap!(indicators[i], z, x;
width = config.width_ind, stride = config.stride_ind
Expand Down
9 changes: 4 additions & 5 deletions src/significance/surrogates_significance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function significant_transitions(res::SlidingWindowResults, signif::SurrogatesSi
(; x, x_indicator, x_change) = res
(; indicators, change_metrics, width_ind, stride_ind, width_cha, stride_cha) = res.config
(; surrogate, n, tail, rng, p, pvalues) = signif
n_ind = length(indicators)
n_ind = length(change_metrics)
sanitycheck_tail(tail, n_ind)

# Init pvalues
Expand All @@ -75,15 +75,14 @@ function significant_transitions(res::SlidingWindowResults, signif::SurrogatesSi
seeds = rand(rng, 1:typemax(Int), Threads.nthreads())
sgens = [surrogenerator(x, surrogate, Xoshiro(seed)) for seed in seeds]
# Dummy vals for surrogate parallelization
if indicators !== (nothing, )
if !isnothing(indicators)
indicator_dummys = [x_indicator[:, 1] for _ in 1:Threads.nthreads()]
else
indicator_dummys = [copy(x) for _ in 1:Threads.nthreads()]
end
change_dummys = [x_change[:, 1] for _ in 1:Threads.nthreads()]

Threads.@threads for _ in 1:n

id = Threads.threadid()
s = sgens[id]()
change_dummy = change_dummys[id]
Expand Down Expand Up @@ -172,8 +171,8 @@ function sanitycheck_tail(tail, n_ind)
end
end

function choose_metrics(indicators::Tuple, change_metrics, tail, i::Int)
ind = indicators === (nothing, ) ? nothing : indicators[i]
function choose_metrics(indicators, change_metrics, tail, i::Int)
ind = isnothing(indicators) ? nothing : indicators[i]
tai = tail isa Symbol ? tail : tail[i]
return ind, change_metrics[i], tai
end
Expand Down

0 comments on commit 7d1da2d

Please sign in to comment.