Skip to content

Commit

Permalink
Add support for multipoint query retrieval to TFRS Retrieval task.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689532029
  • Loading branch information
TensorFlow Recommenders Authors committed Oct 24, 2024
1 parent 685694e commit 36a1836
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 4 deletions.
18 changes: 14 additions & 4 deletions tensorflow_recommenders/tasks/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def call(
Args:
query_embeddings: [num_queries, embedding_dim] tensor of query
representations.
representations, or [num_queries, num_heads, embedding_dim]. If latter,
we do "maxsim" scoring over those multiple query heads. This applies to
the loss computation and batch metrics. Factorized metrics won't be
computed in this case.
candidate_embeddings: [num_candidates, embedding_dim] tensor of candidate
representations. Normally, `num_candidates` is the same as
`num_queries`: there is a positive candidate corresponding for every
Expand All @@ -166,8 +169,15 @@ def call(
loss: Tensor of loss values.
"""

scores = tf.linalg.matmul(
query_embeddings, candidate_embeddings, transpose_b=True)
if len(tf.shape(query_embeddings)) == 3:
scores = tf.einsum(
"qne,ce->qnc", query_embeddings, candidate_embeddings
)
scores = tf.math.reduce_max(scores, axis=1)
else:
scores = tf.linalg.matmul(
query_embeddings, candidate_embeddings, transpose_b=True
)

num_queries = tf.shape(scores)[0]
num_candidates = tf.shape(scores)[1]
Expand Down Expand Up @@ -203,7 +213,7 @@ def call(
for metric in self._loss_metrics:
update_ops.append(metric.update_state(loss))

if compute_metrics:
if compute_metrics and len(tf.shape(query_embeddings)) == 2:
for metric in self._factorized_metrics:
update_ops.append(
metric.update_state(
Expand Down
48 changes: 48 additions & 0 deletions tensorflow_recommenders/tasks/retrieval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,53 @@ def test_task(self):
self.assertAllClose(expected_metrics2, metrics2_)


class RetrievalTestWithMultipointQueries(tf.test.TestCase):

def test_task(self):

query = tf.constant(
[[[3, 2, 1], [1, 2, 3]], [[2, 3, 4], [4, 3, 2]]], dtype=tf.float32
)
candidate = tf.constant([[0, 1, 0], [0, 1, 1], [1, 1, 0]], dtype=tf.float32)
candidate_dataset = tf.data.Dataset.from_tensor_slices(
np.array([[0, 0, 0]] * 20, dtype=np.float32)
)

task = retrieval.Retrieval(
metrics=metrics.FactorizedTopK(
candidates=candidate_dataset.batch(16), ks=[5]
),
batch_metrics=[
tf.keras.metrics.TopKCategoricalAccuracy(
k=1, name="batch_categorical_accuracy_at_1"
)
],
)

# Scores will have shape [num_queries, num_candidates]
# All_pair_scores: [[[2,2], [3,5], [5,3]], [[3, 3], [7,5], [5,7]]].
# Max-sim scores: [[2, 5, 5], [3, 7, 7]].
# Normalized logits: [[0, 3, 3], [1, 5, 5]].
expected_loss = -np.log(1 / (1 + np.exp(3) + np.exp(3))) - np.log(
np.exp(5) / (np.exp(1) + np.exp(5) + np.exp(5))
)

expected_metrics = {
"factorized_top_k/top_5_categorical_accuracy": (
0.0
), # not computed for multipoint queries
"batch_categorical_accuracy_at_1": 0.5,
}
loss = task(
query_embeddings=query,
candidate_embeddings=candidate,
)
metrics_ = {metric.name: metric.result().numpy() for metric in task.metrics}

self.assertIsNotNone(loss)
self.assertAllClose(expected_loss, loss)
self.assertAllClose(expected_metrics, metrics_)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 36a1836

Please sign in to comment.