From 6f870f09ec2bbf0045c938c7b4b9eedff8dfb1ee Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Tue, 18 Jun 2024 21:47:33 -0700 Subject: [PATCH] Reland [2] [FBGEMM][MTIA] Support MTIA in DenseTableBatchedEmbeddingBagsCodegen (#2100) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2100 Original diff D58137460, get reverted on D58449527 Will reland in two diffs, one with FBGEMM, one with other changes Reviewed By: egienvalue Differential Revision: D58473485 fbshipit-source-id: 53a75989bb05978694963a42c232b378a018dde8 --- .../distributed/batched_embedding_kernel.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index fc2a514b5..9b47c5e41 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -939,16 +939,20 @@ def __init__( weights_precision = data_type_to_sparse_type(config.data_type) fused_params = config.fused_params or {} output_dtype = fused_params.get("output_dtype", SparseType.FP32) + use_cpu: bool = ( + device is None + or device.type == "cpu" + or (not (torch.cuda.is_available() or torch.mtia.is_available())) + ) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( list(zip(self._local_rows, self._local_cols)), feature_table_map=self._feature_table_map, pooling_mode=PoolingMode.NONE, - use_cpu=device is None - or device.type == "cpu" - or not torch.cuda.is_available(), + use_cpu=use_cpu, weights_precision=weights_precision, output_dtype=output_dtype, + use_mtia=device is not None and device.type == "mtia", ) ) self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict( @@ -1376,16 +1380,20 @@ def __init__( weights_precision = data_type_to_sparse_type(config.data_type) fused_params = config.fused_params or {} output_dtype = fused_params.get("output_dtype", SparseType.FP32) + use_cpu: bool = ( + device is None + or device.type == "cpu" + or (not (torch.cuda.is_available() or torch.mtia.is_available())) + ) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( list(zip(self._local_rows, self._local_cols)), feature_table_map=self._feature_table_map, pooling_mode=self._pooling, - use_cpu=device is None - or device.type == "cpu" - or not torch.cuda.is_available(), + use_cpu=use_cpu, weights_precision=weights_precision, output_dtype=output_dtype, + use_mtia=device is not None and device.type == "mtia", ) ) self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(