Skip to content

Commit 0a04f57

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Add test_load_state_dict for FP-ebc (#1709)
Summary: Add test_load_state_dict for FP-ebc. Differential Revision: D53839126
1 parent 0eca3ca commit 0a04f57

File tree

3 files changed

+326
-87
lines changed

3 files changed

+326
-87
lines changed

torchrec/distributed/tests/test_fp_embeddingbag.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from torchrec.distributed.test_utils.test_sharding import copy_state_dict
3737
from torchrec.distributed.tests.test_fp_embeddingbag_utils import (
3838
create_module_and_freeze,
39+
get_configs_and_kjt_inputs,
3940
)
4041
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan
4142
from torchrec.modules.embedding_configs import EmbeddingBagConfig
@@ -209,92 +210,6 @@ def _test_sharding_from_meta( # noqa C901
209210
torch.testing.assert_close(param, torch.ones_like(param))
210211

211212

212-
def get_configs_and_kjt_inputs() -> Tuple[
213-
List[EmbeddingBagConfig], List[KeyedJaggedTensor]
214-
]:
215-
embedding_bag_config = [
216-
EmbeddingBagConfig(
217-
name="table_0",
218-
feature_names=["feature_0"],
219-
embedding_dim=3 * 16,
220-
num_embeddings=16,
221-
),
222-
EmbeddingBagConfig(
223-
name="table_1",
224-
feature_names=["feature_1"],
225-
embedding_dim=8,
226-
num_embeddings=16,
227-
),
228-
EmbeddingBagConfig(
229-
name="table_2",
230-
feature_names=["feature_2"],
231-
embedding_dim=8,
232-
num_embeddings=16,
233-
),
234-
EmbeddingBagConfig(
235-
name="table_3",
236-
feature_names=["feature_3"],
237-
embedding_dim=3 * 16,
238-
num_embeddings=16,
239-
),
240-
]
241-
242-
# Rank 0
243-
# instance 0 instance 1 instance 2
244-
# "feature_0" [0, 1] None [2]
245-
# "feature_1" [0, 1] None [2]
246-
# "feature_2" [3, 1] [4,1] [5]
247-
# "feature_3" [1] [6,1,8] [0,3,3]
248-
249-
# Rank 1
250-
251-
# instance 0 instance 1 instance 2
252-
# "feature_0" [3, 2] [1,2] [0,1,2,3]
253-
# "feature_1" [2, 3] None [2]
254-
# "feature_2" [2, 7] [1,8,2] [8,1]
255-
# "feature_3" [9] [8] [7]
256-
257-
kjt_input_per_rank = [ # noqa
258-
KeyedJaggedTensor.from_lengths_sync(
259-
keys=["feature_0", "feature_1", "feature_2", "feature_3"],
260-
values=torch.LongTensor(
261-
[0, 1, 2, 0, 1, 2, 3, 1, 4, 1, 5, 1, 6, 1, 8, 0, 3, 3]
262-
),
263-
lengths=torch.LongTensor(
264-
[
265-
2,
266-
0,
267-
1,
268-
2,
269-
0,
270-
1,
271-
2,
272-
2,
273-
1,
274-
1,
275-
3,
276-
3,
277-
]
278-
),
279-
weights=torch.FloatTensor(
280-
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
281-
),
282-
),
283-
KeyedJaggedTensor.from_lengths_sync(
284-
keys=["feature_0", "feature_1", "feature_2", "feature_3"],
285-
values=torch.LongTensor(
286-
[3, 2, 1, 2, 0, 1, 2, 3, 2, 3, 2, 2, 7, 1, 8, 2, 8, 1, 9, 8, 7]
287-
),
288-
lengths=torch.LongTensor([2, 2, 4, 2, 0, 1, 2, 3, 2, 1, 1, 1]),
289-
weights=torch.FloatTensor(
290-
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
291-
),
292-
),
293-
]
294-
295-
return embedding_bag_config, kjt_input_per_rank
296-
297-
298213
@skip_if_asan_class
299214
class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase):
300215
@unittest.skipIf(
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import os
9+
import unittest
10+
from typing import cast, Dict, List, Optional, OrderedDict, Tuple
11+
12+
import torch
13+
import torch.nn as nn
14+
from hypothesis import given, settings, strategies as st, Verbosity
15+
from torch import distributed as dist
16+
from torchrec import distributed as trec_dist
17+
from torchrec.distributed import DistributedModelParallel
18+
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
19+
from torchrec.distributed.model_parallel import get_default_sharders
20+
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
21+
from torchrec.distributed.test_utils.test_model import ModelInput
22+
from torchrec.distributed.tests.test_fp_embeddingbag_utils import (
23+
create_module_and_freeze,
24+
get_configs_and_kjt_inputs,
25+
TestFPEBCSharder,
26+
)
27+
from torchrec.distributed.types import (
28+
ModuleSharder,
29+
ShardedTensor,
30+
ShardingEnv,
31+
ShardingType,
32+
)
33+
from torchrec.test_utils import get_free_port
34+
35+
36+
class FPModelParallelStateDictTest(unittest.TestCase):
37+
def setUp(self) -> None:
38+
os.environ["RANK"] = "0"
39+
os.environ["WORLD_SIZE"] = "1"
40+
os.environ["LOCAL_WORLD_SIZE"] = "1"
41+
os.environ["MASTER_ADDR"] = str("localhost")
42+
os.environ["MASTER_PORT"] = str(get_free_port())
43+
44+
self.backend = "nccl"
45+
if torch.cuda.is_available():
46+
self.device = torch.device("cuda:0")
47+
torch.cuda.set_device(self.device)
48+
else:
49+
self.device = torch.device("cpu")
50+
51+
dist.init_process_group(backend=self.backend)
52+
53+
self.tables, self.kjt_input_per_rank = get_configs_and_kjt_inputs()
54+
55+
def tearDown(self) -> None:
56+
dist.destroy_process_group()
57+
58+
def _generate_dmps_and_batch(
59+
self,
60+
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
61+
constraints: Optional[Dict[str, trec_dist.planner.ParameterConstraints]] = None,
62+
use_fp_collection: bool = True,
63+
) -> Tuple[List[DistributedModelParallel], ModelInput]:
64+
"""
65+
Generate two DMPs based on Sequence Sparse NN and one batch of data.
66+
"""
67+
if constraints is None:
68+
constraints = {}
69+
if sharders is None:
70+
sharders = get_default_sharders()
71+
72+
batch = self.kjt_input_per_rank[0].to(self.device)
73+
74+
dmps = []
75+
pg = dist.GroupMember.WORLD
76+
assert pg is not None, "Process group is not initialized"
77+
env = ShardingEnv.from_process_group(pg)
78+
79+
planner = EmbeddingShardingPlanner(
80+
topology=Topology(
81+
local_world_size=trec_dist.comm.get_local_size(env.world_size),
82+
world_size=env.world_size,
83+
compute_device=self.device.type,
84+
),
85+
constraints=constraints,
86+
)
87+
88+
for _ in range(2):
89+
# Create two TestSparseNN modules, wrap both in DMP
90+
m = create_module_and_freeze(
91+
tables=self.tables,
92+
use_fp_collection=use_fp_collection,
93+
device=torch.device("meta"),
94+
)
95+
if pg is not None:
96+
plan = planner.collective_plan(m, sharders, pg)
97+
else:
98+
plan = planner.plan(m, sharders)
99+
100+
dmp = DistributedModelParallel(
101+
module=m,
102+
init_data_parallel=False,
103+
device=self.device,
104+
sharders=sharders,
105+
plan=plan,
106+
)
107+
108+
with torch.no_grad():
109+
dmp(batch)
110+
dmp.init_data_parallel()
111+
dmps.append(dmp)
112+
return (dmps, batch)
113+
114+
@unittest.skipIf(
115+
torch.cuda.device_count() <= 0,
116+
"Not enough GPUs, this test requires at least one GPU",
117+
)
118+
# pyre-ignore[56]
119+
@given(
120+
sharding_type=st.sampled_from(
121+
[
122+
ShardingType.TABLE_WISE.value,
123+
ShardingType.COLUMN_WISE.value,
124+
ShardingType.TABLE_COLUMN_WISE.value,
125+
]
126+
),
127+
kernel_type=st.sampled_from(
128+
[
129+
EmbeddingComputeKernel.FUSED.value,
130+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
131+
EmbeddingComputeKernel.FUSED_UVM.value,
132+
]
133+
),
134+
is_training=st.booleans(),
135+
use_fp_collection=st.booleans(),
136+
)
137+
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
138+
def test_load_state_dict(
139+
self,
140+
sharding_type: str,
141+
kernel_type: str,
142+
is_training: bool,
143+
use_fp_collection: bool,
144+
) -> None:
145+
sharders = [
146+
cast(
147+
ModuleSharder[nn.Module],
148+
TestFPEBCSharder(
149+
sharding_type=sharding_type,
150+
kernel_type=kernel_type,
151+
),
152+
),
153+
]
154+
models, batch = self._generate_dmps_and_batch(
155+
sharders=sharders, use_fp_collection=use_fp_collection
156+
)
157+
m1, m2 = models
158+
159+
# load the second's (m2's) with the first (m1's) state_dict
160+
m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict()))
161+
162+
# validate the models are equivalent
163+
if is_training:
164+
for _ in range(2):
165+
loss1, pred1 = m1(batch)
166+
loss2, pred2 = m2(batch)
167+
loss1.backward()
168+
loss2.backward()
169+
self.assertTrue(torch.equal(loss1, loss2))
170+
self.assertTrue(torch.equal(pred1, pred2))
171+
else:
172+
with torch.no_grad():
173+
loss1, pred1 = m1(batch)
174+
loss2, pred2 = m2(batch)
175+
self.assertTrue(torch.equal(loss1, loss2))
176+
self.assertTrue(torch.equal(pred1, pred2))
177+
178+
sd1 = m1.state_dict()
179+
for key, value in m2.state_dict().items():
180+
v2 = sd1[key]
181+
if isinstance(value, ShardedTensor):
182+
assert len(value.local_shards()) == 1
183+
dst = value.local_shards()[0].tensor
184+
else:
185+
dst = value
186+
if isinstance(v2, ShardedTensor):
187+
assert len(v2.local_shards()) == 1
188+
src = v2.local_shards()[0].tensor
189+
else:
190+
src = v2
191+
self.assertTrue(torch.equal(src, dst))

0 commit comments

Comments
 (0)