Skip to content

Commit

Permalink
kernel row alignment correction (#1789)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1789

* sharder can help inject tbe row alignment via extra params from fused params.

Reviewed By: tissue3

Differential Revision: D54842984

fbshipit-source-id: df52030274fefe3eec271b87fdbb61923eafaa66
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Mar 13, 2024
1 parent cb6b69a commit 0f51389
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
12 changes: 12 additions & 0 deletions torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: str = (
"__register_quant_state_dict_split_scale_bias"
)
FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment"


class TBEToRegisterMixIn:
Expand Down Expand Up @@ -47,6 +48,15 @@ def is_fused_param_register_tbe(fused_params: Optional[Dict[str, Any]]) -> bool:
)


def get_fused_param_tbe_row_alignment(
fused_params: Optional[Dict[str, Any]]
) -> Optional[int]:
if fused_params is None or FUSED_PARAM_TBE_ROW_ALIGNMENT not in fused_params:
return None
else:
return fused_params[FUSED_PARAM_TBE_ROW_ALIGNMENT]


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
) -> bool:
Expand All @@ -68,5 +78,7 @@ def tbe_fused_params(
fused_params_for_tbe.pop(FUSED_PARAM_REGISTER_TBE_BOOL)
if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS)
if FUSED_PARAM_TBE_ROW_ALIGNMENT in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT)

return fused_params_for_tbe
10 changes: 9 additions & 1 deletion torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
GroupedEmbeddingConfig,
)
from torchrec.distributed.fused_params import (
get_fused_param_tbe_row_alignment,
is_fused_param_quant_state_dict_split_scale_bias,
is_fused_param_register_tbe,
tbe_fused_params,
Expand Down Expand Up @@ -318,6 +319,9 @@ def __init__(
self._quant_state_dict_split_scale_bias: bool = (
is_fused_param_quant_state_dict_split_scale_bias(fused_params)
)
self._tbe_row_alignment: Optional[int] = get_fused_param_tbe_row_alignment(
fused_params
)
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
Expand All @@ -342,7 +346,11 @@ def __init__(
device=device,
pooling_mode=PoolingMode.NONE,
feature_table_map=self._feature_table_map,
row_alignment=16,
row_alignment=(
self._tbe_row_alignment
if self._tbe_row_alignment is not None
else 16
),
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
**(tbe_fused_params(fused_params) or {}),
)
Expand Down

0 comments on commit 0f51389

Please sign in to comment.