Skip to content

Commit d11d3dd

Browse files
Rodrigo Kumperapytorchmergebot
Rodrigo Kumpera
authored andcommitted
[dist.cp] Introduce LoadPlanner and SavePlanner extensibility API. (pytorch#83419)
The planners come with default implementations in default_planner.py. The default planners expose their core functionality as separate functions to make it easy for other checkpoint implementations to use this functionality. Pull Request resolved: pytorch#83419 Approved by: https://github.com/wanchaol
1 parent 4a033be commit d11d3dd

File tree

8 files changed

+1026
-3
lines changed

8 files changed

+1026
-3
lines changed

Diff for: test/distributed/_shard/checkpoint/test_planner.py

+268
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Owner(s): ["oncall: distributed"]
2+
3+
import sys
4+
5+
import torch
6+
from torch.distributed._shard.checkpoint.planner import LoadItemType, WriteItemType
7+
8+
from torch.distributed._shard.sharded_tensor import (
9+
Shard,
10+
ShardMetadata,
11+
ShardedTensor,
12+
ShardedTensorMetadata,
13+
)
14+
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
15+
16+
from torch.testing._internal.common_utils import (
17+
TestCase,
18+
TEST_WITH_DEV_DBG_ASAN,
19+
run_tests,
20+
)
21+
from torch.distributed._shard.checkpoint.metadata import BytesStorageMetadata, MetadataIndex, TensorStorageMetadata
22+
from torch.testing._internal.distributed.distributed_utils import (
23+
with_fake_comms,
24+
with_dist
25+
)
26+
27+
from torch.distributed._shard.checkpoint.default_planner import (
28+
create_default_global_save_plan,
29+
create_default_local_save_plan,
30+
create_default_local_load_plan,
31+
_create_default_local_metadata
32+
)
33+
34+
if TEST_WITH_DEV_DBG_ASAN:
35+
print(
36+
"Skip dev-asan as torch + multiprocessing spawn have known issues",
37+
file=sys.stderr,
38+
)
39+
sys.exit(0)
40+
41+
def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
42+
shards_metadata = []
43+
local_shards = []
44+
for idx in range(0, world_size * shards_per_rank):
45+
shard_rank = idx // shards_per_rank
46+
shard_md = ShardMetadata(shard_offsets=[idx * shard_size], shard_sizes=[shard_size], placement=f"rank:{shard_rank}/cpu")
47+
shards_metadata.append(shard_md)
48+
if shard_rank == rank:
49+
shard = Shard.from_tensor_and_offsets(
50+
torch.rand(*shard_md.shard_sizes),
51+
shard_offsets=shard_md.shard_offsets,
52+
rank=rank
53+
)
54+
local_shards.append(shard)
55+
56+
sharded_tensor_md = ShardedTensorMetadata(
57+
shards_metadata=shards_metadata,
58+
size=torch.Size([shard_size * len(shards_metadata)]),
59+
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
60+
)
61+
62+
return ShardedTensor._init_from_local_shards_and_global_metadata(
63+
local_shards=local_shards,
64+
sharded_tensor_metadata=sharded_tensor_md
65+
)
66+
67+
68+
class TestSavePlan(TestCase):
69+
@with_fake_comms(rank=1, world_size=4)
70+
def test_local_plan(self):
71+
tensor = torch.rand(10)
72+
val = [1, 2, 3]
73+
st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1)
74+
state_dict = {
75+
"tensor": tensor,
76+
"value": val,
77+
"st": st
78+
}
79+
plan = create_default_local_save_plan(state_dict, False)
80+
self.assertEqual(1, len(plan.items))
81+
wi = plan.items[0]
82+
self.assertEqual(wi.index, MetadataIndex("st", [8]))
83+
self.assertEqual(wi.type, WriteItemType.SHARD)
84+
self.assertEqual(wi.tensor_data.size, st.size())
85+
self.assertEqual(wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
86+
self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([8]))
87+
self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([8]))
88+
89+
# Coordinator rank, should include replicated items as well
90+
plan = create_default_local_save_plan(state_dict, True)
91+
self.assertEqual(3, len(plan.items))
92+
93+
tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR)
94+
self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0]))
95+
self.assertEqual(tensor_wi.tensor_data.size, tensor.size())
96+
self.assertEqual(tensor_wi.tensor_data.properties, TensorProperties.create_from_tensor(tensor))
97+
self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0]))
98+
self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10]))
99+
100+
bytes_wi = next(wi for wi in plan.items if wi.type == WriteItemType.BYTE_IO)
101+
self.assertEqual(bytes_wi.index, MetadataIndex("value"))
102+
self.assertIsNone(bytes_wi.tensor_data)
103+
104+
def test_global_plan(self):
105+
def create_data(rank):
106+
with with_dist(rank=rank, world_size=4):
107+
tensor = torch.rand(10)
108+
val = [1, 2, 3]
109+
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
110+
state_dict = {
111+
"tensor": tensor,
112+
"value": val,
113+
"st": st
114+
}
115+
return create_default_local_save_plan(state_dict, rank == 0)
116+
117+
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
118+
final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
119+
# The default global plan updates all indexes to include hints
120+
for new_plan, old_plan in zip(final_plans, all_plans):
121+
for new_item, old_item in zip(new_plan.items, old_plan.items):
122+
self.assertEqual(new_item.index, old_item.index)
123+
self.assertEqual(new_item.type, old_item.type)
124+
self.assertEqual(new_item.tensor_data, old_item.tensor_data)
125+
self.assertIn(new_item.index.fqn, metadata.state_dict_metadata)
126+
127+
item_md = metadata.state_dict_metadata[new_item.index.fqn]
128+
if new_item.type == WriteItemType.BYTE_IO:
129+
self.assertTrue(isinstance(item_md, BytesStorageMetadata))
130+
else:
131+
self.assertTrue(isinstance(item_md, TensorStorageMetadata))
132+
self.assertEqual(item_md.size, old_item.tensor_data.size)
133+
self.assertEqual(item_md.properties, old_item.tensor_data.properties)
134+
135+
self.assertIsNotNone(new_item.index.index)
136+
# Make sure the hint is correct
137+
self.assertEqual(item_md.chunks[new_item.index.index], old_item.tensor_data.chunk)
138+
139+
def test_local_load_plan(self):
140+
def create_state_dict(rank):
141+
with with_dist(rank=rank, world_size=4):
142+
tensor = torch.rand(10)
143+
val = [1, 2, 3]
144+
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
145+
return {
146+
"tensor": tensor,
147+
"value": val,
148+
"st": st
149+
}
150+
151+
state_dict = create_state_dict(1)
152+
metadata = _create_default_local_metadata(state_dict)
153+
154+
load_plan = create_default_local_load_plan(state_dict, metadata)
155+
# This will create 3 entries
156+
self.assertEqual(3, len(load_plan.items))
157+
st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st")
158+
tensor_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "tensor")
159+
bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value")
160+
161+
self.assertEqual(st_item.type, LoadItemType.TENSOR)
162+
# This is an exact copy
163+
self.assertEqual(st_item.dest_index, MetadataIndex("st", [8]))
164+
self.assertEqual(st_item.dest_offsets, torch.Size([0]))
165+
self.assertEqual(st_item.storage_index, MetadataIndex("st", [8]))
166+
self.assertEqual(st_item.storage_offsets, torch.Size([0]))
167+
self.assertEqual(st_item.lengths, torch.Size([8]))
168+
169+
self.assertEqual(tensor_item.type, LoadItemType.TENSOR)
170+
self.assertEqual(tensor_item.dest_index, MetadataIndex("tensor", [0]))
171+
self.assertEqual(tensor_item.dest_offsets, torch.Size([0]))
172+
self.assertEqual(tensor_item.storage_index, MetadataIndex("tensor", [0]))
173+
self.assertEqual(tensor_item.storage_offsets, torch.Size([0]))
174+
self.assertEqual(tensor_item.lengths, torch.Size([10]))
175+
176+
self.assertEqual(bytes_item.type, LoadItemType.BYTE_IO)
177+
self.assertEqual(bytes_item.dest_index, MetadataIndex("value"))
178+
179+
def test_load_with_resharding(self):
180+
def create_state_dict(rank, world_size):
181+
with with_dist(rank=rank, world_size=world_size):
182+
return {
183+
"st": create_sharded_tensor(
184+
rank=rank,
185+
world_size=world_size,
186+
shards_per_rank=1,
187+
shard_size=128 // world_size,
188+
)
189+
}
190+
191+
192+
# Rank 1 has a 16 bytes shard from [16, 32[
193+
world8_state_dict = create_state_dict(rank=1, world_size=8)
194+
world8_metadata = _create_default_local_metadata(world8_state_dict)
195+
196+
# Rank 1 has a 32 bytes shard from [32, 64[
197+
world4_state_dict = create_state_dict(rank=1, world_size=4)
198+
world4_metadata = _create_default_local_metadata(world4_state_dict)
199+
200+
# First scenario, going from world=8 to world=4, need to load 2 shards
201+
# Each 4-world shard has 32 elements, so it needs to load 2 shards
202+
load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata)
203+
self.assertEqual(2, len(load_plan.items))
204+
low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
205+
high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16]))
206+
207+
self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32]))
208+
self.assertEqual(low_ri.storage_offsets, torch.Size([0]))
209+
self.assertEqual(low_ri.dest_index, MetadataIndex("st", [32]))
210+
self.assertEqual(low_ri.dest_offsets, torch.Size([0]))
211+
self.assertEqual(low_ri.lengths, torch.Size([16]))
212+
213+
self.assertEqual(high_ri.storage_index, MetadataIndex("st", [48]))
214+
self.assertEqual(high_ri.storage_offsets, torch.Size([0]))
215+
self.assertEqual(high_ri.dest_index, MetadataIndex("st", [32]))
216+
self.assertEqual(high_ri.dest_offsets, torch.Size([16]))
217+
self.assertEqual(high_ri.lengths, torch.Size([16]))
218+
219+
# Second scenario, going from world=4 to world=8, need to load half of 1 shard
220+
# rank1 on 8-world needs to load the upper half of the rank0 4-world shard
221+
load_plan = create_default_local_load_plan(world8_state_dict, world4_metadata)
222+
self.assertEqual(1, len(load_plan.items))
223+
ri = load_plan.items[0]
224+
self.assertEqual(ri.storage_index, MetadataIndex("st", [0]))
225+
self.assertEqual(ri.storage_offsets, torch.Size([16]))
226+
self.assertEqual(ri.dest_index, MetadataIndex("st", [16]))
227+
self.assertEqual(ri.dest_offsets, torch.Size([0]))
228+
self.assertEqual(ri.lengths, torch.Size([16]))
229+
230+
def test_load_with_world_size_diff_by_one(self):
231+
def create_state_dict(rank, world_size):
232+
with with_dist(rank=rank, world_size=world_size):
233+
return {
234+
"st": create_sharded_tensor(
235+
rank=rank,
236+
world_size=world_size,
237+
shards_per_rank=1,
238+
shard_size=120 // world_size,
239+
)
240+
}
241+
# rank 1 has a 30 bytes shard from [30, 60[
242+
world4_state_dict = create_state_dict(rank=1, world_size=4)
243+
world4_metadata = _create_default_local_metadata(world4_state_dict)
244+
245+
# rank 1 has a 40 bytes shard from [40, 80[
246+
world3_state_dict = create_state_dict(rank=1, world_size=3)
247+
248+
load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata)
249+
self.assertEqual(2, len(load_plan.items))
250+
# this is [30, 60] to load [40, 60]
251+
low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
252+
# this is [60, 90] to load [60, 80]
253+
high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20]))
254+
255+
self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30]))
256+
self.assertEqual(low_ri.storage_offsets, torch.Size([10]))
257+
self.assertEqual(low_ri.dest_index, MetadataIndex("st", [40]))
258+
self.assertEqual(low_ri.dest_offsets, torch.Size([0]))
259+
self.assertEqual(low_ri.lengths, torch.Size([20]))
260+
261+
self.assertEqual(high_ri.storage_index, MetadataIndex("st", [60]))
262+
self.assertEqual(high_ri.storage_offsets, torch.Size([0]))
263+
self.assertEqual(high_ri.dest_index, MetadataIndex("st", [40]))
264+
self.assertEqual(high_ri.dest_offsets, torch.Size([20]))
265+
self.assertEqual(high_ri.lengths, torch.Size([20]))
266+
267+
if __name__ == "__main__":
268+
run_tests()

Diff for: test/distributed/_shard/checkpoint/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def test_flat_data(self):
8181

8282
a = find_state_dict_object(state_dict, MetadataIndex("a"))
8383
self.assertEqual(a, state_dict["a"])
84+
a = find_state_dict_object(state_dict, MetadataIndex("a", [0]))
85+
self.assertEqual(a, state_dict["a"])
8486
a = find_state_dict_object(state_dict, MetadataIndex("a", index=99))
8587
self.assertEqual(a, state_dict["a"])
8688

@@ -91,8 +93,6 @@ def test_flat_data(self):
9193

9294
with self.assertRaisesRegex(ValueError, "FQN"):
9395
find_state_dict_object(state_dict, MetadataIndex("c"))
94-
with self.assertRaisesRegex(ValueError, "ShardedTensor"):
95-
find_state_dict_object(state_dict, MetadataIndex("a", [0]))
9696
with self.assertRaisesRegex(ValueError, "ShardedTensor"):
9797
find_state_dict_object(state_dict, MetadataIndex("b", [1]))
9898

Diff for: torch/distributed/_shard/checkpoint/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,13 @@
1313
from .storage import StorageReader, StorageWriter
1414
from .filesystem import FileSystemReader, FileSystemWriter
1515
from .api import CheckpointException
16+
17+
18+
from .planner import (
19+
SavePlanner,
20+
LoadPlanner,
21+
SavePlan,
22+
LoadPlan,
23+
ReadItem,
24+
WriteItem,
25+
)

0 commit comments

Comments
 (0)