Skip to content

Commit

Permalink
Add test for mix sharding for mutiple tables in SQEBC
Browse files Browse the repository at this point in the history
Summary: As titled

Reviewed By: jingsh

Differential Revision: D54313506

fbshipit-source-id: 4317d02ee46bd96ece37df53e979d5e8698a170c
  • Loading branch information
gnahzg authored and facebook-github-bot committed Feb 29, 2024
1 parent e54532d commit 1f34283
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
)
from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder
from torchrec.distributed.quant_embeddingbag import (
QuantEmbeddingBagCollection,
QuantEmbeddingBagCollectionSharder,
QuantFeatureProcessedEmbeddingBagCollectionSharder,
ShardedQuantEmbeddingBagCollection,
)
from torchrec.distributed.quant_state import sharded_tbes_weights_spec, WeightSpec
from torchrec.distributed.shard import _shard_modules
Expand Down Expand Up @@ -1138,6 +1141,84 @@ def test_rw_uneven_sharding_mutiple_table(
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 3,
"Not enough GPUs available",
)
# pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`.
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
device=st.sampled_from(["cuda"]), # TODO: add cpu test when it's fixed
)
@settings(max_examples=4, deadline=None)
def test_mix_sharding_mutiple_table(
self,
weight_dtype: torch.dtype,
device: str,
) -> None:
num_embeddings = 512
emb_dim = 64
local_size = 4
world_size = 4
batch_size = 1
local_device = torch.device("cuda:0" if device == "cuda" else device)
mi = create_test_model(
num_embeddings,
emb_dim,
world_size,
batch_size,
dense_device=local_device,
sparse_device=local_device,
quant_state_dict_split_scale_bias=True,
weight_dtype=weight_dtype,
num_features=4,
)

non_sharded_model = mi.quant_model

sharder = QuantEmbeddingBagCollectionSharder()

module_plan = construct_module_sharding_plan(
non_sharded_model._module.sparse.ebc,
per_param_sharding={
"table_0": row_wise(
([256, 128, 64, 64], device),
),
"table_1": row_wise(([128, 128, 128, 128], device)),
"table_2": column_wise(ranks=[0, 1]),
"table_3": table_wise(rank=0),
},
# pyre-ignore
sharder=sharder,
local_size=local_size,
world_size=world_size,
)

plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan})

sharded_model = shard_qebc(
mi=mi,
sharding_type=ShardingType.ROW_WISE,
device=local_device,
expected_shards=None, # expected_shards,
plan=plan,
)
inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(non_sharded_model.state_dict())

# We need this first inference to make all lazy init in forward
sharded_output = sharded_model(*inputs[0])
non_sharded_output = non_sharded_model(*inputs[0])
assert_close(non_sharded_output, sharded_output)

gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
gm_script = torch.jit.script(gm)
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
Expand Down

0 comments on commit 1f34283

Please sign in to comment.