From 8c8dd131a46883794fde3ac987f26488ad76ec98 Mon Sep 17 00:00:00 2001 From: TensorFlow Recommenders Authors Date: Thu, 15 Aug 2024 16:15:04 -0700 Subject: [PATCH] Fix a bug in loss metrics calculation, and allow specifying `score_mask` in `tfrs.tasks.Retrieval.call` PiperOrigin-RevId: 663490989 --- tensorflow_recommenders/tasks/retrieval.py | 43 ++++++++++------ .../tasks/retrieval_test.py | 49 ++++++++++++++++--- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/tensorflow_recommenders/tasks/retrieval.py b/tensorflow_recommenders/tasks/retrieval.py index c4b7204..3347ca6 100644 --- a/tensorflow_recommenders/tasks/retrieval.py +++ b/tensorflow_recommenders/tasks/retrieval.py @@ -15,14 +15,16 @@ # Lint-as: python3 """A factorized retrieval task.""" -from typing import Optional, Sequence, Union, Text, List +from typing import List, Optional, Sequence, Text, Union +import numpy as np import tensorflow as tf - from tensorflow_recommenders import layers from tensorflow_recommenders import metrics as tfrs_metrics from tensorflow_recommenders.tasks import base +MIN_FLOAT = np.finfo(np.float32).min / 100.0 + class Retrieval(tf.keras.layers.Layer, base.Task): """A factorized retrieval task. @@ -116,14 +118,17 @@ def factorized_metrics(self, self._factorized_metrics = value - def call(self, - query_embeddings: tf.Tensor, - candidate_embeddings: tf.Tensor, - sample_weight: Optional[tf.Tensor] = None, - candidate_sampling_probability: Optional[tf.Tensor] = None, - candidate_ids: Optional[tf.Tensor] = None, - compute_metrics: bool = True, - compute_batch_metrics: bool = True) -> tf.Tensor: + def call( + self, + query_embeddings: tf.Tensor, + candidate_embeddings: tf.Tensor, + sample_weight: Optional[tf.Tensor] = None, + candidate_sampling_probability: Optional[tf.Tensor] = None, + candidate_ids: Optional[tf.Tensor] = None, + compute_metrics: bool = True, + compute_batch_metrics: bool = True, + score_mask: Optional[tf.Tensor] = None, + ) -> tf.Tensor: """Computes the task loss and metrics. The main argument are pairs of query and candidate embeddings: the first row @@ -149,10 +154,14 @@ def call(self, reflect the sampling probability of negative candidates. candidate_ids: Optional tensor containing candidate ids. When given, factorized top-K evaluation will be id-based rather than score-based. - compute_metrics: Whether to compute metrics. Set this to False - during training for faster training. - compute_batch_metrics: Whether to compute batch level metrics. - In-batch loss_metrics will still be computed. + compute_metrics: Whether to compute metrics. Set this to False during + training for faster training. + compute_batch_metrics: Whether to compute batch level metrics. In-batch + loss_metrics will still be computed. + score_mask: [num_queries, num_candidates] boolean tensor indicating for + each query, which candidates should be considered for loss and + metrics computation (false means the candidate is not considered). + Returns: loss: Tensor of loss values. """ @@ -180,6 +189,9 @@ def call(self, ) scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids) + if score_mask is not None: + scores = tf.where(score_mask, scores, MIN_FLOAT) + if self._num_hard_negatives is not None: scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)( scores, @@ -189,8 +201,7 @@ def call(self, update_ops = [] for metric in self._loss_metrics: - update_ops.append( - metric.update_state(loss, sample_weight=sample_weight)) + update_ops.append(metric.update_state(loss)) if compute_metrics: for metric in self._factorized_metrics: diff --git a/tensorflow_recommenders/tasks/retrieval_test.py b/tensorflow_recommenders/tasks/retrieval_test.py index 0bb409a..e8a9a04 100644 --- a/tensorflow_recommenders/tasks/retrieval_test.py +++ b/tensorflow_recommenders/tasks/retrieval_test.py @@ -37,13 +37,20 @@ def test_task(self): task = retrieval.Retrieval( metrics=metrics.FactorizedTopK( - candidates=candidate_dataset.batch(16), - ks=[5] + candidates=candidate_dataset.batch(16), ks=[5] ), batch_metrics=[ tf.keras.metrics.TopKCategoricalAccuracy( - k=1, name="batch_categorical_accuracy_at_1") - ]) + k=1, name="batch_categorical_accuracy_at_1" + ) + ], + loss_metrics=[ + tf.keras.metrics.Mean( + name="batch_loss", + dtype=tf.float32, + ) + ], + ) # All_pair_scores: [[6, 3], [9, 5]]. # Normalized logits: [[3, 0], [4, 0]]. @@ -52,6 +59,7 @@ def test_task(self): expected_metrics = { "factorized_top_k/top_5_categorical_accuracy": 1.0, "batch_categorical_accuracy_at_1": 0.5, + "batch_loss": expected_loss, } loss = task(query_embeddings=query, candidate_embeddings=candidate) metrics_ = { @@ -70,7 +78,8 @@ def test_task(self): compute_metrics=False) expected_metrics1 = { "factorized_top_k/top_5_categorical_accuracy": 0.0, - "batch_categorical_accuracy_at_1": 0.5 + "batch_categorical_accuracy_at_1": 0.5, + "batch_loss": loss, } metrics1_ = { metric.name: metric.result().numpy() for metric in task.metrics @@ -89,7 +98,8 @@ def test_task(self): compute_batch_metrics=False) expected_metrics2 = { "factorized_top_k/top_5_categorical_accuracy": 1.0, - "batch_categorical_accuracy_at_1": 0.0 + "batch_categorical_accuracy_at_1": 0.0, + "batch_loss": loss, } metrics2_ = { metric.name: metric.result().numpy() for metric in task.metrics @@ -99,6 +109,33 @@ def test_task(self): self.assertAllClose(expected_loss, loss) self.assertAllClose(expected_metrics2, metrics2_) + # Test computation of metrics with sample_weight + for metric in task.metrics: + metric.reset_states() + loss = task( + query_embeddings=query, + candidate_embeddings=candidate, + sample_weight=tf.constant([0.7, 0.3], dtype=tf.float32), + ) + + # All_pair_scores: [[6, 3], [9, 5]]. + # Normalized logits: [[3, 0], [4, 0]]. + expected_loss3 = -0.7 * np.log(_sigmoid(3.0)) - 0.3 * np.log( + 1 - _sigmoid(4.0) + ) + + expected_metrics3 = { + "factorized_top_k/top_5_categorical_accuracy": 1.0, + "batch_categorical_accuracy_at_1": 0.7, + "batch_loss": expected_loss3, + } + metrics3_ = { + metric.name: metric.result().numpy() for metric in task.metrics + } + self.assertIsNotNone(loss) + self.assertAllClose(expected_loss3, loss) + self.assertAllClose(expected_metrics3, metrics3_) + def test_task_graph(self): with tf.Graph().as_default():