Skip to content

Commit

Permalink
Merge pull request #29 from ahuang314/main
Browse files Browse the repository at this point in the history
Improves the R_omega function in jax_util
  • Loading branch information
aymgal authored Jun 7, 2024
2 parents e61183c + f4069b6 commit 1229382
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions herculens/Util/jax_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,16 @@ def R_omega(z, t, q, nmax):
omega_i = z # jnp.array(np.copy(z)) # Avoid overwriting z ?
partial_sum = omega_i

for i in range(1, nmax):
# Iteration-dependent factor
ratio = (2. * i - (2. - t)) / (2. * i + (2 - t))
# Current Omega term proportional to the previous term
omega_i = -f * ratio * ei2phi * omega_i
# Update the partial sum
partial_sum += omega_i
return partial_sum
@jit
def body_fun(i, val):
# Currrent term in the series is proportional to the previous
ratio = (2. * i + t - 2.) / (2. * i - t + 2.)
val[1] = -f * ei2phi * ratio * val[1]
# Adds the current term to the partial sum
val[0] += val[1]
return val

return lax.fori_loop(1, nmax, body_fun, [partial_sum, omega_i])[0]


def omega_real(x, y, t, q, nmax):
Expand All @@ -87,7 +89,7 @@ def omega_real(x, y, t, q, nmax):
Cs, Ss = jnp.cos(phi), jnp.sin(phi)
Cs2, Ss2 = jnp.cos(2 * phi), jnp.sin(2 * phi)
def update(n, val):
prefac = -f * (2 * n - (2 - t)) / (2 * n + (2 - t))
prefac = -f * (2. * n - (2. - t)) / (2. * n + (2. - t))
last_x, last_y, fx, fy = val
last_x, last_y = prefac * (Cs2 * last_x - Ss2 * last_y), prefac * (
Ss2 * last_x + Cs2 * last_y
Expand Down

0 comments on commit 1229382

Please sign in to comment.