From af15d97a9289ab29f48e913f339711bdbe681a7a Mon Sep 17 00:00:00 2001 From: Lihanyiyi <99228366+Lihanyiyi@users.noreply.github.com> Date: Tue, 21 May 2024 12:40:45 +0800 Subject: [PATCH 1/2] Update pytorch_alstm_ts.py --- qlib/contrib/model/pytorch_alstm_ts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 008d789402..d934fb87ba 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -158,7 +158,7 @@ def loss_fn(self, pred, label, weight=None): def metric_fn(self, pred, label): mask = torch.isfinite(label) - if self.metric in ("", "loss"): + if self.metric in ("", "loss","mse"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) From 7d77d54cfd6b1deb3acb1a710036bc70fb5f2093 Mon Sep 17 00:00:00 2001 From: Lihanyiyi <99228366+Lihanyiyi@users.noreply.github.com> Date: Tue, 21 May 2024 12:42:21 +0800 Subject: [PATCH 2/2] Update pytorch_alstm.py --- qlib/contrib/model/pytorch_alstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index 2fe7cce3b0..084a071dba 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -148,7 +148,7 @@ def loss_fn(self, pred, label): def metric_fn(self, pred, label): mask = torch.isfinite(label) - if self.metric in ("", "loss"): + if self.metric in ("", "loss","mse"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric)