From d8e07cef21d8745345df0e1574ecec557f94138a Mon Sep 17 00:00:00 2001 From: Tuan Trieu Date: Tue, 11 Feb 2025 16:29:24 -0800 Subject: [PATCH] Support histogram_binning_calibration for export (#3657) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/733 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3657 Add fake tensor for `histogram_binning_calibration` which is needed to export old PA. Reviewed By: hongyang-zhao Differential Revision: D69089371 fbshipit-source-id: 8f3aedfe2d42248ee0004109a2b91f00659025e1 --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 25 +++++++++++++++++++++++ fbgemm_gpu/test/sparse/failures_dict.json | 8 ++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 79359d0010..8b642361d8 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -1100,6 +1100,27 @@ def fused_8_bit_rowwise_quantized_to_half( return torch.empty(output_shape, dtype=torch.float16, device=input_t.device) +def generic_histogram_binning_calibration_by_feature( + logit: Tensor, + segment_value: Tensor, + segment_lengths: Tensor, + num_segments: int, + bin_num_examples: Tensor, + bin_num_positives: Tensor, + bin_boundaries: Tensor, + positive_weight: float, + bin_ctr_in_use_after: int, + bin_ctr_weight_value: float, +) -> Tuple[Tensor, Tensor]: + torch._check(bin_num_examples.numel() == bin_num_positives.numel()) + torch._check( + bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1) + ) + return torch.empty_like(logit), torch.empty( + [logit.numel()], dtype=torch.int64, device=logit.device + ) + + def _setup() -> None: # pyre-ignore[16] _setup.done = getattr(_setup, "done", False) @@ -1233,6 +1254,10 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None "fbgemm::histogram_binning_calibration", histogram_binning_calibration_abstract, ) + impl_abstract( + "fbgemm::generic_histogram_binning_calibration_by_feature", + generic_histogram_binning_calibration_by_feature, + ) impl_abstract( "fbgemm::FloatToHFP8Quantized", float_to_hfp8_quantized, diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index 50cfe984e5..a86a086da2 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -141,19 +141,19 @@ "fbgemm::generic_histogram_binning_calibration_by_feature": { "HistogramBinningCalibrationTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "HistogramBinningCalibrationTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature_cpu_gpu": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "HistogramBinningCalibrationTest.test_faketensor__test_generic_histogram_binning_calibration_by_feature": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "HistogramBinningCalibrationTest.test_faketensor__test_generic_histogram_binning_calibration_by_feature_cpu_gpu": { "comment": "", - "status": "xfail" + "status": "xsuccess" } }, "fbgemm::group_index_select_dim0": {