Skip to content

Commit

Permalink
Make sure session_ids Tensor on the same device as model output (pyto…
Browse files Browse the repository at this point in the history
…rch#2274)

Summary:
Pull Request resolved: pytorch#2274

in order to support running NDCGMetric (https://fburl.com/code/c25larm3) on gpu, `session_ids_to_tensor` should create Tensor on the same device as  model outputs (labels).

Reviewed By: joshuadeng

Differential Revision: D60800906

fbshipit-source-id: 754d3d987e7ac59b9d6aef64a94b6a2d259c06aa
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Aug 7, 2024
1 parent d50d319 commit 07dd9b9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
15 changes: 11 additions & 4 deletions torchrec/metrics/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from torchrec.metrics.rec_metric import RecTaskInfo


def session_ids_to_tensor(session_ids: List[str]) -> torch.Tensor:
def session_ids_to_tensor(
session_ids: List[str],
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
This function is used to prepare model outputs with session_ids as List[str] to tensor to be consumed by the Metric computation
"""
Expand All @@ -28,7 +31,7 @@ def session_ids_to_tensor(session_ids: List[str]) -> torch.Tensor:
curr_id += 1

session_lengths_list.append(curr_id)
return torch.tensor(session_lengths_list[1:])
return torch.tensor(session_lengths_list[1:], device=device)


def is_empty_signals(
Expand Down Expand Up @@ -86,14 +89,15 @@ def parse_required_inputs(
model_out: Dict[str, torch.Tensor],
required_inputs_list: List[str],
ndcg_transform_input: bool = False,
device: Optional[torch.device] = None,
) -> Dict[str, torch.Tensor]:
required_inputs: Dict[str, torch.Tensor] = {}
for feature in required_inputs_list:
# convert feature defined from config only
if ndcg_transform_input:
model_out[feature] = (
# pyre-ignore[6]
session_ids_to_tensor(model_out[feature])
session_ids_to_tensor(model_out[feature], device=device)
if isinstance(model_out[feature], list)
else model_out[feature]
)
Expand Down Expand Up @@ -136,7 +140,10 @@ def parse_task_model_outputs(

if required_inputs_list is not None:
all_required_inputs = parse_required_inputs(
model_out, required_inputs_list, ndcg_transform_input
model_out,
required_inputs_list,
ndcg_transform_input,
device=labels.device,
)

return all_labels, all_predictions, all_weights, all_required_inputs
31 changes: 30 additions & 1 deletion torchrec/metrics/tests/test_recmetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
import unittest

import torch
from torchrec.metrics.metrics_config import DefaultTaskInfo
from torchrec.metrics.metrics_config import DefaultTaskInfo, RecTaskInfo
from torchrec.metrics.model_utils import parse_task_model_outputs
from torchrec.metrics.mse import MSEMetric
from torchrec.metrics.ne import NEMetric
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import gen_test_batch, gen_test_tasks


_CUDA_UNAVAILABLE: bool = not torch.cuda.is_available()


class RecMetricTest(unittest.TestCase):
def setUp(self) -> None:
# Create testing labels, predictions and weights
Expand Down Expand Up @@ -272,3 +275,29 @@ def test_reset(self) -> None:
ne.reset()
window_buffer = ne._batch_window_buffers["window_cross_entropy_sum"].buffers
self.assertEqual(len(window_buffer), 0)

@unittest.skipIf(_CUDA_UNAVAILABLE, "Test needs to run on GPU")
def test_parse_task_model_outputs_ndcg(self) -> None:
_, _, _, required_inputs = parse_task_model_outputs(
tasks=[
RecTaskInfo(
name="ndcg_example",
),
],
# pyre-fixme[6]: for argument model_out, expected Dict[str, Tensor] but
# got Dict[str, Union[List[str], Tensor]]
model_out={
"label": torch.tensor(
[0.0, 1.0, 0.0, 1.0], device=torch.device("cuda:0")
),
"weight": torch.tensor(
[1.0, 1.0, 1.0, 1.0], device=torch.device("cuda:0")
),
"prediction": torch.tensor(
[0.0, 1.0, 0.0, 1.0], device=torch.device("cuda:0")
),
"session_id": ["1", "1", "2", "2"],
},
required_inputs_list=["session_id"],
)
self.assertEqual(required_inputs["session_id"].device, torch.device("cuda:0"))

0 comments on commit 07dd9b9

Please sign in to comment.