diff --git a/HISTORY.rst b/HISTORY.rst index 1a04b8581..cb2f36ea0 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Deprecate use of @generated_jit. Remove upper bound on numba. (:pr:`289`) * Remove unnecessary new_axes in calibration utils after upstream fix in dask (:pr:`288`) * Check that ncorr is never larger than 2 in calibration utils (:pr:`287`) * Optionally check NRT allocations (:pr:`286`) diff --git a/africanus/averaging/bda_avg.py b/africanus/averaging/bda_avg.py index 3303a54c7..12574a161 100644 --- a/africanus/averaging/bda_avg.py +++ b/africanus/averaging/bda_avg.py @@ -10,7 +10,9 @@ merge_flags, vis_output_arrays) from africanus.util.docs import DocstringTemplate -from africanus.util.numba import (generated_jit, +from africanus.util.numba import (njit, + overload, + JIT_OPTIONS, intrinsic, is_numba_type_none) @@ -20,11 +22,25 @@ RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_average(meta, ant1, ant2, flag_row=None, time_centroid=None, exposure=None, uvw=None, weight=None, sigma=None): + return row_average_impl(meta, ant1, ant2, flag_row=flag_row, + time_centroid=time_centroid, exposure=exposure, + uvw=uvw, weight=weight, sigma=sigma) + +def row_average_impl(meta, ant1, ant2, flag_row=None, + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): + return NotImplementedError + + +@overload(row_average_impl, jit_options=JIT_OPTIONS) +def nb_row_average_impl(meta, ant1, ant2, flag_row=None, + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): have_flag_row = not is_numba_type_none(flag_row) have_time_centroid = not is_numba_type_none(time_centroid) have_exposure = not is_numba_type_none(exposure) @@ -310,13 +326,35 @@ def codegen(context, builder, signature, args): return sig, codegen -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_chan_average(meta, flag_row=None, weight=None, visibilities=None, flag=None, weight_spectrum=None, sigma_spectrum=None): + return row_chan_average_impl(meta, flag_row=flag_row, weight=weight, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum) + + +def row_chan_average_impl(meta, flag_row=None, weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None): + + return NotImplementedError + + +@overload(row_chan_average_impl, jit_options=JIT_OPTIONS) +def nb_row_chan_average(meta, flag_row=None, weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None): + have_vis = not is_numba_type_none(visibilities) have_flag = not is_numba_type_none(flag) have_flag_row = not is_numba_type_none(flag_row) @@ -523,7 +561,7 @@ def impl(meta, flag_row=None, weight=None, _rowchan_output_fields) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def bda(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, uvw=None, weight=None, sigma=None, @@ -535,7 +573,21 @@ def bda(time, interval, antenna1, antenna2, decorrelation=0.98, time_bin_secs=None, min_nchan=1): - def impl(time, interval, antenna1, antenna2, + + return bda_impl(time, interval, antenna1, antenna2, + time_centroid=time_centroid, exposure=exposure, + flag_row=flag_row, uvw=uvw, weight=weight, sigma=sigma, + chan_freq=chan_freq, chan_width=chan_width, + effective_bw=effective_bw, resolution=resolution, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + max_uvw_dist=max_uvw_dist, max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, min_nchan=min_nchan) + + +def bda_impl(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, uvw=None, weight=None, sigma=None, chan_freq=None, chan_width=None, @@ -546,54 +598,67 @@ def impl(time, interval, antenna1, antenna2, decorrelation=0.98, time_bin_secs=None, min_nchan=1): - # Merge flag_row and flag arrays - flag_row = merge_flags(flag_row, flag) - - meta = bda_mapper(time, interval, antenna1, antenna2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan) - - row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841 - time_centroid, exposure, uvw, - weight=weight, sigma=sigma) - - row_chan_avg = row_chan_average(meta, # noqa: F841 - flag_row=flag_row, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum) - - # Have to explicitly write it out because numba tuples - # are highly constrained types - return AverageOutput(meta.map, - meta.offsets, - meta.decorr_chan_width, - meta.time, - meta.interval, - meta.chan_width, - meta.flag_row, - row_avg.antenna1, - row_avg.antenna2, - row_avg.time_centroid, - row_avg.exposure, - row_avg.uvw, - row_avg.weight, - row_avg.sigma, - # None, # chan_data.chan_freq, - # None, # chan_data.chan_width, - # None, # chan_data.effective_bw, - # None, # chan_data.resolution, - row_chan_avg.visibilities, - row_chan_avg.flag, - row_chan_avg.weight_spectrum, - row_chan_avg.sigma_spectrum) - - return impl + return NotImplementedError + + +@overload(bda_impl, jit_options=JIT_OPTIONS) +def nb_bda_impl(time, interval, antenna1, antenna2, + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + max_uvw_dist=None, max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): + # Merge flag_row and flag arrays + flag_row = merge_flags(flag_row, flag) + + meta = bda_mapper(time, interval, antenna1, antenna2, uvw, + chan_width, chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan) + + row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841 + time_centroid, exposure, uvw, + weight=weight, sigma=sigma) + + row_chan_avg = row_chan_average(meta, # noqa: F841 + flag_row=flag_row, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum) + + # Have to explicitly write it out because numba tuples + # are highly constrained types + return AverageOutput(meta.map, + meta.offsets, + meta.decorr_chan_width, + meta.time, + meta.interval, + meta.chan_width, + meta.flag_row, + row_avg.antenna1, + row_avg.antenna2, + row_avg.time_centroid, + row_avg.exposure, + row_avg.uvw, + row_avg.weight, + row_avg.sigma, + # None, # chan_data.chan_freq, + # None, # chan_data.chan_width, + # None, # chan_data.effective_bw, + # None, # chan_data.resolution, + row_chan_avg.visibilities, + row_chan_avg.flag, + row_chan_avg.weight_spectrum, + row_chan_avg.sigma_spectrum) BDA_DOCS = DocstringTemplate(""" diff --git a/africanus/averaging/bda_mapping.py b/africanus/averaging/bda_mapping.py index 34d5de625..a2abdbf7f 100644 --- a/africanus/averaging/bda_mapping.py +++ b/africanus/averaging/bda_mapping.py @@ -5,10 +5,14 @@ import numpy as np import numba from numba.experimental import jitclass -import numba.types +from numba import types from africanus.constants import c as lightspeed -from africanus.util.numba import generated_jit, njit, is_numba_type_none +from africanus.util.numba import ( + JIT_OPTIONS, + overload, + njit, + is_numba_type_none) from africanus.averaging.support import unique_time, unique_baselines @@ -16,73 +20,6 @@ class RowMapperError(Exception): pass -@njit(nogil=True, cache=True) -def erf26(x): - """Implements 7.1.26 erf approximation from Abramowitz and - Stegun (1972), pg. 299. Accurate for abs(eps(x)) <= 1.5e-7.""" - - # Constants - p = 0.3275911 - a1 = 0.254829592 - a2 = -0.284496736 - a3 = 1.421413741 - a4 = -1.453152027 - a5 = 1.061405429 - e = 2.718281828 - - # t - t = 1.0/(1.0 + (p * x)) - - # Erf calculation - erf = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t) - erf *= e ** -(x ** 2) - - return -round(erf, 9) if x < 0 else round(erf, 0) - - -@njit(nogil=True, cache=True) -def time_decorrelation(u, v, w, max_lm, time_bin_secs, min_wavelength): - sidereal_rotation_rate = 7.292118516e-5 - diffraction_limit = min_wavelength / np.sqrt(u**2 + v**2 + w**2) - term = max_lm * time_bin_secs * sidereal_rotation_rate / diffraction_limit - return 1.0 - 1.0645 * erf26(0.8326*term) / term - - -_SERIES_COEFFS = (1./40, 107./67200, 3197./24192000, 49513./3973939200) - - -@njit(nogil=True, cache=True, inline='always') -def inv_sinc(sinc_x, tol=1e-12): - # Invalid input - if sinc_x > 1.0: - raise ValueError("sinc_x > 1.0") - - # Initial guess from reversion of Taylor series - # https://math.stackexchange.com/questions/3189307/inverse-of-frac-sinxx - x = t_pow = np.sqrt(6*np.abs((1 - sinc_x))) - t_squared = t_pow*t_pow - - for coeff in numba.literal_unroll(_SERIES_COEFFS): - t_pow *= t_squared - x += coeff * t_pow - - # Use Newton Raphson to go the rest of the way - # https://www.wolframalpha.com/input/?i=simplify+%28sinc%5Bx%5D+-+c%29+%2F+D%5Bsinc%5Bx%5D%2Cx%5D - while True: - # evaluate delta between this iteration sinc(x) and original - sinx = np.sin(x) - š¯˛“sinc_x = (1.0 if x == 0.0 else sinx/x) - sinc_x - - # Stop if converged - if np.abs(š¯˛“sinc_x) < tol: - break - - # Next iteration - x -= (x*x * š¯˛“sinc_x) / (x*np.cos(x) - sinx) - - return x - - @njit(nogil=True, cache=True, inline='always') def factors(n): assert n >= 1 @@ -126,7 +63,7 @@ def max_chan_width(ref_freq, fractional_bandwidth): "nchan", "flag"]) -class Binner(object): +class Binner: def __init__(self, row_start, row_end, max_lm, decorrelation, time_bin_secs, max_chan_freq): @@ -338,7 +275,7 @@ def finalise_bin(self, auto_corr, uvw, time, interval, "time", "interval", "chan_width", "flag_row"]) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def bda_mapper(time, interval, ant1, ant2, uvw, chan_width, chan_freq, max_uvw_dist, @@ -347,10 +284,39 @@ def bda_mapper(time, interval, ant1, ant2, uvw, decorrelation=0.98, time_bin_secs=None, min_nchan=1): - + return bda_mapper_impl(time, interval, ant1, ant2, uvw, + chan_width, chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan) + + +def bda_mapper_impl(time, interval, ant1, ant2, uvw, + chan_width, chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): + return NotImplementedError + + +@overload(bda_mapper_impl, jit_options={"nogil": True}) +def nb_bda_mapper(time, interval, ant1, ant2, uvw, + chan_width, chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): have_time_bin_secs = not is_numba_type_none(time_bin_secs) - Omitted = numba.types.misc.Omitted + Omitted = types.misc.Omitted decorr_type = (numba.typeof(decorrelation.value) if isinstance(decorrelation, Omitted) diff --git a/africanus/averaging/shared.py b/africanus/averaging/shared.py index 5d2deef89..089bd9da9 100644 --- a/africanus/averaging/shared.py +++ b/africanus/averaging/shared.py @@ -4,8 +4,8 @@ from africanus.util.numba import (is_numba_type_none, intrinsic, + JIT_OPTIONS, njit, - generated_jit, overload) @@ -13,11 +13,17 @@ def shape_or_invalid_shape(array, ndim): pass -# TODO(sjperkins) -# maybe replace with njit and inline='always' if -# https://github.com/numba/numba/issues/4693 is resolved -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def merge_flags(flag_row, flag): + return merge_flags_impl(flag_row, flag) + + +def merge_flags_impl(flag_row, flag): + raise NotImplementedError + + +@overload(merge_flags_impl, jit_options=JIT_OPTIONS) +def nb_merge_flags(flag_row, flag): have_flag_row = not is_numba_type_none(flag_row) have_flag = not is_numba_type_none(flag) diff --git a/africanus/averaging/support.py b/africanus/averaging/support.py index d5e15f681..a4ca769ef 100644 --- a/africanus/averaging/support.py +++ b/africanus/averaging/support.py @@ -4,7 +4,7 @@ import numpy as np import numba -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import JIT_OPTIONS, overload, njit @njit(nogil=True, cache=True) @@ -53,8 +53,17 @@ def _unique_internal(data): return aux[mask], perm[mask], inv_idx, np.diff(np.array(counts)) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def unique_time(time): + return unique_time_impl(time) + + +def unique_time_impl(time): + return NotImplementedError + + +@overload(unique_time_impl, jit_options=JIT_OPTIONS) +def nb_unique_time(time): """ Return unique time, inverse index and counts """ if time.dtype not in (numba.float32, numba.float64): raise ValueError("time must be floating point but is %s" % time.dtype) @@ -65,8 +74,17 @@ def impl(time): return impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def unique_baselines(ant1, ant2): + return unique_baselines_impl(ant1, ant2) + + +def unique_baselines_impl(ant1, ant2): + return NotImplementedError + + +@overload(unique_baselines_impl, jit_options=JIT_OPTIONS) +def nb_unique_baselines(ant1, ant2): """ Return unique baselines, inverse index and counts """ if not ant1.dtype == numba.int32 or not ant2.dtype == numba.int32: # Need these to be int32 for the bl_32bit.view(np.int64) trick diff --git a/africanus/averaging/tests/test_bda_averaging.py b/africanus/averaging/tests/test_bda_averaging.py index 549048f69..7ea2688c1 100644 --- a/africanus/averaging/tests/test_bda_averaging.py +++ b/africanus/averaging/tests/test_bda_averaging.py @@ -212,7 +212,7 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): assert_array_equal(row_avg.exposure, out_interval) assert_array_equal(row_avg.uvw, out_uvw) assert_array_equal(row_avg.weight, out_weight) - assert_array_equal(row_avg.sigma, out_sigma) + assert_array_almost_equal(row_avg.sigma, out_sigma) vshape = (in_row, in_chan, in_corr) vis = rs.normal(size=vshape) + rs.normal(size=vshape)*1j diff --git a/africanus/averaging/time_and_channel_avg.py b/africanus/averaging/time_and_channel_avg.py index a1c4aeb87..208fd07e0 100644 --- a/africanus/averaging/time_and_channel_avg.py +++ b/africanus/averaging/time_and_channel_avg.py @@ -13,7 +13,7 @@ vis_output_arrays) from africanus.util.docs import DocstringTemplate -from africanus.util.numba import (is_numba_type_none, generated_jit, +from africanus.util.numba import (is_numba_type_none, JIT_OPTIONS, njit, overload, intrinsic) TUPLE_TYPE = 0 @@ -59,10 +59,25 @@ def chan_add(output, input, orow, ochan, irow, ichan, corr): RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_average(meta, ant1, ant2, flag_row=None, time_centroid=None, exposure=None, uvw=None, weight=None, sigma=None): + return row_average_impl(meta, ant1, ant2, flag_row=flag_row, + time_centroid=time_centroid, exposure=exposure, + uvw=uvw, weight=weight, sigma=sigma) + + +def row_average_impl(meta, ant1, ant2, flag_row=None, + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): + return NotImplementedError + + +@overload(row_average_impl, jit_options=JIT_OPTIONS) +def nb_row_average(meta, ant1, ant2, flag_row=None, + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): have_flag_row = not is_numba_type_none(flag_row) flags_match = matching_flag_factory(have_flag_row) @@ -317,11 +332,30 @@ def codegen(context, builder, signature, args): return sig, codegen -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_chan_average(row_meta, chan_meta, flag_row=None, weight=None, visibilities=None, flag=None, weight_spectrum=None, sigma_spectrum=None): + return row_chan_average_impl(row_meta, chan_meta, + flag_row=flag_row, weight=weight, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum) + + +def row_chan_average_impl(row_meta, chan_meta, + flag_row=None, weight=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None): + return NotImplementedError + + +@overload(row_chan_average_impl, jit_options=JIT_OPTIONS) +def nb_row_chan_average(row_meta, chan_meta, + flag_row=None, weight=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None): dummy_chan_freq = None dummy_chan_width = None @@ -522,9 +556,24 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, ChannelAverageOutput = namedtuple("ChannelAverageOutput", _chan_output_fields) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def chan_average(chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None): + return chan_average_impl(chan_meta, chan_freq=chan_freq, + chan_width=chan_width, + effective_bw=effective_bw, + resolution=resolution) + + +def chan_average_impl(chan_meta, chan_freq=None, chan_width=None, + effective_bw=None, resolution=None): + + return NotImplementedError + + +@overload(chan_average_impl, jit_options=JIT_OPTIONS) +def nb_chan_average(chan_meta, chan_freq=None, chan_width=None, + effective_bw=None, resolution=None): def impl(chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None): @@ -580,7 +629,7 @@ def impl(chan_meta, chan_freq=None, chan_width=None, _rowchan_output_fields) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def time_and_channel(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, uvw=None, weight=None, sigma=None, @@ -589,17 +638,54 @@ def time_and_channel(time, interval, antenna1, antenna2, visibilities=None, flag=None, weight_spectrum=None, sigma_spectrum=None, time_bin_secs=1.0, chan_bin_size=1): + return time_and_channel_impl(time, interval, antenna1, antenna2, + time_centroid=time_centroid, + exposure=exposure, + flag_row=flag_row, + uvw=uvw, weight=weight, sigma=sigma, + chan_freq=chan_freq, chan_width=chan_width, + effective_bw=effective_bw, + resolution=resolution, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + time_bin_secs=time_bin_secs, + chan_bin_size=chan_bin_size) + + +def time_and_channel_impl(time, interval, antenna1, antenna2, + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + time_bin_secs=1.0, chan_bin_size=1): + return NotImplementedError + + +@overload(time_and_channel_impl, jit_options=JIT_OPTIONS) +def nb_time_and_channel(time, interval, antenna1, antenna2, + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + time_bin_secs=1.0, chan_bin_size=1): valid_types = (types.misc.Omitted, types.scalars.Float, types.scalars.Integer) if not isinstance(time_bin_secs, valid_types): - raise TypeError("time_bin_secs must be a scalar float") + raise TypeError( + f"time_bin_secs ({time_bin_secs}) must be a scalar float") valid_types = (types.misc.Omitted, types.scalars.Integer) if not isinstance(chan_bin_size, valid_types): - raise TypeError("chan_bin_size must be a scalar integer") + raise TypeError( + f"chan_bin_size ({chan_bin_size}) must be a scalar integer") def impl(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, diff --git a/africanus/averaging/time_and_channel_mapping.py b/africanus/averaging/time_and_channel_mapping.py index a3f2984a5..31b287b86 100644 --- a/africanus/averaging/time_and_channel_mapping.py +++ b/africanus/averaging/time_and_channel_mapping.py @@ -7,7 +7,12 @@ import numba from africanus.averaging.support import unique_time, unique_baselines -from africanus.util.numba import is_numba_type_none, generated_jit, njit, jit +from africanus.util.numba import ( + is_numba_type_none, + njit, + jit, + JIT_OPTIONS, + overload) class RowMapperError(Exception): @@ -56,7 +61,7 @@ def impl(flag_row, in_row, out_flag_row, out_row, flagged): ["map", "time", "interval", "flag_row"]) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_mapper(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): """ @@ -178,6 +183,18 @@ def row_mapper(time, interval, antenna1, antenna2, Raised if an illegal condition occurs """ + return row_mapper_impl(time, interval, antenna1, antenna2, + flag_row=flag_row, time_bin_secs=time_bin_secs) + + +def row_mapper_impl(time, interval, antenna1, antenna2, + flag_row=None, time_bin_secs=1): + return NotImplementedError + + +@overload(row_mapper_impl, jit_options=JIT_OPTIONS) +def nb_row_mapper(time, interval, antenna1, antenna2, + flag_row=None, time_bin_secs=1): have_flag_row = not is_numba_type_none(flag_row) is_flagged_fn = is_flagged_factory(have_flag_row) diff --git a/africanus/calibration/phase_only/phase_only.py b/africanus/calibration/phase_only/phase_only.py index 8668de19c..f0c6d329a 100755 --- a/africanus/calibration/phase_only/phase_only.py +++ b/africanus/calibration/phase_only/phase_only.py @@ -3,7 +3,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate from africanus.calibration.utils import residual_vis, check_type -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -22,9 +22,22 @@ def jacobian(a1j, blj, a2j, sign, out): return njit(nogil=True, inline='always')(jacobian) -@generated_jit(nopython=True, nogil=True, cache=True, fastmath=True) +@njit(**JIT_OPTIONS) def compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): + return compute_jhj_and_jhr_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, residual, + model, flag) + + +def compute_jhj_and_jhr_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, residual, model, flag): + return NotImplementedError + + +@overload(compute_jhj_and_jhr_impl, jit_options=JIT_OPTIONS) +def nb_compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, residual, model, flag): mode = check_type(jones, residual) if mode != DIAG_DIAG: @@ -70,9 +83,21 @@ def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1, return _jhj_and_jhr_fn -@generated_jit(nopython=True, nogil=True, cache=True, fastmath=True) +@njit(**JIT_OPTIONS) def compute_jhj(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag): + return compute_jhj_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, model, flag) + + +def compute_jhj_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, flag): + return NotImplementedError + + +@overload(compute_jhj_impl, jit_options=JIT_OPTIONS) +def nb_compute_jhj(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, flag): mode = check_type(jones, model, vis_type='model') @@ -110,9 +135,22 @@ def _compute_jhj_fn(time_bin_indices, time_bin_counts, antenna1, return _compute_jhj_fn -@generated_jit(nopython=True, nogil=True, cache=True, fastmath=True) +@njit(**JIT_OPTIONS) def compute_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): + return compute_jhr_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, residual, + model, flag) + + +def compute_jhr_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, residual, model, flag): + return NotImplementedError + + +@overload(compute_jhr_impl, jit_options=JIT_OPTIONS) +def nb_compute_jhr(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, residual, model, flag): mode = check_type(jones, model, vis_type='model') diff --git a/africanus/calibration/utils/compute_and_corrupt_vis.py b/africanus/calibration/utils/compute_and_corrupt_vis.py index 6902c06bd..33ad85de0 100644 --- a/africanus/calibration/utils/compute_and_corrupt_vis.py +++ b/africanus/calibration/utils/compute_and_corrupt_vis.py @@ -2,7 +2,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils import check_type from africanus.constants import minus_two_pi_over_c as m2pioc from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -73,9 +73,22 @@ def jones_mul(a1j, model, a2j, uvw, freq, lm, out): return njit(nogil=True, inline='always')(jones_mul) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm): + return compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, model, + uvw, freq, lm) + + +def compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, uvw, freq, lm): + return NotImplementedError + + +@overload(compute_and_corrupt_vis_impl, jit_options=JIT_OPTIONS) +def mb_compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, uvw, freq, lm): mode = check_type(jones, model, vis_type='model') jones_mul = jones_mul_factory(mode) diff --git a/africanus/calibration/utils/correct_vis.py b/africanus/calibration/utils/correct_vis.py index 2071507cc..6e99cdea6 100644 --- a/africanus/calibration/utils/correct_vis.py +++ b/africanus/calibration/utils/correct_vis.py @@ -2,7 +2,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils import check_type from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -66,9 +66,21 @@ def jones_inverse_mul(a1j, blj, a2j, out): return njit(nogil=True, inline='always')(jones_inverse_mul) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def correct_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag): + return correct_vis_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, vis, flag) + + +def correct_vis_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, vis, flag): + return NotImplementedError + + +@overload(correct_vis_impl, jit_options=JIT_OPTIONS) +def nb_correct_vis(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, vis, flag): mode = check_type(jones, vis) jones_inverse_mul = jones_inverse_mul_factory(mode) diff --git a/africanus/calibration/utils/corrupt_vis.py b/africanus/calibration/utils/corrupt_vis.py index e1059899d..191069f00 100644 --- a/africanus/calibration/utils/corrupt_vis.py +++ b/africanus/calibration/utils/corrupt_vis.py @@ -2,7 +2,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils import check_type from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -56,9 +56,21 @@ def jones_mul(a1j, model, a2j, out): return njit(nogil=True, inline='always')(jones_mul) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): + return corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model) + + +def corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model): + return NotImplementedError + + +@overload(corrupt_vis_impl, jit_options=JIT_OPTIONS) +def nb_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model): mode = check_type(jones, model, vis_type='model') jones_mul = jones_mul_factory(mode) diff --git a/africanus/calibration/utils/residual_vis.py b/africanus/calibration/utils/residual_vis.py index bc985d8b9..6f2d0c961 100644 --- a/africanus/calibration/utils/residual_vis.py +++ b/africanus/calibration/utils/residual_vis.py @@ -3,7 +3,7 @@ import numpy as np from functools import wraps from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils import check_type from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -59,9 +59,21 @@ def subtract_model(a1j, blj, a2j, model, out): return njit(nogil=True, inline='always')(subtract_model) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def residual_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model): + return residual_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, vis, flag, model) + + +def residual_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, vis, flag, model): + return NotImplementedError + + +@overload(residual_vis_impl, jit_options=JIT_OPTIONS) +def nb_residual_vis(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, vis, flag, model): mode = check_type(jones, vis) subtract_model = subtract_model_factory(mode) diff --git a/africanus/coordinates/coordinates.py b/africanus/coordinates/coordinates.py index c2b127382..475691d36 100644 --- a/africanus/coordinates/coordinates.py +++ b/africanus/coordinates/coordinates.py @@ -4,7 +4,9 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import is_numba_type_none, generated_jit, jit +from africanus.util.numba import (is_numba_type_none, + jit, JIT_OPTIONS, + njit, overload) from africanus.util.requirements import requires_optional try: @@ -25,8 +27,17 @@ def _return_phase_centre(phase_centre, dtype): return phase_centre -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def radec_to_lmn(radec, phase_centre=None): + return radec_to_lmn_impl(radec, phase_centre=phase_centre) + + +def radec_to_lmn_impl(radec, phase_centre=None): + raise NotImplementedError + + +@overload(radec_to_lmn_impl, jit_options=JIT_OPTIONS) +def nb_radec_to_lmn(radec, phase_centre=None): dtype = radec.dtype if is_numba_type_none(phase_centre): @@ -64,8 +75,17 @@ def _radec_to_lmn_impl(radec, phase_centre=None): return _radec_to_lmn_impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def radec_to_lm(radec, phase_centre=None): + return radec_to_lm_impl(radec, phase_centre=phase_centre) + + +def radec_to_lm_impl(radec, phase_centre=None): + raise NotImplementedError + + +@overload(radec_to_lm_impl, jit_options=JIT_OPTIONS) +def nb_radec_to_lm(radec, phase_centre=None): dtype = radec.dtype if is_numba_type_none(phase_centre): @@ -101,8 +121,17 @@ def _radec_to_lm_impl(radec, phase_centre=None): return _radec_to_lm_impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def lmn_to_radec(lmn, phase_centre=None): + return lmn_to_radec_impl(lmn, phase_centre=phase_centre) + + +def lmn_to_radec_impl(lmn, phase_centre=None): + raise NotImplementedError + + +@overload(lmn_to_radec_impl, jit_options=JIT_OPTIONS) +def nb_lmn_to_radec(lmn, phase_centre=None): dtype = lmn.dtype if is_numba_type_none(phase_centre): @@ -131,8 +160,17 @@ def _lmn_to_radec_impl(lmn, phase_centre=None): return _lmn_to_radec_impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def lm_to_radec(lm, phase_centre=None): + return lm_to_radec_impl(lm, phase_centre=phase_centre) + + +def lm_to_radec_impl(lm, phase_centre=None): + raise NotImplementedError + + +@overload(lm_to_radec_impl, jit_options=JIT_OPTIONS) +def nb_lm_to_radec(lm, phase_centre=None): dtype = lm.dtype if is_numba_type_none(phase_centre): diff --git a/africanus/dft/kernels.py b/africanus/dft/kernels.py index 893302ed1..000cbba1d 100644 --- a/africanus/dft/kernels.py +++ b/africanus/dft/kernels.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from africanus.util.numba import is_numba_type_none, generated_jit +from africanus.util.numba import (is_numba_type_none, njit, + overload, JIT_OPTIONS) from africanus.util.docs import doc_tuple_to_str from collections import namedtuple @@ -11,9 +12,21 @@ from africanus.constants import minus_two_pi_over_c, two_pi_over_c -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def im_to_vis(image, uvw, lm, frequency, convention='fourier', dtype=None): + return im_to_vis_impl(image, uvw, lm, frequency, + convention=convention, dtype=dtype) + + +def im_to_vis_impl(image, uvw, lm, frequency, + convention='fourier', dtype=None): + raise NotImplementedError + + +@overload(im_to_vis_impl, jit_options=JIT_OPTIONS) +def nb_im_to_vis(image, uvw, lm, frequency, + convention='fourier', dtype=None): # Infer complex output dtype if none provided if is_numba_type_none(dtype): out_dtype = np.result_type(np.complex64, @@ -62,9 +75,21 @@ def impl(image, uvw, lm, frequency, return impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def vis_to_im(vis, uvw, lm, frequency, flags, convention='fourier', dtype=None): + return vis_to_im_impl(vis, uvw, lm, frequency, flags, + convention=convention, dtype=dtype) + + +def vis_to_im_impl(vis, uvw, lm, frequency, flags, + convention='fourier', dtype=None): + raise NotImplementedError + + +@overload(vis_to_im_impl, jit_options=JIT_OPTIONS) +def nb_vis_to_im(vis, uvw, lm, frequency, flags, + convention='fourier', dtype=None): # Infer output dtype if none provided if is_numba_type_none(dtype): # Support both real and complex visibilities... diff --git a/africanus/dft/tests/test_dft.py b/africanus/dft/tests/test_dft.py index 53915f5b9..ac445d82f 100644 --- a/africanus/dft/tests/test_dft.py +++ b/africanus/dft/tests/test_dft.py @@ -161,7 +161,7 @@ def test_adjointness(): RHS = (RH(gamma_vis, uvw, lm, frequency, flag).reshape( size_im, 1).T.dot(gamma_im.reshape(size_im, 1))).real - assert np.abs(LHS - RHS) < 1e-14 + assert np.abs(LHS - RHS) < 1e-13 def test_vis_to_im_flagged(): diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index d23341dda..af4f98feb 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -2,10 +2,11 @@ from collections import defaultdict import numba -from numba import generated_jit, types +from numba import types import numpy as np from africanus.util.patterns import Multiton +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.experimental.rime.fused.arguments import ArgumentDependencies from africanus.experimental.rime.fused.intrinsics import IntrinsicFactory from africanus.experimental.rime.fused.specification import RimeSpecification @@ -29,23 +30,34 @@ def rime_impl_factory(terms, transformers, ncorr): - @generated_jit(nopython=True, nogil=True, cache=True) - def rime(names, *inargs): - if len(inargs) != 1 or not isinstance(inargs[0], types.BaseTuple): - raise TypeError(f"{inargs[0]} must be be a Tuple") + @njit(**JIT_OPTIONS) + def rime(*args): + return rime_impl(*args) - if not isinstance(names, types.BaseTuple): - raise TypeError(f"{names} must be a Tuple of strings") + def rime_impl(*args): + raise NotImplementedError - if len(names) != len(inargs[0]): - raise ValueError(f"len(names): {len(names)} " - f"!= {len(inargs[0])}") + @overload(rime_impl, jit_options=JIT_OPTIONS) + def nb_rime(*args): + if not len(args) > 0: + raise TypeError("rime must be at least be called " + "with the signature argument") + + if not isinstance(args[0], types.Literal): + raise TypeError(f"Signature hash ({args[0]}) must be a literal") + + if not len(args) % 2 == 1: + raise TypeError(f"Length of named arguments {len(args)} " + f"is not divisible by 2") + + argstart = 1 + (len(args) - 1) // 2 + names = args[1:argstart] if not all(isinstance(n, types.Literal) for n in names): - raise TypeError(f"{names} must be a Tuple of strings") + raise TypeError(f"{names} must be a Tuple of Literal strings") if not all(n.literal_type is types.unicode_type for n in names): - raise TypeError(f"{names} must be a Tuple of strings") + raise TypeError(f"{names} must be a Tuple of Literal strings") # Get literal argument names names = tuple(n.literal_value for n in names) @@ -65,8 +77,8 @@ def rime(names, *inargs): except ValueError as e: raise ValueError(f"{str(e)} is required") - def impl(names, *inargs): - args_opt_idx = pack_opts_indices(inargs) + def impl(*args): + args_opt_idx = pack_opts_indices(args[argstart:]) args = pack_transformed(args_opt_idx) state = term_state(args) @@ -193,10 +205,9 @@ def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs): keys = (self.REQUIRED_ARGS_LITERAL + tuple(map(types.literal, kwargs.keys()))) - return self.impl(keys, time, - antenna1, antenna2, - feed1, feed2, - *kwargs.values()) + args = keys + (time, antenna1, antenna2, feed1, + feed2) + tuple(kwargs.values()) + return self.impl(types.literal(self.rime_spec.spec_hash), *args) def consolidate_args(args, kw): diff --git a/africanus/experimental/rime/fused/specification.py b/africanus/experimental/rime/fused/specification.py index 1c61d50c9..4966a7bf4 100644 --- a/africanus/experimental/rime/fused/specification.py +++ b/africanus/experimental/rime/fused/specification.py @@ -1,4 +1,5 @@ import ast +from hashlib import shake_256 from importlib import import_module import inspect import multiprocessing @@ -11,7 +12,7 @@ from africanus.experimental.rime.fused import terms as term_mod from africanus.experimental.rime.fused.transformers.core import Transformer from africanus.experimental.rime.fused import transformers as transformer_mod -from africanus.util.patterns import LazyProxy +from africanus.util.patterns import freeze, LazyProxy TERM_STRING_REGEX = re.compile("([A-Z])(pq|p|q)") @@ -335,6 +336,9 @@ def __init__(self, specification, terms=None, transformers=None): "process_pool": pool } + hash_elements = list(v for k, v in global_kw.items() + if k != "process_pool") + for cls, cfg in zip(term_types, term_cfgs): if cfg == "pq": cfg = "middle" @@ -350,7 +354,7 @@ def __init__(self, specification, terms=None, transformers=None): cls_kw = {} if "configuration" not in init_sig.parameters: - raise RimeSpecification( + raise RimeSpecificationError( f"{cls}.__init__{init_sig} must take a " f"'configuration' argument and call " f"super().__init__(configuration)") @@ -358,7 +362,7 @@ def __init__(self, specification, terms=None, transformers=None): for a, p in list(init_sig.parameters.items())[1:]: if p.kind not in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}: - raise RimeSpecification( + raise RimeSpecificationError( f"{cls}.__init__{init_sig} may not contain " f"*args or **kwargs") @@ -371,6 +375,8 @@ def __init__(self, specification, terms=None, transformers=None): f"Available args: {available_kw}") term = cls(**cls_kw) + hash_elements.append(".".join((cls.__module__, cls.__name__))) + hash_elements.append(cfg) terms.append(term) term_type_set = set(term_types) @@ -385,7 +391,7 @@ def __init__(self, specification, terms=None, transformers=None): transformers = [] - for cls in transformer_types.values(): + for _, cls in sorted(transformer_types.items()): init_sig = inspect.signature(cls.__init__) cls_kw = {} @@ -405,10 +411,13 @@ def __init__(self, specification, terms=None, transformers=None): f"Available args: {available_kw}") transformer = cls(**cls_kw) + hash_elements.append(".".join((cls.__module__, cls.__name__))) transformers.append(transformer) self.terms = terms self.transformers = transformers + str_elements = str((freeze(hash_elements))).encode("utf-8") + self.spec_hash = shake_256(str_elements).hexdigest(16) @staticmethod def _finalise_pool(pool): diff --git a/africanus/gridding/wgridder/vis2im.py b/africanus/gridding/wgridder/vis2im.py index 30155c565..752d376b0 100644 --- a/africanus/gridding/wgridder/vis2im.py +++ b/africanus/gridding/wgridder/vis2im.py @@ -20,9 +20,9 @@ def _dirty_internal(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, # need a copy here if using multiple row chunks freq_bin_idx2 = freq_bin_idx - freq_bin_idx.min() nband = freq_bin_idx.size - if type(vis[0, 0]) == np.complex64: + if type(vis[0, 0]) is np.complex64: real_type = np.float32 - elif type(vis[0, 0]) == np.complex128: + elif type(vis[0, 0]) is np.complex128: real_type = np.float64 else: raise ValueError("Vis of incorrect type") diff --git a/africanus/model/shape/gaussian_shape.py b/africanus/model/shape/gaussian_shape.py index 8a5b09050..de341536c 100644 --- a/africanus/model/shape/gaussian_shape.py +++ b/africanus/model/shape/gaussian_shape.py @@ -4,12 +4,21 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit +from africanus.util.numba import njit, overload, JIT_OPTIONS from africanus.constants import c as lightspeed -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def gaussian(uvw, frequency, shape_params): + return gaussian_impl(uvw, frequency, shape_params) + + +def gaussian_impl(uvw, frequency, shape_params): + raise NotImplementedError + + +@overload(gaussian_impl, jit_options=JIT_OPTIONS) +def nb_gaussian(uvw, frequency, shape_params): # https://en.wikipedia.org/wiki/Full_width_at_half_maximum fwhm = 2.0 * np.sqrt(2.0 * np.log(2.0)) fwhminv = 1.0 / fwhm diff --git a/africanus/model/spectral/spec_model.py b/africanus/model/spectral/spec_model.py index f7b672f18..42d6b3cf8 100644 --- a/africanus/model/spectral/spec_model.py +++ b/africanus/model/spectral/spec_model.py @@ -4,7 +4,7 @@ from numba import types import numpy as np -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, JIT_OPTIONS, njit from africanus.util.docs import DocstringTemplate @@ -95,8 +95,17 @@ def impl(array): return njit(nogil=True, cache=True)(impl) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def spectral_model(stokes, spi, ref_freq, frequency, base=0): + return spectral_model_impl(stokes, spi, ref_freq, frequency, base=base) + + +def spectral_model_impl(stokes, spi, ref_freq, frequency, base=0): + raise NotImplementedError + + +@overload(spectral_model_impl, jit_options=JIT_OPTIONS) +def nb_spectral_model(stokes, spi, ref_freq, frequency, base=0): arg_dtypes = tuple(np.dtype(a.dtype.name) for a in (stokes, spi, ref_freq, frequency)) dtype = np.result_type(*arg_dtypes) diff --git a/africanus/model/wsclean/spec_model.py b/africanus/model/wsclean/spec_model.py index acec935ac..ee51a1998 100644 --- a/africanus/model/wsclean/spec_model.py +++ b/africanus/model/wsclean/spec_model.py @@ -2,7 +2,7 @@ from numba import types import numpy as np -from africanus.util.numba import generated_jit +from africanus.util.numba import JIT_OPTIONS, njit, overload from africanus.util.docs import DocstringTemplate @@ -26,8 +26,12 @@ def log_spectral_model(I, coeffs, log_poly, ref_freq, freq): # noqa: E741 return I[:, None] * np.exp(term.sum(axis=2)) -@generated_jit(nopython=True, nogil=True, cache=True) def _check_log_poly_shape(coeffs, log_poly): + raise NotImplementedError + + +@overload(_check_log_poly_shape) +def overload_check_log_poly_shape(coeffs, log_poly): if isinstance(log_poly, types.npytypes.Array): def impl(coeffs, log_poly): if coeffs.shape[0] != log_poly.shape[0]: @@ -41,8 +45,12 @@ def impl(coeffs, log_poly): return impl -@generated_jit(nopython=True, nogil=True, cache=True) def _log_polynomial(log_poly, s): + raise NotImplementedError + + +@overload(_log_polynomial) +def overload_log_polynomial(log_poly, s): if isinstance(log_poly, types.npytypes.Array): def impl(log_poly, s): return log_poly[s] @@ -55,8 +63,17 @@ def impl(log_poly, s): return impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def spectra(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 + return spectra_impl(I, coeffs, log_poly, ref_freq, frequency) + + +def spectra_impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 + raise NotImplementedError + + +@overload(spectra_impl, jit_option=JIT_OPTIONS) +def nb_spectra(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 arg_dtypes = tuple(np.dtype(a.dtype.name) for a in (I, coeffs, ref_freq, frequency)) dtype = np.result_type(*arg_dtypes) diff --git a/africanus/rime/dask_predict.py b/africanus/rime/dask_predict.py index 98e56e3ee..e25b2ba13 100644 --- a/africanus/rime/dask_predict.py +++ b/africanus/rime/dask_predict.py @@ -17,7 +17,7 @@ predict_vis as np_predict_vis) from africanus.rime.wsclean_predict import ( WSCLEAN_PREDICT_DOCS, - wsclean_predict_impl as wsclean_predict_body) + wsclean_predict_main as wsclean_predict_body) from africanus.model.wsclean.spec_model import spectra as wsclean_spectra diff --git a/africanus/rime/phase.py b/africanus/rime/phase.py index 90970c2a3..8fce52a7e 100644 --- a/africanus/rime/phase.py +++ b/africanus/rime/phase.py @@ -4,12 +4,21 @@ from africanus.constants import minus_two_pi_over_c from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit +from africanus.util.numba import JIT_OPTIONS, overload, njit from africanus.util.type_inference import infer_complex_dtype -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def phase_delay(lm, uvw, frequency, convention='fourier'): + return phase_delay_impl(lm, uvw, frequency, convention=convention) + + +def phase_delay_impl(lm, uvw, frequency, convention='fourier'): + raise NotImplementedError + + +@overload(phase_delay_impl, jit_options=JIT_OPTIONS) +def nb_phase_delay(lm, uvw, frequency, convention='fourier'): # Bake constants in with the correct type one = lm.dtype(1.0) zero = lm.dtype(0.0) diff --git a/africanus/rime/predict.py b/africanus/rime/predict.py index a9eae8cea..cc7e13e13 100644 --- a/africanus/rime/predict.py +++ b/africanus/rime/predict.py @@ -4,7 +4,8 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import is_numba_type_none, generated_jit, njit +from africanus.util.numba import (is_numba_type_none, JIT_OPTIONS, + njit, overload) JONES_NOT_PRESENT = 0 @@ -413,10 +414,29 @@ def predict_checks(time_index, antenna1, antenna2, have_dies1, have_bvis, have_dies2) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def predict_vis(time_index, antenna1, antenna2, dde1_jones=None, source_coh=None, dde2_jones=None, die1_jones=None, base_vis=None, die2_jones=None): + return predict_vis_impl(time_index, antenna1, antenna2, + dde1_jones=dde1_jones, + source_coh=source_coh, + dde2_jones=dde2_jones, + die1_jones=die1_jones, + base_vis=base_vis, + die2_jones=die2_jones) + + +def predict_vis_impl(time_index, antenna1, antenna2, + dde1_jones=None, source_coh=None, dde2_jones=None, + die1_jones=None, base_vis=None, die2_jones=None): + raise NotImplementedError + + +@overload(predict_vis_impl, jit_options=JIT_OPTIONS) +def nb_predict_vis(time_index, antenna1, antenna2, + dde1_jones=None, source_coh=None, dde2_jones=None, + die1_jones=None, base_vis=None, die2_jones=None): tup = predict_checks(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, @@ -490,9 +510,21 @@ def _predict_vis_fn(time_index, antenna1, antenna2, return _predict_vis_fn -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def apply_gains(time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones): + return apply_gains_impl(time_index, antenna1, antenna2, + die1_jones, corrupted_vis, die2_jones) + + +def apply_gains_impl(time_index, antenna1, antenna2, + die1_jones, corrupted_vis, die2_jones): + raise NotImplementedError + + +@overload(apply_gains_impl, jit_options=JIT_OPTIONS) +def nb_apply_gains(time_index, antenna1, antenna2, + die1_jones, corrupted_vis, die2_jones): def impl(time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones): diff --git a/africanus/rime/wsclean_predict.py b/africanus/rime/wsclean_predict.py index b200cd145..c490dbc07 100644 --- a/africanus/rime/wsclean_predict.py +++ b/africanus/rime/wsclean_predict.py @@ -4,12 +4,12 @@ from africanus.constants import two_pi_over_c, c as lightspeed from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, jit +from africanus.util.numba import JIT_OPTIONS, overload, njit from africanus.model.wsclean.spec_model import spectra -@jit(nopython=True, nogil=True, cache=True) -def wsclean_predict_impl(uvw, lm, source_type, gauss_shape, +@njit(**JIT_OPTIONS) +def wsclean_predict_main(uvw, lm, source_type, gauss_shape, frequency, spectrum, dtype): fwhm = 2.0 * np.sqrt(2.0 * np.log(2.0)) @@ -86,9 +86,21 @@ def wsclean_predict_impl(uvw, lm, source_type, gauss_shape, return vis -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def wsclean_predict(uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency): + return wsclean_predict_impl(uvw, lm, source_type, flux, coeffs, + log_poly, ref_freq, gauss_shape, frequency) + + +def wsclean_predict_impl(uvw, lm, source_type, flux, coeffs, + log_poly, ref_freq, gauss_shape, frequency): + raise NotImplementedError + + +@overload(wsclean_predict_impl, jit_options=JIT_OPTIONS) +def nb_wsclean_predict(uvw, lm, source_type, flux, coeffs, + log_poly, ref_freq, gauss_shape, frequency): arg_dtypes = tuple(np.dtype(a.dtype.name) for a in (uvw, lm, flux, coeffs, ref_freq, frequency)) dtype = np.result_type(np.complex64, *arg_dtypes) @@ -96,7 +108,7 @@ def wsclean_predict(uvw, lm, source_type, flux, coeffs, def impl(uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency): spectrum = spectra(flux, coeffs, log_poly, ref_freq, frequency) - return wsclean_predict_impl(uvw, lm, source_type, gauss_shape, + return wsclean_predict_main(uvw, lm, source_type, gauss_shape, frequency, spectrum, dtype) return impl diff --git a/africanus/util/numba.py b/africanus/util/numba.py index bd59b5601..ada203069 100644 --- a/africanus/util/numba.py +++ b/africanus/util/numba.py @@ -6,6 +6,11 @@ from africanus.util.docs import on_rtd +JIT_OPTIONS = { + "cache": True, + "nogil": True, +} + if on_rtd(): # Fake decorators when on readthedocs def _fake_decorator(*args, **kwargs): @@ -18,7 +23,6 @@ def wrapper(*args, **kwargs): return decorator cfunc = _fake_decorator - generated_jit = _fake_decorator jit = _fake_decorator njit = _fake_decorator stencil = _fake_decorator @@ -27,7 +31,7 @@ def wrapper(*args, **kwargs): register_jitable = _fake_decorator intrinsic = _fake_decorator else: - from numba import cfunc, jit, njit, generated_jit, stencil # noqa + from numba import cfunc, jit, njit, stencil # noqa from numba.extending import overload, register_jitable, intrinsic # noqa diff --git a/docs/experimental.rst b/docs/experimental.rst index 119d8ded1..8b5f50d63 100644 --- a/docs/experimental.rst +++ b/docs/experimental.rst @@ -178,8 +178,7 @@ defined on the `Phase` term, called `init_fields`. Additionally, these arrays will be stored on the ``state`` object provided to the sampling function. -2. It supports reasoning about Numba types in a manner - similar to :func:`numba.generated_jit`. +2. It supports reasoning about Numba types. The ``lm``, ``uvw`` and ``chan_freq`` arguments contain the Numba types of the variables supplied to the RIME, while the ``typingctx`` argument contains a Numba diff --git a/setup.py b/setup.py index 077cc40d1..fce7cadfd 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,7 @@ # astropy breaks with numpy 1.15.3 # https://github.com/astropy/astropy/issues/7943 "numpy >= 1.14.0, != 1.15.3", - # https://github.com/ratt-ru/codex-africanus/issues/283 - "numba >= 0.53.1, < 0.59" + "numba >= 0.53.1" ] extras_require = {