-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from JuliaDecisionFocusedLearning/more-tests
Add more tests for better coverage
- Loading branch information
Showing
6 changed files
with
93 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,35 @@ | ||
module FixedSizeShortestPathTest | ||
|
||
using DecisionFocusedLearningBenchmarks.FixedSizeShortestPath | ||
|
||
# using Flux | ||
# using InferOpt | ||
# using ProgressMeter | ||
# using UnicodePlots | ||
# using Zygote | ||
|
||
bench = FixedSizeShortestPathBenchmark() | ||
|
||
(; features, costs, solutions) = generate_dataset(bench) | ||
|
||
model = generate_statistical_model(bench) | ||
maximizer = generate_maximizer(bench) | ||
|
||
# perturbed = PerturbedAdditive(maximizer; nb_samples=10, ε=0.1) | ||
# fyl = FenchelYoungLoss(perturbed) | ||
|
||
# opt_state = Flux.setup(Adam(), model) | ||
# loss_history = Float64[] | ||
# gap_history = Float64[] | ||
# E = 100 | ||
# @showprogress for epoch in 1:E | ||
# loss = 0.0 | ||
# for (x, y) in zip(features, solutions) | ||
# val, grads = Flux.withgradient(model) do m | ||
# θ = m(x) | ||
# fyl(θ, y) | ||
# end | ||
# loss += val | ||
# Flux.update!(opt_state, model, grads[1]) | ||
# end | ||
# push!(loss_history, loss ./ E) | ||
# push!( | ||
# gap_history, compute_gap(bench, model, features, costs, solutions, maximizer) .* 100 | ||
# ) | ||
# end | ||
|
||
# println(lineplot(loss_history; title="Loss")) | ||
# println(lineplot(gap_history; title="Gap")) | ||
|
||
@testitem "FixedSizeShortestPath" begin | ||
using DecisionFocusedLearningBenchmarks.FixedSizeShortestPath | ||
using Graphs | ||
|
||
p = 5 | ||
grid_size = (5, 5) | ||
A = (grid_size[1] - 1) * grid_size[2] + grid_size[1] * (grid_size[2] - 1) | ||
b = FixedSizeShortestPathBenchmark(; p=p, grid_size=grid_size) | ||
|
||
@test nv(b.graph) == grid_size[1] * grid_size[2] | ||
@test ne(b.graph) == A | ||
|
||
dataset = generate_dataset(b, 50) | ||
model = generate_statistical_model(b) | ||
maximizer = generate_maximizer(b) | ||
|
||
gap = compute_gap(b, dataset, model, maximizer) | ||
@test gap >= 0 | ||
|
||
for sample in dataset | ||
x = sample.x | ||
θ_true = sample.θ | ||
y_true = sample.y | ||
@test all(θ_true .< 0) | ||
@test size(x) == (p,) | ||
@test length(θ_true) == A | ||
@test length(y_true) == A | ||
@test isnothing(sample.instance) | ||
@test all(y_true .== maximizer(θ_true)) | ||
θ = model(x) | ||
@test length(θ) == length(θ_true) | ||
y = maximizer(θ) | ||
@test length(y) == length(y_true) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,29 @@ | ||
@testitem "Portfolio Optimization" begin | ||
using DecisionFocusedLearningBenchmarks | ||
using InferOpt | ||
using Flux | ||
using Zygote | ||
|
||
b = PortfolioOptimizationBenchmark() | ||
d = 50 | ||
p = 5 | ||
b = PortfolioOptimizationBenchmark(; d=d, p=p) | ||
|
||
dataset = generate_dataset(b, 100) | ||
dataset = generate_dataset(b, 50) | ||
model = generate_statistical_model(b) | ||
maximizer = generate_maximizer(b) | ||
|
||
# train_dataset, test_dataset = dataset[1:50], dataset[50:100] | ||
# X_train = train_dataset.features | ||
# Y_train = train_dataset.solutions | ||
|
||
# perturbed_maximizer = PerturbedAdditive(maximizer; ε=0.1, nb_samples=1) | ||
# loss = FenchelYoungLoss(perturbed_maximizer) | ||
|
||
# starting_gap = compute_gap(b, test_dataset, model, maximizer) | ||
|
||
# opt_state = Flux.setup(Adam(), model) | ||
# loss_history = Float64[] | ||
# for epoch in 1:50 | ||
# val, grads = Flux.withgradient(model) do m | ||
# sum(loss(m(x), y) for (x, y) in zip(X_train, Y_train)) / length(train_dataset) | ||
# end | ||
# Flux.update!(opt_state, model, grads[1]) | ||
# push!(loss_history, val) | ||
# end | ||
|
||
# final_gap = compute_gap(b, test_dataset, model, maximizer) | ||
|
||
# @test loss_history[end] < loss_history[1] | ||
# @test final_gap < starting_gap / 10 | ||
for sample in dataset | ||
x = sample.x | ||
θ_true = sample.θ | ||
y_true = sample.y | ||
@test size(x) == (p,) | ||
@test length(θ_true) == d | ||
@test length(y_true) == d | ||
@test isnothing(sample.instance) | ||
@test all(y_true .== maximizer(θ_true)) | ||
|
||
θ = model(x) | ||
@test length(θ) == d | ||
|
||
y = maximizer(θ) | ||
@test length(y) == d | ||
@test sum(y) <= 1 | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,37 @@ | ||
@testitem "Subset selection" begin | ||
using DecisionFocusedLearningBenchmarks | ||
using InferOpt | ||
using Flux | ||
using UnicodePlots | ||
using Zygote | ||
|
||
b = SubsetSelectionBenchmark() | ||
n = 25 | ||
k = 5 | ||
|
||
dataset = generate_dataset(b, 500) | ||
model = generate_statistical_model(b) | ||
maximizer = generate_maximizer(b) | ||
|
||
# train_dataset, test_dataset = dataset[1:450], dataset[451:500] | ||
# X_train = train_dataset.features | ||
# Y_train = train_dataset.solutions | ||
|
||
# perturbed_maximizer = PerturbedAdditive(maximizer; ε=1.0, nb_samples=100) | ||
# loss = FenchelYoungLoss(perturbed_maximizer) | ||
b = SubsetSelectionBenchmark(; n=n, k=k) | ||
|
||
# starting_gap = compute_gap(b, test_dataset, model, maximizer) | ||
io = IOBuffer() | ||
show(io, b) | ||
@test String(take!(io)) == "SubsetSelectionBenchmark(n=25, k=5)" | ||
|
||
# opt_state = Flux.setup(Adam(0.1), model) | ||
# loss_history = Float64[] | ||
# for epoch in 1:50 | ||
# val, grads = Flux.withgradient(model) do m | ||
# sum(loss(m(x), y) for (x, y) in zip(X_train, Y_train)) / length(train_dataset) | ||
# end | ||
# Flux.update!(opt_state, model, grads[1]) | ||
# push!(loss_history, val) | ||
# end | ||
|
||
# final_gap = compute_gap(b, test_dataset, model, maximizer) | ||
dataset = generate_dataset(b, 50) | ||
model = generate_statistical_model(b) | ||
maximizer = generate_maximizer(b) | ||
|
||
# lineplot(loss_history) | ||
# @test loss_history[end] < loss_history[1] | ||
# @test final_gap < starting_gap / 10 | ||
for (i, sample) in enumerate(dataset) | ||
x = sample.x | ||
θ_true = sample.θ | ||
y_true = sample.y | ||
@test size(x) == (n,) | ||
@test length(θ_true) == n | ||
@test length(y_true) == n | ||
@test isnothing(sample.instance) | ||
@test all(y_true .== maximizer(θ_true)) | ||
|
||
# Features and true weights should be equal | ||
@test all(θ_true .== x) | ||
|
||
θ = model(x) | ||
@test length(θ) == n | ||
|
||
y = maximizer(θ) | ||
@test length(y) == n | ||
@test sum(y) == k | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters