From 8d9ae1f9982755355d0ee85b7ce38b339923ae22 Mon Sep 17 00:00:00 2001 From: araujoms Date: Fri, 3 Jan 2025 11:46:01 +0100 Subject: [PATCH] fix bug in nonlocal: everything must be big endian --- src/nonlocal.jl | 24 ++++++++++++------------ test/nonlocal.jl | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/nonlocal.jl b/src/nonlocal.jl index e7549b4..c939b23 100644 --- a/src/nonlocal.jl +++ b/src/nonlocal.jl @@ -148,8 +148,8 @@ end function _local_bound_probability_core(chunk, outs::NTuple{N,Int}, ins::NTuple{N,Int}, squareG::Array{T,2}) where {T,N} score = typemin(T) - bases = reduce(vcat, [fill(outs[i], ins[i]) for i ∈ 2:length(ins)]) - ind = _digits_mixed_basis(chunk[1] - 1, bases) + base = reduce(vcat, [fill(outs[i], ins[i]) for i ∈ 2:N]) + ind = _digits(chunk[1] - 1; base) Galice = zeros(T, outs[1] * ins[1]) sizes = (outs[2:N]..., ins[2:N]...) prodsizes = [prod(sizes[1:i-1]) for i ∈ 1:2N-2] @@ -173,7 +173,7 @@ function _local_bound_probability_core(chunk, outs::NTuple{N,Int}, ins::NTuple{N @views sum!(Galice, squareG[:, offset_ind]) temp_score = _maxcols!(Galice, outs[1], ins[1]) score = max(score, temp_score) - _update_odometer!(ind, bases) + _update_odometer!(ind, base) end return score end @@ -226,22 +226,22 @@ function _partition(n::T, k::T) where {T<:Integer} return parts end -function _digits_mixed_basis(ind, bases) - N = length(bases) +function _digits(ind; base) + N = length(base) digits = zeros(Int, N) - @inbounds for i ∈ N:-1:1 - digits[i] = mod(ind, bases[i]) - ind = div(ind, bases[i]) + @inbounds for i ∈ 1:N + digits[i] = ind % base[i] + ind = ind ÷ base[i] end return digits end -function _update_odometer!(ind::AbstractVector{<:Integer}, bases::AbstractVector{<:Integer}) +function _update_odometer!(ind::AbstractVector{<:Integer}, base::AbstractVector{<:Integer}) ind[1] += 1 d = length(ind) @inbounds for i ∈ 1:d - if ind[i] ≥ bases[i] + if ind[i] ≥ base[i] ind[i] = 0 i < d ? ind[i+1] += 1 : return else @@ -250,12 +250,12 @@ function _update_odometer!(ind::AbstractVector{<:Integer}, bases::AbstractVector end end -function _update_odometer!(ind::AbstractVector{<:Integer}, bases::Integer) +function _update_odometer!(ind::AbstractVector{<:Integer}, base::Integer) ind[1] += 1 d = length(ind) @inbounds for i ∈ 1:d - if ind[i] ≥ bases + if ind[i] ≥ base ind[i] = 0 i < d ? ind[i+1] += 1 : return else diff --git a/test/nonlocal.jl b/test/nonlocal.jl index 54140ab..2704fd2 100644 --- a/test/nonlocal.jl +++ b/test/nonlocal.jl @@ -8,7 +8,7 @@ @test local_bound(gyni(Int, 3)) == 1 @test local_bound(gyni(Int, 4)) == 1 for T ∈ [Float64, Double64, Float128, BigFloat] - fp1 = rand(T, 2, 2, 3, 4) # randn would mess things up + fp1 = randn(T, 2, 2, 3, 4) fp2 = permutedims(fp1, (2, 1, 4, 3)) fc1 = tensor_correlation(fp1) fc2 = tensor_correlation(fp2)