Skip to content

Commit

Permalink
Update Segment NE to allow for more customization (#2684)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2684

Update SegmentNE definition to allow:
1. Providing custom grouping_key tensor name
2. Casting grouping_key tensor to ints if the original dtype is float32

Reviewed By: iamzainhuda

Differential Revision: D68172071

fbshipit-source-id: ea9056f608bca9ad430c3b2912358397849d3a12
  • Loading branch information
Shashank Bhushan authored and facebook-github-bot committed Jan 16, 2025
1 parent 4685bac commit 8cf154d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
38 changes: 29 additions & 9 deletions torchrec/metrics/segmented_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,25 @@ class SegmentedNEMetricComputation(RecMetricComputation):
Args:
include_logloss (bool): return vanilla logloss as one of metrics results, on top of segmented NE.
num_groups (int): number of groups to segment NE by.
grouping_keys (str): name of the tensor containing the label by which results will be segmented. This tensor should be of type torch.int64.
cast_keys_to_int (bool): whether to cast grouping_keys to torch.int64. Only works if grouping_keys is of type torch.float32.
"""

def __init__(
self,
*args: Any,
include_logloss: bool = False, # TODO - include
num_groups: int = 1,
grouping_keys: str = "grouping_keys",
cast_keys_to_int: bool = False,
**kwargs: Any,
) -> None:
self._include_logloss: bool = include_logloss
super().__init__(*args, **kwargs)
self._num_groups = num_groups # would there be checkpointing issues with this? maybe make this state
self._grouping_keys = grouping_keys
self._cast_keys_to_int = cast_keys_to_int
self._add_state(
"cross_entropy_sum",
torch.zeros((self._n_tasks, num_groups), dtype=torch.double),
Expand Down Expand Up @@ -217,21 +224,30 @@ def update(
) -> None:
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' and 'grouping_keys' should not be None for NEMetricComputation update"
f"Inputs 'predictions' and 'weights' and '{self._grouping_keys}' should not be None for NEMetricComputation update"
)
elif (
"required_inputs" not in kwargs
or kwargs["required_inputs"].get("grouping_keys") is None
or kwargs["required_inputs"].get(self._grouping_keys) is None
):
raise RecMetricException(
f"Required inputs for SegmentedNEMetricComputation update should contain 'grouping_keys', got kwargs: {kwargs}"
)
elif kwargs["required_inputs"]["grouping_keys"].dtype != torch.int64:
raise RecMetricException(
f"Grouping keys must have type torch.int64, got {kwargs['required_inputs']['grouping_keys'].dtype}."
f"Required inputs for SegmentedNEMetricComputation update should contain {self._grouping_keys}, got kwargs: {kwargs}"
)
elif kwargs["required_inputs"][self._grouping_keys].dtype != torch.int64:
if (
self._cast_keys_to_int
and kwargs["required_inputs"][self._grouping_keys].dtype
== torch.float32
):
kwargs["required_inputs"][self._grouping_keys] = kwargs[
"required_inputs"
][self._grouping_keys].to(torch.int64)
else:
raise RecMetricException(
f"Grouping keys expected to have type torch.int64 or torch.float32 with cast_keys_to_int set to true, got {kwargs['required_inputs'][self._grouping_keys].dtype}."
)

grouping_keys = kwargs["required_inputs"]["grouping_keys"]
grouping_keys = kwargs["required_inputs"][self._grouping_keys]
states = get_segemented_ne_states(
labels,
predictions,
Expand Down Expand Up @@ -325,4 +341,8 @@ def __init__(
process_group=process_group,
**kwargs,
)
self._required_inputs.add("grouping_keys")
if "grouping_keys" not in kwargs:
self._required_inputs.add("grouping_keys")
else:
# pyre-ignore[6]
self._required_inputs.add(kwargs["grouping_keys"])
31 changes: 28 additions & 3 deletions torchrec/metrics/tests/test_segmented_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Iterable, Union
from typing import Any, Dict, Iterable, Union

import torch
from torch import no_grad
Expand All @@ -31,6 +31,8 @@ def _test_segemented_ne_helper(
weights: torch.Tensor,
expected_ne: torch.Tensor,
grouping_keys: torch.Tensor,
grouping_key_tensor_name: str = "grouping_keys",
cast_keys_to_int: bool = False,
) -> None:
num_task = labels.shape[0]
batch_size = labels.shape[0]
Expand All @@ -41,7 +43,7 @@ def _test_segemented_ne_helper(
"weights": {},
}
if grouping_keys is not None:
inputs["required_inputs"] = {"grouping_keys": grouping_keys}
inputs["required_inputs"] = {grouping_key_tensor_name: grouping_keys}
for i in range(num_task):
task_info = RecTaskInfo(
name=f"Task:{i}",
Expand All @@ -64,6 +66,10 @@ def _test_segemented_ne_helper(
tasks=task_list,
# pyre-ignore
num_groups=max(2, torch.unique(grouping_keys)[-1].item() + 1),
# pyre-ignore
grouping_keys=grouping_key_tensor_name,
# pyre-ignore
cast_keys_to_int=cast_keys_to_int,
)
ne.update(**inputs)
actual_ne = ne.compute()
Expand Down Expand Up @@ -95,7 +101,7 @@ def test_grouped_ne(self) -> None:
raise


def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
def generate_model_outputs_cases() -> Iterable[Dict[str, Any]]:
return [
# base condition
{
Expand Down Expand Up @@ -149,4 +155,23 @@ def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
), # for this case, both tasks have same groupings
"expected_ne": torch.tensor([[3.1615, 1.6004], [1.0034, 0.4859]]),
},
# Custom grouping key tensor name
{
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
"predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
"weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]),
"grouping_keys": torch.tensor([0, 1, 0, 1, 1]),
"expected_ne": torch.tensor([[3.1615, 1.6004]]),
"grouping_key_tensor_name": "custom_key",
},
# Cast grouping keys to int32
{
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
"predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
"weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]),
"grouping_keys": torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0]),
"expected_ne": torch.tensor([[3.1615, 1.6004]]),
"grouping_key_tensor_name": "custom_key",
"cast_keys_to_int": True,
},
]

0 comments on commit 8cf154d

Please sign in to comment.