From ea5414b2a63f8b46954ba177570427f53eb03b75 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen <33718780+casparvl@users.noreply.github.com> Date: Wed, 26 Jul 2023 16:15:40 +0200 Subject: [PATCH] Update eessi/testsuite/tests/apps/tensorflow/src/tf_test.py More elegant way of retrieving local rank Co-authored-by: Sam Moors --- eessi/testsuite/tests/apps/tensorflow/src/tf_test.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/eessi/testsuite/tests/apps/tensorflow/src/tf_test.py b/eessi/testsuite/tests/apps/tensorflow/src/tf_test.py index fe4ed0b7..ac41415b 100644 --- a/eessi/testsuite/tests/apps/tensorflow/src/tf_test.py +++ b/eessi/testsuite/tests/apps/tensorflow/src/tf_test.py @@ -36,13 +36,9 @@ def get_local_rank(rank_info, rank_info_list): # Note that rank_info_list is sorted by rank, by definition of the MPI allgather routine. # Thus, if our current rank is the n-th item in rank_info_list for which the hostname matches, # our local rank is n - local_rank = 0 - for item in rank_info_list: - if item['hostname'] == rank_info['hostname']: - if item['rank'] == rank_info['rank']: - return local_rank - else: - local_rank += 1 + for index, item in enumerate(rank_info_list): + if item['hostname'] == rank_info['hostname'] and item['rank'] == rank_info['rank']: + return index def get_rank_info(comm=MPI.COMM_WORLD): '''Create a dict for this worker containing rank, hostname and port to be used by this worker'''