From c3ae228ea63a4f193fc3c0a87d46b34012309210 Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Tue, 30 Jul 2024 15:54:00 +0100 Subject: [PATCH 1/3] Allows mass profile to be given directly as instances (still allow string names for backward compat.) --- herculens/MassModel/mass_model_base.py | 126 ++++++++++++------------- herculens/MassModel/profile_mapping.py | 37 ++++++++ 2 files changed, 100 insertions(+), 63 deletions(-) create mode 100644 herculens/MassModel/profile_mapping.py diff --git a/herculens/MassModel/mass_model_base.py b/herculens/MassModel/mass_model_base.py index ee94087..40b9efd 100644 --- a/herculens/MassModel/mass_model_base.py +++ b/herculens/MassModel/mass_model_base.py @@ -7,25 +7,20 @@ __author__ = 'sibirrer', 'austinpeel', 'aymgal' -from herculens.MassModel.Profiles import (gaussian_potential, point_mass, multipole, - shear, sie, sis, nie, epl, pixelated) +import herculens.MassModel.profile_mapping as pm from herculens.Util import util -__all__ = ['MassModelBase'] -SUPPORTED_MODELS = [ - 'EPL', 'NIE', 'SIE', 'SIS', 'GAUSSIAN', 'POINT_MASS', - 'SHEAR', 'SHEAR_GAMMA_PSI', 'MULTIPOLE', - 'PIXELATED', 'PIXELATED_DIRAC', 'PIXELATED_FIXED', -] +__all__ = ['MassModelBase'] -# TODO: create parent for methods shared between MassProfileBase and LightProfileBase +SUPPORTED_MODELS = pm.SUPPORTED_MODELS +STRING_MAPPING = pm.STRING_MAPPING class MassModelBase(object): """Base class for managing lens models in single- or multi-plane lensing.""" - def __init__(self, lens_model_list, + def __init__(self, profile_list, kwargs_pixelated=None, no_complex_numbers=True, pixel_interpol='fast_bilinear', @@ -33,85 +28,90 @@ def __init__(self, lens_model_list, kwargs_pixel_grid_fixed=None): """Create a MassProfileBase object. + NOTE: extra keyword arguments are given to the corresponding profile class + only when that profile is given as a string instead of a class instance. + Parameters ---------- - lens_model_list : list of str + profile_list : list of str or profile class instance Lens model profile types. """ self.func_list, self._pix_idx = self._load_model_instances( - lens_model_list, pixel_derivative_type, pixel_interpol, + profile_list, pixel_derivative_type, pixel_interpol, no_complex_numbers, kwargs_pixel_grid_fixed ) self._num_func = len(self.func_list) - self._model_list = lens_model_list + self._model_list = profile_list if kwargs_pixelated is None: kwargs_pixelated = {} self._kwargs_pixelated = kwargs_pixelated - def _load_model_instances(self, - lens_model_list, pixel_derivative_type, pixel_interpol, + def _load_model_instances( + self, profile_list, pixel_derivative_type, pixel_interpol, no_complex_numbers, kwargs_pixel_grid_fixed, ): func_list = [] - imported_classes = {} pix_idx = None - for idx, lens_type in enumerate(lens_model_list): - # These models require a new instance per profile as certain pre-computations - # are relevant per individual profile - if lens_type in ['PIXELATED', 'PIXELATED_DIRAC']: - mass_model_class = self._import_class( - lens_type, pixel_derivative_type=pixel_derivative_type, pixel_interpol=pixel_interpol + for idx, profile_type in enumerate(profile_list): + # NOTE: Passing string is supported for backward-compatibility only + if isinstance(profile_type, str): + # These models require a new instance per profile as certain pre-computations + # are relevant per individual profile + profile_class = self.get_class_from_string( + profile_type, + kwargs_pixel_grid_fixed=kwargs_pixel_grid_fixed, + pixel_derivative_type=pixel_derivative_type, + pixel_interpol=pixel_interpol, + no_complex_numbers=no_complex_numbers, ) - pix_idx = idx + if profile_type in ['PIXELATED', 'PIXELATED_DIRAC']: + pix_idx = idx + + # NOTE: this is the new preferred way: passing the profile as a class + elif self.is_mass_profile_class(profile_type): + profile_class = profile_type + if isinstance( + profile_class, + (STRING_MAPPING['PIXELATED'], STRING_MAPPING['PIXELATED_DIRAC']) + ): + pix_idx = idx else: - if lens_type not in imported_classes.keys(): - mass_model_class = self._import_class( - lens_type, no_complex_numbers=no_complex_numbers, - kwargs_pixel_grid_fixed=kwargs_pixel_grid_fixed, - ) - imported_classes.update({lens_type: mass_model_class}) - else: - mass_model_class = imported_classes[lens_type] - func_list.append(mass_model_class) + raise ValueError("Each profile can either be a string or " + "directly the profile class (not instantiated).") + func_list.append(profile_class) return func_list, pix_idx + + @staticmethod + def is_mass_profile_class(profile): + """Simply checks that the mass profile has the required methods""" + return ( + hasattr(profile, 'function') and + hasattr(profile, 'derivatives') and + hasattr(profile, 'hessian') + ) @staticmethod - def _import_class( - lens_type, pixel_derivative_type=None, pixel_interpol=None, - no_complex_numbers=None, kwargs_pixel_grid_fixed=None + def get_class_from_string( + profile_string, pixel_derivative_type=None, pixel_interpol=None, + no_complex_numbers=None, kwargs_pixel_grid_fixed=None, ): """Get the lens profile class of the corresponding type.""" - if lens_type == 'GAUSSIAN': - return gaussian_potential.Gaussian() - elif lens_type == 'SHEAR': - return shear.Shear() - elif lens_type == 'SHEAR_GAMMA_PSI': - return shear.ShearGammaPsi() - elif lens_type == 'POINT_MASS': - return point_mass.PointMass() - elif lens_type == 'NIE': - return nie.NIE() - elif lens_type == 'SIE': - return sie.SIE() - elif lens_type == 'SIS': - return sis.SIS() - elif lens_type == 'EPL': - return epl.EPL(no_complex_numbers=no_complex_numbers) - elif lens_type == 'MULTIPOLE': - return multipole.Multipole() - elif lens_type == 'PIXELATED': - return pixelated.PixelatedPotential(derivative_type=pixel_derivative_type, interpolation_type=pixel_interpol) - elif lens_type == 'PIXELATED_DIRAC': - return pixelated.PixelatedPotentialDirac() - elif lens_type == 'PIXELATED_FIXED': + if profile_string not in list(STRING_MAPPING.keys()): + raise ValueError(f"{profile_string} is not a valid lens model. " + f"Supported types are {SUPPORTED_MODELS}") + profile_class = STRING_MAPPING[profile_string] + # treats the few special cases that require user settings + if profile_string == 'EPL': + return profile_class(no_complex_numbers=no_complex_numbers) + elif profile_string == 'PIXELATED': + return profile_class(derivative_type=pixel_derivative_type, interpolation_type=pixel_interpol) + elif profile_string == 'PIXELATED_FIXED': if kwargs_pixel_grid_fixed is None: raise ValueError("At least one pixel grid must be provided to use 'PIXELATED_FIXED' profile") - return pixelated.PixelatedFixed(**kwargs_pixel_grid_fixed) - else: - err_msg = (f"{lens_type} is not a valid lens model. " + - f"Supported types are {SUPPORTED_MODELS}") - raise ValueError(err_msg) + return profile_class(**kwargs_pixel_grid_fixed) + # all remaining profile takes no extra arguments + return profile_class() def _bool_list(self, k): return util.convert_bool_list(n=self._num_func, k=k) diff --git a/herculens/MassModel/profile_mapping.py b/herculens/MassModel/profile_mapping.py new file mode 100644 index 0000000..4a5eabd --- /dev/null +++ b/herculens/MassModel/profile_mapping.py @@ -0,0 +1,37 @@ +# NOTE: this file is useful for backward-compatibility only, +# as the preferred way is now to pass the profile class directly +# to the MassModel() constructor. + +from herculens.MassModel.Profiles import ( + gaussian_potential, + point_mass, + multipole, + shear, + sie, + sis, + nie, + epl, + pixelated +) + +SUPPORTED_MODELS = [ + 'EPL', 'NIE', 'SIE', 'SIS', 'GAUSSIAN', 'POINT_MASS', + 'SHEAR', 'SHEAR_GAMMA_PSI', 'MULTIPOLE', + 'PIXELATED', 'PIXELATED_DIRAC', 'PIXELATED_FIXED', +] + +# mapping between the string name to the mass profile class. +STRING_MAPPING = { + 'EPL': epl.EPL, + 'NIE': nie.NIE, + 'SIE': sie.SIE, + 'SIS': sis.SIS, + 'GAUSSIAN': gaussian_potential.Gaussian, + 'POINT_MASS': point_mass.PointMass, + 'SHEAR': shear.Shear, + 'SHEAR_GAMMA_PSI': shear.ShearGammaPsi, + 'MULTIPOLE': multipole.Multipole, + 'PIXELATED': pixelated.PixelatedPotential, + 'PIXELATED_DIRAC': pixelated.PixelatedPotentialDirac, + 'PIXELATED_FIXED': pixelated.PixelatedFixed, +} From cb508e68e42976e891ff875b6e4bd926df32b4d4 Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Wed, 31 Jul 2024 00:32:24 +0100 Subject: [PATCH 2/3] Improve base mass/light classes behavior to handle string and instance inputs --- herculens/LightModel/Profiles/gaussian.py | 2 +- herculens/LightModel/light_model.py | 28 ++-- herculens/LightModel/light_model_base.py | 160 ++++++++++++++-------- herculens/LightModel/profile_mapping.py | 26 ++++ herculens/MassModel/mass_model.py | 17 +-- herculens/MassModel/mass_model_base.py | 67 +++++---- herculens/MassModel/profile_mapping.py | 54 ++++---- 7 files changed, 218 insertions(+), 136 deletions(-) create mode 100644 herculens/LightModel/profile_mapping.py diff --git a/herculens/LightModel/Profiles/gaussian.py b/herculens/LightModel/Profiles/gaussian.py index 059aa7a..40b4ad6 100644 --- a/herculens/LightModel/Profiles/gaussian.py +++ b/herculens/LightModel/Profiles/gaussian.py @@ -13,7 +13,7 @@ from herculens.Util import param_util -__all__ = ['Gaussian'] +__all__ = ['Gaussian', 'GaussianEllipse'] class Gaussian(object): diff --git a/herculens/LightModel/light_model.py b/herculens/LightModel/light_model.py index ebbf4bf..e3b1cdf 100644 --- a/herculens/LightModel/light_model.py +++ b/herculens/LightModel/light_model.py @@ -7,13 +7,9 @@ __author__ = 'sibirrer', 'austinpeel', 'aymgal' -import numpy as np import jax.numpy as jnp -# from functools import partial -# from jax import jit from herculens.LightModel.light_model_base import LightModelBase -from herculens.Util import util __all__ = ['LightModel'] @@ -31,16 +27,20 @@ class LightModel(LightModelBase): for a given set of parameters. """ - def __init__(self, light_model_list, smoothing=0.001, - shapelets_n_max=4, superellipse_exponent=2, - kwargs_pixelated=None, **kwargs): - """Create a LightModel object.""" - self.profile_type_list = light_model_list - super(LightModel, self).__init__(self.profile_type_list, smoothing=smoothing, - shapelets_n_max=shapelets_n_max, - superellipse_exponent=superellipse_exponent, - kwargs_pixelated=kwargs_pixelated, - **kwargs) + def __init__(self, profile_list, **kwargs): + """Create a LightModel object. + + Parameters + ---------- + profile_list : list of strings or profile instances + List of light profiles. + kwargs_pixelated : dictionary for settings related to PIXELATED profiles. + """ + if not isinstance(profile_list, (list, tuple)): + # useful when using a single profile + profile_list = [profile_list] + self.profile_type_list = profile_list + super(LightModel, self).__init__(self.profile_type_list, **kwargs) def surface_brightness(self, x, y, kwargs_list, k=None, pixels_x_coord=None, pixels_y_coord=None): diff --git a/herculens/LightModel/light_model_base.py b/herculens/LightModel/light_model_base.py index 2da9e73..281fb72 100644 --- a/herculens/LightModel/light_model_base.py +++ b/herculens/LightModel/light_model_base.py @@ -10,85 +10,129 @@ import numpy as np import jax.numpy as jnp -from herculens.LightModel.Profiles import (sersic, pixelated, uniform, gaussian, multipole) +import herculens.LightModel.profile_mapping as pm from herculens.Util import util __all__ = ['LightModelBase'] -SUPPORTED_MODELS = [ - 'GAUSSIAN', 'GAUSSIAN_ELLIPSE', - 'SERSIC', 'SERSIC_ELLIPSE', 'SERSIC_SUPERELLIPSE', - 'SHAPELETS', 'UNIFORM', 'PIXELATED', - 'MULTIPOLE', -] +SUPPORTED_MODELS = pm.SUPPORTED_MODELS +STRING_MAPPING = pm.STRING_MAPPING class LightModelBase(object): """Base class for source and lens light models.""" # TODO: instead of settings for creating PixelGrid objects, pass directly the object to the LightModel - def __init__(self, light_model_list, smoothing=0.001, - shapelets_n_max=4, superellipse_exponent=2, - pixel_interpol='bilinear', pixel_adaptive_grid=False, - pixel_allow_extrapolation=False, kwargs_pixelated=None): + def __init__(self, profile_list, kwargs_pixelated=None, + **profile_specific_kwargs): """Create a LightModelBase object. + NOTE: the extra keyword arguments are given to the corresponding profile class + only when that profile is given as a string instead of a class instance. + Parameters ---------- - light_model_list : list of str - Light model types. - smoothing : float - Smoothing factor for some models (deprecated). - pixel_interpol : string - Type of interpolation for 'PIXELATED' profiles: 'bilinear' or 'bicubic' - pixel_allow_extrapolation : bool - For 'PIXELATED' profiles, wether or not to extrapolate flux values outside the chosen region - otherwise force values to be zero. + profile_list : list of strings or profile instances + List of mass profiles. If not a list, wrap the passed argument in a list. kwargs_pixelated : dict - Settings related to the creation of the pixelated grid. See herculens.PixelGrid.create_model_grid for details + Settings related to the creation of the pixelated grid. + See herculens.PixelGrid.create_model_grid for details. + profile_specific_kwargs : dict + See docstring for get_class_from_string(). """ - func_list = [] - pix_idx = None - for idx, profile_type in enumerate(light_model_list): - if profile_type == 'GAUSSIAN': - func_list.append(gaussian.Gaussian()) - elif profile_type == 'GAUSSIAN_ELLIPSE': - func_list.append(gaussian.GaussianEllipse()) - elif profile_type == 'SERSIC': - func_list.append(sersic.Sersic(smoothing)) - elif profile_type == 'SERSIC_ELLIPSE': - func_list.append(sersic.SersicElliptic(smoothing, exponent=2)) - elif profile_type == 'SERSIC_SUPERELLIPSE': - func_list.append(sersic.SersicElliptic(smoothing, exponent=superellipse_exponent)) - elif profile_type == 'UNIFORM': - func_list.append(uniform.Uniform()) - elif profile_type == 'MULTIPOLE': - func_list.append(multipole.Multipole()) - elif profile_type == 'PIXELATED': - if pix_idx is not None: - raise ValueError("Multiple pixelated profiles is currently not supported.") - func_list.append( - pixelated.Pixelated( - interpolation_type=pixel_interpol, - allow_extrapolation=pixel_allow_extrapolation, - adaptive_grid=pixel_adaptive_grid, - ) - ) - pix_idx = idx - elif profile_type == 'SHAPELETS': - from herculens.LightModel.Profiles import shapelets # prevent importing GigaLens if not used - func_list.append(shapelets.Shapelets(shapelets_n_max)) - else: - err_msg = (f"No light model of type {profile_type} found. " + - f"Supported types are: {SUPPORTED_MODELS}") - raise ValueError(err_msg) - self.func_list = func_list + self.func_list, self._pix_idx = self._load_model_instances( + profile_list, **profile_specific_kwargs + ) self._num_func = len(self.func_list) - self._pix_idx = pix_idx + self._model_list = profile_list if kwargs_pixelated is None: kwargs_pixelated = {} self._kwargs_pixelated = kwargs_pixelated + + def _load_model_instances( + self, profile_list, **profile_specific_kwargs, + ): + func_list = [] + pix_idx = None + for idx, profile_type in enumerate(profile_list): + if isinstance(profile_type, str): + # passing string is supported for backward-compatibility only + profile_class = self.get_class_from_string( + profile_type, + **profile_specific_kwargs, + ) + if profile_type in ['PIXELATED']: + pix_idx = idx + + # this is the new preferred way: passing the profile as a class + elif self.is_light_profile_class(profile_type): + profile_class = profile_type + if isinstance(profile_class, STRING_MAPPING['PIXELATED']): + pix_idx = idx + else: + raise ValueError("Each profile can either be a string or " + "directly the profile instance.") + func_list.append(profile_class) + return func_list, pix_idx + + @staticmethod + def is_light_profile_class(profile): + """Simply checks that the mass profile has the required methods""" + return hasattr(profile, 'function') + + @staticmethod + def get_class_from_string( + profile_string, + smoothing=0.001, + shapelets_n_max=4, + superellipse_exponent=2, + pixel_interpol='bilinear', + pixel_adaptive_grid=False, + pixel_allow_extrapolation=False, + ): + """ + Get the profile class of the corresponding type. + Keyword arguments are related to specific profile types. + + Parameters + ---------- + smoothing : float + Smoothing factor for some models. + shapelets_n_max : int + Maximal order of the shapelets basis set. + superellipse_exponent : int, float + Exponent for super-elliptical profiles (e.g. 'SERSIC_SUPERELLIPSE'). + pixel_interpol : string + Type of interpolation for 'PIXELATED' profiles: 'bilinear' or 'bicubic' + pixel_adaptive_grid : bool + Whether or not the pixelated light profile is defined on a grid + whose extent is adapted based on other model components. + pixel_allow_extrapolation : bool + Wether or not to allow the interpolator to predict values outside + the field of view of the pixelated profile + """ + print("SUPPORTED_MODELS", SUPPORTED_MODELS) + raise + if profile_string in SUPPORTED_MODELS: + profile_class = STRING_MAPPING[profile_string] + # treats the few special cases that require user settings + if profile_string == 'SERSIC': + return profile_class(smoothing=smoothing) + elif profile_string == 'SERSIC_ELLIPSE': + return profile_class(smoothing=smoothing, exponent=2) + elif profile_string == 'SERSIC_SUPERELLIPSE': + return profile_class(smoothing=smoothing, exponent=superellipse_exponent) + elif profile_string == 'SHAPELETS': + return profile_class(shapelets_n_max) + elif profile_string == 'PIXELATED': + return profile_class(interpolation_type=pixel_interpol, + allow_extrapolation=pixel_allow_extrapolation, + adaptive_grid=pixel_adaptive_grid) + else: + raise ValueError(f"Could not load profile type '{profile_string}'.") + # all remaining profiles take no extra arguments + return profile_class() @property def param_name_list(self): diff --git a/herculens/LightModel/profile_mapping.py b/herculens/LightModel/profile_mapping.py new file mode 100644 index 0000000..5b28572 --- /dev/null +++ b/herculens/LightModel/profile_mapping.py @@ -0,0 +1,26 @@ +# NOTE: this file is useful for backward-compatibility only, +# as the preferred way is now to pass the profile class directly +# to the LightModel() constructor. + +from herculens.LightModel.Profiles.sersic import (Sersic, SersicElliptic) +from herculens.LightModel.Profiles.multipole import Multipole +from herculens.LightModel.Profiles.gaussian import (Gaussian, GaussianEllipse) +from herculens.LightModel.Profiles.pixelated import Pixelated +from herculens.LightModel.Profiles.uniform import Uniform +from herculens.LightModel.Profiles.shapelets import Shapelets + + +# mapping between the string name to the mass profile class. +STRING_MAPPING = { + 'SERSIC': Sersic, + 'SERSIC_ELLIPSE': SersicElliptic, + 'GAUSSIAN': Gaussian, + 'GAUSSIAN_ELLIPSE': GaussianEllipse, + 'MULTIPOLE': Multipole, + 'PIXELATED': Pixelated, + 'UNIFORM': Uniform, + 'SHAPELETS': Shapelets +} + +SUPPORTED_MODELS = list(STRING_MAPPING.keys()) +print("INSIDE", SUPPORTED_MODELS) \ No newline at end of file diff --git a/herculens/MassModel/mass_model.py b/herculens/MassModel/mass_model.py index 2eb079e..9fb41a3 100644 --- a/herculens/MassModel/mass_model.py +++ b/herculens/MassModel/mass_model.py @@ -18,18 +18,19 @@ class MassModel(MassModelBase): """An arbitrary list of lens models.""" - def __init__(self, mass_model_list, kwargs_pixelated=None, **kwargs): + def __init__(self, profile_list, **kwargs): """Create a MassModel object. Parameters ---------- - mass_model_list : list of str - Lens model profile names. - kwargs_pixelated : dictionary for settings related to PIXELATED profiles. - """ - self.profile_type_list = mass_model_list - super().__init__(self.profile_type_list, kwargs_pixelated=kwargs_pixelated, - **kwargs) + profile_list : list of strings or profile instances + List of mass profiles. + """ + if not isinstance(profile_list, (list, tuple)): + # useful when using a single profile + profile_list = [profile_list] + self.profile_type_list = profile_list + super().__init__(self.profile_type_list, **kwargs) def ray_shooting(self, x, y, kwargs, k=None): """ diff --git a/herculens/MassModel/mass_model_base.py b/herculens/MassModel/mass_model_base.py index 40b9efd..8fc165e 100644 --- a/herculens/MassModel/mass_model_base.py +++ b/herculens/MassModel/mass_model_base.py @@ -20,26 +20,25 @@ class MassModelBase(object): """Base class for managing lens models in single- or multi-plane lensing.""" - def __init__(self, profile_list, - kwargs_pixelated=None, - no_complex_numbers=True, - pixel_interpol='fast_bilinear', - pixel_derivative_type='interpol', - kwargs_pixel_grid_fixed=None): + def __init__(self, profile_list, kwargs_pixelated=None, + **profile_specific_kwargs): """Create a MassProfileBase object. - NOTE: extra keyword arguments are given to the corresponding profile class + NOTE: the extra keyword arguments are given to the corresponding profile class only when that profile is given as a string instead of a class instance. Parameters ---------- - profile_list : list of str or profile class instance - Lens model profile types. - + profile_list : list of strings or profile instances + List of mass profiles. If not a list, wrap the passed argument in a list. + kwargs_pixelated : dict + Settings related to the creation of the pixelated grid. + See herculens.PixelGrid.create_model_grid for details. + profile_specific_kwargs : dict + See docstring for get_class_from_string(). """ self.func_list, self._pix_idx = self._load_model_instances( - profile_list, pixel_derivative_type, pixel_interpol, - no_complex_numbers, kwargs_pixel_grid_fixed + profile_list, **profile_specific_kwargs, ) self._num_func = len(self.func_list) self._model_list = profile_list @@ -48,27 +47,21 @@ def __init__(self, profile_list, self._kwargs_pixelated = kwargs_pixelated def _load_model_instances( - self, profile_list, pixel_derivative_type, pixel_interpol, - no_complex_numbers, kwargs_pixel_grid_fixed, + self, profile_list, **profile_specific_kwargs, ): func_list = [] pix_idx = None for idx, profile_type in enumerate(profile_list): - # NOTE: Passing string is supported for backward-compatibility only if isinstance(profile_type, str): - # These models require a new instance per profile as certain pre-computations - # are relevant per individual profile + # passing string is supported for backward-compatibility only profile_class = self.get_class_from_string( profile_type, - kwargs_pixel_grid_fixed=kwargs_pixel_grid_fixed, - pixel_derivative_type=pixel_derivative_type, - pixel_interpol=pixel_interpol, - no_complex_numbers=no_complex_numbers, + **profile_specific_kwargs, ) if profile_type in ['PIXELATED', 'PIXELATED_DIRAC']: pix_idx = idx - # NOTE: this is the new preferred way: passing the profile as a class + # this is the new preferred way: passing the profile as a class elif self.is_mass_profile_class(profile_type): profile_class = profile_type if isinstance( @@ -78,7 +71,7 @@ def _load_model_instances( pix_idx = idx else: raise ValueError("Each profile can either be a string or " - "directly the profile class (not instantiated).") + "directly the profile class.") func_list.append(profile_class) return func_list, pix_idx @@ -93,10 +86,30 @@ def is_mass_profile_class(profile): @staticmethod def get_class_from_string( - profile_string, pixel_derivative_type=None, pixel_interpol=None, - no_complex_numbers=None, kwargs_pixel_grid_fixed=None, + profile_string, + pixel_derivative_type=None, + pixel_interpol=None, + no_complex_numbers=None, + kwargs_pixel_grid_fixed=None, ): - """Get the lens profile class of the corresponding type.""" + """ + Get the lens profile class of the corresponding type. + Keyword arguments are related to specific profile types. + + Parameters + ---------- + smoothing : float + Smoothing factor for some models (deprecated). + pixel_interpol : string + Type of interpolation for 'PIXELATED' profiles: 'fast_bilinear' or 'bicubic' + pixel_derivative_type : str + Type of interpolation: 'interpol' or 'autodiff' + no_complex_numbers : bool + Use or not complex number in the EPL's deflection computation. + kwargs_pixel_grid_fixed : dict + Settings related to the creation of the pixelated grid for profile type 'PIXELATED_FIXED'. + See herculens.PixelGrid.create_model_grid for details. + """ if profile_string not in list(STRING_MAPPING.keys()): raise ValueError(f"{profile_string} is not a valid lens model. " f"Supported types are {SUPPORTED_MODELS}") @@ -110,7 +123,7 @@ def get_class_from_string( if kwargs_pixel_grid_fixed is None: raise ValueError("At least one pixel grid must be provided to use 'PIXELATED_FIXED' profile") return profile_class(**kwargs_pixel_grid_fixed) - # all remaining profile takes no extra arguments + # all remaining profiles take no extra arguments return profile_class() def _bool_list(self, k): diff --git a/herculens/MassModel/profile_mapping.py b/herculens/MassModel/profile_mapping.py index 4a5eabd..b815077 100644 --- a/herculens/MassModel/profile_mapping.py +++ b/herculens/MassModel/profile_mapping.py @@ -2,36 +2,34 @@ # as the preferred way is now to pass the profile class directly # to the MassModel() constructor. -from herculens.MassModel.Profiles import ( - gaussian_potential, - point_mass, - multipole, - shear, - sie, - sis, - nie, - epl, - pixelated +from herculens.MassModel.Profiles.gaussian_potential import Gaussian +from herculens.MassModel.Profiles.point_mass import PointMass +from herculens.MassModel.Profiles.multipole import Multipole +from herculens.MassModel.Profiles.shear import Shear, ShearGammaPsi +from herculens.MassModel.Profiles.sis import SIS +from herculens.MassModel.Profiles.sie import SIE +from herculens.MassModel.Profiles.nie import NIE +from herculens.MassModel.Profiles.epl import EPL +from herculens.MassModel.Profiles.pixelated import ( + PixelatedPotential, + PixelatedFixed, + PixelatedPotentialDirac, ) -SUPPORTED_MODELS = [ - 'EPL', 'NIE', 'SIE', 'SIS', 'GAUSSIAN', 'POINT_MASS', - 'SHEAR', 'SHEAR_GAMMA_PSI', 'MULTIPOLE', - 'PIXELATED', 'PIXELATED_DIRAC', 'PIXELATED_FIXED', -] - # mapping between the string name to the mass profile class. STRING_MAPPING = { - 'EPL': epl.EPL, - 'NIE': nie.NIE, - 'SIE': sie.SIE, - 'SIS': sis.SIS, - 'GAUSSIAN': gaussian_potential.Gaussian, - 'POINT_MASS': point_mass.PointMass, - 'SHEAR': shear.Shear, - 'SHEAR_GAMMA_PSI': shear.ShearGammaPsi, - 'MULTIPOLE': multipole.Multipole, - 'PIXELATED': pixelated.PixelatedPotential, - 'PIXELATED_DIRAC': pixelated.PixelatedPotentialDirac, - 'PIXELATED_FIXED': pixelated.PixelatedFixed, + 'EPL': EPL, + 'NIE': NIE, + 'SIE': SIE, + 'SIS': SIS, + 'GAUSSIAN': Gaussian, + 'POINT_MASS': PointMass, + 'SHEAR': Shear, + 'SHEAR_GAMMA_PSI': ShearGammaPsi, + 'MULTIPOLE': Multipole, + 'PIXELATED': PixelatedPotential, + 'PIXELATED_DIRAC': PixelatedPotentialDirac, + 'PIXELATED_FIXED': PixelatedFixed, } + +SUPPORTED_MODELS = list(STRING_MAPPING.keys()) From 8f853abdeac9a8e65e506307296b1c7f4b010ba0 Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Wed, 31 Jul 2024 00:37:34 +0100 Subject: [PATCH 3/3] Fix bugs --- herculens/LightModel/light_model_base.py | 2 -- herculens/MassModel/mass_model_base.py | 26 ++++++++++++------------ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/herculens/LightModel/light_model_base.py b/herculens/LightModel/light_model_base.py index 281fb72..89f233d 100644 --- a/herculens/LightModel/light_model_base.py +++ b/herculens/LightModel/light_model_base.py @@ -112,8 +112,6 @@ def get_class_from_string( Wether or not to allow the interpolator to predict values outside the field of view of the pixelated profile """ - print("SUPPORTED_MODELS", SUPPORTED_MODELS) - raise if profile_string in SUPPORTED_MODELS: profile_class = STRING_MAPPING[profile_string] # treats the few special cases that require user settings diff --git a/herculens/MassModel/mass_model_base.py b/herculens/MassModel/mass_model_base.py index 8fc165e..29801af 100644 --- a/herculens/MassModel/mass_model_base.py +++ b/herculens/MassModel/mass_model_base.py @@ -110,19 +110,19 @@ def get_class_from_string( Settings related to the creation of the pixelated grid for profile type 'PIXELATED_FIXED'. See herculens.PixelGrid.create_model_grid for details. """ - if profile_string not in list(STRING_MAPPING.keys()): - raise ValueError(f"{profile_string} is not a valid lens model. " - f"Supported types are {SUPPORTED_MODELS}") - profile_class = STRING_MAPPING[profile_string] - # treats the few special cases that require user settings - if profile_string == 'EPL': - return profile_class(no_complex_numbers=no_complex_numbers) - elif profile_string == 'PIXELATED': - return profile_class(derivative_type=pixel_derivative_type, interpolation_type=pixel_interpol) - elif profile_string == 'PIXELATED_FIXED': - if kwargs_pixel_grid_fixed is None: - raise ValueError("At least one pixel grid must be provided to use 'PIXELATED_FIXED' profile") - return profile_class(**kwargs_pixel_grid_fixed) + if profile_string in SUPPORTED_MODELS: + profile_class = STRING_MAPPING[profile_string] + # treats the few special cases that require user settings + if profile_string == 'EPL': + return profile_class(no_complex_numbers=no_complex_numbers) + elif profile_string == 'PIXELATED': + return profile_class(derivative_type=pixel_derivative_type, interpolation_type=pixel_interpol) + elif profile_string == 'PIXELATED_FIXED': + if kwargs_pixel_grid_fixed is None: + raise ValueError("At least one pixel grid must be provided to use 'PIXELATED_FIXED' profile") + return profile_class(**kwargs_pixel_grid_fixed) + else: + raise ValueError(f"Could not load profile type '{profile_string}'.") # all remaining profiles take no extra arguments return profile_class()