Skip to content

Commit

Permalink
Support padding based on row_dim (torchrec part) (pytorch#2204)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2204

MX4 GEMM kernel required the total number of element in the KJT per rank is divisible to 32.

Failed job: aps-350x_lite-b76673ccdc

  RuntimeError:  Number of inputs needs to be a multiple of group size
  Exception raised from quantize_mx_cuda at fbcode/deeplearning/fbgemm/fbgemm_gpu/src/quantize_ops/quantize_mx.cu:63 (most recent call first):

  # 2  c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*)
  # 3  fbgemm_gpu::quantize_mx_cuda(at::Tensor const&, std::vector<long, std::allocator<long> > const&, long, long, long, double, long, bool, long)
  # 4  std::decay<c10::guts::infer_function_traits<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, std::vector<long, std::allocator<long> > const&, long, long, long, double, long, bool, long), &fbgemm_gpu::quantize_mx_cuda>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, std::vector<long, std::allocator<long> > const&, long, long, long, double, long, bool, long> > >::type::return_type>::type c10::impl::call_functor_with_args_from_stack_<c10::impl::detail::WrapFunctionIntoFunctor...

Reviewed By: sryap

Differential Revision: D58223717

fbshipit-source-id: 910d365b95b9c8d06b1ac4240b550816d723c9f0
  • Loading branch information
Robert Luo authored and facebook-github-bot committed Jul 4, 2024
1 parent 291e7e1 commit 375bff6
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions torchrec/distributed/fbgemm_qcomm_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class QCommsConfig:
fp8_quantize_dim: Optional[int] = None
fp8_quantize_dim_bwd: Optional[int] = None
fp8_bwd_uses_143: Optional[bool] = False
mx4_quantize_dim: Optional[int] = None
mx4_quantize_dim_bwd: Optional[int] = None

def __post_init__(self) -> None:
if (
Expand All @@ -90,10 +92,35 @@ def __post_init__(self) -> None:
f"No override of FP8 bwd row dim, using general FP8 row dim for backward: {self.fp8_quantize_dim_bwd} "
)

if (
self.forward_precision != CommType.MX4
and self.backward_precision != CommType.MX4
and (
self.mx4_quantize_dim is not None
or self.mx4_quantize_dim_bwd is not None
)
):
raise ValueError(
f"mx4_quantize_dim is set to {self.mx4_quantize_dim} and mx4_quantize_dim_bwd is set to {self.mx4_quantize_dim_bwd} but no MX4 precision is found in forward or backward precisions"
)
if (
self.backward_precision == CommType.MX4
and self.mx4_quantize_dim_bwd is None
):
self.mx4_quantize_dim_bwd = self.mx4_quantize_dim
logger.warning(
f"No override of MX4 bwd row dim, using general MX4 row dim for backward: {self.mx4_quantize_dim_bwd} "
)


def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCodecs:
codecs = QuantizedCommCodecs()
if qcomms_config is not None:
row_dim = None
if qcomms_config.forward_precision == CommType.FP8:
row_dim = qcomms_config.fp8_quantize_dim
elif qcomms_config.forward_precision == CommType.MX4:
row_dim = qcomms_config.mx4_quantize_dim
codecs.forward = cast(
QuantizedCommCodec[QuantizationContext],
FbgemmQuantizedCommCodec(
Expand All @@ -102,13 +129,14 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
),
loss_scale=qcomms_config.forward_loss_scale,
is_fwd=True,
row_dim=(
qcomms_config.fp8_quantize_dim
if qcomms_config.forward_precision == CommType.FP8
else None
),
row_dim=row_dim,
),
)
row_dim_bwd = None
if qcomms_config.backward_precision == CommType.FP8:
row_dim_bwd = qcomms_config.fp8_quantize_dim_bwd
elif qcomms_config.backward_precision == CommType.MX4:
row_dim_bwd = qcomms_config.mx4_quantize_dim_bwd
codecs.backward = cast(
QuantizedCommCodec[QuantizationContext],
FbgemmQuantizedCommCodec(
Expand All @@ -120,11 +148,7 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
True if qcomms_config.fp8_bwd_uses_143 else False
), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3
# if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2
row_dim=(
qcomms_config.fp8_quantize_dim_bwd
if qcomms_config.backward_precision == CommType.FP8
else None
),
row_dim=row_dim_bwd,
),
)
return codecs
Expand Down

0 comments on commit 375bff6

Please sign in to comment.