From 01cfbb594cda86604e78fb33f2ee671608d66bd0 Mon Sep 17 00:00:00 2001 From: Yousif Alsaffar Date: Tue, 15 Oct 2024 02:50:12 +0300 Subject: [PATCH] adding type hints and adjusting normalize_inputs conditions --- aepsych/strategy.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index dc9484961..401fdc0d4 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -170,7 +170,7 @@ def __init__( self.name = name - def normalize_inputs(self, x, y): + def normalize_inputs(self, x:torch.Tensor, y:torch.Tensor): """converts inputs into normalized format for this strategy Args: @@ -193,14 +193,10 @@ def normalize_inputs(self, x, y): if x.shape == self.event_shape: x = x[None, :] - if self.x is None: - x = x - else: + if self.x is not None: x = torch.cat((self.x, x), dim=0) - if self.y is None: - y = y - else: + if self.y is not None: y = torch.cat((self.y, y), dim=0) # Ensure the correct dtype @@ -313,7 +309,7 @@ def n_trials(self): ) return self.min_asks - def add_data(self, x, y): + def add_data(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]): """ Adds new data points to the strategy, and normalizes the inputs.