Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Segment NE to allow for more customization #2684

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions torchrec/metrics/segmented_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,25 @@ class SegmentedNEMetricComputation(RecMetricComputation):

Args:
include_logloss (bool): return vanilla logloss as one of metrics results, on top of segmented NE.
num_groups (int): number of groups to segment NE by.
grouping_keys (str): name of the tensor containing the label by which results will be segmented. This tensor should be of type torch.int64.
cast_keys_to_int (bool): whether to cast grouping_keys to torch.int64. Only works if grouping_keys is of type torch.float32.
"""

def __init__(
self,
*args: Any,
include_logloss: bool = False, # TODO - include
num_groups: int = 1,
grouping_keys: str = "grouping_keys",
cast_keys_to_int: bool = False,
**kwargs: Any,
) -> None:
self._include_logloss: bool = include_logloss
super().__init__(*args, **kwargs)
self._num_groups = num_groups # would there be checkpointing issues with this? maybe make this state
self._grouping_keys = grouping_keys
self._cast_keys_to_int = cast_keys_to_int
self._add_state(
"cross_entropy_sum",
torch.zeros((self._n_tasks, num_groups), dtype=torch.double),
Expand Down Expand Up @@ -217,21 +224,30 @@ def update(
) -> None:
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' and 'grouping_keys' should not be None for NEMetricComputation update"
f"Inputs 'predictions' and 'weights' and '{self._grouping_keys}' should not be None for NEMetricComputation update"
)
elif (
"required_inputs" not in kwargs
or kwargs["required_inputs"].get("grouping_keys") is None
or kwargs["required_inputs"].get(self._grouping_keys) is None
):
raise RecMetricException(
f"Required inputs for SegmentedNEMetricComputation update should contain 'grouping_keys', got kwargs: {kwargs}"
)
elif kwargs["required_inputs"]["grouping_keys"].dtype != torch.int64:
raise RecMetricException(
f"Grouping keys must have type torch.int64, got {kwargs['required_inputs']['grouping_keys'].dtype}."
f"Required inputs for SegmentedNEMetricComputation update should contain {self._grouping_keys}, got kwargs: {kwargs}"
)
elif kwargs["required_inputs"][self._grouping_keys].dtype != torch.int64:
if (
self._cast_keys_to_int
and kwargs["required_inputs"][self._grouping_keys].dtype
== torch.float32
):
kwargs["required_inputs"][self._grouping_keys] = kwargs[
"required_inputs"
][self._grouping_keys].to(torch.int64)
else:
raise RecMetricException(
f"Grouping keys expected to have type torch.int64 or torch.float32 with cast_keys_to_int set to true, got {kwargs['required_inputs'][self._grouping_keys].dtype}."
)

grouping_keys = kwargs["required_inputs"]["grouping_keys"]
grouping_keys = kwargs["required_inputs"][self._grouping_keys]
states = get_segemented_ne_states(
labels,
predictions,
Expand Down Expand Up @@ -325,4 +341,8 @@ def __init__(
process_group=process_group,
**kwargs,
)
self._required_inputs.add("grouping_keys")
if "grouping_keys" not in kwargs:
self._required_inputs.add("grouping_keys")
else:
# pyre-ignore[6]
self._required_inputs.add(kwargs["grouping_keys"])
31 changes: 28 additions & 3 deletions torchrec/metrics/tests/test_segmented_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Iterable, Union
from typing import Any, Dict, Iterable, Union

import torch
from torch import no_grad
Expand All @@ -31,6 +31,8 @@ def _test_segemented_ne_helper(
weights: torch.Tensor,
expected_ne: torch.Tensor,
grouping_keys: torch.Tensor,
grouping_key_tensor_name: str = "grouping_keys",
cast_keys_to_int: bool = False,
) -> None:
num_task = labels.shape[0]
batch_size = labels.shape[0]
Expand All @@ -41,7 +43,7 @@ def _test_segemented_ne_helper(
"weights": {},
}
if grouping_keys is not None:
inputs["required_inputs"] = {"grouping_keys": grouping_keys}
inputs["required_inputs"] = {grouping_key_tensor_name: grouping_keys}
for i in range(num_task):
task_info = RecTaskInfo(
name=f"Task:{i}",
Expand All @@ -64,6 +66,10 @@ def _test_segemented_ne_helper(
tasks=task_list,
# pyre-ignore
num_groups=max(2, torch.unique(grouping_keys)[-1].item() + 1),
# pyre-ignore
grouping_keys=grouping_key_tensor_name,
# pyre-ignore
cast_keys_to_int=cast_keys_to_int,
)
ne.update(**inputs)
actual_ne = ne.compute()
Expand Down Expand Up @@ -95,7 +101,7 @@ def test_grouped_ne(self) -> None:
raise


def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
def generate_model_outputs_cases() -> Iterable[Dict[str, Any]]:
return [
# base condition
{
Expand Down Expand Up @@ -149,4 +155,23 @@ def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
), # for this case, both tasks have same groupings
"expected_ne": torch.tensor([[3.1615, 1.6004], [1.0034, 0.4859]]),
},
# Custom grouping key tensor name
{
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
"predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
"weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]),
"grouping_keys": torch.tensor([0, 1, 0, 1, 1]),
"expected_ne": torch.tensor([[3.1615, 1.6004]]),
"grouping_key_tensor_name": "custom_key",
},
# Cast grouping keys to int32
{
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
"predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
"weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]),
"grouping_keys": torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0]),
"expected_ne": torch.tensor([[3.1615, 1.6004]]),
"grouping_key_tensor_name": "custom_key",
"cast_keys_to_int": True,
},
]
Loading