Skip to content

Commit

Permalink
Switch party order in recursive function
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiendesignolle committed Jan 8, 2025
1 parent 4decb43 commit 358391d
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,18 @@ end
# A::Matrix{T},
# marg = true,
# m = size(A),
# tmp = [zeros(T, m[2])],
# ind = [zeros(Int8, m[1] - marg)],
# ax = [ones(T, m[1])],
# tmp = [zeros(T, m[1])],
# ind = [zeros(Int8, m[2] - marg)],
# ax = [ones(T, m[2])],
# ) where {T<:Real}
# tmp1::Vector{T} = tmp[1]
# tmp_end::Vector{T} = tmp[1]
# score = typemin(T)
# @inbounds for _ ∈ 0:2^(m[1]-marg)-1
# @inbounds for _ ∈ 0:2^(m[2]-marg)-1
# @views ax[1][marg+1:end] .= 2 .* ind[1] .- 1
# mul!(tmp1, A', ax[1])
# temp_score = marg ? tmp1[1] : abs(tmp1[1])
# for x ∈ 2:m[2]
# temp_score += abs(tmp1[x])
# mul!(tmp_end, A, ax[1])
# temp_score = marg ? tmp_end[1] : abs(tmp_end[1])
# for x ∈ 2:m[1]
# temp_score += abs(tmp_end[x])
# end
# if temp_score > score
# score = temp_score
Expand All @@ -141,34 +141,34 @@ Base.@propagate_inbounds function _local_bound_correlation_recursive(
A::Array{T,N},
marg = true,
m = size(A),
tmp = [zeros(T, m[i+1:N]...) for i 1:N-1],
ind = [zeros(Int8, m[i] - marg) for i 1:N-1],
ax = [ones(T, m[i]) for i 1:N-1],
tmp = [zeros(T, m[1:i]...) for i 1:N-1],
ind = [zeros(Int8, m[i] - marg) for i 2:N],
ax = [ones(T, m[i]) for i 2:N],
) where {T<:Real,N}
tmp1::Array{T,N-1} = tmp[1]
tmp_end::Array{T,N-1} = tmp[N-1]
score = typemin(T)
@inbounds for _ 0:2^(m[1]-marg)-1
@views ax[1][marg+1:end] .= 2 .* ind[1] .- 1
_tensor_contraction!(tmp1, A, ax[1])
@views temp_score = _local_bound_correlation_recursive(tmp1, marg, m[2:N], tmp[2:N-1], ind[2:N-1], ax[2:N-1])
@inbounds for _ 0:2^(m[N]-marg)-1
@views ax[N-1][marg+1:end] .= 2 .* ind[N-1] .- 1
_tensor_contraction!(tmp_end, A, ax[N-1])
@views temp_score = _local_bound_correlation_recursive(tmp_end, marg, m[1:N-1], tmp[1:N-2], ind[1:N-2], ax[1:N-2])
if temp_score > score
score = temp_score
end
_update_odometer!(ind[1], 2)
_update_odometer!(ind[N-1], 2)
end
return score
end

function _tensor_contraction!(tmp, A::Matrix{T}, ax::Vector{T}) where {T<:Real}
@inbounds mul!(tmp, A', ax)
@inbounds mul!(tmp, A, ax)
end

# among ci/x orders in the loop and in the indexing,
# this is the fastest contraction, hence the enumeration order
function _tensor_contraction!(tmp, A::Array{T,N}, ax::Vector{T}) where {T<:Real,N}
tmp .= 0
@inbounds for ci CartesianIndices(tmp), x eachindex(ax)
tmp[ci] += A[x, ci] * ax[x]
@inbounds for x in eachindex(ax), ci in CartesianIndices(tmp)
tmp[ci] += A[ci, x] * ax[x]
end
end

Expand Down

0 comments on commit 358391d

Please sign in to comment.