Skip to content

Commit

Permalink
type stability
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed Jan 2, 2025
1 parent eaee958 commit 2ee5774
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function _local_bound_correlation(G::Array{T,N}; marg::Bool = true) where {T<:Re
end
squareG = reshape(G, ins[1], prod(ins[2:N]))

chunks = _partition(prod((outs .^ (ins .- marg))[2:N]), Threads.nthreads())
chunks = _partition(prod((outs .^ (ins.-marg))[2:N]), Threads.nthreads())
ins2 = ins
squareG2 = squareG #workaround for https://github.com/JuliaLang/julia/issues/15276
tasks = map(chunks) do chunk
Expand Down Expand Up @@ -104,27 +104,25 @@ function _local_bound_probability(G::Array{T,N2}) where {T<:Real,N2}
largest_party = argmax(num_strategies)
if largest_party != 1
vperm = [largest_party; 2:largest_party-1; 1; largest_party+1:N]
outs = outs[vperm]
ins = ins[vperm]
bigperm::NTuple{N2,Int} = Tuple([vperm; vperm .+ N])
G = permutedims(G, bigperm)
outs::NTuple{N,Int} = outs[vperm]
ins::NTuple{N,Int} = ins[vperm]
G = permutedims(G, [vperm; vperm .+ N])
end
perm::NTuple{N2,Int} = Tuple([1; N + 1; 2:N; N+2:2N])
permutedG = permutedims(G, perm)
permutedG = permutedims(G, [1; N + 1; 2:N; N+2:2N])
squareG = reshape(permutedG, outs[1] * ins[1], prod(outs[2:N]) * prod(ins[2:N]))

chunks = _partition(prod((outs .^ ins)[2:N]), Threads.nthreads())
outs2 = outs
ins2 = ins
squareG2 = squareG #workaround for https://github.com/JuliaLang/julia/issues/15276
tasks = map(chunks) do chunk
Threads.@spawn _local_bound_probability_single(chunk, outs2, ins2, squareG2)
Threads.@spawn _local_bound_probability_core(chunk, outs2, ins2, squareG2)
end
score = maximum(fetch.(tasks))
return score
end

function _local_bound_probability_single2(chunk, outs::NTuple{2,Int}, ins::NTuple{2,Int}, squareG::Array{T,2}) where {T}
function _local_bound_probability_core(chunk, outs::NTuple{2,Int}, ins::NTuple{2,Int}, squareG::Array{T,2}) where {T}
oa, ob = outs
ia, ib = ins
score = typemin(T)
Expand All @@ -142,7 +140,7 @@ function _local_bound_probability_single2(chunk, outs::NTuple{2,Int}, ins::NTupl
return score
end

function _local_bound_probability_single(chunk, outs::NTuple{N,Int}, ins::NTuple{N,Int}, squareG::Array{T,2}) where {T,N}
function _local_bound_probability_core(chunk, outs::NTuple{N,Int}, ins::NTuple{N,Int}, squareG::Array{T,2}) where {T,N}
score = typemin(T)
bases = reduce(vcat, [fill(outs[i], ins[i]) for i 2:length(ins)])
ind = _digits_mixed_basis(chunk[1] - 1, bases)
Expand Down

0 comments on commit 2ee5774

Please sign in to comment.