Skip to content

Commit

Permalink
Improve hparams saving and loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Sep 25, 2021
1 parent 5706cf5 commit 95d8491
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def __init__(
if not hasattr(self, "optimizer"): # callables are removed from hyperparameters, so better to save them
self.optimizer = self.hparams.optimizer

# delete everything from hparams that cannot be serialized with yaml dump
# delete everything from hparams that cannot be serialized with yaml.dump
# which is particularly important for tensorboard logging
hparams_to_delete = []
for k, v in self.hparams.items():
try:
Expand Down Expand Up @@ -960,6 +961,8 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["hparams_name"] = "kwargs"
# save specials
checkpoint[self.CHECKPOINT_HYPER_PARAMS_SPECIAL_KEY] = {k: getattr(self, k) for k in self.hparams_special}
# add special hparams them back to save the hparams correctly for checkpoint
checkpoint[self.CHECKPOINT_HYPER_PARAMS_KEY].update(checkpoint[self.CHECKPOINT_HYPER_PARAMS_SPECIAL_KEY])

@property
def target_names(self) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_forecasting/models/nbeats/sub_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def linspace(backcast_length: int, forecast_length: int, centered: bool = False)
norm = backcast_length + forecast_length
start = 0
stop = backcast_length + forecast_length - 1
lin_space = np.linspace(start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float6432)
lin_space = np.linspace(start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32)
b_ls = lin_space[:backcast_length]
f_ls = lin_space[backcast_length:]
return b_ls, f_ls
Expand Down

0 comments on commit 95d8491

Please sign in to comment.