From c7dc6beaf401fbdcf4175a3fe4c271b81cc53a03 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 9 May 2024 12:27:55 +0800 Subject: [PATCH] feat: add BaseLoss and BaseMetric; --- pypots/nn/modules/loss.py | 28 ++++++++++++++++++++++++++++ pypots/nn/modules/metric.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 pypots/nn/modules/loss.py create mode 100644 pypots/nn/modules/metric.py diff --git a/pypots/nn/modules/loss.py b/pypots/nn/modules/loss.py new file mode 100644 index 00000000..0868d2ca --- /dev/null +++ b/pypots/nn/modules/loss.py @@ -0,0 +1,28 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .metric import BaseMetric +from ..functional import calc_mse + + +class BaseLoss(BaseMetric): + def __init__( + self, + ): + super().__init__() + + def forward(self, prediction, target): + raise NotImplementedError + + +class MAE_Loss(BaseLoss): + def __init__(self): + super().__init__() + + def forward(self, prediction, target, mask=None): + return calc_mse(prediction, target, mask) diff --git a/pypots/nn/modules/metric.py b/pypots/nn/modules/metric.py new file mode 100644 index 00000000..faea3d78 --- /dev/null +++ b/pypots/nn/modules/metric.py @@ -0,0 +1,29 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import torch.nn as nn + +from ..functional import calc_pr_auc + + +class BaseMetric(nn.Module): + def __init__(self, lower_better: bool = True): + super().__init__() + self.lower_better = lower_better + + def forward(self, prediction, target): + raise NotImplementedError + + +class PR_AUC(BaseMetric): + def __init__(self): + super().__init__(lower_better=False) + + def forward(self, prediction, target): + pr_auc, _, _, _ = calc_pr_auc(prediction, target) + return pr_auc