Skip to content

Commit

Permalink
Merge pull request #18 from JuliaDecisionFocusedLearning/argmax-encoder
Browse files Browse the repository at this point in the history
Store true encoder in Argmax and Ranking benchmarks
  • Loading branch information
BatyLeo authored Dec 23, 2024
2 parents 8748674 + 247a9b7 commit f1efe57
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
25 changes: 16 additions & 9 deletions src/Argmax/Argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ Basic benchmark problem with an argmax as the CO algorithm.
# Fields
$TYPEDFIELDS
"""
struct ArgmaxBenchmark <: AbstractBenchmark
struct ArgmaxBenchmark{E} <: AbstractBenchmark
"instances dimension, total number of classes"
instance_dim::Int
"number of features"
nb_features::Int
"true mapping between features and costs"
encoder::E
end

function Base.show(io::IO, bench::ArgmaxBenchmark)
Expand All @@ -27,8 +29,15 @@ function Base.show(io::IO, bench::ArgmaxBenchmark)
)
end

function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5)
return ArgmaxBenchmark(instance_dim, nb_features)
"""
$TYPEDSIGNATURES
Custom constructor for [`ArgmaxBenchmark`](@ref).
"""
function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothing)
Random.seed!(seed)
model = Chain(Dense(nb_features => 1; bias=false), vec)
return ArgmaxBenchmark(instance_dim, nb_features, model)
end

"""
Expand Down Expand Up @@ -59,12 +68,10 @@ Generate a dataset of labeled instances for the argmax problem.
function Utils.generate_dataset(
bench::ArgmaxBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0
)
(; instance_dim, nb_features) = bench
(; instance_dim, nb_features, encoder) = bench
rng = MersenneTwister(seed)
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
mapping = Chain(Dense(nb_features => 1; bias=false), vec)
costs = mapping.(features)
# solutions = one_hot_argmax.(costs)
costs = encoder.(features)
noisy_solutions = [
one_hot_argmax+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs
]
Expand All @@ -79,9 +86,9 @@ $TYPEDSIGNATURES
Initialize a linear model for `bench` using `Flux`.
"""
function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=0)
Random.seed!(seed)
function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=nothing)
(; nb_features) = bench
Random.seed!(seed)
return Chain(Dense(nb_features => 1; bias=false), vec)
end

Expand Down
25 changes: 16 additions & 9 deletions src/Ranking/Ranking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ Basic benchmark problem with ranking as the CO algorithm.
# Fields
$TYPEDFIELDS
"""
struct RankingBenchmark <: AbstractBenchmark
struct RankingBenchmark{E} <: AbstractBenchmark
"instances dimension, total number of classes"
instance_dim::Int
"number of features"
nb_features::Int
"true mapping between features and costs"
encoder::E
end

function Base.show(io::IO, bench::RankingBenchmark)
Expand All @@ -27,8 +29,15 @@ function Base.show(io::IO, bench::RankingBenchmark)
)
end

function RankingBenchmark(; instance_dim::Int=10, nb_features::Int=5)
return RankingBenchmark(instance_dim, nb_features)
"""
$TYPEDSIGNATURES
Custom constructor for [`RankingBenchmark`](@ref).
"""
function RankingBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothing)
Random.seed!(seed)
model = Chain(Dense(nb_features => 1; bias=false), vec)
return RankingBenchmark(instance_dim, nb_features, model)
end

"""
Expand Down Expand Up @@ -57,12 +66,10 @@ Generate a dataset of labeled instances for the ranking problem.
function Utils.generate_dataset(
bench::RankingBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0
)
(; instance_dim, nb_features) = bench
(; instance_dim, nb_features, encoder) = bench
rng = MersenneTwister(seed)
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
mapping = Chain(Dense(nb_features => 1; bias=false), vec)
costs = mapping.(features)
# solutions = ranking.(costs)
costs = encoder.(features)
noisy_solutions = [
ranking.+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs
]
Expand All @@ -77,9 +84,9 @@ $TYPEDSIGNATURES
Initialize a linear model for `bench` using `Flux`.
"""
function Utils.generate_statistical_model(bench::RankingBenchmark; seed=0)
Random.seed!(seed)
function Utils.generate_statistical_model(bench::RankingBenchmark; seed=nothing)
(; nb_features) = bench
Random.seed!(seed)
return Chain(Dense(nb_features => 1; bias=false), vec)
end

Expand Down

0 comments on commit f1efe57

Please sign in to comment.