Skip to content

Commit

Permalink
Identify channel and correlation-like dimensions in non-standard MS c…
Browse files Browse the repository at this point in the history
…olumns
  • Loading branch information
sjperkins committed May 13, 2024
1 parent b436c0e commit de9a0dc
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 4 deletions.
7 changes: 6 additions & 1 deletion daskms/dask_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,12 @@ def xds_from_ms(ms, columns=None, index_cols=None, group_cols=None, **kwargs):
kwargs.setdefault("table_schema", "MS")

return xds_from_table(
ms, columns=columns, index_cols=index_cols, group_cols=group_cols, **kwargs
ms,
columns=columns,
index_cols=index_cols,
group_cols=group_cols,
context="ms",
**kwargs,
)


Expand Down
58 changes: 55 additions & 3 deletions daskms/reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def __init__(self, table, select_cols, group_cols, index_cols, **kwargs):
self.table_keywords = kwargs.pop("table_keywords", False)
self.column_keywords = kwargs.pop("column_keywords", False)
self.table_proxy = kwargs.pop("table_proxy", False)
self.context = kwargs.pop("context", None)

if len(kwargs) > 0:
raise ValueError(f"Unhandled kwargs: {kwargs}")
Expand Down Expand Up @@ -359,8 +360,10 @@ def _single_dataset(self, table_proxy, orders, exemplar_row=0):
coords = {"ROWID": rowid}

attrs = {DASKMS_PARTITION_KEY: ()}

return Dataset(variables, coords=coords, attrs=attrs)
dataset = Dataset(variables, coords=coords, attrs=attrs)
return self.postprocess_dataset(
dataset, table_proxy, exemplar_row, orders, self.chunks[0], short_table_name
)

def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
_, t, s = table_path_split(self.canonical_name)
Expand Down Expand Up @@ -420,10 +423,59 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
group_id = [gid.item() for gid in group_id]
attrs.update(zip(self.group_cols, group_id))

datasets.append(Dataset(group_var_dims, attrs=attrs, coords=coords))
dataset = Dataset(group_var_dims, attrs=attrs, coords=coords)
dataset = self.postprocess_dataset(
dataset, table_proxy, exemplar_row, order, group_chunks, array_suffix
)
datasets.append(dataset)

return datasets

def postprocess_dataset(
self, dataset, table_proxy, exemplar_row, order, chunks, array_suffix
):
if not self.context or self.context != "ms":
return dataset

# Fixup any non-standard columns
# with dimensions like chan and corr
try:
chan = dataset.sizes["chan"]
corr = dataset.sizes["corr"]
except KeyError:
return dataset

schema_updates = {}

for name, var in dataset.data_vars.items():
new_dims = list(var.dims[1:])

for dim, dim_name in enumerate(var.dims[1:]):
if dim_name.startswith(f"{name}-"):
if dataset.sizes[dim_name] == chan:
new_dims[dim] = "chan"
elif dataset.sizes[dim_name] == corr:
new_dims[dim] = "corr"

new_dims = tuple(new_dims)
if var.dims[1:] != new_dims:
schema_updates[name] = {"dims": new_dims}

if not schema_updates:
return dataset

return dataset.assign(
**_dataset_variable_factory(
table_proxy,
schema_updates,
list(schema_updates.keys()),
exemplar_row,
order,
chunks,
array_suffix,
)
)

def datasets(self):
table_proxy = self._table_proxy_factory()

Expand Down
35 changes: 35 additions & 0 deletions daskms/tests/test_ms_read_and_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,38 @@ def test_mismatched_rowid(ms):

def test_request_rowid(ms):
xdsl = xds_from_ms(ms, columns=["TIME", "ROWID"]) # noqa


def test_postprocess_ms(ms):
"""Test that postprocessing of MS variables identifies chan/corr like data"""
xdsl = xds_from_ms(ms)

def _array(ds, dims):
shape = tuple(ds.sizes[d] for d in dims)
chunks = tuple(ds.chunks[d] for d in dims)
return (dims, da.random.random(size=shape, chunks=chunks))

# Write some non-standard columns back to the MS
for i, ds in enumerate(xdsl):
xdsl[i] = ds.assign(
{
"BITFLAG": _array(ds, ("row", "chan", "corr")),
"HAS_CORRS": _array(ds, ("row", "corr")),
"HAS_CHANS": _array(ds, ("row", "chan")),
}
)

dask.compute(xds_to_table(xdsl, ms))

for ds in xds_from_ms(ms, chunks={"row": 1, "chan": 1, "corr": 1}):
assert ds.BITFLAG.dims == ("row", "chan", "corr")

assert ds.HAS_CORRS.dims == ("row", "corr")
assert ds.HAS_CHANS.dims == ("row", "chan")

assert dict(ds.chunks) == {
"uvw": (3,),
"row": (1,) * ds.sizes["row"],
"chan": (1,) * ds.sizes["chan"],
"corr": (1,) * ds.sizes["corr"],
}

0 comments on commit de9a0dc

Please sign in to comment.