diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h index 1cfc1a987c..b1c5386b0e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -30,6 +30,7 @@ using torch::autograd::variable_list; class PermuteMultiEmbeddingOp : public torch::autograd::Function { public: + static constexpr bool is_traceable = true; static variable_list forward( AutogradContext* ctx, const at::TensorList& pooled_embs,