diff --git a/neuralforecast/common/_base_multivariate.py b/neuralforecast/common/_base_multivariate.py index a1f8a51ac..3ac6fe6a0 100644 --- a/neuralforecast/common/_base_multivariate.py +++ b/neuralforecast/common/_base_multivariate.py @@ -5,6 +5,7 @@ # %% ../../nbs/common.base_multivariate.ipynb 5 import numpy as np +import ml_dtypes import torch import torch.nn as nn import pytorch_lightning as pl @@ -595,7 +596,11 @@ def predict( trainer = pl.Trainer(**pred_trainer_kwargs) fcsts = trainer.predict(self, datamodule=datamodule) - fcsts = torch.vstack(fcsts).numpy() + fcsts = torch.vstack(fcsts) + if fcsts.dtype == torch.bfloat16: + fcsts = fcsts.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + else: + fcsts = fcsts.numpy() fcsts = np.transpose(fcsts, (2, 0, 1)) fcsts = fcsts.flatten()