Skip to content

Commit

Permalink
fixed array sizes, prep for tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
s-m-e committed Jan 14, 2024
1 parent 81d5611 commit 203357e
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions src/hapsira/core/math/ivp/_rk.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
MIN_FACTOR = 0.2 # Minimum allowed decrease in a step size.
MAX_FACTOR = 10 # Maximum allowed increase in a step size.

INTERPOLATOR_POWER = 7
N_RV = 6
N_STAGES = 12
N_STAGES_EXTENDED = 16


@jit(nopython=False)
def norm(x: np.ndarray) -> float:
Expand Down Expand Up @@ -176,7 +181,7 @@ def validate_max_step(max_step: float) -> float:


@jit(nopython=False)
def validate_tol(rtol: float, atol: float, n: int) -> Tuple[float, float]:
def validate_tol(rtol: float, atol: float) -> Tuple[float, float]:
"""Validate tolerance values."""

if np.any(rtol < 100 * EPS):
Expand All @@ -187,7 +192,7 @@ def validate_tol(rtol: float, atol: float, n: int) -> Tuple[float, float]:
rtol = np.maximum(rtol, 100 * EPS)

atol = np.asarray(atol)
if atol.ndim > 0 and atol.shape != (n,):
if atol.ndim > 0 and atol.shape != (N_RV,):
raise ValueError("`atol` has wrong shape.")

if np.any(atol < 0):
Expand Down Expand Up @@ -342,20 +347,19 @@ class DOP853:

TOO_SMALL_STEP = "Required step size is less than spacing between numbers."

n_stages: int = 12 # N_STAGES == 12
order: int = 8
error_estimator_order: int = 7
A = dop853_coefficients.A[:n_stages, :n_stages]
A = dop853_coefficients.A[:N_STAGES, :N_STAGES]
B = dop853_coefficients.B
C = dop853_coefficients.C[:n_stages]
C = dop853_coefficients.C[:N_STAGES]
E: np.ndarray = NotImplemented
E3 = dop853_coefficients.E3
E5 = dop853_coefficients.E5
D = dop853_coefficients.D
P: np.ndarray = NotImplemented

A_EXTRA = dop853_coefficients.A[n_stages + 1 :]
C_EXTRA = dop853_coefficients.C[n_stages + 1 :]
A_EXTRA = dop853_coefficients.A[N_STAGES + 1 :]
C_EXTRA = dop853_coefficients.C[N_STAGES + 1 :]

def __init__(
self,
Expand All @@ -367,6 +371,8 @@ def __init__(
rtol: float = 1e-3,
atol: float = 1e-6,
):
assert y0.shape == (N_RV,)

self.t_old = None
self.t = t0
self._fun, self.y = check_arguments(fun, y0)
Expand All @@ -382,7 +388,6 @@ def fun(t, y):
self.fun_single = fun_single

self.direction = np.sign(t_bound - t0) if t_bound != t0 else 1
self.n = self.y.size
self.status = "running"

self.nfev = 0
Expand All @@ -391,7 +396,7 @@ def fun(t, y):

self.y_old = None
self.max_step = validate_max_step(max_step)
self.rtol, self.atol = validate_tol(rtol, atol, self.n)
self.rtol, self.atol = validate_tol(rtol, atol)
self.f = self.fun(self.t, self.y)
self.h_abs = select_initial_step(
self.fun,
Expand All @@ -406,10 +411,8 @@ def fun(t, y):
self.error_exponent = -1 / (self.error_estimator_order + 1)
self.h_previous = None

self.K_extended = np.empty(
(16, self.n), dtype=self.y.dtype
) # N_STAGES_EXTENDED == 16
self.K = self.K_extended[: self.n_stages + 1]
self.K_extended = np.empty((N_STAGES_EXTENDED, N_RV), dtype=self.y.dtype)
self.K = self.K_extended[: N_STAGES + 1]

@property
def step_size(self):
Expand All @@ -431,7 +434,7 @@ def step(self):
if self.status != "running":
raise RuntimeError("Attempt to step on a failed or finished " "solver.")

if self.n == 0 or self.t == self.t_bound:
if self.t == self.t_bound:
# Handle corner cases of empty solver or no integration.
self.t_old = self.t
self.t = self.t_bound
Expand Down Expand Up @@ -460,7 +463,7 @@ def dense_output(self):
"""
assert self.t_old is not None

assert not (self.n == 0 or self.t == self.t_old)
assert self.t != self.t_old

return self._dense_output_impl()

Expand Down Expand Up @@ -543,13 +546,11 @@ def _step_impl(self):
def _dense_output_impl(self):
K = self.K_extended
h = self.h_previous
for s, (a, c) in enumerate(
zip(self.A_EXTRA, self.C_EXTRA), start=self.n_stages + 1
):
for s, (a, c) in enumerate(zip(self.A_EXTRA, self.C_EXTRA), start=N_STAGES + 1):
dy = np.dot(K[:s].T, a[:s]) * h
K[s] = self.fun(self.t_old + c * h, self.y_old + dy)

F = np.empty((7, self.n), dtype=self.y_old.dtype) # INTERPOLATOR_POWER==7
F = np.empty((INTERPOLATOR_POWER, N_RV), dtype=self.y_old.dtype)

f_old = K[0]
delta_y = self.y - self.y_old
Expand Down

0 comments on commit 203357e

Please sign in to comment.