Skip to content

Commit

Permalink
Add multipartite _local_bound_correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiendesignolle committed Jan 2, 2025
1 parent 31d4eec commit 5072d49
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
58 changes: 44 additions & 14 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function _local_bound_correlation(G::Array{T,N}; marg::Bool = true) where {T<:Re
outs = fill(2, N)
ins = size(G)

num_strategies = outs .^ ins
num_strategies = outs .^ (ins .- 1)
largest_party = argmax(num_strategies)
if largest_party != 1
perm = [largest_party; 2:largest_party-1; 1; largest_party+1:N]
Expand All @@ -29,7 +29,7 @@ function _local_bound_correlation(G::Array{T,N}; marg::Bool = true) where {T<:Re
end
squareG = reshape(G, ins[1], prod(ins[2:N]))

chunks = _partition(prod((outs .^ ins)[2:N]), Threads.nthreads())
chunks = _partition(prod((outs .^ (ins .- marg))[2:N]), Threads.nthreads())
ins2 = ins
squareG2 = squareG #workaround for https://github.com/JuliaLang/julia/issues/15276
tasks = map(chunks) do chunk
Expand All @@ -42,18 +42,48 @@ end
function _local_bound_correlation_single(chunk, ins::NTuple{2,Int}, squareG::Array{T,2}; marg::Bool = true) where {T}
ia, ib = ins
score = typemin(T)
ind = digits(chunk[1] - 1; base = 2, pad = ib)
offset = Vector(1 .+ 2 * (0:ib-1))
ind = digits(chunk[1] - 1; base = 2, pad = ib - marg)
tmp = zeros(T, ia)
ax = zeros(T, ia)
if marg
ax[1] = 1
ax = ones(T, ia)
by = ones(T, ib)
@inbounds for _ chunk[1]:chunk[2]
by[marg+1:ib] .= 2 .* ind .- 1
mul!(tmp, squareG, by)
for x marg+1:ia
ax[x] = tmp[x] > zero(T) ? one(T) : -one(T)
end
temp_score = dot(ax, tmp)
score = max(score, temp_score)
_update_odometer!(ind, 2)
end
return score
end

function _local_bound_correlation_single(chunk, ins::NTuple{N,Int}, squareG::Array{T,2}; marg::Bool = true) where {T,N}
score = typemin(T)
ind = digits(chunk[1] - 1; base = 2, pad = sum(ins[2:N] .- marg))
sumsizes = [1; cumsum(collect(ins[2:N] .- marg)) .+ 1]
prodsizes = ones(Int, N - 1)
for i 1:N-1
prodsizes[i] = prod(ins[2:i])
end
by = zeros(T, ib)
linearindex(v) = 1 + dot(v, prodsizes)
tmp = zeros(T, ins[1])
ax = ones(T, ins[1])
by = [ones(T, ins[i]) for i 2:N]
@inbounds for _ chunk[1]:chunk[2]
by .= 2 .* ind .- 1
@views mul!(tmp, squareG, by)
for x in marg+1:ia
tmp .= 0
for i 2:N
by[i-1][marg+1:ins[i]] .= 2 .* ind[sumsizes[i-1]:sumsizes[i]-1] .- 1
end
for y CartesianIndices(ins[2:N])
b = prod(by[i][y[i]] for i 1:N-1)
lin_by = linearindex(y.I .- 1)
for x 1:ins[1]
tmp[x] += squareG[x, lin_by] * b
end
end
for x marg+1:ins[1]
ax[x] = tmp[x] > zero(T) ? one(T) : -one(T)
end
temp_score = dot(ax, tmp)
Expand Down Expand Up @@ -114,10 +144,9 @@ end

function _local_bound_probability_single(chunk, outs::NTuple{N,Int}, ins::NTuple{N,Int}, squareG::Array{T,2}) where {T,N}
score = typemin(T)
bases = reduce(vcat, [outs[i] * ones(Int, ins[i]) for i 2:length(ins)])
bases = reduce(vcat, [fill(outs[i], ins[i]) for i 2:length(ins)])
ind = _digits_mixed_basis(chunk[1] - 1, bases)
Galice = zeros(T, outs[1] * ins[1])
b = zeros(Int, N - 1)
sizes = (outs[2:N]..., ins[2:N]...)
prodsizes = ones(Int, 2 * (N - 1))
for i 1:length(prodsizes)
Expand All @@ -135,8 +164,9 @@ function _local_bound_probability_single(chunk, outs::NTuple{N,Int}, ins::NTuple
for i 1:N-1
by[i+N-1] = y[i] - 1
end
lin_by = linearindex(by)
for i 1:outs[1]*ins[1]
Galice[i] += squareG[i, linearindex(by)]
Galice[i] += squareG[i, lin_by]
end
end
temp_score = _maxcols!(Galice, outs[1], ins[1])
Expand Down
7 changes: 7 additions & 0 deletions test/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
@test local_bound(fp1) local_bound(fp2)
@test local_bound(fc1) local_bound(fc2)
@test local_bound(fc1) local_bound(fp1)
fp1 = rand(T, 2, 2, 2, 3, 4, 5)
fp2 = permutedims(fp1, (3, 2, 1, 6, 5, 4))
fc1 = tensor_correlation(fp1)
fc2 = tensor_correlation(fp2)
@test local_bound(fp1) local_bound(fp2)
@test local_bound(fc1) local_bound(fc2)
@test local_bound(fc1) local_bound(fp1)
end
@test local_bound([1 1; 1 -1]; marg = false) == 2

Expand Down

0 comments on commit 5072d49

Please sign in to comment.