From b80208b63f792285d53a15357c712095d8ae6d54 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Sat, 30 Nov 2024 01:31:12 -0800 Subject: [PATCH 1/2] Make check_feature_gate_key PT2 compatible (#3425) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/512 Add a new API for `check_feature_gate_key` that is PT2 compatible. PT2 complains when an op does not take/return a tensor. Thus, `check_feature_gate_key_pt2` (the new API) takes a dummy tensor as an input and returns a boolean tensor as an output. Differential Revision: D66611784 --- fbgemm_gpu/src/config/feature_gates.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/fbgemm_gpu/src/config/feature_gates.cpp b/fbgemm_gpu/src/config/feature_gates.cpp index e0e07a12d0..cb3aad5925 100644 --- a/fbgemm_gpu/src/config/feature_gates.cpp +++ b/fbgemm_gpu/src/config/feature_gates.cpp @@ -65,6 +65,21 @@ DLL_PUBLIC bool check_feature_gate_key(const std::string& key) { } } +DLL_PUBLIC at::Tensor check_feature_gate_key_pt2( + at::Tensor& tensor, + const std::string& key) { + auto output = at::empty({1}, tensor.options().dtype(at::kBool)); + output.data_ptr()[0] = check_feature_gate_key(key); + return output; +} + +DLL_PUBLIC at::Tensor check_feature_gate_key_pt2_meta( + at::Tensor& tensor, + const std::string& key) { + auto output = at::empty({1}, tensor.options().dtype(at::kBool)); + return output; +} + DLL_PUBLIC bool is_feature_enabled(const FeatureGateName& feature) { return check_feature_gate_key(to_string(feature)); } @@ -81,4 +96,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "check_feature_gate_key(str key) -> bool", fbgemm_gpu::config::check_feature_gate_key); + m.def("check_feature_gate_key_pt2(Tensor tensor, str key) -> Tensor"); + DISPATCH_TO_CPU( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2); + DISPATCH_TO_CUDA( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2); + DISPATCH_TO_META( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2_meta); } From 8f561dd4154df50f35dbab74e3f979187a17df29 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Sat, 30 Nov 2024 01:31:12 -0800 Subject: [PATCH 2/2] Make check_feature_gate_key PT2 compatible Summary: Add a new API for `check_feature_gate_key` that is PT2 compatible. PT2 complains when an op does not take/return a tensor. Thus, `check_feature_gate_key_pt2` (the new API) takes a dummy tensor as an input and returns a boolean tensor as an output. Differential Revision: D66611785 --- fbgemm_gpu/fbgemm_gpu/config/feature_list.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py index 3fab8c47d1..eb4bfa24e5 100644 --- a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py +++ b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py @@ -71,6 +71,10 @@ class FeatureGate: FeatureGate.is_enabled(FeatureGateName.TBE_V2) """ + dummy_tensor = torch.empty(1) + @classmethod def is_enabled(cls, feature: FeatureGateName) -> bool: - return torch.ops.fbgemm.check_feature_gate_key(feature.name) + return torch.ops.fbgemm.check_feature_gate_key_pt2( + cls.dummy_tensor, feature.name + )[0].item()