From 2daa274aa21c1f901f4337dae2e99b2717805c43 Mon Sep 17 00:00:00 2001 From: carusyte Date: Sun, 2 Feb 2025 21:40:27 +0800 Subject: [PATCH] add bfloat16 support in _base_multivariate --- neuralforecast/common/_base_multivariate.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/neuralforecast/common/_base_multivariate.py b/neuralforecast/common/_base_multivariate.py index a1f8a51ac..3ac6fe6a0 100644 --- a/neuralforecast/common/_base_multivariate.py +++ b/neuralforecast/common/_base_multivariate.py @@ -5,6 +5,7 @@ # %% ../../nbs/common.base_multivariate.ipynb 5 import numpy as np +import ml_dtypes import torch import torch.nn as nn import pytorch_lightning as pl @@ -595,7 +596,11 @@ def predict( trainer = pl.Trainer(**pred_trainer_kwargs) fcsts = trainer.predict(self, datamodule=datamodule) - fcsts = torch.vstack(fcsts).numpy() + fcsts = torch.vstack(fcsts) + if fcsts.dtype == torch.bfloat16: + fcsts = fcsts.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + else: + fcsts = fcsts.numpy() fcsts = np.transpose(fcsts, (2, 0, 1)) fcsts = fcsts.flatten()