Skip to content

Commit

Permalink
fix: circular grid search for SVA
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Apr 13, 2024
1 parent 63992dd commit 60264ac
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions pdf_agents/scientific_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from botorch import fit_gpytorch_mll
from botorch.acquisition import UpperConfidenceBound, qUpperConfidenceBound
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf
from botorch.optim import optimize_acqf # noqa: F401
from gpytorch.mlls import ExactMarginalLogLikelihood
from scipy.spatial import distance_matrix

from .agents import PDFBaseAgent
from .utils import make_wafer_grid_list

logger = getLogger("pdf_agents.scientific_value")

Expand Down Expand Up @@ -140,9 +141,19 @@ def ask(self, batch_size: int = 1):
if batch_size == 1
else qUpperConfidenceBound(gp, beta=self.ucb_beta).to(self.device)
)
candidates, acq_value = optimize_acqf(
acq, bounds=self.bounds, q=batch_size, num_restarts=self.num_restarts, raw_samples=self.raw_samples
)

# Override the optimization to use a grid search over a circular wafer
# candidates, acq_value = optimize_acqf(
# acq, bounds=self.bounds, q=batch_size, num_restarts=self.num_restarts, raw_samples=self.raw_samples
# )
grid = torch.tensor(make_wafer_grid_list(*self.bounds.cpu().numpy().ravel(), step=self.motor_resolution))[
:, None, :
]
acq_grid = acq(grid)
top_indicies = torch.argsort(acq_grid, descending=True, dim=0)[:batch_size]
candidates = grid[top_indicies].squeeze(1)
acq_value = acq_grid[top_indicies]

if batch_size == 1:
acq_value = [acq_value]

Expand Down

0 comments on commit 60264ac

Please sign in to comment.