Skip to content

Commit

Permalink
Add _local_bound_correlation_recursive
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiendesignolle committed Jan 3, 2025
1 parent 04829be commit a68014a
Showing 1 changed file with 60 additions and 3 deletions.
63 changes: 60 additions & 3 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function _local_bound_correlation(G::Array{T,N}; marg::Bool = true) where {T<:Re
return score
end

function _local_bound_correlation_core(chunk, ins::NTuple{2,Int}, squareG::Array{T,2}; marg::Bool = true) where {T}
function _local_bound_correlation_core(chunk, ins::NTuple{2,Int}, squareG::Array{T,2}; marg::Bool = true) where {T<:Real}
ia, ib = ins
score = typemin(T)
ind = Vector{Int8}(undef, ib - marg)
Expand Down Expand Up @@ -66,10 +66,10 @@ function _local_bound_correlation_core(chunk, ins::NTuple{2,Int}, squareG::Array
return score
end

function _local_bound_correlation_core(chunk, ins::NTuple{N,Int}, squareG::Array{T,2}; marg::Bool = true) where {T,N}
function _local_bound_correlation_core(chunk, ins::NTuple{N,Int}, squareG::Array{T,2}; marg::Bool = true) where {T<:Real,N}
ia = ins[1]
score = typemin(T)
ind = Vector{Int8}(undef, sum(ins[2:N]) - marg*(N-1))
ind = Vector{Int8}(undef, sum(ins[2:N]) - marg * (N - 1))
digits!(ind, chunk[1] - 1; base = 2)
sumsizes = [1; cumsum(collect(ins[2:N]) .- marg) .+ 1]
prodsizes = ones(Int, N - 1)
Expand Down Expand Up @@ -103,6 +103,63 @@ function _local_bound_correlation_core(chunk, ins::NTuple{N,Int}, squareG::Array
return score
end

function _local_bound_correlation_recursive(
A::Array{T,2};
marg = true,
N = 2,
m = size(A),
tmp = [zeros(T, m[i+1:N]...) for i 1:N-1],
ind = [zeros(Int8, m[i] - marg) for i 1:N-1],
ax = [ones(T, m[i]) for i 1:N-1],
) where {T<:Real}
tmp1::Vector{T} = tmp[1]
score = typemin(T)
for _ 0:2^(m[1]-marg)-1
@views ax[1][marg+1:end] .= 2 .* ind[1] .- 1
mul!(tmp1, A', ax[1])
temp_score = marg ? tmp1[1] : abs(tmp1[1])
for x 2:m[2]
temp_score += abs(tmp1[x])
end
if temp_score > score
score = temp_score
end
_update_odometer!(ind[1], 2)
end
return score
end

function _local_bound_correlation_recursive(
A::Array{T,N};
marg = true,
m = size(A),
tmp = [zeros(T, m[i+1:N]...) for i 1:N-1],
ind = [zeros(Int8, m[i] - marg) for i 1:N-1],
ax = [ones(T, m[i]) for i 1:N-1],
) where {T<:Real,N}
tmp1::Array{T,N-1} = tmp[1]
score = typemin(T)
for _ 0:2^(m[1]-marg)-1
@views ax[1][marg+1:end] .= 2 .* ind[1] .- 1
_tensor_contraction!(tmp1, A, ax[1])
temp_score = _local_bound_correlation_recursive(tmp1; marg, m = m[2:N], tmp = tmp[2:N-1], ind = ind[2:N-1], ax = ax[2:N-1])
if temp_score > score
score = temp_score
end
_update_odometer!(ind[1], 2)
end
return score
end

# among ci/x orders in the loop and in the indexing,
# this is the fastest contraction, hence the enumeration order
function _tensor_contraction!(tmp, A::Array{T,N}, ax) where {T<:Real,N}
tmp .= 0
for ci CartesianIndices(tmp), x eachindex(ax)
tmp[ci] += A[x, ci] * ax[x]
end
end

function _local_bound_probability(G::Array{T,N2}) where {T<:Real,N2}
@assert iseven(N2)
N = N2 ÷ 2
Expand Down

0 comments on commit a68014a

Please sign in to comment.