diff --git a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py index 3fab8c47d1..eb4bfa24e5 100644 --- a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py +++ b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py @@ -71,6 +71,10 @@ class FeatureGate: FeatureGate.is_enabled(FeatureGateName.TBE_V2) """ + dummy_tensor = torch.empty(1) + @classmethod def is_enabled(cls, feature: FeatureGateName) -> bool: - return torch.ops.fbgemm.check_feature_gate_key(feature.name) + return torch.ops.fbgemm.check_feature_gate_key_pt2( + cls.dummy_tensor, feature.name + )[0].item() diff --git a/fbgemm_gpu/src/config/feature_gates.cpp b/fbgemm_gpu/src/config/feature_gates.cpp index e0e07a12d0..cb3aad5925 100644 --- a/fbgemm_gpu/src/config/feature_gates.cpp +++ b/fbgemm_gpu/src/config/feature_gates.cpp @@ -65,6 +65,21 @@ DLL_PUBLIC bool check_feature_gate_key(const std::string& key) { } } +DLL_PUBLIC at::Tensor check_feature_gate_key_pt2( + at::Tensor& tensor, + const std::string& key) { + auto output = at::empty({1}, tensor.options().dtype(at::kBool)); + output.data_ptr()[0] = check_feature_gate_key(key); + return output; +} + +DLL_PUBLIC at::Tensor check_feature_gate_key_pt2_meta( + at::Tensor& tensor, + const std::string& key) { + auto output = at::empty({1}, tensor.options().dtype(at::kBool)); + return output; +} + DLL_PUBLIC bool is_feature_enabled(const FeatureGateName& feature) { return check_feature_gate_key(to_string(feature)); } @@ -81,4 +96,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "check_feature_gate_key(str key) -> bool", fbgemm_gpu::config::check_feature_gate_key); + m.def("check_feature_gate_key_pt2(Tensor tensor, str key) -> Tensor"); + DISPATCH_TO_CPU( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2); + DISPATCH_TO_CUDA( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2); + DISPATCH_TO_META( + "check_feature_gate_key_pt2", + fbgemm_gpu::config::check_feature_gate_key_pt2_meta); }