Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ht.concatenate() #1210

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
361 changes: 137 additions & 224 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,250 +462,163 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray:
for arr in arrays:
sanitation.sanitize_in(arr)

if not isinstance(axis, int):
raise TypeError(f"axis must be an integer, currently: {type(axis)}")
axis = stride_tricks.sanitize_axis(arrays[0].gshape, axis)

# a single array cannot be concatenated
if len(arrays) < 2:
raise ValueError("concatenate requires 2 arrays")
# concatenate multiple arrays
elif len(arrays) > 2:
res = concatenate((arrays[0], arrays[1]), axis=axis)
for a in range(2, len(arrays)):
res = concatenate((res, arrays[a]), axis=axis)
return res

# unpack the arrays
arr0, arr1 = arrays
res = arrays[0]
best_dtype = arrays[0].dtype
output_split = arrays[0].split
arrays_copy = []

if not isinstance(axis, int):
raise TypeError(f"axis must be an integer, currently: {type(axis)}")
axis = stride_tricks.sanitize_axis(arr0.gshape, axis)
split_values = {arr.split for arr in arrays}
if len(split_values) > 1 and output_split is None:
output_split = next(s for s in split_values if s is not None)

if arr0.ndim != arr1.ndim:
raise ValueError("DNDarrays must have the same number of dimensions")
for i, arr in enumerate(arrays[0:]):
# different dimensions may not be concatenated
if res.ndim != arr.ndim:
raise ValueError("DNDarrays must have the same number of dimensions")

if any(arr0.gshape[i] != arr1.gshape[i] for i in range(len(arr0.gshape)) if i != axis):
raise ValueError(
f"Arrays cannot be concatenated, shapes must be the same in every axis except the selected axis: {arr0.gshape}, {arr1.gshape}"
)
# different communicators may not be concatenated
if res.comm != arr.comm:
raise RuntimeError("Communicators of passed arrays mismatch.")

# different global shapes may not be concatenated
if any(i != axis and res.gshape[i] != arr.gshape[i] for i in range(len(res.gshape))):
raise ValueError(
f"Arrays cannot be concatenated, shapes must be the same in every axis except the selected axis: "
f"{res.gshape}, {arr.gshape}"
)

# different communicators may not be concatenated
if arr0.comm != arr1.comm:
raise RuntimeError("Communicators of passed arrays mismatch.")
if output_split != arr.split and output_split is not None and arr.split is not None:
raise RuntimeError(
f"DNDarrays given have differing split axes, arr0 {res.split} arr{i} {arr.split}"
)

if not arr.is_distributed():
arrays_copy.append(
factories.array(
obj=arr,
split=output_split,
is_split=None,
copy=True,
device=arr.device,
comm=arr.comm,
)
)
else:
arrays_copy.append(arr.copy())

# identify common data type
best_dtype = types.promote_types(res.dtype, arr.dtype)
if res.dtype != best_dtype:
res = best_dtype(res, device=res.device)

# convert all arrays to best_dtype
def conversion_func(x):
return best_dtype(x, device=x.device)

# identify common data type
out_dtype = types.promote_types(arr0.dtype, arr1.dtype)
if arr0.dtype != out_dtype:
arr0 = out_dtype(arr0, device=arr0.device)
if arr1.dtype != out_dtype:
arr1 = out_dtype(arr1, device=arr1.device)
arrays_copy = list(map(conversion_func, arrays_copy))

s0, s1 = arr0.split, arr1.split
# no splits, local concat
if s0 is None and s1 is None:
res_gshape = list(arrays_copy[0].gshape)
res_gshape[axis] = sum(arr.gshape[axis] for arr in arrays_copy)

if all(not arr.is_distributed() for arr in arrays):
# none of the arrays are distributed: use torch cat to concatenate the arrays sequence
return factories.array(
torch.cat((arr0.larray, arr1.larray), dim=axis), device=arr0.device, comm=arr0.comm
torch.cat([arri.larray for arri in arrays], dim=axis),
dtype=best_dtype,
is_split=None,
device=arrays[0].device,
comm=arrays[0].comm,
)

# non-matching splits when both arrays are split
elif s0 != s1 and all([s is not None for s in [s0, s1]]):
raise RuntimeError(f"DNDarrays given have differing split axes, arr0 {s0} arr1 {s1}")

elif (s0 is None and s1 != axis) or (s1 is None and s0 != axis):
_, _, arr0_slice = arr1.comm.chunk(arr0.shape, arr1.split)
_, _, arr1_slice = arr0.comm.chunk(arr1.shape, arr0.split)
out = factories.array(
torch.cat((arr0.larray[arr0_slice], arr1.larray[arr1_slice]), dim=axis),
dtype=out_dtype,
is_split=s1 if s1 is not None else s0,
device=arr1.device,
comm=arr0.comm,
elif axis != output_split:
return __concatenate_split_differ_axis(
arrays=arrays_copy, res_gshape=tuple(res_gshape), axis=axis, output_split=output_split
)
else: # axis=split
return __concatenate_split_equals_axis(
arrays=arrays_copy, res_gshape=tuple(res_gshape), axis=axis, output_split=output_split
)

return out

elif s0 == s1 or any(s is None for s in [s0, s1]):
if s0 != axis and all(s is not None for s in [s0, s1]):
# the axis is different than the split axis, this case can be easily implemented
# torch cat arrays together and return a new array that is_split

out = factories.array(
torch.cat((arr0.larray, arr1.larray), dim=axis),
dtype=out_dtype,
is_split=s0,
device=arr0.device,
comm=arr0.comm,
)
return out
def __concatenate_split_differ_axis(
arrays: Sequence[DNDarray, ...], res_gshape: Tuple[int], axis: int = 0, output_split: int = 0
) -> DNDarray:

def __balance_func(x: DNDarray):
if x.split is not None:
balance(x, copy=False)
else:
t_arr0 = arr0.larray
t_arr1 = arr1.larray
# maps are created for where the data is and the output shape is calculated
lshape_map = torch.zeros((2, arr0.comm.size, len(arr0.gshape)), dtype=torch.int)
lshape_map[0, arr0.comm.rank, :] = torch.Tensor(arr0.lshape)
lshape_map[1, arr0.comm.rank, :] = torch.Tensor(arr1.lshape)
lshape_map_comm = arr0.comm.Iallreduce(MPI.IN_PLACE, lshape_map, MPI.SUM)

arr0_shape, arr1_shape = list(arr0.shape), list(arr1.shape)
arr0_shape[axis] += arr1_shape[axis]
out_shape = tuple(arr0_shape)

# the chunk map is used to determine how much data should be on each process
chunk_map = torch.zeros((arr0.comm.size, len(arr0.gshape)), dtype=torch.int)
_, _, chk = arr0.comm.chunk(out_shape, s0 if s0 is not None else s1)
for i in range(len(out_shape)):
chunk_map[arr0.comm.rank, i] = chk[i].stop - chk[i].start
chunk_map_comm = arr0.comm.Iallreduce(MPI.IN_PLACE, chunk_map, MPI.SUM)

lshape_map_comm.Wait()
chunk_map_comm.Wait()

if s0 is not None:
send_slice = [slice(None)] * arr0.ndim
keep_slice = [slice(None)] * arr0.ndim
# data is first front-loaded onto the first size/2 processes
for spr in range(1, arr0.comm.size):
if arr0.comm.rank == spr:
for pr in range(spr):
send_amt = abs((chunk_map[pr, axis] - lshape_map[0, pr, axis]).item())
send_amt = (
send_amt if send_amt < t_arr0.shape[axis] else t_arr0.shape[axis]
)
if send_amt:
send_slice[arr0.split] = slice(0, send_amt)
keep_slice[arr0.split] = slice(send_amt, t_arr0.shape[axis])
send = arr0.comm.Isend(
t_arr0[send_slice].clone(),
dest=pr,
tag=pr + arr0.comm.size + spr,
)
t_arr0 = t_arr0[keep_slice].clone()
send.Wait()
for pr in range(spr):
snt = abs((chunk_map[pr, s0] - lshape_map[0, pr, s0]).item())
snt = (
snt
if snt < lshape_map[0, spr, axis]
else lshape_map[0, spr, axis].item()
)
if arr0.comm.rank == pr and snt:
shp = list(arr0.gshape)
shp[arr0.split] = snt
data = torch.zeros(
shp, dtype=out_dtype.torch_type(), device=arr0.device.torch_device
)

arr0.comm.Recv(data, source=spr, tag=pr + arr0.comm.size + spr)
t_arr0 = torch.cat((t_arr0, data), dim=arr0.split)
lshape_map[0, pr, arr0.split] += snt
lshape_map[0, spr, arr0.split] -= snt

if s1 is not None:
send_slice = [slice(None)] * arr0.ndim
keep_slice = [slice(None)] * arr0.ndim

# push the data backwards (arr1), making the data the proper size for arr1 on the last nodes
# the data is "compressed" on np/2 processes. data is sent from
for spr in range(arr0.comm.size - 1, -1, -1):
if arr0.comm.rank == spr:
for pr in range(arr0.comm.size - 1, spr, -1):
# calculate the amount of data to send from the chunk map
send_amt = abs((chunk_map[pr, axis] - lshape_map[1, pr, axis]).item())
send_amt = (
send_amt if send_amt < t_arr1.shape[axis] else t_arr1.shape[axis]
)
if send_amt:
send_slice[axis] = slice(
t_arr1.shape[axis] - send_amt, t_arr1.shape[axis]
)
keep_slice[axis] = slice(0, t_arr1.shape[axis] - send_amt)
send = arr1.comm.Isend(
t_arr1[send_slice].clone(),
dest=pr,
tag=pr + arr1.comm.size + spr,
)
t_arr1 = t_arr1[keep_slice].clone()
send.Wait()
for pr in range(arr1.comm.size - 1, spr, -1):
snt = abs((chunk_map[pr, axis] - lshape_map[1, pr, axis]).item())
snt = (
snt
if snt < lshape_map[1, spr, axis]
else lshape_map[1, spr, axis].item()
)
balance(x.resplit(output_split), copy=False)

if arr1.comm.rank == pr and snt:
shp = list(arr1.gshape)
shp[axis] = snt
data = torch.zeros(
shp, dtype=out_dtype.torch_type(), device=arr1.device.torch_device
)
arr1.comm.Recv(data, source=spr, tag=pr + arr1.comm.size + spr)
t_arr1 = torch.cat((data, t_arr1), dim=axis)
lshape_map[1, pr, axis] += snt
lshape_map[1, spr, axis] -= snt

if s0 is None:
arb_slice = [None] * len(arr1.shape)
for c in range(len(chunk_map)):
arb_slice[axis] = c
# the chunk map is adjusted by subtracting what data is already in the correct place (the data from
# arr1 is already correctly placed) i.e. the chunk map shows how much data is still needed on each
# process, the local
chunk_map[arb_slice] -= lshape_map[tuple([1] + arb_slice)]

# after adjusting arr1 need to now select the target data in arr0 on each node with a local slice
if arr0.comm.rank == 0:
lcl_slice = [slice(None)] * arr0.ndim
lcl_slice[axis] = slice(chunk_map[0, axis].item())
t_arr0 = t_arr0[lcl_slice].clone().squeeze()
ttl = chunk_map[0, axis].item()
for en in range(1, arr0.comm.size):
sz = chunk_map[en, axis]
if arr0.comm.rank == en:
lcl_slice = [slice(None)] * arr0.ndim
lcl_slice[axis] = slice(ttl, sz.item() + ttl, 1)
t_arr0 = t_arr0[lcl_slice].clone().squeeze()
ttl += sz.item()

if len(t_arr0.shape) < len(t_arr1.shape):
t_arr0.unsqueeze_(axis)

if s1 is None:
arb_slice = [None] * len(arr0.shape)
for c in range(len(chunk_map)):
arb_slice[axis] = c
chunk_map[arb_slice] -= lshape_map[tuple([0] + arb_slice)]

# get the desired data in arr1 on each node with a local slice
if arr1.comm.rank == arr1.comm.size - 1:
lcl_slice = [slice(None)] * arr1.ndim
lcl_slice[axis] = slice(
t_arr1.shape[axis] - chunk_map[-1, axis].item(), t_arr1.shape[axis], 1
)
t_arr1 = t_arr1[lcl_slice].clone().squeeze()
ttl = chunk_map[-1, axis].item()
for en in range(arr1.comm.size - 2, -1, -1):
sz = chunk_map[en, axis]
if arr1.comm.rank == en:
lcl_slice = [slice(None)] * arr1.ndim
lcl_slice[axis] = slice(
t_arr1.shape[axis] - (sz.item() + ttl), t_arr1.shape[axis] - ttl, 1
)
t_arr1 = t_arr1[lcl_slice].clone().squeeze()
ttl += sz.item()
if len(t_arr1.shape) < len(t_arr0.shape):
t_arr1.unsqueeze_(axis)

res = torch.cat((t_arr0, t_arr1), dim=axis)
out = factories.array(
res,
is_split=s0 if s0 is not None else s1,
dtype=out_dtype,
device=arr0.device,
comm=arr0.comm,
)
balanced_arrays = list(__balance_func(array) for array in arrays)
res = torch.cat([arr.larray for arr in balanced_arrays], dim=axis)

return out
# create a DNDarray from result
return DNDarray(
array=res,
gshape=res_gshape,
dtype=arrays[0].dtype,
split=output_split,
device=arrays[0].device,
comm=arrays[0].comm,
balanced=True,
)


def __concatenate_split_equals_axis(
arrays: Sequence[DNDarray, ...], res_gshape: Tuple[int], axis: int = 0, output_split: int = 0
) -> DNDarray:
# calculate final global shape
res_arrays = []
local_axis_slice = res_gshape[axis] // arrays[0].comm.size
remainder = res_gshape[axis] % arrays[0].comm.size

target_map = arrays[0].lshape_map
target_map[:, axis] = 0
arr = 0
arr_offset = 0

# redistribute arrays for balanced final result
for device in range(arrays[0].comm.size):
device_load = 0
device_capacity = local_axis_slice + 1 * (device < remainder)
while device_load < device_capacity:
target_map[device, axis] += min(
arrays[arr].gshape[axis] - arr_offset, device_capacity - device_load
)
device_load += target_map[device, axis]
arr_offset += target_map[device, axis]

if arr_offset == arrays[arr].gshape[axis]:
# redistribute
arrays[arr].redistribute_(lshape_map=arrays[arr].lshape_map, target_map=target_map)
res_arrays.append(arrays[arr])
# proceed to next array
arr += 1
target_map[:, axis] = 0
arr_offset = 0

# local cat
res = torch.cat([arr.larray for arr in res_arrays], dim=axis)

# create a DNDarray from result
return DNDarray(
array=res,
gshape=res_gshape,
dtype=arrays[0].dtype,
split=output_split,
device=arrays[0].device,
comm=arrays[0].comm,
balanced=True,
)


def diag(a: DNDarray, offset: int = 0) -> DNDarray:
Expand Down
Loading