From 0f3d56746503d2fb8dbc9dc001b01d3f3f83d772 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Fri, 19 Jan 2024 19:23:57 +0100 Subject: [PATCH] Improve performance in parallel case (#9) --- src/tranquilo/acceptance_decision.py | 30 +++++++++++----- src/tranquilo/filter_points.py | 6 ++-- src/tranquilo/options.py | 2 ++ tests/test_acceptance_decision.py | 3 +- tests/test_filter_points.py | 53 ++++++++++++++++++++++++++++ 5 files changed, 83 insertions(+), 11 deletions(-) diff --git a/src/tranquilo/acceptance_decision.py b/src/tranquilo/acceptance_decision.py index 62b8a949..534b3397 100644 --- a/src/tranquilo/acceptance_decision.py +++ b/src/tranquilo/acceptance_decision.py @@ -16,7 +16,10 @@ from tranquilo.options import AcceptanceOptions -def get_acceptance_decider(acceptance_decider, acceptance_options): +def get_acceptance_decider( + acceptance_decider, + acceptance_options, +): func_dict = { "classic": _accept_classic, "naive_noisy": accept_naive_noisy, @@ -92,11 +95,11 @@ def accept_classic_line_search( state, history, *, + speculative_sampling_radius_factor, wrapped_criterion, min_improvement, batch_size, sample_points, - search_radius_factor, rng, ): # ================================================================================== @@ -144,11 +147,12 @@ def accept_classic_line_search( if n_unallocated_evals > 0: speculative_xs = _generate_speculative_sample( new_center=candidate_x, - search_radius_factor=search_radius_factor, + radius_factor=speculative_sampling_radius_factor, trustregion=state.trustregion, sample_points=sample_points, n_points=n_unallocated_evals, history=history, + line_search_xs=line_search_xs, rng=rng, ) else: @@ -427,7 +431,14 @@ def calculate_rho(actual_improvement, expected_improvement): def _generate_speculative_sample( - new_center, trustregion, sample_points, n_points, history, search_radius_factor, rng + new_center, + trustregion, + sample_points, + n_points, + history, + line_search_xs, + radius_factor, + rng, ): """Generative a speculative sample. @@ -437,8 +448,8 @@ def _generate_speculative_sample( sample_points (callable): Function to sample points. n_points (int): Number of points to sample. history (History): Tranquilo history. - search_radius_factor (float): Factor to multiply the trust region radius by to - get the search radius. + radius_factor (float): Factor to multiply the trust region radius by to get the + radius of the region from which to draw the speculative sample. rng (np.random.Generator): Random number generator. Returns: @@ -446,14 +457,17 @@ def _generate_speculative_sample( """ search_region = trustregion._replace( - center=new_center, radius=search_radius_factor * trustregion.radius + center=new_center, radius=radius_factor * trustregion.radius ) old_indices = history.get_x_indices_in_region(search_region) old_xs = history.get_xs(old_indices) - model_xs = old_xs + if line_search_xs is not None: + model_xs = np.row_stack([old_xs, line_search_xs]) + else: + model_xs = old_xs new_xs = sample_points( search_region, diff --git a/src/tranquilo/filter_points.py b/src/tranquilo/filter_points.py index dc3441fd..318d3b15 100644 --- a/src/tranquilo/filter_points.py +++ b/src/tranquilo/filter_points.py @@ -50,8 +50,10 @@ def keep_all(xs, indices): return xs, indices -def drop_excess(xs, indices, state, target_size): - n_to_drop = max(0, len(xs) - target_size) +def drop_excess(xs, indices, state, target_size, n_max_factor): + filter_target_size = int(np.floor(target_size * n_max_factor)) + + n_to_drop = max(0, len(xs) - filter_target_size) if n_to_drop: xs, indices = drop_worst_points(xs, indices, state, n_to_drop) diff --git a/src/tranquilo/options.py b/src/tranquilo/options.py index 2a92516a..dbe0312c 100644 --- a/src/tranquilo/options.py +++ b/src/tranquilo/options.py @@ -140,6 +140,7 @@ class AcceptanceOptions(NamedTuple): n_min: int = 4 n_max: int = 50 min_improvement: float = 0.0 + speculative_sampling_radius_factor: float = 0.75 class StagnationOptions(NamedTuple): @@ -179,6 +180,7 @@ class VarianceEstimatorOptions(NamedTuple): class FilterOptions(NamedTuple): strictness: float = 1e-10 shape: str = "sphere" + n_max_factor: int = 3 class SamplerOptions(NamedTuple): diff --git a/tests/test_acceptance_decision.py b/tests/test_acceptance_decision.py index 062e772b..2be2bab9 100644 --- a/tests/test_acceptance_decision.py +++ b/tests/test_acceptance_decision.py @@ -185,7 +185,8 @@ def test_generate_speculative_sample(): sample_points=get_sampler("random_hull"), n_points=3, history=history, - search_radius_factor=1.0, + radius_factor=1.0, + line_search_xs=None, rng=np.random.default_rng(1234), ) diff --git a/tests/test_filter_points.py b/tests/test_filter_points.py index 98e5c343..59ac427b 100644 --- a/tests/test_filter_points.py +++ b/tests/test_filter_points.py @@ -1,4 +1,5 @@ from tranquilo.filter_points import get_sample_filter +from tranquilo.filter_points import drop_worst_points from tranquilo.tranquilo import State from tranquilo.region import Region from numpy.testing import assert_array_equal as aae @@ -46,3 +47,55 @@ def test_keep_all(): got_xs, got_idxs = filter(xs=xs, indices=indices, state=None) aae(got_xs, xs) aae(got_idxs, indices) + + +def test_drop_worst_point(state): + xs = np.array( + [ + [1, 1.1], # should be dropped + [1, 1.2], + [1, 1], # center (needs to have index=2) + [3, 3], # should be dropped + ] + ) + + got_xs, got_indices = drop_worst_points( + xs, indices=np.arange(4), state=state, n_to_drop=2 + ) + + expected_xs = np.array( + [ + [1, 1.2], + [1, 1], + ] + ) + expected_indices = np.array([1, 2]) + + aae(got_xs, expected_xs) + aae(got_indices, expected_indices) + + +def test_drop_excess(state): + filter = get_sample_filter("drop_excess", user_options={"n_max_factor": 1.0}) + + xs = np.array( + [ + [1, 1.1], # should be dropped + [1, 1.2], + [1, 1], # center (needs to have index=2) + [3, 3], # should be dropped + ] + ) + + got_xs, got_indices = filter(xs, indices=np.arange(4), state=state, target_size=2) + + expected_xs = np.array( + [ + [1, 1.2], + [1, 1], + ] + ) + expected_indices = np.array([1, 2]) + + aae(got_xs, expected_xs) + aae(got_indices, expected_indices)