Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[breaking] Cleanup the Context object: remove id_hash, simplify dictionaries #645

Merged
merged 3 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions docs/src/examples/mixed_integer/aux_files/antidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ export antidiag

### Diagonal
### Represents the kth diagonal of an mxn matrix as a (min(m, n) - k) x 1 vector
struct AntidiagAtom <: Convex.AbstractExpr
mutable struct AntidiagAtom <: Convex.AbstractExpr
head::Symbol
id_hash::UInt64
children::Tuple{Convex.AbstractExpr}
size::Tuple{Int,Int}
k::Int
Expand All @@ -27,13 +26,7 @@ struct AntidiagAtom <: Convex.AbstractExpr
error("Bounds error in calling diag")
end
children = (x,)
return new(
:antidiag,
hash((children, k)),
children,
(minimum(x.size) - k, 1),
k,
)
return new(:antidiag, children, (minimum(x.size) - k, 1), k)
end
end

Expand Down
13 changes: 6 additions & 7 deletions docs/src/manual/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,15 @@ this. To do so, we define

```@example 1
using Convex

# Must be mutable! Otherwise variables with the same size/value would be treated as the same object.
mutable struct ProbabilityVector <: Convex.AbstractVariable
head::Symbol
id_hash::UInt64
size::Tuple{Int, Int}
size::Tuple{Int,Int}
value::Union{Convex.Value,Nothing}
vexity::Convex.Vexity
function ProbabilityVector(d)
this = new(:ProbabilityVector, 0, (d,1), nothing, Convex.AffineVexity())
this.id_hash = objectid(this)
this
return new(:ProbabilityVector, (d, 1), nothing, Convex.AffineVexity())
end
end

Expand All @@ -165,8 +164,8 @@ solve!(prob, SCS.Optimizer)
evaluate(p) # [1.0, 0.0, 0.0]
```

Subtypes of `AbstractVariable` must have the fields `head`, `id_hash`, and
`size`, and `id_hash` must be populated as shown in the example. Then they must also
Subtypes of `AbstractVariable` must have the fields `head` and
`size`. Then they must also

* either have a field `value`, or implement [`Convex._value`](@ref) and
[`Convex.set_value!`](@ref)
Expand Down
23 changes: 12 additions & 11 deletions src/Context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,20 @@ mutable struct Context{T,M}
model::M

# Used for populating variable values after solving
var_id_to_moi_indices::OrderedCollections.OrderedDict{
UInt64,
var_to_moi_indices::IdDict{
Any,
Union{
Vector{MOI.VariableIndex},
Tuple{Vector{MOI.VariableIndex},Vector{MOI.VariableIndex}},
},
}
# `id_hash` -> `AbstractVariable`
id_to_variables::OrderedCollections.OrderedDict{UInt64,Any}

# Used for populating constraint duals
constr_to_moi_inds::IdDict{Any,Any}

detected_infeasible_during_formulation::Ref{Bool}
detected_infeasible_during_formulation::Bool

# Cache
# conic_form_cache::DataStructures.WeakKeyIdDict{Any, Any}
conic_form_cache::IdDict{Any,Any}
end

Expand All @@ -39,8 +36,13 @@ function Context{T}(optimizer_factory; add_cache::Bool = false) where {T}
end
return Context{T,typeof(model)}(
model,
OrderedCollections.OrderedDict{UInt64,Vector{MOI.VariableIndex}}(),
OrderedCollections.OrderedDict{UInt64,Any}(),
IdDict{
Any,
Union{
Vector{MOI.VariableIndex},
Tuple{Vector{MOI.VariableIndex},Vector{MOI.VariableIndex}},
},
}(),
IdDict{Any,Any}(),
false,
IdDict{Any,Any}(),
Expand All @@ -49,10 +51,9 @@ end

function Base.empty!(context::Context)
MOI.empty!(context.model)
empty!(context.var_id_to_moi_indices)
empty!(context.id_to_variables)
empty!(context.var_to_moi_indices)
empty!(context.constr_to_moi_inds)
context.detected_infeasible_during_formulation[] = false
context.detected_infeasible_during_formulation = false
empty!(context.conic_form_cache)
return
end
9 changes: 4 additions & 5 deletions src/MOI_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

struct Optimizer{T,M} <: MOI.AbstractOptimizer
context::Context{T,M}
moi_to_convex::OrderedCollections.OrderedDict{MOI.VariableIndex,UInt64}
moi_to_convex::OrderedCollections.OrderedDict{MOI.VariableIndex,Any}
convex_to_moi::Dict{UInt64,Vector{MOI.VariableIndex}}
constraint_map::Vector{MOI.ConstraintIndex}
function Optimizer(context::Context{T,M}) where {T,M}
Expand Down Expand Up @@ -47,9 +47,8 @@ end

function _add_variable(model::Optimizer, vi::MOI.VariableIndex)
var = Variable()
model.moi_to_convex[vi] = var.id_hash
model.context.var_id_to_moi_indices[var.id_hash] = [vi]
model.context.id_to_variables[var.id_hash] = var
model.moi_to_convex[vi] = var
model.context.var_to_moi_indices[var] = [vi]
return
end

Expand Down Expand Up @@ -129,7 +128,7 @@ function _expr(::Optimizer, v::Value)
end

function _expr(model::Optimizer, x::MOI.VariableIndex)
return model.context.id_to_variables[model.moi_to_convex[x]]
return model.moi_to_convex[x]
end

function _expr(model::Optimizer, f::MOI.AbstractScalarFunction)
Expand Down
12 changes: 2 additions & 10 deletions src/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ end

mutable struct Constant{T<:Real} <: AbstractExpr
head::Symbol
id_hash::UInt64
value::Union{Matrix{T},SPARSE_MATRIX{T}}
size::Tuple{Int,Int}
sign::Sign
Expand All @@ -47,13 +46,7 @@ mutable struct Constant{T<:Real} <: AbstractExpr
if x isa Complex || x isa AbstractArray{<:Complex}
throw(DomainError(x, "Constant expects real values"))
end
return new{eltype(x)}(
:constant,
objectid(x),
_matrix(x),
_size(x),
sign,
)
return new{eltype(x)}(:constant, _matrix(x), _size(x), sign)
end
end
function Constant(x::Value, check_sign::Bool = true)
Expand All @@ -63,13 +56,12 @@ end

mutable struct ComplexConstant{T<:Real} <: AbstractExpr
head::Symbol
id_hash::UInt64
size::Tuple{Int,Int}
real_constant::Constant{T}
imag_constant::Constant{T}
function ComplexConstant(re::Constant{T}, im::Constant{T}) where {T}
size(re) == size(im) || error("size mismatch")
return new{T}(:complex_constant, rand(UInt64), size(re), re, im)
return new{T}(:complex_constant, size(re), re, im)
end

# function ComplexConstant(re::Constant{S1}, im::Constant{S2}) where {S1,S2}
Expand Down
2 changes: 1 addition & 1 deletion src/constraints/GenericConstraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end
function _add_constraint!(context::Context, c::GenericConstraint)
if vexity(c.child) == ConstVexity()
if !is_feasible(evaluate(c.child), c.set, CONSTANT_CONSTRAINT_TOL[])
context.detected_infeasible_during_formulation[] = true
context.detected_infeasible_during_formulation = true
end
return
end
Expand Down
3 changes: 1 addition & 2 deletions src/expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ const Value = Union{Number,AbstractArray}
# We commandeer `==` to create a constraint.
# Therefore we define `isequal` to still have a notion of equality
# (Normally `isequal` falls back to `==`, so we need to provide a method).
# All `AbstractExpr` (Constraints are not AbstractExpr's!) are compared by value, except for AbstractVariables,
# which use their `id_hash` field.
# All `AbstractExpr` (Constraints are not AbstractExpr's!) are compared by value, except for AbstractVariables, which are compared by `===` (objectid).
function Base.isequal(x::AbstractExpr, y::AbstractExpr)
if typeof(x) != typeof(y)
return false
Expand Down
8 changes: 3 additions & 5 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ end

function _add_variable_primal_start(context::Convex.Context{T}) where {T}
attr = MOI.VariablePrimalStart()
for (id, moi_indices) in context.var_id_to_moi_indices
x = context.id_to_variables[id]
for (x, moi_indices) in context.var_to_moi_indices
if x.value === nothing
continue
elseif moi_indices isa Tuple # Variable is complex
Expand Down Expand Up @@ -97,7 +96,7 @@ function solve!(
if warmstart && MOI.supports(context.model, attr, MOI.VariableIndex)
_add_variable_primal_start(context)
end
if context.detected_infeasible_during_formulation[]
if context.detected_infeasible_during_formulation
p.status = MOI.INFEASIBLE
else
MOI.optimize!(context.model)
Expand All @@ -108,8 +107,7 @@ function solve!(
@warn "Problem wasn't solved optimally" status = p.status
end
primal_status = MOI.get(context.model, MOI.PrimalStatus())
for (id, indices) in context.var_id_to_moi_indices
var = context.id_to_variables[id]
for (var, indices) in context.var_to_moi_indices
if vexity(var) == ConstVexity()
continue # Fixed variable
elseif primal_status == MOI.NO_SOLUTION
Expand Down
8 changes: 4 additions & 4 deletions src/utilities/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using .TreePrint
"""
show_id(io::IO, x::Union{AbstractVariable}; digits = 3)

Print a truncated version of the objects `id_hash` field.
Print a truncated version of the object's id.

## Example

Expand All @@ -19,12 +19,12 @@ julia> Convex.show_id(stdout, x)
id: 163…906
```
"""
function show_id(io::IO, x::Union{AbstractVariable}; digits = MAXDIGITS[])
function show_id(io::IO, x::AbstractVariable; digits = MAXDIGITS[])
return print(io, show_id(x; digits = digits))
end

function show_id(x::Union{AbstractVariable}; digits = MAXDIGITS[])
hash_str = string(x.id_hash)
function show_id(x::AbstractVariable; digits = MAXDIGITS[])
hash_str = string(objectid(x))
if length(hash_str) > (2 * digits + 1)
return "id: " * first(hash_str, digits) * "…" * last(hash_str, digits)
else
Expand Down
16 changes: 4 additions & 12 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ integer-valued (`IntVar`), or binary (`BinVar`).
"""
abstract type AbstractVariable <: AbstractExpr end

An `AbstractVariable` should have `head` field, an `id_hash` field
and a `size` field to conform to the `AbstractExpr` interface, and
An `AbstractVariable` should have `head` field, and a `size` field to conform to the `AbstractExpr` interface, and
implement methods (or use the field-access fallbacks) for

* [`_value`](@ref), [`set_value!`](@ref): get or set the numeric value of the variable.
Expand Down Expand Up @@ -188,10 +187,10 @@ function free!(x::AbstractVariable)
end

function Base.isequal(x::AbstractVariable, y::AbstractVariable)
return x.id_hash == y.id_hash
return x === y
end

Base.hash(x::AbstractVariable, h::UInt) = hash(x.id_hash, h)
Base.hash(x::AbstractVariable, h::UInt) = hash(objectid(x), h)

iscomplex(x::Sign) = x == ComplexSign()

Expand All @@ -204,8 +203,6 @@ iscomplex(::Union{Real,AbstractArray{<:Real}}) = false
mutable struct Variable <: AbstractVariable
# Every `AbstractExpr` has a `head`; for a Variable it is set to `:variable`.
head::Symbol
# A unique identifying hash used for caching.
id_hash::UInt64
# The current value of the variable. Defaults to `nothing` until the
# variable has been [`fix!`](@ref)'d to a particular value, or the
# variable has been used in a problem which has been solved, at which
Expand Down Expand Up @@ -243,18 +240,15 @@ mutable struct Variable <: AbstractVariable
),
)
end
this = new(
return new(
:variable,
0,
nothing,
size,
AffineVexity(),
sign,
Constraint[],
vartype,
)
this.id_hash = objectid(this)
return this
end
end

Expand All @@ -274,7 +268,6 @@ Variable(vartype::VarType) = Variable((1, 1), NoSign(), vartype)

mutable struct ComplexVariable <: AbstractVariable
head::Symbol
id_hash::UInt64
size::Tuple{Int,Int}
value::Union{Value,Nothing}
vexity::Vexity
Expand All @@ -283,7 +276,6 @@ mutable struct ComplexVariable <: AbstractVariable
function ComplexVariable(size::Tuple{Int,Int} = (1, 1))
return new(
:ComplexVariable,
rand(UInt64),
size,
nothing,
AffineVexity(),
Expand Down
18 changes: 2 additions & 16 deletions src/variable_template.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
# It might be useful to get a direct VOV sometimes...
function _template(a::AbstractVariable, context::Context{T}) where {T}
first_cache = false
var_inds = get!(context.var_id_to_moi_indices, a.id_hash) do
var_inds = get!(context.var_to_moi_indices, a) do
first_cache = true
return add_variables!(context.model, a)
end

context.id_to_variables[a.id_hash] = a

# we only want this to run once, when the variable is first added,
# and after `var_id_to_moi_indices` is populated
# and after `var_to_moi_indices` is populated
if first_cache
if sign(a) == Positive()
add_constraint!(context, a >= 0)
Expand Down Expand Up @@ -97,18 +95,6 @@ accessed in `context[a]`, otherwise, it is created by calling
with the same expression does not create a duplicate one.
"""
function conic_form!(context::Context, a::AbstractExpr)

# Nicer implementation
d = context.conic_form_cache
return get!(() -> new_conic_form!(context, a), d, a)

# Avoid closure
# wkh = context.conic_form_cache
# default = () -> new_conic_form!(context, a)
# key = a
# return Base.@lock wkh.lock begin
# get!(default, wkh.ht, DataStructures.WeakRefForWeakDict(key))
# end

return
end
Loading
Loading