diff --git a/src/nonlocal.jl b/src/nonlocal.jl index eddcace..ae6fc89 100644 --- a/src/nonlocal.jl +++ b/src/nonlocal.jl @@ -45,13 +45,11 @@ function _local_bound_single(chunk, outs::NTuple{2, Int}, ins::NTuple{2, Int}, s ind = digits(chunk[1] - 1; base = ob, pad = ib) offset = Vector(1 .+ ob * (0:(ib - 1))) offset_ind = zeros(Int, ib) - Galice = zeros(T, oa * ia, 1) - maxvec = zeros(T, 1, ia) + Galice = zeros(T, oa * ia) @inbounds for _ in chunk[1]:chunk[2] offset_ind .= ind .+ offset @views sum!(Galice, squareG[:, offset_ind]) - squareGalice = reshape(Galice, oa, ia) - temp_score = sum(maximum!(maxvec, squareGalice)) + temp_score = _maxcols!(Galice, oa, ia) score = max(score, temp_score) _update_odometer!(ind, ob) end @@ -63,8 +61,7 @@ function _local_bound_single(chunk, outs::NTuple{N, Int}, ins::NTuple{N, Int}, s score = typemin(T) bases = reduce(vcat, [outs[i] * ones(Int, ins[i]) for i in 2:length(ins)]) ind = _digits_mixed_basis(chunk[1] - 1, bases) - Galice = zeros(T, outs[1] * ins[1], 1) - maxvec = zeros(T, 1, ins[1]) + Galice = zeros(T, outs[1] * ins[1]) b = zeros(Int, N - 1) sizes = (outs[2:N]..., ins[2:N]...) prodsizes = ones(Int, 2 * (N - 1)) @@ -83,12 +80,11 @@ function _local_bound_single(chunk, outs::NTuple{N, Int}, ins::NTuple{N, Int}, s for i in 1:(N - 1) by[i + N - 1] = y[i] - 1 end - for x in 1:ins[1], a in 1:outs[1] - Galice[a + (x - 1) * outs[1]] += squareG[a + (x - 1) * outs[1], linearindex(by)] + for i in 1:outs[1]*ins[1] + Galice[i] += squareG[i, linearindex(by)] end end - squareGalice = reshape(Galice, outs[1], ins[1]) - temp_score = sum(maximum!(maxvec, squareGalice)) + temp_score = _maxcols!(Galice, outs[1], ins[1]) score = max(score, temp_score) _update_odometer!(ind, bases) end @@ -96,6 +92,22 @@ function _local_bound_single(chunk, outs::NTuple{N, Int}, ins::NTuple{N, Int}, s return score end +#sum(maximum(v, dims = 1)), with v interpreted as a oa x ia matrix +function _maxcols!(v, oa, ia) + for x = 1:ia + for a = 2:oa + if v[a + (x-1)*oa] > v[1 + (x-1)*oa] + v[1 + (x-1)*oa] = v[a + (x-1)*oa] + end + end + end + temp_score = v[1] + for x = 2:ia + temp_score += v[1 + (x-1)*oa] + end + return temp_score +end + """ partition(n::Integer, k::Integer)