diff --git a/src/atlas/parallel/acc/acc.cc b/src/atlas/parallel/acc/acc.cc index c73bdea2c..c1446f7a0 100644 --- a/src/atlas/parallel/acc/acc.cc +++ b/src/atlas/parallel/acc/acc.cc @@ -77,5 +77,21 @@ void* deviceptr(void* host_data) { #endif } +CompilerId compiler_id() { +#if ATLAS_HAVE_ACC + static CompilerId id = []() { + switch (atlas_acc_compiler_id()) { + case atlas_acc_compiler_id_cray: return CompilerId::cray; + case atlas_acc_compiler_id_nvidia: return CompilerId::nvidia; + default: return CompilerId::unknown; + } + }(); + return id; +#else + return CompilerId::unknown; +#endif +} + + } diff --git a/src/atlas/parallel/acc/acc.h b/src/atlas/parallel/acc/acc.h index 3a10f1e25..7d7760855 100644 --- a/src/atlas/parallel/acc/acc.h +++ b/src/atlas/parallel/acc/acc.h @@ -13,11 +13,18 @@ namespace atlas::acc { +enum class CompilerId { + unknown, + nvidia, + cray, +}; + int devices(); void map(void* host_data, void* device_data, std::size_t bytes); void unmap(void* host_data); bool is_present(void* host_data, std::size_t bytes); void* deviceptr(void* host_data); +CompilerId compiler_id(); } diff --git a/src/atlas_acc_support/atlas_acc.F90 b/src/atlas_acc_support/atlas_acc.F90 index bb73b5237..bdc1f378c 100644 --- a/src/atlas_acc_support/atlas_acc.F90 +++ b/src/atlas_acc_support/atlas_acc.F90 @@ -12,6 +12,17 @@ module atlas_acc contains +function atlas_acc_compiler_id() bind(C,name="atlas_acc_compiler_id") result(compiler_id) + use, intrinsic :: iso_c_binding, only : c_int + integer(c_int) :: compiler_id + ! compiler_id must match number in atlas_acc.h enum type +#ifdef _CRAYFTN + compiler_id = 2 ! cray +#else + compiler_id = 0 ! unknown +#endif +end function + function atlas_acc_get_num_devices() bind(C,name="atlas_acc_get_num_devices") result(num_devices) use, intrinsic :: iso_c_binding, only : c_int integer(c_int) :: num_devices diff --git a/src/atlas_acc_support/atlas_acc.cc b/src/atlas_acc_support/atlas_acc.cc index 8cc58981a..0a7cbd575 100644 --- a/src/atlas_acc_support/atlas_acc.cc +++ b/src/atlas_acc_support/atlas_acc.cc @@ -127,4 +127,13 @@ const char* atlas_acc_info_str() { int atlas_acc_get_num_devices() { return acc_get_num_devices(acc_get_device_type()); } + +atlas_acc_compiler_id_t atlas_acc_compiler_id() { +#if defined(__NVCOMPILER) + return atlas_acc_compiler_id_nvidia; +#else + return atlas_acc_compiler_id_unknown; +#endif +} + } diff --git a/src/atlas_acc_support/atlas_acc.h b/src/atlas_acc_support/atlas_acc.h index 37c928b95..2652e4ae9 100644 --- a/src/atlas_acc_support/atlas_acc.h +++ b/src/atlas_acc_support/atlas_acc.h @@ -27,6 +27,14 @@ void* atlas_acc_deviceptr(void* cpu_ptr); atlas_acc_device_t atlas_acc_get_device_type(); int atlas_acc_get_num_devices(); +typedef enum { + atlas_acc_compiler_id_unknown = 0, + atlas_acc_compiler_id_nvidia = 1, + atlas_acc_compiler_id_cray = 2 +} atlas_acc_compiler_id_t; + +atlas_acc_compiler_id_t atlas_acc_compiler_id(); + #ifdef __cplusplus } #endif