Skip to content

Commit 1df4801

Browse files
committed
Add GA with local improvment
1 parent ee7ddfa commit 1df4801

File tree

3 files changed

+75
-10
lines changed

3 files changed

+75
-10
lines changed

include/simulation/RoutingSolver.hpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,17 @@ class RoutingSolver {
7373
double start_sampling_range_fraction,
7474
double abort_sampling_range_fraction, double contr_coeff);
7575

76+
std::tuple<SolverResult, ScoreHistory>
77+
local_search(RoutingSolutionSet starting_solution,
78+
double start_sampling_range_fraction,
79+
double abort_sampling_range_fraction, double contr_coeff,
80+
const std::function<double(void)>& rng01);
81+
7682
std::optional<SolverResult>
7783
greedy_solution(std::chrono::milliseconds per_train_stall_time);
7884

7985
std::tuple<std::optional<SolverResult>, ScoreHistory>
80-
genetic_search(GeneticParams params);
86+
genetic_search(GeneticParams params, bool local_improv = false);
8187

8288
// GA Helpers
8389
struct MiddleCost {
@@ -100,6 +106,11 @@ class RoutingSolver {
100106
const RoutingSolutionSet& X2,
101107
const std::function<double(void)>& rnd01);
102108

109+
RoutingSolutionSet
110+
crossover_local_improv(const RoutingSolutionSet& X1,
111+
const RoutingSolutionSet& X2,
112+
const std::function<double(void)>& rnd01);
113+
103114
double calculate_SO_total_fitness(const GA_Type::thisChromosomeType& X);
104115

105116
void SO_report_generation(

src/simulation/RoutingSolver.cpp

+57-4
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,41 @@ cda_rail::sim::RoutingSolver::local_search(
263263
return {last_result, hist};
264264
}
265265

266+
std::tuple<cda_rail::sim::SolverResult, cda_rail::sim::ScoreHistory>
267+
cda_rail::sim::RoutingSolver::local_search(
268+
cda_rail::sim::RoutingSolutionSet starting_solution,
269+
double start_sampling_range_fraction, double abort_sampling_range_fraction,
270+
double contr_coeff, const std::function<double(void)>& rng01) {
271+
std::chrono::steady_clock::time_point initial_time =
272+
std::chrono::steady_clock::now();
273+
274+
ScoreHistory hist;
275+
double sampling_range_fraction = start_sampling_range_fraction;
276+
SolverResult last_result{instance, starting_solution};
277+
double last_score = last_result.get_score();
278+
hist.push_back({std::chrono::duration_cast<std::chrono::milliseconds>(
279+
std::chrono::steady_clock::now() - initial_time),
280+
last_score});
281+
282+
while (sampling_range_fraction > abort_sampling_range_fraction) {
283+
RoutingSolutionSet new_sol = last_result.get_solutions();
284+
new_sol.perturb(instance, sampling_range_fraction, rng01);
285+
SolverResult new_result{instance, new_sol};
286+
287+
if (double new_score = new_result.get_score(); new_score < last_score) {
288+
last_result = new_result;
289+
last_score = new_score;
290+
hist.push_back({std::chrono::duration_cast<std::chrono::milliseconds>(
291+
std::chrono::steady_clock::now() - initial_time),
292+
last_score});
293+
} else {
294+
sampling_range_fraction = sampling_range_fraction * contr_coeff;
295+
}
296+
}
297+
298+
return {last_result, hist};
299+
}
300+
266301
std::optional<cda_rail::sim::SolverResult>
267302
cda_rail::sim::RoutingSolver::greedy_solution(
268303
std::chrono::milliseconds per_train_stall_time) {
@@ -320,7 +355,8 @@ cda_rail::sim::RoutingSolver::greedy_solution(
320355

321356
std::tuple<std::optional<cda_rail::sim::SolverResult>,
322357
cda_rail::sim::ScoreHistory>
323-
cda_rail::sim::RoutingSolver::genetic_search(GeneticParams params) {
358+
cda_rail::sim::RoutingSolver::genetic_search(GeneticParams params,
359+
bool local_improv) {
324360
/**
325361
* Genetic algorithm for entire solution sets
326362
*/
@@ -342,9 +378,16 @@ cda_rail::sim::RoutingSolver::genetic_search(GeneticParams params) {
342378
std::placeholders::_2);
343379
ga_obj.mutate = std::bind(&RoutingSolver::mutate, this, std::placeholders::_1,
344380
std::placeholders::_2, std::placeholders::_3);
345-
ga_obj.crossover =
346-
std::bind(&RoutingSolver::crossover, this, std::placeholders::_1,
347-
std::placeholders::_2, std::placeholders::_3);
381+
382+
if (local_improv) {
383+
ga_obj.crossover = std::bind(&RoutingSolver::crossover_local_improv, this,
384+
std::placeholders::_1, std::placeholders::_2,
385+
std::placeholders::_3);
386+
} else {
387+
ga_obj.crossover =
388+
std::bind(&RoutingSolver::crossover, this, std::placeholders::_1,
389+
std::placeholders::_2, std::placeholders::_3);
390+
}
348391

349392
std::chrono::steady_clock::time_point starting_time =
350393
std::chrono::steady_clock::now();
@@ -406,6 +449,16 @@ cda_rail::sim::RoutingSolutionSet cda_rail::sim::RoutingSolver::crossover(
406449
return X_new;
407450
}
408451

452+
cda_rail::sim::RoutingSolutionSet
453+
cda_rail::sim::RoutingSolver::crossover_local_improv(
454+
const RoutingSolutionSet& X1, const RoutingSolutionSet& X2,
455+
const std::function<double(void)>& rnd01) {
456+
auto X_new = crossover(X1, X2, rnd01);
457+
auto res = local_search(X_new, 0.05, 0.01, 0.95, rnd01);
458+
459+
return std::get<0>(res).get_solutions();
460+
}
461+
409462
double cda_rail::sim::RoutingSolver::calculate_SO_total_fitness(
410463
const GA_Type::thisChromosomeType& X) {
411464
return X.middle_costs.score;

test/test_simulation.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ TEST(Simulation, LocalSearch) {
331331
sim::RoutingSolver solver{instance};
332332

333333
sim::RoutingSolutionSet solution_set{instance, rng_engine};
334-
auto res = solver.local_search(solution_set, 0.5, 1e-5, 0.95);
334+
auto res = solver.local_search(solution_set, 0.1, 0.01, 0.95);
335335
}
336336

337337
TEST(Simulation, RandomLocalSearch) {
@@ -378,11 +378,11 @@ TEST(Simulation, GeneticSearch) {
378378
"./example-networks-unidirec/SimpleNetwork/timetable/", network);
379379

380380
cda_rail::sim::GeneticParams ga_params{
381-
.is_multithread = false,
382-
.population = 1000,
381+
.is_multithread = true,
382+
.population = 100,
383383
.gen_max = 20,
384384
.stall_max = 5,
385-
.n_elite = 100,
385+
.n_elite = 10,
386386
.xover_frac = 0.7,
387387
.mut_rate = 0.1,
388388
};
@@ -391,7 +391,8 @@ TEST(Simulation, GeneticSearch) {
391391
sim::RoutingSolver solver{instance};
392392

393393
sim::RoutingSolutionSet solution_set{instance, rng_engine};
394-
auto res = solver.genetic_search(ga_params);
394+
auto res = solver.genetic_search(ga_params);
395+
auto res2 = solver.genetic_search(ga_params, true);
395396
}
396397

397398
// TODO: test for invariance of solution after being repaired and used again

0 commit comments

Comments
 (0)