Skip to content

Commit

Permalink
Merge pull request #35 from jakegrigsby/predict_method
Browse files Browse the repository at this point in the history
Predict method
  • Loading branch information
jakegrigsby authored Feb 16, 2022
2 parents 9d0b128 + 5bb1f01 commit 2c02b4b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 14 deletions.
14 changes: 7 additions & 7 deletions spacetimeformer/data/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def mask_intervals(mask, intervals, cond):

self._scaler = self._scaler.fit(self._train_data[target_cols].values)

self._train_data = self.apply_scaling(df[train_mask])
self._val_data = self.apply_scaling(df[val_mask])
self._test_data = self.apply_scaling(df[test_mask])
self._train_data = self.apply_scaling_df(df[train_mask])
self._val_data = self.apply_scaling_df(df[val_mask])
self._test_data = self.apply_scaling_df(df[test_mask])

def get_slice(self, split, start, stop, skip):
assert split in ["train", "val", "test"]
Expand All @@ -91,25 +91,25 @@ def get_slice(self, split, start, stop, skip):
else:
return self.test_data.iloc[start:stop:skip]

def apply_scaling(self, df):
def apply_scaling_df(self, df):
scaled = df.copy(deep=True)
# scaled[self.target_cols] = self._scaler.transform(df[self.target_cols].values)
scaled[self.target_cols] = (
df[self.target_cols].values - self._scaler.mean_
) / self._scaler.scale_
return scaled

def apply_scaling(self, array):
return (array - self._scaler.mean_) / self._scaler.scale_

def reverse_scaling_df(self, df):
scaled = df.copy(deep=True)
# scaled[self.target_cols] = self._scaler.inverse_transform(df[self.target_cols].values)
scaled[self.target_cols] = (
df[self.target_cols] * self._scaler.scale_
) + self._scaler.mean_
return scaled

def reverse_scaling(self, array):
return (array * self._scaler.scale_) + self._scaler.mean_
# return self._scaler.inverse_transform(array)

@property
def train_data(self):
Expand Down
29 changes: 26 additions & 3 deletions spacetimeformer/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
):
super().__init__()
self._inv_scaler = lambda x: x
self._scaler = lambda x: x
self.l2_coeff = l2_coeff
self.learning_rate = learning_rate
self.time_masked_idx = None
Expand All @@ -36,6 +37,9 @@ def set_null_value(self, val: float) -> None:
def set_inv_scaler(self, scaler) -> None:
self._inv_scaler = scaler

def set_scaler(self, scaler) -> None:
self._scaler = scaler

@property
@abstractmethod
def train_step_forward_kwargs(self):
Expand Down Expand Up @@ -101,20 +105,39 @@ def predict(
y_c: torch.Tensor,
x_t: torch.Tensor,
sample_preds: bool = False,
) -> np.ndarray:
y_t = torch.zeros((x_t.shape[0], x_t.shape[1], y_c.shape[2])).to(x_t.device)
) -> torch.Tensor:
og_device = y_c.device
# move to model device
x_c = x_c.to(self.device).float()
x_t = x_t.to(self.device).float()
# move y_c to cpu if it isn't already there, scale, and then move back to the model device
y_c = torch.from_numpy(self._scaler(y_c.cpu().numpy())).to(self.device).float()
# create dummy y_t of zeros
y_t = (
torch.zeros((x_t.shape[0], x_t.shape[1], y_c.shape[2]))
.to(self.device)
.float()
)

with torch.no_grad():
# gradient-free prediction
normalized_preds, *_ = self.forward(
x_c, y_c, x_t, y_t, **self.eval_step_forward_kwargs
)

# handle case that the output is a distribution (spacetimeformer)
if isinstance(normalized_preds, Normal):
if sample_preds:
normalized_preds = normalized_preds.sample()
else:
normalized_preds = normalized_preds.mean

preds = self._inv_scaler(normalized_preds.cpu().numpy())
# preds --> cpu --> inverse scale to original units --> original device of y_c
preds = (
torch.from_numpy(self._inv_scaler(normalized_preds.cpu().numpy()))
.to(og_device)
.float()
)
return preds

@abstractmethod
Expand Down
11 changes: 7 additions & 4 deletions spacetimeformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def create_model(config):

def create_dset(config):
INV_SCALER = lambda x: x
SCALER = lambda x: x
NULL_VAL = None

if config.dset == "metr-la" or config.dset == "pems-bay":
Expand All @@ -239,6 +240,7 @@ def create_dset(config):
workers=config.workers,
)
INV_SCALER = data.inverse_scale
SCALER = data.scale
NULL_VAL = 0.0

elif config.dset == "precip":
Expand Down Expand Up @@ -302,9 +304,10 @@ def create_dset(config):
workers=config.workers,
)
INV_SCALER = dset.reverse_scaling
SCALER = dset.apply_scaling
NULL_VAL = None

return DATA_MODULE, INV_SCALER, NULL_VAL
return DATA_MODULE, INV_SCALER, SCALER, NULL_VAL


def create_callbacks(config):
Expand Down Expand Up @@ -379,14 +382,14 @@ def main(args):
)
logger.log_hyperparams(config)


# Dset
data_module, inv_scaler, null_val = create_dset(args)
data_module, inv_scaler, scaler, null_val = create_dset(args)

# Model
args.null_value = null_val
forecaster = create_model(args)
forecaster.set_inv_scaler(inv_scaler)
forecaster.set_scaler(scaler)
forecaster.set_null_value(null_val)

# Callbacks
Expand Down Expand Up @@ -419,7 +422,7 @@ def main(args):
gradient_clip_val=args.grad_clip_norm,
gradient_clip_algorithm="norm",
overfit_batches=20 if args.debug else 0,
#track_grad_norm=2,
# track_grad_norm=2,
accumulate_grad_batches=args.accumulate,
sync_batchnorm=True,
val_check_interval=0.25 if args.dset == "asos" else 1.0,
Expand Down

0 comments on commit 2c02b4b

Please sign in to comment.