Skip to content

Commit

Permalink
Merge pull request #1602 from helmholtz-analytics/1601-bug-inefficien…
Browse files Browse the repository at this point in the history
…t-distribution-of-files

`load_npy_from_path`: improve load balancing
  • Loading branch information
mrfh92 authored Aug 9, 2024
2 parents ba3f4db + 0d0e575 commit 7acb068
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,11 +1183,14 @@ def load_npy_from_path(
raise RuntimeError("Number of processes can't exceed number of files")

rank = MPI_WORLD.rank
n_for_procs = n_files // process_number
idx = rank * n_for_procs
if rank + 1 == process_number:
n_for_procs += n_files % process_number
if rank < (n_files % process_number):
n_for_procs = n_files // process_number + 1
idx = rank * n_for_procs
else:
n_for_procs = n_files // process_number
idx = rank * n_for_procs + (n_files % process_number)
array_list = [np.load(path + "/" + element) for element in file_list[idx : idx + n_for_procs]]

larray = np.concatenate(array_list, split)
larray = torch.from_numpy(larray)

Expand Down

0 comments on commit 7acb068

Please sign in to comment.