Skip to content

Commit

Permalink
Identify channel and correlation-like dimensions in non-standard MS c…
Browse files Browse the repository at this point in the history
…olumns (#329)
  • Loading branch information
sjperkins authored May 13, 2024
1 parent b436c0e commit 747408f
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 16 deletions.
2 changes: 2 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Improve table schema handling (:pr:`329`)
* Identify channel and correlation-like dimensions in non-standard MS columns (:pr:`329`)
* DaskMSStore depends on ``fsspec >= 2022.7.0`` (:pr:`328`)
* Optimise `broadcast_arrays` in katdal import (:pr:`326`)
* Change `dask-ms katdal import` to `dask-ms import katdal` (:pr:`325`)
Expand Down
18 changes: 15 additions & 3 deletions daskms/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,23 @@ def column_metadata(column, table_proxy, table_schema, chunks, exemplar_row=0):
"match shape of exemplar=%s" % (ndim, shape)
)

# Extract dimension schema
# Get the column schema, or create a default
try:
dims = table_schema[column]["dims"]
column_schema = table_schema[column]
except KeyError:
dims = tuple("%s-%d" % (column, i) for i in range(1, len(shape) + 1))
column_schema = {
"dims": tuple("%s-%d" % (column, i) for i in range(1, len(shape) + 1))
}

try:
dims = column_schema["dims"]
except KeyError:
raise ColumnMetadataError(
f"Column schema {column_schema} does not contain required 'dims' attribute"
)

if not isinstance(dims, tuple) or not all(isinstance(d, str) for d in dims):
raise ColumnMetadataError(f"Dimensions {dims} is not a tuple of strings")

dim_chunks = []

Expand Down
7 changes: 6 additions & 1 deletion daskms/dask_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,12 @@ def xds_from_ms(ms, columns=None, index_cols=None, group_cols=None, **kwargs):
kwargs.setdefault("table_schema", "MS")

return xds_from_table(
ms, columns=columns, index_cols=index_cols, group_cols=group_cols, **kwargs
ms,
columns=columns,
index_cols=index_cols,
group_cols=group_cols,
context="ms",
**kwargs,
)


Expand Down
63 changes: 60 additions & 3 deletions daskms/reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def __init__(self, table, select_cols, group_cols, index_cols, **kwargs):
self.table_keywords = kwargs.pop("table_keywords", False)
self.column_keywords = kwargs.pop("column_keywords", False)
self.table_proxy = kwargs.pop("table_proxy", False)
self.context = kwargs.pop("context", None)

if len(kwargs) > 0:
raise ValueError(f"Unhandled kwargs: {kwargs}")
Expand Down Expand Up @@ -359,8 +360,10 @@ def _single_dataset(self, table_proxy, orders, exemplar_row=0):
coords = {"ROWID": rowid}

attrs = {DASKMS_PARTITION_KEY: ()}

return Dataset(variables, coords=coords, attrs=attrs)
dataset = Dataset(variables, coords=coords, attrs=attrs)
return self.postprocess_dataset(
dataset, table_proxy, exemplar_row, orders, self.chunks[0], short_table_name
)

def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
_, t, s = table_path_split(self.canonical_name)
Expand Down Expand Up @@ -420,10 +423,64 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
group_id = [gid.item() for gid in group_id]
attrs.update(zip(self.group_cols, group_id))

datasets.append(Dataset(group_var_dims, attrs=attrs, coords=coords))
dataset = Dataset(group_var_dims, attrs=attrs, coords=coords)
dataset = self.postprocess_dataset(
dataset, table_proxy, exemplar_row, order, group_chunks, array_suffix
)
datasets.append(dataset)

return datasets

def postprocess_dataset(
self, dataset, table_proxy, exemplar_row, order, chunks, array_suffix
):
if not self.context or self.context != "ms":
return dataset

# Fixup any non-standard columns
# with dimensions like chan and corr
try:
chan = dataset.sizes["chan"]
corr = dataset.sizes["corr"]
except KeyError:
return dataset

schema_updates = {}

for name, var in dataset.data_vars.items():
new_dims = list(var.dims[1:])

unassigned = {"chan", "corr"}

for dim, dim_name in enumerate(var.dims[1:]):
# An automicatically assigned dimension name
if dim_name == f"{name}-{dim + 1}":
if dataset.sizes[dim_name] == chan and "chan" in unassigned:
new_dims[dim] = "chan"
unassigned.discard("chan")
elif dataset.sizes[dim_name] == corr and "corr" in unassigned:
new_dims[dim] = "corr"
unassigned.discard("corr")

new_dims = tuple(new_dims)
if var.dims[1:] != new_dims:
schema_updates[name] = {"dims": new_dims}

if not schema_updates:
return dataset

return dataset.assign(
**_dataset_variable_factory(
table_proxy,
schema_updates,
list(schema_updates.keys()),
exemplar_row,
order,
chunks,
array_suffix,
)
)

def datasets(self):
table_proxy = self._table_proxy_factory()

Expand Down
21 changes: 12 additions & 9 deletions daskms/table_schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-

from copy import deepcopy

try:
from collections.abc import Mapping
except ImportError:
Expand Down Expand Up @@ -158,19 +160,20 @@ def lookup_table_schema(table_name, lookup_str):
A dictionary of the form
:code:`{column: {'dims': (...)}}`.
"""
if lookup_str is None:
table_type = infer_table_type(table_name)
table_type = infer_table_type(table_name)

try:
return _ALL_SCHEMAS[table_type]
except KeyError:
raise ValueError(f"No schema registered " f"for table type '{table_type}'")
# Infer a base schema from the inferred table, if available
try:
table_schema = deepcopy(_ALL_SCHEMAS[table_type])
except KeyError:
table_schema = {}

if not isinstance(lookup_str, (tuple, list)):
if lookup_str is None:
lookup_str = []
elif not isinstance(lookup_str, (tuple, list)):
lookup_str = [lookup_str]

table_schema = {}

# Add extra schema information to the table
for ls in lookup_str:
if isinstance(ls, Mapping):
table_schema.update(ls)
Expand Down
5 changes: 5 additions & 0 deletions daskms/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ def test_dataset_table_schemas(ms):
table_schema = ["MS", {"DATA": {"dims": data_dims}}]
datasets = read_datasets(ms, [], [], [], table_schema=table_schema)
assert datasets[0].data_vars["DATA"].dims == ("row",) + data_dims
assert datasets[0].data_vars["UVW"].dims == ("row", "uvw")

datasets = read_datasets(ms, [], [], [], table_schema={"DATA": {"dims": data_dims}})
assert datasets[0].data_vars["DATA"].dims == ("row",) + data_dims
assert datasets[0].data_vars["UVW"].dims == ("row", "uvw")


@pytest.mark.parametrize(
Expand Down
35 changes: 35 additions & 0 deletions daskms/tests/test_ms_read_and_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,38 @@ def test_mismatched_rowid(ms):

def test_request_rowid(ms):
xdsl = xds_from_ms(ms, columns=["TIME", "ROWID"]) # noqa


def test_postprocess_ms(ms):
"""Test that postprocessing of MS variables identifies chan/corr like data"""
xdsl = xds_from_ms(ms)

def _array(ds, dims):
shape = tuple(ds.sizes[d] for d in dims)
chunks = tuple(ds.chunks[d] for d in dims)
return (dims, da.random.random(size=shape, chunks=chunks))

# Write some non-standard columns back to the MS
for i, ds in enumerate(xdsl):
xdsl[i] = ds.assign(
**{
"BITFLAG": _array(ds, ("row", "chan", "corr")),
"HAS_CORRS": _array(ds, ("row", "corr")),
"HAS_CHANS": _array(ds, ("row", "chan")),
}
)

dask.compute(xds_to_table(xdsl, ms))

for ds in xds_from_ms(ms, chunks={"row": 1, "chan": 1, "corr": 1}):
assert ds.BITFLAG.dims == ("row", "chan", "corr")

assert ds.HAS_CORRS.dims == ("row", "corr")
assert ds.HAS_CHANS.dims == ("row", "chan")

assert dict(ds.chunks) == {
"uvw": (3,),
"row": (1,) * ds.sizes["row"],
"chan": (1,) * ds.sizes["chan"],
"corr": (1,) * ds.sizes["corr"],
}

0 comments on commit 747408f

Please sign in to comment.