From 60264ac4319342263a98f1079457ed24c81bbeaf Mon Sep 17 00:00:00 2001 From: maffettone Date: Fri, 12 Apr 2024 19:32:47 -0700 Subject: [PATCH] fix: circular grid search for SVA --- pdf_agents/scientific_value.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pdf_agents/scientific_value.py b/pdf_agents/scientific_value.py index 64e669c..a394fd5 100644 --- a/pdf_agents/scientific_value.py +++ b/pdf_agents/scientific_value.py @@ -8,11 +8,12 @@ from botorch import fit_gpytorch_mll from botorch.acquisition import UpperConfidenceBound, qUpperConfidenceBound from botorch.models import SingleTaskGP -from botorch.optim import optimize_acqf +from botorch.optim import optimize_acqf # noqa: F401 from gpytorch.mlls import ExactMarginalLogLikelihood from scipy.spatial import distance_matrix from .agents import PDFBaseAgent +from .utils import make_wafer_grid_list logger = getLogger("pdf_agents.scientific_value") @@ -140,9 +141,19 @@ def ask(self, batch_size: int = 1): if batch_size == 1 else qUpperConfidenceBound(gp, beta=self.ucb_beta).to(self.device) ) - candidates, acq_value = optimize_acqf( - acq, bounds=self.bounds, q=batch_size, num_restarts=self.num_restarts, raw_samples=self.raw_samples - ) + + # Override the optimization to use a grid search over a circular wafer + # candidates, acq_value = optimize_acqf( + # acq, bounds=self.bounds, q=batch_size, num_restarts=self.num_restarts, raw_samples=self.raw_samples + # ) + grid = torch.tensor(make_wafer_grid_list(*self.bounds.cpu().numpy().ravel(), step=self.motor_resolution))[ + :, None, : + ] + acq_grid = acq(grid) + top_indicies = torch.argsort(acq_grid, descending=True, dim=0)[:batch_size] + candidates = grid[top_indicies].squeeze(1) + acq_value = acq_grid[top_indicies] + if batch_size == 1: acq_value = [acq_value]