diff --git a/src/nonlocal.jl b/src/nonlocal.jl index d816b2c..26545bc 100644 --- a/src/nonlocal.jl +++ b/src/nonlocal.jl @@ -20,7 +20,7 @@ function _local_bound_correlation(G::Array{T,N}; marg::Bool = true) where {T<:Re outs = fill(2, N) ins = size(G) - num_strategies = outs .^ ins + num_strategies = outs .^ (ins .- 1) largest_party = argmax(num_strategies) if largest_party != 1 perm = [largest_party; 2:largest_party-1; 1; largest_party+1:N] @@ -29,7 +29,7 @@ function _local_bound_correlation(G::Array{T,N}; marg::Bool = true) where {T<:Re end squareG = reshape(G, ins[1], prod(ins[2:N])) - chunks = _partition(prod((outs .^ ins)[2:N]), Threads.nthreads()) + chunks = _partition(prod((outs .^ (ins .- marg))[2:N]), Threads.nthreads()) ins2 = ins squareG2 = squareG #workaround for https://github.com/JuliaLang/julia/issues/15276 tasks = map(chunks) do chunk @@ -42,18 +42,48 @@ end function _local_bound_correlation_single(chunk, ins::NTuple{2,Int}, squareG::Array{T,2}; marg::Bool = true) where {T} ia, ib = ins score = typemin(T) - ind = digits(chunk[1] - 1; base = 2, pad = ib) - offset = Vector(1 .+ 2 * (0:ib-1)) + ind = digits(chunk[1] - 1; base = 2, pad = ib - marg) tmp = zeros(T, ia) - ax = zeros(T, ia) - if marg - ax[1] = 1 + ax = ones(T, ia) + by = ones(T, ib) + @inbounds for _ ∈ chunk[1]:chunk[2] + by[marg+1:ib] .= 2 .* ind .- 1 + mul!(tmp, squareG, by) + for x ∈ marg+1:ia + ax[x] = tmp[x] > zero(T) ? one(T) : -one(T) + end + temp_score = dot(ax, tmp) + score = max(score, temp_score) + _update_odometer!(ind, 2) + end + return score +end + +function _local_bound_correlation_single(chunk, ins::NTuple{N,Int}, squareG::Array{T,2}; marg::Bool = true) where {T,N} + score = typemin(T) + ind = digits(chunk[1] - 1; base = 2, pad = sum(ins[2:N] .- marg)) + sumsizes = [1; cumsum(collect(ins[2:N] .- marg)) .+ 1] + prodsizes = ones(Int, N - 1) + for i ∈ 1:N-1 + prodsizes[i] = prod(ins[2:i]) end - by = zeros(T, ib) + linearindex(v) = 1 + dot(v, prodsizes) + tmp = zeros(T, ins[1]) + ax = ones(T, ins[1]) + by = [ones(T, ins[i]) for i ∈ 2:N] @inbounds for _ ∈ chunk[1]:chunk[2] - by .= 2 .* ind .- 1 - @views mul!(tmp, squareG, by) - for x in marg+1:ia + tmp .= 0 + for i ∈ 2:N + by[i-1][marg+1:ins[i]] .= 2 .* ind[sumsizes[i-1]:sumsizes[i]-1] .- 1 + end + for y ∈ CartesianIndices(ins[2:N]) + b = prod(by[i][y[i]] for i ∈ 1:N-1) + lin_by = linearindex(y.I .- 1) + for x ∈ 1:ins[1] + tmp[x] += squareG[x, lin_by] * b + end + end + for x ∈ marg+1:ins[1] ax[x] = tmp[x] > zero(T) ? one(T) : -one(T) end temp_score = dot(ax, tmp) @@ -114,10 +144,9 @@ end function _local_bound_probability_single(chunk, outs::NTuple{N,Int}, ins::NTuple{N,Int}, squareG::Array{T,2}) where {T,N} score = typemin(T) - bases = reduce(vcat, [outs[i] * ones(Int, ins[i]) for i ∈ 2:length(ins)]) + bases = reduce(vcat, [fill(outs[i], ins[i]) for i ∈ 2:length(ins)]) ind = _digits_mixed_basis(chunk[1] - 1, bases) Galice = zeros(T, outs[1] * ins[1]) - b = zeros(Int, N - 1) sizes = (outs[2:N]..., ins[2:N]...) prodsizes = ones(Int, 2 * (N - 1)) for i ∈ 1:length(prodsizes) @@ -135,8 +164,9 @@ function _local_bound_probability_single(chunk, outs::NTuple{N,Int}, ins::NTuple for i ∈ 1:N-1 by[i+N-1] = y[i] - 1 end + lin_by = linearindex(by) for i ∈ 1:outs[1]*ins[1] - Galice[i] += squareG[i, linearindex(by)] + Galice[i] += squareG[i, lin_by] end end temp_score = _maxcols!(Galice, outs[1], ins[1]) diff --git a/test/nonlocal.jl b/test/nonlocal.jl index 394dc9c..54140ab 100644 --- a/test/nonlocal.jl +++ b/test/nonlocal.jl @@ -15,6 +15,13 @@ @test local_bound(fp1) ≈ local_bound(fp2) @test local_bound(fc1) ≈ local_bound(fc2) @test local_bound(fc1) ≈ local_bound(fp1) + fp1 = rand(T, 2, 2, 2, 3, 4, 5) + fp2 = permutedims(fp1, (3, 2, 1, 6, 5, 4)) + fc1 = tensor_correlation(fp1) + fc2 = tensor_correlation(fp2) + @test local_bound(fp1) ≈ local_bound(fp2) + @test local_bound(fc1) ≈ local_bound(fc2) + @test local_bound(fc1) ≈ local_bound(fp1) end @test local_bound([1 1; 1 -1]; marg = false) == 2