Skip to content

Commit

Permalink
Merge pull request #16 from OSU-Nowlab/refactor_SP
Browse files Browse the repository at this point in the history
Refactor SP benchmarks
  • Loading branch information
Quentin-Anthony authored Oct 25, 2023
2 parents ca0a066 + e7699d5 commit c671670
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 453 deletions.
235 changes: 11 additions & 224 deletions benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import logging
from torchgems import parser
from torchgems.mp_pipeline import model_generator
from torchgems.train_spatial import train_model_spatial
from torchgems.train_spatial import train_model_spatial, split_input, get_shapes_spatial
import torchgems.comm as gems_comm

parser_obj = parser.get_parser()
Expand Down Expand Up @@ -205,178 +205,13 @@ def verify_config():

# Get the shape of model on each split rank for image_size and number of spatial parts
image_size_times = int(image_size / image_size_seq)
temp_count = 0
if args.slice_method == "square":
amoebanet_shapes_list = []
for output_shape in model_gen_seq.shape_list:
if isinstance(output_shape, list):
temp_shape = []
for shape_tuple in output_shape:
if temp_count < spatial_size:
# reduce shape only when it is smaller than spatial size
x = (
int(shape_tuple[0]),
shape_tuple[1],
int(
shape_tuple[2]
* image_size_times
/ int(math.sqrt(spatial_part_size))
),
int(
shape_tuple[3]
* image_size_times
/ int(math.sqrt(spatial_part_size))
),
)
temp_shape.append(x)
else:
x = (
int(shape_tuple[0]),
shape_tuple[1],
int(shape_tuple[2] * image_size_times),
int(shape_tuple[3] * image_size_times),
)
temp_shape.append(x)
amoebanet_shapes_list.append(temp_shape)
else:
if len(output_shape) == 2:
x = (int(output_shape[0]), output_shape[1])
amoebanet_shapes_list.append(x)
else:
if temp_count < spatial_size:
x = (
int(output_shape[0]),
output_shape[1],
int(
output_shape[2]
* image_size_times
/ int(math.sqrt(spatial_part_size))
),
int(
output_shape[3]
* image_size_times
/ int(math.sqrt(spatial_part_size))
),
)
amoebanet_shapes_list.append(x)
else:
x = (
int(output_shape[0]),
output_shape[1],
int(output_shape[2] * image_size_times),
int(output_shape[3] * image_size_times),
)
amoebanet_shapes_list.append(x)
temp_count += 1

elif args.slice_method == "vertical":
amoebanet_shapes_list = []
for output_shape in model_gen_seq.shape_list:
if isinstance(output_shape, list):
temp_shape = []
for shape_tuple in output_shape:
if temp_count < spatial_size:
x = (
int(shape_tuple[0]),
shape_tuple[1],
int(shape_tuple[2] * image_size_times / 1),
int(
shape_tuple[3]
* image_size_times
/ num_spatial_parts_list[temp_count]
),
)
temp_shape.append(x)
else:
x = (
int(shape_tuple[0]),
shape_tuple[1],
int(shape_tuple[2] * image_size_times),
int(shape_tuple[3] * image_size_times),
)
temp_shape.append(x)
amoebanet_shapes_list.append(temp_shape)
else:
if len(output_shape) == 2:
x = (int(output_shape[0]), output_shape[1])
amoebanet_shapes_list.append(x)
else:
if temp_count < spatial_size:
x = (
int(output_shape[0]),
output_shape[1],
int(output_shape[2] * image_size_times / 1),
int(
output_shape[3]
* image_size_times
/ num_spatial_parts_list[temp_count]
),
)
amoebanet_shapes_list.append(x)
else:
x = (
int(output_shape[0]),
output_shape[1],
int(output_shape[2] * image_size_times),
int(output_shape[3] * image_size_times),
)
amoebanet_shapes_list.append(x)
temp_count += 1


elif args.slice_method == "horizontal":
amoebanet_shapes_list = []
for output_shape in model_gen_seq.shape_list:
if isinstance(output_shape, list):
temp_shape = []
for shape_tuple in output_shape:
if temp_count < spatial_size:
x = (
int(shape_tuple[0]),
shape_tuple[1],
int(
shape_tuple[2]
* image_size_times
/ num_spatial_parts_list[temp_count]
),
int(shape_tuple[3] * image_size_times / 1),
)
temp_shape.append(x)
else:
x = (
int(shape_tuple[0]),
shape_tuple[1],
int(shape_tuple[2] * image_size_times),
int(shape_tuple[3] * image_size_times),
)
temp_shape.append(x)
amoebanet_shapes_list.append(temp_shape)
else:
if len(output_shape) == 2:
x = (int(output_shape[0]), output_shape[1])
amoebanet_shapes_list.append(x)
else:
if temp_count < spatial_size:
x = (
int(output_shape[0]),
output_shape[1],
int(
output_shape[2]
* image_size_times
/ num_spatial_parts_list[temp_count]
),
int(output_shape[3] * image_size_times / 1),
)
amoebanet_shapes_list.append(x)
else:
x = (
int(output_shape[0]),
output_shape[1],
int(output_shape[2] * image_size_times),
int(output_shape[3] * image_size_times),
)
amoebanet_shapes_list.append(x)
temp_count += 1
amoebanet_shapes_list = get_shapes_spatial(
model_gen_seq.shape_list,
args.slice_method,
spatial_size,
num_spatial_parts_list,
image_size_times,
)

del model_seq
del model_gen_seq
Expand Down Expand Up @@ -507,56 +342,6 @@ def verify_config():

################################################################################


def split_input(inputs):
if args.slice_method == "square":
image_height_local = int(image_size / math.sqrt(spatial_part_size))
image_width_local = int(image_size / math.sqrt(spatial_part_size))

total_rows = int(math.sqrt(spatial_part_size))
total_cols = int(math.sqrt(spatial_part_size))

# current position of rank in matrix of math.sqrt(spatial_part_size) * math.sqrt(num_spatial_parts)
row = int(local_rank / total_cols)
col = int(local_rank % total_cols)

start_left = col * image_width_local
end_right = (col + 1) * image_width_local

start_top = row * image_height_local
end_bottom = (row + 1) * image_height_local

return inputs[:, :, start_top:end_bottom, start_left:end_right]

elif args.slice_method == "vertical":
image_height_local = int(image_size / spatial_part_size)
image_width_local = int(image_size / spatial_part_size)

start_left = local_rank * image_width_local
end_right = (local_rank + 1) * image_width_local

if local_rank == spatial_part_size - 1:
# In case of GPU count, partition size will be uneven and last
# rank will receive remaining image
return inputs[:, :, :, start_left:]
else:
return inputs[:, :, :, start_left:end_right]

elif args.slice_method == "horizontal":
image_height_local = int(image_size / spatial_part_size)
image_width_local = int(image_size / spatial_part_size)

start_top = local_rank * image_height_local
end_bottom = (local_rank + 1) * image_height_local

if local_rank == spatial_part_size - 1:
# In case of odd GPU count, partition size will be uneven and last
# rank will receive remaining image
return inputs[:, :, start_top:, :]
else:
return inputs[:, :, start_top:end_bottom, :]


################################# Train Model ##################################

perf = []
Expand All @@ -577,7 +362,9 @@ def run_epoch():
inputs, labels = data

if local_rank < spatial_part_size:
x = split_input(inputs)
x = split_input(
inputs, args.slice_method, image_size, spatial_part_size, local_rank
)
else:
x = inputs

Expand Down
Loading

0 comments on commit c671670

Please sign in to comment.