Skip to content

Commit

Permalink
fix dataset tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jammy2211 committed Dec 15, 2024
1 parent cdda058 commit a5ddad2
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 41 deletions.
18 changes: 3 additions & 15 deletions autoarray/dataset/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def uniform(self) -> Union[Grid1D, Grid2D]:
"""
return Grid2D.from_mask(
mask=self.mask,
over_sampling=self.over_sampling.over_sampler,
over_sampling_size=self.over_sampling.uniform,
)

@cached_property
Expand All @@ -95,7 +95,7 @@ def non_uniform(self) -> Optional[Union[Grid1D, Grid2D]]:
"""
return Grid2D.from_mask(
mask=self.mask,
over_sampling=self.over_sampling.non_uniform,
over_sampling_size=self.over_sampling.non_uniform,
)

@cached_property
Expand All @@ -116,7 +116,7 @@ def pixelization(self) -> Grid2D:
"""
return Grid2D.from_mask(
mask=self.mask,
over_sampling=self.over_sampling.pixelization,
over_sampling_size=self.over_sampling.pixelization,
)

@cached_property
Expand All @@ -143,18 +143,6 @@ def blurring(self) -> Optional[Grid2D]:
kernel_shape_native=self.psf.shape_native,
)

@cached_property
def over_sampler_uniform(self):
return self.uniform.over_sampling.over_sampler_from(mask=self.mask)

@cached_property
def over_sampler_non_uniform(self):
return self.non_uniform.over_sampling.over_sampler_from(mask=self.mask)

@cached_property
def over_sampler_pixelization(self):
return self.pixelization.over_sampling.over_sampler_from(mask=self.mask)

@cached_property
def border_relocator(self) -> BorderRelocator:
return BorderRelocator(
Expand Down
2 changes: 1 addition & 1 deletion autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def apply_over_sampling(
This class controls over sampling for all the different grids (e.g. `grid`, `grids.pixelization).
"""

uniform = over_sampling.over_sampler or self.over_sampling.over_sampler
uniform = over_sampling.uniform or self.over_sampling.uniform
non_uniform = over_sampling.non_uniform or self.over_sampling.non_uniform
pixelization = over_sampling.pixelization or self.over_sampling.pixelization

Expand Down
6 changes: 3 additions & 3 deletions autoarray/dataset/plot/imaging_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def figures_2d(

if over_sampling:
self.mat_plot_2d.plot_array(
array=self.dataset.grids.over_sampler_uniform.sub_size,
array=self.dataset.grids.uniform.over_sampling_size,
visuals_2d=self.get_visuals_2d(),
auto_labels=AutoLabels(
title=title_str or f"Over Sampling (Uniform)",
Expand All @@ -138,7 +138,7 @@ def figures_2d(

if over_sampling_non_uniform:
self.mat_plot_2d.plot_array(
array=self.dataset.grids.over_sampler_non_uniform.sub_size,
array=self.dataset.grids.non_uniform.over_sampling_size,
visuals_2d=self.get_visuals_2d(),
auto_labels=AutoLabels(
title=title_str or f"Over Sampling (Non Uniform)",
Expand All @@ -149,7 +149,7 @@ def figures_2d(

if over_sampling_pixelization:
self.mat_plot_2d.plot_array(
array=self.dataset.grids.over_sampler_pixelization.sub_size,
array=self.dataset.grids.pixelization.over_sampling_size,
visuals_2d=self.get_visuals_2d(),
auto_labels=AutoLabels(
title=title_str or f"Over Sampling (Pixelization)",
Expand Down
10 changes: 5 additions & 5 deletions autoarray/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def make_imaging_7x7():
data=make_image_7x7(),
psf=make_psf_3x3(),
noise_map=make_noise_map_7x7(),
over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=1)),
over_sampling=aa.OverSamplingDataset(uniform=1),
)


Expand All @@ -161,7 +161,7 @@ def make_imaging_7x7_sub_2():
data=make_image_7x7(),
psf=make_psf_3x3(),
noise_map=make_noise_map_7x7(),
over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=2)),
over_sampling=aa.OverSamplingDataset(uniform=2),
)


Expand All @@ -170,7 +170,7 @@ def make_imaging_covariance_7x7():
data=make_image_7x7(),
psf=make_psf_3x3(),
noise_covariance_matrix=make_noise_covariance_matrix_7x7(),
over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=1)),
over_sampling=aa.OverSamplingDataset(uniform=1),
)


Expand All @@ -179,7 +179,7 @@ def make_imaging_7x7_no_blur():
data=make_image_7x7(),
psf=make_psf_3x3_no_blur(),
noise_map=make_noise_map_7x7(),
over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=1)),
over_sampling=aa.OverSamplingDataset(uniform=1),
)


Expand All @@ -188,7 +188,7 @@ def make_imaging_7x7_no_blur_sub_2():
data=make_image_7x7(),
psf=make_psf_3x3_no_blur(),
noise_map=make_noise_map_7x7(),
over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=2)),
over_sampling=aa.OverSamplingDataset(uniform=2),
)


Expand Down
4 changes: 2 additions & 2 deletions autoarray/structures/decorators/to_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def via_grid_2d(self, result) -> Union[Grid2D, List[Grid2D]]:
return Grid2D(
values=result,
mask=self.mask,
over_sampling=self.over_sampling,
over_sampling_size=self.over_sampling_size,
)
return [
Grid2D(
values=res,
mask=self.mask,
over_sampling=self.over_sampling,
over_sampling_size=self.over_sampling_size,
)
for res in result
]
Expand Down
28 changes: 14 additions & 14 deletions test_autoarray/dataset/abstract/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test__grid__uses_mask_and_settings(
masked_imaging_7x7 = ds.AbstractDataset(
data=masked_image_7x7,
noise_map=masked_noise_map_7x7,
over_sampling=aa.OverSamplingDataset(uniform=aa.OverSampling(sub_size=2)),
over_sampling=aa.OverSamplingDataset(uniform=2),
)

assert isinstance(masked_imaging_7x7.grids.uniform, aa.Grid2D)
Expand Down Expand Up @@ -95,27 +95,27 @@ def test__grids_pixelization__uses_mask_and_settings(
data=masked_image_7x7,
noise_map=masked_noise_map_7x7,
over_sampling=aa.OverSamplingDataset(
uniform=aa.OverSampling(sub_size=2),
pixelization=aa.OverSampling(sub_size=4),
uniform=2,
pixelization=4,
),
)

assert isinstance(masked_imaging_7x7.grids.pixelization, aa.Grid2D)
assert masked_imaging_7x7.grids.pixelization.over_sampling.sub_size == 4
assert masked_imaging_7x7.grids.pixelization.over_sampling_size[0] == 4


def test__grid_settings__sub_size(image_7x7, noise_map_7x7):
dataset_7x7 = ds.AbstractDataset(
data=image_7x7,
noise_map=noise_map_7x7,
over_sampling=aa.OverSamplingDataset(
uniform=aa.OverSampling(sub_size=2),
pixelization=aa.OverSampling(sub_size=4),
uniform=2,
pixelization=4,
),
)

assert dataset_7x7.grids.uniform.over_sampling.sub_size == 2
assert dataset_7x7.grids.pixelization.over_sampling.sub_size == 4
assert dataset_7x7.grids.uniform.over_sampling_size[0] == 2
assert dataset_7x7.grids.pixelization.over_sampling_size[0] == 4


def test__new_imaging_with_arrays_trimmed_via_kernel_shape():
Expand All @@ -141,8 +141,8 @@ def test__apply_over_sampling(image_7x7, noise_map_7x7):
data=image_7x7,
noise_map=noise_map_7x7,
over_sampling=aa.OverSamplingDataset(
uniform=aa.OverSampling(sub_size=2),
pixelization=aa.OverSampling(sub_size=2),
uniform=2,
pixelization=2,
),
)

Expand All @@ -160,13 +160,13 @@ def test__apply_over_sampling(image_7x7, noise_map_7x7):

dataset_7x7 = dataset_7x7.apply_over_sampling(
over_sampling=aa.OverSamplingDataset(
uniform=aa.OverSampling(sub_size=4),
pixelization=aa.OverSampling(sub_size=4),
uniform=4,
pixelization=4,
)
)

assert dataset_7x7.over_sampling.over_sampler.sub_size == 4
assert dataset_7x7.over_sampling.pixelization.sub_size == 4
assert dataset_7x7.over_sampling.uniform == 4
assert dataset_7x7.over_sampling.pixelization == 4

assert dataset_7x7.grids.uniform[0][0] == pytest.approx(3.0, 1.0e-4)
assert dataset_7x7.grids.pixelization[0][0] == pytest.approx(3.0, 1.0e-4)
2 changes: 1 addition & 1 deletion test_autoarray/dataset/plot/test_imaging_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test__individual_attributes_are_output(
visuals = aplt.Visuals2D(mask=mask_2d_7x7, positions=grid_2d_irregular_7x7_list)

imaging_7x7 = imaging_7x7.apply_over_sampling(
over_sampling=aa.OverSamplingDataset(non_uniform=aa.OverSampling(sub_size=1))
over_sampling=aa.OverSamplingDataset(non_uniform=1)
)

dataset_plotter = aplt.ImagingPlotter(
Expand Down

0 comments on commit a5ddad2

Please sign in to comment.