Skip to content

Commit 38cb78b

Browse files
authored
Merge pull request su2code#2235 from su2code/improved_tape_statistics
Improved Tape Statistics
2 parents 910930b + a898882 commit 38cb78b

File tree

5 files changed

+81
-30
lines changed

5 files changed

+81
-30
lines changed

Common/include/basic_types/ad_structure.hpp

+77-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,13 @@ inline bool TapeActive() { return false; }
5858

5959
/*!
6060
* \brief Prints out tape statistics.
61+
*
62+
* Tape statistics are aggregated across OpenMP threads and MPI processes, if applicable.
63+
* With MPI, the given communicator is used to reduce data across MPI processes, and the printing behaviour can be set
64+
* per rank (usually, only the master rank prints).
6165
*/
62-
inline void PrintStatistics() {}
66+
template <typename Comm>
67+
inline void PrintStatistics(Comm communicator, bool printingRank) {}
6368

6469
/*!
6570
* \brief Registers the variable as an input. I.e. as a leaf of the computational graph.
@@ -348,7 +353,77 @@ FORCEINLINE void StopRecording() { AD::getTape().setPassive(); }
348353

349354
FORCEINLINE bool TapeActive() { return AD::getTape().isActive(); }
350355

351-
FORCEINLINE void PrintStatistics() { AD::getTape().printStatistics(); }
356+
template <typename Comm>
357+
FORCEINLINE void PrintStatistics(Comm communicator, bool printingRank) {
358+
if (printingRank) {
359+
std::cout << "-------------------------------------------------------\n";
360+
std::cout << " Serial parts of the tape\n";
361+
#ifdef HAVE_MPI
362+
std::cout << " (aggregated across MPI processes)\n";
363+
#endif
364+
std::cout << "-------------------------------------------------------\n";
365+
}
366+
367+
codi::TapeValues serialTapeValues = AD::getTape().getTapeValues();
368+
serialTapeValues.combineDataMPI(communicator);
369+
370+
if (printingRank) {
371+
serialTapeValues.formatDefault(std::cout);
372+
}
373+
374+
double totalMemoryUsed = serialTapeValues.getUsedMemorySize();
375+
double totalMemoryAllocated = serialTapeValues.getAllocatedMemorySize();
376+
377+
#ifdef HAVE_OPDI
378+
379+
if (printingRank) {
380+
std::cout << "-------------------------------------------------------\n";
381+
std::cout << " OpenMP parallel parts of the tape\n";
382+
std::cout << " (aggregated across OpenMP threads)\n";
383+
#ifdef HAVE_MPI
384+
std::cout << " (aggregated across MPI processes)\n";
385+
#endif
386+
std::cout << "-------------------------------------------------------\n";
387+
}
388+
389+
codi::TapeValues* aggregatedOpenMPTapeValues = nullptr;
390+
391+
// clang-format off
392+
393+
SU2_OMP_PARALLEL {
394+
if (omp_get_thread_num() == 0) { // master thread
395+
codi::TapeValues masterTapeValues = AD::getTape().getTapeValues();
396+
aggregatedOpenMPTapeValues = &masterTapeValues;
397+
398+
SU2_OMP_BARRIER // master completes initialization
399+
SU2_OMP_BARRIER // other threads complete adding their data
400+
401+
aggregatedOpenMPTapeValues->combineDataMPI(communicator);
402+
totalMemoryUsed += aggregatedOpenMPTapeValues->getUsedMemorySize();
403+
totalMemoryAllocated += aggregatedOpenMPTapeValues->getAllocatedMemorySize();
404+
if (printingRank) {
405+
aggregatedOpenMPTapeValues->formatDefault(std::cout);
406+
}
407+
aggregatedOpenMPTapeValues = nullptr;
408+
} else { // other threads
409+
SU2_OMP_BARRIER // master completes initialization
410+
SU2_OMP_CRITICAL {
411+
aggregatedOpenMPTapeValues->combineData(AD::getTape().getTapeValues());
412+
} END_SU2_OMP_CRITICAL
413+
SU2_OMP_BARRIER // other threads complete adding their data
414+
}
415+
} END_SU2_OMP_PARALLEL
416+
417+
// clang-format on
418+
#endif
419+
420+
if (printingRank) {
421+
std::cout << "-------------------------------------------------------\n";
422+
std::cout << " Total memory used : " << totalMemoryUsed / 1024.0 / 1024.0 << " MB\n";
423+
std::cout << " Total memory allocated : " << totalMemoryAllocated / 1024.0 / 1024.0 << " MB\n";
424+
std::cout << "-------------------------------------------------------\n";
425+
}
426+
}
352427

353428
FORCEINLINE void ClearAdjoints() { AD::getTape().clearAdjoints(); }
354429

SU2_CFD/src/drivers/CDiscAdjMultizoneDriver.cpp

+1-13
Original file line numberDiff line numberDiff line change
@@ -670,19 +670,7 @@ void CDiscAdjMultizoneDriver::SetRecording(RECORDING kind_recording, Kind_Tape t
670670
}
671671

672672
if (kind_recording != RECORDING::CLEAR_INDICES && driver_config->GetWrt_AD_Statistics()) {
673-
if (rank == MASTER_NODE) AD::PrintStatistics();
674-
#ifdef CODI_REVERSE_TYPE
675-
if (size > SINGLE_NODE) {
676-
su2double myMem = AD::getTape().getTapeValues().getUsedMemorySize(), totMem = 0.0;
677-
SU2_MPI::Allreduce(&myMem, &totMem, 1, MPI_DOUBLE, MPI_SUM, SU2_MPI::GetComm());
678-
if (rank == MASTER_NODE) {
679-
cout << "MPI\n";
680-
cout << "-------------------------------------\n";
681-
cout << " Total memory used : " << totMem << " MB\n";
682-
cout << "-------------------------------------\n" << endl;
683-
}
684-
}
685-
#endif
673+
AD::PrintStatistics(SU2_MPI::GetComm(), rank == MASTER_NODE);
686674
}
687675

688676
AD::StopRecording();

SU2_CFD/src/drivers/CDiscAdjSinglezoneDriver.cpp

+1-13
Original file line numberDiff line numberDiff line change
@@ -305,19 +305,7 @@ void CDiscAdjSinglezoneDriver::SetRecording(RECORDING kind_recording){
305305
SetObjFunction();
306306

307307
if (kind_recording != RECORDING::CLEAR_INDICES && config_container[ZONE_0]->GetWrt_AD_Statistics()) {
308-
if (rank == MASTER_NODE) AD::PrintStatistics();
309-
#ifdef CODI_REVERSE_TYPE
310-
if (size > SINGLE_NODE) {
311-
su2double myMem = AD::getTape().getTapeValues().getUsedMemorySize(), totMem = 0.0;
312-
SU2_MPI::Allreduce(&myMem, &totMem, 1, MPI_DOUBLE, MPI_SUM, SU2_MPI::GetComm());
313-
if (rank == MASTER_NODE) {
314-
cout << "MPI\n";
315-
cout << "-------------------------------------\n";
316-
cout << " Total memory used : " << totMem << " MB\n";
317-
cout << "-------------------------------------\n" << endl;
318-
}
319-
}
320-
#endif
308+
AD::PrintStatistics(SU2_MPI::GetComm(), rank == MASTER_NODE);
321309
}
322310

323311
AD::StopRecording();

meson_scripts/init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def init_submodules(
5555

5656
# This information of the modules is used if projects was not cloned using git
5757
# The sha tag must be maintained manually to point to the correct commit
58-
sha_version_codi = "bb7689fb9479818d4ab55c4f3898c88d92890315"
58+
sha_version_codi = "c6b039e5c9edb7675f90ffc725f9dd8e66571264"
5959
github_repo_codi = "https://github.com/scicompkl/CoDiPack"
6060
sha_version_medi = "ab3a7688f6d518f8d940eb61a341d89f51922ba4"
6161
github_repo_medi = "https://github.com/SciCompKL/MeDiPack"

0 commit comments

Comments
 (0)