Skip to content

Commit

Permalink
Fix unreturned outputs that broke tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aymgal committed Nov 21, 2024
1 parent 30d738b commit 8338b4f
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 43 deletions.
90 changes: 63 additions & 27 deletions herculens/LightModel/light_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,40 +47,76 @@ def __init__(self, profile_list, verbose=False, **kwargs):
profile_list = [profile_list]
self.profile_type_list = profile_list
super(LightModel, self).__init__(self.profile_type_list, **kwargs)
self._single_profile_mode = False
self._repeated_profile_mode = False
self._single_profile_mode = len(self.profile_type_list) == 1
if len(self.profile_type_list) > 0:
first_profile = self.profile_type_list[0]
self._single_profile_mode = (
self._repeated_profile_mode = (
all(p is first_profile for p in self.profile_type_list)
)
if verbose is True and self._single_profile_mode:
print("Single profile mode in LightModel.")

def surface_brightness(self, x, y, kwargs, k=None):
"""Total source flux at a given position.
Parameters
----------
x, y : float or array_like
Position coordinate(s) in arcsec relative to the image center.
kwargs_list : list
List of parameter dictionaries corresponding to each source model.
k : int, optional
Position index of a single source model component.
"""
# x = np.array(x, dtype=float)
# y = np.array(y, dtype=float)
if isinstance(k, int):
return self.func_list[k].function(x, y, **kwargs[k])
elif self._single_profile_mode:
return self._surf_bright_single(x, y, kwargs, k=k)
if verbose is True and self._repeated_profile_mode:
print("All LightModel profiles are identical.")

def surface_brightness(self, x, y, kwargs, k=None,
pixels_x_coord=None, pixels_y_coord=None):
"""Total source flux at a given position.
Parameters
----------
x : float or array_like
Position coordinate(s) in arcsec relative to the image center.
y : float or array_like
Position coordinate(s) in arcsec relative to the image center.
kwargs : list
List of parameter dictionaries corresponding to each source model.
k : int, optional
Position index of a single source model component.
pixels_x_coord : array_like, optional
x-coordinates of the pixelated light profile (if any).
pixels_y_coord : array_like, optional
y-coordinates of the pixelated light profile (if any).
Returns
-------
float or array_like
Total source flux at the given position(s).
"""
# x = np.array(x, dtype=float)
# y = np.array(y, dtype=float)
if isinstance(k, int):
return self._surf_bright_single(x, y, kwargs, k=k,
pixels_x_coord=pixels_x_coord,
pixels_y_coord=pixels_y_coord)
elif self._single_profile_mode:
return self._surf_bright_single(x, y, kwargs, k=0,
pixels_x_coord=pixels_x_coord,
pixels_y_coord=pixels_y_coord)
elif self._repeated_profile_mode:
return self._surf_bright_repeated(x, y, kwargs, k=k,
pixels_x_coord=pixels_x_coord,
pixels_y_coord=pixels_y_coord)
else:
return self._surf_bright_loop(x, y, kwargs, k=k,
pixels_x_coord=pixels_x_coord,
pixels_y_coord=pixels_y_coord)

def _surf_bright_single(self, x, y, kwargs, k=None,
pixels_x_coord=None, pixels_y_coord=None):
if k == self.pixelated_index:
return self.func_list[k].function(
x, y, **kwargs[k],
pixels_x_coord=pixels_x_coord,
pixels_y_coord=pixels_y_coord,
)
else:
return self._surf_bright_loop(x, y, kwargs, k=k)
return self.func_list[k].function(x, y, **kwargs[k])

def _surf_bright_single(self, x, y, kwargs, k=None):
def _surf_bright_repeated(self, x, y, kwargs, k=None,
pixels_x_coord=None, pixels_y_coord=None):
if k is not None:
raise NotImplementedError # TODO: implement case with k not None
raise NotImplementedError("Repeated profile mode not implemented "
"specific profile k.")
func = function_static_single(x, y, self.func_list[0].function)
return jnp.sum(
jnp.array([
Expand Down
14 changes: 7 additions & 7 deletions herculens/MassModel/mass_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ def __init__(self, profile_list, use_jax_scan=False, verbose=False, **kwargs):
self.profile_type_list = profile_list
super().__init__(self.profile_type_list, **kwargs)
self._use_jax_scan = use_jax_scan
self._single_profile_mode = False
self._repeated_profile_mode = False
if len(self.profile_type_list) > 0:
first_profile = self.profile_type_list[0]
self._single_profile_mode = (
self._repeated_profile_mode = (
all(p is first_profile for p in self.profile_type_list)
)
if verbose is True and self._single_profile_mode:
print("Single profile mode in MassModel.")
if verbose is True and self._repeated_profile_mode:
print("All MassModel profiles are identical.")

@partial(jit, static_argnums=(0, 4))
def ray_shooting(self, x, y, kwargs, k=None):
Expand Down Expand Up @@ -138,14 +138,14 @@ def alpha(self, x, y, kwargs, k=None):
# y = np.array(y, dtype=float)
if isinstance(k, int):
return self.func_list[k].derivatives(x, y, **kwargs[k])
elif self._single_profile_mode:
return self._alpha_single(x, y, kwargs, k=k)
elif self._repeated_profile_mode:
return self._alpha_repeated(x, y, kwargs, k=k)
elif self._use_jax_scan:
return self._alpha_scan(x, y, kwargs, k=k)
else:
return self._alpha_loop(x, y, kwargs, k=k)

def _alpha_single(self, x, y, kwargs, k=None):
def _alpha_repeated(self, x, y, kwargs, k=None):
if k is not None:
raise NotImplementedError # TODO: implement case with k not None
alpha_func = alpha_static_single(x, y, self.func_list[0].derivatives)
Expand Down
35 changes: 27 additions & 8 deletions test/LightModel/light_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def base_setup():
# Populate kwargs with parameters associated to the base_light_model
kwargs_light = [
{
'amp': 1.0,
'amp': 1354.0,
'R_sersic': 0.5,
'n_sersic': 4.0,
'center_x': 0.04,
Expand All @@ -54,7 +54,7 @@ def base_setup():
'e2': 0.07,
},
{
'amp': 0.8,
'amp': 194.,
'sigma': 0.1,
'center_x': 0.0,
'center_y': 0.0,
Expand All @@ -67,7 +67,7 @@ def base_setup():
]
if TEST_SHAPELETS:
kwargs_light.append({
'amps': np.random.randn((n_max+1)*(n_max+2)//2),
'amps': 1e2*np.random.randn((n_max+1)*(n_max+2)//2),
'beta': 0.2,
'center_x': -0.02,
'center_y': 0.1,
Expand All @@ -81,15 +81,33 @@ def get_light_model_instance(alpha_method):
light_model = LightModel([hcl.SersicElliptic(), hcl.SersicElliptic(), hcl.SersicElliptic()], verbose=True)
else:
light_model = LightModel(3 * [hcl.SersicElliptic()], verbose=True)
kwargs_light = 3 * [
kwargs_light = [
{
'amp': 1.0,
'amp': 123.,
'R_sersic': 0.5,
'n_sersic': 4.0,
'center_x': 0.0,
'center_y': 0.0,
'e1': 0.0,
'e2': 0.0,
},
{
'amp': 21.,
'R_sersic': 0.6,
'n_sersic': 3.2,
'center_x': 0.01,
'center_y': 0.1,
'e1': 0.04,
'e2': 0.1,
},
{
'amp': 1335.,
'R_sersic': 2.1,
'n_sersic': 1.8,
'center_x': -0.01,
'center_y': -0.2,
'e1': -0.04,
'e2': 0.1,
}
]
return light_model, kwargs_light
Expand All @@ -102,16 +120,17 @@ def test_summation_methods(xy):
# unpack the coordinates
x, y = xy
# get the instance corresponding to the alpha_method
light_model1, kwargs_light2 = get_light_model_instance('repeated')
light_model1, kwargs_light1 = get_light_model_instance('repeated')
light_model2, kwargs_light2 = get_light_model_instance('unique')
# test the resulting values of the light profiles
print("AAAA", light_model1.surface_brightness(x, y, kwargs_light1, k=0), light_model2.surface_brightness(x, y, kwargs_light2, k=0))
assert np.allclose(
light_model1.surface_brightness(x, y, kwargs_light2),
light_model1.surface_brightness(x, y, kwargs_light1),
light_model2.surface_brightness(x, y, kwargs_light2), rtol=1e-8
)
# here we test the slightly different call when only one profile is evaluated
assert np.allclose(
light_model1.surface_brightness(x, y, kwargs_light2, k=0),
light_model1.surface_brightness(x, y, kwargs_light1, k=0),
light_model2.surface_brightness(x, y, kwargs_light2, k=0), rtol=1e-8
)

Expand Down
2 changes: 1 addition & 1 deletion test/modeling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def simulate_data(data_type, supersampling_factor):
elif data_type == 'lensed_source_only':
lens_mass_input = MassModel(['EPL', 'SHEAR'])
lens_light_input = LightModel([])
source_input = LightModel(['SERSIC_ELLIPSE'])
source_input = LightModel(['SERSIC_ELLIPSE'], verbose=True)
elif data_type == 'source_only':
lens_mass_input = MassModel([])
lens_light_input = LightModel([])
Expand Down

0 comments on commit 8338b4f

Please sign in to comment.