Skip to content

Commit

Permalink
fix: SVA relative corrdinate operations
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Apr 13, 2024
1 parent dd4b6a3 commit ec83d18
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions pdf_agents/scientific_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from scipy.spatial import distance_matrix

from .agents import PDFBaseAgent
from .sklearn import PassiveKmeansAgent

logger = getLogger("pdf_agents.scientific_value")

Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(
if device is None
else torch.device(device)
)
self.bounds = torch.tensor(bounds, device=self.device).view(2, -1)
self.bounds = torch.tensor(bounds, device=self.device, dtype=torch.float).view(2, -1)

self.num_restarts = num_restarts
self.raw_samples = raw_samples
Expand Down Expand Up @@ -152,12 +151,22 @@ def ask(self, batch_size: int = 1):
candidate=candidate.detach().cpu().numpy(),
acquisition_value=acq.detach().cpu().numpy(),
latest_data=self.tell_cache[-1],
cache_len=len(self.independent_cache),
cache_len=(
len(self.independent_cache)
if isinstance(self.independent_cache, list)
else self.independent_cache.shape[0]
),
latest_value=value.squeeze()[-1],
ucb_beta=self.ucb_beta,
absolute_position_offset=self._motor_origins,
)
for candidate, acq in zip(candidates, acq_value)
]
if not hasattr(self.independent_cache[0], "shape"):
candidates = candidates.squeeze()
return docs, torch.atleast_1d(candidates).detach().cpu().numpy().tolist()

def measurement_plan(self, relative_point):
"""Send measurement plan absolute point from reltive position"""
absolute_point = relative_point + self._motor_origins
return super().measurement_plan(absolute_point)

0 comments on commit ec83d18

Please sign in to comment.