Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create containers with map instead of for loops #2070

Merged
merged 9 commits into from
Oct 4, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/src/constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ following.
One way of adding a group of constraints compactly is the following:
```jldoctest constraint_arrays; setup=:(model=Model(); @variable(model, x))
julia> @constraint(model, con[i = 1:3], i * x <= i + 1)
3-element Array{ConstraintRef{Model,C,Shape} where Shape<:AbstractShape where C,1}:
3-element Array{ConstraintRef{Model,MathOptInterface.ConstraintIndex{MathOptInterface.ScalarAffineFunction{Float64},MathOptInterface.LessThan{Float64}},ScalarShape},1}:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty bad printing experience.

con[1] : x <= 2.0
con[2] : 2 x <= 3.0
con[3] : 3 x <= 4.0
Expand All @@ -264,7 +264,7 @@ julia> con[1]
con[1] : x <= 2.0

julia> con[2:3]
2-element Array{ConstraintRef{Model,C,Shape} where Shape<:AbstractShape where C,1}:
2-element Array{ConstraintRef{Model,MathOptInterface.ConstraintIndex{MathOptInterface.ScalarAffineFunction{Float64},MathOptInterface.LessThan{Float64}},ScalarShape},1}:
con[2] : 2 x <= 3.0
con[3] : 3 x <= 4.0
```
Expand All @@ -273,7 +273,7 @@ Anonymous containers can also be constructed by dropping the name (e.g. `con`)
before the square brackets:
```jldoctest constraint_arrays
julia> @constraint(model, [i = 1:2], i * x <= i + 1)
2-element Array{ConstraintRef{Model,C,Shape} where Shape<:AbstractShape where C,1}:
2-element Array{ConstraintRef{Model,MathOptInterface.ConstraintIndex{MathOptInterface.ScalarAffineFunction{Float64},MathOptInterface.LessThan{Float64}},ScalarShape},1}:
x <= 2.0
2 x <= 3.0
```
Expand All @@ -294,10 +294,10 @@ variables.

```jldoctest constraint_jumparrays; setup=:(model=Model(); @variable(model, x))
julia> @constraint(model, con[i = 1:2, j = 2:3], i * x <= j + 1)
2-dimensional DenseAxisArray{ConstraintRef{Model,C,Shape} where Shape<:AbstractShape where C,2,...} with index sets:
Dimension 1, 1:2
2-dimensional DenseAxisArray{ConstraintRef{Model,MathOptInterface.ConstraintIndex{MathOptInterface.ScalarAffineFunction{Float64},MathOptInterface.LessThan{Float64}},ScalarShape},2,...} with index sets:
Dimension 1, Base.OneTo(2)
Dimension 2, 2:3
And data, a 2×2 Array{ConstraintRef{Model,C,Shape} where Shape<:AbstractShape where C,2}:
And data, a 2×2 Array{ConstraintRef{Model,MathOptInterface.ConstraintIndex{MathOptInterface.ScalarAffineFunction{Float64},MathOptInterface.LessThan{Float64}},ScalarShape},2}:
con[1,2] : x <= 3.0 con[1,3] : x <= 4.0
con[2,2] : 2 x <= 3.0 con[2,3] : 2 x <= 4.0
```
Expand All @@ -311,7 +311,7 @@ similar to the [syntax for constructing](@ref variable_sparseaxisarrays) a

```jldoctest constraint_jumparrays; setup=:(model=Model(); @variable(model, x))
julia> @constraint(model, con[i = 1:2, j = 1:2; i != j], i * x <= j + 1)
JuMP.Containers.SparseAxisArray{ConstraintRef{Model,C,Shape} where Shape<:AbstractShape where C,2,Tuple{Any,Any}} with 2 entries:
JuMP.Containers.SparseAxisArray{ConstraintRef{Model,MathOptInterface.ConstraintIndex{MathOptInterface.ScalarAffineFunction{Float64},MathOptInterface.LessThan{Float64}},ScalarShape},2,Tuple{Int64,Int64}} with 2 entries:
[1, 2] = con[1,2] : x <= 3.0
[2, 1] = con[2,1] : 2 x <= 2.0
```
Expand Down
6 changes: 3 additions & 3 deletions docs/src/variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ return a `DenseAxisArray`. For example:
```jldoctest variables_jump_arrays; setup=:(model=Model())
julia> @variable(model, x[1:2, [:A,:B]])
2-dimensional DenseAxisArray{VariableRef,2,...} with index sets:
Dimension 1, 1:2
Dimension 1, Base.OneTo(2)
Dimension 2, Symbol[:A, :B]
And data, a 2×2 Array{VariableRef,2}:
x[1,A] x[1,B]
Expand Down Expand Up @@ -371,7 +371,7 @@ For example, this applies when indices have a dependence upon previous
indices (called *triangular indexing*). JuMP supports this as follows:
```jldoctest; setup=:(model=Model())
julia> @variable(model, x[i=1:2, j=i:2])
JuMP.Containers.SparseAxisArray{VariableRef,2,Tuple{Any,Any}} with 3 entries:
JuMP.Containers.SparseAxisArray{VariableRef,2,Tuple{Int64,Int64}} with 3 entries:
[1, 2] = x[1,2]
[2, 2] = x[2,2]
[1, 1] = x[1,1]
Expand All @@ -382,7 +382,7 @@ syntax appends a comparison check that depends upon the named indices and is
separated from the indices by a semi-colon (`;`). For example:
```jldoctest; setup=:(model=Model())
julia> @variable(model, x[i=1:4; mod(i, 2)==0])
JuMP.Containers.SparseAxisArray{VariableRef,1,Tuple{Any}} with 2 entries:
JuMP.Containers.SparseAxisArray{VariableRef,1,Tuple{Int64}} with 2 entries:
[4] = x[4]
[2] = x[2]
```
Expand Down
4 changes: 4 additions & 0 deletions src/Containers/Containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,9 @@ export DenseAxisArray, SparseAxisArray
include("DenseAxisArray.jl")
include("SparseAxisArray.jl")
include("generate_container.jl")
include("vectorized_product_iterator.jl")
include("nested_iterator.jl")
include("container.jl")
include("macro.jl")

end
41 changes: 41 additions & 0 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
const ArrayIndices{N} = VectorizedProductIterator{NTuple{N, Base.OneTo{Int}}}
blegat marked this conversation as resolved.
Show resolved Hide resolved
container(f::Function, indices) = container(f, indices, default_container(indices))
default_container(::ArrayIndices) = Array
function container(f::Function, indices::ArrayIndices, ::Type{Array})
return map(I -> f(I...), indices)
end
function _oneto(indices)
if indices isa UnitRange{Int} && indices == 1:length(indices)
return Base.OneTo(length(indices))
end
error("Index set for array is not one-based interval.")
end
function container(f::Function, indices::VectorizedProductIterator,
blegat marked this conversation as resolved.
Show resolved Hide resolved
::Type{Array})
container(f, vectorized_product(_oneto.(indices.prod.iterators)...), Array)
end
default_container(::VectorizedProductIterator) = DenseAxisArray
function container(f::Function, indices::VectorizedProductIterator,
::Type{DenseAxisArray})
return DenseAxisArray(map(I -> f(I...), indices), indices.prod.iterators...)
end
default_container(::NestedIterator) = SparseAxisArray
function container(f::Function, indices,
::Type{SparseAxisArray})
mappings = map(I -> I => f(I...), indices)
data = Dict(mappings)
if length(mappings) != length(data)
unique_indices = Set()
duplicate = nothing
for index in indices
if index in unique_indices
duplicate = index
break
end
push!(unique_indices, index)
end
# TODO compute idx
error("Repeated index ", duplicate, ". Index sets must have unique elements.")
end
return SparseAxisArray(Dict(data))
end
187 changes: 187 additions & 0 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
using Base.Meta

"""
_extract_kw_args(args)

Process the arguments to a macro, separating out the keyword arguments.
Return a tuple of (flat_arguments, keyword arguments, and requested_container),
where `requested_container` is a symbol to be passed to `parse_container`.
"""
function _extract_kw_args(args)
kw_args = filter(x -> isexpr(x, :(=)) && x.args[1] != :container , collect(args))
flat_args = filter(x->!isexpr(x, :(=)), collect(args))
requested_container = :Auto
for kw in args
if isexpr(kw, :(=)) && kw.args[1] == :container
requested_container = kw.args[2]
end
end
return flat_args, kw_args, requested_container
end

function _try_parse_idx_set(arg::Expr)
# [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and
# Expr(:ref, :x, Expr(:kw, :i, 1)) respectively.
if arg.head === :kw || arg.head === :(=)
@assert length(arg.args) == 2
return true, arg.args[1], arg.args[2]
elseif isexpr(arg, :call) && arg.args[1] === :in
return true, arg.args[2], arg.args[3]
else
return false, nothing, nothing
end
end
function _explicit_oneto(index_set)
s = Meta.isexpr(index_set,:escape) ? index_set.args[1] : index_set
if Meta.isexpr(s,:call) && length(s.args) == 3 && s.args[1] == :(:) && s.args[2] == 1
return :(Base.OneTo($index_set))
else
return index_set
end
end

function _expr_is_splat(ex::Expr)
if ex.head == :(...)
return true
elseif ex.head == :escape
return _expr_is_splat(ex.args[1])
end
return false
end
_expr_is_splat(::Any) = false

"""
_parse_ref_sets(expr::Expr)

Helper function for macros to construct container objects. Takes an `Expr` that
specifies the container, e.g. `:(x[i=1:3,[:red,:blue],k=S; i+k <= 6])`, and
returns:

1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]`
2. `idxsets`: Sets used for indexing, e.g. `[1:3, [:red,:blue], S]`
3. `condition`: Expr containing any conditional imposed on indexing, or `:()` if none is present
"""
function _parse_ref_sets(_error::Function, expr::Expr)
c = copy(expr)
idxvars = Any[]
idxsets = Any[]
# On 0.7, :(t[i;j]) is a :ref, while t[i,j;j] is a :typed_vcat.
blegat marked this conversation as resolved.
Show resolved Hide resolved
# In both cases :t is the first arg.
if isexpr(c, :typed_vcat) || isexpr(c, :ref)
popfirst!(c.args)
end
condition = :()
if isexpr(c, :vcat) || isexpr(c, :typed_vcat)
# Parameters appear as plain args at the end.
if length(c.args) > 2
_error("Unsupported syntax $c.")
elseif length(c.args) == 2
condition = pop!(c.args)
end # else no condition.
elseif isexpr(c, :ref) || isexpr(c, :vect)
# Parameters appear at the front.
if isexpr(c.args[1], :parameters)
if length(c.args[1].args) != 1
_error("Invalid syntax: $c. Multiple semicolons are not " *
"supported.")
end
condition = popfirst!(c.args).args[1]
end
end
if isexpr(c, :vcat) || isexpr(c, :typed_vcat) || isexpr(c, :ref)
if isexpr(c.args[1], :parameters)
@assert length(c.args[1].args) == 1
condition = popfirst!(c.args).args[1]
end # else no condition.
end

for s in c.args
parse_done = false
if isa(s, Expr)
parse_done, idxvar, _idxset = _try_parse_idx_set(s::Expr)
if parse_done
idxset = esc(_idxset)
end
end
if !parse_done # No index variable specified
idxvar = gensym()
idxset = esc(s)
end
push!(idxvars, idxvar)
push!(idxsets, idxset)
end
return idxvars, idxsets, condition
end
_parse_ref_sets(_error::Function, expr) = (Any[], Any[], :())

"""
_build_ref_sets(_error::Function, expr)

Helper function for macros to construct container objects. Takes an `Expr` that
specifies the container, e.g. `:(x[i=1:3,[:red,:blue],k=S; i+k <= 6])`, and
returns:

1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]`
2. `indices`: Iterators over the indices indexing, e.g.
`Constainers.NestedIterators((1:3, [:red,:blue], S), (i, ##..., k) -> i + k <= 6)`.
"""
function _build_ref_sets(_error::Function, expr)
idxvars, idxsets, condition = _parse_ref_sets(_error, expr)
if any(_expr_is_splat.(idxsets))
_error("cannot use splatting operator `...` in the definition of an index set.")
end
has_dependent = has_dependent_sets(idxvars, idxsets)
if has_dependent || condition != :()
esc_idxvars = esc.(idxvars)
idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)]
if condition == :()
indices = :(Containers.nested($(idxfuns...)))
else
condition_fun = :(($(esc_idxvars...),) -> $(esc(condition)))
indices = :(Containers.nested($(idxfuns...); condition = $condition_fun))
end
else
indices = :(Containers.vectorized_product($(_explicit_oneto.(idxsets)...)))
end
return idxvars, indices
end

function container_code(idxvars, indices, code, requested_container)
if isempty(idxvars)
return code
end
if !(requested_container in [:Auto, :Array, :DenseAxisArray, :SparseAxisArray])
# We do this two-step interpolation, first into the string, and then
# into the expression because interpolating into a string inside an
# expression has scoping issues.
error_message = "Invalid container type $requested_container. Must be " *
"Auto, Array, DenseAxisArray, or SparseAxisArray."
return :(error($error_message))
end
if requested_container == :DenseAxisArray
requested_container = :(JuMP.Containers.DenseAxisArray)
elseif requested_container == :SparseAxisArray
requested_container = :(JuMP.Containers.SparseAxisArray)
end
esc_idxvars = esc.(idxvars)
func = :(($(esc_idxvars...),) -> $code)
if requested_container == :Auto
return :(Containers.container($func, $indices))
else
return :(Containers.container($func, $indices, $requested_container))
end
end
function parse_container(_error, var, value, requested_container)
idxvars, indices = _build_ref_sets(_error, var)
return container_code(idxvars, indices, value, requested_container)
end

macro container(args...)
blegat marked this conversation as resolved.
Show resolved Hide resolved
args, kw_args, requested_container = _extract_kw_args(args)
@assert length(args) == 2
@assert isempty(kw_args)
var, value = args
name = var.args[1]
code = parse_container(error, var, esc(value), requested_container)
return :($(esc(name)) = $code)
end
67 changes: 67 additions & 0 deletions src/Containers/nested_iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
struct NestedIterator{T}
iterators::T # Tuple of functions
condition::Function
end

Iterators over the tuples that are produced by a nested for loop.
For instance, if `length(iterators) == 3` , this corresponds to the tuples
`(i1, i2, i3)` produced by:
```
for i1 in iterators[1]()
for i2 in iterator[2](i1)
for i3 in iterator[3](i1, i2)
if condition(i1, i2, i3)
# produces (i1, i2, i3)
end
end
end
end
```
"""
struct NestedIterator{T}
blegat marked this conversation as resolved.
Show resolved Hide resolved
iterators::T # Tuple of functions
condition::Function
end
function nested(iterators...; condition = (args...) -> true)
return NestedIterator(iterators, condition)
end
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown()
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state)
if elem_state === nothing
return nothing
end
elem, state = elem_state
elems_states = first_iterate(
it, i + 1, (elems..., elem),
(states..., (iterator, state, elem)))
if elems_states !== nothing
return elems_states
end
return next_iterate(it, i, elems, states, iterator, iterate(iterator, state))
end
function first_iterate(it::NestedIterator, i, elems, states)
if i > length(it.iterators)
if it.condition(elems...)
return elems, states
else
return nothing
end
end
iterator = it.iterators[i](elems...)
return next_iterate(it, i, elems, states, iterator, iterate(iterator))
end
function tail_iterate(it::NestedIterator, i, elems, states)
if i > length(it.iterators)
return nothing
end
next = tail_iterate(it, i + 1, (elems..., states[i][3]), states)
if next !== nothing
return next
end
iterator = states[i][1]
next_iterate(it, i, elems, states[1:(i - 1)], iterator, iterate(iterator, states[i][2]))
end
Base.iterate(it::NestedIterator) = first_iterate(it, 1, tuple(), tuple())
Base.iterate(it::NestedIterator, states) = tail_iterate(it, 1, tuple(), states)
Loading