Skip to content

Commit

Permalink
Fix Strategy class to ensure consistent tensor operations for data no…
Browse files Browse the repository at this point in the history
…rmalization (facebookresearch#403)

Summary:
This PR addresses the second part of issue facebookresearch#365, focusing on the `Strategy` class and how data is added and normalized, transitioning the process to use tensors instead of NumPy operations.

The changes were made specifically within the `normalize_inputs` method of the `Strategy` class. Previously, this method had mismatched docstrings indicating `np.array` usage. Now, it consistently accepts and returns tensors, performing all operations within tensors.

The `normalize_inputs` method is called in `add_data()` (where the confusion arises), as the data passed can vary (either tensors or `np.array`). To resolve this, the method now acts as the first step, accepting both formats and then converting everything to tensors for consistent operations (model fitting later on). It’s also crucial to ensure the data type is `float64`, as `gpytorch` does not support other data types.

Additionally, a detailed docstring was added to clarify the method's expectations and ensure its proper use going forward.

Pull Request resolved: facebookresearch#403

Reviewed By: crasanders

Differential Revision: D64343236

Pulled By: JasonKChow

fbshipit-source-id: 413077605f4fa46b82405897c713cbc62b58a3f3
  • Loading branch information
yalsaffar authored and facebook-github-bot committed Oct 17, 2024
1 parent 39cc066 commit 09b1d59
Showing 1 changed file with 35 additions and 18 deletions.
53 changes: 35 additions & 18 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def __init__(
lb=self.lb, ub=self.ub, size=self._n_eval_points
)

self.x = None
self.y = None
self.x: Optional[torch.Tensor] = None
self.y: Optional[torch.Tensor] = None
self.n = 0
self.min_asks = min_asks
self._count = 0
Expand Down Expand Up @@ -170,38 +170,41 @@ def __init__(

self.name = name

def normalize_inputs(self, x, y):
def normalize_inputs(self, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""converts inputs into normalized format for this strategy
Args:
x (np.ndarray): training inputs
y (np.ndarray): training outputs
x (torch.Tensor): training inputs
y (torch.Tensor): training outputs
Returns:
x (np.ndarray): training inputs, normalized
y (np.ndarray): training outputs, normalized
x (torch.Tensor): training inputs, normalized
y (torch.Tensor): training outputs, normalized
n (int): number of observations
"""
assert (
x.shape == self.event_shape or x.shape[1:] == self.event_shape
), f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}"


# Handle scalar y values
if y.ndim == 0:
y = y.unsqueeze(0)

if x.shape == self.event_shape:
x = x[None, :]

if self.x is None:
x = np.r_[x]
else:
x = np.r_[self.x, x]
if self.x is not None:
x = torch.cat((self.x, x), dim=0)

if self.y is None:
y = np.r_[y]
else:
y = np.r_[self.y, y]
if self.y is not None:
y = torch.cat((self.y, y), dim=0)

# Ensure the correct dtype
x = x.to(torch.float64)
y = y.to(torch.float64)
n = y.shape[0]

return torch.Tensor(x), torch.Tensor(y), n
return x, y, n

# TODO: allow user to pass in generator options
@ensure_model_is_fresh
Expand Down Expand Up @@ -306,7 +309,21 @@ def n_trials(self):
)
return self.min_asks

def add_data(self, x, y):
def add_data(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]):
"""
Adds new data points to the strategy, and normalizes the inputs.
Args:
x (torch.Tensor, np.ndarray): The input data points. Can be a PyTorch tensor or NumPy array.
y (torch.Tensor, np.ndarray): The output data points. Can be a PyTorch tensor or NumPy array.
"""
# Necessary as sometimes the data is passed in as numpy arrays or torch tensors.
if not isinstance(y, torch.Tensor):
y = torch.tensor(y, dtype=torch.float64)
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=torch.float64)

self.x, self.y, self.n = self.normalize_inputs(x, y)
self._model_is_fresh = False

Expand Down

0 comments on commit 09b1d59

Please sign in to comment.