Skip to content

Commit

Permalink
Expose PointSourceModel methods to compute source plane and image pla…
Browse files Browse the repository at this point in the history
…ne errors
  • Loading branch information
aymgal committed Apr 10, 2024
1 parent d2fef44 commit 046ea18
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
35 changes: 22 additions & 13 deletions herculens/PointSourceModel/point_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,31 +193,40 @@ def source_amplitude(self, kwargs_point_source, kwargs_lens=None):
elif self.type == 'SOURCE_POSITION':
return jnp.array(kwargs_point_source['amp'])

def log_prob_image_plane(self, kwargs_point_source, kwargs_lens,
kwargs_solver, sigma_image=1e-3):
def error_image_plane(self, kwargs_point_source, kwargs_lens, kwargs_solver):
self._check_solver_install("log_prob_image_plane")
# get the optimized image positions
theta_x_opti = jnp.array(kwargs_point_source['ra'])
theta_y_opti = jnp.array(kwargs_point_source['dec'])
# find source position via ray-tracing
theta_x_in = jnp.array(kwargs_point_source['ra'])
theta_y_in = jnp.array(kwargs_point_source['dec'])
beta = self.mass_model.ray_shooting(theta_x_in, theta_y_in, kwargs_lens)
beta_x, beta_y = jnp.mean(beta[0]), jnp.mean(beta[1])
beta = self.mass_model.ray_shooting(theta_x_opti, theta_y_opti, kwargs_lens)
beta_x, beta_y = beta[0].mean(), beta[1].mean()
beta = jnp.array([beta_x, beta_y])
# solve lens equation to find corresponding image positions
# solve lens equation to find the predicted image positions
theta, beta = self.solver.solve(
beta, kwargs_lens, **kwargs_solver,
)
theta_x, theta_y = theta.T
theta_x_pred, theta_y_pred = theta.T
# return departures between original and new positions
return jnp.sqrt((theta_x_opti - theta_x_pred)**2 + (theta_y_opti - theta_y_pred)**2)

def log_prob_image_plane(self, kwargs_point_source, kwargs_lens,
kwargs_solver, sigma_image=1e-3):
error_image = self.error_image_plane(kwargs_point_source, kwargs_lens, kwargs_solver)
# penalize departures between original and new positions
return - jnp.sum((jnp.hypot(theta_x_in - theta_x, theta_y_in - theta_y) / sigma_image)**2)
return - jnp.sum((error_image / sigma_image)**2)

def log_prob_source_plane(self, kwargs_point_source, kwargs_lens, sigma_source=1e-3):
def error_source_plane(self, kwargs_point_source, kwargs_lens):
# find source position via ray-tracing
theta_x_in = jnp.array(kwargs_point_source['ra'])
theta_y_in = jnp.array(kwargs_point_source['dec'])
beta_x, beta_y = self.mass_model.ray_shooting(theta_x_in, theta_y_in, kwargs_lens)
beta_x_mean, beta_y_mean = jnp.mean(beta_x), jnp.mean(beta_y)
# penalize departures between original and new positions
return - jnp.sum((jnp.hypot(beta_x - beta_x_mean, beta_y - beta_y_mean) / sigma_source)**2)
# compute distance between mean position and ray-traced positions
return jnp.sqrt((beta_x - beta_x.mean())**2 + (beta_y - beta_y.mean())**2)

def log_prob_source_plane(self, kwargs_point_source, kwargs_lens, sigma_source=1e-3):
error_source = self.error_source_plane(kwargs_point_source, kwargs_lens)
return - jnp.sum((error_source / sigma_source)**2)

def _zero_amp_duplicated_images(self, amp_in, theta_x_in, theta_y_in, kwargs_solver):
"""This function takes as input the list of multiply lensed images
Expand Down
34 changes: 34 additions & 0 deletions herculens/PointSourceModel/point_source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,23 @@ def get_source_plane_points(self, kwargs_point_source, kwargs_lens=None,

return beta_x, beta_y

def error_image_plane(self, kwargs_params, kwargs_solver):
"""This function takes as input the current parameters
and returns distance between the predicted image positions
based on the mean ray-traced source position, and the model image positions.
"""
error_image = 0.
kwargs_point_source = kwargs_params['kwargs_point_source']
for i, ps_type in enumerate(self.type_list):
if ps_type == 'IMAGE_POSITIONS':
ps = self.point_sources[i]
error_image += ps.error_image_plane(
kwargs_params['kwargs_point_source'][i],
kwargs_params['kwargs_lens'],
kwargs_solver,
)
return error_image

def log_prob_image_plane(self, kwargs_params, kwargs_solver, **kwargs_hyperparams):
"""This function takes as input the current parameters
and returns log-probability penalty term that enforces multiply imaged
Expand All @@ -161,6 +178,23 @@ def log_prob_image_plane(self, kwargs_params, kwargs_solver, **kwargs_hyperparam
)
return log_prob

def error_source_plane(self, kwargs_params):
"""This function takes as input the current parameters
and returns distance between the predicted source positions
based on the mean source position, and the ray-traced source positions
from the model image positions.
"""
error_source = 0.
kwargs_point_source = kwargs_params['kwargs_point_source']
for i, ps_type in enumerate(self.type_list):
if ps_type == 'IMAGE_POSITIONS':
ps = self.point_sources[i]
error_source += ps.error_source_plane(
kwargs_params['kwargs_point_source'][i],
kwargs_params['kwargs_lens'],
)
return error_source

def log_prob_source_plane(self, kwargs_params, **kwargs_hyperparams):
"""This function takes as input the current parameters
and returns log-probability penalty term that enforces multiply imaged
Expand Down

0 comments on commit 046ea18

Please sign in to comment.