Skip to content

Commit

Permalink
Fix a bug in loss metrics calculation, and allow specifying `score_ma…
Browse files Browse the repository at this point in the history
…sk` in `tfrs.tasks.Retrieval.call`

PiperOrigin-RevId: 663490989
  • Loading branch information
TensorFlow Recommenders Authors committed Aug 16, 2024
1 parent 5e0629c commit 8c8dd13
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 22 deletions.
43 changes: 27 additions & 16 deletions tensorflow_recommenders/tasks/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
# Lint-as: python3
"""A factorized retrieval task."""

from typing import Optional, Sequence, Union, Text, List
from typing import List, Optional, Sequence, Text, Union

import numpy as np
import tensorflow as tf

from tensorflow_recommenders import layers
from tensorflow_recommenders import metrics as tfrs_metrics
from tensorflow_recommenders.tasks import base

MIN_FLOAT = np.finfo(np.float32).min / 100.0


class Retrieval(tf.keras.layers.Layer, base.Task):
"""A factorized retrieval task.
Expand Down Expand Up @@ -116,14 +118,17 @@ def factorized_metrics(self,

self._factorized_metrics = value

def call(self,
query_embeddings: tf.Tensor,
candidate_embeddings: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
candidate_sampling_probability: Optional[tf.Tensor] = None,
candidate_ids: Optional[tf.Tensor] = None,
compute_metrics: bool = True,
compute_batch_metrics: bool = True) -> tf.Tensor:
def call(
self,
query_embeddings: tf.Tensor,
candidate_embeddings: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
candidate_sampling_probability: Optional[tf.Tensor] = None,
candidate_ids: Optional[tf.Tensor] = None,
compute_metrics: bool = True,
compute_batch_metrics: bool = True,
score_mask: Optional[tf.Tensor] = None,
) -> tf.Tensor:
"""Computes the task loss and metrics.
The main argument are pairs of query and candidate embeddings: the first row
Expand All @@ -149,10 +154,14 @@ def call(self,
reflect the sampling probability of negative candidates.
candidate_ids: Optional tensor containing candidate ids. When given,
factorized top-K evaluation will be id-based rather than score-based.
compute_metrics: Whether to compute metrics. Set this to False
during training for faster training.
compute_batch_metrics: Whether to compute batch level metrics.
In-batch loss_metrics will still be computed.
compute_metrics: Whether to compute metrics. Set this to False during
training for faster training.
compute_batch_metrics: Whether to compute batch level metrics. In-batch
loss_metrics will still be computed.
score_mask: [num_queries, num_candidates] boolean tensor indicating for
each query, which candidates should be considered for loss and
metrics computation (false means the candidate is not considered).
Returns:
loss: Tensor of loss values.
"""
Expand Down Expand Up @@ -180,6 +189,9 @@ def call(self,
)
scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids)

if score_mask is not None:
scores = tf.where(score_mask, scores, MIN_FLOAT)

if self._num_hard_negatives is not None:
scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)(
scores,
Expand All @@ -189,8 +201,7 @@ def call(self,

update_ops = []
for metric in self._loss_metrics:
update_ops.append(
metric.update_state(loss, sample_weight=sample_weight))
update_ops.append(metric.update_state(loss))

if compute_metrics:
for metric in self._factorized_metrics:
Expand Down
49 changes: 43 additions & 6 deletions tensorflow_recommenders/tasks/retrieval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@ def test_task(self):

task = retrieval.Retrieval(
metrics=metrics.FactorizedTopK(
candidates=candidate_dataset.batch(16),
ks=[5]
candidates=candidate_dataset.batch(16), ks=[5]
),
batch_metrics=[
tf.keras.metrics.TopKCategoricalAccuracy(
k=1, name="batch_categorical_accuracy_at_1")
])
k=1, name="batch_categorical_accuracy_at_1"
)
],
loss_metrics=[
tf.keras.metrics.Mean(
name="batch_loss",
dtype=tf.float32,
)
],
)

# All_pair_scores: [[6, 3], [9, 5]].
# Normalized logits: [[3, 0], [4, 0]].
Expand All @@ -52,6 +59,7 @@ def test_task(self):
expected_metrics = {
"factorized_top_k/top_5_categorical_accuracy": 1.0,
"batch_categorical_accuracy_at_1": 0.5,
"batch_loss": expected_loss,
}
loss = task(query_embeddings=query, candidate_embeddings=candidate)
metrics_ = {
Expand All @@ -70,7 +78,8 @@ def test_task(self):
compute_metrics=False)
expected_metrics1 = {
"factorized_top_k/top_5_categorical_accuracy": 0.0,
"batch_categorical_accuracy_at_1": 0.5
"batch_categorical_accuracy_at_1": 0.5,
"batch_loss": loss,
}
metrics1_ = {
metric.name: metric.result().numpy() for metric in task.metrics
Expand All @@ -89,7 +98,8 @@ def test_task(self):
compute_batch_metrics=False)
expected_metrics2 = {
"factorized_top_k/top_5_categorical_accuracy": 1.0,
"batch_categorical_accuracy_at_1": 0.0
"batch_categorical_accuracy_at_1": 0.0,
"batch_loss": loss,
}
metrics2_ = {
metric.name: metric.result().numpy() for metric in task.metrics
Expand All @@ -99,6 +109,33 @@ def test_task(self):
self.assertAllClose(expected_loss, loss)
self.assertAllClose(expected_metrics2, metrics2_)

# Test computation of metrics with sample_weight
for metric in task.metrics:
metric.reset_states()
loss = task(
query_embeddings=query,
candidate_embeddings=candidate,
sample_weight=tf.constant([0.7, 0.3], dtype=tf.float32),
)

# All_pair_scores: [[6, 3], [9, 5]].
# Normalized logits: [[3, 0], [4, 0]].
expected_loss3 = -0.7 * np.log(_sigmoid(3.0)) - 0.3 * np.log(
1 - _sigmoid(4.0)
)

expected_metrics3 = {
"factorized_top_k/top_5_categorical_accuracy": 1.0,
"batch_categorical_accuracy_at_1": 0.7,
"batch_loss": expected_loss3,
}
metrics3_ = {
metric.name: metric.result().numpy() for metric in task.metrics
}
self.assertIsNotNone(loss)
self.assertAllClose(expected_loss3, loss)
self.assertAllClose(expected_metrics3, metrics3_)

def test_task_graph(self):

with tf.Graph().as_default():
Expand Down

0 comments on commit 8c8dd13

Please sign in to comment.