Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchRec] Add support for FakeProcessGroup for EBC #2228

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,14 +559,25 @@ def variable_batch_all2all_pooled_sync(
]

with record_function("## alltoall_fwd_single ##"):
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
pg_name(pg),
pg.size(),
get_gradient_division(),
)
if pg._get_backend_name() == "fake":
sharded_output_embeddings = torch.empty(
sum(output_split_sizes),
device=sharded_input_embeddings.device,
dtype=sharded_input_embeddings.dtype,
)
s0 = sharded_output_embeddings.size(0)
# Bad assumption that our rank GE than other
torch._check(s0 <= sharded_input_embeddings.size(0))
sharded_output_embeddings.copy_(sharded_input_embeddings[:s0])
else:
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
pg_name(pg),
pg.size(),
get_gradient_division(),
)

if a2ai.codecs is not None:
codecs = none_throws(a2ai.codecs)
Expand Down
54 changes: 42 additions & 12 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,25 @@ def __init__(
# https://github.com/pytorch/pytorch/issues/122788
with record_function("## all2all_data:kjt splits ##"):
input_tensor = torch.stack(input_tensors, dim=1).flatten()
self._output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split_sizes=None,
input_split_sizes=None,
group=pg,
)
if pg._get_backend_name() == "fake":
self._output_tensor = torch.empty(
[self.num_workers * len(input_tensors)],
device=input_tensors[0].device,
dtype=input_tensors[0].dtype,
)

self._output_tensor = input_tensor[
: input_tensor.size(0) // 2
].repeat(2)
else:
self._output_tensor = (
dist._functional_collectives.all_to_all_single(
input_tensor,
output_split_sizes=None,
input_split_sizes=None,
group=pg,
)
)
# To avoid hasattr in _wait_impl to check self._splits_awaitable
# pyre-ignore
self._splits_awaitable = None
Expand Down Expand Up @@ -342,6 +355,7 @@ def __init__(
self._output_tensors: List[torch.Tensor] = []
self._awaitables: List[dist.Work] = []
self._world_size: int = self._pg.size()
rank = dist.get_rank(self._pg)

for input_split, output_split, input_tensor, label in zip(
input_splits,
Expand All @@ -353,12 +367,28 @@ def __init__(
# TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_
# https://github.com/pytorch/pytorch/issues/122788
with record_function(f"## all2all_data:kjt {label} ##"):
output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split,
input_split,
pg,
)
if self._pg._get_backend_name() == "fake":
output_tensor = torch.empty(
sum(output_split),
device=self._device,
dtype=input_tensor.dtype,
)
_l = sum(output_split[:rank])
_r = _l + output_split[rank]
torch._check(_r < input_tensor.size(0))
torch._check(_l < input_tensor.size(0))
torch._check(_l <= _r)
torch._check(2 * (_r - _l) == output_tensor.size(0))
output_tensor.copy_(
input_tensor[_l:_r].repeat(self._world_size)
)
else:
output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split,
input_split,
pg,
)
self._output_tensors.append(output_tensor)
else:
output_tensor = torch.empty(
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def __init__(
broadcast_buffers=True,
static_graph=True,
)
self._initialize_torch_state()

if env.process_group and dist.get_backend(env.process_group) != "fake":
self._initialize_torch_state()

# TODO[zainhuda]: support module device coming from CPU
if module.device not in ["meta", "cpu"] and module.device.type not in [
Expand Down
194 changes: 194 additions & 0 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from hypothesis import given, settings, strategies as st, Verbosity
from torch import distributed as dist
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.distributed import ProcessGroup
from torch.testing._internal.distributed.fake_pg import FakeStore
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig
Expand Down Expand Up @@ -499,6 +501,184 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
##### NUMERIC CHECK END #####


def _test_compile_fake_pg_fn(
rank: int,
world_size: int,
) -> None:
sharding_type = ShardingType.TABLE_WISE.value
input_type = _InputType.SINGLE_BATCH
torch_compile_backend = "eager"
config = _TestConfig()
num_embeddings = 256
# emb_dim must be % 4 == 0 for fbgemm
emb_dim = 12
batch_size = 10
num_features: int = 5

num_float_features: int = 8
num_weighted_features: int = 1

device: torch.Device = torch.device("cuda")
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
pg: ProcessGroup = dist.distributed_c10d._get_default_group()

topology: Topology = Topology(world_size=world_size, compute_device="cuda")
mi = TestModelInfo(
dense_device=device,
sparse_device=device,
num_features=num_features,
num_float_features=num_float_features,
num_weighted_features=num_weighted_features,
topology=topology,
)

mi.planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
enumerator=EmbeddingEnumerator(
topology=topology,
batch_size=batch_size,
estimator=[
EmbeddingPerfEstimator(topology=topology),
EmbeddingStorageEstimator(topology=topology),
],
),
)

mi.tables = [
EmbeddingBagConfig(
num_embeddings=num_embeddings,
embedding_dim=emb_dim,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(mi.num_features)
]

mi.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=num_embeddings,
embedding_dim=emb_dim,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(mi.num_weighted_features)
]

mi.model = _gen_model(_ModelType.EBC, mi)
mi.model.training = True

model = mi.model

planner = EmbeddingShardingPlanner(
topology=Topology(world_size, device.type),
constraints=None,
)

sharders = [
EBCSharderFixedShardingType(sharding_type),
ECSharderFixedShardingType(sharding_type),
]

plan: ShardingPlan = planner.plan(model, sharders) # pyre-ignore

def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore
return DistributedModelParallel(
m,
env=ShardingEnv(world_size, rank, pg),
plan=plan,
sharders=sharders,
device=device,
init_data_parallel=False,
)

dmp = _dmp(model)
dmp_compile = _dmp(model)

# TODO: Fix some data dependent failures on subsequent inputs
n_extra_numerics_checks = config.n_extra_numerics_checks_inputs
ins = []

for _ in range(1 + n_extra_numerics_checks):
if input_type == _InputType.VARIABLE_BATCH:
(
_,
local_model_inputs,
) = ModelInput.generate_variable_batch_input(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
# pyre-ignore
tables=mi.tables,
)
else:
(
_,
local_model_inputs,
) = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
tables=mi.tables,
weighted_tables=mi.weighted_tables,
variable_batch_size=False,
)
ins.append(local_model_inputs)

local_model_input = ins[0][rank].to(device)

kjt = local_model_input.idlist_features
ff = local_model_input.float_features
ff.requires_grad = True
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=True)

compile_input_ff = ff.clone().detach()
compile_input_ff.requires_grad = True

torchrec.distributed.comm_ops.set_use_sync_collectives(True)
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)

dmp.train(True)
dmp_compile.train(True)

def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
tbe = dmp._dmp_wrapped_module._ebc._lookups[0]._emb_modules[0]._emb_module
assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen)
return tbe.weights_dev.clone().detach()

original_weights = get_weights(dmp)
original_weights.zero_()
original_compile_weights = get_weights(dmp_compile)
original_compile_weights.zero_()

eager_out = dmp(kjt_ft, ff)
reduce_to_scalar_loss(eager_out).backward()

if torch_compile_backend is None:
return

##### COMPILE #####
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = True

opt_fn = torch.compile(
dmp_compile,
backend=torch_compile_backend,
fullgraph=True,
)
compile_out = opt_fn(
kjt_for_pt2_tracing(kjt, convert_to_vb=True), compile_input_ff
)
torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3)
##### COMPILE END #####


class TestPt2Train(MultiProcessTestBase):
def disable_cuda_tf32(self) -> bool:
return True
Expand Down Expand Up @@ -580,3 +760,17 @@ def test_compile_multiprocess(
config=config,
torch_compile_backend=compile_backend,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() < 1,
"Not enough GPUs, this test requires one GPU",
)
@settings(verbosity=Verbosity.verbose, deadline=None)
def test_compile_multiprocess_fake_pg(
self,
) -> None:
_test_compile_fake_pg_fn(
rank=0,
world_size=2,
)
Loading