Skip to content

Commit

Permalink
Reland [2] [FBGEMM][MTIA] Support MTIA in DenseTableBatchedEmbeddingB…
Browse files Browse the repository at this point in the history
…agsCodegen (#2100)

Summary:
Pull Request resolved: #2100

Original diff D58137460, get reverted on D58449527

Will reland in two diffs, one with FBGEMM, one with other changes

Reviewed By: egienvalue

Differential Revision: D58473485

fbshipit-source-id: 53a75989bb05978694963a42c232b378a018dde8
  • Loading branch information
gnahzg authored and facebook-github-bot committed Jun 19, 2024
1 parent 52d689e commit 6f870f0
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,16 +939,20 @@ def __init__(
weights_precision = data_type_to_sparse_type(config.data_type)
fused_params = config.fused_params or {}
output_dtype = fused_params.get("output_dtype", SparseType.FP32)
use_cpu: bool = (
device is None
or device.type == "cpu"
or (not (torch.cuda.is_available() or torch.mtia.is_available()))
)
self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
DenseTableBatchedEmbeddingBagsCodegen(
list(zip(self._local_rows, self._local_cols)),
feature_table_map=self._feature_table_map,
pooling_mode=PoolingMode.NONE,
use_cpu=device is None
or device.type == "cpu"
or not torch.cuda.is_available(),
use_cpu=use_cpu,
weights_precision=weights_precision,
output_dtype=output_dtype,
use_mtia=device is not None and device.type == "mtia",
)
)
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
Expand Down Expand Up @@ -1376,16 +1380,20 @@ def __init__(
weights_precision = data_type_to_sparse_type(config.data_type)
fused_params = config.fused_params or {}
output_dtype = fused_params.get("output_dtype", SparseType.FP32)
use_cpu: bool = (
device is None
or device.type == "cpu"
or (not (torch.cuda.is_available() or torch.mtia.is_available()))
)
self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
DenseTableBatchedEmbeddingBagsCodegen(
list(zip(self._local_rows, self._local_cols)),
feature_table_map=self._feature_table_map,
pooling_mode=self._pooling,
use_cpu=device is None
or device.type == "cpu"
or not torch.cuda.is_available(),
use_cpu=use_cpu,
weights_precision=weights_precision,
output_dtype=output_dtype,
use_mtia=device is not None and device.type == "mtia",
)
)
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
Expand Down

0 comments on commit 6f870f0

Please sign in to comment.