Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2K/N) (#3583)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#668

Pull Request resolved: #3583

- Update TBE benchmark test to support `int32_t` indicies
- Currently only supports dual int32_t indices/offsets or int64_t indices/offsets, no mixed.  Depending on future feedback, we might enable mixed mode, similar to the prototype work done in D63857531.

Reviewed By: basilwong

Differential Revision: D68296454

fbshipit-source-id: 6bf65f8228d2761e21a8a334c6850244ad376384
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 24, 2025
1 parent 5754ce7 commit 5f3adca
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
13 changes: 9 additions & 4 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def cli() -> None:
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--dense", is_flag=True, default=False)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64")
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--export-trace", is_flag=True, default=False)
Expand Down Expand Up @@ -189,6 +190,7 @@ def device( # noqa C901
flush_gpu_cache_size_mb: int,
dense: bool,
output_dtype: SparseType,
indices_dtype: str,
requests_data_file: Optional[str],
tables: Optional[str],
export_trace: bool,
Expand All @@ -201,6 +203,9 @@ def device( # noqa C901
) -> None:
assert not ssd or not dense, "--ssd cannot be used together with --dense"
num_requests = iters if num_requests == -1 else num_requests
indices_dtype_torch: torch.dtype = (
torch.int32 if int(indices_dtype) == 32 else torch.int64
)
np.random.seed(42)
torch.manual_seed(42)
B = batch_size
Expand Down Expand Up @@ -378,8 +383,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices,
offsets,
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down Expand Up @@ -411,8 +416,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb(
indices,
offsets,
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3373,8 +3373,10 @@ def prepare_inputs(
)

if force_cast_input_types:
# Force casting indices and offsets to long
(indices, offsets) = indices.long(), offsets.long()
# NOTE: Force offsets to have the same dtype as indices since the
# kernels assume same dtype. We might need to revisit the assumption
# of same dtypes in the future.
offsets = offsets.to(dtype=indices.dtype)

# Force casting per_sample_weights to float
if per_sample_weights is not None:
Expand Down Expand Up @@ -3731,7 +3733,11 @@ def forward(
offsets, batch_size_per_feature_per_rank
)

(indices, offsets) = indices.long(), offsets.long()
# NOTE: Force offsets to have the same dtype as indices since the
# kernels assume same dtype. We might need to revisit the assumption
# of same dtypes in the future.
offsets = offsets.to(dtype=indices.dtype)

# Force casting per_sample_weights to float
if per_sample_weights is not None:
per_sample_weights = per_sample_weights.float()
Expand Down

0 comments on commit 5f3adca

Please sign in to comment.