From a0af39df82fce30cce0821f4f08ec7b13d710b01 Mon Sep 17 00:00:00 2001 From: Miles Lubin Date: Mon, 12 Nov 2018 20:51:17 -0500 Subject: [PATCH] Performance improvements in linear_terms and quad_terms (#1604) * :zap: Performance improvements in linear_terms and quad_terms * drop unneeded <:Any --- src/aff_expr.jl | 23 +++++++++++++++++++---- src/quad_expr.jl | 26 +++++++++++++++++++++----- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/src/aff_expr.jl b/src/aff_expr.jl index fb5502593eb..aed52013423 100644 --- a/src/aff_expr.jl +++ b/src/aff_expr.jl @@ -139,14 +139,29 @@ if VERSION < v"0.7-" Base.done( lti::LinearTermIterator, state::Int) = done(lti.aff.terms, state) Base.next( lti::LinearTermIterator, state::Int) = reorder_iterator(next(lti.aff.terms, state)...) else - reorder_iterator(::Nothing) = nothing - reorder_iterator(t::Tuple{Pair,Int}) = ((first(t).second, first(t).first), last(t)) - Base.iterate(lti::LinearTermIterator) = reorder_iterator(iterate(lti.aff.terms)) + reverse_pair_to_tuple(p::Pair) = (p.second, p.first) + function Base.iterate(lti::LinearTermIterator) + ret = iterate(lti.aff.terms) + if ret === nothing + return nothing + else + return reverse_pair_to_tuple(ret[1]), ret[2] + end + end function Base.iterate(lti::LinearTermIterator, state) - reorder_iterator(iterate(lti.aff.terms, state)) + ret = iterate(lti.aff.terms, state) + if ret === nothing + return nothing + else + return reverse_pair_to_tuple(ret[1]), ret[2] + end end end Base.length(lti::LinearTermIterator) = length(lti.aff.terms) +function Base.eltype(lti::LinearTermIterator{GenericAffExpr{C, V}} + ) where {C, V} + return Tuple{C, V} +end """ add_to_expression!(expression, terms...) diff --git a/src/quad_expr.jl b/src/quad_expr.jl index 6bc7271d915..511b02f3f6a 100644 --- a/src/quad_expr.jl +++ b/src/quad_expr.jl @@ -94,7 +94,7 @@ linear_terms(quad::GenericQuadExpr) = LinearTermIterator(quad.aff) struct QuadTermIterator{GQE<:GenericQuadExpr} quad::GQE end - +# TODO: rename to quad_terms """ quadterms(quad::GenericQuadExpr{C, V}) @@ -112,15 +112,31 @@ if VERSION < v"0.7-" Base.done(qti::QuadTermIterator, state::Int) = done(qti.quad.terms, state) Base.next(qti::QuadTermIterator, state::Int) = reorder_iterator(next(qti.quad.terms, state)...) else - function reorder_iterator(t::Tuple{Pair{<:UnorderedPair,<:Any},Int}) - ((first(t).second, first(t).first.a, first(t).first.b), last(t)) + function reorder_and_flatten(p::Pair{<:UnorderedPair}) + return (p.second, p.first.a, p.first.b) + end + function Base.iterate(qti::QuadTermIterator) + ret = iterate(qti.quad.terms) + if ret === nothing + return nothing + else + return reorder_and_flatten(ret[1]), ret[2] + end end - Base.iterate(qti::QuadTermIterator) = reorder_iterator(iterate(qti.quad.terms)) function Base.iterate(qti::QuadTermIterator, state) - reorder_iterator(iterate(qti.quad.terms, state)) + ret = iterate(qti.quad.terms, state) + if ret === nothing + return nothing + else + return reorder_and_flatten(ret[1]), ret[2] + end end end Base.length(qti::QuadTermIterator) = length(qti.quad.terms) +function Base.eltype(qti::QuadTermIterator{GenericQuadExpr{C, V}} + ) where {C, V} + return Tuple{C, V, V} +end function add_to_expression!(quad::GenericQuadExpr{C,V}, new_coef::C, new_var1::V, new_var2::V) where {C,V} # Node: OrderedDict updates the *key* as well. That is, if there was a