Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
enable float types in pytorch for non comptue comms (#263)
Browse files Browse the repository at this point in the history
Summary:
Coupled with this: pytorch/pytorch#126556
test everytihng is pasing

Pull Request resolved: #263

Reviewed By: wanchaol

Differential Revision: D57505783

Pulled By: drisspg

fbshipit-source-id: cd928420f559839c63d79bfe7558416fbcfe1d69
  • Loading branch information
drisspg authored and facebook-github-bot committed May 18, 2024
1 parent 6891cbe commit f7a920d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
3 changes: 0 additions & 3 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,8 @@ def allgather_fp8(aten_op, args, kwargs=None):
), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"

fp8_data = fp8_input._data
fp8_data = fp8_data.view(torch.uint8)
fp8_data = fp8_data.contiguous()
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
fp8_out = torch.ops._c10d_functional.wait_tensor(fp8_out)
fp8_out = fp8_out.view(fp8_input._data.dtype)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
)
Expand Down
2 changes: 2 additions & 0 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,5 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e

torch.distributed.destroy_process_group()

0 comments on commit f7a920d

Please sign in to comment.