Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Jan 26, 2024
1 parent b9a27e5 commit ba32c64
Show file tree
Hide file tree
Showing 20 changed files with 561 additions and 227 deletions.
186 changes: 124 additions & 62 deletions africanus/averaging/bda_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
merge_flags,
vis_output_arrays)
from africanus.util.docs import DocstringTemplate
from africanus.util.numba import (generated_jit,
from africanus.util.numba import (njit,
overload,
JIT_OPTIONS,
intrinsic,
is_numba_type_none)

Expand All @@ -20,11 +22,24 @@
RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields)


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def row_average(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
return row_average_impl(meta, ant1, ant2, flag_row=flag_row,
time_centroid=time_centroid, exposure=exposure, uvw=uvw,
weight=weight, sigma=sigma)

def row_average_impl(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
return NotImplementedError


@overload(row_average_impl, jit_options=JIT_OPTIONS)
def nb_row_average_impl(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
have_flag_row = not is_numba_type_none(flag_row)
have_time_centroid = not is_numba_type_none(time_centroid)
have_exposure = not is_numba_type_none(exposure)
Expand Down Expand Up @@ -310,13 +325,36 @@ def codegen(context, builder, signature, args):
return sig, codegen


@generated_jit(nopython=True, nogil=True, cache=True)


@njit(**JIT_OPTIONS)
def row_chan_average(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

return row_chan_average_impl(meta, flag_row=flag_row, weight=weight,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)


def row_chan_average_impl(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

return NotImplementedError

@overload(row_chan_average_impl, jit_options=JIT_OPTIONS)
def nb_row_chan_average(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

have_vis = not is_numba_type_none(visibilities)
have_flag = not is_numba_type_none(flag)
have_flag_row = not is_numba_type_none(flag_row)
Expand Down Expand Up @@ -523,7 +561,7 @@ def impl(meta, flag_row=None, weight=None,
_rowchan_output_fields)


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def bda(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
Expand All @@ -535,65 +573,89 @@ def bda(time, interval, antenna1, antenna2,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
def impl(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
chan_freq=None, chan_width=None,
effective_bw=None, resolution=None,
visibilities=None, flag=None,
weight_spectrum=None, sigma_spectrum=None,
max_uvw_dist=None, max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
# Merge flag_row and flag arrays
flag_row = merge_flags(flag_row, flag)

meta = bda_mapper(time, interval, antenna1, antenna2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan)

row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841
time_centroid, exposure, uvw,
weight=weight, sigma=sigma)

row_chan_avg = row_chan_average(meta, # noqa: F841
flag_row=flag_row,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)

# Have to explicitly write it out because numba tuples
# are highly constrained types
return AverageOutput(meta.map,
meta.offsets,
meta.decorr_chan_width,
meta.time,
meta.interval,
meta.chan_width,
meta.flag_row,
row_avg.antenna1,
row_avg.antenna2,
row_avg.time_centroid,
row_avg.exposure,
row_avg.uvw,
row_avg.weight,
row_avg.sigma,
# None, # chan_data.chan_freq,
# None, # chan_data.chan_width,
# None, # chan_data.effective_bw,
# None, # chan_data.resolution,
row_chan_avg.visibilities,
row_chan_avg.flag,
row_chan_avg.weight_spectrum,
row_chan_avg.sigma_spectrum)

return impl
return bda_impl(time, interval, antenna1, antenna2,
time_centroid=time_centroid, exposure=exposure, flag_row=flag_row,
uvw=uvw, weight=weight, sigma=sigma,
chan_freq=chan_freq, chan_width=chan_width,
effective_bw=effective_bw, resolution=resolution,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum,
max_uvw_dist=max_uvw_dist, max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs, min_nchan=min_nchan)

def bda_impl(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
chan_freq=None, chan_width=None,
effective_bw=None, resolution=None,
visibilities=None, flag=None,
weight_spectrum=None, sigma_spectrum=None,
max_uvw_dist=None, max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
return NotImplementedError

@overload(bda_impl, jit_options=JIT_OPTIONS)
def nb_bda_impl(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
chan_freq=None, chan_width=None,
effective_bw=None, resolution=None,
visibilities=None, flag=None,
weight_spectrum=None, sigma_spectrum=None,
max_uvw_dist=None, max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
# Merge flag_row and flag arrays
flag_row = merge_flags(flag_row, flag)

meta = bda_mapper(time, interval, antenna1, antenna2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan)

row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841
time_centroid, exposure, uvw,
weight=weight, sigma=sigma)

row_chan_avg = row_chan_average(meta, # noqa: F841
flag_row=flag_row,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)

# Have to explicitly write it out because numba tuples
# are highly constrained types
return AverageOutput(meta.map,
meta.offsets,
meta.decorr_chan_width,
meta.time,
meta.interval,
meta.chan_width,
meta.flag_row,
row_avg.antenna1,
row_avg.antenna2,
row_avg.time_centroid,
row_avg.exposure,
row_avg.uvw,
row_avg.weight,
row_avg.sigma,
# None, # chan_data.chan_freq,
# None, # chan_data.chan_width,
# None, # chan_data.effective_bw,
# None, # chan_data.resolution,
row_chan_avg.visibilities,
row_chan_avg.flag,
row_chan_avg.weight_spectrum,
row_chan_avg.sigma_spectrum)


BDA_DOCS = DocstringTemplate("""
Expand Down
Loading

0 comments on commit ba32c64

Please sign in to comment.