Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't force global multiprocessing start method #4890

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions bin/all_sky_search/pycbc_calculate_psd
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python
""" Calculate psd estimates for analysis segments
"""
import logging, argparse, numpy, multiprocessing, time, copy
import logging, argparse, numpy, time, copy
from six.moves import zip_longest
import pycbc, pycbc.psd, pycbc.strain, pycbc.events
from pycbc.io import HFile
from pycbc.pool import BroadcastPool as Pool
from pycbc.fft.fftw import set_measure_level
from pycbc.workflow import resolve_td_option
from ligo.segments import segmentlist, segment
Expand Down Expand Up @@ -101,7 +102,7 @@ segments = segmentlist(frozenset(segments))
logging.info('%d psds to calculate', len(segments))

if len(segments) > 0:
pool = multiprocessing.Pool(args.cores)
pool = Pool(args.cores)
psds = pool.map_async(get_psd, zip(segments, range(len(segments))))
psds = psds.get()
else:
Expand Down
2 changes: 1 addition & 1 deletion bin/inference/pycbc_inference_plot_movie
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import logging
import subprocess
import os
import glob
from multiprocessing import Pool
from pycbc.pool import BroadcastPool as Pool

import numpy

Expand Down
2 changes: 1 addition & 1 deletion bin/pycbc_inspiral
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import argparse
import numpy
import itertools
import time
from multiprocessing import Pool
from pycbc.pool import BroadcastPool as Pool

import pycbc
from pycbc import vetoes, psd, waveform, strain, scheme, fft, DYN_RANGE_FAC, events
Expand Down
3 changes: 2 additions & 1 deletion bin/pycbc_optimal_snr
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from ligo.lw import lsctables
import pycbc
import pycbc.inject
import pycbc.psd
from pycbc.pool import BroadcastPool as Pool
from pycbc.filter import sigma, make_frequency_series
from pycbc.types import TimeSeries, FrequencySeries, zeros, float32, \
MultiDetOptionAction, load_frequencyseries
Expand Down Expand Up @@ -312,7 +313,7 @@ if __name__ == '__main__':

if opts.cores > 1:
logging.info('Starting workers')
pool = multiprocessing.Pool(processes=opts.cores)
pool = Pool(processes=opts.cores)
iterator = pool.imap_unordered(compute_optimal_snr, inj_table)
else:
# do not bother spawning extra processes if running single-core
Expand Down
12 changes: 0 additions & 12 deletions pycbc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,6 @@ def makedir(path):
# platforms (mac) that are silly and don't use the standard gcc.
if sys.platform == 'darwin':
HAVE_OMP = False

# MacosX after python3.7 switched to 'spawn', however, this does not
# preserve common state information which we have relied on when using
# multiprocessing based pools.
import multiprocessing
if multiprocessing.get_start_method(allow_none=True) is None:
if hasattr(multiprocessing, 'set_start_method'):
multiprocessing.set_start_method('fork')
elif multiprocessing.get_start_method() != 'fork':
warnings.warn("PyCBC requires the use of the 'fork' start method"
" for multiprocessing, it is currently set to {}"
.format(multiprocessing.get_start_method()))
else:
HAVE_OMP = True

Expand Down
2 changes: 1 addition & 1 deletion pycbc/inference/models/brute_marg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import numpy

from multiprocessing import Pool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly stupid question - some places have multiprocessing.Pool, other have multiprocessing.pool.Pool. This is how it was before these changes as well, so it must be okay, but wanted to flag it

Copy link

@meiyasan meiyasan Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If using 'multiprocessing.Pool' I believe this will use default context as defined by 'set_start_method'. In this case one can just use 'get_context("fork/spawn").Pool()'

from pycbc.pool import BroadcastPool as Pool
from scipy.special import logsumexp

from .gaussian_noise import BaseGaussianNoise
Expand Down
14 changes: 11 additions & 3 deletions pycbc/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
import multiprocessing.pool
import functools
from multiprocessing import TimeoutError, cpu_count
from multiprocessing import TimeoutError, cpu_count, get_context
import types
import signal
import atexit
Expand Down Expand Up @@ -49,13 +49,20 @@ def _shutdown_pool(p):
class BroadcastPool(multiprocessing.pool.Pool):
""" Multiprocessing pool with a broadcast method
"""
def __init__(self, processes=None, initializer=None, initargs=(), **kwds):
def __init__(self, processes=None, initializer=None, initargs=(),
context=None, **kwds):
global _process_lock
global _numdone
_process_lock = multiprocessing.Lock()
_numdone = multiprocessing.Value('i', 0)
noint = functools.partial(_noint, initializer)
super(BroadcastPool, self).__init__(processes, noint, initargs, **kwds)

# Default is fork to preserve child memory inheritance and
# copy on write
if context is None:
context = get_context("fork")
super(BroadcastPool, self).__init__(processes, noint, initargs,
context=context, **kwds)
atexit.register(_shutdown_pool, self)

def __len__(self):
Expand Down Expand Up @@ -166,6 +173,7 @@ def use_mpi(require_mpi=False, log=True):
size = rank = 0
return use_mpi, size, rank


def choose_pool(processes, mpi=False):
""" Get processing pool
"""
Expand Down
Loading