From b9a27e540caedf7791634ce86211979be3e1f180 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 25 Jan 2024 16:48:36 +0200 Subject: [PATCH 01/18] Relax DFT precision error somewhat --- africanus/dft/tests/test_dft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/africanus/dft/tests/test_dft.py b/africanus/dft/tests/test_dft.py index 53915f5b..ac445d82 100644 --- a/africanus/dft/tests/test_dft.py +++ b/africanus/dft/tests/test_dft.py @@ -161,7 +161,7 @@ def test_adjointness(): RHS = (RH(gamma_vis, uvw, lm, frequency, flag).reshape( size_im, 1).T.dot(gamma_im.reshape(size_im, 1))).real - assert np.abs(LHS - RHS) < 1e-14 + assert np.abs(LHS - RHS) < 1e-13 def test_vis_to_im_flagged(): From ba32c64243695480e26abe8bfa67dadf0e370883 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 09:16:14 +0200 Subject: [PATCH 02/18] Initial commit --- africanus/averaging/bda_avg.py | 186 ++++++++++++------ africanus/averaging/bda_mapping.py | 136 ++++++++----- africanus/averaging/shared.py | 10 +- africanus/averaging/support.py | 20 +- .../averaging/tests/test_bda_averaging.py | 2 +- africanus/averaging/tests/test_bda_mapping.py | 74 +++---- africanus/averaging/time_and_channel_avg.py | 85 +++++++- .../averaging/time_and_channel_mapping.py | 16 +- .../calibration/phase_only/phase_only.py | 42 +++- .../utils/compute_and_corrupt_vis.py | 16 +- africanus/calibration/utils/correct_vis.py | 15 +- africanus/calibration/utils/corrupt_vis.py | 14 +- africanus/calibration/utils/residual_vis.py | 14 +- africanus/coordinates/coordinates.py | 41 +++- africanus/dft/kernels.py | 27 ++- africanus/experimental/rime/fused/core.py | 39 ++-- africanus/model/shape/gaussian_shape.py | 12 +- africanus/model/spectral/spec_model.py | 13 +- africanus/model/wsclean/spec_model.py | 21 +- africanus/util/numba.py | 5 + 20 files changed, 561 insertions(+), 227 deletions(-) diff --git a/africanus/averaging/bda_avg.py b/africanus/averaging/bda_avg.py index 3303a54c..39877f8e 100644 --- a/africanus/averaging/bda_avg.py +++ b/africanus/averaging/bda_avg.py @@ -10,7 +10,9 @@ merge_flags, vis_output_arrays) from africanus.util.docs import DocstringTemplate -from africanus.util.numba import (generated_jit, +from africanus.util.numba import (njit, + overload, + JIT_OPTIONS, intrinsic, is_numba_type_none) @@ -20,11 +22,24 @@ RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_average(meta, ant1, ant2, flag_row=None, time_centroid=None, exposure=None, uvw=None, weight=None, sigma=None): + return row_average_impl(meta, ant1, ant2, flag_row=flag_row, + time_centroid=time_centroid, exposure=exposure, uvw=uvw, + weight=weight, sigma=sigma) +def row_average_impl(meta, ant1, ant2, flag_row=None, + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): + return NotImplementedError + + +@overload(row_average_impl, jit_options=JIT_OPTIONS) +def nb_row_average_impl(meta, ant1, ant2, flag_row=None, + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): have_flag_row = not is_numba_type_none(flag_row) have_time_centroid = not is_numba_type_none(time_centroid) have_exposure = not is_numba_type_none(exposure) @@ -310,13 +325,36 @@ def codegen(context, builder, signature, args): return sig, codegen -@generated_jit(nopython=True, nogil=True, cache=True) + + +@njit(**JIT_OPTIONS) def row_chan_average(meta, flag_row=None, weight=None, visibilities=None, flag=None, weight_spectrum=None, sigma_spectrum=None): + return row_chan_average_impl(meta, flag_row=flag_row, weight=weight, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum) + + +def row_chan_average_impl(meta, flag_row=None, weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None): + + return NotImplementedError + +@overload(row_chan_average_impl, jit_options=JIT_OPTIONS) +def nb_row_chan_average(meta, flag_row=None, weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None): + have_vis = not is_numba_type_none(visibilities) have_flag = not is_numba_type_none(flag) have_flag_row = not is_numba_type_none(flag_row) @@ -523,7 +561,7 @@ def impl(meta, flag_row=None, weight=None, _rowchan_output_fields) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def bda(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, uvw=None, weight=None, sigma=None, @@ -535,65 +573,89 @@ def bda(time, interval, antenna1, antenna2, decorrelation=0.98, time_bin_secs=None, min_nchan=1): - def impl(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - max_uvw_dist=None, max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): - # Merge flag_row and flag arrays - flag_row = merge_flags(flag_row, flag) - - meta = bda_mapper(time, interval, antenna1, antenna2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan) - - row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841 - time_centroid, exposure, uvw, - weight=weight, sigma=sigma) - - row_chan_avg = row_chan_average(meta, # noqa: F841 - flag_row=flag_row, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum) - - # Have to explicitly write it out because numba tuples - # are highly constrained types - return AverageOutput(meta.map, - meta.offsets, - meta.decorr_chan_width, - meta.time, - meta.interval, - meta.chan_width, - meta.flag_row, - row_avg.antenna1, - row_avg.antenna2, - row_avg.time_centroid, - row_avg.exposure, - row_avg.uvw, - row_avg.weight, - row_avg.sigma, - # None, # chan_data.chan_freq, - # None, # chan_data.chan_width, - # None, # chan_data.effective_bw, - # None, # chan_data.resolution, - row_chan_avg.visibilities, - row_chan_avg.flag, - row_chan_avg.weight_spectrum, - row_chan_avg.sigma_spectrum) - return impl + return bda_impl(time, interval, antenna1, antenna2, + time_centroid=time_centroid, exposure=exposure, flag_row=flag_row, + uvw=uvw, weight=weight, sigma=sigma, + chan_freq=chan_freq, chan_width=chan_width, + effective_bw=effective_bw, resolution=resolution, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum, + max_uvw_dist=max_uvw_dist, max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, min_nchan=min_nchan) + +def bda_impl(time, interval, antenna1, antenna2, + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + max_uvw_dist=None, max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): + return NotImplementedError + +@overload(bda_impl, jit_options=JIT_OPTIONS) +def nb_bda_impl(time, interval, antenna1, antenna2, + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + max_uvw_dist=None, max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): + # Merge flag_row and flag arrays + flag_row = merge_flags(flag_row, flag) + + meta = bda_mapper(time, interval, antenna1, antenna2, uvw, + chan_width, chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan) + + row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841 + time_centroid, exposure, uvw, + weight=weight, sigma=sigma) + + row_chan_avg = row_chan_average(meta, # noqa: F841 + flag_row=flag_row, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum) + + # Have to explicitly write it out because numba tuples + # are highly constrained types + return AverageOutput(meta.map, + meta.offsets, + meta.decorr_chan_width, + meta.time, + meta.interval, + meta.chan_width, + meta.flag_row, + row_avg.antenna1, + row_avg.antenna2, + row_avg.time_centroid, + row_avg.exposure, + row_avg.uvw, + row_avg.weight, + row_avg.sigma, + # None, # chan_data.chan_freq, + # None, # chan_data.chan_width, + # None, # chan_data.effective_bw, + # None, # chan_data.resolution, + row_chan_avg.visibilities, + row_chan_avg.flag, + row_chan_avg.weight_spectrum, + row_chan_avg.sigma_spectrum) BDA_DOCS = DocstringTemplate(""" diff --git a/africanus/averaging/bda_mapping.py b/africanus/averaging/bda_mapping.py index 34d5de62..c1b870a8 100644 --- a/africanus/averaging/bda_mapping.py +++ b/africanus/averaging/bda_mapping.py @@ -8,7 +8,11 @@ import numba.types from africanus.constants import c as lightspeed -from africanus.util.numba import generated_jit, njit, is_numba_type_none +from africanus.util.numba import ( + JIT_OPTIONS, + overload, + njit, + is_numba_type_none) from africanus.averaging.support import unique_time, unique_baselines @@ -16,71 +20,71 @@ class RowMapperError(Exception): pass -@njit(nogil=True, cache=True) -def erf26(x): - """Implements 7.1.26 erf approximation from Abramowitz and - Stegun (1972), pg. 299. Accurate for abs(eps(x)) <= 1.5e-7.""" +# @njit(nogil=True, cache=True) +# def erf26(x): +# """Implements 7.1.26 erf approximation from Abramowitz and +# Stegun (1972), pg. 299. Accurate for abs(eps(x)) <= 1.5e-7.""" - # Constants - p = 0.3275911 - a1 = 0.254829592 - a2 = -0.284496736 - a3 = 1.421413741 - a4 = -1.453152027 - a5 = 1.061405429 - e = 2.718281828 +# # Constants +# p = 0.3275911 +# a1 = 0.254829592 +# a2 = -0.284496736 +# a3 = 1.421413741 +# a4 = -1.453152027 +# a5 = 1.061405429 +# e = 2.718281828 - # t - t = 1.0/(1.0 + (p * x)) +# # t +# t = 1.0/(1.0 + (p * x)) - # Erf calculation - erf = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t) - erf *= e ** -(x ** 2) +# # Erf calculation +# erf = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t) +# erf *= e ** -(x ** 2) - return -round(erf, 9) if x < 0 else round(erf, 0) +# return -round(erf, 9) if x < 0 else round(erf, 0) -@njit(nogil=True, cache=True) -def time_decorrelation(u, v, w, max_lm, time_bin_secs, min_wavelength): - sidereal_rotation_rate = 7.292118516e-5 - diffraction_limit = min_wavelength / np.sqrt(u**2 + v**2 + w**2) - term = max_lm * time_bin_secs * sidereal_rotation_rate / diffraction_limit - return 1.0 - 1.0645 * erf26(0.8326*term) / term +# @njit(nogil=True, cache=True) +# def time_decorrelation(u, v, w, max_lm, time_bin_secs, min_wavelength): +# sidereal_rotation_rate = 7.292118516e-5 +# diffraction_limit = min_wavelength / np.sqrt(u**2 + v**2 + w**2) +# term = max_lm * time_bin_secs * sidereal_rotation_rate / diffraction_limit +# return 1.0 - 1.0645 * erf26(0.8326*term) / term -_SERIES_COEFFS = (1./40, 107./67200, 3197./24192000, 49513./3973939200) +# _SERIES_COEFFS = (1./40, 107./67200, 3197./24192000, 49513./3973939200) -@njit(nogil=True, cache=True, inline='always') -def inv_sinc(sinc_x, tol=1e-12): - # Invalid input - if sinc_x > 1.0: - raise ValueError("sinc_x > 1.0") +# @njit(nogil=True, cache=True, inline='always') +# def inv_sinc(sinc_x, tol=1e-12): +# # Invalid input +# if sinc_x > 1.0: +# raise ValueError("sinc_x > 1.0") - # Initial guess from reversion of Taylor series - # https://math.stackexchange.com/questions/3189307/inverse-of-frac-sinxx - x = t_pow = np.sqrt(6*np.abs((1 - sinc_x))) - t_squared = t_pow*t_pow +# # Initial guess from reversion of Taylor series +# # https://math.stackexchange.com/questions/3189307/inverse-of-frac-sinxx +# x = t_pow = np.sqrt(6*np.abs((1 - sinc_x))) +# t_squared = t_pow*t_pow - for coeff in numba.literal_unroll(_SERIES_COEFFS): - t_pow *= t_squared - x += coeff * t_pow +# for coeff in numba.literal_unroll(_SERIES_COEFFS): +# t_pow *= t_squared +# x += coeff * t_pow - # Use Newton Raphson to go the rest of the way - # https://www.wolframalpha.com/input/?i=simplify+%28sinc%5Bx%5D+-+c%29+%2F+D%5Bsinc%5Bx%5D%2Cx%5D - while True: - # evaluate delta between this iteration sinc(x) and original - sinx = np.sin(x) - š¯˛“sinc_x = (1.0 if x == 0.0 else sinx/x) - sinc_x +# # Use Newton Raphson to go the rest of the way +# # https://www.wolframalpha.com/input/?i=simplify+%28sinc%5Bx%5D+-+c%29+%2F+D%5Bsinc%5Bx%5D%2Cx%5D +# while True: +# # evaluate delta between this iteration sinc(x) and original +# sinx = np.sin(x) +# š¯˛“sinc_x = (1.0 if x == 0.0 else sinx/x) - sinc_x - # Stop if converged - if np.abs(š¯˛“sinc_x) < tol: - break +# # Stop if converged +# if np.abs(š¯˛“sinc_x) < tol: +# break - # Next iteration - x -= (x*x * š¯˛“sinc_x) / (x*np.cos(x) - sinx) +# # Next iteration +# x -= (x*x * š¯˛“sinc_x) / (x*np.cos(x) - sinx) - return x +# return x @njit(nogil=True, cache=True, inline='always') @@ -338,7 +342,7 @@ def finalise_bin(self, auto_corr, uvw, time, interval, "time", "interval", "chan_width", "flag_row"]) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def bda_mapper(time, interval, ant1, ant2, uvw, chan_width, chan_freq, max_uvw_dist, @@ -347,7 +351,35 @@ def bda_mapper(time, interval, ant1, ant2, uvw, decorrelation=0.98, time_bin_secs=None, min_nchan=1): + return bda_mapper_impl(time, interval, ant1, ant2, uvw, + chan_width, chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan) + +def bda_mapper_impl(time, interval, ant1, ant2, uvw, + chan_width, chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): + return NotImplementedError + +@overload(bda_mapper_impl, jit_options=JIT_OPTIONS) +def nb_bda_mapper(time, interval, ant1, ant2, uvw, + chan_width, chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): have_time_bin_secs = not is_numba_type_none(time_bin_secs) Omitted = numba.types.misc.Omitted @@ -381,7 +413,7 @@ def bda_mapper(time, interval, ant1, ant2, uvw, ('max_chan_freq', chan_freq.dtype), ('max_uvw_dist', max_uvw_dist)] - JitBinner = jitclass(spec)(Binner) + JitBinner = st(spec)(Binner) def impl(time, interval, ant1, ant2, uvw, chan_width, chan_freq, diff --git a/africanus/averaging/shared.py b/africanus/averaging/shared.py index 5d2deef8..64e5e3c2 100644 --- a/africanus/averaging/shared.py +++ b/africanus/averaging/shared.py @@ -5,7 +5,7 @@ from africanus.util.numba import (is_numba_type_none, intrinsic, njit, - generated_jit, + JIT_OPTIONS, overload) @@ -13,11 +13,11 @@ def shape_or_invalid_shape(array, ndim): pass -# TODO(sjperkins) -# maybe replace with njit and inline='always' if -# https://github.com/numba/numba/issues/4693 is resolved -@generated_jit(nopython=True, nogil=True, cache=True) def merge_flags(flag_row, flag): + pass + +@overload(merge_flags, inline='always') +def _merge_flags(flag_row, flag): have_flag_row = not is_numba_type_none(flag_row) have_flag = not is_numba_type_none(flag) diff --git a/africanus/averaging/support.py b/africanus/averaging/support.py index d5e15f68..3c671c0a 100644 --- a/africanus/averaging/support.py +++ b/africanus/averaging/support.py @@ -4,7 +4,7 @@ import numpy as np import numba -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import JIT_OPTIONS, overload, njit @njit(nogil=True, cache=True) @@ -53,8 +53,15 @@ def _unique_internal(data): return aux[mask], perm[mask], inv_idx, np.diff(np.array(counts)) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def unique_time(time): + return unique_time_impl(time) + +def unique_time_impl(time): + return NotImplementedError + +@overload(unique_time_impl, jit_options=JIT_OPTIONS) +def nb_unique_time(time): """ Return unique time, inverse index and counts """ if time.dtype not in (numba.float32, numba.float64): raise ValueError("time must be floating point but is %s" % time.dtype) @@ -65,8 +72,15 @@ def impl(time): return impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def unique_baselines(ant1, ant2): + return unique_baselines_impl(ant1, ant2) + +def unique_baselines_impl(ant1, ant2): + return NotImplementedError + +@overload(unique_baselines_impl, jit_options=JIT_OPTIONS) +def nb_unique_baselines(ant1, ant2): """ Return unique baselines, inverse index and counts """ if not ant1.dtype == numba.int32 or not ant2.dtype == numba.int32: # Need these to be int32 for the bl_32bit.view(np.int64) trick diff --git a/africanus/averaging/tests/test_bda_averaging.py b/africanus/averaging/tests/test_bda_averaging.py index 549048f6..7ea2688c 100644 --- a/africanus/averaging/tests/test_bda_averaging.py +++ b/africanus/averaging/tests/test_bda_averaging.py @@ -212,7 +212,7 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): assert_array_equal(row_avg.exposure, out_interval) assert_array_equal(row_avg.uvw, out_uvw) assert_array_equal(row_avg.weight, out_weight) - assert_array_equal(row_avg.sigma, out_sigma) + assert_array_almost_equal(row_avg.sigma, out_sigma) vshape = (in_row, in_chan, in_corr) vis = rs.normal(size=vshape) + rs.normal(size=vshape)*1j diff --git a/africanus/averaging/tests/test_bda_mapping.py b/africanus/averaging/tests/test_bda_mapping.py index fc050869..b0ed2e65 100644 --- a/africanus/averaging/tests/test_bda_mapping.py +++ b/africanus/averaging/tests/test_bda_mapping.py @@ -172,43 +172,43 @@ def synthesized_uvw(ants, time, phase_dir, auto_correlations): return ant1, ant2, uvw -@pytest.mark.parametrize("decorrelation", [0.95]) -@pytest.mark.parametrize("min_nchan", [1]) -def test_bda_mapper(time, synthesized_uvw, interval, - chan_freq, chan_width, - decorrelation, min_nchan): - time = np.unique(time) - ant1, ant2, uvw = synthesized_uvw - - nbl = ant1.shape[0] - ntime = time.shape[0] - - time = np.repeat(time, nbl) - interval = np.repeat(interval, nbl) - ant1 = np.tile(ant1, ntime) - ant2 = np.tile(ant2, ntime) - flag_row = np.zeros(time.shape[0], dtype=np.int8) - - max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max() - - row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841 - chan_width, chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=3.0, - decorrelation=decorrelation, - min_nchan=min_nchan) - - offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0]) - assert_array_equal(offsets, row_meta.offsets[:-1]) - assert row_meta.map.max() + 1 == row_meta.offsets[-1] - - # NUM_CHAN divides number of channels exactly - num_chan = np.diff(row_meta.offsets) - _, remainder = np.divmod(chan_width.shape[0], num_chan) - assert np.all(remainder == 0) - decorr_cw = chan_width.sum() / num_chan - assert_array_equal(decorr_cw, row_meta.decorr_chan_width) +# @pytest.mark.parametrize("decorrelation", [0.95]) +# @pytest.mark.parametrize("min_nchan", [1]) +# def test_bda_mapper(time, synthesized_uvw, interval, +# chan_freq, chan_width, +# decorrelation, min_nchan): +# time = np.unique(time) +# ant1, ant2, uvw = synthesized_uvw + +# nbl = ant1.shape[0] +# ntime = time.shape[0] + +# time = np.repeat(time, nbl) +# interval = np.repeat(interval, nbl) +# ant1 = np.tile(ant1, ntime) +# ant2 = np.tile(ant2, ntime) +# flag_row = np.zeros(time.shape[0], dtype=np.int8) + +# max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max() + +# row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841 +# chan_width, chan_freq, +# max_uvw_dist, +# flag_row=flag_row, +# max_fov=3.0, +# decorrelation=decorrelation, +# min_nchan=min_nchan) + +# offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0]) +# assert_array_equal(offsets, row_meta.offsets[:-1]) +# assert row_meta.map.max() + 1 == row_meta.offsets[-1] + +# # NUM_CHAN divides number of channels exactly +# num_chan = np.diff(row_meta.offsets) +# _, remainder = np.divmod(chan_width.shape[0], num_chan) +# assert np.all(remainder == 0) +# decorr_cw = chan_width.sum() / num_chan +# assert_array_equal(decorr_cw, row_meta.decorr_chan_width) def test_bda_binner(time, interval, synthesized_uvw, diff --git a/africanus/averaging/time_and_channel_avg.py b/africanus/averaging/time_and_channel_avg.py index a1c4aeb8..08dc199d 100644 --- a/africanus/averaging/time_and_channel_avg.py +++ b/africanus/averaging/time_and_channel_avg.py @@ -13,7 +13,7 @@ vis_output_arrays) from africanus.util.docs import DocstringTemplate -from africanus.util.numba import (is_numba_type_none, generated_jit, +from africanus.util.numba import (is_numba_type_none, JIT_OPTIONS, njit, overload, intrinsic) TUPLE_TYPE = 0 @@ -59,8 +59,21 @@ def chan_add(output, input, orow, ochan, irow, ichan, corr): RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_average(meta, ant1, ant2, flag_row=None, + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): + return row_average_impl(meta, ant1, ant2, flag_row=flag_row, + time_centroid=time_centroid, exposure=exposure, + uvw=uvw, weight=weight, sigma=sigma) + +def row_average_impl(meta, ant1, ant2, flag_row=None, + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): + return NotImplementedError + +@overload(row_average_impl, jit_options=JIT_OPTIONS) +def nb_row_average(meta, ant1, ant2, flag_row=None, time_centroid=None, exposure=None, uvw=None, weight=None, sigma=None): @@ -317,11 +330,27 @@ def codegen(context, builder, signature, args): return sig, codegen -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_chan_average(row_meta, chan_meta, flag_row=None, weight=None, visibilities=None, flag=None, weight_spectrum=None, sigma_spectrum=None): + return row_chan_average_impl(row_meta, chan_meta, + flag_row=flag_row, weight=weight, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum) + +def row_chan_average_impl(row_meta, chan_meta, + flag_row=None, weight=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None): + return NotImplementedError + +@overload(row_chan_average_impl, jit_options=JIT_OPTIONS) +def nb_row_chan_average(row_meta, chan_meta, + flag_row=None, weight=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None): dummy_chan_freq = None dummy_chan_width = None @@ -521,10 +550,20 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, _chan_output_fields = ["chan_freq", "chan_width", "effective_bw", "resolution"] ChannelAverageOutput = namedtuple("ChannelAverageOutput", _chan_output_fields) - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def chan_average(chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None): + return chan_average_impl(chan_meta, chan_freq=chan_freq, chan_width=chan_width, + effective_bw=effective_bw, resolution=resolution) + +def chan_average_impl(chan_meta, chan_freq=None, chan_width=None, + effective_bw=None, resolution=None): + + return NotImplementedError + +@overload(chan_average_impl, jit_options=JIT_OPTIONS) +def nb_chan_average(chan_meta, chan_freq=None, chan_width=None, + effective_bw=None, resolution=None): def impl(chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None): @@ -579,8 +618,7 @@ def impl(chan_meta, chan_freq=None, chan_width=None, _chan_output_fields + _rowchan_output_fields) - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def time_and_channel(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, uvw=None, weight=None, sigma=None, @@ -589,17 +627,46 @@ def time_and_channel(time, interval, antenna1, antenna2, visibilities=None, flag=None, weight_spectrum=None, sigma_spectrum=None, time_bin_secs=1.0, chan_bin_size=1): + return time_and_channel_impl(time, interval, antenna1, antenna2, + time_centroid=time_centroid, exposure=exposure, flag_row=flag_row, + uvw=uvw, weight=weight, sigma=sigma, + chan_freq=chan_freq, chan_width=chan_width, + effective_bw=effective_bw, resolution=resolution, + visibilities=visibilities, flag=flag, + weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum, + time_bin_secs=time_bin_secs, chan_bin_size=chan_bin_size) + +def time_and_channel_impl(time, interval, antenna1, antenna2, + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + time_bin_secs=1.0, chan_bin_size=1): + return NotImplementedError + + +@overload(time_and_channel_impl, jit_options=JIT_OPTIONS) +def nb_time_and_channel(time, interval, antenna1, antenna2, + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + time_bin_secs=1.0, chan_bin_size=1): valid_types = (types.misc.Omitted, types.scalars.Float, types.scalars.Integer) if not isinstance(time_bin_secs, valid_types): - raise TypeError("time_bin_secs must be a scalar float") + raise TypeError(f"time_bin_secs ({time_bin_secs}) must be a scalar float") valid_types = (types.misc.Omitted, types.scalars.Integer) if not isinstance(chan_bin_size, valid_types): - raise TypeError("chan_bin_size must be a scalar integer") + raise TypeError(f"chan_bin_size ({chan_bin_size}) must be a scalar integer") def impl(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, diff --git a/africanus/averaging/time_and_channel_mapping.py b/africanus/averaging/time_and_channel_mapping.py index a3f2984a..9acd92d0 100644 --- a/africanus/averaging/time_and_channel_mapping.py +++ b/africanus/averaging/time_and_channel_mapping.py @@ -7,7 +7,7 @@ import numba from africanus.averaging.support import unique_time, unique_baselines -from africanus.util.numba import is_numba_type_none, generated_jit, njit, jit +from africanus.util.numba import is_numba_type_none, njit, jit, JIT_OPTIONS, overload class RowMapperError(Exception): @@ -55,8 +55,7 @@ def impl(flag_row, in_row, out_flag_row, out_row, flagged): RowMapOutput = namedtuple("RowMapOutput", ["map", "time", "interval", "flag_row"]) - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def row_mapper(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): """ @@ -178,6 +177,17 @@ def row_mapper(time, interval, antenna1, antenna2, Raised if an illegal condition occurs """ + return row_mapper_impl(time, interval, antenna1, antenna2, + flag_row=flag_row, time_bin_secs=time_bin_secs) + +def row_mapper_impl(time, interval, antenna1, antenna2, + flag_row=None, time_bin_secs=1): + return NotImplementedError + + +@overload(row_mapper_impl, jit_options=JIT_OPTIONS) +def nb_row_mapper(time, interval, antenna1, antenna2, + flag_row=None, time_bin_secs=1): have_flag_row = not is_numba_type_none(flag_row) is_flagged_fn = is_flagged_factory(have_flag_row) diff --git a/africanus/calibration/phase_only/phase_only.py b/africanus/calibration/phase_only/phase_only.py index 8668de19..afdac6b6 100755 --- a/africanus/calibration/phase_only/phase_only.py +++ b/africanus/calibration/phase_only/phase_only.py @@ -3,7 +3,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate from africanus.calibration.utils import residual_vis, check_type -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -22,9 +22,21 @@ def jacobian(a1j, blj, a2j, sign, out): return njit(nogil=True, inline='always')(jacobian) -@generated_jit(nopython=True, nogil=True, cache=True, fastmath=True) +@njit(**JIT_OPTIONS) def compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): + return compute_jhj_and_jhr_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, residual, + model, flag) + +def compute_jhj_and_jhr_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, residual, model, flag): + return NotImplementedError + + +@overload(compute_jhj_and_jhr_impl, jit_options=JIT_OPTIONS) +def nb_compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, residual, model, flag): mode = check_type(jones, residual) if mode != DIAG_DIAG: @@ -70,9 +82,20 @@ def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1, return _jhj_and_jhr_fn -@generated_jit(nopython=True, nogil=True, cache=True, fastmath=True) +@njit(**JIT_OPTIONS) def compute_jhj(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag): + return compute_jhj_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, model, flag) + +def compute_jhj_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, flag): + return NotImplementedError + + +@overload(compute_jhj_impl, jit_options=JIT_OPTIONS) +def nb_compute_jhj(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, flag): mode = check_type(jones, model, vis_type='model') @@ -110,9 +133,20 @@ def _compute_jhj_fn(time_bin_indices, time_bin_counts, antenna1, return _compute_jhj_fn -@generated_jit(nopython=True, nogil=True, cache=True, fastmath=True) +@njit(**JIT_OPTIONS) def compute_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): + return compute_jhr_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, residual, + model, flag) + +def compute_jhr_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, residual, model, flag): + return NotImplementedError + +@overload(compute_jhr_impl, jit_options=JIT_OPTIONS) +def nb_compute_jhr(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, residual, model, flag): mode = check_type(jones, model, vis_type='model') diff --git a/africanus/calibration/utils/compute_and_corrupt_vis.py b/africanus/calibration/utils/compute_and_corrupt_vis.py index 6902c06b..73f6199d 100644 --- a/africanus/calibration/utils/compute_and_corrupt_vis.py +++ b/africanus/calibration/utils/compute_and_corrupt_vis.py @@ -2,7 +2,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils import check_type from africanus.constants import minus_two_pi_over_c as m2pioc from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -72,10 +72,20 @@ def jones_mul(a1j, model, a2j, uvw, freq, lm, out): return njit(nogil=True, inline='always')(jones_mul) - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm): + return compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, uvw, freq, lm) + +def compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, uvw, freq, lm): + return NotImplementedError + + +@overload(compute_and_corrupt_vis_impl, jit_options=JIT_OPTIONS) +def mb_compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model, uvw, freq, lm): mode = check_type(jones, model, vis_type='model') jones_mul = jones_mul_factory(mode) diff --git a/africanus/calibration/utils/correct_vis.py b/africanus/calibration/utils/correct_vis.py index 2071507c..4918c85b 100644 --- a/africanus/calibration/utils/correct_vis.py +++ b/africanus/calibration/utils/correct_vis.py @@ -2,7 +2,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils import check_type from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -65,10 +65,19 @@ def jones_inverse_mul(a1j, blj, a2j, out): t4*b11 return njit(nogil=True, inline='always')(jones_inverse_mul) - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def correct_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag): + return correct_vis_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, vis, flag) + +def correct_vis_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, vis, flag): + return NotImplementedError + +@overload(correct_vis_impl, jit_options=JIT_OPTIONS) +def nb_correct_vis(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, vis, flag): mode = check_type(jones, vis) jones_inverse_mul = jones_inverse_mul_factory(mode) diff --git a/africanus/calibration/utils/corrupt_vis.py b/africanus/calibration/utils/corrupt_vis.py index e1059899..79f1cb1e 100644 --- a/africanus/calibration/utils/corrupt_vis.py +++ b/africanus/calibration/utils/corrupt_vis.py @@ -2,7 +2,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils import check_type from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -56,9 +56,19 @@ def jones_mul(a1j, model, a2j, out): return njit(nogil=True, inline='always')(jones_mul) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): + return corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model) + +def corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model): + return NotImplementedError + +@overload(corrupt_vis_impl, jit_options=JIT_OPTIONS) +def nb_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, model): mode = check_type(jones, model, vis_type='model') jones_mul = jones_mul_factory(mode) diff --git a/africanus/calibration/utils/residual_vis.py b/africanus/calibration/utils/residual_vis.py index bc985d8b..37fb36b0 100644 --- a/africanus/calibration/utils/residual_vis.py +++ b/africanus/calibration/utils/residual_vis.py @@ -3,7 +3,7 @@ import numpy as np from functools import wraps from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.calibration.utils import check_type from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL @@ -59,9 +59,19 @@ def subtract_model(a1j, blj, a2j, model, out): return njit(nogil=True, inline='always')(subtract_model) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def residual_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model): + return residual_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, vis, flag, model) + +def residual_vis_impl(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, vis, flag, model): + return NotImplementedError + +@overload(residual_vis_impl, jit_options=JIT_OPTIONS) +def nb_residual_vis(time_bin_indices, time_bin_counts, antenna1, + antenna2, jones, vis, flag, model): mode = check_type(jones, vis) subtract_model = subtract_model_factory(mode) diff --git a/africanus/coordinates/coordinates.py b/africanus/coordinates/coordinates.py index c2b12738..ac3e66fe 100644 --- a/africanus/coordinates/coordinates.py +++ b/africanus/coordinates/coordinates.py @@ -4,7 +4,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import is_numba_type_none, generated_jit, jit +from africanus.util.numba import is_numba_type_none, jit, JIT_OPTIONS, njit, overload from africanus.util.requirements import requires_optional try: @@ -24,9 +24,15 @@ def _create_phase_centre(phase_centre, dtype): def _return_phase_centre(phase_centre, dtype): return phase_centre - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def radec_to_lmn(radec, phase_centre=None): + return radec_to_lmn_impl(radec, phase_centre=phase_centre) + +def radec_to_lmn_impl(radec, phase_centre=None): + raise NotImplementedError + +@overload(radec_to_lmn_impl, jit_options=JIT_OPTIONS) +def nb_radec_to_lmn(radec, phase_centre=None): dtype = radec.dtype if is_numba_type_none(phase_centre): @@ -63,9 +69,15 @@ def _radec_to_lmn_impl(radec, phase_centre=None): return _radec_to_lmn_impl - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def radec_to_lm(radec, phase_centre=None): + return radec_to_lm_impl(radec, phase_centre=phase_centre) + +def radec_to_lm_impl(radec, phase_centre=None): + raise NotImplementedError + +@overload(radec_to_lm_impl, jit_options=JIT_OPTIONS) +def nb_radec_to_lm(radec, phase_centre=None): dtype = radec.dtype if is_numba_type_none(phase_centre): @@ -100,9 +112,15 @@ def _radec_to_lm_impl(radec, phase_centre=None): return _radec_to_lm_impl - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def lmn_to_radec(lmn, phase_centre=None): + return lmn_to_radec_impl(lmn, phase_centre=phase_centre) + +def lmn_to_radec_impl(lmn, phase_centre=None): + raise NotImplementedError + +@overload(lmn_to_radec_impl, jit_options=JIT_OPTIONS) +def nb_lmn_to_radec(lmn, phase_centre=None): dtype = lmn.dtype if is_numba_type_none(phase_centre): @@ -131,8 +149,15 @@ def _lmn_to_radec_impl(lmn, phase_centre=None): return _lmn_to_radec_impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def lm_to_radec(lm, phase_centre=None): + return lm_to_radec_impl(lm, phase_centre=phase_centre) + +def lm_to_radec_impl(lm, phase_centre=None): + raise NotImplementedError + +@overload(lm_to_radec_impl, jit_options=JIT_OPTIONS) +def nb_lm_to_radec(lm, phase_centre=None): dtype = lm.dtype if is_numba_type_none(phase_centre): diff --git a/africanus/dft/kernels.py b/africanus/dft/kernels.py index 893302ed..f722c709 100644 --- a/africanus/dft/kernels.py +++ b/africanus/dft/kernels.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from africanus.util.numba import is_numba_type_none, generated_jit +from africanus.util.numba import is_numba_type_none, njit, overload, JIT_OPTIONS from africanus.util.docs import doc_tuple_to_str from collections import namedtuple @@ -10,10 +10,19 @@ from africanus.constants import minus_two_pi_over_c, two_pi_over_c - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def im_to_vis(image, uvw, lm, frequency, convention='fourier', dtype=None): + return im_to_vis_impl(image, uvw, lm, frequency, + convention=convention, dtype=dtype) + +def im_to_vis_impl(image, uvw, lm, frequency, + convention='fourier', dtype=None): + raise NotImplementedError + +@overload(im_to_vis_impl, jit_options=JIT_OPTIONS) +def nb_im_to_vis(image, uvw, lm, frequency, + convention='fourier', dtype=None): # Infer complex output dtype if none provided if is_numba_type_none(dtype): out_dtype = np.result_type(np.complex64, @@ -62,9 +71,19 @@ def impl(image, uvw, lm, frequency, return impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def vis_to_im(vis, uvw, lm, frequency, flags, convention='fourier', dtype=None): + return vis_to_im_impl(vis, uvw, lm, frequency, flags, + convention=convention, dtype=dtype) + +def vis_to_im_impl(vis, uvw, lm, frequency, flags, + convention='fourier', dtype=None): + raise NotImplementedError + +@overload(vis_to_im_impl, jit_options=JIT_OPTIONS) +def nb_vis_to_im(vis, uvw, lm, frequency, flags, + convention='fourier', dtype=None): # Infer output dtype if none provided if is_numba_type_none(dtype): # Support both real and complex visibilities... diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index d23341dd..4ec1393d 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -2,10 +2,11 @@ from collections import defaultdict import numba -from numba import generated_jit, types +from numba import types import numpy as np from africanus.util.patterns import Multiton +from africanus.util.numba import overload, njit, JIT_OPTIONS from africanus.experimental.rime.fused.arguments import ArgumentDependencies from africanus.experimental.rime.fused.intrinsics import IntrinsicFactory from africanus.experimental.rime.fused.specification import RimeSpecification @@ -29,23 +30,27 @@ def rime_impl_factory(terms, transformers, ncorr): - @generated_jit(nopython=True, nogil=True, cache=True) - def rime(names, *inargs): - if len(inargs) != 1 or not isinstance(inargs[0], types.BaseTuple): - raise TypeError(f"{inargs[0]} must be be a Tuple") + def rime_impl(*args): + raise NotImplementedError - if not isinstance(names, types.BaseTuple): - raise TypeError(f"{names} must be a Tuple of strings") + @njit(**JIT_OPTIONS) + def rime(*args): + return rime_impl(*args) - if len(names) != len(inargs[0]): - raise ValueError(f"len(names): {len(names)} " - f"!= {len(inargs[0])}") + @overload(rime_impl, jit_options=JIT_OPTIONS, prefer_literal=True) + def nb_rime(*args): + if not len(args) % 2 == 0: + raise TypeError(f"len(args) {len(args)} is not divisible by 2") + + argstart = len(args) // 2 + names = args[:argstart] + inargs = args[argstart:] if not all(isinstance(n, types.Literal) for n in names): - raise TypeError(f"{names} must be a Tuple of strings") + raise TypeError(f"{names} must be a Tuple of Literal strings") if not all(n.literal_type is types.unicode_type for n in names): - raise TypeError(f"{names} must be a Tuple of strings") + raise TypeError(f"{names} must be a Tuple of Literal strings") # Get literal argument names names = tuple(n.literal_value for n in names) @@ -65,8 +70,8 @@ def rime(names, *inargs): except ValueError as e: raise ValueError(f"{str(e)} is required") - def impl(names, *inargs): - args_opt_idx = pack_opts_indices(inargs) + def impl(*args): + args_opt_idx = pack_opts_indices(args[argstart:]) args = pack_transformed(args_opt_idx) state = term_state(args) @@ -193,10 +198,8 @@ def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs): keys = (self.REQUIRED_ARGS_LITERAL + tuple(map(types.literal, kwargs.keys()))) - return self.impl(keys, time, - antenna1, antenna2, - feed1, feed2, - *kwargs.values()) + args = keys + (time, antenna1, antenna2, feed1, feed2) + tuple(kwargs.values()) + return self.impl(*args) def consolidate_args(args, kw): diff --git a/africanus/model/shape/gaussian_shape.py b/africanus/model/shape/gaussian_shape.py index 8a5b0905..84f02acb 100644 --- a/africanus/model/shape/gaussian_shape.py +++ b/africanus/model/shape/gaussian_shape.py @@ -4,12 +4,18 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit +from africanus.util.numba import njit, overload, JIT_OPTIONS from africanus.constants import c as lightspeed - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def gaussian(uvw, frequency, shape_params): + return gaussian_impl(uvw, frequency, shape_params) + +def gaussian_impl(uvw, frequency, shape_params): + raise NotImplementedError + +@overload(gaussian_impl, jit_options=JIT_OPTIONS) +def nb_gaussian(uvw, frequency, shape_params): # https://en.wikipedia.org/wiki/Full_width_at_half_maximum fwhm = 2.0 * np.sqrt(2.0 * np.log(2.0)) fwhminv = 1.0 / fwhm diff --git a/africanus/model/spectral/spec_model.py b/africanus/model/spectral/spec_model.py index f7b672f1..9c189cf8 100644 --- a/africanus/model/spectral/spec_model.py +++ b/africanus/model/spectral/spec_model.py @@ -4,10 +4,9 @@ from numba import types import numpy as np -from africanus.util.numba import generated_jit, njit +from africanus.util.numba import overload, JIT_OPTIONS, njit from africanus.util.docs import DocstringTemplate - def numpy_spectral_model(stokes, spi, ref_freq, frequency, base): out_shape = (stokes.shape[0], frequency.shape[0]) + stokes.shape[1:] @@ -94,9 +93,15 @@ def impl(array): return njit(nogil=True, cache=True)(impl) - -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def spectral_model(stokes, spi, ref_freq, frequency, base=0): + return spectral_model_impl(stokes, spi, ref_freq, frequency, base=base) + +def spectral_model_impl(stokes, spi, ref_freq, frequency, base=0): + raise NotImplementedError + +@overload(spectral_model_impl, jit_options=JIT_OPTIONS) +def nb_spectral_model(stokes, spi, ref_freq, frequency, base=0): arg_dtypes = tuple(np.dtype(a.dtype.name) for a in (stokes, spi, ref_freq, frequency)) dtype = np.result_type(*arg_dtypes) diff --git a/africanus/model/wsclean/spec_model.py b/africanus/model/wsclean/spec_model.py index acec935a..d0f57d90 100644 --- a/africanus/model/wsclean/spec_model.py +++ b/africanus/model/wsclean/spec_model.py @@ -2,7 +2,7 @@ from numba import types import numpy as np -from africanus.util.numba import generated_jit +from africanus.util.numba import JIT_OPTIONS, njit, overload from africanus.util.docs import DocstringTemplate @@ -26,8 +26,11 @@ def log_spectral_model(I, coeffs, log_poly, ref_freq, freq): # noqa: E741 return I[:, None] * np.exp(term.sum(axis=2)) -@generated_jit(nopython=True, nogil=True, cache=True) def _check_log_poly_shape(coeffs, log_poly): + raise NotImplementedError + +@overload(_check_log_poly_shape) +def overload_check_log_poly_shape(coeffs, log_poly): if isinstance(log_poly, types.npytypes.Array): def impl(coeffs, log_poly): if coeffs.shape[0] != log_poly.shape[0]: @@ -41,8 +44,11 @@ def impl(coeffs, log_poly): return impl -@generated_jit(nopython=True, nogil=True, cache=True) def _log_polynomial(log_poly, s): + raise NotImplementedError + +@overload(_log_polynomial) +def overload_log_polynomial(log_poly, s): if isinstance(log_poly, types.npytypes.Array): def impl(log_poly, s): return log_poly[s] @@ -55,8 +61,15 @@ def impl(log_poly, s): return impl -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def spectra(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 + return spectra_impl(I, coeffs, log_poly, ref_freq, frequency) + +def spectra_impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 + raise NotImplementedError + +@overload(spectra_impl, jit_option=JIT_OPTIONS) +def nb_spectra(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 arg_dtypes = tuple(np.dtype(a.dtype.name) for a in (I, coeffs, ref_freq, frequency)) dtype = np.result_type(*arg_dtypes) diff --git a/africanus/util/numba.py b/africanus/util/numba.py index bd59b560..2d0cc457 100644 --- a/africanus/util/numba.py +++ b/africanus/util/numba.py @@ -6,6 +6,11 @@ from africanus.util.docs import on_rtd +JIT_OPTIONS = { + "cache": True, + "nogil": True, +} + if on_rtd(): # Fake decorators when on readthedocs def _fake_decorator(*args, **kwargs): From 87aab3e6b8c346ea1812c0b2d2c36791a23748ca Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 13:35:10 +0200 Subject: [PATCH 03/18] Remove dead code --- africanus/averaging/bda_mapping.py | 66 ------------------------------ 1 file changed, 66 deletions(-) diff --git a/africanus/averaging/bda_mapping.py b/africanus/averaging/bda_mapping.py index c1b870a8..e5504c4f 100644 --- a/africanus/averaging/bda_mapping.py +++ b/africanus/averaging/bda_mapping.py @@ -20,72 +20,6 @@ class RowMapperError(Exception): pass -# @njit(nogil=True, cache=True) -# def erf26(x): -# """Implements 7.1.26 erf approximation from Abramowitz and -# Stegun (1972), pg. 299. Accurate for abs(eps(x)) <= 1.5e-7.""" - -# # Constants -# p = 0.3275911 -# a1 = 0.254829592 -# a2 = -0.284496736 -# a3 = 1.421413741 -# a4 = -1.453152027 -# a5 = 1.061405429 -# e = 2.718281828 - -# # t -# t = 1.0/(1.0 + (p * x)) - -# # Erf calculation -# erf = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t) -# erf *= e ** -(x ** 2) - -# return -round(erf, 9) if x < 0 else round(erf, 0) - - -# @njit(nogil=True, cache=True) -# def time_decorrelation(u, v, w, max_lm, time_bin_secs, min_wavelength): -# sidereal_rotation_rate = 7.292118516e-5 -# diffraction_limit = min_wavelength / np.sqrt(u**2 + v**2 + w**2) -# term = max_lm * time_bin_secs * sidereal_rotation_rate / diffraction_limit -# return 1.0 - 1.0645 * erf26(0.8326*term) / term - - -# _SERIES_COEFFS = (1./40, 107./67200, 3197./24192000, 49513./3973939200) - - -# @njit(nogil=True, cache=True, inline='always') -# def inv_sinc(sinc_x, tol=1e-12): -# # Invalid input -# if sinc_x > 1.0: -# raise ValueError("sinc_x > 1.0") - -# # Initial guess from reversion of Taylor series -# # https://math.stackexchange.com/questions/3189307/inverse-of-frac-sinxx -# x = t_pow = np.sqrt(6*np.abs((1 - sinc_x))) -# t_squared = t_pow*t_pow - -# for coeff in numba.literal_unroll(_SERIES_COEFFS): -# t_pow *= t_squared -# x += coeff * t_pow - -# # Use Newton Raphson to go the rest of the way -# # https://www.wolframalpha.com/input/?i=simplify+%28sinc%5Bx%5D+-+c%29+%2F+D%5Bsinc%5Bx%5D%2Cx%5D -# while True: -# # evaluate delta between this iteration sinc(x) and original -# sinx = np.sin(x) -# š¯˛“sinc_x = (1.0 if x == 0.0 else sinx/x) - sinc_x - -# # Stop if converged -# if np.abs(š¯˛“sinc_x) < tol: -# break - -# # Next iteration -# x -= (x*x * š¯˛“sinc_x) / (x*np.cos(x) - sinx) - -# return x - @njit(nogil=True, cache=True, inline='always') def factors(n): From 90888ade40af6c8393fe43542bc1ab5a5063f5cb Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 13:35:53 +0200 Subject: [PATCH 04/18] Disable caching of bda_mapper and re-enable its test case --- africanus/averaging/bda_mapping.py | 10 +-- africanus/averaging/tests/test_bda_mapping.py | 76 +++++++++---------- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/africanus/averaging/bda_mapping.py b/africanus/averaging/bda_mapping.py index e5504c4f..92db5f97 100644 --- a/africanus/averaging/bda_mapping.py +++ b/africanus/averaging/bda_mapping.py @@ -5,7 +5,7 @@ import numpy as np import numba from numba.experimental import jitclass -import numba.types +from numba import types from africanus.constants import c as lightspeed from africanus.util.numba import ( @@ -64,7 +64,7 @@ def max_chan_width(ref_freq, fractional_bandwidth): "nchan", "flag"]) -class Binner(object): +class Binner: def __init__(self, row_start, row_end, max_lm, decorrelation, time_bin_secs, max_chan_freq): @@ -305,7 +305,7 @@ def bda_mapper_impl(time, interval, ant1, ant2, uvw, min_nchan=1): return NotImplementedError -@overload(bda_mapper_impl, jit_options=JIT_OPTIONS) +@overload(bda_mapper_impl, jit_options={"nogil": True}) def nb_bda_mapper(time, interval, ant1, ant2, uvw, chan_width, chan_freq, max_uvw_dist, @@ -316,7 +316,7 @@ def nb_bda_mapper(time, interval, ant1, ant2, uvw, min_nchan=1): have_time_bin_secs = not is_numba_type_none(time_bin_secs) - Omitted = numba.types.misc.Omitted + Omitted = types.misc.Omitted decorr_type = (numba.typeof(decorrelation.value) if isinstance(decorrelation, Omitted) @@ -347,7 +347,7 @@ def nb_bda_mapper(time, interval, ant1, ant2, uvw, ('max_chan_freq', chan_freq.dtype), ('max_uvw_dist', max_uvw_dist)] - JitBinner = st(spec)(Binner) + JitBinner = jitclass(spec)(Binner) def impl(time, interval, ant1, ant2, uvw, chan_width, chan_freq, diff --git a/africanus/averaging/tests/test_bda_mapping.py b/africanus/averaging/tests/test_bda_mapping.py index b0ed2e65..58b1b8a5 100644 --- a/africanus/averaging/tests/test_bda_mapping.py +++ b/africanus/averaging/tests/test_bda_mapping.py @@ -5,7 +5,7 @@ import pytest from africanus.averaging.bda_mapping import bda_mapper, Binner - +from africanus.util.numba import njit @pytest.fixture(scope="session", params=[4096]) def nchan(request): @@ -172,43 +172,43 @@ def synthesized_uvw(ants, time, phase_dir, auto_correlations): return ant1, ant2, uvw -# @pytest.mark.parametrize("decorrelation", [0.95]) -# @pytest.mark.parametrize("min_nchan", [1]) -# def test_bda_mapper(time, synthesized_uvw, interval, -# chan_freq, chan_width, -# decorrelation, min_nchan): -# time = np.unique(time) -# ant1, ant2, uvw = synthesized_uvw - -# nbl = ant1.shape[0] -# ntime = time.shape[0] - -# time = np.repeat(time, nbl) -# interval = np.repeat(interval, nbl) -# ant1 = np.tile(ant1, ntime) -# ant2 = np.tile(ant2, ntime) -# flag_row = np.zeros(time.shape[0], dtype=np.int8) - -# max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max() - -# row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841 -# chan_width, chan_freq, -# max_uvw_dist, -# flag_row=flag_row, -# max_fov=3.0, -# decorrelation=decorrelation, -# min_nchan=min_nchan) - -# offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0]) -# assert_array_equal(offsets, row_meta.offsets[:-1]) -# assert row_meta.map.max() + 1 == row_meta.offsets[-1] - -# # NUM_CHAN divides number of channels exactly -# num_chan = np.diff(row_meta.offsets) -# _, remainder = np.divmod(chan_width.shape[0], num_chan) -# assert np.all(remainder == 0) -# decorr_cw = chan_width.sum() / num_chan -# assert_array_equal(decorr_cw, row_meta.decorr_chan_width) +@pytest.mark.parametrize("decorrelation", [0.95]) +@pytest.mark.parametrize("min_nchan", [1]) +def test_bda_mapper(time, synthesized_uvw, interval, + chan_freq, chan_width, + decorrelation, min_nchan): + time = np.unique(time) + ant1, ant2, uvw = synthesized_uvw + + nbl = ant1.shape[0] + ntime = time.shape[0] + + time = np.repeat(time, nbl) + interval = np.repeat(interval, nbl) + ant1 = np.tile(ant1, ntime) + ant2 = np.tile(ant2, ntime) + flag_row = np.zeros(time.shape[0], dtype=np.int8) + + max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max() + + row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841 + chan_width, chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=3.0, + decorrelation=decorrelation, + min_nchan=min_nchan) + + offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0]) + assert_array_equal(offsets, row_meta.offsets[:-1]) + assert row_meta.map.max() + 1 == row_meta.offsets[-1] + + # NUM_CHAN divides number of channels exactly + num_chan = np.diff(row_meta.offsets) + _, remainder = np.divmod(chan_width.shape[0], num_chan) + assert np.all(remainder == 0) + decorr_cw = chan_width.sum() / num_chan + assert_array_equal(decorr_cw, row_meta.decorr_chan_width) def test_bda_binner(time, interval, synthesized_uvw, From 4863069ce2123193a932496d604dc1f4b335349c Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 13:50:04 +0200 Subject: [PATCH 05/18] Further changes and autopep8 --- africanus/averaging/bda_avg.py | 118 +++++++++--------- africanus/averaging/bda_mapping.py | 30 ++--- africanus/averaging/shared.py | 1 + africanus/averaging/support.py | 4 + africanus/averaging/tests/test_bda_mapping.py | 1 + africanus/averaging/time_and_channel_avg.py | 31 +++-- .../averaging/time_and_channel_mapping.py | 2 + .../calibration/phase_only/phase_only.py | 10 +- .../utils/compute_and_corrupt_vis.py | 4 +- africanus/calibration/utils/correct_vis.py | 5 +- africanus/calibration/utils/corrupt_vis.py | 4 +- africanus/calibration/utils/residual_vis.py | 2 + africanus/coordinates/coordinates.py | 11 ++ africanus/dft/kernels.py | 7 +- africanus/experimental/rime/fused/core.py | 3 +- africanus/model/shape/gaussian_shape.py | 3 + africanus/model/spectral/spec_model.py | 4 + africanus/model/wsclean/spec_model.py | 4 + africanus/rime/phase.py | 13 +- africanus/rime/predict.py | 33 ++++- africanus/rime/wsclean_predict.py | 22 +++- 21 files changed, 211 insertions(+), 101 deletions(-) diff --git a/africanus/averaging/bda_avg.py b/africanus/averaging/bda_avg.py index 39877f8e..ddcc022d 100644 --- a/africanus/averaging/bda_avg.py +++ b/africanus/averaging/bda_avg.py @@ -30,6 +30,7 @@ def row_average(meta, ant1, ant2, flag_row=None, time_centroid=time_centroid, exposure=exposure, uvw=uvw, weight=weight, sigma=sigma) + def row_average_impl(meta, ant1, ant2, flag_row=None, time_centroid=None, exposure=None, uvw=None, weight=None, sigma=None): @@ -325,8 +326,6 @@ def codegen(context, builder, signature, args): return sig, codegen - - @njit(**JIT_OPTIONS) def row_chan_average(meta, flag_row=None, weight=None, visibilities=None, @@ -341,19 +340,20 @@ def row_chan_average(meta, flag_row=None, weight=None, def row_chan_average_impl(meta, flag_row=None, weight=None, - visibilities=None, - flag=None, - weight_spectrum=None, - sigma_spectrum=None): + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None): return NotImplementedError + @overload(row_chan_average_impl, jit_options=JIT_OPTIONS) def nb_row_chan_average(meta, flag_row=None, weight=None, - visibilities=None, - flag=None, - weight_spectrum=None, - sigma_spectrum=None): + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None): have_vis = not is_numba_type_none(visibilities) have_flag = not is_numba_type_none(flag) @@ -585,42 +585,44 @@ def bda(time, interval, antenna1, antenna2, decorrelation=decorrelation, time_bin_secs=time_bin_secs, min_nchan=min_nchan) + def bda_impl(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - max_uvw_dist=None, max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + max_uvw_dist=None, max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): return NotImplementedError + @overload(bda_impl, jit_options=JIT_OPTIONS) def nb_bda_impl(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - max_uvw_dist=None, max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): + time_centroid=None, exposure=None, flag_row=None, + uvw=None, weight=None, sigma=None, + chan_freq=None, chan_width=None, + effective_bw=None, resolution=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None, + max_uvw_dist=None, max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): # Merge flag_row and flag arrays flag_row = merge_flags(flag_row, flag) meta = bda_mapper(time, interval, antenna1, antenna2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan) + chan_width, chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan) row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841 time_centroid, exposure, uvw, @@ -635,27 +637,27 @@ def nb_bda_impl(time, interval, antenna1, antenna2, # Have to explicitly write it out because numba tuples # are highly constrained types return AverageOutput(meta.map, - meta.offsets, - meta.decorr_chan_width, - meta.time, - meta.interval, - meta.chan_width, - meta.flag_row, - row_avg.antenna1, - row_avg.antenna2, - row_avg.time_centroid, - row_avg.exposure, - row_avg.uvw, - row_avg.weight, - row_avg.sigma, - # None, # chan_data.chan_freq, - # None, # chan_data.chan_width, - # None, # chan_data.effective_bw, - # None, # chan_data.resolution, - row_chan_avg.visibilities, - row_chan_avg.flag, - row_chan_avg.weight_spectrum, - row_chan_avg.sigma_spectrum) + meta.offsets, + meta.decorr_chan_width, + meta.time, + meta.interval, + meta.chan_width, + meta.flag_row, + row_avg.antenna1, + row_avg.antenna2, + row_avg.time_centroid, + row_avg.exposure, + row_avg.uvw, + row_avg.weight, + row_avg.sigma, + # None, # chan_data.chan_freq, + # None, # chan_data.chan_width, + # None, # chan_data.effective_bw, + # None, # chan_data.resolution, + row_chan_avg.visibilities, + row_chan_avg.flag, + row_chan_avg.weight_spectrum, + row_chan_avg.sigma_spectrum) BDA_DOCS = DocstringTemplate(""" diff --git a/africanus/averaging/bda_mapping.py b/africanus/averaging/bda_mapping.py index 92db5f97..a2abdbf7 100644 --- a/africanus/averaging/bda_mapping.py +++ b/africanus/averaging/bda_mapping.py @@ -20,7 +20,6 @@ class RowMapperError(Exception): pass - @njit(nogil=True, cache=True, inline='always') def factors(n): assert n >= 1 @@ -296,24 +295,25 @@ def bda_mapper(time, interval, ant1, ant2, uvw, def bda_mapper_impl(time, interval, ant1, ant2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=None, - max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): + chan_width, chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): return NotImplementedError + @overload(bda_mapper_impl, jit_options={"nogil": True}) def nb_bda_mapper(time, interval, ant1, ant2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=None, - max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): + chan_width, chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1): have_time_bin_secs = not is_numba_type_none(time_bin_secs) Omitted = types.misc.Omitted diff --git a/africanus/averaging/shared.py b/africanus/averaging/shared.py index 64e5e3c2..d719c68c 100644 --- a/africanus/averaging/shared.py +++ b/africanus/averaging/shared.py @@ -16,6 +16,7 @@ def shape_or_invalid_shape(array, ndim): def merge_flags(flag_row, flag): pass + @overload(merge_flags, inline='always') def _merge_flags(flag_row, flag): have_flag_row = not is_numba_type_none(flag_row) diff --git a/africanus/averaging/support.py b/africanus/averaging/support.py index 3c671c0a..a4ca769e 100644 --- a/africanus/averaging/support.py +++ b/africanus/averaging/support.py @@ -57,9 +57,11 @@ def _unique_internal(data): def unique_time(time): return unique_time_impl(time) + def unique_time_impl(time): return NotImplementedError + @overload(unique_time_impl, jit_options=JIT_OPTIONS) def nb_unique_time(time): """ Return unique time, inverse index and counts """ @@ -76,9 +78,11 @@ def impl(time): def unique_baselines(ant1, ant2): return unique_baselines_impl(ant1, ant2) + def unique_baselines_impl(ant1, ant2): return NotImplementedError + @overload(unique_baselines_impl, jit_options=JIT_OPTIONS) def nb_unique_baselines(ant1, ant2): """ Return unique baselines, inverse index and counts """ diff --git a/africanus/averaging/tests/test_bda_mapping.py b/africanus/averaging/tests/test_bda_mapping.py index 58b1b8a5..af235e84 100644 --- a/africanus/averaging/tests/test_bda_mapping.py +++ b/africanus/averaging/tests/test_bda_mapping.py @@ -7,6 +7,7 @@ from africanus.averaging.bda_mapping import bda_mapper, Binner from africanus.util.numba import njit + @pytest.fixture(scope="session", params=[4096]) def nchan(request): return request.param diff --git a/africanus/averaging/time_and_channel_avg.py b/africanus/averaging/time_and_channel_avg.py index 08dc199d..7038306e 100644 --- a/africanus/averaging/time_and_channel_avg.py +++ b/africanus/averaging/time_and_channel_avg.py @@ -61,21 +61,23 @@ def chan_add(output, input, orow, ochan, irow, ichan, corr): @njit(**JIT_OPTIONS) def row_average(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): return row_average_impl(meta, ant1, ant2, flag_row=flag_row, time_centroid=time_centroid, exposure=exposure, uvw=uvw, weight=weight, sigma=sigma) + def row_average_impl(meta, ant1, ant2, flag_row=None, time_centroid=None, exposure=None, uvw=None, weight=None, sigma=None): return NotImplementedError + @overload(row_average_impl, jit_options=JIT_OPTIONS) def nb_row_average(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): + time_centroid=None, exposure=None, uvw=None, + weight=None, sigma=None): have_flag_row = not is_numba_type_none(flag_row) flags_match = matching_flag_factory(have_flag_row) @@ -340,17 +342,19 @@ def row_chan_average(row_meta, chan_meta, visibilities=visibilities, flag=flag, weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum) + def row_chan_average_impl(row_meta, chan_meta, flag_row=None, weight=None, visibilities=None, flag=None, weight_spectrum=None, sigma_spectrum=None): return NotImplementedError + @overload(row_chan_average_impl, jit_options=JIT_OPTIONS) def nb_row_chan_average(row_meta, chan_meta, - flag_row=None, weight=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None): + flag_row=None, weight=None, + visibilities=None, flag=None, + weight_spectrum=None, sigma_spectrum=None): dummy_chan_freq = None dummy_chan_width = None @@ -550,20 +554,23 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, _chan_output_fields = ["chan_freq", "chan_width", "effective_bw", "resolution"] ChannelAverageOutput = namedtuple("ChannelAverageOutput", _chan_output_fields) + @njit(**JIT_OPTIONS) def chan_average(chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None): return chan_average_impl(chan_meta, chan_freq=chan_freq, chan_width=chan_width, effective_bw=effective_bw, resolution=resolution) + def chan_average_impl(chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None): return NotImplementedError + @overload(chan_average_impl, jit_options=JIT_OPTIONS) def nb_chan_average(chan_meta, chan_freq=None, chan_width=None, - effective_bw=None, resolution=None): + effective_bw=None, resolution=None): def impl(chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None): @@ -618,6 +625,7 @@ def impl(chan_meta, chan_freq=None, chan_width=None, _chan_output_fields + _rowchan_output_fields) + @njit(**JIT_OPTIONS) def time_and_channel(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, @@ -636,6 +644,7 @@ def time_and_channel(time, interval, antenna1, antenna2, weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum, time_bin_secs=time_bin_secs, chan_bin_size=chan_bin_size) + def time_and_channel_impl(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, uvw=None, weight=None, sigma=None, @@ -661,12 +670,14 @@ def nb_time_and_channel(time, interval, antenna1, antenna2, types.scalars.Integer) if not isinstance(time_bin_secs, valid_types): - raise TypeError(f"time_bin_secs ({time_bin_secs}) must be a scalar float") + raise TypeError( + f"time_bin_secs ({time_bin_secs}) must be a scalar float") valid_types = (types.misc.Omitted, types.scalars.Integer) if not isinstance(chan_bin_size, valid_types): - raise TypeError(f"chan_bin_size ({chan_bin_size}) must be a scalar integer") + raise TypeError( + f"chan_bin_size ({chan_bin_size}) must be a scalar integer") def impl(time, interval, antenna1, antenna2, time_centroid=None, exposure=None, flag_row=None, diff --git a/africanus/averaging/time_and_channel_mapping.py b/africanus/averaging/time_and_channel_mapping.py index 9acd92d0..7dbb6af0 100644 --- a/africanus/averaging/time_and_channel_mapping.py +++ b/africanus/averaging/time_and_channel_mapping.py @@ -55,6 +55,7 @@ def impl(flag_row, in_row, out_flag_row, out_row, flagged): RowMapOutput = namedtuple("RowMapOutput", ["map", "time", "interval", "flag_row"]) + @njit(**JIT_OPTIONS) def row_mapper(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): @@ -180,6 +181,7 @@ def row_mapper(time, interval, antenna1, antenna2, return row_mapper_impl(time, interval, antenna1, antenna2, flag_row=flag_row, time_bin_secs=time_bin_secs) + def row_mapper_impl(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): return NotImplementedError diff --git a/africanus/calibration/phase_only/phase_only.py b/africanus/calibration/phase_only/phase_only.py index afdac6b6..f0c6d329 100755 --- a/africanus/calibration/phase_only/phase_only.py +++ b/africanus/calibration/phase_only/phase_only.py @@ -29,6 +29,7 @@ def compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, antenna1, antenna2, jones, residual, model, flag) + def compute_jhj_and_jhr_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): return NotImplementedError @@ -36,7 +37,7 @@ def compute_jhj_and_jhr_impl(time_bin_indices, time_bin_counts, antenna1, @overload(compute_jhj_and_jhr_impl, jit_options=JIT_OPTIONS) def nb_compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): + antenna2, jones, residual, model, flag): mode = check_type(jones, residual) if mode != DIAG_DIAG: @@ -88,6 +89,7 @@ def compute_jhj(time_bin_indices, time_bin_counts, antenna1, return compute_jhj_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag) + def compute_jhj_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag): return NotImplementedError @@ -95,7 +97,7 @@ def compute_jhj_impl(time_bin_indices, time_bin_counts, antenna1, @overload(compute_jhj_impl, jit_options=JIT_OPTIONS) def nb_compute_jhj(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, flag): + antenna2, jones, model, flag): mode = check_type(jones, model, vis_type='model') @@ -140,10 +142,12 @@ def compute_jhr(time_bin_indices, time_bin_counts, antenna1, antenna1, antenna2, jones, residual, model, flag) + def compute_jhr_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): + antenna2, jones, residual, model, flag): return NotImplementedError + @overload(compute_jhr_impl, jit_options=JIT_OPTIONS) def nb_compute_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): diff --git a/africanus/calibration/utils/compute_and_corrupt_vis.py b/africanus/calibration/utils/compute_and_corrupt_vis.py index 73f6199d..cd2b5ef8 100644 --- a/africanus/calibration/utils/compute_and_corrupt_vis.py +++ b/africanus/calibration/utils/compute_and_corrupt_vis.py @@ -72,12 +72,14 @@ def jones_mul(a1j, model, a2j, uvw, freq, lm, out): return njit(nogil=True, inline='always')(jones_mul) + @njit(**JIT_OPTIONS) def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm): return compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm) + def compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm): return NotImplementedError @@ -85,7 +87,7 @@ def compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, @overload(compute_and_corrupt_vis_impl, jit_options=JIT_OPTIONS) def mb_compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, uvw, freq, lm): + antenna2, jones, model, uvw, freq, lm): mode = check_type(jones, model, vis_type='model') jones_mul = jones_mul_factory(mode) diff --git a/africanus/calibration/utils/correct_vis.py b/africanus/calibration/utils/correct_vis.py index 4918c85b..6e99cdea 100644 --- a/africanus/calibration/utils/correct_vis.py +++ b/africanus/calibration/utils/correct_vis.py @@ -65,19 +65,22 @@ def jones_inverse_mul(a1j, blj, a2j, out): t4*b11 return njit(nogil=True, inline='always')(jones_inverse_mul) + @njit(**JIT_OPTIONS) def correct_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag): return correct_vis_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag) + def correct_vis_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag): return NotImplementedError + @overload(correct_vis_impl, jit_options=JIT_OPTIONS) def nb_correct_vis(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, vis, flag): + antenna1, antenna2, jones, vis, flag): mode = check_type(jones, vis) jones_inverse_mul = jones_inverse_mul_factory(mode) diff --git a/africanus/calibration/utils/corrupt_vis.py b/africanus/calibration/utils/corrupt_vis.py index 79f1cb1e..191069f0 100644 --- a/africanus/calibration/utils/corrupt_vis.py +++ b/africanus/calibration/utils/corrupt_vis.py @@ -60,12 +60,14 @@ def jones_mul(a1j, model, a2j, out): def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): return corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model) + antenna2, jones, model) + def corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): return NotImplementedError + @overload(corrupt_vis_impl, jit_options=JIT_OPTIONS) def nb_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): diff --git a/africanus/calibration/utils/residual_vis.py b/africanus/calibration/utils/residual_vis.py index 37fb36b0..6f2d0c96 100644 --- a/africanus/calibration/utils/residual_vis.py +++ b/africanus/calibration/utils/residual_vis.py @@ -65,10 +65,12 @@ def residual_vis(time_bin_indices, time_bin_counts, antenna1, return residual_vis_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model) + def residual_vis_impl(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model): return NotImplementedError + @overload(residual_vis_impl, jit_options=JIT_OPTIONS) def nb_residual_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model): diff --git a/africanus/coordinates/coordinates.py b/africanus/coordinates/coordinates.py index ac3e66fe..047f5b9c 100644 --- a/africanus/coordinates/coordinates.py +++ b/africanus/coordinates/coordinates.py @@ -24,13 +24,16 @@ def _create_phase_centre(phase_centre, dtype): def _return_phase_centre(phase_centre, dtype): return phase_centre + @njit(**JIT_OPTIONS) def radec_to_lmn(radec, phase_centre=None): return radec_to_lmn_impl(radec, phase_centre=phase_centre) + def radec_to_lmn_impl(radec, phase_centre=None): raise NotImplementedError + @overload(radec_to_lmn_impl, jit_options=JIT_OPTIONS) def nb_radec_to_lmn(radec, phase_centre=None): dtype = radec.dtype @@ -69,13 +72,16 @@ def _radec_to_lmn_impl(radec, phase_centre=None): return _radec_to_lmn_impl + @njit(**JIT_OPTIONS) def radec_to_lm(radec, phase_centre=None): return radec_to_lm_impl(radec, phase_centre=phase_centre) + def radec_to_lm_impl(radec, phase_centre=None): raise NotImplementedError + @overload(radec_to_lm_impl, jit_options=JIT_OPTIONS) def nb_radec_to_lm(radec, phase_centre=None): dtype = radec.dtype @@ -112,13 +118,16 @@ def _radec_to_lm_impl(radec, phase_centre=None): return _radec_to_lm_impl + @njit(**JIT_OPTIONS) def lmn_to_radec(lmn, phase_centre=None): return lmn_to_radec_impl(lmn, phase_centre=phase_centre) + def lmn_to_radec_impl(lmn, phase_centre=None): raise NotImplementedError + @overload(lmn_to_radec_impl, jit_options=JIT_OPTIONS) def nb_lmn_to_radec(lmn, phase_centre=None): dtype = lmn.dtype @@ -153,9 +162,11 @@ def _lmn_to_radec_impl(lmn, phase_centre=None): def lm_to_radec(lm, phase_centre=None): return lm_to_radec_impl(lm, phase_centre=phase_centre) + def lm_to_radec_impl(lm, phase_centre=None): raise NotImplementedError + @overload(lm_to_radec_impl, jit_options=JIT_OPTIONS) def nb_lm_to_radec(lm, phase_centre=None): dtype = lm.dtype diff --git a/africanus/dft/kernels.py b/africanus/dft/kernels.py index f722c709..267cff97 100644 --- a/africanus/dft/kernels.py +++ b/africanus/dft/kernels.py @@ -10,19 +10,22 @@ from africanus.constants import minus_two_pi_over_c, two_pi_over_c + @njit(**JIT_OPTIONS) def im_to_vis(image, uvw, lm, frequency, convention='fourier', dtype=None): return im_to_vis_impl(image, uvw, lm, frequency, convention=convention, dtype=dtype) + def im_to_vis_impl(image, uvw, lm, frequency, convention='fourier', dtype=None): raise NotImplementedError + @overload(im_to_vis_impl, jit_options=JIT_OPTIONS) def nb_im_to_vis(image, uvw, lm, frequency, - convention='fourier', dtype=None): + convention='fourier', dtype=None): # Infer complex output dtype if none provided if is_numba_type_none(dtype): out_dtype = np.result_type(np.complex64, @@ -77,10 +80,12 @@ def vis_to_im(vis, uvw, lm, frequency, flags, return vis_to_im_impl(vis, uvw, lm, frequency, flags, convention=convention, dtype=dtype) + def vis_to_im_impl(vis, uvw, lm, frequency, flags, convention='fourier', dtype=None): raise NotImplementedError + @overload(vis_to_im_impl, jit_options=JIT_OPTIONS) def nb_vis_to_im(vis, uvw, lm, frequency, flags, convention='fourier', dtype=None): diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index 4ec1393d..04277227 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -198,7 +198,8 @@ def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs): keys = (self.REQUIRED_ARGS_LITERAL + tuple(map(types.literal, kwargs.keys()))) - args = keys + (time, antenna1, antenna2, feed1, feed2) + tuple(kwargs.values()) + args = keys + (time, antenna1, antenna2, feed1, + feed2) + tuple(kwargs.values()) return self.impl(*args) diff --git a/africanus/model/shape/gaussian_shape.py b/africanus/model/shape/gaussian_shape.py index 84f02acb..de341536 100644 --- a/africanus/model/shape/gaussian_shape.py +++ b/africanus/model/shape/gaussian_shape.py @@ -7,13 +7,16 @@ from africanus.util.numba import njit, overload, JIT_OPTIONS from africanus.constants import c as lightspeed + @njit(**JIT_OPTIONS) def gaussian(uvw, frequency, shape_params): return gaussian_impl(uvw, frequency, shape_params) + def gaussian_impl(uvw, frequency, shape_params): raise NotImplementedError + @overload(gaussian_impl, jit_options=JIT_OPTIONS) def nb_gaussian(uvw, frequency, shape_params): # https://en.wikipedia.org/wiki/Full_width_at_half_maximum diff --git a/africanus/model/spectral/spec_model.py b/africanus/model/spectral/spec_model.py index 9c189cf8..42d6b3cf 100644 --- a/africanus/model/spectral/spec_model.py +++ b/africanus/model/spectral/spec_model.py @@ -7,6 +7,7 @@ from africanus.util.numba import overload, JIT_OPTIONS, njit from africanus.util.docs import DocstringTemplate + def numpy_spectral_model(stokes, spi, ref_freq, frequency, base): out_shape = (stokes.shape[0], frequency.shape[0]) + stokes.shape[1:] @@ -93,13 +94,16 @@ def impl(array): return njit(nogil=True, cache=True)(impl) + @njit(**JIT_OPTIONS) def spectral_model(stokes, spi, ref_freq, frequency, base=0): return spectral_model_impl(stokes, spi, ref_freq, frequency, base=base) + def spectral_model_impl(stokes, spi, ref_freq, frequency, base=0): raise NotImplementedError + @overload(spectral_model_impl, jit_options=JIT_OPTIONS) def nb_spectral_model(stokes, spi, ref_freq, frequency, base=0): arg_dtypes = tuple(np.dtype(a.dtype.name) for a diff --git a/africanus/model/wsclean/spec_model.py b/africanus/model/wsclean/spec_model.py index d0f57d90..ee51a199 100644 --- a/africanus/model/wsclean/spec_model.py +++ b/africanus/model/wsclean/spec_model.py @@ -29,6 +29,7 @@ def log_spectral_model(I, coeffs, log_poly, ref_freq, freq): # noqa: E741 def _check_log_poly_shape(coeffs, log_poly): raise NotImplementedError + @overload(_check_log_poly_shape) def overload_check_log_poly_shape(coeffs, log_poly): if isinstance(log_poly, types.npytypes.Array): @@ -47,6 +48,7 @@ def impl(coeffs, log_poly): def _log_polynomial(log_poly, s): raise NotImplementedError + @overload(_log_polynomial) def overload_log_polynomial(log_poly, s): if isinstance(log_poly, types.npytypes.Array): @@ -65,9 +67,11 @@ def impl(log_poly, s): def spectra(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 return spectra_impl(I, coeffs, log_poly, ref_freq, frequency) + def spectra_impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 raise NotImplementedError + @overload(spectra_impl, jit_option=JIT_OPTIONS) def nb_spectra(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 arg_dtypes = tuple(np.dtype(a.dtype.name) for a diff --git a/africanus/rime/phase.py b/africanus/rime/phase.py index 90970c2a..8fce52a7 100644 --- a/africanus/rime/phase.py +++ b/africanus/rime/phase.py @@ -4,12 +4,21 @@ from africanus.constants import minus_two_pi_over_c from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit +from africanus.util.numba import JIT_OPTIONS, overload, njit from africanus.util.type_inference import infer_complex_dtype -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def phase_delay(lm, uvw, frequency, convention='fourier'): + return phase_delay_impl(lm, uvw, frequency, convention=convention) + + +def phase_delay_impl(lm, uvw, frequency, convention='fourier'): + raise NotImplementedError + + +@overload(phase_delay_impl, jit_options=JIT_OPTIONS) +def nb_phase_delay(lm, uvw, frequency, convention='fourier'): # Bake constants in with the correct type one = lm.dtype(1.0) zero = lm.dtype(0.0) diff --git a/africanus/rime/predict.py b/africanus/rime/predict.py index a9eae8ce..8733acd7 100644 --- a/africanus/rime/predict.py +++ b/africanus/rime/predict.py @@ -4,7 +4,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import is_numba_type_none, generated_jit, njit +from africanus.util.numba import is_numba_type_none, JIT_OPTIONS, njit, overload JONES_NOT_PRESENT = 0 @@ -413,10 +413,25 @@ def predict_checks(time_index, antenna1, antenna2, have_dies1, have_bvis, have_dies2) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def predict_vis(time_index, antenna1, antenna2, dde1_jones=None, source_coh=None, dde2_jones=None, die1_jones=None, base_vis=None, die2_jones=None): + return predict_vis_impl(time_index, antenna1, antenna2, + dde1_jones=dde1_jones, source_coh=source_coh, dde2_jones=dde2_jones, + die1_jones=die1_jones, base_vis=base_vis, die2_jones=die2_jones) + + +def predict_vis_impl(time_index, antenna1, antenna2, + dde1_jones=None, source_coh=None, dde2_jones=None, + die1_jones=None, base_vis=None, die2_jones=None): + raise NotImplementedError + + +@overload(predict_vis_impl, jit_options=JIT_OPTIONS) +def nb_predict_vis(time_index, antenna1, antenna2, + dde1_jones=None, source_coh=None, dde2_jones=None, + die1_jones=None, base_vis=None, die2_jones=None): tup = predict_checks(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, @@ -490,9 +505,21 @@ def _predict_vis_fn(time_index, antenna1, antenna2, return _predict_vis_fn -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def apply_gains(time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones): + return apply_gains_impl(time_index, antenna1, antenna2, + die1_jones, corrupted_vis, die2_jones) + + +def apply_gains_impl(time_index, antenna1, antenna2, + die1_jones, corrupted_vis, die2_jones): + raise NotImplementedError + + +@overload(apply_gains_impl, jit_options=JIT_OPTIONS) +def nb_apply_gains(time_index, antenna1, antenna2, + die1_jones, corrupted_vis, die2_jones): def impl(time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones): diff --git a/africanus/rime/wsclean_predict.py b/africanus/rime/wsclean_predict.py index b200cd14..c490dbc0 100644 --- a/africanus/rime/wsclean_predict.py +++ b/africanus/rime/wsclean_predict.py @@ -4,12 +4,12 @@ from africanus.constants import two_pi_over_c, c as lightspeed from africanus.util.docs import DocstringTemplate -from africanus.util.numba import generated_jit, jit +from africanus.util.numba import JIT_OPTIONS, overload, njit from africanus.model.wsclean.spec_model import spectra -@jit(nopython=True, nogil=True, cache=True) -def wsclean_predict_impl(uvw, lm, source_type, gauss_shape, +@njit(**JIT_OPTIONS) +def wsclean_predict_main(uvw, lm, source_type, gauss_shape, frequency, spectrum, dtype): fwhm = 2.0 * np.sqrt(2.0 * np.log(2.0)) @@ -86,9 +86,21 @@ def wsclean_predict_impl(uvw, lm, source_type, gauss_shape, return vis -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(**JIT_OPTIONS) def wsclean_predict(uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency): + return wsclean_predict_impl(uvw, lm, source_type, flux, coeffs, + log_poly, ref_freq, gauss_shape, frequency) + + +def wsclean_predict_impl(uvw, lm, source_type, flux, coeffs, + log_poly, ref_freq, gauss_shape, frequency): + raise NotImplementedError + + +@overload(wsclean_predict_impl, jit_options=JIT_OPTIONS) +def nb_wsclean_predict(uvw, lm, source_type, flux, coeffs, + log_poly, ref_freq, gauss_shape, frequency): arg_dtypes = tuple(np.dtype(a.dtype.name) for a in (uvw, lm, flux, coeffs, ref_freq, frequency)) dtype = np.result_type(np.complex64, *arg_dtypes) @@ -96,7 +108,7 @@ def wsclean_predict(uvw, lm, source_type, flux, coeffs, def impl(uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency): spectrum = spectra(flux, coeffs, log_poly, ref_freq, frequency) - return wsclean_predict_impl(uvw, lm, source_type, gauss_shape, + return wsclean_predict_main(uvw, lm, source_type, gauss_shape, frequency, spectrum, dtype) return impl From 8fe3c73560890aa39886b53620c14956773df9f3 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 14:07:24 +0200 Subject: [PATCH 06/18] more flake8 fixes --- africanus/averaging/bda_avg.py | 15 +++++++------ africanus/averaging/shared.py | 1 - africanus/averaging/tests/test_bda_mapping.py | 1 - africanus/averaging/time_and_channel_avg.py | 22 +++++++++++++------ .../averaging/time_and_channel_mapping.py | 7 +++++- .../utils/compute_and_corrupt_vis.py | 5 +++-- africanus/coordinates/coordinates.py | 4 +++- africanus/dft/kernels.py | 3 ++- africanus/experimental/rime/fused/core.py | 1 - africanus/gridding/wgridder/vis2im.py | 4 ++-- africanus/rime/predict.py | 11 +++++++--- 11 files changed, 47 insertions(+), 27 deletions(-) diff --git a/africanus/averaging/bda_avg.py b/africanus/averaging/bda_avg.py index ddcc022d..12574a16 100644 --- a/africanus/averaging/bda_avg.py +++ b/africanus/averaging/bda_avg.py @@ -27,8 +27,8 @@ def row_average(meta, ant1, ant2, flag_row=None, time_centroid=None, exposure=None, uvw=None, weight=None, sigma=None): return row_average_impl(meta, ant1, ant2, flag_row=flag_row, - time_centroid=time_centroid, exposure=exposure, uvw=uvw, - weight=weight, sigma=sigma) + time_centroid=time_centroid, exposure=exposure, + uvw=uvw, weight=weight, sigma=sigma) def row_average_impl(meta, ant1, ant2, flag_row=None, @@ -575,12 +575,13 @@ def bda(time, interval, antenna1, antenna2, min_nchan=1): return bda_impl(time, interval, antenna1, antenna2, - time_centroid=time_centroid, exposure=exposure, flag_row=flag_row, - uvw=uvw, weight=weight, sigma=sigma, + time_centroid=time_centroid, exposure=exposure, + flag_row=flag_row, uvw=uvw, weight=weight, sigma=sigma, chan_freq=chan_freq, chan_width=chan_width, effective_bw=effective_bw, resolution=resolution, visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, max_uvw_dist=max_uvw_dist, max_fov=max_fov, decorrelation=decorrelation, time_bin_secs=time_bin_secs, min_nchan=min_nchan) @@ -625,8 +626,8 @@ def nb_bda_impl(time, interval, antenna1, antenna2, min_nchan=min_nchan) row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841 - time_centroid, exposure, uvw, - weight=weight, sigma=sigma) + time_centroid, exposure, uvw, + weight=weight, sigma=sigma) row_chan_avg = row_chan_average(meta, # noqa: F841 flag_row=flag_row, diff --git a/africanus/averaging/shared.py b/africanus/averaging/shared.py index d719c68c..b53b43e9 100644 --- a/africanus/averaging/shared.py +++ b/africanus/averaging/shared.py @@ -5,7 +5,6 @@ from africanus.util.numba import (is_numba_type_none, intrinsic, njit, - JIT_OPTIONS, overload) diff --git a/africanus/averaging/tests/test_bda_mapping.py b/africanus/averaging/tests/test_bda_mapping.py index af235e84..fc050869 100644 --- a/africanus/averaging/tests/test_bda_mapping.py +++ b/africanus/averaging/tests/test_bda_mapping.py @@ -5,7 +5,6 @@ import pytest from africanus.averaging.bda_mapping import bda_mapper, Binner -from africanus.util.numba import njit @pytest.fixture(scope="session", params=[4096]) diff --git a/africanus/averaging/time_and_channel_avg.py b/africanus/averaging/time_and_channel_avg.py index 7038306e..208fd07e 100644 --- a/africanus/averaging/time_and_channel_avg.py +++ b/africanus/averaging/time_and_channel_avg.py @@ -340,7 +340,8 @@ def row_chan_average(row_meta, chan_meta, return row_chan_average_impl(row_meta, chan_meta, flag_row=flag_row, weight=weight, visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum) + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum) def row_chan_average_impl(row_meta, chan_meta, @@ -558,8 +559,10 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, @njit(**JIT_OPTIONS) def chan_average(chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None): - return chan_average_impl(chan_meta, chan_freq=chan_freq, chan_width=chan_width, - effective_bw=effective_bw, resolution=resolution) + return chan_average_impl(chan_meta, chan_freq=chan_freq, + chan_width=chan_width, + effective_bw=effective_bw, + resolution=resolution) def chan_average_impl(chan_meta, chan_freq=None, chan_width=None, @@ -636,13 +639,18 @@ def time_and_channel(time, interval, antenna1, antenna2, weight_spectrum=None, sigma_spectrum=None, time_bin_secs=1.0, chan_bin_size=1): return time_and_channel_impl(time, interval, antenna1, antenna2, - time_centroid=time_centroid, exposure=exposure, flag_row=flag_row, + time_centroid=time_centroid, + exposure=exposure, + flag_row=flag_row, uvw=uvw, weight=weight, sigma=sigma, chan_freq=chan_freq, chan_width=chan_width, - effective_bw=effective_bw, resolution=resolution, + effective_bw=effective_bw, + resolution=resolution, visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum, - time_bin_secs=time_bin_secs, chan_bin_size=chan_bin_size) + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + time_bin_secs=time_bin_secs, + chan_bin_size=chan_bin_size) def time_and_channel_impl(time, interval, antenna1, antenna2, diff --git a/africanus/averaging/time_and_channel_mapping.py b/africanus/averaging/time_and_channel_mapping.py index 7dbb6af0..31b287b8 100644 --- a/africanus/averaging/time_and_channel_mapping.py +++ b/africanus/averaging/time_and_channel_mapping.py @@ -7,7 +7,12 @@ import numba from africanus.averaging.support import unique_time, unique_baselines -from africanus.util.numba import is_numba_type_none, njit, jit, JIT_OPTIONS, overload +from africanus.util.numba import ( + is_numba_type_none, + njit, + jit, + JIT_OPTIONS, + overload) class RowMapperError(Exception): diff --git a/africanus/calibration/utils/compute_and_corrupt_vis.py b/africanus/calibration/utils/compute_and_corrupt_vis.py index cd2b5ef8..33ad85de 100644 --- a/africanus/calibration/utils/compute_and_corrupt_vis.py +++ b/africanus/calibration/utils/compute_and_corrupt_vis.py @@ -76,8 +76,9 @@ def jones_mul(a1j, model, a2j, uvw, freq, lm, out): @njit(**JIT_OPTIONS) def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm): - return compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, uvw, freq, lm) + return compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, + antenna1, antenna2, jones, model, + uvw, freq, lm) def compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, diff --git a/africanus/coordinates/coordinates.py b/africanus/coordinates/coordinates.py index 047f5b9c..475691d3 100644 --- a/africanus/coordinates/coordinates.py +++ b/africanus/coordinates/coordinates.py @@ -4,7 +4,9 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import is_numba_type_none, jit, JIT_OPTIONS, njit, overload +from africanus.util.numba import (is_numba_type_none, + jit, JIT_OPTIONS, + njit, overload) from africanus.util.requirements import requires_optional try: diff --git a/africanus/dft/kernels.py b/africanus/dft/kernels.py index 267cff97..000cbba1 100644 --- a/africanus/dft/kernels.py +++ b/africanus/dft/kernels.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from africanus.util.numba import is_numba_type_none, njit, overload, JIT_OPTIONS +from africanus.util.numba import (is_numba_type_none, njit, + overload, JIT_OPTIONS) from africanus.util.docs import doc_tuple_to_str from collections import namedtuple diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index 04277227..5af73a10 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -44,7 +44,6 @@ def nb_rime(*args): argstart = len(args) // 2 names = args[:argstart] - inargs = args[argstart:] if not all(isinstance(n, types.Literal) for n in names): raise TypeError(f"{names} must be a Tuple of Literal strings") diff --git a/africanus/gridding/wgridder/vis2im.py b/africanus/gridding/wgridder/vis2im.py index 30155c56..752d376b 100644 --- a/africanus/gridding/wgridder/vis2im.py +++ b/africanus/gridding/wgridder/vis2im.py @@ -20,9 +20,9 @@ def _dirty_internal(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, # need a copy here if using multiple row chunks freq_bin_idx2 = freq_bin_idx - freq_bin_idx.min() nband = freq_bin_idx.size - if type(vis[0, 0]) == np.complex64: + if type(vis[0, 0]) is np.complex64: real_type = np.float32 - elif type(vis[0, 0]) == np.complex128: + elif type(vis[0, 0]) is np.complex128: real_type = np.float64 else: raise ValueError("Vis of incorrect type") diff --git a/africanus/rime/predict.py b/africanus/rime/predict.py index 8733acd7..cc7e13e1 100644 --- a/africanus/rime/predict.py +++ b/africanus/rime/predict.py @@ -4,7 +4,8 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import is_numba_type_none, JIT_OPTIONS, njit, overload +from africanus.util.numba import (is_numba_type_none, JIT_OPTIONS, + njit, overload) JONES_NOT_PRESENT = 0 @@ -418,8 +419,12 @@ def predict_vis(time_index, antenna1, antenna2, dde1_jones=None, source_coh=None, dde2_jones=None, die1_jones=None, base_vis=None, die2_jones=None): return predict_vis_impl(time_index, antenna1, antenna2, - dde1_jones=dde1_jones, source_coh=source_coh, dde2_jones=dde2_jones, - die1_jones=die1_jones, base_vis=base_vis, die2_jones=die2_jones) + dde1_jones=dde1_jones, + source_coh=source_coh, + dde2_jones=dde2_jones, + die1_jones=die1_jones, + base_vis=base_vis, + die2_jones=die2_jones) def predict_vis_impl(time_index, antenna1, antenna2, From da15d5a1b3d3d308cdedae3ebe5c60234c8fc8eb Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 14:57:57 +0200 Subject: [PATCH 07/18] Correctly provide a jitted merge_flags --- africanus/averaging/shared.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/africanus/averaging/shared.py b/africanus/averaging/shared.py index b53b43e9..089bd9da 100644 --- a/africanus/averaging/shared.py +++ b/africanus/averaging/shared.py @@ -4,6 +4,7 @@ from africanus.util.numba import (is_numba_type_none, intrinsic, + JIT_OPTIONS, njit, overload) @@ -12,12 +13,17 @@ def shape_or_invalid_shape(array, ndim): pass +@njit(**JIT_OPTIONS) def merge_flags(flag_row, flag): - pass + return merge_flags_impl(flag_row, flag) + + +def merge_flags_impl(flag_row, flag): + raise NotImplementedError -@overload(merge_flags, inline='always') -def _merge_flags(flag_row, flag): +@overload(merge_flags_impl, jit_options=JIT_OPTIONS) +def nb_merge_flags(flag_row, flag): have_flag_row = not is_numba_type_none(flag_row) have_flag = not is_numba_type_none(flag) From a69f4fd80dfe9776dd6e51c1f73c740816cb1a2f Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 14:58:14 +0200 Subject: [PATCH 08/18] Remove generated_jit from africanus.util.numba --- africanus/util/numba.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/africanus/util/numba.py b/africanus/util/numba.py index 2d0cc457..ada20306 100644 --- a/africanus/util/numba.py +++ b/africanus/util/numba.py @@ -23,7 +23,6 @@ def wrapper(*args, **kwargs): return decorator cfunc = _fake_decorator - generated_jit = _fake_decorator jit = _fake_decorator njit = _fake_decorator stencil = _fake_decorator @@ -32,7 +31,7 @@ def wrapper(*args, **kwargs): register_jitable = _fake_decorator intrinsic = _fake_decorator else: - from numba import cfunc, jit, njit, generated_jit, stencil # noqa + from numba import cfunc, jit, njit, stencil # noqa from numba.extending import overload, register_jitable, intrinsic # noqa From d3c037f402d5fb1795432cf04ac82916e25a4b70 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 14:58:28 +0200 Subject: [PATCH 09/18] Remove reference to numba.generated_jit in docs --- docs/experimental.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/experimental.rst b/docs/experimental.rst index 119d8ded..8b5f50d6 100644 --- a/docs/experimental.rst +++ b/docs/experimental.rst @@ -178,8 +178,7 @@ defined on the `Phase` term, called `init_fields`. Additionally, these arrays will be stored on the ``state`` object provided to the sampling function. -2. It supports reasoning about Numba types in a manner - similar to :func:`numba.generated_jit`. +2. It supports reasoning about Numba types. The ``lm``, ``uvw`` and ``chan_freq`` arguments contain the Numba types of the variables supplied to the RIME, while the ``typingctx`` argument contains a Numba From 733aa6775c29610d8bd636c6729ffb297f0fd00d Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 26 Jan 2024 15:11:10 +0200 Subject: [PATCH 10/18] Remove upper bound on numba versioning --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 077cc40d..fce7cadf 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,7 @@ # astropy breaks with numpy 1.15.3 # https://github.com/astropy/astropy/issues/7943 "numpy >= 1.14.0, != 1.15.3", - # https://github.com/ratt-ru/codex-africanus/issues/283 - "numba >= 0.53.1, < 0.59" + "numba >= 0.53.1" ] extras_require = { From eccd074343c11ce163b420e2112ddb1718ab7e98 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 29 Jan 2024 12:04:54 +0200 Subject: [PATCH 11/18] Fix wsclean impotr --- africanus/rime/dask_predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/africanus/rime/dask_predict.py b/africanus/rime/dask_predict.py index 98e56e3e..e25b2ba1 100644 --- a/africanus/rime/dask_predict.py +++ b/africanus/rime/dask_predict.py @@ -17,7 +17,7 @@ predict_vis as np_predict_vis) from africanus.rime.wsclean_predict import ( WSCLEAN_PREDICT_DOCS, - wsclean_predict_impl as wsclean_predict_body) + wsclean_predict_main as wsclean_predict_body) from africanus.model.wsclean.spec_model import spectra as wsclean_spectra From d060c2b2a95abad52e257723dae09d3799ae312e Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 29 Jan 2024 14:30:56 +0200 Subject: [PATCH 12/18] Remove unused prefer_literal kwarg --- africanus/experimental/rime/fused/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index 5af73a10..4f32a6e6 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -37,7 +37,7 @@ def rime_impl(*args): def rime(*args): return rime_impl(*args) - @overload(rime_impl, jit_options=JIT_OPTIONS, prefer_literal=True) + @overload(rime_impl, jit_options=JIT_OPTIONS) def nb_rime(*args): if not len(args) % 2 == 0: raise TypeError(f"len(args) {len(args)} is not divisible by 2") From b6fd6f6de6af17a580545fb72ec927c5633d7c17 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 29 Jan 2024 14:58:18 +0200 Subject: [PATCH 13/18] Disable numba kernel caching in the FUSED rime for the moment --- africanus/experimental/rime/fused/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index 4f32a6e6..1bd2146c 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -33,7 +33,7 @@ def rime_impl_factory(terms, transformers, ncorr): def rime_impl(*args): raise NotImplementedError - @njit(**JIT_OPTIONS) + @njit(nogil=True) def rime(*args): return rime_impl(*args) From 4f368189ea43b3fd379d042a92278493472c2583 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 29 Jan 2024 20:00:12 +0200 Subject: [PATCH 14/18] Create a hash uniquely identifying a RimeSpecification --- .../experimental/rime/fused/specification.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/africanus/experimental/rime/fused/specification.py b/africanus/experimental/rime/fused/specification.py index 1c61d50c..ee15f2d1 100644 --- a/africanus/experimental/rime/fused/specification.py +++ b/africanus/experimental/rime/fused/specification.py @@ -1,4 +1,5 @@ import ast +from hashlib import shake_256 from importlib import import_module import inspect import multiprocessing @@ -11,7 +12,7 @@ from africanus.experimental.rime.fused import terms as term_mod from africanus.experimental.rime.fused.transformers.core import Transformer from africanus.experimental.rime.fused import transformers as transformer_mod -from africanus.util.patterns import LazyProxy +from africanus.util.patterns import freeze, LazyProxy TERM_STRING_REGEX = re.compile("([A-Z])(pq|p|q)") @@ -335,6 +336,8 @@ def __init__(self, specification, terms=None, transformers=None): "process_pool": pool } + hash_elements = list(v for k, v in global_kw.items() if k != "process_pool") + for cls, cfg in zip(term_types, term_cfgs): if cfg == "pq": cfg = "middle" @@ -350,7 +353,7 @@ def __init__(self, specification, terms=None, transformers=None): cls_kw = {} if "configuration" not in init_sig.parameters: - raise RimeSpecification( + raise RimeSpecificationError( f"{cls}.__init__{init_sig} must take a " f"'configuration' argument and call " f"super().__init__(configuration)") @@ -358,7 +361,7 @@ def __init__(self, specification, terms=None, transformers=None): for a, p in list(init_sig.parameters.items())[1:]: if p.kind not in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}: - raise RimeSpecification( + raise RimeSpecificationError( f"{cls}.__init__{init_sig} may not contain " f"*args or **kwargs") @@ -371,6 +374,8 @@ def __init__(self, specification, terms=None, transformers=None): f"Available args: {available_kw}") term = cls(**cls_kw) + hash_elements.append(".".join((cls.__module__, cls.__name__))) + hash_elements.append(cfg) terms.append(term) term_type_set = set(term_types) @@ -385,7 +390,7 @@ def __init__(self, specification, terms=None, transformers=None): transformers = [] - for cls in transformer_types.values(): + for _, cls in sorted(transformer_types.items()): init_sig = inspect.signature(cls.__init__) cls_kw = {} @@ -405,10 +410,12 @@ def __init__(self, specification, terms=None, transformers=None): f"Available args: {available_kw}") transformer = cls(**cls_kw) + hash_elements.append(".".join((cls.__module__, cls.__name__))) transformers.append(transformer) self.terms = terms self.transformers = transformers + self.spec_hash = shake_256(str((freeze(hash_elements))).encode("utf-8")).hexdigest(16) @staticmethod def _finalise_pool(pool): From 730dbaa5181db6d1cd6d30931de0514238e7c86d Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 29 Jan 2024 20:00:44 +0200 Subject: [PATCH 15/18] Reintroduce rime caching by passing the rime spec hash in as the 1st literal arg --- africanus/experimental/rime/fused/core.py | 24 ++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index 1bd2146c..f12a7773 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -30,20 +30,26 @@ def rime_impl_factory(terms, transformers, ncorr): - def rime_impl(*args): - raise NotImplementedError - - @njit(nogil=True) + @njit(**JIT_OPTIONS) def rime(*args): return rime_impl(*args) + def rime_impl(*args): + raise NotImplementedError + @overload(rime_impl, jit_options=JIT_OPTIONS) def nb_rime(*args): - if not len(args) % 2 == 0: - raise TypeError(f"len(args) {len(args)} is not divisible by 2") + if not len(args) > 0: + raise TypeError(f"rime must be called with at least the signature argument") + + if not isinstance(args[0], types.Literal): + raise TypeError(f"Signature hash ({args[0]}) must be a literal") + + if not len(args) % 2 == 1: + raise TypeError(f"Length of named arguments {len(args)} is not divisible by 2") - argstart = len(args) // 2 - names = args[:argstart] + argstart = 1 + (len(args) - 1) // 2 + names = args[1:argstart] if not all(isinstance(n, types.Literal) for n in names): raise TypeError(f"{names} must be a Tuple of Literal strings") @@ -199,7 +205,7 @@ def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs): args = keys + (time, antenna1, antenna2, feed1, feed2) + tuple(kwargs.values()) - return self.impl(*args) + return self.impl(types.literal(self.rime_spec.spec_hash), *args) def consolidate_args(args, kw): From 8ab1fa60662ec96395c1734328a5c4f4f686f434 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 29 Jan 2024 20:17:58 +0200 Subject: [PATCH 16/18] flake8 --- africanus/experimental/rime/fused/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index f12a7773..af4f98fe 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -40,13 +40,15 @@ def rime_impl(*args): @overload(rime_impl, jit_options=JIT_OPTIONS) def nb_rime(*args): if not len(args) > 0: - raise TypeError(f"rime must be called with at least the signature argument") + raise TypeError("rime must be at least be called " + "with the signature argument") if not isinstance(args[0], types.Literal): raise TypeError(f"Signature hash ({args[0]}) must be a literal") if not len(args) % 2 == 1: - raise TypeError(f"Length of named arguments {len(args)} is not divisible by 2") + raise TypeError(f"Length of named arguments {len(args)} " + f"is not divisible by 2") argstart = 1 + (len(args) - 1) // 2 names = args[1:argstart] From c3e9c12b171c751bd1de702be2b1846053e89c2c Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 29 Jan 2024 20:52:57 +0200 Subject: [PATCH 17/18] more flake8 --- africanus/experimental/rime/fused/specification.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/africanus/experimental/rime/fused/specification.py b/africanus/experimental/rime/fused/specification.py index ee15f2d1..4966a7bf 100644 --- a/africanus/experimental/rime/fused/specification.py +++ b/africanus/experimental/rime/fused/specification.py @@ -336,7 +336,8 @@ def __init__(self, specification, terms=None, transformers=None): "process_pool": pool } - hash_elements = list(v for k, v in global_kw.items() if k != "process_pool") + hash_elements = list(v for k, v in global_kw.items() + if k != "process_pool") for cls, cfg in zip(term_types, term_cfgs): if cfg == "pq": @@ -415,7 +416,8 @@ def __init__(self, specification, terms=None, transformers=None): self.terms = terms self.transformers = transformers - self.spec_hash = shake_256(str((freeze(hash_elements))).encode("utf-8")).hexdigest(16) + str_elements = str((freeze(hash_elements))).encode("utf-8") + self.spec_hash = shake_256(str_elements).hexdigest(16) @staticmethod def _finalise_pool(pool): From 6ee5f3dfc64c182acb81e5c67dbbb0b58d95f476 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 30 Jan 2024 08:10:02 +0200 Subject: [PATCH 18/18] [skip ci] Update HISTORY.rst --- HISTORY.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/HISTORY.rst b/HISTORY.rst index 1a04b858..cb2f36ea 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Deprecate use of @generated_jit. Remove upper bound on numba. (:pr:`289`) * Remove unnecessary new_axes in calibration utils after upstream fix in dask (:pr:`288`) * Check that ncorr is never larger than 2 in calibration utils (:pr:`287`) * Optionally check NRT allocations (:pr:`286`)