Skip to content

Commit

Permalink
Remove recursive_top and modify A in place (factor 2)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiendesignolle committed Jan 8, 2025
1 parent 39863f4 commit cfdbcbd
Showing 1 changed file with 21 additions and 54 deletions.
75 changes: 21 additions & 54 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,61 +29,38 @@ function _local_bound_correlation(G::Array{T,N}; marg::Bool = true) where {T<:Re
end

chunks = _partition(outs[N]^(ins[N] - marg), Threads.nthreads())
ins2 = ins
G2 = G #workaround for https://github.com/JuliaLang/julia/issues/15276
tasks = map(chunks) do chunk
Threads.@spawn _local_bound_correlation_recursive_top(chunk, ins2, G2; marg)
Threads.@spawn _local_bound_correlation_recursive!(copy(G2), chunk, marg)
end
score = maximum(fetch.(tasks))
return score
end

Base.@propagate_inbounds function _local_bound_correlation_recursive_top(
function _local_bound_correlation_recursive!(
A::Array{T,N},
chunk,
m::NTuple{N,Int},
A::Array{T,N};
marg = true
marg = true,
m = size(A),
tmp = [zeros(T, m[1:i]...) for i 1:N-1],
offset = [zeros(T, m[1:i]...) for i 1:N-1],
ind = [zeros(Int8, m[i] - marg) for i 2:N],
) where {T<:Real,N}
tmp = [zeros(T, m[1:i]) for i 1:N-1]
offset = [zeros(T, m[1:i]) for i 1:N-1]
ind = [zeros(Int8, m[i] - marg) for i 2:N]
digits!(ind[N-1], chunk[1] - 1; base = 2)
tmp_end::Array{T,N - 1} = tmp[N-1]
offset_end::Array{T,N - 1} = offset[N-1]
_compute_offset!(offset_end, A, marg)
score = typemin(T)
A2 = 2 * A
@inbounds for _ chunk[1]:chunk[2]
tmp_end .= offset_end
_tensor_contraction!(tmp_end, A2, ind[N-1], marg)
@views temp_score =
_local_bound_correlation_recursive(tmp_end, marg, m[1:N-1], tmp[1:N-2], offset[1:N-2], ind[1:N-2])
if temp_score > score
score = temp_score
tmp_end::Array{T,N-1} = tmp[N-1]
offset_end::Array{T,N-1} = offset[N-1]
sum!(offset_end, A)
A .*= 2
if marg
for ci CartesianIndices(offset_end)
offset_end[ci] -= A[ci, 1]
end
_update_odometer!(ind[N-1], 2)
end
return score
end

Base.@propagate_inbounds function _local_bound_correlation_recursive(
A::Array{T,N},
marg,
m,
tmp,
offset,
ind
) where {T<:Real,N}
tmp_end::Array{T,N - 1} = tmp[N-1]
offset_end::Array{T,N - 1} = offset[N-1]
_compute_offset!(offset_end, A, marg)
offset_end .*= -1
score = typemin(T)
A2 = 2 * A
@inbounds for _ 0:2^(m[N]-marg)-1
for _ chunk[1]:chunk[2]
tmp_end .= offset_end
_tensor_contraction!(tmp_end, A2, ind[N-1], marg)
@views temp_score =
_local_bound_correlation_recursive(tmp_end, marg, m[1:N-1], tmp[1:N-2], offset[1:N-2], ind[1:N-2])
_tensor_contraction!(tmp_end, A, ind[N-1], marg)
@views temp_score = _local_bound_correlation_recursive!(tmp_end, (0, 2^(m[N-1]-marg)-1), marg, m[1:N-1], tmp[1:N-2], offset[1:N-2], ind[1:N-2])
if temp_score > score
score = temp_score
end
Expand All @@ -92,7 +69,7 @@ Base.@propagate_inbounds function _local_bound_correlation_recursive(
return score
end

Base.@propagate_inbounds function _local_bound_correlation_recursive(A::Vector, marg, m, tmp, offset, ind)
function _local_bound_correlation_recursive!(A::Vector, chunk, marg, m, tmp, offset, ind)
score = marg ? A[1] : abs(A[1])
for x 2:m[1]
score += abs(A[x])
Expand All @@ -101,7 +78,7 @@ Base.@propagate_inbounds function _local_bound_correlation_recursive(A::Vector,
end

function _tensor_contraction!(tmp, A::Array{T,N}, ind, marg) where {T<:Number,N}
@inbounds for x eachindex(ind)
for x eachindex(ind)
if ind[x] == 1
for ci CartesianIndices(tmp)
tmp[ci] += A[ci, x+marg]
Expand All @@ -110,16 +87,6 @@ function _tensor_contraction!(tmp, A::Array{T,N}, ind, marg) where {T<:Number,N}
end
end

function _compute_offset!(offset_end, A, marg)
sum!(offset_end, A)
if marg
for ci CartesianIndices(offset_end)
offset_end[ci] -= 2 * A[ci, 1]
end
end
offset_end .*= -1
end

function _local_bound_probability(G::Array{T,N2}) where {T<:Real,N2}
@assert iseven(N2)
N = N2 ÷ 2
Expand Down

0 comments on commit cfdbcbd

Please sign in to comment.