Skip to content

Commit 6f870f0

Browse files
gnahzgfacebook-github-bot
authored andcommitted
Reland [2] [FBGEMM][MTIA] Support MTIA in DenseTableBatchedEmbeddingBagsCodegen (#2100)
Summary: Pull Request resolved: #2100 Original diff D58137460, get reverted on D58449527 Will reland in two diffs, one with FBGEMM, one with other changes Reviewed By: egienvalue Differential Revision: D58473485 fbshipit-source-id: 53a75989bb05978694963a42c232b378a018dde8
1 parent 52d689e commit 6f870f0

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -939,16 +939,20 @@ def __init__(
939939
weights_precision = data_type_to_sparse_type(config.data_type)
940940
fused_params = config.fused_params or {}
941941
output_dtype = fused_params.get("output_dtype", SparseType.FP32)
942+
use_cpu: bool = (
943+
device is None
944+
or device.type == "cpu"
945+
or (not (torch.cuda.is_available() or torch.mtia.is_available()))
946+
)
942947
self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
943948
DenseTableBatchedEmbeddingBagsCodegen(
944949
list(zip(self._local_rows, self._local_cols)),
945950
feature_table_map=self._feature_table_map,
946951
pooling_mode=PoolingMode.NONE,
947-
use_cpu=device is None
948-
or device.type == "cpu"
949-
or not torch.cuda.is_available(),
952+
use_cpu=use_cpu,
950953
weights_precision=weights_precision,
951954
output_dtype=output_dtype,
955+
use_mtia=device is not None and device.type == "mtia",
952956
)
953957
)
954958
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
@@ -1376,16 +1380,20 @@ def __init__(
13761380
weights_precision = data_type_to_sparse_type(config.data_type)
13771381
fused_params = config.fused_params or {}
13781382
output_dtype = fused_params.get("output_dtype", SparseType.FP32)
1383+
use_cpu: bool = (
1384+
device is None
1385+
or device.type == "cpu"
1386+
or (not (torch.cuda.is_available() or torch.mtia.is_available()))
1387+
)
13791388
self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
13801389
DenseTableBatchedEmbeddingBagsCodegen(
13811390
list(zip(self._local_rows, self._local_cols)),
13821391
feature_table_map=self._feature_table_map,
13831392
pooling_mode=self._pooling,
1384-
use_cpu=device is None
1385-
or device.type == "cpu"
1386-
or not torch.cuda.is_available(),
1393+
use_cpu=use_cpu,
13871394
weights_precision=weights_precision,
13881395
output_dtype=output_dtype,
1396+
use_mtia=device is not None and device.type == "mtia",
13891397
)
13901398
)
13911399
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(

0 commit comments

Comments
 (0)