Skip to content

Commit

Permalink
Fix for iterators with shape
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Sep 26, 2019
1 parent c771e5d commit 63ff9cc
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/Containers/Containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ export DenseAxisArray, SparseAxisArray
include("DenseAxisArray.jl")
include("SparseAxisArray.jl")
include("generate_container.jl")
include("vectorized_product_iterator.jl")
include("nested_iterator.jl")
include("container.jl")
include("macro.jl")

Expand Down
2 changes: 0 additions & 2 deletions src/Containers/SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

include("nested_iterator.jl")

"""
struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N}
data::Dict{K,T}
Expand Down
12 changes: 6 additions & 6 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const ArrayIndices{N} = Iterators.ProductIterator{NTuple{N, Base.OneTo{Int}}}
const ArrayIndices{N} = VectorizedProductIterator{NTuple{N, Base.OneTo{Int}}}
container(f::Function, indices) = container(f, indices, default_container(indices))
default_container(::ArrayIndices) = Array
function container(f::Function, indices::ArrayIndices, ::Type{Array})
Expand All @@ -10,14 +10,14 @@ function _oneto(indices)
end
error("Index set for array is not one-based interval.")
end
function container(f::Function, indices::Iterators.ProductIterator,
function container(f::Function, indices::VectorizedProductIterator,
::Type{Array})
container(f, Iterators.ProductIterator(_oneto.(indices.iterators)), Array)
container(f, vectorized_product(_oneto.(indices.prod.iterators)...), Array)
end
default_container(::Iterators.ProductIterator) = DenseAxisArray
function container(f::Function, indices::Iterators.ProductIterator,
default_container(::VectorizedProductIterator) = DenseAxisArray
function container(f::Function, indices::VectorizedProductIterator,
::Type{DenseAxisArray})
return DenseAxisArray(map(I -> f(I...), indices), indices.iterators...)
return DenseAxisArray(map(I -> f(I...), indices), indices.prod.iterators...)
end
default_container(::NestedIterator) = SparseAxisArray
function container(f::Function, indices,
Expand Down
6 changes: 3 additions & 3 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ function _build_ref_sets(_error::Function, expr)
esc_idxvars = esc.(idxvars)
idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)]
if condition == :()
indices = :(Containers.NestedIterator(($(idxfuns...),)))
indices = :(Containers.nested($(idxfuns...)))
else
condition_fun = :(($(esc_idxvars...),) -> $(esc(condition)))
indices = :(Containers.NestedIterator(($(idxfuns...),), $condition_fun))
indices = :(Containers.nested($(idxfuns...); condition = $condition_fun))
end
else
indices = :(Base.Iterators.product(($(_explicit_oneto.(idxsets)...))))
indices = :(Containers.vectorized_product($(_explicit_oneto.(idxsets)...)))
end
return idxvars, indices
end
Expand Down
4 changes: 3 additions & 1 deletion src/Containers/nested_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ struct NestedIterator{T}
iterators::T # Tuple of functions
condition::Function
end
NestedIterator(iterator) = NestedIterator(iterator, (args...) -> true)
function nested(iterators...; condition = (args...) -> true)
return NestedIterator(iterators, condition)
end
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown()
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state)
Expand Down
32 changes: 32 additions & 0 deletions src/Containers/vectorized_product_iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
struct VectorizedProductIterator{T}
prod::Iterators.ProductIterator{T}
end
Same as `Base.Iterators.ProductIterator` except that it is independent
on the `IteratorSize` of the elements of `prod.iterators`.
For instance:
* `size(Iterators.product(1, 2))` is `tuple()` while
`size(VectorizedProductIterator(1, 2))` is `(1, 1)`.
* `size(Iterators.product(ones(2, 3)))` is `(2, 3)` while
`size(VectorizedProductIterator(ones(2, 3)))` is `(1, 1)`.
"""
struct VectorizedProductIterator{T}
prod::Iterators.ProductIterator{T}
end
function vectorized_product(iterators...)
return VectorizedProductIterator(Iterators.product(iterators...))
end
function Base.IteratorSize(::Type{<:VectorizedProductIterator{<:Tuple{Vararg{Any, N}}}}) where N
return Base.HasShape{N}()
end
Base.IteratorEltype(::Type{<:VectorizedProductIterator}) = Base.EltypeUnknown()
Base.size(it::VectorizedProductIterator) = _prod_size(it.prod.iterators)
_prod_size(::Tuple{}) = ()
_prod_size(t::Tuple) = (length(t[1]), _prod_size(Base.tail(t))...)
Base.axes(it::VectorizedProductIterator) = _prod_indices(it.prod.iterators)
_prod_indices(::Tuple{}) = ()
_prod_indices(t::Tuple) = (Base.OneTo(length(t[1])), _prod_indices(Base.tail(t))...)
Base.ndims(it::VectorizedProductIterator) = length(axes(it))
Base.length(it::VectorizedProductIterator) = prod(size(it))
Base.iterate(it::VectorizedProductIterator, args...) = iterate(it.prod, args...)
2 changes: 1 addition & 1 deletion src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ above but does it without using the `@variable` macro
x = JuMP.Containers.container(i -> begin
info = VariableInfo(false, NaN, true, ub[i], false, NaN, false, NaN, false, false)
x[i] = JuMP.add_variable(model, JuMP.build_variable(error, info), "x[\$i]")
end, Base.Iterators.product(keys(ub)))
end, JuMP.Containers.vectorized_product(keys(ub)))
# output
1-dimensional DenseAxisArray{VariableRef,1,...} with index sets:
Expand Down
6 changes: 6 additions & 0 deletions test/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ using JuMP.Containers
@test x isa Containers.DenseAxisArray{Int, 1}
Containers.@container(x[i = 2:3, j = 1:2], i + j)
@test x isa Containers.DenseAxisArray{Int, 2}
Containers.@container(x[4], 0.0)
@test x isa Containers.DenseAxisArray{Float64, 1}
Containers.@container(x[4, 5], 0)
@test x isa Containers.DenseAxisArray{Int, 2}
Containers.@container(x[4, 1:3, 5], 0)
@test x isa Containers.DenseAxisArray{Int, 3}
end
@testset "SparseAxisArray" begin
Containers.@container(x[i = 1:3, j = 1:i], i + j)
Expand Down

0 comments on commit 63ff9cc

Please sign in to comment.