Skip to content

Commit

Permalink
parallelize recursive version
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed Jan 8, 2025
1 parent 1831478 commit c7d3047
Showing 1 changed file with 76 additions and 86 deletions.
162 changes: 76 additions & 86 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ Reference: Araújo, Hirsch, and Quintino, [arXiv:2005.13418](https://arxiv.org/a
"""
function local_bound(G::Array{T,N}; correlation::Bool = N < 4, marg::Bool = true) where {T<:Real,N}
if correlation
#return _local_bound_correlation(G; marg)
return _local_bound_correlation_recursive(G, marg)
return _local_bound_correlation(G; marg)
else
return _local_bound_probability(G)
end
Expand All @@ -28,110 +27,85 @@ function _local_bound_correlation(G::Array{T,N}; marg::Bool = true) where {T<:Re
ins::NTuple{N,Int} = ins[perm]
G = permutedims(G, perm)
end
squareG = reshape(G, ins[1], prod(ins[2:N]))

chunks = _partition(prod((outs .^ (ins.-marg))[2:N]), Threads.nthreads())
ins2 = ins #workaround for https://github.com/JuliaLang/julia/issues/15276
chunks = _partition(outs[N]^(ins[N] - marg), Threads.nthreads())
ins2 = ins
G2 = G #workaround for https://github.com/JuliaLang/julia/issues/15276
tasks = map(chunks) do chunk
Threads.@spawn _local_bound_correlation_core(chunk, ins2, squareG; marg)
Threads.@spawn _local_bound_correlation_recursive_top(chunk, ins2, G2; marg)
end
score = maximum(fetch.(tasks))
return score
end

function _local_bound_correlation_core(chunk, ins::NTuple{2,Int}, squareG::Array{T,2}; marg::Bool = true) where {T<:Real}
ia, ib = ins
score = typemin(T)
ind = Vector{Int8}(undef, ib - marg)
digits!(ind, chunk[1] - 1; base = 2)
bx = zeros(T, ia)
offset = zeros(T, ia)
@views sum!(offset, squareG[:, marg+1:ib])
marg && @views offset .-= squareG[:, 1]
offset .*= -1
squareG2 = 2*squareG #necessary because of multithreading
@inbounds for _ chunk[1]:chunk[2]
bx .= offset
for y marg+1:ib
if ind[y-marg] == 1
@views bx .+= squareG2[:, y]
end
end
temp_score = marg ? bx[1] : abs(bx[1])
for x 2:ia
temp_score += abs(bx[x])
end
score = max(score, temp_score)
_update_odometer!(ind, 2)
end
return score
end

function _local_bound_correlation_core(chunk, ins::NTuple{N,Int}, squareG::Array{T,2}; marg::Bool = true) where {T<:Real,N}
ia = ins[1]
Base.@propagate_inbounds function _local_bound_correlation_recursive_top(
chunk,
m::NTuple{N,Int},
A::Array{T,N};
marg = true
) where {T<:Real,N}
tmp = [zeros(T, m[1:i]...) for i 1:N-1]
offset = [zeros(T, m[1:i]...) for i 1:N-1]
ind = [zeros(Int8, m[i] - marg) for i 2:N]
digits!(ind[N-1], chunk[1] - 1; base = 2)
ax = [ones(T, m[i]) for i 2:N]
tmp_end::Array{T,N - 1} = tmp[N-1]
offset_end::Array{T,N - 1} = offset[N-1]
sum!(offset_end, A)
offset_end .*= -1
score = typemin(T)
ind = Vector{Int8}(undef, sum(ins[2:N]) - marg * (N - 1))
digits!(ind, chunk[1] - 1; base = 2)
sumsizes = [1; cumsum(collect(ins[2:N]) .- marg) .+ 1]
prodsizes = ones(Int, N - 1)
for i 2:N-1
prodsizes[i] = prod(ins[2:i])
end
linearindex_offset = 1 - sum(prodsizes) # to avoid y.I .- 1
linearindex(v) = linearindex_offset + dot(v, prodsizes)
tmp = zeros(T, ia)
by = [ones(T, ins[i]) for i 2:N]
ins_region = CartesianIndices(ins[2:N])
A2 = 2 * A
CI = CartesianIndices(tmp_end)
@inbounds for _ chunk[1]:chunk[2]
tmp .= 0
for i 2:N
@views by[i-1][marg+1:ins[i]] .= 2 .* ind[sumsizes[i-1]:sumsizes[i]-1] .- 1
end
for y ins_region
b = prod(by[i][y[i]] for i 1:N-1)
lin_by = linearindex(y.I)
for x 1:ia
tmp[x] += squareG[x, lin_by] * b
end
end
temp_score = marg ? tmp[1] : abs(tmp[1])
for x 2:ia
temp_score += abs(tmp[x])
@views ax[N-1][marg+1:end] .= ind[N-1]
tmp_end .= offset_end
_tensor_contraction!(tmp_end, A2, ax[N-1], CI)
@views temp_score = _local_bound_correlation_recursive(
tmp_end,
marg,
m[1:N-1],
tmp[1:N-2],
offset[1:N-2],
ind[1:N-2],
ax[1:N-2]
)
if temp_score > score
score = temp_score
end
score = max(score, temp_score)
_update_odometer!(ind, 2)
end
return score
end

Base.@propagate_inbounds function _local_bound_correlation_recursive(A::Vector{T}, marg, m, tmp, offset, ind, ax) where {T<:Real}
score = marg ? A[1] : abs(A[1])
for x 2:m[1]
score += abs(A[x])
_update_odometer!(ind[N-1], 2)
end
return score
end

Base.@propagate_inbounds function _local_bound_correlation_recursive(
A::Array{T,N},
marg = true,
m = size(A),
tmp = [zeros(T, m[1:i]...) for i 1:N-1],
offset = [zeros(T, m[1:i]...) for i 1:N-1],
ind = [zeros(Int8, m[i] - marg) for i 2:N],
ax = [ones(T, m[i]) for i 2:N],
marg,
m,
tmp,
offset,
ind,
ax
) where {T<:Real,N}
tmp_end::Array{T,N-1} = tmp[N-1]
offset_end::Array{T,N-1} = offset[N-1]
tmp_end::Array{T,N - 1} = tmp[N-1]
offset_end::Array{T,N - 1} = offset[N-1]
sum!(offset_end, A)
offset_end .*= -1
score = typemin(T)
A2 = 2*A
A2 = 2 * A
CI = CartesianIndices(tmp_end)
@inbounds for _ 0:2^(m[N]-marg)-1
@views ax[N-1][marg+1:end] .= ind[N-1]
tmp_end .= offset_end
_tensor_contraction!(tmp_end, A2, ax[N-1])
@views temp_score = _local_bound_correlation_recursive(tmp_end, marg, m[1:N-1], tmp[1:N-2], offset[1:N-2], ind[1:N-2], ax[1:N-2])
_tensor_contraction!(tmp_end, A2, ax[N-1], CI)
@views temp_score = _local_bound_correlation_recursive(
tmp_end,
marg,
m[1:N-1],
tmp[1:N-2],
offset[1:N-2],
ind[1:N-2],
ax[1:N-2]
)
if temp_score > score
score = temp_score
end
Expand All @@ -140,10 +114,26 @@ Base.@propagate_inbounds function _local_bound_correlation_recursive(
return score
end

function _tensor_contraction!(tmp, A::Array{T,N}, ax) where {T<:Number,N}
Base.@propagate_inbounds function _local_bound_correlation_recursive(
A::Vector{T},
marg,
m,
tmp,
offset,
ind,
ax
) where {T<:Real}
score = marg ? A[1] : abs(A[1])
for x 2:m[1]
score += abs(A[x])
end
return score
end

function _tensor_contraction!(tmp, A::Array{T,N}, ax, CI) where {T<:Number,N}
@inbounds for x eachindex(ax)
if ax[x] == 1
for ci CartesianIndices(tmp)
for ci CI
tmp[ci] += A[ci, x]
end
end
Expand Down Expand Up @@ -201,7 +191,7 @@ function _local_bound_probability_core(chunk, outs::NTuple{N,Int}, ins::NTuple{N
base = reduce(vcat, [fill(outs[i], ins[i]) for i 2:N])
ind = _digits(chunk[1] - 1; base)
Galice = zeros(T, outs[1] * ins[1])
sumins = zeros(Int, N-1)
sumins = zeros(Int, N - 1)
for i 2:N-1
sumins[i] = sum(ins[2:i])
end
Expand Down

0 comments on commit c7d3047

Please sign in to comment.