diff --git a/pdf_agents/scientific_value.py b/pdf_agents/scientific_value.py index fec2a70..d73da03 100644 --- a/pdf_agents/scientific_value.py +++ b/pdf_agents/scientific_value.py @@ -13,7 +13,6 @@ from scipy.spatial import distance_matrix from .agents import PDFBaseAgent -from .sklearn import PassiveKmeansAgent logger = getLogger("pdf_agents.scientific_value") @@ -88,7 +87,7 @@ def __init__( if device is None else torch.device(device) ) - self.bounds = torch.tensor(bounds, device=self.device).view(2, -1) + self.bounds = torch.tensor(bounds, device=self.device, dtype=torch.float).view(2, -1) self.num_restarts = num_restarts self.raw_samples = raw_samples @@ -152,12 +151,22 @@ def ask(self, batch_size: int = 1): candidate=candidate.detach().cpu().numpy(), acquisition_value=acq.detach().cpu().numpy(), latest_data=self.tell_cache[-1], - cache_len=len(self.independent_cache), + cache_len=( + len(self.independent_cache) + if isinstance(self.independent_cache, list) + else self.independent_cache.shape[0] + ), latest_value=value.squeeze()[-1], ucb_beta=self.ucb_beta, + absolute_position_offset=self._motor_origins, ) for candidate, acq in zip(candidates, acq_value) ] if not hasattr(self.independent_cache[0], "shape"): candidates = candidates.squeeze() return docs, torch.atleast_1d(candidates).detach().cpu().numpy().tolist() + + def measurement_plan(self, relative_point): + """Send measurement plan absolute point from reltive position""" + absolute_point = relative_point + self._motor_origins + return super().measurement_plan(absolute_point)