Skip to content

Commit

Permalink
add integration of local and distributed (old) modes
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Jun 3, 2024
1 parent f2e6728 commit 2031499
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from autotm.algorithms_for_tuning.genetic_algorithm.ga import GA
from autotm.algorithms_for_tuning.genetic_algorithm.surrogate import Surrogate
from autotm.algorithms_for_tuning.individuals import IndividualBuilder
from autotm.fitness.estimator import ComputableFitnessEstimator, SurrogateEnabledComputableFitnessEstimator
from autotm.fitness.estimator import ComputableFitnessEstimator, SurrogateEnabledComputableFitnessEstimator, \
DistributedSurrogateEnabledComputableFitnessEstimator
from autotm.fitness.tm import fit_tm, TopicModel
from autotm.utils import make_log_config_dict

Expand Down Expand Up @@ -47,6 +48,7 @@ def get_best_individual(
quiet_log: bool = False,
statistics_collector: Optional[StatisticsCollector] = None,
individual_type: str = "regular",
fitness_estimator_type: str = "local", # distributed
**kwargs
):
"""
Expand Down Expand Up @@ -103,15 +105,31 @@ def get_best_individual(
cross_alpha = float(cross_alpha)

ibuilder = IndividualBuilder(individual_type)
fitness_estimator = \
SurrogateEnabledComputableFitnessEstimator(

if fitness_estimator_type == "local" and surrogate_name:
fitness_estimator = SurrogateEnabledComputableFitnessEstimator(
ibuilder,
Surrogate(surrogate_name),
"type1",
SPEEDUP,
num_fitness_evaluations,
statistics_collector
)
elif fitness_estimator_type == "local":
fitness_estimator = ComputableFitnessEstimator(ibuilder, num_fitness_evaluations, statistics_collector)
elif fitness_estimator_type == "distributed" and surrogate_name:
fitness_estimator = DistributedSurrogateEnabledComputableFitnessEstimator(
ibuilder,
Surrogate(surrogate_name),
"type1",
SPEEDUP,
num_fitness_evaluations,
statistics_collector
) if surrogate_name else ComputableFitnessEstimator(ibuilder, num_fitness_evaluations, statistics_collector)
)
elif fitness_estimator_type == "distributed":
fitness_estimator = ComputableFitnessEstimator(ibuilder, num_fitness_evaluations, statistics_collector)
else:
raise ValueError("Incorrect settings")

g = GA(
dataset=dataset,
Expand Down Expand Up @@ -174,15 +192,17 @@ def run_algorithm(
use_nelder_mead_in_selector: bool = False,
train_option: str = "offline",
quiet_log: bool = False,
individual_type: str = "regular"
individual_type: str = "regular",
fitness_estimator_type: str = "local"
) -> TopicModel:
best_individual = get_best_individual(dataset, data_path, exp_id, topic_count, num_individuals, num_iterations,
num_fitness_evaluations, mutation_type, crossover_type, selection_type,
elem_cross_prob, cross_alpha, best_proc, log_file, tag, surrogate_name,
gpr_kernel, gpr_alpha, gpr_normalize_y, use_pipeline,
use_nelder_mead_in_mutation, use_nelder_mead_in_crossover,
use_nelder_mead_in_selector, train_option, quiet_log,
individual_type=individual_type)
individual_type=individual_type,
fitness_estimator_type=fitness_estimator_type)

best_topic_model = fit_tm(
preproc_data_path=data_path,
Expand Down
24 changes: 24 additions & 0 deletions autotm/fitness/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,27 @@ def __init__(self,
if calc_scheme not in self.SUPPORTED_CALC_SCHEMES:
raise ValueError(f"Unexpected surrogate scheme! {self.calc_scheme}")
super().__init__(ibuilder, num_fitness_evaluations, statistics_collector)


class DistributedSurrogateEnabledComputableFitnessEstimator(
DistributedComputableFitnessEstimator,
SurrogateEnabledFitnessEstimatorMixin
):
def __init__(self,
ibuilder: IndividualBuilder,
surrogate: Surrogate,
calc_scheme: str,
speedup: bool = True,
num_fitness_evaluations: Optional[int] = None,
statistics_collector: Optional[StatisticsCollector] = None):
self.ibuilder = ibuilder
self.surrogate = surrogate
self.calc_scheme = calc_scheme
self.speedup = speedup

self.all_params: List[AbstractParams] = []
self.all_fitness: List[float] = []

if calc_scheme not in self.SUPPORTED_CALC_SCHEMES:
raise ValueError(f"Unexpected surrogate scheme! {self.calc_scheme}")
super().__init__(ibuilder, num_fitness_evaluations, statistics_collector)
1 change: 1 addition & 0 deletions tests/integration/test_fit_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_fit_predict(pytestconfig):
alg_params={
"num_iterations": 2,
"num_individuals": 4,
"use_pipeline": False,
"use_nelder_mead_in_mutation": False,
"use_nelder_mead_in_crossover": False,
"use_nelder_mead_in_selector": False,
Expand Down

0 comments on commit 2031499

Please sign in to comment.