Skip to content

Commit

Permalink
Updating split_table_batched_embeddings_ops_training.py (#3613)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#690


After this diff stack:

EmbeddingKernelConfig now supports adding embedding_table_int32_index_type and embedding_table_int32_offset_type to the fused_params.

These are used downstream to determine whether the indices and offsets types for split_table_batched_embeddings_ops_training.py

Reviewed By: q10

Differential Revision: D68005929
  • Loading branch information
basilwong authored and facebook-github-bot committed Jan 29, 2025
1 parent 5b048ab commit 50febbe
Showing 1 changed file with 30 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
are not enabled by default.
- `use_rowwise_bias_correction` is used in Adam to enable rowwise
bias correction computation
embedding_table_index_type (torch.dtype = torch.int64): The data type of
the embedding table index tensor. Options are `torch.int32` and
`torch.int64`
embedding_table_offset_type (torch.dtype = torch.int64): The data type of
the embedding table offset tensor. Options are `torch.int32` and
`torch.int64`
"""

embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
Expand Down Expand Up @@ -654,6 +662,8 @@ def __init__( # noqa C901
uvm_host_mapped: bool = False,
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
embedding_table_index_type: torch.dtype = torch.int64,
embedding_table_offset_type: torch.dtype = torch.int64,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -1343,6 +1353,17 @@ def __init__( # noqa C901
FeatureGateName.BOUNDS_CHECK_INDICES_V2
)

if embedding_table_index_type not in [torch.int32, torch.int64]:
raise ValueError(
f"embedding_table_index_type must be torch.int32 or torch.int64, but got {embedding_table_index_type}"
)
self.embedding_table_index_type: torch.dtype = embedding_table_index_type
if embedding_table_offset_type not in [torch.int32, torch.int64]:
raise ValueError(
f"embedding_table_offset_type must be torch.int32 or torch.int64, but got {embedding_table_offset_type}"
)
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type

@torch.jit.ignore
def log(self, msg: str) -> None:
"""
Expand Down Expand Up @@ -3409,6 +3430,15 @@ def prepare_inputs(
# NOTE: Force offsets to have the same dtype as indices since the
# kernels assume same dtype. We might need to revisit the assumption
# of same dtypes in the future.
if self.embedding_table_index_type == torch.int32:
self.log(
"Casting indices to int32 based on embedding_table_index_type input."
)
indices = indices.to(torch.int32)
if self.embedding_table_index_type != self.embedding_table_offset_type:
self.log(
f"Force casting offsets to {self.embedding_table_index_type} so that it is the same as the indices type."
)
offsets = offsets.to(dtype=indices.dtype)

# Force casting per_sample_weights to float
Expand Down

0 comments on commit 50febbe

Please sign in to comment.