Skip to content

Commit

Permalink
optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed Jan 2, 2025
1 parent 5072d49 commit eaee958
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function _local_bound_probability(G::Array{T,N2}) where {T<:Real,N2}
return score
end

function _local_bound_probability_single(chunk, outs::NTuple{2,Int}, ins::NTuple{2,Int}, squareG::Array{T,2}) where {T}
function _local_bound_probability_single2(chunk, outs::NTuple{2,Int}, ins::NTuple{2,Int}, squareG::Array{T,2}) where {T}
oa, ob = outs
ia, ib = ins
score = typemin(T)
Expand Down Expand Up @@ -154,29 +154,34 @@ function _local_bound_probability_single(chunk, outs::NTuple{N,Int}, ins::NTuple
end
linearindex(v) = 1 + dot(v, prodsizes)
by = zeros(Int, 2 * (N - 1))
ins_region = CartesianIndices(ins[2:N])
offset_ind = zeros(Int, prod(ins[2:N]))
@inbounds for _ chunk[1]:chunk[2]
fill!(Galice, 0)
for y CartesianIndices(ins[2:N])
counter = 0
for y ins_region
by[1] = ind[y[1]]
for i 2:length(y)
by[i] = ind[y[i]+ins[i]]
end
for i 1:N-1
by[i+N-1] = y[i] - 1
end
lin_by = linearindex(by)
for i 1:outs[1]*ins[1]
Galice[i] += squareG[i, lin_by]
end
counter += 1
offset_ind[counter] = linearindex(by)
end
@views sum!(Galice, squareG[:, offset_ind])
temp_score = _maxcols!(Galice, outs[1], ins[1])
score = max(score, temp_score)
_update_odometer!(ind, bases)
end
return score
end

#sum(maximum(v, dims = 1)), with v interpreted as a oa x ia matrix
"""
_maxcols!(A::Array, n::Integer, m::Integer)
Computes `sum(maximum(A, dims = 1))`, with `A` interpreted as an `n` by `m` matrix. `A` is destroyed.
"""
function _maxcols!(v, oa, ia)
for x 1:ia
for a 2:oa
Expand Down Expand Up @@ -223,7 +228,7 @@ end
function _digits_mixed_basis(ind, bases)
N = length(bases)
digits = zeros(Int, N)
for i N:-1:1
@inbounds for i N:-1:1
digits[i] = mod(ind, bases[i])
ind = div(ind, bases[i])
end
Expand All @@ -234,7 +239,7 @@ function _update_odometer!(ind::AbstractVector{<:Integer}, bases::AbstractVector
ind[1] += 1
d = length(ind)

for i 1:d
@inbounds for i 1:d
if ind[i] bases[i]
ind[i] = 0
i < d ? ind[i+1] += 1 : return
Expand All @@ -248,7 +253,7 @@ function _update_odometer!(ind::AbstractVector{<:Integer}, bases::Integer)
ind[1] += 1
d = length(ind)

for i 1:d
@inbounds for i 1:d
if ind[i] bases
ind[i] = 0
i < d ? ind[i+1] += 1 : return
Expand Down

0 comments on commit eaee958

Please sign in to comment.