-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #84 from Jammy2211/feature/hilbert_bg
Feature/hilbert bg
- Loading branch information
Showing
4 changed files
with
176 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .hilbert import Hilbert | ||
from .hilbert_balanced import HilbertBalanced | ||
from .overlay import Overlay | ||
from .kmeans import KMeans |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
145 changes: 145 additions & 0 deletions
145
autoarray/inversion/pixelization/image_mesh/hilbert_balanced.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from __future__ import annotations | ||
import numpy as np | ||
from scipy.interpolate import interp1d, griddata | ||
from typing import Optional | ||
|
||
from autoarray.structures.grids.uniform_2d import Grid2D | ||
from autoarray.mask.mask_2d import Mask2D | ||
from autoarray.inversion.pixelization.image_mesh.abstract_weighted import ( | ||
AbstractImageMeshWeighted, | ||
) | ||
from autoarray.structures.grids.irregular_2d import Grid2DIrregular | ||
|
||
from autoarray.inversion.pixelization.image_mesh.hilbert import image_and_grid_from | ||
from autoarray.inversion.pixelization.image_mesh.hilbert import ( | ||
inverse_transform_sampling_interpolated, | ||
) | ||
|
||
from autoarray import exc | ||
|
||
|
||
class HilbertBalanced(AbstractImageMeshWeighted): | ||
def __init__( | ||
self, | ||
pixels=10.0, | ||
weight_floor=0.0, | ||
weight_power=0.0, | ||
): | ||
""" | ||
Computes a balanced image-mesh by computing the Hilbert curve of the adapt data and drawing points from it. | ||
The standard `Hilbert` image-mesh suffers a systematic where the vast majority of points are drawn from | ||
the high weighted reigons. This often leaves few points to reconstruct the lower weight regions, leading to | ||
discontinuities in the reconstruction. | ||
This image-mesh addresses this by drawing half the points from the weight map and the other half from | ||
(1 - weight map). This ensures both high and low weighted regions are sampled equally, but still has sufficient | ||
flexibility to dedicate many points to the highest weighted regions. | ||
This requires an adapt-image, which is the image that the Hilbert curve algorithm adapts to in order to compute | ||
the image mesh. This could simply be the image itself, or a model fit to the image which removes certain | ||
features or noise. | ||
For example, using the adapt image, the image mesh is computed as follows: | ||
1) Convert the adapt image to a weight map, which is a 2D array of weight values. | ||
2) Run the Hilbert algorithm on the weight map, such that the image mesh pixels cluster around the weight map | ||
values with higher values. | ||
Parameters | ||
---------- | ||
pixels | ||
The total number of pixels in the image mesh and drawn from the Hilbert curve. | ||
weight_floor | ||
The minimum weight value in the weight map, which allows more pixels to be drawn from the lower weight | ||
regions of the adapt image. | ||
weight_power | ||
The power the weight values are raised too, which allows more pixels to be drawn from the higher weight | ||
regions of the adapt image. | ||
""" | ||
|
||
super().__init__( | ||
pixels=pixels, | ||
weight_floor=weight_floor, | ||
weight_power=weight_power, | ||
) | ||
|
||
def image_plane_mesh_grid_from( | ||
self, grid: Grid2D, adapt_data: Optional[np.ndarray] | ||
) -> Grid2DIrregular: | ||
""" | ||
Returns an image mesh by running the balanced Hilbert curve on the weight map. | ||
See the `__init__` docstring for a full description of how this is performed. | ||
Parameters | ||
---------- | ||
grid | ||
The grid of (y,x) coordinates of the image data the pixelization fits, which the Hilbert curve adapts to. | ||
adapt_data | ||
The weights defining the regions of the image the Hilbert curve adapts to. | ||
Returns | ||
------- | ||
""" | ||
|
||
if not grid.mask.is_circular: | ||
raise exc.PixelizationException( | ||
""" | ||
Hilbert image-mesh has been called but the input grid does not use a circular mask. | ||
Ensure that analysis is using a circular mask via the Mask2D.circular classmethod. | ||
""" | ||
) | ||
|
||
adapt_data_hb, grid_hb = image_and_grid_from( | ||
image=adapt_data, | ||
mask=grid.mask, | ||
mask_radius=grid.mask.circular_radius, | ||
pixel_scales=grid.mask.pixel_scales, | ||
hilbert_length=193, | ||
) | ||
|
||
weight_map = self.weight_map_from(adapt_data=adapt_data_hb) | ||
|
||
weight_map_background = 1.0 - weight_map | ||
|
||
weight_map /= np.sum(weight_map) | ||
weight_map_background /= np.sum(weight_map_background) | ||
|
||
if self.pixels % 2 == 1: | ||
pixels = self.pixels + 1 | ||
else: | ||
pixels = self.pixels | ||
|
||
( | ||
drawn_id, | ||
drawn_x, | ||
drawn_y, | ||
) = inverse_transform_sampling_interpolated( | ||
probabilities=weight_map, | ||
n_samples=pixels // 2, | ||
gridx=grid_hb[:, 1], | ||
gridy=grid_hb[:, 0], | ||
) | ||
|
||
grid = np.stack((drawn_y, drawn_x), axis=-1) | ||
|
||
( | ||
drawn_id, | ||
drawn_x, | ||
drawn_y, | ||
) = inverse_transform_sampling_interpolated( | ||
probabilities=weight_map_background, | ||
n_samples=(self.pixels // 2) + 1, | ||
gridx=grid_hb[:, 1], | ||
gridy=grid_hb[:, 0], | ||
) | ||
|
||
grid_background = np.stack((drawn_y, drawn_x), axis=-1) | ||
|
||
return Grid2DIrregular( | ||
values=np.concatenate((grid, grid_background[1:, :]), axis=0) | ||
) |
29 changes: 29 additions & 0 deletions
29
test_autoarray/inversion/pixelization/image_mesh/test_hilbert_balanced.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
|
||
import autoarray as aa | ||
|
||
|
||
def test__image_plane_mesh_grid_from(): | ||
mask = aa.Mask2D.circular( | ||
shape_native=(4, 4), | ||
radius=2.0, | ||
pixel_scales=1.0, | ||
sub_size=1, | ||
) | ||
|
||
grid = aa.Grid2D.from_mask(mask=mask) | ||
|
||
adapt_data = aa.Array2D.ones( | ||
shape_native=mask.shape_native, | ||
pixel_scales=1.0, | ||
) | ||
|
||
kmeans = aa.image_mesh.HilbertBalanced(pixels=10) | ||
image_mesh = kmeans.image_plane_mesh_grid_from(grid=grid, adapt_data=adapt_data) | ||
|
||
print(image_mesh) | ||
|
||
assert image_mesh[0, :] == pytest.approx( | ||
[-1.02590674, -1.70984456], | ||
1.0e-4, | ||
) |