diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index 2fe7cce3b0..084a071dba 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -148,7 +148,7 @@ def loss_fn(self, pred, label): def metric_fn(self, pred, label): mask = torch.isfinite(label) - if self.metric in ("", "loss"): + if self.metric in ("", "loss","mse"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 008d789402..d934fb87ba 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -158,7 +158,7 @@ def loss_fn(self, pred, label, weight=None): def metric_fn(self, pred, label): mask = torch.isfinite(label) - if self.metric in ("", "loss"): + if self.metric in ("", "loss","mse"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric)