Skip to content

Commit

Permalink
Merge pull request #132 from Jammy2211/feature/jax_remove_preload
Browse files Browse the repository at this point in the history
Feature/jax remove preload
  • Loading branch information
Jammy2211 authored Oct 25, 2024
2 parents 2143e94 + 280d574 commit c597f33
Show file tree
Hide file tree
Showing 39 changed files with 127 additions and 1,656 deletions.
1 change: 0 additions & 1 deletion autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from . import fixtures
from . import mock as m
from .numba_util import profile_func
from .preloads import Preloads
from .dataset import preprocess
from .dataset.abstract.dataset import AbstractDataset
from .dataset.abstract.w_tilde import AbstractWTilde
Expand Down
3 changes: 0 additions & 3 deletions autoarray/config/general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,3 @@ pixelization:
voronoi_nn_max_interpolation_neighbors: 300
structures:
native_binned_only: false # If True, data structures are only stored in their native and binned format. This is used to reduce memory usage in autocti.
test:
preloads_check_threshold: 1.0 # If the figure of merit of a fit with and without preloads is greater than this threshold, the check preload test fails and an exception raised for a model-fit.

2 changes: 1 addition & 1 deletion autoarray/config/visualize/include.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ include_1d:
mask: false # Include a Mask ?
origin: false # Include the (x,) origin of the data's coordinate system ?
include_2d:
border: true # Include the border of the mask (all pixels on the outside of the mask) ?
border: false # Include the border of the mask (all pixels on the outside of the mask) ?
grid: false # Include the data's 2D grid of (y,x) coordinates ?
mapper_image_plane_mesh_grid: false # For an Inversion, include the pixel centres computed in the image-plane / data frame?
mapper_source_plane_data_grid: false # For an Inversion, include the centres of the image-plane grid mapped to the source-plane / frame in source-plane figures?
Expand Down
11 changes: 0 additions & 11 deletions autoarray/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,6 @@ class PlottingException(Exception):
pass


class PreloadsException(Exception):
"""
Raises exceptions associated with the `preloads.py` module and `Preloads` class.
For example if the preloaded quantities lead to a change in figure of merit of a fit compared to a fit without
preloading.
"""

pass


class ProfilingException(Exception):
"""
Raises exceptions associated with in-built profiling tools (e.g. the `profile_func` decorator).
Expand Down
10 changes: 6 additions & 4 deletions autoarray/geometry/geometry_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,12 @@ def transform_grid_2d_to_reference_frame(
theta_coordinate_to_profile = np.arctan2(
shifted_grid_2d[:, 0], shifted_grid_2d[:, 1]
) - np.radians(angle)
return np.vstack([
radius * np.sin(theta_coordinate_to_profile),
radius * np.cos(theta_coordinate_to_profile)
]).T
return np.vstack(
[
radius * np.sin(theta_coordinate_to_profile),
radius * np.cos(theta_coordinate_to_profile),
]
).T


def transform_grid_2d_from_reference_frame(
Expand Down
31 changes: 6 additions & 25 deletions autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
dataset: Union[Imaging, Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Optional["Preloads"] = None,
run_time_dict: Optional[Dict] = None,
):
"""
Expand Down Expand Up @@ -70,17 +69,10 @@ def __init__(
input dataset's data and whose values are solved for via the inversion.
settings
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
preloads
Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation,
for example certain matrices used by the linear algebra could be preloaded.
run_time_dict
A dictionary which contains timing of certain functions calls which is used for profiling.
"""

from autoarray.preloads import Preloads

preloads = preloads or Preloads()

# try:
# import numba
# except ModuleNotFoundError:
Expand All @@ -98,7 +90,6 @@ def __init__(

self.settings = settings

self.preloads = preloads
self.run_time_dict = run_time_dict

@property
Expand Down Expand Up @@ -322,10 +313,6 @@ def operated_mapping_matrix(self) -> np.ndarray:
If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous
linear equations are solved simultaneously.
"""

if self.preloads.operated_mapping_matrix is not None:
return self.preloads.operated_mapping_matrix

return np.hstack(self.operated_mapping_matrix_list)

@cached_property
Expand Down Expand Up @@ -356,9 +343,6 @@ def regularization_matrix(self) -> Optional[np.ndarray]:
If the `settings.force_edge_pixels_to_zeros` is `True`, the edge pixels of each mapper in the inversion
are regularized so high their value is forced to zero.
"""
if self.preloads.regularization_matrix is not None:
return self.preloads.regularization_matrix

return block_diag(
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
)
Expand Down Expand Up @@ -509,12 +493,12 @@ def reconstruction(self) -> np.ndarray:

solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0])

solutions[
values_to_solve
] = inversion_util.reconstruction_positive_only_from(
data_vector=data_vector_input,
curvature_reg_matrix=curvature_reg_matrix_input,
settings=self.settings,
solutions[values_to_solve] = (
inversion_util.reconstruction_positive_only_from(
data_vector=data_vector_input,
curvature_reg_matrix=curvature_reg_matrix_input,
settings=self.settings,
)
)
return solutions
else:
Expand Down Expand Up @@ -735,9 +719,6 @@ def log_det_regularization_matrix_term(self) -> float:
if not self.has(cls=AbstractRegularization):
return 0.0

if self.preloads.log_det_regularization_matrix_term is not None:
return self.preloads.log_det_regularization_matrix_term

try:
lu = splu(csc_matrix(self.regularization_matrix_reduced))
diagL = lu.L.diagonal()
Expand Down
31 changes: 2 additions & 29 deletions autoarray/inversion/inversion/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde
from autoarray.inversion.inversion.settings import SettingsInversion
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.preloads import Preloads


def inversion_from(
dataset: Union[Imaging, Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = Preloads(),
run_time_dict: Optional[Dict] = None,
):
"""
Expand All @@ -51,9 +49,6 @@ def inversion_from(
input dataset's data and whose values are solved for via the inversion.
settings
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
preloads
Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation,
for example certain matrices used by the linear algebra could be preloaded.
run_time_dict
A dictionary which contains timing of certain functions calls which is used for profiling.
Expand All @@ -66,7 +61,6 @@ def inversion_from(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
run_time_dict=run_time_dict,
)

Expand All @@ -82,7 +76,6 @@ def inversion_imaging_from(
dataset,
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = Preloads(),
run_time_dict: Optional[Dict] = None,
):
"""
Expand Down Expand Up @@ -112,9 +105,6 @@ def inversion_imaging_from(
input dataset's data and whose values are solved for via the inversion.
settings
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
preloads
Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation,
for example certain matrices used by the linear algebra could be preloaded.
run_time_dict
A dictionary which contains timing of certain functions calls which is used for profiling.
Expand All @@ -127,34 +117,27 @@ def inversion_imaging_from(
for linear_obj in linear_obj_list
):
use_w_tilde = False
elif preloads.use_w_tilde is not None:
use_w_tilde = preloads.use_w_tilde
else:
use_w_tilde = settings.use_w_tilde

if not settings.use_w_tilde:
use_w_tilde = False

if use_w_tilde:
if preloads.w_tilde is not None:
w_tilde = preloads.w_tilde
else:
w_tilde = dataset.w_tilde
w_tilde = dataset.w_tilde

return InversionImagingWTilde(
dataset=dataset,
w_tilde=w_tilde,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
run_time_dict=run_time_dict,
)

return InversionImagingMapping(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
run_time_dict=run_time_dict,
)

Expand All @@ -163,7 +146,6 @@ def inversion_interferometer_from(
dataset: Union[Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = Preloads(),
run_time_dict: Optional[Dict] = None,
):
"""
Expand Down Expand Up @@ -197,9 +179,6 @@ def inversion_interferometer_from(
input dataset's data and whose values are solved for via the inversion.
settings
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
preloads
Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation,
for example certain matrices used by the linear algebra could be preloaded.
run_time_dict
A dictionary which contains timing of certain functions calls which is used for profiling.
Expand All @@ -222,17 +201,13 @@ def inversion_interferometer_from(

if not settings.use_linear_operators:
if use_w_tilde:
if preloads.w_tilde is not None:
w_tilde = preloads.w_tilde
else:
w_tilde = dataset.w_tilde
w_tilde = dataset.w_tilde

return InversionInterferometerWTilde(
dataset=dataset,
w_tilde=w_tilde,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
run_time_dict=run_time_dict,
)

Expand All @@ -241,7 +216,6 @@ def inversion_interferometer_from(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
run_time_dict=run_time_dict,
)

Expand All @@ -250,6 +224,5 @@ def inversion_interferometer_from(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
run_time_dict=run_time_dict,
)
44 changes: 9 additions & 35 deletions autoarray/inversion/inversion/imaging/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
dataset: Union[Imaging, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads=None,
run_time_dict: Optional[Dict] = None,
):
"""
Expand Down Expand Up @@ -65,22 +64,14 @@ def __init__(
input dataset's data and whose values are solved for via the inversion.
settings
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
preloads
Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation,
for example certain matrices used by the linear algebra could be preloaded.
run_time_dict
A dictionary which contains timing of certain functions calls which is used for profiling.
"""

from autoarray.preloads import Preloads

preloads = preloads or Preloads()

super().__init__(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
run_time_dict=run_time_dict,
)

Expand All @@ -104,11 +95,13 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
"""

return [
self.convolver.convolve_mapping_matrix(
mapping_matrix=linear_obj.mapping_matrix
(
self.convolver.convolve_mapping_matrix(
mapping_matrix=linear_obj.mapping_matrix
)
if linear_obj.operated_mapping_matrix_override is None
else self.linear_func_operated_mapping_matrix_dict[linear_obj]
)
if linear_obj.operated_mapping_matrix_override is None
else self.linear_func_operated_mapping_matrix_dict[linear_obj]
for linear_obj in self.linear_obj_list
]

Expand Down Expand Up @@ -140,12 +133,6 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict:
A dictionary mapping every linear function object to its operated mapping matrix.
"""

if self.preloads.linear_func_operated_mapping_matrix_dict is not None:
return self._updated_cls_key_dict_from(
cls=AbstractLinearObjFuncList,
preload_dict=self.preloads.linear_func_operated_mapping_matrix_dict,
)

linear_func_operated_mapping_matrix_dict = {}

for linear_func in self.cls_list_from(cls=AbstractLinearObjFuncList):
Expand All @@ -156,9 +143,9 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict:
mapping_matrix=linear_func.mapping_matrix
)

linear_func_operated_mapping_matrix_dict[
linear_func
] = operated_mapping_matrix
linear_func_operated_mapping_matrix_dict[linear_func] = (
operated_mapping_matrix
)

return linear_func_operated_mapping_matrix_dict

Expand Down Expand Up @@ -192,12 +179,6 @@ def data_linear_func_matrix_dict(self):
A matrix of shape [data_pixels, total_fixed_linear_functions] that for each data pixel, maps it to the sum
of the values of a linear object function convolved with the PSF kernel at the data pixel.
"""
if self.preloads.data_linear_func_matrix_dict is not None:
return self._updated_cls_key_dict_from(
cls=AbstractLinearObjFuncList,
preload_dict=self.preloads.data_linear_func_matrix_dict,
)

linear_func_list = self.cls_list_from(cls=AbstractLinearObjFuncList)

data_linear_func_matrix_dict = {}
Expand Down Expand Up @@ -237,13 +218,6 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict:
-------
A dictionary mapping every mapper object to its operated mapping matrix.
"""

if self.preloads.mapper_operated_mapping_matrix_dict is not None:
return self._updated_cls_key_dict_from(
cls=AbstractMapper,
preload_dict=self.preloads.mapper_operated_mapping_matrix_dict,
)

mapper_operated_mapping_matrix_dict = {}

for mapper in self.cls_list_from(cls=AbstractMapper):
Expand Down
Loading

0 comments on commit c597f33

Please sign in to comment.