Skip to content

Commit

Permalink
feat: add BaseLoss and BaseMetric;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed May 9, 2024
1 parent 620a95c commit c7dc6be
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
28 changes: 28 additions & 0 deletions pypots/nn/modules/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .metric import BaseMetric
from ..functional import calc_mse


class BaseLoss(BaseMetric):
def __init__(
self,
):
super().__init__()

def forward(self, prediction, target):
raise NotImplementedError


class MAE_Loss(BaseLoss):
def __init__(self):
super().__init__()

def forward(self, prediction, target, mask=None):
return calc_mse(prediction, target, mask)
29 changes: 29 additions & 0 deletions pypots/nn/modules/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


import torch.nn as nn

from ..functional import calc_pr_auc


class BaseMetric(nn.Module):
def __init__(self, lower_better: bool = True):
super().__init__()
self.lower_better = lower_better

def forward(self, prediction, target):
raise NotImplementedError


class PR_AUC(BaseMetric):
def __init__(self):
super().__init__(lower_better=False)

def forward(self, prediction, target):
pr_auc, _, _, _ = calc_pr_auc(prediction, target)
return pr_auc

0 comments on commit c7dc6be

Please sign in to comment.