Skip to content

Commit

Permalink
Add Support for FakeProcessGroup (pytorch#1877)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#133039

Pull Request resolved: pytorch#1877

# context
* use FakeProcessGroup to mimic the multi-process tests
* can use `_test_compile_fake_pg_fn` as the single-process VB compile test
```
from torchrec.distributed.tests.test_pt2_multiprocess import _test_compile_fake_pg_fn
_test_compile_fake_pg_fn(
    rank=0,
    world_size=2,
)
```

reference: D59637444

NOTE: right now only tested for EBC, not sure about other sparse modules like PEA or VLE, which shouldn't be too hard to add similar changes.

Reviewed By: ezyang

Differential Revision: D56124045
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 9, 2024
1 parent 5f8a495 commit 9da8f4a
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 21 deletions.
27 changes: 19 additions & 8 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,14 +559,25 @@ def variable_batch_all2all_pooled_sync(
]

with record_function("## alltoall_fwd_single ##"):
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
pg_name(pg),
pg.size(),
get_gradient_division(),
)
if pg._get_backend_name() == "fake":
sharded_output_embeddings = torch.empty(
sum(output_split_sizes),
device=sharded_input_embeddings.device,
dtype=sharded_input_embeddings.dtype,
)
s0 = sharded_output_embeddings.size(0)
# Bad assumption that our rank GE than other
torch._check(s0 <= sharded_input_embeddings.size(0))
sharded_output_embeddings.copy_(sharded_input_embeddings[:s0])
else:
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
pg_name(pg),
pg.size(),
get_gradient_division(),
)

if a2ai.codecs is not None:
codecs = none_throws(a2ai.codecs)
Expand Down
54 changes: 42 additions & 12 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,25 @@ def __init__(
# https://github.com/pytorch/pytorch/issues/122788
with record_function("## all2all_data:kjt splits ##"):
input_tensor = torch.stack(input_tensors, dim=1).flatten()
self._output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split_sizes=None,
input_split_sizes=None,
group=pg,
)
if pg._get_backend_name() == "fake":
self._output_tensor = torch.empty(
[self.num_workers * len(input_tensors)],
device=input_tensors[0].device,
dtype=input_tensors[0].dtype,
)

self._output_tensor = input_tensor[
: input_tensor.size(0) // 2
].repeat(2)
else:
self._output_tensor = (
dist._functional_collectives.all_to_all_single(
input_tensor,
output_split_sizes=None,
input_split_sizes=None,
group=pg,
)
)
# To avoid hasattr in _wait_impl to check self._splits_awaitable
# pyre-ignore
self._splits_awaitable = None
Expand Down Expand Up @@ -342,6 +355,7 @@ def __init__(
self._output_tensors: List[torch.Tensor] = []
self._awaitables: List[dist.Work] = []
self._world_size: int = self._pg.size()
rank = dist.get_rank(self._pg)

for input_split, output_split, input_tensor, label in zip(
input_splits,
Expand All @@ -353,12 +367,28 @@ def __init__(
# TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_
# https://github.com/pytorch/pytorch/issues/122788
with record_function(f"## all2all_data:kjt {label} ##"):
output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split,
input_split,
pg,
)
if self._pg._get_backend_name() == "fake":
output_tensor = torch.empty(
sum(output_split),
device=self._device,
dtype=input_tensor.dtype,
)
_l = sum(output_split[:rank])
_r = _l + output_split[rank]
torch._check(_r < input_tensor.size(0))
torch._check(_l < input_tensor.size(0))
torch._check(_l <= _r)
torch._check(2 * (_r - _l) == output_tensor.size(0))
output_tensor.copy_(
input_tensor[_l:_r].repeat(self._world_size)
)
else:
output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split,
input_split,
pg,
)
self._output_tensors.append(output_tensor)
else:
output_tensor = torch.empty(
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def __init__(
broadcast_buffers=True,
static_graph=True,
)
self._initialize_torch_state()

if env.process_group and dist.get_backend(env.process_group) != "fake":
self._initialize_torch_state()

# TODO[zainhuda]: support module device coming from CPU
if module.device not in ["meta", "cpu"] and module.device.type not in [
Expand Down
194 changes: 194 additions & 0 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from hypothesis import given, settings, strategies as st, Verbosity
from torch import distributed as dist
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.distributed import ProcessGroup
from torch.testing._internal.distributed.fake_pg import FakeStore
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig
Expand Down Expand Up @@ -499,6 +501,184 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
##### NUMERIC CHECK END #####


def _test_compile_fake_pg_fn(
rank: int,
world_size: int,
) -> None:
sharding_type = ShardingType.TABLE_WISE.value
input_type = _InputType.SINGLE_BATCH
torch_compile_backend = "eager"
config = _TestConfig()
num_embeddings = 256
# emb_dim must be % 4 == 0 for fbgemm
emb_dim = 12
batch_size = 10
num_features: int = 5

num_float_features: int = 8
num_weighted_features: int = 1

device: torch.Device = torch.device("cuda")
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
pg: ProcessGroup = dist.distributed_c10d._get_default_group()

topology: Topology = Topology(world_size=world_size, compute_device="cuda")
mi = TestModelInfo(
dense_device=device,
sparse_device=device,
num_features=num_features,
num_float_features=num_float_features,
num_weighted_features=num_weighted_features,
topology=topology,
)

mi.planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
enumerator=EmbeddingEnumerator(
topology=topology,
batch_size=batch_size,
estimator=[
EmbeddingPerfEstimator(topology=topology),
EmbeddingStorageEstimator(topology=topology),
],
),
)

mi.tables = [
EmbeddingBagConfig(
num_embeddings=num_embeddings,
embedding_dim=emb_dim,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(mi.num_features)
]

mi.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=num_embeddings,
embedding_dim=emb_dim,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(mi.num_weighted_features)
]

mi.model = _gen_model(_ModelType.EBC, mi)
mi.model.training = True

model = mi.model

planner = EmbeddingShardingPlanner(
topology=Topology(world_size, device.type),
constraints=None,
)

sharders = [
EBCSharderFixedShardingType(sharding_type),
ECSharderFixedShardingType(sharding_type),
]

plan: ShardingPlan = planner.plan(model, sharders) # pyre-ignore

def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore
return DistributedModelParallel(
m,
env=ShardingEnv(world_size, rank, pg),
plan=plan,
sharders=sharders,
device=device,
init_data_parallel=False,
)

dmp = _dmp(model)
dmp_compile = _dmp(model)

# TODO: Fix some data dependent failures on subsequent inputs
n_extra_numerics_checks = config.n_extra_numerics_checks_inputs
ins = []

for _ in range(1 + n_extra_numerics_checks):
if input_type == _InputType.VARIABLE_BATCH:
(
_,
local_model_inputs,
) = ModelInput.generate_variable_batch_input(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
# pyre-ignore
tables=mi.tables,
)
else:
(
_,
local_model_inputs,
) = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
tables=mi.tables,
weighted_tables=mi.weighted_tables,
variable_batch_size=False,
)
ins.append(local_model_inputs)

local_model_input = ins[0][rank].to(device)

kjt = local_model_input.idlist_features
ff = local_model_input.float_features
ff.requires_grad = True
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=True)

compile_input_ff = ff.clone().detach()
compile_input_ff.requires_grad = True

torchrec.distributed.comm_ops.set_use_sync_collectives(True)
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)

dmp.train(True)
dmp_compile.train(True)

def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
tbe = dmp._dmp_wrapped_module._ebc._lookups[0]._emb_modules[0]._emb_module
assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen)
return tbe.weights_dev.clone().detach()

original_weights = get_weights(dmp)
original_weights.zero_()
original_compile_weights = get_weights(dmp_compile)
original_compile_weights.zero_()

eager_out = dmp(kjt_ft, ff)
reduce_to_scalar_loss(eager_out).backward()

if torch_compile_backend is None:
return

##### COMPILE #####
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = True

opt_fn = torch.compile(
dmp_compile,
backend=torch_compile_backend,
fullgraph=True,
)
compile_out = opt_fn(
kjt_for_pt2_tracing(kjt, convert_to_vb=True), compile_input_ff
)
torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3)
##### COMPILE END #####


class TestPt2Train(MultiProcessTestBase):
def disable_cuda_tf32(self) -> bool:
return True
Expand Down Expand Up @@ -580,3 +760,17 @@ def test_compile_multiprocess(
config=config,
torch_compile_backend=compile_backend,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() < 1,
"Not enough GPUs, this test requires one GPU",
)
@settings(verbosity=Verbosity.verbose, deadline=None)
def test_compile_multiprocess_fake_pg(
self,
) -> None:
_test_compile_fake_pg_fn(
rank=0,
world_size=2,
)

0 comments on commit 9da8f4a

Please sign in to comment.