Skip to content

Commit

Permalink
Initial hacky commit of zarr support.
Browse files Browse the repository at this point in the history
  • Loading branch information
JSKenyon committed Oct 18, 2023
1 parent c88bc52 commit cbacc25
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 24 deletions.
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
readme = readme_file.read()

requirements = [
'dask[array] == 2021.2.0',
'donfig >= 0.4.0',
'numpy >= 1.14.0, <= 1.19.5', # breakage in newer numpy + numerical errors
'numba >= 0.43.0',
'scipy >= 1.2.0',
'threadpoolctl >= 1.0.0',
'dask-ms == 0.2.6',
'zarr >= 2.3.1'
'dask-ms[xarray,zarr,s3]'
]

extras_require = {'testing': ['pytest',
Expand Down
24 changes: 15 additions & 9 deletions tricolour/apps/tricolour/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
ResourceProfiler,
CacheProfiler, visualize)
import numpy as np
from daskms import xds_from_ms, xds_from_table, xds_to_table
from daskms import (xds_from_storage_ms,
xds_from_storage_table,
xds_to_storage_table)
from threadpoolctl import threadpool_limits

from tricolour.apps.tricolour.strat_executor import StrategyExecutor
Expand Down Expand Up @@ -229,10 +231,10 @@ def support_tables(ms):
"""

# Get datasets for sub-tables partitioned by row when variably shaped
support = {t: xds_from_table("::".join((ms, t)), group_cols="__row__")
support = {t: xds_from_storage_table("::".join((ms, t)), group_cols="__row__")
for t in ["FIELD", "POLARIZATION", "SPECTRAL_WINDOW"]}
# These columns have fixed shapes
support.update({t: xds_from_table("::".join((ms, t)))[0]
support.update({t: xds_from_storage_table("::".join((ms, t)))[0]
for t in ["ANTENNA", "DATA_DESCRIPTION"]})

# Reify all values upfront
Expand Down Expand Up @@ -291,7 +293,7 @@ def _main(args):
if args.subtract_model_column is not None:
columns.append(args.subtract_model_column)

xds = list(xds_from_ms(args.ms,
xds = list(xds_from_storage_ms(args.ms,
columns=tuple(columns),
group_cols=group_cols,
index_cols=index_cols,
Expand Down Expand Up @@ -347,6 +349,7 @@ def _main(args):
field_dict = {i: fn for i, fn in enumerate(fieldnames)}

# List which hold our dask compute graphs for each dataset
writable_xds = []
write_computes = []
original_stats = []
final_stats = []
Expand Down Expand Up @@ -386,7 +389,7 @@ def _main(args):
# Generate unflagged defaults if we should ignore existing flags
# otherwise take flags from the dataset
if args.ignore_flags is True:
flags = da.full_like(vis, False, dtype=np.bool)
flags = da.full_like(vis, False, dtype=np.bool_)
log.critical("Completely ignoring measurement set "
"flags as per '-if' request. "
"Strategy WILL NOT or with original flags, even if "
Expand Down Expand Up @@ -471,10 +474,13 @@ def _main(args):
# Create new dataset containing new flags
new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags))

# Write back to original dataset
writes = xds_to_table(new_ds, args.ms, "FLAG")
# original should also have .compute called because we need stats
write_computes.append(writes)
# Append to list of datasets we intend to write to disk.
writable_xds.append(new_ds)

# Write back to original dataset
write_computes = xds_to_storage_table(
writable_xds, args.ms, columns=("FLAG",), rechunk=True
)

if len(write_computes) > 0:
# Combine stats from all datasets
Expand Down
2 changes: 1 addition & 1 deletion tricolour/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def load_mask(filename, dilate):
# Load mask
mask = np.load(filename)

if mask.dtype[0] != np.bool or mask.dtype[1] != np.float64:
if mask.dtype[0] != np.bool_ or mask.dtype[1] != np.float64:
raise ValueError("Mask %s is not a valid static mask "
"with labelled channel axis "
"[dtype == (bool, float64)]" % filename)
Expand Down
4 changes: 2 additions & 2 deletions tricolour/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _create_window_dask(name, ntime, nchan, nbl, ncorr, token,

graph = HighLevelGraph.from_collections(collection_name, layers, ())
chunks = ((0,),) # One chunk containing single zarr array object
return da.Array(graph, collection_name, chunks, dtype=np.object)
return da.Array(graph, collection_name, chunks, dtype=object)


def create_vis_windows(ntime, nchan, nbl, ncorr, token,
Expand Down Expand Up @@ -343,7 +343,7 @@ def pack_data(time_inv, ubl,
flags, ("row", "chan", "corr"),
vis_win_obj, ("windim",),
flag_win_obj, ("windim",),
dtype=np.bool)
dtype=np.bool_)

# Expose visibility data at it's full resolution
vis_windows = da.blockwise(_packed_windows, _WINDOW_SCHEMA,
Expand Down
8 changes: 4 additions & 4 deletions tricolour/tests/test_flagging_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines,
accumulation_mode="or")

# Check that first mask's flags are applied
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool)
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_)
chan_sel[[2, 10]] = True

assert np.all(new_flags[:, :, :, chan_sel] == 1)
Expand All @@ -144,7 +144,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines,
accumulation_mode="or")

# Check that both mask's flags have been applied
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool)
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_)
chan_sel[[2, 10, 4, 11, 5]] = True

assert np.all(new_flags[:, :, :, chan_sel] == 1)
Expand All @@ -157,7 +157,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines,
accumulation_mode="override")

# Check that only last mask's flags applied
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool)
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_)
chan_sel[[4, 11, 5]] = True

assert np.all(new_flags[:, :, :, chan_sel] == 1)
Expand All @@ -176,7 +176,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines,
uvrange=uvrange)

# Check that both mask's flags have been applied
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool)
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_)
chan_sel[[2, 10, 4, 11, 5]] = True

# Select baselines based on the uvrange
Expand Down
8 changes: 4 additions & 4 deletions tricolour/window_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,21 @@ def window_stats(flag_window, ubls, chan_freqs,
field_name, None,
ddid, None,
nchanbins, None,
meta=np.empty((0,), dtype=np.object))
meta=np.empty((0,), dtype=object))

# Create an empty stats object if the user hasn't supplied one
if prev_stats is None:
def _window_stat_creator():
return WindowStatistics(nchanbins)

prev_stats = da.blockwise(_window_stat_creator, (),
meta=np.empty((), dtype=np.object))
meta=np.empty((), dtype=object))

# Combine per-baseline stats into a single stats object
return da.blockwise(_combine_baseline_window_stats, (),
stats, ("bl",),
prev_stats, (),
meta=np.empty((), dtype=np.object))
meta=np.empty((), dtype=object))


def _combine_window_stats(*args):
Expand Down Expand Up @@ -167,7 +167,7 @@ def combine_window_stats(window_stats):
args = (v for ws in window_stats for v in (ws, ()))

return da.blockwise(_combine_window_stats, (),
*args, dtype=np.object)
*args, dtype=object)


class WindowStatistics(object):
Expand Down

0 comments on commit cbacc25

Please sign in to comment.