Skip to content

Commit

Permalink
Add ability to calculate the hessian for MP mass model using auto diff
Browse files Browse the repository at this point in the history
  • Loading branch information
CKrawczyk committed Aug 19, 2024
1 parent d51e053 commit 68fca10
Showing 1 changed file with 69 additions and 1 deletion.
70 changes: 69 additions & 1 deletion herculens/MassModel/mass_model_multiplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def ray_shooting(self, x, y, eta_flat, kwargs, N=None, k=None):
:type x: numpy array
:param y: y-position (preferentially arcsec)
:type y: numpy array
:param eta: upper triangular elements of eta matrix, values defined as
:param eta_flat: upper triangular elements of eta matrix, values defined as
eta_ij = D_ij D_i+1 / D_j D_ii+1 where D_ij is the angular diameter
distance between redshifts i and j. Only include values where
j > i+1. This convention implies that all einstein radii are defined
Expand Down Expand Up @@ -91,3 +91,71 @@ def ray_shooting(self, x, y, eta_flat, kwargs, N=None, k=None):
ys = ys - etas_j * dy
return xs, ys

def ray_shooting_slice(self, x, y, eta_flat, kwargs):
'''Helper function that give *scaler* inputs of x and y give a *vector*
output for each mass plane. Used for the vectorization of the `A` method'''
return jnp.stack(
self.ray_shooting(jnp.array([x]), jnp.array([y]), eta_flat, kwargs)
).T.squeeze()

def A_stack(self, x, y, eta_flat, kwargs):
'''Helper function that takes the jacobian of the ray shooting give *scaler*
inputs for x and y and returns a 2x2 array.'''
return jnp.stack(
jax.jacfwd(
self.ray_shooting_slice,
argnums=(0, 1)
)(x, y, eta_flat, kwargs)
)

@partial(jax.jit, static_argnums=(0,))
def A(self, x, y, eta_flat, kwargs):
'''
Area distortion matrix of the lens mapping.
Parameters
----------
x : jax.numpy array
x-position (preferentially arcsec)
y : jax.numpy array
y-position (preferentially arcsec)
eta_flat : jax.numpy array
upper triangular elements of eta matrix, values defined as
eta_ij = D_ij D_i+1 / D_j D_ii+1 where D_ij is the angular diameter
distance between redshifts i and j. Only include values where
j > i+1. This convention implies that all einstein radii are defined
with respect to the **next** mass plane back (**not** the last plane in
the stack).
kwargs: list of list
keyword arguments of lens model parameters matching the lens model classes
Returns
-------
A : jnp.numpy array
The area distortion matrix of the lens mapping for each position
and each mass plane (including the image plane) with shape
(N+1, *(x.shape), 2, 2) where N is the number of mass planes.
'''
A_stack_part = partial(
self.A_stack,
eta_flat=eta_flat,
kwargs=kwargs
)
return jnp.moveaxis(jnp.vectorize(
A_stack_part,
signature='(),()->(i,j,i)'
)(x, y), 3, 0)

def inverse_magnification(self, x, y, eta_flat, kwargs):
A = self.A(x, y, eta_flat, kwargs)
return A[..., 0, 0] * A[..., 1, 1] - A[..., 0, 1] * A[..., 1, 0]

def kappa(self, x, y, eta_flat, kwargs):
A = self.A(x, y, eta_flat, kwargs)
return 1 - 0.5 * (A[..., 0, 0] + A[..., 1, 1])

def gamma(self, x, y, eta_flat, kwargs):
A = self.A(x, y, eta_flat, kwargs)
gamma1 = 0.5 * (A[..., 1, 1] - A[..., 0, 0])
gamma2 = -A[..., 0, 1]
return gamma1, gamma2

0 comments on commit 68fca10

Please sign in to comment.