Skip to content

Commit

Permalink
update example
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Aug 14, 2023
1 parent 1c1aa60 commit 26deec9
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions examples/fit_multiple_shooting_second_order_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def h(self, x, u=None, t=None):
t_train = np.linspace(0, 10, 1000)
samplerate = 1 / t_train[1]
np.random.seed(42)
u_train = np.random.normal(size=len(t_train))
u_train = np.sum(
np.stack(
[np.sin(f * t_train) for f in np.random.uniform(0, samplerate / 4, size=10)]
),
axis=0,
)
initial_x = [0.0, 0.0]
x_train, y_train = true_model(initial_x, t_train, u_train)

Expand All @@ -81,10 +86,10 @@ def h(self, x, u=None, t=None):
y=y_train,
x0=initial_x,
u=u_train,
verbose=0,
verbose=2,
num_shots=num_shots,
)
model = res.x
model = res.model
x0s = res.x0s
ts = res.ts
ts0 = res.ts0
Expand All @@ -93,7 +98,6 @@ def h(self, x, u=None, t=None):

# check the results
x_pred, y_pred = model(initial_x, t_train, u_train)
assert np.allclose(x_train, x_pred, atol=1e-5, rtol=1e-5)

# plot
xs_pred, _ = jax.vmap(model)(x0s, ts0, us)
Expand All @@ -105,3 +109,5 @@ def h(self, x, u=None, t=None):
plt.plot()
plt.legend()
plt.show()

assert np.allclose(x_train, x_pred, atol=1e-5, rtol=1e-5)

0 comments on commit 26deec9

Please sign in to comment.