Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Jammy2211 committed Oct 3, 2024
1 parent 5f9b6ff commit 280d574
Show file tree
Hide file tree
Showing 15 changed files with 102 additions and 94 deletions.
1 change: 0 additions & 1 deletion autoarray/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ class PlottingException(Exception):
pass



class ProfilingException(Exception):
"""
Raises exceptions associated with in-built profiling tools (e.g. the `profile_func` decorator).
Expand Down
10 changes: 6 additions & 4 deletions autoarray/geometry/geometry_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,12 @@ def transform_grid_2d_to_reference_frame(
theta_coordinate_to_profile = np.arctan2(
shifted_grid_2d[:, 0], shifted_grid_2d[:, 1]
) - np.radians(angle)
return np.vstack([
radius * np.sin(theta_coordinate_to_profile),
radius * np.cos(theta_coordinate_to_profile)
]).T
return np.vstack(
[
radius * np.sin(theta_coordinate_to_profile),
radius * np.cos(theta_coordinate_to_profile),
]
).T


def transform_grid_2d_from_reference_frame(
Expand Down
12 changes: 6 additions & 6 deletions autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,12 @@ def reconstruction(self) -> np.ndarray:

solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0])

solutions[
values_to_solve
] = inversion_util.reconstruction_positive_only_from(
data_vector=data_vector_input,
curvature_reg_matrix=curvature_reg_matrix_input,
settings=self.settings,
solutions[values_to_solve] = (
inversion_util.reconstruction_positive_only_from(
data_vector=data_vector_input,
curvature_reg_matrix=curvature_reg_matrix_input,
settings=self.settings,
)
)
return solutions
else:
Expand Down
16 changes: 9 additions & 7 deletions autoarray/inversion/inversion/imaging/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
"""

return [
self.convolver.convolve_mapping_matrix(
mapping_matrix=linear_obj.mapping_matrix
(
self.convolver.convolve_mapping_matrix(
mapping_matrix=linear_obj.mapping_matrix
)
if linear_obj.operated_mapping_matrix_override is None
else self.linear_func_operated_mapping_matrix_dict[linear_obj]
)
if linear_obj.operated_mapping_matrix_override is None
else self.linear_func_operated_mapping_matrix_dict[linear_obj]
for linear_obj in self.linear_obj_list
]

Expand Down Expand Up @@ -141,9 +143,9 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict:
mapping_matrix=linear_func.mapping_matrix
)

linear_func_operated_mapping_matrix_dict[
linear_func
] = operated_mapping_matrix
linear_func_operated_mapping_matrix_dict[linear_func] = (
operated_mapping_matrix
)

return linear_func_operated_mapping_matrix_dict

Expand Down
12 changes: 6 additions & 6 deletions autoarray/inversion/pixelization/border_relocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ def sub_border_pixel_slim_indexes_from(
int(border_pixel)
]

sub_border_pixels[
border_1d_index
] = grid_2d_util.furthest_grid_2d_slim_index_from(
grid_2d_slim=sub_grid_2d_slim,
slim_indexes=sub_border_pixels_of_border_pixel,
coordinate=mask_centre,
sub_border_pixels[border_1d_index] = (
grid_2d_util.furthest_grid_2d_slim_index_from(
grid_2d_slim=sub_grid_2d_slim,
slim_indexes=sub_border_pixels_of_border_pixel,
coordinate=mask_centre,
)
)

return sub_border_pixels
Expand Down
7 changes: 3 additions & 4 deletions autoarray/inversion/plot/inversion_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,9 @@ def figures_2d_of_pixelization(
"inversion"
]["reconstruction_vmax_factor"]

self.mat_plot_2d.cmap.kwargs[
"vmax"
] = reconstruction_vmax_factor * np.max(
self.inversion.reconstruction
self.mat_plot_2d.cmap.kwargs["vmax"] = (
reconstruction_vmax_factor
* np.max(self.inversion.reconstruction)
)
vmax_custom = True

Expand Down
4 changes: 1 addition & 3 deletions autoarray/mask/mask_2d_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,7 @@ def elliptical_radius_from(
y_scaled_elliptical = r_scaled * np.sin(theta_rotated)
x_scaled_elliptical = r_scaled * np.cos(theta_rotated)

return np.sqrt(
x_scaled_elliptical**2.0 + (y_scaled_elliptical / axis_ratio) ** 2.0
)
return np.sqrt(x_scaled_elliptical**2.0 + (y_scaled_elliptical / axis_ratio) ** 2.0)


@numba_util.jit()
Expand Down
1 change: 0 additions & 1 deletion autoarray/numpy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax
from jax import numpy as np, jit


print("JAX mode enabled")
except ImportError:
raise ImportError(
Expand Down
65 changes: 28 additions & 37 deletions autoarray/operators/convolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,12 @@ def __init__(self, mask, kernel):
mask_index_array=self.mask_index_array,
kernel_2d=np.array(self.kernel.native[:, :]),
)
self.image_frame_1d_indexes[
mask_1d_index, :
] = image_frame_1d_indexes
self.image_frame_1d_kernels[
mask_1d_index, :
] = image_frame_1d_kernels
self.image_frame_1d_indexes[mask_1d_index, :] = (
image_frame_1d_indexes
)
self.image_frame_1d_kernels[mask_1d_index, :] = (
image_frame_1d_kernels
)
self.image_frame_1d_lengths[mask_1d_index] = image_frame_1d_indexes[
image_frame_1d_indexes >= 0
].shape[0]
Expand Down Expand Up @@ -265,15 +265,15 @@ def __init__(self, mask, kernel):
mask_index_array=np.array(self.mask_index_array),
kernel_2d=np.array(self.kernel.native),
)
self.blurring_frame_1d_indexes[
mask_1d_index, :
] = image_frame_1d_indexes
self.blurring_frame_1d_kernels[
mask_1d_index, :
] = image_frame_1d_kernels
self.blurring_frame_1d_lengths[
mask_1d_index
] = image_frame_1d_indexes[image_frame_1d_indexes >= 0].shape[0]
self.blurring_frame_1d_indexes[mask_1d_index, :] = (
image_frame_1d_indexes
)
self.blurring_frame_1d_kernels[mask_1d_index, :] = (
image_frame_1d_kernels
)
self.blurring_frame_1d_lengths[mask_1d_index] = (
image_frame_1d_indexes[image_frame_1d_indexes >= 0].shape[0]
)
mask_1d_index += 1

@staticmethod
Expand Down Expand Up @@ -317,33 +317,28 @@ def frame_at_coordinates_jit(coordinates, mask, mask_index_array, kernel_2d):

return frame, kernel_frame

def jax_convolve(self, image, blurring_image, method='auto'):
def jax_convolve(self, image, blurring_image, method="auto"):
slim_to_2D_index_image = jnp.nonzero(
jnp.logical_not(self.mask.array),
size=image.shape[0]
jnp.logical_not(self.mask.array), size=image.shape[0]
)
slim_to_2D_index_blurring = jnp.nonzero(
jnp.logical_not(self.blurring_mask),
size=blurring_image.shape[0]
jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0]
)
expanded_image_native = jnp.zeros(self.mask.shape)
expanded_image_native = expanded_image_native.at[
slim_to_2D_index_image
].set(image.array)
expanded_image_native = expanded_image_native.at[
slim_to_2D_index_blurring
].set(blurring_image.array)
expanded_image_native = expanded_image_native.at[slim_to_2D_index_image].set(
image.array
)
expanded_image_native = expanded_image_native.at[slim_to_2D_index_blurring].set(
blurring_image.array
)
kernel = np.array(self.kernel.native.array)
convolve_native = jax.scipy.signal.convolve(
expanded_image_native,
kernel,
mode='same',
method=method
expanded_image_native, kernel, mode="same", method=method
)
convolve_slim = convolve_native[slim_to_2D_index_image]
return convolve_slim

def convolve_image(self, image, blurring_image, jax_method='fft'):
def convolve_image(self, image, blurring_image, jax_method="fft"):
"""
For a given 1D array and blurring array, convolve the two using this convolver.
Expand Down Expand Up @@ -371,14 +366,10 @@ def exception_message():
self.blurring_mask is None,
lambda _: jax.debug.callback(exception_message),
lambda _: None,
None
None,
)

return self.jax_convolve(
image,
blurring_image,
method=jax_method
)
return self.jax_convolve(image, blurring_image, method=jax_method)

else:
if self.blurring_mask is None:
Expand Down
32 changes: 23 additions & 9 deletions autoarray/operators/over_sampling/over_sample_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def sub_slim_index_for_sub_native_index_from(sub_mask_2d: np.ndarray):
for sub_mask_y in range(sub_mask_2d.shape[0]):
for sub_mask_x in range(sub_mask_2d.shape[1]):
if sub_mask_2d[sub_mask_y, sub_mask_x] == False:
sub_slim_index_for_sub_native_index[
sub_mask_y, sub_mask_x
] = sub_mask_1d_index
sub_slim_index_for_sub_native_index[sub_mask_y, sub_mask_x] = (
sub_mask_1d_index
)
sub_mask_1d_index += 1

return sub_slim_index_for_sub_native_index
Expand Down Expand Up @@ -407,18 +407,32 @@ def grid_2d_slim_over_sampled_via_mask_from(
for x1 in range(sub):
if use_jax:
# while this makes it run, it is very, very slow
grid_slim = grid_slim.at[sub_index, 0].set(-(
y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0)
))
grid_slim = grid_slim.at[sub_index, 0].set(
-(
y_scaled
- y_sub_half
+ y1 * y_sub_step
+ (y_sub_step / 2.0)
)
)
grid_slim = grid_slim.at[sub_index, 1].set(
x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0)
x_scaled
- x_sub_half
+ x1 * x_sub_step
+ (x_sub_step / 2.0)
)
else:
grid_slim[sub_index, 0] = -(
y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0)
y_scaled
- y_sub_half
+ y1 * y_sub_step
+ (y_sub_step / 2.0)
)
grid_slim[sub_index, 1] = (
x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0)
x_scaled
- x_sub_half
+ x1 * x_sub_step
+ (x_sub_step / 2.0)
)
sub_index += 1

Expand Down
6 changes: 3 additions & 3 deletions autoarray/plot/wrap/base/colorbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def tick_labels_from(
cb_unit = units.colorbar_label

middle_index = (len(manual_tick_labels) - 1) // 2
manual_tick_labels[
middle_index
] = rf"{manual_tick_labels[middle_index]}{cb_unit}"
manual_tick_labels[middle_index] = (
rf"{manual_tick_labels[middle_index]}{cb_unit}"
)

return manual_tick_labels

Expand Down
2 changes: 1 addition & 1 deletion autoarray/plot/wrap/base/title.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Title(AbstractMatWrap):
def __init__(self, prefix: str = None, disable_log10_label : bool = False, **kwargs):
def __init__(self, prefix: str = None, disable_log10_label: bool = False, **kwargs):
"""
The settings used to customize the figure's title.
Expand Down
20 changes: 13 additions & 7 deletions autoarray/structures/grids/grid_2d_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def check_grid_2d(grid_2d: np.ndarray):

def check_grid_2d_and_mask_2d(grid_2d: np.ndarray, mask_2d: Mask2D):
if len(grid_2d.shape) == 2:

def exception_message():
raise exc.GridException(
f"""
Expand All @@ -68,17 +69,19 @@ def exception_message():
The mask number of pixels is {mask_2d.pixels_in_mask}.
"""
)

if use_jax:
jax.lax.cond(
grid_2d.shape[0] != mask_2d.pixels_in_mask,
lambda _: jax.debug.callback(exception_message),
lambda _: None,
None
None,
)
elif grid_2d.shape[0] != mask_2d.pixels_in_mask:
exception_message()

elif len(grid_2d.shape) == 3:

def exception_message():
raise exc.GridException(
f"""
Expand All @@ -89,12 +92,13 @@ def exception_message():
The mask shape_native is {mask_2d.shape_native}.
"""
)

if use_jax:
jax.lax.cond(
(grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native,
lambda _: jax.debug.callback(exception_message),
lambda _: None,
None
None,
)
elif (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native:
exception_message()
Expand Down Expand Up @@ -283,8 +287,12 @@ def grid_2d_slim_via_mask_from(
for x in range(mask_2d.shape[1]):
if not mask_2d[y, x]:
if use_jax:
grid_slim = grid_slim.at[index, 0].set(-(y - centres_scaled[0]) * pixel_scales[0])
grid_slim = grid_slim.at[index, 1].set((x - centres_scaled[1]) * pixel_scales[1])
grid_slim = grid_slim.at[index, 0].set(
-(y - centres_scaled[0]) * pixel_scales[0]
)
grid_slim = grid_slim.at[index, 1].set(
(x - centres_scaled[1]) * pixel_scales[1]
)
else:
grid_slim[index, 0] = -(y - centres_scaled[0]) * pixel_scales[0]
grid_slim[index, 1] = (x - centres_scaled[1]) * pixel_scales[1]
Expand Down Expand Up @@ -786,9 +794,7 @@ def grid_2d_slim_upscaled_from(
The pixel scale of the uniform grid that laid over the irregular grid of (y,x) coordinates.
"""

grid_2d_slim_upscaled = np.zeros(
shape=(grid_slim.shape[0] * upscale_factor**2, 2)
)
grid_2d_slim_upscaled = np.zeros(shape=(grid_slim.shape[0] * upscale_factor**2, 2))

upscale_index = 0

Expand Down
6 changes: 3 additions & 3 deletions autoarray/structures/grids/uniform_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,10 +845,10 @@ def distances_to_coordinate_from(
coordinate
The (y,x) coordinate from which the distance of every grid (y,x) coordinate is computed.
"""
squared_distance = self.squared_distances_to_coordinate_from(coordinate=coordinate)
distances = np.sqrt(
squared_distance.array
squared_distance = self.squared_distances_to_coordinate_from(
coordinate=coordinate
)
distances = np.sqrt(squared_distance.array)
return Array2D(values=distances, mask=self.mask)

def grid_2d_radial_projected_shape_slim_from(
Expand Down
Loading

0 comments on commit 280d574

Please sign in to comment.