From 0d0e575b784bdff2f9ec86bf966664ef15150bf1 Mon Sep 17 00:00:00 2001 From: "Nguyen Xuan, Tu" Date: Wed, 7 Aug 2024 14:53:56 +0200 Subject: [PATCH] Now more optimal distribution of files --- heat/core/io.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index c48c2c3a6..3135ee18d 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1183,11 +1183,14 @@ def load_npy_from_path( raise RuntimeError("Number of processes can't exceed number of files") rank = MPI_WORLD.rank - n_for_procs = n_files // process_number - idx = rank * n_for_procs - if rank + 1 == process_number: - n_for_procs += n_files % process_number + if rank < (n_files % process_number): + n_for_procs = n_files // process_number + 1 + idx = rank * n_for_procs + else: + n_for_procs = n_files // process_number + idx = rank * n_for_procs + (n_files % process_number) array_list = [np.load(path + "/" + element) for element in file_list[idx : idx + n_for_procs]] + larray = np.concatenate(array_list, split) larray = torch.from_numpy(larray)