diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 794024a8d..f4a1af36a 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -63,7 +63,7 @@ def _check_hashes(x, y, datatype): def set_default_compiler(compiler): - """Set the PyOP2 default compiler, globally. + """Set the PyOP2 default compiler, globally over COMM_WORLD. :arg compiler: String with name or path to compiler executable OR a subclass of the Compiler class @@ -85,66 +85,73 @@ def set_default_compiler(compiler): ) -def sniff_compiler(exe): +def sniff_compiler(exe, comm=mpi.COMM_WORLD): """Obtain the correct compiler class by calling the compiler executable. :arg exe: String with name or path to compiler executable + :arg comm: Comm over which we want to determine the compiler type :returns: A compiler class """ - try: - output = subprocess.run( - [exe, "--version"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - encoding="utf-8" - ).stdout - except (subprocess.CalledProcessError, UnicodeDecodeError): - output = "" - - # Find the name of the compiler family - if output.startswith("gcc") or output.startswith("g++"): - name = "GNU" - elif output.startswith("clang"): - name = "clang" - elif output.startswith("Apple LLVM") or output.startswith("Apple clang"): - name = "clang" - elif output.startswith("icc"): - name = "Intel" - elif "Cray" in output.split("\n")[0]: - # Cray is more awkward eg: - # Cray clang version 11.0.4 () - # gcc (GCC) 9.3.0 20200312 (Cray Inc.) - name = "Cray" - else: - name = "unknown" - - # Set the compiler instance based on the platform (and architecture) - if sys.platform.find("linux") == 0: - if name == "Intel": - compiler = LinuxIntelCompiler - elif name == "GNU": - compiler = LinuxGnuCompiler - elif name == "clang": - compiler = LinuxClangCompiler - elif name == "Cray": - compiler = LinuxCrayCompiler + compiler = None + if comm.rank == 0: + # Note: + # Sniffing compiler for very large numbers of MPI ranks is + # expensive so we do this on one rank and broadcast + try: + output = subprocess.run( + [exe, "--version"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8" + ).stdout + except (subprocess.CalledProcessError, UnicodeDecodeError): + output = "" + + # Find the name of the compiler family + if output.startswith("gcc") or output.startswith("g++"): + name = "GNU" + elif output.startswith("clang"): + name = "clang" + elif output.startswith("Apple LLVM") or output.startswith("Apple clang"): + name = "clang" + elif output.startswith("icc"): + name = "Intel" + elif "Cray" in output.split("\n")[0]: + # Cray is more awkward eg: + # Cray clang version 11.0.4 () + # gcc (GCC) 9.3.0 20200312 (Cray Inc.) + name = "Cray" else: - compiler = AnonymousCompiler - elif sys.platform.find("darwin") == 0: - if name == "clang": - machine = platform.uname().machine - if machine == "arm64": - compiler = MacClangARMCompiler - elif machine == "x86_64": - compiler = MacClangCompiler - elif name == "GNU": - compiler = MacGNUCompiler + name = "unknown" + + # Set the compiler instance based on the platform (and architecture) + if sys.platform.find("linux") == 0: + if name == "Intel": + compiler = LinuxIntelCompiler + elif name == "GNU": + compiler = LinuxGnuCompiler + elif name == "clang": + compiler = LinuxClangCompiler + elif name == "Cray": + compiler = LinuxCrayCompiler + else: + compiler = AnonymousCompiler + elif sys.platform.find("darwin") == 0: + if name == "clang": + machine = platform.uname().machine + if machine == "arm64": + compiler = MacClangARMCompiler + elif machine == "x86_64": + compiler = MacClangCompiler + elif name == "GNU": + compiler = MacGNUCompiler + else: + compiler = AnonymousCompiler else: compiler = AnonymousCompiler - else: - compiler = AnonymousCompiler - return compiler + + return comm.bcast(compiler, 0) class Compiler(ABC): @@ -178,8 +185,8 @@ class Compiler(ABC): _debugflags = () def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None): - # Get compiler version ASAP since it is used in __repr__ - self.sniff_compiler_version() + # Set compiler version ASAP since it is used in __repr__ + self.version = None self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) @@ -190,6 +197,7 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co # Compilation communicators are reference counted on the PyOP2 comm self.pcomm = mpi.internal_comm(comm, self) self.comm = mpi.compilation_comm(self.pcomm, self) + self.sniff_compiler_version() def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" @@ -238,23 +246,28 @@ def sniff_compiler_version(self, cpp=False): :arg cpp: If set to True will use the C++ compiler rather than the C compiler to determine the version number. """ + # Note: + # Sniffing the compiler version for very large numbers of + # MPI ranks is expensive exe = self.cxx if cpp else self.cc - self.version = None - # `-dumpversion` is not sufficient to get the whole version string (for some compilers), - # but other compilers do not implement `-dumpfullversion`! - for dumpstring in ["-dumpfullversion", "-dumpversion"]: - try: - output = subprocess.run( - [exe, dumpstring], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - encoding="utf-8" - ).stdout - self.version = Version(output) - break - except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): - continue + version = None + if self.comm.rank == 0: + # `-dumpversion` is not sufficient to get the whole version string (for some compilers), + # but other compilers do not implement `-dumpfullversion`! + for dumpstring in ["-dumpfullversion", "-dumpversion"]: + try: + output = subprocess.run( + [exe, dumpstring], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8" + ).stdout + version = Version(output) + break + except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): + continue + self.version = self.comm.bcast(version, 0) @property def bugfix_cflags(self): @@ -448,23 +461,6 @@ class LinuxGnuCompiler(Compiler): _optflags = ("-march=native", "-O3", "-ffast-math") _debugflags = ("-O0", "-g") - def sniff_compiler_version(self, cpp=False): - super(LinuxGnuCompiler, self).sniff_compiler_version() - if self.version >= Version("7.0"): - try: - # gcc-7 series only spits out patch level on dumpfullversion. - exe = self.cxx if cpp else self.cc - output = subprocess.run( - [exe, "-dumpfullversion"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - encoding="utf-8" - ).stdout - self.version = Version(output) - except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): - pass - @property def bugfix_cflags(self): """Flags to work around bugs in compilers.""" @@ -596,7 +592,7 @@ def __init__(self, code, argtypes): exe = configuration["cxx"] or "mpicxx" else: exe = configuration["cc"] or "mpicc" - compiler = sniff_compiler(exe) + compiler = sniff_compiler(exe, comm) dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension) if isinstance(jitmodule, GlobalKernel):