diff --git a/heat/core/io.py b/heat/core/io.py index c48c2c3a6..3135ee18d 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1183,11 +1183,14 @@ def load_npy_from_path( raise RuntimeError("Number of processes can't exceed number of files") rank = MPI_WORLD.rank - n_for_procs = n_files // process_number - idx = rank * n_for_procs - if rank + 1 == process_number: - n_for_procs += n_files % process_number + if rank < (n_files % process_number): + n_for_procs = n_files // process_number + 1 + idx = rank * n_for_procs + else: + n_for_procs = n_files // process_number + idx = rank * n_for_procs + (n_files % process_number) array_list = [np.load(path + "/" + element) for element in file_list[idx : idx + n_for_procs]] + larray = np.concatenate(array_list, split) larray = torch.from_numpy(larray)