Skip to content

Commit

Permalink
fix: torch_forecasting_model load_weights with float16 and float32 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Hsinfu authored Nov 6, 2023
1 parent 772d705 commit af5b141
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Fixed**
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when using encoders with `RegressionModel` and series with a non-evenly spaced frequency (e.g. Month Begin). This raised an error during lagged data creation when trying to divide a pd.Timedelta by the ambiguous frequency. [#2034](https://github.com/unit8co/darts/pull/2034) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when loading a `TorchForecastingModel` that was trained with a precision other than `float64`. [#2046](https://github.com/unit8co/darts/pull/2046) by [Freddie Hsin-Fu Huang](https://github.com/Hsinfu).

### For developers of the library:

Expand Down
9 changes: 8 additions & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@
RUNS_FOLDER = "runs"
INIT_MODEL_NAME = "_model.pth.tar"

TORCH_NP_DTYPES = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
}

# pickling a TorchForecastingModel will not save below attributes: the keys specify the
# attributes to be ignored, and the values are the default values getting assigned upon loading
TFM_ATTRS_NO_PICKLE = {"model": None, "trainer": None}
Expand Down Expand Up @@ -1872,8 +1878,9 @@ def load_weights_from_checkpoint(
)

# pl_forecasting module saves the train_sample shape, must recreate one
np_dtype = TORCH_NP_DTYPES[ckpt["model_dtype"]]
mock_train_sample = [
np.zeros(sample_shape) if sample_shape else None
np.zeros(sample_shape, dtype=np_dtype) if sample_shape else None
for sample_shape in ckpt["train_sample_shape"]
]
self.train_sample = tuple(mock_train_sample)
Expand Down
23 changes: 23 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,29 @@ def test_load_weights(self, tmpdir_fn):
f"respectively {retrained_mape} and {original_mape}"
)

def test_load_weights_with_float32_dtype(self, tmpdir_fn):
ts_float32 = self.series.astype("float32")
model_name = "test_model"
ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt")
# barebone model
model = DLinearModel(
input_chunk_length=4,
output_chunk_length=1,
n_epochs=1,
)
model.fit(ts_float32)
model.save(ckpt_path)
assert model.model._dtype == torch.float32 # type: ignore

# identical model
loading_model = DLinearModel(
input_chunk_length=4,
output_chunk_length=1,
)
loading_model.load_weights(ckpt_path)
loading_model.fit(ts_float32)
assert loading_model.model._dtype == torch.float32 # type: ignore

def test_multi_steps_pipeline(self, tmpdir_fn):
ts_training, ts_val = self.series.split_before(75)
pretrain_model_name = "pre-train"
Expand Down

0 comments on commit af5b141

Please sign in to comment.