diff --git a/src/Containers/Containers.jl b/src/Containers/Containers.jl index 32d11c4f7fb..7f91aca1362 100644 --- a/src/Containers/Containers.jl +++ b/src/Containers/Containers.jl @@ -26,6 +26,8 @@ export DenseAxisArray, SparseAxisArray include("DenseAxisArray.jl") include("SparseAxisArray.jl") include("generate_container.jl") +include("vectorized_product_iterator.jl") +include("nested_iterator.jl") include("container.jl") include("macro.jl") diff --git a/src/Containers/SparseAxisArray.jl b/src/Containers/SparseAxisArray.jl index 001bdcaa522..bb746b230ad 100644 --- a/src/Containers/SparseAxisArray.jl +++ b/src/Containers/SparseAxisArray.jl @@ -3,8 +3,6 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -include("nested_iterator.jl") - """ struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N} data::Dict{K,T} diff --git a/src/Containers/container.jl b/src/Containers/container.jl index bef17066c5e..5a9ac5ddb9e 100644 --- a/src/Containers/container.jl +++ b/src/Containers/container.jl @@ -1,4 +1,4 @@ -const ArrayIndices{N} = Iterators.ProductIterator{NTuple{N, Base.OneTo{Int}}} +const ArrayIndices{N} = VectorizedProductIterator{NTuple{N, Base.OneTo{Int}}} container(f::Function, indices) = container(f, indices, default_container(indices)) default_container(::ArrayIndices) = Array function container(f::Function, indices::ArrayIndices, ::Type{Array}) @@ -10,14 +10,14 @@ function _oneto(indices) end error("Index set for array is not one-based interval.") end -function container(f::Function, indices::Iterators.ProductIterator, +function container(f::Function, indices::VectorizedProductIterator, ::Type{Array}) - container(f, Iterators.ProductIterator(_oneto.(indices.iterators)), Array) + container(f, vectorized_product(_oneto.(indices.prod.iterators)...), Array) end -default_container(::Iterators.ProductIterator) = DenseAxisArray -function container(f::Function, indices::Iterators.ProductIterator, +default_container(::VectorizedProductIterator) = DenseAxisArray +function container(f::Function, indices::VectorizedProductIterator, ::Type{DenseAxisArray}) - return DenseAxisArray(map(I -> f(I...), indices), indices.iterators...) + return DenseAxisArray(map(I -> f(I...), indices), indices.prod.iterators...) end default_container(::NestedIterator) = SparseAxisArray function container(f::Function, indices, diff --git a/src/Containers/macro.jl b/src/Containers/macro.jl index 14380b36f30..68df9f63d66 100644 --- a/src/Containers/macro.jl +++ b/src/Containers/macro.jl @@ -135,13 +135,13 @@ function _build_ref_sets(_error::Function, expr) esc_idxvars = esc.(idxvars) idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)] if condition == :() - indices = :(Containers.NestedIterator(($(idxfuns...),))) + indices = :(Containers.nested($(idxfuns...))) else condition_fun = :(($(esc_idxvars...),) -> $(esc(condition))) - indices = :(Containers.NestedIterator(($(idxfuns...),), $condition_fun)) + indices = :(Containers.nested($(idxfuns...); condition = $condition_fun)) end else - indices = :(Base.Iterators.product(($(_explicit_oneto.(idxsets)...)))) + indices = :(Containers.vectorized_product($(_explicit_oneto.(idxsets)...))) end return idxvars, indices end diff --git a/src/Containers/nested_iterator.jl b/src/Containers/nested_iterator.jl index 706a130e459..165f0b5bb65 100644 --- a/src/Containers/nested_iterator.jl +++ b/src/Containers/nested_iterator.jl @@ -23,7 +23,9 @@ struct NestedIterator{T} iterators::T # Tuple of functions condition::Function end -NestedIterator(iterator) = NestedIterator(iterator, (args...) -> true) +function nested(iterators...; condition = (args...) -> true) + return NestedIterator(iterators, condition) +end Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown() function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state) diff --git a/src/Containers/vectorized_product_iterator.jl b/src/Containers/vectorized_product_iterator.jl new file mode 100644 index 00000000000..2e23960fd00 --- /dev/null +++ b/src/Containers/vectorized_product_iterator.jl @@ -0,0 +1,32 @@ +""" + struct VectorizedProductIterator{T} + prod::Iterators.ProductIterator{T} + end + +Same as `Base.Iterators.ProductIterator` except that it is independent +on the `IteratorSize` of the elements of `prod.iterators`. +For instance: +* `size(Iterators.product(1, 2))` is `tuple()` while + `size(VectorizedProductIterator(1, 2))` is `(1, 1)`. +* `size(Iterators.product(ones(2, 3)))` is `(2, 3)` while + `size(VectorizedProductIterator(ones(2, 3)))` is `(1, 1)`. +""" +struct VectorizedProductIterator{T} + prod::Iterators.ProductIterator{T} +end +function vectorized_product(iterators...) + return VectorizedProductIterator(Iterators.product(iterators...)) +end +function Base.IteratorSize(::Type{<:VectorizedProductIterator{<:Tuple{Vararg{Any, N}}}}) where N + return Base.HasShape{N}() +end +Base.IteratorEltype(::Type{<:VectorizedProductIterator}) = Base.EltypeUnknown() +Base.size(it::VectorizedProductIterator) = _prod_size(it.prod.iterators) +_prod_size(::Tuple{}) = () +_prod_size(t::Tuple) = (length(t[1]), _prod_size(Base.tail(t))...) +Base.axes(it::VectorizedProductIterator) = _prod_indices(it.prod.iterators) +_prod_indices(::Tuple{}) = () +_prod_indices(t::Tuple) = (Base.OneTo(length(t[1])), _prod_indices(Base.tail(t))...) +Base.ndims(it::VectorizedProductIterator) = length(axes(it)) +Base.length(it::VectorizedProductIterator) = prod(size(it)) +Base.iterate(it::VectorizedProductIterator, args...) = iterate(it.prod, args...) diff --git a/src/macros.jl b/src/macros.jl index 15661ba9485..0dc89e969e0 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -1098,7 +1098,7 @@ above but does it without using the `@variable` macro x = JuMP.Containers.container(i -> begin info = VariableInfo(false, NaN, true, ub[i], false, NaN, false, NaN, false, false) x[i] = JuMP.add_variable(model, JuMP.build_variable(error, info), "x[\$i]") - end, Base.Iterators.product(keys(ub))) + end, JuMP.Containers.vectorized_product(keys(ub))) # output 1-dimensional DenseAxisArray{VariableRef,1,...} with index sets: diff --git a/test/Containers/macro.jl b/test/Containers/macro.jl index 8e953ae94ed..7fe774a8170 100644 --- a/test/Containers/macro.jl +++ b/test/Containers/macro.jl @@ -14,6 +14,12 @@ using JuMP.Containers @test x isa Containers.DenseAxisArray{Int, 1} Containers.@container(x[i = 2:3, j = 1:2], i + j) @test x isa Containers.DenseAxisArray{Int, 2} + Containers.@container(x[4], 0.0) + @test x isa Containers.DenseAxisArray{Float64, 1} + Containers.@container(x[4, 5], 0) + @test x isa Containers.DenseAxisArray{Int, 2} + Containers.@container(x[4, 1:3, 5], 0) + @test x isa Containers.DenseAxisArray{Int, 3} end @testset "SparseAxisArray" begin Containers.@container(x[i = 1:3, j = 1:i], i + j)