From e1ba188c6833f199b81a34db880ff537fcb79242 Mon Sep 17 00:00:00 2001 From: TensorFlow Recommenders Authors Date: Thu, 24 Oct 2024 09:22:42 -0700 Subject: [PATCH] Add support for multipoint query retrieval to TFRS Retrieval task. PiperOrigin-RevId: 689408268 --- tensorflow_recommenders/tasks/retrieval.py | 18 +++++-- .../tasks/retrieval_test.py | 48 +++++++++++++++++++ 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/tensorflow_recommenders/tasks/retrieval.py b/tensorflow_recommenders/tasks/retrieval.py index 3347ca6..803eecc 100644 --- a/tensorflow_recommenders/tasks/retrieval.py +++ b/tensorflow_recommenders/tasks/retrieval.py @@ -141,7 +141,10 @@ def call( Args: query_embeddings: [num_queries, embedding_dim] tensor of query - representations. + representations, or [num_queries, num_heads, embedding_dim]. If latter, + we do "maxsim" scoring over those multiple query heads. This applies to + the loss computation and batch metrics. Factorized metrics won't be + computed in this case. candidate_embeddings: [num_candidates, embedding_dim] tensor of candidate representations. Normally, `num_candidates` is the same as `num_queries`: there is a positive candidate corresponding for every @@ -166,8 +169,15 @@ def call( loss: Tensor of loss values. """ - scores = tf.linalg.matmul( - query_embeddings, candidate_embeddings, transpose_b=True) + if len(tf.shape(query_embeddings)) == 3: + scores = tf.einsum( + "qne,ce->qnc", query_embeddings, candidate_embeddings + ) + scores = tf.math.reduce_max(scores, axis=1) + else: + scores = tf.linalg.matmul( + query_embeddings, candidate_embeddings, transpose_b=True + ) num_queries = tf.shape(scores)[0] num_candidates = tf.shape(scores)[1] @@ -203,7 +213,7 @@ def call( for metric in self._loss_metrics: update_ops.append(metric.update_state(loss)) - if compute_metrics: + if compute_metrics and len(tf.shape(query_embeddings)) == 2: for metric in self._factorized_metrics: update_ops.append( metric.update_state( diff --git a/tensorflow_recommenders/tasks/retrieval_test.py b/tensorflow_recommenders/tasks/retrieval_test.py index e8a9a04..1af7cd9 100644 --- a/tensorflow_recommenders/tasks/retrieval_test.py +++ b/tensorflow_recommenders/tasks/retrieval_test.py @@ -250,5 +250,53 @@ def test_task(self): self.assertAllClose(expected_metrics2, metrics2_) +class RetrievalTestWithMultipointQueries(tf.test.TestCase): + + def test_task(self): + + query = tf.constant( + [[[3, 2, 1], [1, 2, 3]], [[2, 3, 4], [4, 3, 2]]], dtype=tf.float32 + ) + candidate = tf.constant([[0, 1, 0], [0, 1, 1], [1, 1, 0]], dtype=tf.float32) + candidate_dataset = tf.data.Dataset.from_tensor_slices( + np.array([[0, 0, 0]] * 20, dtype=np.float32) + ) + + task = retrieval.Retrieval( + metrics=metrics.FactorizedTopK( + candidates=candidate_dataset.batch(16), ks=[5] + ), + batch_metrics=[ + tf.keras.metrics.TopKCategoricalAccuracy( + k=1, name="batch_categorical_accuracy_at_1" + ) + ], + ) + + # Scores will have shape [num_queries, num_candidates] + # All_pair_scores: [[[2,2], [3,5], [5,3]], [[3, 3], [7,5], [5,7]]]. + # Max-sim scores: [[2, 5, 5], [3, 7, 7]]. + # Normalized logits: [[0, 3, 3], [1, 5, 5]]. + expected_loss = -np.log(1 / (1 + np.exp(3) + np.exp(3))) - np.log( + np.exp(5) / (np.exp(1) + np.exp(5) + np.exp(5)) + ) + + expected_metrics = { + "factorized_top_k/top_5_categorical_accuracy": ( + 0.0 + ), # not computed for multipoint queries + "batch_categorical_accuracy_at_1": 0.5, + } + loss = task( + query_embeddings=query, + candidate_embeddings=candidate, + ) + metrics_ = {metric.name: metric.result().numpy() for metric in task.metrics} + + self.assertIsNotNone(loss) + self.assertAllClose(expected_loss, loss) + self.assertAllClose(expected_metrics, metrics_) + + if __name__ == "__main__": tf.test.main()