From a5ddad2925a70c50ddc52eab932848627b884106 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sun, 15 Dec 2024 14:42:40 +0000 Subject: [PATCH] fix dataset tests --- autoarray/dataset/grids.py | 18 ++---------- autoarray/dataset/imaging/dataset.py | 2 +- autoarray/dataset/plot/imaging_plotters.py | 6 ++-- autoarray/fixtures.py | 10 +++---- autoarray/structures/decorators/to_grid.py | 4 +-- .../dataset/abstract/test_dataset.py | 28 +++++++++---------- .../dataset/plot/test_imaging_plotters.py | 2 +- 7 files changed, 29 insertions(+), 41 deletions(-) diff --git a/autoarray/dataset/grids.py b/autoarray/dataset/grids.py index 1b86b44e..8e30c2d1 100644 --- a/autoarray/dataset/grids.py +++ b/autoarray/dataset/grids.py @@ -71,7 +71,7 @@ def uniform(self) -> Union[Grid1D, Grid2D]: """ return Grid2D.from_mask( mask=self.mask, - over_sampling=self.over_sampling.over_sampler, + over_sampling_size=self.over_sampling.uniform, ) @cached_property @@ -95,7 +95,7 @@ def non_uniform(self) -> Optional[Union[Grid1D, Grid2D]]: """ return Grid2D.from_mask( mask=self.mask, - over_sampling=self.over_sampling.non_uniform, + over_sampling_size=self.over_sampling.non_uniform, ) @cached_property @@ -116,7 +116,7 @@ def pixelization(self) -> Grid2D: """ return Grid2D.from_mask( mask=self.mask, - over_sampling=self.over_sampling.pixelization, + over_sampling_size=self.over_sampling.pixelization, ) @cached_property @@ -143,18 +143,6 @@ def blurring(self) -> Optional[Grid2D]: kernel_shape_native=self.psf.shape_native, ) - @cached_property - def over_sampler_uniform(self): - return self.uniform.over_sampling.over_sampler_from(mask=self.mask) - - @cached_property - def over_sampler_non_uniform(self): - return self.non_uniform.over_sampling.over_sampler_from(mask=self.mask) - - @cached_property - def over_sampler_pixelization(self): - return self.pixelization.over_sampling.over_sampler_from(mask=self.mask) - @cached_property def border_relocator(self) -> BorderRelocator: return BorderRelocator( diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 8ada9e4b..c3aa4faa 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -455,7 +455,7 @@ def apply_over_sampling( This class controls over sampling for all the different grids (e.g. `grid`, `grids.pixelization). """ - uniform = over_sampling.over_sampler or self.over_sampling.over_sampler + uniform = over_sampling.uniform or self.over_sampling.uniform non_uniform = over_sampling.non_uniform or self.over_sampling.non_uniform pixelization = over_sampling.pixelization or self.over_sampling.pixelization diff --git a/autoarray/dataset/plot/imaging_plotters.py b/autoarray/dataset/plot/imaging_plotters.py index 57fb92d8..f758cf59 100644 --- a/autoarray/dataset/plot/imaging_plotters.py +++ b/autoarray/dataset/plot/imaging_plotters.py @@ -127,7 +127,7 @@ def figures_2d( if over_sampling: self.mat_plot_2d.plot_array( - array=self.dataset.grids.over_sampler_uniform.sub_size, + array=self.dataset.grids.uniform.over_sampling_size, visuals_2d=self.get_visuals_2d(), auto_labels=AutoLabels( title=title_str or f"Over Sampling (Uniform)", @@ -138,7 +138,7 @@ def figures_2d( if over_sampling_non_uniform: self.mat_plot_2d.plot_array( - array=self.dataset.grids.over_sampler_non_uniform.sub_size, + array=self.dataset.grids.non_uniform.over_sampling_size, visuals_2d=self.get_visuals_2d(), auto_labels=AutoLabels( title=title_str or f"Over Sampling (Non Uniform)", @@ -149,7 +149,7 @@ def figures_2d( if over_sampling_pixelization: self.mat_plot_2d.plot_array( - array=self.dataset.grids.over_sampler_pixelization.sub_size, + array=self.dataset.grids.pixelization.over_sampling_size, visuals_2d=self.get_visuals_2d(), auto_labels=AutoLabels( title=title_str or f"Over Sampling (Pixelization)", diff --git a/autoarray/fixtures.py b/autoarray/fixtures.py index 601333b0..9cceef5b 100644 --- a/autoarray/fixtures.py +++ b/autoarray/fixtures.py @@ -152,7 +152,7 @@ def make_imaging_7x7(): data=make_image_7x7(), psf=make_psf_3x3(), noise_map=make_noise_map_7x7(), - over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=1)), + over_sampling=aa.OverSamplingDataset(uniform=1), ) @@ -161,7 +161,7 @@ def make_imaging_7x7_sub_2(): data=make_image_7x7(), psf=make_psf_3x3(), noise_map=make_noise_map_7x7(), - over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=2)), + over_sampling=aa.OverSamplingDataset(uniform=2), ) @@ -170,7 +170,7 @@ def make_imaging_covariance_7x7(): data=make_image_7x7(), psf=make_psf_3x3(), noise_covariance_matrix=make_noise_covariance_matrix_7x7(), - over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=1)), + over_sampling=aa.OverSamplingDataset(uniform=1), ) @@ -179,7 +179,7 @@ def make_imaging_7x7_no_blur(): data=make_image_7x7(), psf=make_psf_3x3_no_blur(), noise_map=make_noise_map_7x7(), - over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=1)), + over_sampling=aa.OverSamplingDataset(uniform=1), ) @@ -188,7 +188,7 @@ def make_imaging_7x7_no_blur_sub_2(): data=make_image_7x7(), psf=make_psf_3x3_no_blur(), noise_map=make_noise_map_7x7(), - over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=2)), + over_sampling=aa.OverSamplingDataset(uniform=2), ) diff --git a/autoarray/structures/decorators/to_grid.py b/autoarray/structures/decorators/to_grid.py index 3b70b1b2..706397fc 100644 --- a/autoarray/structures/decorators/to_grid.py +++ b/autoarray/structures/decorators/to_grid.py @@ -25,13 +25,13 @@ def via_grid_2d(self, result) -> Union[Grid2D, List[Grid2D]]: return Grid2D( values=result, mask=self.mask, - over_sampling=self.over_sampling, + over_sampling_size=self.over_sampling_size, ) return [ Grid2D( values=res, mask=self.mask, - over_sampling=self.over_sampling, + over_sampling_size=self.over_sampling_size, ) for res in result ] diff --git a/test_autoarray/dataset/abstract/test_dataset.py b/test_autoarray/dataset/abstract/test_dataset.py index 8d7db370..05350fee 100644 --- a/test_autoarray/dataset/abstract/test_dataset.py +++ b/test_autoarray/dataset/abstract/test_dataset.py @@ -65,7 +65,7 @@ def test__grid__uses_mask_and_settings( masked_imaging_7x7 = ds.AbstractDataset( data=masked_image_7x7, noise_map=masked_noise_map_7x7, - over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=2)), + over_sampling=aa.OverSamplingDataset(uniform=2), ) assert isinstance(masked_imaging_7x7.grids.uniform, aa.Grid2D) @@ -95,13 +95,13 @@ def test__grids_pixelization__uses_mask_and_settings( data=masked_image_7x7, noise_map=masked_noise_map_7x7, over_sampling=aa.OverSamplingDataset( - uniform=aa.OverSampling(sub_size=2), - pixelization=aa.OverSampling(sub_size=4), + uniform=2, + pixelization=4, ), ) assert isinstance(masked_imaging_7x7.grids.pixelization, aa.Grid2D) - assert masked_imaging_7x7.grids.pixelization.over_sampling.sub_size == 4 + assert masked_imaging_7x7.grids.pixelization.over_sampling_size[0] == 4 def test__grid_settings__sub_size(image_7x7, noise_map_7x7): @@ -109,13 +109,13 @@ def test__grid_settings__sub_size(image_7x7, noise_map_7x7): data=image_7x7, noise_map=noise_map_7x7, over_sampling=aa.OverSamplingDataset( - uniform=aa.OverSampling(sub_size=2), - pixelization=aa.OverSampling(sub_size=4), + uniform=2, + pixelization=4, ), ) - assert dataset_7x7.grids.uniform.over_sampling.sub_size == 2 - assert dataset_7x7.grids.pixelization.over_sampling.sub_size == 4 + assert dataset_7x7.grids.uniform.over_sampling_size[0] == 2 + assert dataset_7x7.grids.pixelization.over_sampling_size[0] == 4 def test__new_imaging_with_arrays_trimmed_via_kernel_shape(): @@ -141,8 +141,8 @@ def test__apply_over_sampling(image_7x7, noise_map_7x7): data=image_7x7, noise_map=noise_map_7x7, over_sampling=aa.OverSamplingDataset( - uniform=aa.OverSampling(sub_size=2), - pixelization=aa.OverSampling(sub_size=2), + uniform=2, + pixelization=2, ), ) @@ -160,13 +160,13 @@ def test__apply_over_sampling(image_7x7, noise_map_7x7): dataset_7x7 = dataset_7x7.apply_over_sampling( over_sampling=aa.OverSamplingDataset( - uniform=aa.OverSampling(sub_size=4), - pixelization=aa.OverSampling(sub_size=4), + uniform=4, + pixelization=4, ) ) - assert dataset_7x7.over_sampling.over_sampler.sub_size == 4 - assert dataset_7x7.over_sampling.pixelization.sub_size == 4 + assert dataset_7x7.over_sampling.uniform == 4 + assert dataset_7x7.over_sampling.pixelization == 4 assert dataset_7x7.grids.uniform[0][0] == pytest.approx(3.0, 1.0e-4) assert dataset_7x7.grids.pixelization[0][0] == pytest.approx(3.0, 1.0e-4) diff --git a/test_autoarray/dataset/plot/test_imaging_plotters.py b/test_autoarray/dataset/plot/test_imaging_plotters.py index 84b9e420..65d1f9a0 100644 --- a/test_autoarray/dataset/plot/test_imaging_plotters.py +++ b/test_autoarray/dataset/plot/test_imaging_plotters.py @@ -20,7 +20,7 @@ def test__individual_attributes_are_output( visuals = aplt.Visuals2D(mask=mask_2d_7x7, positions=grid_2d_irregular_7x7_list) imaging_7x7 = imaging_7x7.apply_over_sampling( - over_sampling=aa.OverSamplingDataset(non_uniform=aa.OverSampling(sub_size=1)) + over_sampling=aa.OverSamplingDataset(non_uniform=1) ) dataset_plotter = aplt.ImagingPlotter(