Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bda overload to return an implementation #307

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 119 additions & 58 deletions africanus/averaging/bda_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import namedtuple

import numpy as np
from numba import types

from africanus.averaging.bda_mapping import bda_mapper, RowMapOutput
from africanus.averaging.shared import chan_corrs, merge_flags, vis_output_arrays
Expand Down Expand Up @@ -757,72 +758,132 @@ def nb_bda_impl(
time_bin_secs=None,
min_nchan=1,
):
# Merge flag_row and flag arrays
flag_row = merge_flags(flag_row, flag)
if is_numba_type_none(chan_width):
return TypeError(f"chan_width must be provided")

meta = bda_mapper(
if is_numba_type_none(chan_freq):
return TypeError(f"chan_freq must be provided")

if is_numba_type_none(uvw):
raise TypeError(f"uvw must be provided")

valid_types = (
types.misc.NoneType,
types.misc.Omitted,
types.scalars.Float,
types.scalars.Integer,
)

if not isinstance(max_uvw_dist, valid_types):
raise TypeError(f"max_uvw_dist ({max_uvw_dist}) must be a scalar float")

if not isinstance(max_fov, valid_types):
raise TypeError(f"max_fov ({max_fov}) must be a scalar float")

if not isinstance(decorrelation, valid_types):
raise TypeError(f"decorrelation ({decorrelation}) must be a scalar float")

if not isinstance(time_bin_secs, valid_types):
raise TypeError(f"time_bin_secs ({time_bin_secs}) must be a scalar float")

valid_types = (types.misc.NoneType, types.misc.Omitted, types.scalars.Integer)

if not isinstance(min_nchan, valid_types):
raise TypeError(f"min_nchan ({min_nchan}) must be an integer")

def impl(
time,
interval,
antenna1,
antenna2,
uvw,
chan_width,
chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan,
)
time_centroid=None,
exposure=None,
flag_row=None,
uvw=None,
weight=None,
sigma=None,
chan_freq=None,
chan_width=None,
effective_bw=None,
resolution=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None,
max_uvw_dist=None,
max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1,
):
# Merge flag_row and flag arrays
flag_row = merge_flags(flag_row, flag)

meta = bda_mapper(
time,
interval,
antenna1,
antenna2,
uvw,
chan_width,
chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan,
)

row_avg = row_average(
meta,
antenna1,
antenna2,
flag_row, # noqa: F841
time_centroid,
exposure,
uvw,
weight=weight,
sigma=sigma,
)
row_avg = row_average(
meta,
antenna1,
antenna2,
flag_row, # noqa: F841
time_centroid,
exposure,
uvw,
weight=weight,
sigma=sigma,
)

row_chan_avg = row_chan_average(
meta, # noqa: F841
flag_row=flag_row,
visibilities=visibilities,
flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum,
)
row_chan_avg = row_chan_average(
meta, # noqa: F841
flag_row=flag_row,
visibilities=visibilities,
flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum,
)

# Have to explicitly write it out because numba tuples
# are highly constrained types
return AverageOutput(
meta.map,
meta.offsets,
meta.decorr_chan_width,
meta.time,
meta.interval,
meta.chan_width,
meta.flag_row,
row_avg.antenna1,
row_avg.antenna2,
row_avg.time_centroid,
row_avg.exposure,
row_avg.uvw,
row_avg.weight,
row_avg.sigma,
# None, # chan_data.chan_freq,
# None, # chan_data.chan_width,
# None, # chan_data.effective_bw,
# None, # chan_data.resolution,
row_chan_avg.visibilities,
row_chan_avg.flag,
row_chan_avg.weight_spectrum,
row_chan_avg.sigma_spectrum,
)
# Have to explicitly write it out because numba tuples
# are highly constrained types
return AverageOutput(
meta.map,
meta.offsets,
meta.decorr_chan_width,
meta.time,
meta.interval,
meta.chan_width,
meta.flag_row,
row_avg.antenna1,
row_avg.antenna2,
row_avg.time_centroid,
row_avg.exposure,
row_avg.uvw,
row_avg.weight,
row_avg.sigma,
# None, # chan_data.chan_freq,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why chan_freq is not returned?

Copy link
Member Author

@sjperkins sjperkins May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Part of the reason is that it technically transforms to (row, chan) dimensionality as the frequencies and channel widths can also vary per row. This sort of data is available on the meta object, higher up in this namedtuple construction. There's a chan_width member with (row, chan) dimensionality. However, there isn't a chan_freq member. Do you need it? I suppose it wouldn't be necessary if one were only doing time BDA.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I don't need it for this reason but just wanted to make sure it's not an oversight.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should add that when doing time only BDA you shoudn't need uvw to be expanded as rowchan. This actually uses quite a bit of memory. It's not a train smash for now but it might be worth adding a time only BDA function in the future that doesn't do this

# None, # chan_data.chan_width,
# None, # chan_data.effective_bw,
# None, # chan_data.resolution,
row_chan_avg.visibilities,
row_chan_avg.flag,
row_chan_avg.weight_spectrum,
row_chan_avg.sigma_spectrum,
)

return impl


BDA_DOCS = DocstringTemplate(
Expand Down
16 changes: 14 additions & 2 deletions africanus/averaging/tests/test_bda_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from africanus.averaging.bda_mapping import RowMapOutput
from africanus.averaging.bda_avg import row_average, row_chan_average
from africanus.averaging.bda_avg import bda as bda_avg, row_average, row_chan_average
from africanus.averaging.dask import bda as dask_bda


Expand Down Expand Up @@ -94,7 +94,7 @@ def _calc_sigma(weight, sigma, rows):
return np.sqrt(numerator / denominator)


def test_bda_avg(bda_test_map, inv_bda_test_map, flags):
def test_bda_avg_in_parts(bda_test_map, inv_bda_test_map, flags):
rs = np.random.RandomState(42)

# Derive flag_row from flags
Expand All @@ -119,6 +119,7 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags):
weight = rs.normal(size=(in_row, in_corr))
sigma = rs.normal(size=(in_row, in_corr))
chan_width = np.repeat(0.856e9 / out_chan, out_chan)
chan_freq = np.linspace(0.856e9, 2 * 0.856e9, chan_width.size)

# Aggregate time and interval, in_row => out_row
# first channel in the map. We're only averaging over
Expand Down Expand Up @@ -236,6 +237,17 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags):
assert_array_almost_equal(row_chan_avg.weight_spectrum, out_ws)
assert_array_almost_equal(row_chan_avg.sigma_spectrum, out_ss)

result = bda_avg(
time=time,
interval=interval,
antenna1=ant1,
antenna2=ant2,
visibilities=vis,
uvw=uvw,
chan_width=chan_width,
chan_freq=chan_freq,
)


@pytest.mark.parametrize("vis_format", ["ragged", "flat"])
def test_dask_bda_avg(vis_format):
Expand Down