Skip to content

Commit

Permalink
Don't force global multiprocessing start method (gwastro#4890)
Browse files Browse the repository at this point in the history
* Don't force global multiprocessing start method

* code style fixes
  • Loading branch information
ahnitz authored and prayush committed Nov 21, 2024
1 parent 6b2fd63 commit 9092cb9
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 21 deletions.
5 changes: 3 additions & 2 deletions bin/all_sky_search/pycbc_calculate_psd
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python
""" Calculate psd estimates for analysis segments
"""
import logging, argparse, numpy, multiprocessing, time, copy
import logging, argparse, numpy, time, copy
from six.moves import zip_longest
import pycbc, pycbc.psd, pycbc.strain, pycbc.events
from pycbc.io import HFile
from pycbc.pool import BroadcastPool as Pool
from pycbc.fft.fftw import set_measure_level
from pycbc.workflow import resolve_td_option
from ligo.segments import segmentlist, segment
Expand Down Expand Up @@ -101,7 +102,7 @@ segments = segmentlist(frozenset(segments))
logging.info('%d psds to calculate', len(segments))

if len(segments) > 0:
pool = multiprocessing.Pool(args.cores)
pool = Pool(args.cores)
psds = pool.map_async(get_psd, zip(segments, range(len(segments))))
psds = psds.get()
else:
Expand Down
2 changes: 1 addition & 1 deletion bin/inference/pycbc_inference_plot_movie
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import logging
import subprocess
import os
import glob
from multiprocessing import Pool
from pycbc.pool import BroadcastPool as Pool

import numpy

Expand Down
2 changes: 1 addition & 1 deletion bin/pycbc_inspiral
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import argparse
import numpy
import itertools
import time
from multiprocessing import Pool
from pycbc.pool import BroadcastPool as Pool

import pycbc
from pycbc import vetoes, psd, waveform, strain, scheme, fft, DYN_RANGE_FAC, events
Expand Down
3 changes: 2 additions & 1 deletion bin/pycbc_optimal_snr
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from ligo.lw import lsctables
import pycbc
import pycbc.inject
import pycbc.psd
from pycbc.pool import BroadcastPool as Pool
from pycbc.filter import sigma, make_frequency_series
from pycbc.types import TimeSeries, FrequencySeries, zeros, float32, \
MultiDetOptionAction, load_frequencyseries
Expand Down Expand Up @@ -312,7 +313,7 @@ if __name__ == '__main__':

if opts.cores > 1:
logging.info('Starting workers')
pool = multiprocessing.Pool(processes=opts.cores)
pool = Pool(processes=opts.cores)
iterator = pool.imap_unordered(compute_optimal_snr, inj_table)
else:
# do not bother spawning extra processes if running single-core
Expand Down
12 changes: 0 additions & 12 deletions pycbc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,6 @@ def makedir(path):
# platforms (mac) that are silly and don't use the standard gcc.
if sys.platform == 'darwin':
HAVE_OMP = False

# MacosX after python3.7 switched to 'spawn', however, this does not
# preserve common state information which we have relied on when using
# multiprocessing based pools.
import multiprocessing
if multiprocessing.get_start_method(allow_none=True) is None:
if hasattr(multiprocessing, 'set_start_method'):
multiprocessing.set_start_method('fork')
elif multiprocessing.get_start_method() != 'fork':
warnings.warn("PyCBC requires the use of the 'fork' start method"
" for multiprocessing, it is currently set to {}"
.format(multiprocessing.get_start_method()))
else:
HAVE_OMP = True

Expand Down
2 changes: 1 addition & 1 deletion pycbc/inference/models/brute_marg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import numpy

from multiprocessing import Pool
from pycbc.pool import BroadcastPool as Pool
from scipy.special import logsumexp

from .gaussian_noise import BaseGaussianNoise
Expand Down
14 changes: 11 additions & 3 deletions pycbc/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
import multiprocessing.pool
import functools
from multiprocessing import TimeoutError, cpu_count
from multiprocessing import TimeoutError, cpu_count, get_context
import types
import signal
import atexit
Expand Down Expand Up @@ -49,13 +49,20 @@ def _shutdown_pool(p):
class BroadcastPool(multiprocessing.pool.Pool):
""" Multiprocessing pool with a broadcast method
"""
def __init__(self, processes=None, initializer=None, initargs=(), **kwds):
def __init__(self, processes=None, initializer=None, initargs=(),
context=None, **kwds):
global _process_lock
global _numdone
_process_lock = multiprocessing.Lock()
_numdone = multiprocessing.Value('i', 0)
noint = functools.partial(_noint, initializer)
super(BroadcastPool, self).__init__(processes, noint, initargs, **kwds)

# Default is fork to preserve child memory inheritance and
# copy on write
if context is None:
context = get_context("fork")
super(BroadcastPool, self).__init__(processes, noint, initargs,
context=context, **kwds)
atexit.register(_shutdown_pool, self)

def __len__(self):
Expand Down Expand Up @@ -166,6 +173,7 @@ def use_mpi(require_mpi=False, log=True):
size = rank = 0
return use_mpi, size, rank


def choose_pool(processes, mpi=False):
""" Get processing pool
"""
Expand Down

0 comments on commit 9092cb9

Please sign in to comment.