diff --git a/torchrec/metrics/segmented_ne.py b/torchrec/metrics/segmented_ne.py index cf0f01a3d..e7fc3c7d6 100644 --- a/torchrec/metrics/segmented_ne.py +++ b/torchrec/metrics/segmented_ne.py @@ -165,6 +165,9 @@ class SegmentedNEMetricComputation(RecMetricComputation): Args: include_logloss (bool): return vanilla logloss as one of metrics results, on top of segmented NE. + num_groups (int): number of groups to segment NE by. + grouping_keys (str): name of the tensor containing the label by which results will be segmented. This tensor should be of type torch.int64. + cast_keys_to_int (bool): whether to cast grouping_keys to torch.int64. Only works if grouping_keys is of type torch.float32. """ def __init__( @@ -172,11 +175,15 @@ def __init__( *args: Any, include_logloss: bool = False, # TODO - include num_groups: int = 1, + grouping_keys: str = "grouping_keys", + cast_keys_to_int: bool = False, **kwargs: Any, ) -> None: self._include_logloss: bool = include_logloss super().__init__(*args, **kwargs) self._num_groups = num_groups # would there be checkpointing issues with this? maybe make this state + self._grouping_keys = grouping_keys + self._cast_keys_to_int = cast_keys_to_int self._add_state( "cross_entropy_sum", torch.zeros((self._n_tasks, num_groups), dtype=torch.double), @@ -217,21 +224,30 @@ def update( ) -> None: if predictions is None or weights is None: raise RecMetricException( - "Inputs 'predictions' and 'weights' and 'grouping_keys' should not be None for NEMetricComputation update" + f"Inputs 'predictions' and 'weights' and '{self._grouping_keys}' should not be None for NEMetricComputation update" ) elif ( "required_inputs" not in kwargs - or kwargs["required_inputs"].get("grouping_keys") is None + or kwargs["required_inputs"].get(self._grouping_keys) is None ): raise RecMetricException( - f"Required inputs for SegmentedNEMetricComputation update should contain 'grouping_keys', got kwargs: {kwargs}" - ) - elif kwargs["required_inputs"]["grouping_keys"].dtype != torch.int64: - raise RecMetricException( - f"Grouping keys must have type torch.int64, got {kwargs['required_inputs']['grouping_keys'].dtype}." + f"Required inputs for SegmentedNEMetricComputation update should contain {self._grouping_keys}, got kwargs: {kwargs}" ) + elif kwargs["required_inputs"][self._grouping_keys].dtype != torch.int64: + if ( + self._cast_keys_to_int + and kwargs["required_inputs"][self._grouping_keys].dtype + == torch.float32 + ): + kwargs["required_inputs"][self._grouping_keys] = kwargs[ + "required_inputs" + ][self._grouping_keys].to(torch.int64) + else: + raise RecMetricException( + f"Grouping keys expected to have type torch.int64 or torch.float32 with cast_keys_to_int set to true, got {kwargs['required_inputs'][self._grouping_keys].dtype}." + ) - grouping_keys = kwargs["required_inputs"]["grouping_keys"] + grouping_keys = kwargs["required_inputs"][self._grouping_keys] states = get_segemented_ne_states( labels, predictions, @@ -325,4 +341,8 @@ def __init__( process_group=process_group, **kwargs, ) - self._required_inputs.add("grouping_keys") + if "grouping_keys" not in kwargs: + self._required_inputs.add("grouping_keys") + else: + # pyre-ignore[6] + self._required_inputs.add(kwargs["grouping_keys"]) diff --git a/torchrec/metrics/tests/test_segmented_ne.py b/torchrec/metrics/tests/test_segmented_ne.py index 91a70d9e9..507a7cc8f 100644 --- a/torchrec/metrics/tests/test_segmented_ne.py +++ b/torchrec/metrics/tests/test_segmented_ne.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Iterable, Union +from typing import Any, Dict, Iterable, Union import torch from torch import no_grad @@ -31,6 +31,8 @@ def _test_segemented_ne_helper( weights: torch.Tensor, expected_ne: torch.Tensor, grouping_keys: torch.Tensor, + grouping_key_tensor_name: str = "grouping_keys", + cast_keys_to_int: bool = False, ) -> None: num_task = labels.shape[0] batch_size = labels.shape[0] @@ -41,7 +43,7 @@ def _test_segemented_ne_helper( "weights": {}, } if grouping_keys is not None: - inputs["required_inputs"] = {"grouping_keys": grouping_keys} + inputs["required_inputs"] = {grouping_key_tensor_name: grouping_keys} for i in range(num_task): task_info = RecTaskInfo( name=f"Task:{i}", @@ -64,6 +66,10 @@ def _test_segemented_ne_helper( tasks=task_list, # pyre-ignore num_groups=max(2, torch.unique(grouping_keys)[-1].item() + 1), + # pyre-ignore + grouping_keys=grouping_key_tensor_name, + # pyre-ignore + cast_keys_to_int=cast_keys_to_int, ) ne.update(**inputs) actual_ne = ne.compute() @@ -95,7 +101,7 @@ def test_grouped_ne(self) -> None: raise -def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: +def generate_model_outputs_cases() -> Iterable[Dict[str, Any]]: return [ # base condition { @@ -149,4 +155,23 @@ def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: ), # for this case, both tasks have same groupings "expected_ne": torch.tensor([[3.1615, 1.6004], [1.0034, 0.4859]]), }, + # Custom grouping key tensor name + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_ne": torch.tensor([[3.1615, 1.6004]]), + "grouping_key_tensor_name": "custom_key", + }, + # Cast grouping keys to int32 + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0]), + "expected_ne": torch.tensor([[3.1615, 1.6004]]), + "grouping_key_tensor_name": "custom_key", + "cast_keys_to_int": True, + }, ]