diff --git a/dkist/conftest.py b/dkist/conftest.py index 084fcbf9..b8bfee18 100644 --- a/dkist/conftest.py +++ b/dkist/conftest.py @@ -267,6 +267,31 @@ def dataset_4d(identity_gwcs_4d, empty_meta): return Dataset(array, wcs=identity_gwcs_4d, meta=empty_meta, unit=u.count) +@pytest.fixture +def dataset_5d(identity_gwcs_5d_stokes, empty_meta): + shape = (4, 40, 30, 20, 10) + x = np.ones(shape) + array = da.from_array(x, tuple(shape)) + + identity_gwcs_4d.pixel_shape = array.shape[::-1] + identity_gwcs_4d.array_shape = array.shape + + ds = Dataset(array, wcs=identity_gwcs_5d_stokes, meta={"inventory": {}, "headers": Table()}, unit=u.count) + fileuris = np.array([f"dummyfile_{i}" for i in range(np.prod(shape[:-2]))]).reshape(shape[:-2]) + ds._file_manager = FileManager.from_parts(fileuris, 0, float, shape[-2:], loader=AstropyFITSLoader, basepath="./") + + return ds + + +@pytest.fixture +def dataset_5d_dummy_filemanager_axis(dataset_5d): + shape = dataset_5d.data.shape + fileuris = np.array([f"dummyfile_{i}" for i in range(np.prod(shape[:-2]))]).reshape(shape[:-2]) + dataset_5d._file_manager = FileManager.from_parts(fileuris, 0, float, (1, *shape[-2:]), loader=AstropyFITSLoader, basepath="./") + + return dataset_5d + + @pytest.fixture def eit_dataset(): eitdir = Path(rootdir) / "EIT" diff --git a/dkist/dataset/tests/test_dataset.py b/dkist/dataset/tests/test_dataset.py index b6445347..90b8c577 100644 --- a/dkist/dataset/tests/test_dataset.py +++ b/dkist/dataset/tests/test_dataset.py @@ -179,12 +179,22 @@ def test_header_slicing_3D_slice(large_visp_dataset): @pytest.mark.accept_cli_dataset -def test_file_slicing_with_dummy_axis(large_visp_dataset): - assert len(large_visp_dataset[0].files) == 20 - assert len(large_visp_dataset[0, 0].files) == 1 +def test_file_slicing_with_dummy_axis(dataset_5d_dummy_filemanager_axis): + ds = dataset_5d_dummy_filemanager_axis + shape = ds.data.shape + assert len(ds.files) == np.prod(shape[:3]) + assert len(ds[0].files) == np.prod(shape[1:3]) + assert len(ds[0, 0].files) == np.prod(shape[2]) + assert len(ds[0, 0, 0].files) == 1 + assert len(ds[0, 0, 0, 0].files) == 1 @pytest.mark.accept_cli_dataset -def test_file_slicing_without_dummy_axis(large_visp_no_dummy_axis): - assert len(large_visp_no_dummy_axis[0].files) == 20 - assert len(large_visp_no_dummy_axis[0, 0].files) == 1 +def test_file_slicing_without_dummy_axis(dataset_5d): + ds = dataset_5d + shape = ds.data.shape + assert len(ds.files) == np.prod(shape[:3]) + assert len(ds[0].files) == np.prod(shape[1:3]) + assert len(ds[0, 0].files) == np.prod(shape[2]) + assert len(ds[0, 0, 0].files) == 1 + assert len(ds[0, 0, 0, 0].files) == 1