Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate @generated_jit #289

Merged
merged 18 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Deprecate use of @generated_jit. Remove upper bound on numba. (:pr:`289`)
* Remove unnecessary new_axes in calibration utils after upstream fix in dask (:pr:`288`)
* Check that ncorr is never larger than 2 in calibration utils (:pr:`287`)
* Optionally check NRT allocations (:pr:`286`)
Expand Down
171 changes: 118 additions & 53 deletions africanus/averaging/bda_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
merge_flags,
vis_output_arrays)
from africanus.util.docs import DocstringTemplate
from africanus.util.numba import (generated_jit,
from africanus.util.numba import (njit,
overload,
JIT_OPTIONS,
intrinsic,
is_numba_type_none)

Expand All @@ -20,11 +22,25 @@
RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields)


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def row_average(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
return row_average_impl(meta, ant1, ant2, flag_row=flag_row,
time_centroid=time_centroid, exposure=exposure,
uvw=uvw, weight=weight, sigma=sigma)


def row_average_impl(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
return NotImplementedError


@overload(row_average_impl, jit_options=JIT_OPTIONS)
def nb_row_average_impl(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
have_flag_row = not is_numba_type_none(flag_row)
have_time_centroid = not is_numba_type_none(time_centroid)
have_exposure = not is_numba_type_none(exposure)
Expand Down Expand Up @@ -310,13 +326,35 @@ def codegen(context, builder, signature, args):
return sig, codegen


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def row_chan_average(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

return row_chan_average_impl(meta, flag_row=flag_row, weight=weight,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)


def row_chan_average_impl(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

return NotImplementedError


@overload(row_chan_average_impl, jit_options=JIT_OPTIONS)
def nb_row_chan_average(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

have_vis = not is_numba_type_none(visibilities)
have_flag = not is_numba_type_none(flag)
have_flag_row = not is_numba_type_none(flag_row)
Expand Down Expand Up @@ -523,7 +561,7 @@ def impl(meta, flag_row=None, weight=None,
_rowchan_output_fields)


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def bda(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
Expand All @@ -535,7 +573,21 @@ def bda(time, interval, antenna1, antenna2,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
def impl(time, interval, antenna1, antenna2,

return bda_impl(time, interval, antenna1, antenna2,
time_centroid=time_centroid, exposure=exposure,
flag_row=flag_row, uvw=uvw, weight=weight, sigma=sigma,
chan_freq=chan_freq, chan_width=chan_width,
effective_bw=effective_bw, resolution=resolution,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum,
max_uvw_dist=max_uvw_dist, max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs, min_nchan=min_nchan)


def bda_impl(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
chan_freq=None, chan_width=None,
Expand All @@ -546,54 +598,67 @@ def impl(time, interval, antenna1, antenna2,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
# Merge flag_row and flag arrays
flag_row = merge_flags(flag_row, flag)

meta = bda_mapper(time, interval, antenna1, antenna2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan)

row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841
time_centroid, exposure, uvw,
weight=weight, sigma=sigma)

row_chan_avg = row_chan_average(meta, # noqa: F841
flag_row=flag_row,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)

# Have to explicitly write it out because numba tuples
# are highly constrained types
return AverageOutput(meta.map,
meta.offsets,
meta.decorr_chan_width,
meta.time,
meta.interval,
meta.chan_width,
meta.flag_row,
row_avg.antenna1,
row_avg.antenna2,
row_avg.time_centroid,
row_avg.exposure,
row_avg.uvw,
row_avg.weight,
row_avg.sigma,
# None, # chan_data.chan_freq,
# None, # chan_data.chan_width,
# None, # chan_data.effective_bw,
# None, # chan_data.resolution,
row_chan_avg.visibilities,
row_chan_avg.flag,
row_chan_avg.weight_spectrum,
row_chan_avg.sigma_spectrum)

return impl
return NotImplementedError


@overload(bda_impl, jit_options=JIT_OPTIONS)
def nb_bda_impl(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
chan_freq=None, chan_width=None,
effective_bw=None, resolution=None,
visibilities=None, flag=None,
weight_spectrum=None, sigma_spectrum=None,
max_uvw_dist=None, max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
# Merge flag_row and flag arrays
flag_row = merge_flags(flag_row, flag)

meta = bda_mapper(time, interval, antenna1, antenna2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan)

row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841
time_centroid, exposure, uvw,
weight=weight, sigma=sigma)

row_chan_avg = row_chan_average(meta, # noqa: F841
flag_row=flag_row,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)

# Have to explicitly write it out because numba tuples
# are highly constrained types
return AverageOutput(meta.map,
meta.offsets,
meta.decorr_chan_width,
meta.time,
meta.interval,
meta.chan_width,
meta.flag_row,
row_avg.antenna1,
row_avg.antenna2,
row_avg.time_centroid,
row_avg.exposure,
row_avg.uvw,
row_avg.weight,
row_avg.sigma,
# None, # chan_data.chan_freq,
# None, # chan_data.chan_width,
# None, # chan_data.effective_bw,
# None, # chan_data.resolution,
row_chan_avg.visibilities,
row_chan_avg.flag,
row_chan_avg.weight_spectrum,
row_chan_avg.sigma_spectrum)


BDA_DOCS = DocstringTemplate("""
Expand Down
112 changes: 39 additions & 73 deletions africanus/averaging/bda_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,21 @@
import numpy as np
import numba
from numba.experimental import jitclass
import numba.types
from numba import types

from africanus.constants import c as lightspeed
from africanus.util.numba import generated_jit, njit, is_numba_type_none
from africanus.util.numba import (
JIT_OPTIONS,
overload,
njit,
is_numba_type_none)
from africanus.averaging.support import unique_time, unique_baselines


class RowMapperError(Exception):
pass


@njit(nogil=True, cache=True)
def erf26(x):
"""Implements 7.1.26 erf approximation from Abramowitz and
Stegun (1972), pg. 299. Accurate for abs(eps(x)) <= 1.5e-7."""

# Constants
p = 0.3275911
a1 = 0.254829592
a2 = -0.284496736
a3 = 1.421413741
a4 = -1.453152027
a5 = 1.061405429
e = 2.718281828

# t
t = 1.0/(1.0 + (p * x))

# Erf calculation
erf = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t)
erf *= e ** -(x ** 2)

return -round(erf, 9) if x < 0 else round(erf, 0)


@njit(nogil=True, cache=True)
def time_decorrelation(u, v, w, max_lm, time_bin_secs, min_wavelength):
sidereal_rotation_rate = 7.292118516e-5
diffraction_limit = min_wavelength / np.sqrt(u**2 + v**2 + w**2)
term = max_lm * time_bin_secs * sidereal_rotation_rate / diffraction_limit
return 1.0 - 1.0645 * erf26(0.8326*term) / term


_SERIES_COEFFS = (1./40, 107./67200, 3197./24192000, 49513./3973939200)


@njit(nogil=True, cache=True, inline='always')
def inv_sinc(sinc_x, tol=1e-12):
# Invalid input
if sinc_x > 1.0:
raise ValueError("sinc_x > 1.0")

# Initial guess from reversion of Taylor series
# https://math.stackexchange.com/questions/3189307/inverse-of-frac-sinxx
x = t_pow = np.sqrt(6*np.abs((1 - sinc_x)))
t_squared = t_pow*t_pow

for coeff in numba.literal_unroll(_SERIES_COEFFS):
t_pow *= t_squared
x += coeff * t_pow

# Use Newton Raphson to go the rest of the way
# https://www.wolframalpha.com/input/?i=simplify+%28sinc%5Bx%5D+-+c%29+%2F+D%5Bsinc%5Bx%5D%2Cx%5D
while True:
# evaluate delta between this iteration sinc(x) and original
sinx = np.sin(x)
𝞓sinc_x = (1.0 if x == 0.0 else sinx/x) - sinc_x

# Stop if converged
if np.abs(𝞓sinc_x) < tol:
break

# Next iteration
x -= (x*x * 𝞓sinc_x) / (x*np.cos(x) - sinx)

return x


@njit(nogil=True, cache=True, inline='always')
def factors(n):
assert n >= 1
Expand Down Expand Up @@ -126,7 +63,7 @@ def max_chan_width(ref_freq, fractional_bandwidth):
"nchan", "flag"])


class Binner(object):
class Binner:
def __init__(self, row_start, row_end,
max_lm, decorrelation, time_bin_secs,
max_chan_freq):
Expand Down Expand Up @@ -338,7 +275,7 @@ def finalise_bin(self, auto_corr, uvw, time, interval,
"time", "interval", "chan_width", "flag_row"])


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def bda_mapper(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
Expand All @@ -347,10 +284,39 @@ def bda_mapper(time, interval, ant1, ant2, uvw,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):

return bda_mapper_impl(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan)


def bda_mapper_impl(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=None,
max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
return NotImplementedError


@overload(bda_mapper_impl, jit_options={"nogil": True})
def nb_bda_mapper(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=None,
max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
have_time_bin_secs = not is_numba_type_none(time_bin_secs)

Omitted = numba.types.misc.Omitted
Omitted = types.misc.Omitted

decorr_type = (numba.typeof(decorrelation.value)
if isinstance(decorrelation, Omitted)
Expand Down
Loading