Skip to content

Commit

Permalink
[PyTorch] Use __torch_function__ as a class method (NVIDIA#783)
Browse files Browse the repository at this point in the history
Use torch function as a class method

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Apr 16, 2024
1 parent 63c7a1a commit d3552dd
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions transformer_engine/pytorch/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def forward(
@staticmethod
def backward(ctx,
grad: torch.Tensor,
) -> Tuple[[torch.Tensor, None], ...]:
) -> Tuple[Union[torch.Tensor, None], ...]:

if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
Expand Down Expand Up @@ -853,5 +853,8 @@ def _set_data(self, tensor: torch.Tensor) -> None:
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))

# Do not force the Float8Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)

0 comments on commit d3552dd

Please sign in to comment.