Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing profile instances directly to MassModel and LightModel #36

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion herculens/LightModel/Profiles/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from herculens.Util import param_util


__all__ = ['Gaussian']
__all__ = ['Gaussian', 'GaussianEllipse']


class Gaussian(object):
Expand Down
28 changes: 14 additions & 14 deletions herculens/LightModel/light_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@
__author__ = 'sibirrer', 'austinpeel', 'aymgal'


import numpy as np
import jax.numpy as jnp
# from functools import partial
# from jax import jit

from herculens.LightModel.light_model_base import LightModelBase
from herculens.Util import util


__all__ = ['LightModel']
Expand All @@ -31,16 +27,20 @@ class LightModel(LightModelBase):
for a given set of parameters.

"""
def __init__(self, light_model_list, smoothing=0.001,
shapelets_n_max=4, superellipse_exponent=2,
kwargs_pixelated=None, **kwargs):
"""Create a LightModel object."""
self.profile_type_list = light_model_list
super(LightModel, self).__init__(self.profile_type_list, smoothing=smoothing,
shapelets_n_max=shapelets_n_max,
superellipse_exponent=superellipse_exponent,
kwargs_pixelated=kwargs_pixelated,
**kwargs)
def __init__(self, profile_list, **kwargs):
"""Create a LightModel object.

Parameters
----------
profile_list : list of strings or profile instances
List of light profiles.
kwargs_pixelated : dictionary for settings related to PIXELATED profiles.
"""
if not isinstance(profile_list, (list, tuple)):
# useful when using a single profile
profile_list = [profile_list]
self.profile_type_list = profile_list
super(LightModel, self).__init__(self.profile_type_list, **kwargs)

def surface_brightness(self, x, y, kwargs_list, k=None,
pixels_x_coord=None, pixels_y_coord=None):
Expand Down
158 changes: 100 additions & 58 deletions herculens/LightModel/light_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,85 +10,127 @@
import numpy as np
import jax.numpy as jnp

from herculens.LightModel.Profiles import (sersic, pixelated, uniform, gaussian, multipole)
import herculens.LightModel.profile_mapping as pm
from herculens.Util import util

__all__ = ['LightModelBase']


SUPPORTED_MODELS = [
'GAUSSIAN', 'GAUSSIAN_ELLIPSE',
'SERSIC', 'SERSIC_ELLIPSE', 'SERSIC_SUPERELLIPSE',
'SHAPELETS', 'UNIFORM', 'PIXELATED',
'MULTIPOLE',
]
SUPPORTED_MODELS = pm.SUPPORTED_MODELS
STRING_MAPPING = pm.STRING_MAPPING


class LightModelBase(object):
"""Base class for source and lens light models."""
# TODO: instead of settings for creating PixelGrid objects, pass directly the object to the LightModel
def __init__(self, light_model_list, smoothing=0.001,
shapelets_n_max=4, superellipse_exponent=2,
pixel_interpol='bilinear', pixel_adaptive_grid=False,
pixel_allow_extrapolation=False, kwargs_pixelated=None):
def __init__(self, profile_list, kwargs_pixelated=None,
**profile_specific_kwargs):
"""Create a LightModelBase object.

NOTE: the extra keyword arguments are given to the corresponding profile class
only when that profile is given as a string instead of a class instance.

Parameters
----------
light_model_list : list of str
Light model types.
smoothing : float
Smoothing factor for some models (deprecated).
pixel_interpol : string
Type of interpolation for 'PIXELATED' profiles: 'bilinear' or 'bicubic'
pixel_allow_extrapolation : bool
For 'PIXELATED' profiles, wether or not to extrapolate flux values outside the chosen region
otherwise force values to be zero.
profile_list : list of strings or profile instances
List of mass profiles. If not a list, wrap the passed argument in a list.
kwargs_pixelated : dict
Settings related to the creation of the pixelated grid. See herculens.PixelGrid.create_model_grid for details
Settings related to the creation of the pixelated grid.
See herculens.PixelGrid.create_model_grid for details.
profile_specific_kwargs : dict
See docstring for get_class_from_string().

"""
func_list = []
pix_idx = None
for idx, profile_type in enumerate(light_model_list):
if profile_type == 'GAUSSIAN':
func_list.append(gaussian.Gaussian())
elif profile_type == 'GAUSSIAN_ELLIPSE':
func_list.append(gaussian.GaussianEllipse())
elif profile_type == 'SERSIC':
func_list.append(sersic.Sersic(smoothing))
elif profile_type == 'SERSIC_ELLIPSE':
func_list.append(sersic.SersicElliptic(smoothing, exponent=2))
elif profile_type == 'SERSIC_SUPERELLIPSE':
func_list.append(sersic.SersicElliptic(smoothing, exponent=superellipse_exponent))
elif profile_type == 'UNIFORM':
func_list.append(uniform.Uniform())
elif profile_type == 'MULTIPOLE':
func_list.append(multipole.Multipole())
elif profile_type == 'PIXELATED':
if pix_idx is not None:
raise ValueError("Multiple pixelated profiles is currently not supported.")
func_list.append(
pixelated.Pixelated(
interpolation_type=pixel_interpol,
allow_extrapolation=pixel_allow_extrapolation,
adaptive_grid=pixel_adaptive_grid,
)
)
pix_idx = idx
elif profile_type == 'SHAPELETS':
from herculens.LightModel.Profiles import shapelets # prevent importing GigaLens if not used
func_list.append(shapelets.Shapelets(shapelets_n_max))
else:
err_msg = (f"No light model of type {profile_type} found. " +
f"Supported types are: {SUPPORTED_MODELS}")
raise ValueError(err_msg)
self.func_list = func_list
self.func_list, self._pix_idx = self._load_model_instances(
profile_list, **profile_specific_kwargs
)
self._num_func = len(self.func_list)
self._pix_idx = pix_idx
self._model_list = profile_list
if kwargs_pixelated is None:
kwargs_pixelated = {}
self._kwargs_pixelated = kwargs_pixelated

def _load_model_instances(
self, profile_list, **profile_specific_kwargs,
):
func_list = []
pix_idx = None
for idx, profile_type in enumerate(profile_list):
if isinstance(profile_type, str):
# passing string is supported for backward-compatibility only
profile_class = self.get_class_from_string(
profile_type,
**profile_specific_kwargs,
)
if profile_type in ['PIXELATED']:
pix_idx = idx

# this is the new preferred way: passing the profile as a class
elif self.is_light_profile_class(profile_type):
profile_class = profile_type
if isinstance(profile_class, STRING_MAPPING['PIXELATED']):
pix_idx = idx
else:
raise ValueError("Each profile can either be a string or "
"directly the profile instance.")
func_list.append(profile_class)
return func_list, pix_idx

@staticmethod
def is_light_profile_class(profile):
"""Simply checks that the mass profile has the required methods"""
return hasattr(profile, 'function')

@staticmethod
def get_class_from_string(
profile_string,
smoothing=0.001,
shapelets_n_max=4,
superellipse_exponent=2,
pixel_interpol='bilinear',
pixel_adaptive_grid=False,
pixel_allow_extrapolation=False,
):
"""
Get the profile class of the corresponding type.
Keyword arguments are related to specific profile types.

Parameters
----------
smoothing : float
Smoothing factor for some models.
shapelets_n_max : int
Maximal order of the shapelets basis set.
superellipse_exponent : int, float
Exponent for super-elliptical profiles (e.g. 'SERSIC_SUPERELLIPSE').
pixel_interpol : string
Type of interpolation for 'PIXELATED' profiles: 'bilinear' or 'bicubic'
pixel_adaptive_grid : bool
Whether or not the pixelated light profile is defined on a grid
whose extent is adapted based on other model components.
pixel_allow_extrapolation : bool
Wether or not to allow the interpolator to predict values outside
the field of view of the pixelated profile
"""
if profile_string in SUPPORTED_MODELS:
profile_class = STRING_MAPPING[profile_string]
# treats the few special cases that require user settings
if profile_string == 'SERSIC':
return profile_class(smoothing=smoothing)
elif profile_string == 'SERSIC_ELLIPSE':
return profile_class(smoothing=smoothing, exponent=2)
elif profile_string == 'SERSIC_SUPERELLIPSE':
return profile_class(smoothing=smoothing, exponent=superellipse_exponent)
elif profile_string == 'SHAPELETS':
return profile_class(shapelets_n_max)
elif profile_string == 'PIXELATED':
return profile_class(interpolation_type=pixel_interpol,
allow_extrapolation=pixel_allow_extrapolation,
adaptive_grid=pixel_adaptive_grid)
else:
raise ValueError(f"Could not load profile type '{profile_string}'.")
# all remaining profiles take no extra arguments
return profile_class()

@property
def param_name_list(self):
Expand Down
26 changes: 26 additions & 0 deletions herculens/LightModel/profile_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# NOTE: this file is useful for backward-compatibility only,
# as the preferred way is now to pass the profile class directly
# to the LightModel() constructor.

from herculens.LightModel.Profiles.sersic import (Sersic, SersicElliptic)
from herculens.LightModel.Profiles.multipole import Multipole
from herculens.LightModel.Profiles.gaussian import (Gaussian, GaussianEllipse)
from herculens.LightModel.Profiles.pixelated import Pixelated
from herculens.LightModel.Profiles.uniform import Uniform
from herculens.LightModel.Profiles.shapelets import Shapelets


# mapping between the string name to the mass profile class.
STRING_MAPPING = {
'SERSIC': Sersic,
'SERSIC_ELLIPSE': SersicElliptic,
'GAUSSIAN': Gaussian,
'GAUSSIAN_ELLIPSE': GaussianEllipse,
'MULTIPOLE': Multipole,
'PIXELATED': Pixelated,
'UNIFORM': Uniform,
'SHAPELETS': Shapelets
}

SUPPORTED_MODELS = list(STRING_MAPPING.keys())
print("INSIDE", SUPPORTED_MODELS)
17 changes: 9 additions & 8 deletions herculens/MassModel/mass_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@

class MassModel(MassModelBase):
"""An arbitrary list of lens models."""
def __init__(self, mass_model_list, kwargs_pixelated=None, **kwargs):
def __init__(self, profile_list, **kwargs):
"""Create a MassModel object.

Parameters
----------
mass_model_list : list of str
Lens model profile names.
kwargs_pixelated : dictionary for settings related to PIXELATED profiles.
"""
self.profile_type_list = mass_model_list
super().__init__(self.profile_type_list, kwargs_pixelated=kwargs_pixelated,
**kwargs)
profile_list : list of strings or profile instances
List of mass profiles.
"""
if not isinstance(profile_list, (list, tuple)):
# useful when using a single profile
profile_list = [profile_list]
self.profile_type_list = profile_list
super().__init__(self.profile_type_list, **kwargs)

def ray_shooting(self, x, y, kwargs, k=None):
"""
Expand Down
Loading
Loading