diff --git a/autotm/algorithms_for_tuning/genetic_algorithm/genetic_algorithm.py b/autotm/algorithms_for_tuning/genetic_algorithm/genetic_algorithm.py index 04990ff..9aac782 100755 --- a/autotm/algorithms_for_tuning/genetic_algorithm/genetic_algorithm.py +++ b/autotm/algorithms_for_tuning/genetic_algorithm/genetic_algorithm.py @@ -9,7 +9,8 @@ from autotm.algorithms_for_tuning.genetic_algorithm.ga import GA from autotm.algorithms_for_tuning.genetic_algorithm.surrogate import Surrogate from autotm.algorithms_for_tuning.individuals import IndividualBuilder -from autotm.fitness.estimator import ComputableFitnessEstimator, SurrogateEnabledComputableFitnessEstimator +from autotm.fitness.estimator import ComputableFitnessEstimator, SurrogateEnabledComputableFitnessEstimator, \ + DistributedSurrogateEnabledComputableFitnessEstimator from autotm.fitness.tm import fit_tm, TopicModel from autotm.utils import make_log_config_dict @@ -47,6 +48,7 @@ def get_best_individual( quiet_log: bool = False, statistics_collector: Optional[StatisticsCollector] = None, individual_type: str = "regular", + fitness_estimator_type: str = "local", # distributed **kwargs ): """ @@ -103,15 +105,31 @@ def get_best_individual( cross_alpha = float(cross_alpha) ibuilder = IndividualBuilder(individual_type) - fitness_estimator = \ - SurrogateEnabledComputableFitnessEstimator( + + if fitness_estimator_type == "local" and surrogate_name: + fitness_estimator = SurrogateEnabledComputableFitnessEstimator( + ibuilder, + Surrogate(surrogate_name), + "type1", + SPEEDUP, + num_fitness_evaluations, + statistics_collector + ) + elif fitness_estimator_type == "local": + fitness_estimator = ComputableFitnessEstimator(ibuilder, num_fitness_evaluations, statistics_collector) + elif fitness_estimator_type == "distributed" and surrogate_name: + fitness_estimator = DistributedSurrogateEnabledComputableFitnessEstimator( ibuilder, Surrogate(surrogate_name), "type1", SPEEDUP, num_fitness_evaluations, statistics_collector - ) if surrogate_name else ComputableFitnessEstimator(ibuilder, num_fitness_evaluations, statistics_collector) + ) + elif fitness_estimator_type == "distributed": + fitness_estimator = ComputableFitnessEstimator(ibuilder, num_fitness_evaluations, statistics_collector) + else: + raise ValueError("Incorrect settings") g = GA( dataset=dataset, @@ -174,7 +192,8 @@ def run_algorithm( use_nelder_mead_in_selector: bool = False, train_option: str = "offline", quiet_log: bool = False, - individual_type: str = "regular" + individual_type: str = "regular", + fitness_estimator_type: str = "local" ) -> TopicModel: best_individual = get_best_individual(dataset, data_path, exp_id, topic_count, num_individuals, num_iterations, num_fitness_evaluations, mutation_type, crossover_type, selection_type, @@ -182,7 +201,8 @@ def run_algorithm( gpr_kernel, gpr_alpha, gpr_normalize_y, use_pipeline, use_nelder_mead_in_mutation, use_nelder_mead_in_crossover, use_nelder_mead_in_selector, train_option, quiet_log, - individual_type=individual_type) + individual_type=individual_type, + fitness_estimator_type=fitness_estimator_type) best_topic_model = fit_tm( preproc_data_path=data_path, diff --git a/autotm/fitness/estimator.py b/autotm/fitness/estimator.py index f4a4ac9..870e316 100644 --- a/autotm/fitness/estimator.py +++ b/autotm/fitness/estimator.py @@ -278,3 +278,27 @@ def __init__(self, if calc_scheme not in self.SUPPORTED_CALC_SCHEMES: raise ValueError(f"Unexpected surrogate scheme! {self.calc_scheme}") super().__init__(ibuilder, num_fitness_evaluations, statistics_collector) + + +class DistributedSurrogateEnabledComputableFitnessEstimator( + DistributedComputableFitnessEstimator, + SurrogateEnabledFitnessEstimatorMixin +): + def __init__(self, + ibuilder: IndividualBuilder, + surrogate: Surrogate, + calc_scheme: str, + speedup: bool = True, + num_fitness_evaluations: Optional[int] = None, + statistics_collector: Optional[StatisticsCollector] = None): + self.ibuilder = ibuilder + self.surrogate = surrogate + self.calc_scheme = calc_scheme + self.speedup = speedup + + self.all_params: List[AbstractParams] = [] + self.all_fitness: List[float] = [] + + if calc_scheme not in self.SUPPORTED_CALC_SCHEMES: + raise ValueError(f"Unexpected surrogate scheme! {self.calc_scheme}") + super().__init__(ibuilder, num_fitness_evaluations, statistics_collector) diff --git a/tests/integration/test_fit_predict.py b/tests/integration/test_fit_predict.py index e7a9c05..e98d54b 100644 --- a/tests/integration/test_fit_predict.py +++ b/tests/integration/test_fit_predict.py @@ -42,6 +42,7 @@ def test_fit_predict(pytestconfig): alg_params={ "num_iterations": 2, "num_individuals": 4, + "use_pipeline": False, "use_nelder_mead_in_mutation": False, "use_nelder_mead_in_crossover": False, "use_nelder_mead_in_selector": False,