-
Notifications
You must be signed in to change notification settings - Fork 0
/
lambdaloss.py
78 lines (65 loc) · 3.83 KB
/
lambdaloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
def lambdaLoss(y_pred, y_true, eps=1e-10, padded_value_indicator=-1, weighing_scheme=None, k=None, sigma=1., mu=10.,
reduction="sum", reduction_log="binary"):
"""
LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:param weighing_scheme: a string corresponding to a name of one of the weighing schemes
:param k: rank at which the loss is truncated
:param sigma: score difference weight used in the sigmoid function
:param mu: optional weight used in NDCGLoss2++ weighing scheme
:param reduction: losses reduction method, could be either a sum or a mean
:param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
:return: loss value, a torch.Tensor
"""
device = y_pred.device
y_pred = y_pred.clone().unsqueeze(0)
y_true = y_true.clone().unsqueeze(0)
padded_mask = y_true == padded_value_indicator
y_pred[padded_mask] = float("-inf")
y_true[padded_mask] = float("-inf")
# Here we sort the true and predicted relevancy scores.
y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
# After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
padded_pairs_mask = torch.isfinite(true_diffs)
if weighing_scheme != "ndcgLoss1_scheme":
padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)
ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device)
ndcg_at_k_mask[:k, :k] = 1
# Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
true_sorted_by_preds.clamp_(min=0.)
y_true_sorted.clamp_(min=0.)
# Here we find the gains, discounts and ideal DCGs per slate.
pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
D = torch.log2(1. + pos_idxs.float())[None, :]
maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps)
G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]
# Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0)
if weighing_scheme is None:
weights = 1.
else:
weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds) # type: ignore
# We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.)
weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
if reduction_log == "natural":
losses = torch.log(weighted_probas)
elif reduction_log == "binary":
losses = torch.log2(weighted_probas)
else:
raise ValueError("Reduction logarithm base can be either natural or binary")
if reduction == "sum":
loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask])
elif reduction == "mean":
loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask])
else:
raise ValueError("Reduction method can be either sum or mean")
return loss