Skip to content

Commit

Permalink
Ready for super long training on contex fflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ddrous committed Jan 21, 2024
1 parent 89e5766 commit f764c97
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions examples/lotka-voltera/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def step(state, t):
# {"alpha": 0.5, "beta": 1.125, "gamma": 0.5, "delta": 1.125},
# ]

n_traj_per_env = 4 ## training
# n_traj_per_env = 32 ## testing
# n_traj_per_env = 4 ## training
n_traj_per_env = 32 ## testing
# n_traj_per_env = 1 ## adaptation

n_steps_per_traj = int(10/0.5)+1 ## from coda
Expand Down Expand Up @@ -146,7 +146,7 @@ def animate(i):
plt.show()

# Save t_eval and the solution to a npz file
np.savez('tmp/train_data.npz', t=t_eval, X=data)
np.savez('tmp/test_data.npz', t=t_eval, X=data)

## Save the movie to a small mp4 file
# ani.save('tmp/lotka_volterra.mp4', fps=30, extra_args=['-vcodec', 'libx264'])
Expand Down
10 changes: 5 additions & 5 deletions examples/lotka-voltera/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
## Hyperparams
SEED = 3
context_size = 8000
nb_epochs = 1000
nb_epochs = 50000



Expand Down Expand Up @@ -188,15 +188,15 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, alpha, beta, ctx_, key):
nb_train_steps = nb_epochs * 11
sched_node = optax.piecewise_constant_schedule(init_value=3e-3,
boundaries_and_scales={int(nb_train_steps*0.25):0.2,
int(nb_train_steps*0.5):0.2,
int(nb_train_steps*0.75):0.2})
int(nb_train_steps*0.5):0.1,
int(nb_train_steps*0.75):0.01})
# sched_node = 1e-3
# sched_node = optax.exponential_decay(3e-3, nb_epochs*2, 0.99)

sched_ctx = optax.piecewise_constant_schedule(init_value=3e-2,
boundaries_and_scales={int(nb_epochs*0.25):0.2,
int(nb_epochs*0.5):0.2,
int(nb_epochs*0.75):0.2})
int(nb_epochs*0.5):0.1,
int(nb_epochs*0.75):0.01})
# sched_ctx = 1e-3

opt_node = optax.adabelief(sched_node)
Expand Down

0 comments on commit f764c97

Please sign in to comment.