Skip to content

Commit

Permalink
Optimise broadcast_arrays in katdal import (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Apr 16, 2024
1 parent 350415c commit 9690ecd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Optimise `broadcast_arrays` in katdal import (:pr:`326`)
* Change `dask-ms katdal import` to `dask-ms import katdal` (:pr:`325`)
* Configure dependabot (:pr:`319`)
* Add chunk specification to ``dask-ms katdal import`` (:pr:`318`)
Expand Down
37 changes: 28 additions & 9 deletions daskms/experimental/katdal/msv2_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
################################################################################

from functools import partial
from operator import getitem

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -126,14 +127,7 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe

flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose))
weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose))
vis = DaskLazyIndexer(
dataset.vis,
(),
transforms=(
rechunk,
vis_transpose,
),
)
vis = DaskLazyIndexer(dataset.vis, (), (rechunk, vis_transpose))

time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1))
ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0]))
Expand All @@ -147,7 +141,32 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe
row=self._row_view,
)

time, ant1, ant2 = da.broadcast_arrays(time, ant1, ant2)
# Better graph than da.broadcast_arrays
bcast = da.blockwise(
np.broadcast_arrays,
("time", "bl"),
time,
("time", "bl"),
ant1,
("time", "bl"),
ant2,
("time", "bl"),
align_arrays=False,
adjust_chunks={"time": time.chunks[0], "bl": ant1.chunks[1]},
meta=np.empty((0,) * 2, dtype=np.int32),
)

time = da.blockwise(
getitem, ("time", "bl"), bcast, ("time", "bl"), 0, None, dtype=time.dtype
)

ant1 = da.blockwise(
getitem, ("time", "bl"), bcast, ("time", "bl"), 1, None, dtype=ant1.dtype
)

ant2 = da.blockwise(
getitem, ("time", "bl"), bcast, ("time", "bl"), 2, None, dtype=ant2.dtype
)

if self._row_view:
primary_dims = ("row",)
Expand Down

0 comments on commit 9690ecd

Please sign in to comment.