diff --git a/bin/all_sky_search/pycbc_calculate_psd b/bin/all_sky_search/pycbc_calculate_psd index 9764250fc9f..bcfd3ab1676 100755 --- a/bin/all_sky_search/pycbc_calculate_psd +++ b/bin/all_sky_search/pycbc_calculate_psd @@ -1,10 +1,11 @@ #!/usr/bin/env python """ Calculate psd estimates for analysis segments """ -import logging, argparse, numpy, multiprocessing, time, copy +import logging, argparse, numpy, time, copy from six.moves import zip_longest import pycbc, pycbc.psd, pycbc.strain, pycbc.events from pycbc.io import HFile +from pycbc.pool import BroadcastPool as Pool from pycbc.fft.fftw import set_measure_level from pycbc.workflow import resolve_td_option from ligo.segments import segmentlist, segment @@ -101,7 +102,7 @@ segments = segmentlist(frozenset(segments)) logging.info('%d psds to calculate', len(segments)) if len(segments) > 0: - pool = multiprocessing.Pool(args.cores) + pool = Pool(args.cores) psds = pool.map_async(get_psd, zip(segments, range(len(segments)))) psds = psds.get() else: diff --git a/bin/inference/pycbc_inference_plot_movie b/bin/inference/pycbc_inference_plot_movie index 7db38b92016..f5b3fed9f50 100644 --- a/bin/inference/pycbc_inference_plot_movie +++ b/bin/inference/pycbc_inference_plot_movie @@ -43,7 +43,7 @@ import logging import subprocess import os import glob -from multiprocessing import Pool +from pycbc.pool import BroadcastPool as Pool import numpy diff --git a/bin/pycbc_inspiral b/bin/pycbc_inspiral index 61e769ac2e7..24193310714 100644 --- a/bin/pycbc_inspiral +++ b/bin/pycbc_inspiral @@ -24,7 +24,7 @@ import argparse import numpy import itertools import time -from multiprocessing import Pool +from pycbc.pool import BroadcastPool as Pool import pycbc from pycbc import vetoes, psd, waveform, strain, scheme, fft, DYN_RANGE_FAC, events diff --git a/bin/pycbc_optimal_snr b/bin/pycbc_optimal_snr index 08fa83c4ae9..121f0b8c1ab 100644 --- a/bin/pycbc_optimal_snr +++ b/bin/pycbc_optimal_snr @@ -33,6 +33,7 @@ from ligo.lw import lsctables import pycbc import pycbc.inject import pycbc.psd +from pycbc.pool import BroadcastPool as Pool from pycbc.filter import sigma, make_frequency_series from pycbc.types import TimeSeries, FrequencySeries, zeros, float32, \ MultiDetOptionAction, load_frequencyseries @@ -312,7 +313,7 @@ if __name__ == '__main__': if opts.cores > 1: logging.info('Starting workers') - pool = multiprocessing.Pool(processes=opts.cores) + pool = Pool(processes=opts.cores) iterator = pool.imap_unordered(compute_optimal_snr, inj_table) else: # do not bother spawning extra processes if running single-core diff --git a/pycbc/__init__.py b/pycbc/__init__.py index 99f4cc6ef37..49f4d97dcde 100644 --- a/pycbc/__init__.py +++ b/pycbc/__init__.py @@ -207,18 +207,6 @@ def makedir(path): # platforms (mac) that are silly and don't use the standard gcc. if sys.platform == 'darwin': HAVE_OMP = False - - # MacosX after python3.7 switched to 'spawn', however, this does not - # preserve common state information which we have relied on when using - # multiprocessing based pools. - import multiprocessing - if multiprocessing.get_start_method(allow_none=True) is None: - if hasattr(multiprocessing, 'set_start_method'): - multiprocessing.set_start_method('fork') - elif multiprocessing.get_start_method() != 'fork': - warnings.warn("PyCBC requires the use of the 'fork' start method" - " for multiprocessing, it is currently set to {}" - .format(multiprocessing.get_start_method())) else: HAVE_OMP = True diff --git a/pycbc/inference/models/brute_marg.py b/pycbc/inference/models/brute_marg.py index 3843a1e88e0..0b419838dd6 100644 --- a/pycbc/inference/models/brute_marg.py +++ b/pycbc/inference/models/brute_marg.py @@ -20,7 +20,7 @@ import logging import numpy -from multiprocessing import Pool +from pycbc.pool import BroadcastPool as Pool from scipy.special import logsumexp from .gaussian_noise import BaseGaussianNoise diff --git a/pycbc/pool.py b/pycbc/pool.py index a770b9537ec..bffa2202689 100644 --- a/pycbc/pool.py +++ b/pycbc/pool.py @@ -2,7 +2,7 @@ """ import multiprocessing.pool import functools -from multiprocessing import TimeoutError, cpu_count +from multiprocessing import TimeoutError, cpu_count, get_context import types import signal import atexit @@ -49,13 +49,20 @@ def _shutdown_pool(p): class BroadcastPool(multiprocessing.pool.Pool): """ Multiprocessing pool with a broadcast method """ - def __init__(self, processes=None, initializer=None, initargs=(), **kwds): + def __init__(self, processes=None, initializer=None, initargs=(), + context=None, **kwds): global _process_lock global _numdone _process_lock = multiprocessing.Lock() _numdone = multiprocessing.Value('i', 0) noint = functools.partial(_noint, initializer) - super(BroadcastPool, self).__init__(processes, noint, initargs, **kwds) + + # Default is fork to preserve child memory inheritance and + # copy on write + if context is None: + context = get_context("fork") + super(BroadcastPool, self).__init__(processes, noint, initargs, + context=context, **kwds) atexit.register(_shutdown_pool, self) def __len__(self): @@ -166,6 +173,7 @@ def use_mpi(require_mpi=False, log=True): size = rank = 0 return use_mpi, size, rank + def choose_pool(processes, mpi=False): """ Get processing pool """