diff --git a/.coveragerc b/.coveragerc index a46ae511..8640c41a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,3 +3,10 @@ source = enterprise omit = enterprise/__init__* enterprise/signals/__init__* +plugins = + coverage_conditional_plugin + +[coverage_conditional_plugin] +rules = + "sys_version_info >= (3, 8)": py-gte-38 + "sys_version_info < (3, 8)": py-lt-38 diff --git a/enterprise/pulsar.py b/enterprise/pulsar.py index 8f8c3a43..0fdb5e3d 100644 --- a/enterprise/pulsar.py +++ b/enterprise/pulsar.py @@ -5,6 +5,7 @@ import json import logging import os +import pickle import astropy.constants as const import astropy.units as u @@ -14,10 +15,7 @@ import enterprise from enterprise.signals import utils -try: - import cPickle as pickle -except: - import pickle +from enterprise.pulsar_inflate import PulsarInflater logger = logging.getLogger(__name__) @@ -36,7 +34,6 @@ logger.warning("PINT not installed. Will use libstempo instead.") # pragma: no cover pint = None - if pint is None and t2 is None: err_msg = "Must have either PINT or libstempo timing package installed" raise ImportError(err_msg) @@ -141,8 +138,11 @@ def filter_data(self, start_time=None, end_time=None): dmx_mask = np.sum(self._designmatrix, axis=0) != 0.0 self._designmatrix = self._designmatrix[:, dmx_mask] - for key in self._flags: - self._flags[key] = self._flags[key][mask] + if isinstance(self._flags, np.ndarray): + self._flags = self._flags[mask] + else: + for key in self._flags: + self._flags[key] = self._flags[key][mask] if self._planetssb is not None: self._planetssb = self.planetssb[mask, :, :] @@ -233,26 +233,37 @@ def dmx(self): def flags(self): """Return a dictionary of tim-file flags.""" - return dict((k, v[self._isort]) for k, v in self._flags.items()) + flagnames = self._flags.dtype.names if isinstance(self._flags, np.ndarray) else self._flags.keys() + + return {flag: self._flags[flag][self._isort] for flag in flagnames} @property def backend_flags(self): """Return array of backend flags. + Not all TOAs have the same flags for all data sets. In order to facilitate this we have a ranked ordering system that will look for flags. The order is `group`, `g`, `sys`, `i`, `f`, `fe`+`be`. + """ - nobs = len(self._toas) - bflags = ["flag"] * nobs - flags = [["group"], ["g"], ["sys"], ["i"], ["f"], ["fe", "be"]] - for ii in range(nobs): - # TODO: make this cleaner - for f in flags: - if np.all([x in self._flags and self._flags[x][ii] != "" for x in f]): - bflags[ii] = "_".join(self._flags[x][ii] for x in f) - break - return np.array(bflags)[self._isort] + # collect flag names + flagnames = self._flags.dtype.names if isinstance(self._flags, np.ndarray) else list(self._flags.keys()) + + # allocate array with widest dtype + ret = np.zeros(len(self._toas), dtype=max([self._flags[name].dtype for name in flagnames])) + + # go through the flags in reverse order of preference + # setting or replacing values for each TOA + + if "fe" in flagnames and "be" in flagnames: + ret[:] = [(a + "_" + b if (a and b) else "") for a, b in zip(self._flags["fe"], self._flags["be"])] + + for flag in ["f", "i", "sys", "g", "group"]: + if flag in flagnames: + ret[:] = np.where(self._flags[flag] == "", ret, self._flags[flag]) + + return ret @property def theta(self): @@ -431,9 +442,14 @@ def __init__(self, t2pulsar, sort=True, drop_t2pulsar=True, planets=True): spars = [str(p) for p in t2pulsar.pars(which="set")] self.setpars = [sp for sp in spars if sp not in self.fitpars] - self._flags = {} + flags = {} for key in t2pulsar.flags(): - self._flags[key] = t2pulsar.flagvals(key) + flags[key] = t2pulsar.flagvals(key) + + # new-style storage of flags as a numpy record array (previously, psr._flags = flags) + self._flags = np.zeros(len(self._toas), dtype=[(key, val.dtype) for key, val in flags.items()]) + for key, val in flags.items(): + self._flags[key] = val self._pdist = self._get_pdist() self._raj, self._decj = self._get_radec(t2pulsar) @@ -524,9 +540,41 @@ def _get_sunssb(self, t2pulsar): sunssb[:, 3:] = utils.ecl2eq_vec(sunssb[:, 3:]) return sunssb + # infrastructure for sharing Pulsar objects among processes + # (currently Tempo2Pulsar only) + # the Pulsar deflater will copy select numpy arrays to SharedMemory, + # then replace them with pickleable objects that can be inflated + # to numpy arrays with SharedMemory storage -def Pulsar(*args, **kwargs): + _todeflate = ["_designmatrix", "_planetssb", "_sunssb", "_flags"] + _deflated = "pristine" + + def deflate(psr): # pragma: py-lt-38 + if psr._deflated == "pristine": + for attr in psr._todeflate: + if isinstance(getattr(psr, attr), np.ndarray): + setattr(psr, attr, PulsarInflater(getattr(psr, attr))) + + psr._deflated = "deflated" + + def inflate(psr): # pragma: py-lt-38 + if psr._deflated == "deflated": + for attr in psr._todeflate: + if isinstance(getattr(psr, attr), PulsarInflater): + setattr(psr, attr, getattr(psr, attr).inflate()) + psr._deflated = "inflated" + + def destroy(psr): # pragma: py-lt-38 + if psr._deflated == "deflated": + for attr in psr._todeflate: + if isinstance(getattr(psr, attr), PulsarInflater): + getattr(psr, attr).destroy() + + psr._deflated = "destroyed" + + +def Pulsar(*args, **kwargs): ephem = kwargs.get("ephem", None) clk = kwargs.get("clk", None) bipm_version = kwargs.get("bipm_version", None) diff --git a/enterprise/pulsar_inflate.py b/enterprise/pulsar_inflate.py new file mode 100644 index 00000000..fda5fc6a --- /dev/null +++ b/enterprise/pulsar_inflate.py @@ -0,0 +1,60 @@ +# pulsar_inflate.py +"""Defines PulsarInflater class: instances copy a numpy array to shared memory, +and (after pickling) will reinflate to a numpy array that refers to the shared +data. +""" + +import numpy as np + +try: + from multiprocessing import shared_memory, resource_tracker +except: + # shared_memory unavailable in Python < 3.8 + pass + + +class memmap(np.ndarray): + def __del__(self): + if self.base is None and hasattr(self, "shm"): + self.shm.close() + + +# lifecycle of shared pulsar arrays: +# - begin life as numpy arrays in Pulsar object +# - upon psr.deflate(), replaced by PulsarInflater objects +# - these objects save the array metadata, create a SharedMemory buffer, copy the arrays into it +# - the PulsarInflater objects cannot be used as arrays until re-inflated +# - upon psr.inflate, the PulsarInflater objects are replaced with ndarray views of the SharedMemory buffers +# - the views are special memmap objects that hold a reference to the SharedMemory, and close it on destruction +# - upon psr.destroy, the SharedMemory objects are unlinked and the arrays become unusable +# - standard usage requires 3+ processes: +# - a creator, who calls deflate then pickle +# - one or more users, who unpickle then inflate +# - a destroyer, who unpickles then destroys + + +class PulsarInflater: + def __init__(self, array): + self.dtype, self.shape, self.nbytes = array.dtype, array.shape, array.nbytes + + shm = shared_memory.SharedMemory(create=True, size=array.nbytes) + self.shmname = shm.name + + # shm.buf[:array.nbytes] = array.view(dtype='uint8').flatten() + + b = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf) + b[:] = array[:] + + resource_tracker.unregister(shm._name, "shared_memory") + + def inflate(self): + shm = shared_memory.SharedMemory(self.shmname) + + c = np.ndarray(self.shape, dtype=self.dtype, buffer=shm.buf).view(memmap) + c.shm = shm + + return c + + def destroy(self): + shm = shared_memory.SharedMemory(self.shmname) + shm.unlink() diff --git a/requirements_dev.txt b/requirements_dev.txt index b745ae2d..5ae1ed0d 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -12,5 +12,6 @@ PyYAML>=4.2b1 pytest>=4.0.0 sphinx-rtd-theme>=0.4.0 pytest-cov>=2.7.0 +coverage-conditional-plugin>=0.4.0 jupyter>=1.0.0 build==0.3.1.post1 \ No newline at end of file diff --git a/tests/test_pulsar.py b/tests/test_pulsar.py index f8ec6683..404af894 100644 --- a/tests/test_pulsar.py +++ b/tests/test_pulsar.py @@ -9,21 +9,18 @@ for time slicing, PINT integration and pickling. """ - +import sys import os import shutil import unittest +import pickle +import pytest import numpy as np from enterprise.pulsar import Pulsar from tests.enterprise_test_data import datadir -try: - import cPickle as pickle -except: - import pickle - class TestPulsar(unittest.TestCase): @classmethod @@ -133,6 +130,31 @@ def test_to_pickle(self): assert np.allclose(self.psr.residuals, pkl_psr.residuals, rtol=1e-10) + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python >= 3.8") + def test_deflate_inflate(self): + psr = Pulsar(datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim") + + dm = psr._designmatrix.copy() + + psr.deflate() + psr.to_pickle() + + with open("B1855+09.pkl", "rb") as f: + pkl_psr = pickle.load(f) + pkl_psr.inflate() + + assert np.allclose(dm, pkl_psr._designmatrix) + + del pkl_psr + + psr.destroy() + + with open("B1855+09.pkl", "rb") as f: + pkl_psr = pickle.load(f) + + with self.assertRaises(FileNotFoundError): + pkl_psr.inflate() + def test_wrong_input(self): """Test exception when incorrect par(tim) file given.""" @@ -182,3 +204,6 @@ def test_model(self): def test_pint_toas(self): assert hasattr(self.psr, "pint_toas") + + def test_deflate_inflate(self): + pass