Skip to content

Commit

Permalink
remove allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed Jan 1, 2025
1 parent d18760d commit 32eb613
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,11 @@ function _local_bound_single(chunk, outs::NTuple{2, Int}, ins::NTuple{2, Int}, s
ind = digits(chunk[1] - 1; base = ob, pad = ib)
offset = Vector(1 .+ ob * (0:(ib - 1)))
offset_ind = zeros(Int, ib)
Galice = zeros(T, oa * ia, 1)
maxvec = zeros(T, 1, ia)
Galice = zeros(T, oa * ia)
@inbounds for _ in chunk[1]:chunk[2]
offset_ind .= ind .+ offset
@views sum!(Galice, squareG[:, offset_ind])
squareGalice = reshape(Galice, oa, ia)
temp_score = sum(maximum!(maxvec, squareGalice))
temp_score = _maxcols!(Galice, oa, ia)
score = max(score, temp_score)
_update_odometer!(ind, ob)
end
Expand All @@ -63,8 +61,7 @@ function _local_bound_single(chunk, outs::NTuple{N, Int}, ins::NTuple{N, Int}, s
score = typemin(T)
bases = reduce(vcat, [outs[i] * ones(Int, ins[i]) for i in 2:length(ins)])
ind = _digits_mixed_basis(chunk[1] - 1, bases)
Galice = zeros(T, outs[1] * ins[1], 1)
maxvec = zeros(T, 1, ins[1])
Galice = zeros(T, outs[1] * ins[1])
b = zeros(Int, N - 1)
sizes = (outs[2:N]..., ins[2:N]...)
prodsizes = ones(Int, 2 * (N - 1))
Expand All @@ -83,19 +80,34 @@ function _local_bound_single(chunk, outs::NTuple{N, Int}, ins::NTuple{N, Int}, s
for i in 1:(N - 1)
by[i + N - 1] = y[i] - 1
end
for x in 1:ins[1], a in 1:outs[1]
Galice[a + (x - 1) * outs[1]] += squareG[a + (x - 1) * outs[1], linearindex(by)]
for i in 1:outs[1]*ins[1]
Galice[i] += squareG[i, linearindex(by)]
end
end
squareGalice = reshape(Galice, outs[1], ins[1])
temp_score = sum(maximum!(maxvec, squareGalice))
temp_score = _maxcols!(Galice, outs[1], ins[1])
score = max(score, temp_score)
_update_odometer!(ind, bases)
end

return score
end

#sum(maximum(v, dims = 1)), with v interpreted as a oa x ia matrix
function _maxcols!(v, oa, ia)
for x = 1:ia
for a = 2:oa
if v[a + (x-1)*oa] > v[1 + (x-1)*oa]
v[1 + (x-1)*oa] = v[a + (x-1)*oa]
end
end
end
temp_score = v[1]
for x = 2:ia
temp_score += v[1 + (x-1)*oa]
end
return temp_score
end

"""
partition(n::Integer, k::Integer)
Expand Down

0 comments on commit 32eb613

Please sign in to comment.