Skip to content

Commit

Permalink
Add Regression AUC (RAUC) metrics
Browse files Browse the repository at this point in the history
Summary:
Implement regression AUC metrics. Regression AUC is an extension of classification AUC. See Section 4.1.1 in https://arxiv.org/ftp/arxiv/papers/1205/1205.2618.pdf for related discussions.

On a high level, regression AUC is an extension of the traditional AUC for classification through the probabilistic interpretation: the area under the curve is equal to the probability that a classifier will rank a randomly chosen positive instance higher than a randomly chosen negative one.

We utilize merge sort to optimize time complexity to O(nlog(n)).

Reviewed By: zainhuda

Differential Revision: D53377225

fbshipit-source-id: 7dfccbddf9f17a6881c6fd00f9614466c603514d
  • Loading branch information
Z Zhou authored and facebook-github-bot committed Feb 28, 2024
1 parent e719551 commit f1fb67a
Show file tree
Hide file tree
Showing 6 changed files with 896 additions and 3 deletions.
2 changes: 2 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torchrec.metrics.multiclass_recall import MulticlassRecallMetric
from torchrec.metrics.ndcg import NDCGMetric
from torchrec.metrics.ne import NEMetric
from torchrec.metrics.rauc import RAUCMetric
from torchrec.metrics.rec_metric import RecMetric, RecMetricList
from torchrec.metrics.recall_session import RecallSessionMetric
from torchrec.metrics.scalar import ScalarMetric
Expand All @@ -58,6 +59,7 @@
RecMetricEnum.CALIBRATION: CalibrationMetric,
RecMetricEnum.AUC: AUCMetric,
RecMetricEnum.AUPRC: AUPRCMetric,
RecMetricEnum.RAUC: RAUCMetric,
RecMetricEnum.MSE: MSEMetric,
RecMetricEnum.MAE: MAEMetric,
RecMetricEnum.MULTICLASS_RECALL: MulticlassRecallMetric,
Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class RecMetricEnum(RecMetricEnumBase):
CTR = "ctr"
AUC = "auc"
AUPRC = "auprc"
RAUC = "rauc"
CALIBRATION = "calibration"
MSE = "mse"
MAE = "mae"
Expand Down
3 changes: 3 additions & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ class MetricName(MetricNameBase):
RMSE = "rmse"
AUC = "auc"
AUPRC = "auprc"
RAUC = "rauc"
GROUPED_AUC = "grouped_auc"
GROUPED_AUPRC = "grouped_auprc"
GROUPED_RAUC = "grouped_rauc"
RECALL_SESSION_LEVEL = "recall_session_level"
MULTICLASS_RECALL = "multiclass_recall"
WEIGHTED_AVG = "weighted_avg"
Expand All @@ -76,6 +78,7 @@ class MetricNamespace(MetricNamespaceBase):
MSE = "mse"
AUC = "auc"
AUPRC = "auprc"
RAUC = "rauc"
MAE = "mae"
ACCURACY = "accuracy"

Expand Down
Loading

0 comments on commit f1fb67a

Please sign in to comment.