Skip to content

Commit

Permalink
Workaround isort & black clash
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Jul 26, 2023
1 parent b6ee392 commit fcc16e2
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions docs/notebooks/multi_trust_region.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import trieste
from trieste.acquisition import ParallelContinuousThompsonSampling
from trieste.acquisition.optimizer import automatic_optimizer_selector
from trieste.acquisition.rule import EfficientGlobalOptimization, MultiTrustRegionBox
from trieste.acquisition.rule import MultiTrustRegionBox
from trieste.acquisition.utils import split_acquisition_function_calls
from trieste.ask_tell_optimization import AskTellOptimizer
from trieste.experimental.plotting import plot_regret
Expand Down Expand Up @@ -74,10 +74,12 @@ def obj_fun(
)
model = GaussianProcessRegression(gpflow_model)

base_rule = EfficientGlobalOptimization(
base_rule = trieste.acquisition.rule.EfficientGlobalOptimization(
builder=ParallelContinuousThompsonSampling(),
num_query_points=num_query_points,
optimizer=split_acquisition_function_calls(automatic_optimizer_selector, split_size=100_000),
optimizer=split_acquisition_function_calls(
automatic_optimizer_selector, split_size=100_000
),
)

acq_rule = MultiTrustRegionBox(base_rule, number_of_tr=num_query_points)
Expand All @@ -89,7 +91,9 @@ def obj_fun(
# %%
color = cm.rainbow(np.linspace(0, 1, num_query_points))

Xplot, xx, yy = create_grid(mins=search_space.lower, maxs=search_space.upper, grid_density=90)
Xplot, xx, yy = create_grid(
mins=search_space.lower, maxs=search_space.upper, grid_density=90
)
ff = obj_fun(Xplot).numpy()

for step in range(num_steps):
Expand All @@ -115,7 +119,9 @@ def obj_fun(
ask_tell.dataset.query_points[:, 1].numpy(),
color="blue",
)
ax[0, 0].scatter(new_points[:, 0].numpy(), new_points[:, 1].numpy(), color="red")
ax[0, 0].scatter(
new_points[:, 0].numpy(), new_points[:, 1].numpy(), color="red"
)

state = ask_tell.acquisition_state
assert state is not None
Expand Down Expand Up @@ -181,7 +187,9 @@ def obj_fun(
best_found_truth_idx = tf.squeeze(tf.argmin(ground_truth_regret, axis=0))

fig, ax = plt.subplots()
plot_regret(ground_truth_regret.numpy(), ax, num_init=10, idx_best=best_found_truth_idx)
plot_regret(
ground_truth_regret.numpy(), ax, num_init=10, idx_best=best_found_truth_idx
)

ax.set_yscale("log")
ax.set_ylabel("Regret")
Expand Down

0 comments on commit fcc16e2

Please sign in to comment.