-
-
Notifications
You must be signed in to change notification settings - Fork 399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create containers with map instead of for loops #2070
Merged
Merged
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
8ce7535
Move Containers tests to Containers folder
blegat 05f4242
Create containers with map instead of for loops
blegat c771e5d
Update docstrings
blegat 63ff9cc
Fix for iterators with shape
blegat 974dfee
Address comments
blegat 61d5e1a
Add a longer comment
blegat bc881f7
Add axis_constraints.jl benchmark
blegat 37f9a0c
Faster SparseAxisArrays construction in macros
blegat 1192fd8
Type stability fixes in NestedIterator
blegat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
const ArrayIndices{N} = VectorizedProductIterator{NTuple{N, Base.OneTo{Int}}} | ||
blegat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
container(f::Function, indices) = container(f, indices, default_container(indices)) | ||
default_container(::ArrayIndices) = Array | ||
function container(f::Function, indices::ArrayIndices, ::Type{Array}) | ||
return map(I -> f(I...), indices) | ||
end | ||
function _oneto(indices) | ||
if indices isa UnitRange{Int} && indices == 1:length(indices) | ||
return Base.OneTo(length(indices)) | ||
end | ||
error("Index set for array is not one-based interval.") | ||
end | ||
function container(f::Function, indices::VectorizedProductIterator, | ||
blegat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
::Type{Array}) | ||
container(f, vectorized_product(_oneto.(indices.prod.iterators)...), Array) | ||
end | ||
default_container(::VectorizedProductIterator) = DenseAxisArray | ||
function container(f::Function, indices::VectorizedProductIterator, | ||
::Type{DenseAxisArray}) | ||
return DenseAxisArray(map(I -> f(I...), indices), indices.prod.iterators...) | ||
end | ||
default_container(::NestedIterator) = SparseAxisArray | ||
function container(f::Function, indices, | ||
::Type{SparseAxisArray}) | ||
mappings = map(I -> I => f(I...), indices) | ||
data = Dict(mappings) | ||
if length(mappings) != length(data) | ||
unique_indices = Set() | ||
duplicate = nothing | ||
for index in indices | ||
if index in unique_indices | ||
duplicate = index | ||
break | ||
end | ||
push!(unique_indices, index) | ||
end | ||
# TODO compute idx | ||
error("Repeated index ", duplicate, ". Index sets must have unique elements.") | ||
end | ||
return SparseAxisArray(Dict(data)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
using Base.Meta | ||
|
||
""" | ||
_extract_kw_args(args) | ||
|
||
Process the arguments to a macro, separating out the keyword arguments. | ||
Return a tuple of (flat_arguments, keyword arguments, and requested_container), | ||
where `requested_container` is a symbol to be passed to `parse_container`. | ||
""" | ||
function _extract_kw_args(args) | ||
kw_args = filter(x -> isexpr(x, :(=)) && x.args[1] != :container , collect(args)) | ||
flat_args = filter(x->!isexpr(x, :(=)), collect(args)) | ||
requested_container = :Auto | ||
for kw in args | ||
if isexpr(kw, :(=)) && kw.args[1] == :container | ||
requested_container = kw.args[2] | ||
end | ||
end | ||
return flat_args, kw_args, requested_container | ||
end | ||
|
||
function _try_parse_idx_set(arg::Expr) | ||
# [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and | ||
# Expr(:ref, :x, Expr(:kw, :i, 1)) respectively. | ||
if arg.head === :kw || arg.head === :(=) | ||
@assert length(arg.args) == 2 | ||
return true, arg.args[1], arg.args[2] | ||
elseif isexpr(arg, :call) && arg.args[1] === :in | ||
return true, arg.args[2], arg.args[3] | ||
else | ||
return false, nothing, nothing | ||
end | ||
end | ||
function _explicit_oneto(index_set) | ||
s = Meta.isexpr(index_set,:escape) ? index_set.args[1] : index_set | ||
if Meta.isexpr(s,:call) && length(s.args) == 3 && s.args[1] == :(:) && s.args[2] == 1 | ||
return :(Base.OneTo($index_set)) | ||
else | ||
return index_set | ||
end | ||
end | ||
|
||
function _expr_is_splat(ex::Expr) | ||
if ex.head == :(...) | ||
return true | ||
elseif ex.head == :escape | ||
return _expr_is_splat(ex.args[1]) | ||
end | ||
return false | ||
end | ||
_expr_is_splat(::Any) = false | ||
|
||
""" | ||
_parse_ref_sets(expr::Expr) | ||
|
||
Helper function for macros to construct container objects. Takes an `Expr` that | ||
specifies the container, e.g. `:(x[i=1:3,[:red,:blue],k=S; i+k <= 6])`, and | ||
returns: | ||
|
||
1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]` | ||
2. `idxsets`: Sets used for indexing, e.g. `[1:3, [:red,:blue], S]` | ||
3. `condition`: Expr containing any conditional imposed on indexing, or `:()` if none is present | ||
""" | ||
function _parse_ref_sets(_error::Function, expr::Expr) | ||
c = copy(expr) | ||
idxvars = Any[] | ||
idxsets = Any[] | ||
# On 0.7, :(t[i;j]) is a :ref, while t[i,j;j] is a :typed_vcat. | ||
blegat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# In both cases :t is the first arg. | ||
if isexpr(c, :typed_vcat) || isexpr(c, :ref) | ||
popfirst!(c.args) | ||
end | ||
condition = :() | ||
if isexpr(c, :vcat) || isexpr(c, :typed_vcat) | ||
# Parameters appear as plain args at the end. | ||
if length(c.args) > 2 | ||
_error("Unsupported syntax $c.") | ||
elseif length(c.args) == 2 | ||
condition = pop!(c.args) | ||
end # else no condition. | ||
elseif isexpr(c, :ref) || isexpr(c, :vect) | ||
# Parameters appear at the front. | ||
if isexpr(c.args[1], :parameters) | ||
if length(c.args[1].args) != 1 | ||
_error("Invalid syntax: $c. Multiple semicolons are not " * | ||
"supported.") | ||
end | ||
condition = popfirst!(c.args).args[1] | ||
end | ||
end | ||
if isexpr(c, :vcat) || isexpr(c, :typed_vcat) || isexpr(c, :ref) | ||
if isexpr(c.args[1], :parameters) | ||
@assert length(c.args[1].args) == 1 | ||
condition = popfirst!(c.args).args[1] | ||
end # else no condition. | ||
end | ||
|
||
for s in c.args | ||
parse_done = false | ||
if isa(s, Expr) | ||
parse_done, idxvar, _idxset = _try_parse_idx_set(s::Expr) | ||
if parse_done | ||
idxset = esc(_idxset) | ||
end | ||
end | ||
if !parse_done # No index variable specified | ||
idxvar = gensym() | ||
idxset = esc(s) | ||
end | ||
push!(idxvars, idxvar) | ||
push!(idxsets, idxset) | ||
end | ||
return idxvars, idxsets, condition | ||
end | ||
_parse_ref_sets(_error::Function, expr) = (Any[], Any[], :()) | ||
|
||
""" | ||
_build_ref_sets(_error::Function, expr) | ||
|
||
Helper function for macros to construct container objects. Takes an `Expr` that | ||
specifies the container, e.g. `:(x[i=1:3,[:red,:blue],k=S; i+k <= 6])`, and | ||
returns: | ||
|
||
1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]` | ||
2. `indices`: Iterators over the indices indexing, e.g. | ||
`Constainers.NestedIterators((1:3, [:red,:blue], S), (i, ##..., k) -> i + k <= 6)`. | ||
""" | ||
function _build_ref_sets(_error::Function, expr) | ||
idxvars, idxsets, condition = _parse_ref_sets(_error, expr) | ||
if any(_expr_is_splat.(idxsets)) | ||
_error("cannot use splatting operator `...` in the definition of an index set.") | ||
end | ||
has_dependent = has_dependent_sets(idxvars, idxsets) | ||
if has_dependent || condition != :() | ||
esc_idxvars = esc.(idxvars) | ||
idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)] | ||
if condition == :() | ||
indices = :(Containers.nested($(idxfuns...))) | ||
else | ||
condition_fun = :(($(esc_idxvars...),) -> $(esc(condition))) | ||
indices = :(Containers.nested($(idxfuns...); condition = $condition_fun)) | ||
end | ||
else | ||
indices = :(Containers.vectorized_product($(_explicit_oneto.(idxsets)...))) | ||
end | ||
return idxvars, indices | ||
end | ||
|
||
function container_code(idxvars, indices, code, requested_container) | ||
if isempty(idxvars) | ||
return code | ||
end | ||
if !(requested_container in [:Auto, :Array, :DenseAxisArray, :SparseAxisArray]) | ||
# We do this two-step interpolation, first into the string, and then | ||
# into the expression because interpolating into a string inside an | ||
# expression has scoping issues. | ||
error_message = "Invalid container type $requested_container. Must be " * | ||
"Auto, Array, DenseAxisArray, or SparseAxisArray." | ||
return :(error($error_message)) | ||
end | ||
if requested_container == :DenseAxisArray | ||
requested_container = :(JuMP.Containers.DenseAxisArray) | ||
elseif requested_container == :SparseAxisArray | ||
requested_container = :(JuMP.Containers.SparseAxisArray) | ||
end | ||
esc_idxvars = esc.(idxvars) | ||
func = :(($(esc_idxvars...),) -> $code) | ||
if requested_container == :Auto | ||
return :(Containers.container($func, $indices)) | ||
else | ||
return :(Containers.container($func, $indices, $requested_container)) | ||
end | ||
end | ||
function parse_container(_error, var, value, requested_container) | ||
idxvars, indices = _build_ref_sets(_error, var) | ||
return container_code(idxvars, indices, value, requested_container) | ||
end | ||
|
||
macro container(args...) | ||
blegat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
args, kw_args, requested_container = _extract_kw_args(args) | ||
@assert length(args) == 2 | ||
@assert isempty(kw_args) | ||
var, value = args | ||
name = var.args[1] | ||
code = parse_container(error, var, esc(value), requested_container) | ||
return :($(esc(name)) = $code) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
struct NestedIterator{T} | ||
iterators::T # Tuple of functions | ||
condition::Function | ||
end | ||
|
||
Iterators over the tuples that are produced by a nested for loop. | ||
For instance, if `length(iterators) == 3` , this corresponds to the tuples | ||
`(i1, i2, i3)` produced by: | ||
``` | ||
for i1 in iterators[1]() | ||
for i2 in iterator[2](i1) | ||
for i3 in iterator[3](i1, i2) | ||
if condition(i1, i2, i3) | ||
# produces (i1, i2, i3) | ||
end | ||
end | ||
end | ||
end | ||
``` | ||
""" | ||
struct NestedIterator{T} | ||
blegat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
iterators::T # Tuple of functions | ||
condition::Function | ||
end | ||
function nested(iterators...; condition = (args...) -> true) | ||
return NestedIterator(iterators, condition) | ||
end | ||
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown() | ||
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown() | ||
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state) | ||
if elem_state === nothing | ||
return nothing | ||
end | ||
elem, state = elem_state | ||
elems_states = first_iterate( | ||
it, i + 1, (elems..., elem), | ||
(states..., (iterator, state, elem))) | ||
if elems_states !== nothing | ||
return elems_states | ||
end | ||
return next_iterate(it, i, elems, states, iterator, iterate(iterator, state)) | ||
end | ||
function first_iterate(it::NestedIterator, i, elems, states) | ||
if i > length(it.iterators) | ||
if it.condition(elems...) | ||
return elems, states | ||
else | ||
return nothing | ||
end | ||
end | ||
iterator = it.iterators[i](elems...) | ||
return next_iterate(it, i, elems, states, iterator, iterate(iterator)) | ||
end | ||
function tail_iterate(it::NestedIterator, i, elems, states) | ||
if i > length(it.iterators) | ||
return nothing | ||
end | ||
next = tail_iterate(it, i + 1, (elems..., states[i][3]), states) | ||
if next !== nothing | ||
return next | ||
end | ||
iterator = states[i][1] | ||
next_iterate(it, i, elems, states[1:(i - 1)], iterator, iterate(iterator, states[i][2])) | ||
end | ||
Base.iterate(it::NestedIterator) = first_iterate(it, 1, tuple(), tuple()) | ||
Base.iterate(it::NestedIterator, states) = tail_iterate(it, 1, tuple(), states) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty bad printing experience.