diff --git a/pdf_agents/scientific_value.py b/pdf_agents/scientific_value.py index d73da03..64e669c 100644 --- a/pdf_agents/scientific_value.py +++ b/pdf_agents/scientific_value.py @@ -72,7 +72,7 @@ def __init__( bounds: torch.Tensor, device: torch.device = None, num_restarts: int = 10, - raw_samples: int = 20, + raw_samples: int = 128, observable_distance_function: Optional[Callable] = None, ucb_beta=1.0, **kwargs @@ -122,7 +122,7 @@ def tell(self, x, y): def report(self): value = self._value_function(np.array(self.independent_cache), np.array(self.observable_cache)) - dict(latest_data=self.tell_cache[-1], cache_len=len(self.independent_cache), latest_value=value[-1]) + return dict(latest_data=self.tell_cache[-1], cache_len=len(self.independent_cache), latest_value=value[-1]) def ask(self, batch_size: int = 1): value = self._value_function(np.array(self.independent_cache), np.array(self.observable_cache))