diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 63a646dfbc..3ff2e62712 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -913,7 +913,7 @@ def _update_tablewise_cache_miss( self.table_wise_cache_miss[i] += miss_count - def forward( + def _forward_impl( self, indices: Tensor, offsets: Tensor, @@ -1016,6 +1016,16 @@ def forward( fp8_exponent_bias=self.fp8_exponent_bias, ) + def forward( + self, + indices: Tensor, + offsets: Tensor, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: + return self._forward_impl( + indices=indices, offsets=offsets, per_sample_weights=per_sample_weights + ) + def initialize_logical_weights_placements_and_offsets( self, ) -> None: