diff --git a/tensorflow_ranking/python/metrics.py b/tensorflow_ranking/python/metrics.py index ae6f658..4a83244 100644 --- a/tensorflow_ranking/python/metrics.py +++ b/tensorflow_ranking/python/metrics.py @@ -260,10 +260,13 @@ def compute(self, labels, predictions, weights): # Relevance = 1.0 when labels >= 1.0 to accommodate graded relevance. relevance = tf.cast(tf.greater_equal(sorted_labels, 1.0), dtype=tf.float32) reciprocal_rank = 1.0 / tf.cast(tf.range(1, topn + 1), dtype=tf.float32) - # MRR has a shape of [batch_size, 1] + # MRR has a shape of [batch_size, 1]. mrr = tf.reduce_max( input_tensor=relevance * reciprocal_rank, axis=1, keepdims=True) - return tf.compat.v1.metrics.mean(mrr * tf.ones_like(weights), weights) + per_list_weights = _per_example_weights_to_per_list_weights( + weights=weights, + relevance=tf.cast(tf.greater_equal(labels, 1.0), dtype=tf.float32)) + return tf.compat.v1.metrics.mean(mrr, per_list_weights) def mean_reciprocal_rank(labels, predictions, weights=None, name=None): diff --git a/tensorflow_ranking/python/metrics_test.py b/tensorflow_ranking/python/metrics_test.py index 6594917..9f3c2b2 100644 --- a/tensorflow_ranking/python/metrics_test.py +++ b/tensorflow_ranking/python/metrics_test.py @@ -108,7 +108,8 @@ def test_mean_reciprocal_rank(self): self._check_metrics([ (m([labels[0]], [scores[0]]), 0.5), (m(labels, scores), (0.5 + 1.0) / 2), - (m(labels, scores, weights), (6. * 0.5 + 15. * 1.) / (6. + 15.)), + (m(labels, scores, + weights), (3. * 0.5 + (6. + 5.) / 2. * 1.) / (3. + (6. + 5) / 2.)), ]) def test_make_mean_reciprocal_rank_fn(self): @@ -125,7 +126,8 @@ def test_make_mean_reciprocal_rank_fn(self): self._check_metrics([ (m([labels[0]], [scores[0]], features), 0.5), (m(labels, scores, features), (0.5 + 1.0) / 2), - (m_w(labels, scores, features), (6. * 0.5 + 15. * 1.) / (6. + 15.)), + (m_w(labels, scores, features), + (3. * 0.5 + (6. + 5.) / 2. * 1.) / (3. + (6. + 5.) / 2.)), ]) def test_average_relevance_position(self):