Skip to content

Commit

Permalink
keep protocol in tact when globbing from object stores
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Sep 17, 2024
1 parent 581f664 commit bf9395c
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 47 deletions.
19 changes: 13 additions & 6 deletions pfb/utils/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,31 @@ def set_output_names(opts):

def xds_from_url(url, columns='ALL', chunks=-1):
'''
Returns a lazy view of the dataset
Returns a lazy view of all datasets contained in url
as well as a list of full paths to each one
'''
if columns.upper() != 'ALL':
raise NotImplementedError
if chunks != -1:
raise NotImplementedError

url = url.rstrip('/')
from daskms.fsspec_store import DaskMSStore
store = DaskMSStore(url)

if '://' in url:
protocol = url.split('://')[0]
prefix = f'{protocol}://'
else:
protocol = 'file'
prefix = ''
fs = fsspec.filesystem(protocol)
ds_list = fs.glob(f'{url}/*.zarr')
ds_list = list(map(lambda x: prefix + x, ds_list))
# these will only be read in on first value access and won't be chunked
open_zarr = partial(xr.open_zarr, chunks=None)
xds = list(map(open_zarr, store.fs.glob(f'{url}/*.zarr')))
xds = list(map(open_zarr, ds_list))

if not len(xds):
raise ValueError(f'Nothing found at {url}')
return xds
return xds, ds_list


def read_var(ds, var):
Expand Down
3 changes: 1 addition & 2 deletions pfb/workers/fluxmop.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def _fluxmop(**kw):

dds_name = f'{basename}_{opts.suffix}.dds'
dds_store = DaskMSStore(dds_name)
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
dds = xds_from_url(dds_store.url)
dds, dds_list = xds_from_url(dds_store.url)

nx, ny = dds[0].x.size, dds[0].y.size
nx_psf, ny_psf = dds[0].x_psf.size, dds[0].y_psf.size
Expand Down
19 changes: 4 additions & 15 deletions pfb/workers/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def grid(**kw):
ti = time.time()
residual_mfs = _grid(**opts)

dds = xds_from_url(dds_store.url)
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
dds, dds_list = xds_from_url(dds_store.url)

# convert to fits files
futures = []
Expand Down Expand Up @@ -210,7 +209,7 @@ def _grid(**kw):
raise ValueError(f"There must be a dataset at {xds_store.url}")

print(f"Lazy loading xds from {xds_store.url}", file=log)
xds = xds_from_url(xds_store.url)
xds, xds_list = xds_from_url(xds_store.url)

times_in = []
freqs_in = []
Expand Down Expand Up @@ -254,16 +253,6 @@ def _grid(**kw):
print(f"Field of view is ({nx*cell_deg:.3e},{ny*cell_deg:.3e}) degrees",
file=log)


# TODO - how to glob with protocol in tact?
ds_list = xds_store.fs.glob(f'{xds_store.url}/*')
if '://' in xds_store.url:
protocol = xds_store.url.split('://')[0]
else:
protocol = 'file'
url_prepend = protocol + '://'
ds_list = list(map(lambda x: url_prepend + x, ds_list))

# create dds and cache
dds_name = opts.output_filename + f'_{opts.suffix}' + '.dds'
dds_store = DaskMSStore(dds_name)
Expand Down Expand Up @@ -322,7 +311,7 @@ def _grid(**kw):
ntime = 1
times_out = np.mean(times_in, keepdims=True)
for b in range(nband):
for ds, ds_name in zip(xds, ds_list):
for ds, ds_name in zip(xds, xds_list):
if ds.bandid == b:
tbid = f'time0000_band{b:04d}'
xds_dct.setdefault(tbid, {})
Expand All @@ -336,7 +325,7 @@ def _grid(**kw):
times_out = times_in
for t in range(times_in.size):
for b in range(nband):
for ds, ds_name in zip(xds, ds_list):
for ds, ds_name in zip(xds, xds_list):
if ds.time_out == times_in[t] and ds.bandid == b:
tbid = f'time{t:04d}_band{b:04d}'
xds_dct.setdefault(tbid, {})
Expand Down
2 changes: 1 addition & 1 deletion pfb/workers/klean.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def klean(**kw):
ti = time.time()
_klean(**opts)

dds = xds_from_url(dds_store.url)
dds, dds_list = xds_from_url(dds_store.url)

from pfb.utils.fits import dds2fits

Expand Down
3 changes: 1 addition & 2 deletions pfb/workers/model2comps.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def _model2comps(**kw):

dds_name = f'{basename}_{opts.suffix}.dds'
dds_store = DaskMSStore(dds_name)
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
dds = xds_from_url(dds_store.url)
dds, dds_list = xds_from_url(dds_store.url)

if opts.model_out is not None:
coeff_name = opts.model_out
Expand Down
3 changes: 1 addition & 2 deletions pfb/workers/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def _restore(**kw):

dds_name = f'{basename}_{opts.suffix}.dds'
dds_store = DaskMSStore(dds_name)
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
dds = xds_from_url(dds_store.url)
dds, dds_list = xds_from_url(dds_store.url)

if opts.drop_bands is not None:
ddso = []
Expand Down
3 changes: 1 addition & 2 deletions pfb/workers/sara.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,12 @@ def sara(**kw):
basename = opts.output_filename
fits_oname = f'{opts.fits_output_folder}/{oname}'
dds_store = DaskMSStore(f'{basename}_{opts.suffix}.dds')
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')

with ExitStack() as stack:
ti = time.time()
_sara(**opts)

dds = xds_from_url(dds_store.url)
dds, dds_list = xds_from_url(dds_store.url)

if opts.fits_mfs or opts.fits:
from pfb.utils.fits import dds2fits
Expand Down
13 changes: 2 additions & 11 deletions pfb/workers/spotless.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,7 @@ def _spotless(**kw):
except Exception as e:
raise ValueError(f"There must be a dataset at {xds_store.url}")

xds = xds_from_url(xds_name)

# TODO - how to glob with protocol in tact?
xds_list = xds_store.fs.glob(f'{xds_store.url}/*')
if '://' in xds_store.url:
protocol = xds_store.url.split('://')[0]
else:
protocol = 'file'
url_prepend = protocol + '://'
xds_list = list(map(lambda x: url_prepend + x, xds_list))
xds, xds_list = xds_from_url(xds_store.url)

# create dds and cache
dds_name = opts.output_filename + f'_{opts.suffix}' + '.dds'
Expand Down Expand Up @@ -231,7 +222,7 @@ def _spotless(**kw):

from_cache = True
print("Initialising from cached data products", file=log)
dds = xds_from_url(dds_store.url)
dds, dds_list = xds_from_url(dds_store.url)
iter0 = dds[0].niters
except Exception as e:
print(f'Cache verification failed on {attr}. '
Expand Down
2 changes: 1 addition & 1 deletion tests/test_band_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if __name__=='__main__':
# xds_name = '/scratch/bester/stage7_combined_bda_I.xds'
xds_name = '/home/landman/testing/pfb/out/data_I.xds'
xds = xds_from_url(xds_name)
xds, _ = xds_from_url(xds_name)
xds_store = DaskMSStore(xds_name)
xds_list = xds_store.fs.glob(f'{xds_store.url}/*')

Expand Down
2 changes: 1 addition & 1 deletion tests/test_klean.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_klean(do_gains, ms_name):
_klean(**klean_args)

# get inferred model
dds = xds_from_url(dds_name)
dds, _ = xds_from_url(dds_name)
model_inferred = np.zeros((nchan, nx, ny))
for ds in dds:
b = int(ds.bandid)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_polproducts.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_polproducts(do_gains, ms_name):
from pfb.workers.grid import _grid
_grid(**grid_args)

dds = xds_from_url(dds_name)
dds, _ = xds_from_url(dds_name)

for ds in dds:
wsum = ds.WSUM.values
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sara.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_sara(ms_name):

# the residual computed by the grid worker should be identical
# to that computed in sara when transferring model
dds = xds_from_url(dds_name)
dds, _ = xds_from_url(dds_name)

# grid data to produce dirty image
grid_args = {}
Expand All @@ -221,7 +221,7 @@ def test_sara(ms_name):
grid_args["suffix"] = 'subtract'
_grid(**grid_args)

dds2 = xds_from_url(f'{outname}_subtract.dds')
dds2, _ = xds_from_url(f'{outname}_subtract.dds')

for ds, ds2 in zip(dds, dds2):
wsum = ds.WSUM.values
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_sara(ms_name):

dds_name = f'{outname}_main.dds'

dds2 = xds_from_url(dds_name)
dds2, _ = xds_from_url(dds_name)

for ds, ds2 in zip(dds, dds2):
wsum = ds.WSUM.values
Expand Down

0 comments on commit bf9395c

Please sign in to comment.