Skip to content

Commit

Permalink
Use OrderedDict for Array Parameter Supports (infiniteopt#357)
Browse files Browse the repository at this point in the history
* Use OrderedDict for array parameters

* update doctest
  • Loading branch information
pulsipher authored Jul 30, 2024
1 parent 3a85a93 commit c1f2545
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 71 deletions.
8 changes: 4 additions & 4 deletions docs/src/guide/parameter.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ the distribution.
```jldoctest macro_define
julia> supports(θ)
2×3 Matrix{Float64}:
-0.353007 0.679107 0.586617
-0.190712 1.17155 0.420496
0.679107 -0.353007 0.586617
1.17155 -0.190712 0.420496
```
We refer to groups of parameters defined this way as dependent infinite
parameters. In principle, nonrandom infinite parameter types can be made
Expand Down Expand Up @@ -518,8 +518,8 @@ julia> fill_in_supports!(ξ, num_supports = 3)
julia> supports(ξ)
2×3 Matrix{Float64}:
-0.353007 0.679107 0.586617
-0.190712 1.17155 0.420496
0.679107 -0.353007 0.586617
1.17155 -0.190712 0.420496
```
Note that [`fill_in_supports!`](@ref) only fill in supports for parameters with no
associated supports. To modify the supports of parameters already associated
Expand Down
69 changes: 38 additions & 31 deletions src/array_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ function _process_supports(
_error("Support violates the infinite domain.")
end
supps = round.(supps, sigdigits = sig_digits)
return Dict{Vector{Float64}, Set{DataType}}(supps => Set([UserDefined]))
return DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}(
supps => Set([UserDefined])
)
end

# Vector{Vector{<:Real}}
Expand All @@ -185,8 +187,9 @@ function _process_supports(
_error("Supports violate the infinite domain.")
end
supps = round.(supps, sigdigits = sig_digits)
return Dict{Vector{Float64}, Set{DataType}}(s => Set([UserDefined])
for s in eachcol(supps))
return DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}(
s => Set([UserDefined]) for s in eachcol(supps)
)
end

## Use dispatch to make the formatting of the derivative method vector
Expand Down Expand Up @@ -235,14 +238,17 @@ function _build_parameters(
supp_dict = _process_supports(_error, supports, domain, sig_digits)
# we want to generate supports
elseif !iszero(num_supports)
supps, label = generate_support_values(domain,
num_supports = num_supports,
sig_digits = sig_digits)
supp_dict = Dict{Vector{Float64}, Set{DataType}}(s => Set([label])
for s in eachcol(supps))
supps, label = generate_support_values(
domain,
num_supports = num_supports,
sig_digits = sig_digits
)
supp_dict = DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}(
s => Set([label]) for s in eachcol(supps)
)
# no supports are specified
else
supp_dict = Dict{Vector{Float64}, Set{DataType}}()
supp_dict = DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}()
end
# check the derivative methods
methods = _process_derivative_methods(_error, derivative_method, domains)
Expand Down Expand Up @@ -351,32 +357,27 @@ end
# PARAMETER DEPENDENCIES
################################################################################
# Extend _infinite_variable_dependencies
function _infinite_variable_dependencies(pref::DependentParameterRef
)::Vector{InfiniteVariableIndex}
function _infinite_variable_dependencies(pref::DependentParameterRef)
return _data_object(pref).infinite_var_indices
end

# Extend _parameter_function_dependencies
function _parameter_function_dependencies(pref::DependentParameterRef
)::Vector{ParameterFunctionIndex}
function _parameter_function_dependencies(pref::DependentParameterRef)
return _data_object(pref).parameter_func_indices
end

# Extend _measure_dependencies
function _measure_dependencies(pref::DependentParameterRef
)::Vector{MeasureIndex}
function _measure_dependencies(pref::DependentParameterRef)
return _data_object(pref).measure_indices[_param_index(pref)]
end

# Extend _constraint_dependencies
function _constraint_dependencies(pref::DependentParameterRef
)::Vector{InfOptConstraintIndex}
function _constraint_dependencies(pref::DependentParameterRef)
return _data_object(pref).constraint_indices[_param_index(pref)]
end

# Extend _derivative_dependencies
function _derivative_dependencies(pref::DependentParameterRef
)::Vector{DerivativeIndex}
function _derivative_dependencies(pref::DependentParameterRef)
return _data_object(pref).derivative_indices[_param_index(pref)]
end

Expand Down Expand Up @@ -741,8 +742,7 @@ function _update_parameter_domain(
pref::DependentParameterRef,
new_domain::InfiniteArrayDomain
)
old_params = core_object(pref)
new_supports = Dict{Vector{Float64}, Set{DataType}}()
new_supports = DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}()
sig_figs = significant_digits(pref)
methods = _derivative_methods(pref)
new_params = DependentParameters(new_domain, new_supports, sig_figs, methods)
Expand Down Expand Up @@ -790,8 +790,10 @@ function set_infinite_domain(
"a measure.")
end
param_idx = _param_index(pref)
new_domain = CollectionDomain([i != param_idx ? collection_domains(old_domain)[i] : domain
for i in eachindex(collection_domains(old_domain))])
new_domain = CollectionDomain(
[i != param_idx ? collection_domains(old_domain)[i] : domain
for i in eachindex(collection_domains(old_domain))]
)
_update_parameter_domain(pref, new_domain)
return
end
Expand Down Expand Up @@ -1168,8 +1170,9 @@ function _update_parameter_supports(
label::Type{<:AbstractSupportLabel}
)
domain = _parameter_domain(first(prefs))
new_supps = Dict{Vector{Float64}, Set{DataType}}(s => Set([label])
for s in eachcol(supports))
new_supps = DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}(
s => Set([label]) for s in eachcol(supports)
)
sig_figs = significant_digits(first(prefs))
methods = _derivative_methods(first(prefs))
new_params = DependentParameters(domain, new_supps, sig_figs, methods)
Expand Down Expand Up @@ -1438,9 +1441,11 @@ function generate_and_add_supports!(
domain::InfiniteArrayDomain;
num_supports::Int = DefaultNumSupports
)
new_supps, label = generate_supports(domain,
num_supports = num_supports,
sig_digits = significant_digits(first(prefs)))
new_supps, label = generate_supports(
domain,
num_supports = num_supports,
sig_digits = significant_digits(first(prefs))
)
add_supports(Collections.vectorize(prefs), new_supps, check = false,
label = label)
return
Expand All @@ -1453,9 +1458,11 @@ function generate_and_add_supports!(
method::Type{<:AbstractSupportLabel};
num_supports::Int = DefaultNumSupports
)
new_supps, label = generate_supports(domain, method,
num_supports = num_supports,
sig_digits = significant_digits(first(prefs)))
new_supps, label = generate_supports(
domain, method,
num_supports = num_supports,
sig_digits = significant_digits(first(prefs))
)
add_supports(Collections.vectorize(prefs), new_supps, check = false,
label = label)
return
Expand Down
4 changes: 2 additions & 2 deletions src/datatypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ A `DataType` for storing a collection of dependent infinite parameters.
**Fields**
- `domain::T`: The infinite domain that characterizes the parameters.
- `supports::Dict{Vector{Float64}, Set{DataType}}`: Support dictionary where keys
- `supports::DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}`: Support dictionary where keys
are supports and the values are the set of labels for each support.
- `sig_digits::Int`: The number of significant digits used to round the support values.
- `derivative_methods::Vector{M}`: The derivative evaluation methods associated with
Expand All @@ -532,7 +532,7 @@ A `DataType` for storing a collection of dependent infinite parameters.
struct DependentParameters{T <: InfiniteArrayDomain,
M <: NonGenerativeDerivativeMethod} <: InfOptParameter
domain::T
supports::Dict{Vector{Float64}, Set{DataType}} # Support to label set
supports::DataStructures.OrderedDict{Vector{Float64}, Set{DataType}} # Support to label set
sig_digits::Int
derivative_methods::Vector{M}
end
Expand Down
4 changes: 2 additions & 2 deletions test/TranscriptionOpt/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ end
add_supports(par, 0.6, label = InternalLabel)
@test IOTO.set_parameter_supports(tb, m) isa Nothing
expected = ([[0., 0.], [1., 1.], [NaN, NaN]], [0., 0.3, 0.6, 0.8, 1., NaN])
@test isequal(sort.(IOTO.transcription_data(tb).supports), expected)
@test isequal(IOTO.transcription_data(tb).supports, expected)
@test IOTO.has_internal_supports(tb)
expected = ([Set([UniformGrid]), Set([UniformGrid]), Set{DataType}()],
[Set([UserDefined]), Set([InternalGaussLobatto]),
Expand All @@ -361,7 +361,7 @@ end
# test parameter_supports
@testset "parameter_supports" begin
expected = ([[0., 0.], [1., 1.], [NaN, NaN]], [0., 0.3, 0.6, 0.8, 1., NaN])
@test isequal(sort.(IOTO.parameter_supports(tb)), expected)
@test isequal(IOTO.parameter_supports(tb), expected)
end
# test support_index_iterator with 1 argument
@testset "support_index_iterator (1 Arg)" begin
Expand Down
42 changes: 21 additions & 21 deletions test/TranscriptionOpt/transcribe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@
@test IOTO.transcription_variable(x, tb) isa Vector{VariableRef}
@test IOTO.transcription_variable(y, tb) isa Matrix{VariableRef}
@test name(IOTO.transcription_variable(x, tb)[1]) == "x(0.0)"
@test name(IOTO.transcription_variable(y, tb)[2, 1]) in ["y(1.0, [0.0, 0.0])", "y(1.0, [1.0, 1.0])"]
@test name(IOTO.transcription_variable(y, tb)[2, 1]) == "y(1.0, [0.0, 0.0])"
@test has_lower_bound(IOTO.transcription_variable(x)[1])
@test is_binary(IOTO.transcription_variable(y, tb)[2])
@test is_fixed(IOTO.transcription_variable(y, tb)[4])
@test is_integer(IOTO.transcription_variable(x, tb)[2])
@test sort!(vec(start_value.(IOTO.transcription_variable(y, tb)))) == [0., 1, 2, 3]
@test start_value.(IOTO.transcription_variable(y, tb)) == [0. 2; 1 3]
@test supports(x) == [(0,), (1,)]
@test length(supports(y)) == 4
@test supports(y) == [(0, [0, 0]) (0, [1, 1]); (1, [0, 0]) (1, [1, 1])]
end
# test _format_derivative_info
@testset "_format_derivative_info" begin
Expand Down Expand Up @@ -109,12 +109,12 @@
@test name(IOTO.transcription_variable(dx, tb)[1]) == "d/dpar[x(par)](0.0)"
@test name(IOTO.transcription_variable(dx3, tb)[1]) == "d^3/dpar^3[x(par)](0.0)"
@test name(IOTO.transcription_variable(deriv(dx, par), tb)[1]) == "d²/dpar²[x(par)](0.0)"
possible = [Sys.iswindows() ? "d/dpar[y(par, pars)](1.0, [$i, $i])" : "∂/∂par[y(par, pars)](1.0, [$i, $i])" for i in [0.0, 1.0]]
@test name(IOTO.transcription_variable(dy, tb)[2, 1]) in possible
possible = Sys.iswindows() ? "d/dpar[y(par, pars)](1.0, [0.0, 0.0])" : "∂/∂par[y(par, pars)](1.0, [0.0, 0.0])"
@test name(IOTO.transcription_variable(dy, tb)[2, 1]) == possible
@test has_lower_bound(IOTO.transcription_variable(dx, tb)[1])
@test sort!(vec(start_value.(IOTO.transcription_variable(dy, tb)))) == [0., 1, 2, 3]
@test start_value.(IOTO.transcription_variable(dy, tb)) == [0. 2; 1 3]
@test supports(dx) == [(0,), (1,)]
@test length(supports(dy)) == 4
@test supports(dy) == [(0, [0, 0]) (0, [1, 1]); (1, [0, 0]) (1, [1, 1])]
end
# test _set_semi_infinite_variable_mapping
@testset "_set_semi_infinite_variable_mapping" begin
Expand Down Expand Up @@ -177,7 +177,7 @@
@test IOTO.transcription_variable(x0, tb) == IOTO.lookup_by_support(x, tb, [0.])
@test IOTO.transcription_variable(y0, tb) == IOTO.lookup_by_support(y, tb, [0., 0., 0.])
@test name(IOTO.transcription_variable(x0, tb)) == "x(0.0)"
@test name(IOTO.transcription_variable(y0, tb))[1:8] == "y(0.0, ["
@test name(IOTO.transcription_variable(y0, tb)) == "y(0.0, [0.0, 0.0])"
@test lower_bound(IOTO.transcription_variable(x0, tb)) == 0
@test is_integer(IOTO.transcription_variable(x0, tb))
@test lower_bound(IOTO.transcription_variable(y0, tb)) == 0
Expand Down Expand Up @@ -219,7 +219,7 @@ end
@test IOTO.transcription_variable(meas4) isa AffExpr
@test supports(meas1) == ()
@test supports(meas2) == ()
@test sort!(supports(meas3)) == [(0.,), (1., )]
@test supports(meas3) == [(0.,), (1., )]
end
# test transcribe_objective!
@testset "transcribe_objective!" begin
Expand Down Expand Up @@ -375,22 +375,22 @@ end
@test length(IOTO.transcription_constraint(c7)) == 6
@test IOTO.transcription_constraint(c8) isa ConstraintRef
# test the info constraint supports
expected = [([0., 0.], 0.5), ([0., 0.], 1.), ([1., 1.], 0.), ([1., 1.], 0.5), ([1., 1.], 1.)]
@test sort(supports(LowerBoundRef(x))) == expected
@test sort(supports(UpperBoundRef(x))) == expected
@test sort(supports(IntegerRef(x))) == expected
expected = [([1.0, 1.0], 0.0), ([0.0, 0.0], 0.5), ([1.0, 1.0], 0.5), ([0.0, 0.0], 1.0), ([1.0, 1.0], 1.0)]
@test supports(LowerBoundRef(x)) == expected
@test supports(UpperBoundRef(x)) == expected
@test supports(IntegerRef(x)) == expected
@test supports(FixRef(x0)) == ()
@test supports(UpperBoundRef(yf)) == ()
@test supports(BinaryRef(z)) == ()
# test the constraint supports
expected = [([0., 0.], 0.), ([0., 0.], 0.5), ([0., 0.], 1.), ([1., 1.], 0.), ([1., 1.], 0.5), ([1., 1.], 1.)]
@test sort(vec(supports(c1))) == expected
expected = [([0., 0.], 0.) ([0., 0.], 0.5) ([0., 0.], 1.); ([1., 1.], 0.) ([1., 1.], 0.5) ([1., 1.], 1.)]
@test supports(c1) == expected
@test supports(c2) == (0.,)
@test supports(c3) == ([1., 1.], 1.)
@test supports(c4) == [(0.0,), (0.5,)]
@test supports(c5) == ()
@test sort(vec(supports(c6))) == expected
@test sort(vec(supports(c7))) == expected
@test supports(c6) == expected
@test supports(c7) == expected
@test supports(c8) == ()
end
end
Expand Down Expand Up @@ -506,19 +506,19 @@ end
@test IOTO.transcription_variable(x) isa Vector{VariableRef}
@test IOTO.transcription_variable(y) isa Matrix{VariableRef}
@test name(IOTO.transcription_variable(x)[1]) == "x(0.0)"
@test name(IOTO.transcription_variable(y)[3])[1:8] == "y(0.0, ["
@test name(IOTO.transcription_variable(y)[1, 2]) == "y(0.0, [1.0, 1.0])"
@test has_lower_bound(IOTO.transcription_variable(x)[1])
@test is_binary(IOTO.transcription_variable(y)[2])
@test is_fixed(IOTO.transcription_variable(y)[4])
@test is_integer(IOTO.transcription_variable(x)[2])
@test start_value(IOTO.transcription_variable(y)[1]) == 0.
@test supports(x) == [(0.,), (1.,)]
@test length(supports(y)) == 4
@test supports(y) == [(0.0, [0.0, 0.0]) (0.0, [1.0, 1.0]); (1.0, [0.0, 0.0]) (1.0, [1.0, 1.0])]
# test point variables
@test IOTO.transcription_variable(x0) isa VariableRef
@test IOTO.transcription_variable(y0) isa VariableRef
@test name(IOTO.transcription_variable(x0)) == "x(0.0)"
@test name(IOTO.transcription_variable(y0))[1:8] == "y(0.0, ["
@test name(IOTO.transcription_variable(y0)) == "y(0.0, [0.0, 0.0])"
@test has_lower_bound(IOTO.transcription_variable(x0))
@test is_integer(IOTO.transcription_variable(x0))
@test has_lower_bound(IOTO.transcription_variable(y0))
Expand Down Expand Up @@ -551,7 +551,7 @@ end
@test constraint_object(IOTO.transcription_constraint(c6)).func == [zt, wt]
@test IOTO.transcription_constraint(c5) isa Vector{ConstraintRef}
@test name(IOTO.transcription_constraint(c2)) == "c2"
@test name(IOTO.transcription_constraint(c1)) in ["c1[1, 1]", "c1[1, 2]"]
@test name(IOTO.transcription_constraint(c1)) == "c1[1, 1]"
@test supports(c1) == (0., [0., 0.])
@test IOTO.transcription_constraint(c7) isa ConstraintRef
@test isequal(constraint_object(IOTO.transcription_constraint(c7)).func, gr(zt) - 2.)
Expand Down
Loading

0 comments on commit c1f2545

Please sign in to comment.