Skip to content

Commit

Permalink
fix bug in nonlocal: everything must be big endian
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed Jan 3, 2025
1 parent 28d7be4 commit 8d9ae1f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
24 changes: 12 additions & 12 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ end

function _local_bound_probability_core(chunk, outs::NTuple{N,Int}, ins::NTuple{N,Int}, squareG::Array{T,2}) where {T,N}
score = typemin(T)
bases = reduce(vcat, [fill(outs[i], ins[i]) for i 2:length(ins)])
ind = _digits_mixed_basis(chunk[1] - 1, bases)
base = reduce(vcat, [fill(outs[i], ins[i]) for i 2:N])
ind = _digits(chunk[1] - 1; base)
Galice = zeros(T, outs[1] * ins[1])
sizes = (outs[2:N]..., ins[2:N]...)
prodsizes = [prod(sizes[1:i-1]) for i 1:2N-2]
Expand All @@ -173,7 +173,7 @@ function _local_bound_probability_core(chunk, outs::NTuple{N,Int}, ins::NTuple{N
@views sum!(Galice, squareG[:, offset_ind])
temp_score = _maxcols!(Galice, outs[1], ins[1])
score = max(score, temp_score)
_update_odometer!(ind, bases)
_update_odometer!(ind, base)
end
return score
end
Expand Down Expand Up @@ -226,22 +226,22 @@ function _partition(n::T, k::T) where {T<:Integer}
return parts
end

function _digits_mixed_basis(ind, bases)
N = length(bases)
function _digits(ind; base)
N = length(base)
digits = zeros(Int, N)
@inbounds for i N:-1:1
digits[i] = mod(ind, bases[i])
ind = div(ind, bases[i])
@inbounds for i 1:N
digits[i] = ind % base[i]
ind = ind ÷ base[i]
end
return digits
end

function _update_odometer!(ind::AbstractVector{<:Integer}, bases::AbstractVector{<:Integer})
function _update_odometer!(ind::AbstractVector{<:Integer}, base::AbstractVector{<:Integer})
ind[1] += 1
d = length(ind)

@inbounds for i 1:d
if ind[i] bases[i]
if ind[i] base[i]
ind[i] = 0
i < d ? ind[i+1] += 1 : return
else
Expand All @@ -250,12 +250,12 @@ function _update_odometer!(ind::AbstractVector{<:Integer}, bases::AbstractVector
end
end

function _update_odometer!(ind::AbstractVector{<:Integer}, bases::Integer)
function _update_odometer!(ind::AbstractVector{<:Integer}, base::Integer)
ind[1] += 1
d = length(ind)

@inbounds for i 1:d
if ind[i] bases
if ind[i] base
ind[i] = 0
i < d ? ind[i+1] += 1 : return
else
Expand Down
2 changes: 1 addition & 1 deletion test/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@test local_bound(gyni(Int, 3)) == 1
@test local_bound(gyni(Int, 4)) == 1
for T [Float64, Double64, Float128, BigFloat]
fp1 = rand(T, 2, 2, 3, 4) # randn would mess things up
fp1 = randn(T, 2, 2, 3, 4)
fp2 = permutedims(fp1, (2, 1, 4, 3))
fc1 = tensor_correlation(fp1)
fc2 = tensor_correlation(fp2)
Expand Down

0 comments on commit 8d9ae1f

Please sign in to comment.