Skip to content

Commit

Permalink
Attempt to use rFactor for allreduce
Browse files Browse the repository at this point in the history
```
mpirun -np 1 -x NVFUSER_DUMP=fusion_ir_preseg pytest tests/python/test_communication.py -k allreduce_rfactor -s --only-mpi
```

```
%kernel {
T2_l_float[iS10{2}, ideviceIdx.x13{1}rf, rS14{5}rf, iS12{3}] (DeviceMesh{0})
   = reduction( T0_g_float[iS0{2}, ideviceIdx.x6{1}, iS7{5}, iS2{3}] (DeviceMesh{0}), op = add, initial value = float(0), allreduce = false )
T1_g_float[iS15{2}, rS16{1}, iS17{3}] (DeviceMesh{0})
   = reduction( T2_l_float[iS10{2}, ideviceIdx.x13{1}rf, rS14{5}rf, iS12{3}] (DeviceMesh{0}), op = add, initial value = float(0), allreduce = false )

TransformPrinter :
T0_g_float[iS0{2}, ideviceIdx.x6{1}, iS7{5}, iS2{3}] (DeviceMesh{0})
 logical domain : (iS0{2}, iS1{5}, iS2{3})
 allocation domain : (iS0{2}, ideviceIdx.x6{1}, iS7{5}, iS2{3})
 contiguity: t t t t
  Outer split: iS1{5} by factor 1 -> ideviceIdx.x6{1}, iS7{5}
 loop domain : (iS0{2}, ideviceIdx.x6{1}, iS7{5}, iS2{3})
T2_l_float[iS10{2}, ideviceIdx.x13{1}rf, rS14{5}rf, iS12{3}] (DeviceMesh{0})
 root domain : (iS10{2}, rS11{5}rf, iS12{3})
  Outer split: rS11{5}rf by factor 1 -> ideviceIdx.x13{1}rf, rS14{5}rf
 logical domain : (iS10{2}, ideviceIdx.x13{1}rf, rS14{5}rf, iS12{3})
 allocation domain : (iS10{2}, ideviceIdx.x13{1}rf, rS14{5}rf, iS12{3})
 contiguity: t t n t
 loop domain : (iS10{2}, ideviceIdx.x13{1}rf, rS14{5}rf, iS12{3})
T1_g_float[iS15{2}, rS16{1}, iS17{3}] (DeviceMesh{0})
 logical domain : (iS15{2}, rS16{1}, iS17{3})
 allocation domain : (iS15{2}, rS16{1}, iS17{3})
 contiguity: t n t
 loop domain : (iS15{2}, rS16{1}, iS17{3})
} // %kernel
```

```
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/scheduler/vectorize_helper.cpp":1063, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Unexpected producer RF ID: iS2{3}
Exception raised from getVectorizationBreakPointOfReductionProducer at /opt/pytorch/nvfuser/csrc/scheduler/vectorize_helper.cpp:1063 (most recent call first):
```
  • Loading branch information
wujingyue committed Dec 7, 2024
1 parent af7dd68 commit 95eb150
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,44 @@ def multidevice_schedule(self):
torch.testing.assert_close(outputs[0].cpu(), unsharded.sum(0))


@pytest.mark.mpi
def test_allreduce_rfactor(mpi_test):
d = mpi_test.size
m = 2
n = 3
k = d * 5

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor((m, k, n), contiguity=True, dtype=DataType.Float)
self.out = self.ops.sum(self.inp, [1])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched.split(self.inp, 1, d, False)
self.sched.split(self.out, 1, d, False)
out_local = self.sched.rfactor(self.out, [2])

mesh = self.sched._create_device_mesh(range(d))
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)
self.sched._set_device_mesh(out_local, mesh)

self.sched.parallelize(self.inp, 1, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(out_local, 1, nvfuser.ParallelType.mesh_x)

self.sched.set_allocation_as_loop(self.inp)
self.sched.set_allocation_as_loop(out_local)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(m, k, n)
sharded = mpi_test.shard_tensor(unsharded, 1)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded.sum(1))


@pytest.mark.mpi
def test_reduce_scatter(mpi_test):
d = mpi_test.size
Expand Down

0 comments on commit 95eb150

Please sign in to comment.