From f764c9751d11692908b7fc7a777b080663b2c823 Mon Sep 17 00:00:00 2001 From: gb21553 Date: Sun, 21 Jan 2024 15:43:13 +0000 Subject: [PATCH] Ready for super long training on contex fflow --- examples/lotka-voltera/dataset.py | 6 +++--- examples/lotka-voltera/main.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/lotka-voltera/dataset.py b/examples/lotka-voltera/dataset.py index f1bb602..747afb7 100644 --- a/examples/lotka-voltera/dataset.py +++ b/examples/lotka-voltera/dataset.py @@ -88,8 +88,8 @@ def step(state, t): # {"alpha": 0.5, "beta": 1.125, "gamma": 0.5, "delta": 1.125}, # ] -n_traj_per_env = 4 ## training -# n_traj_per_env = 32 ## testing +# n_traj_per_env = 4 ## training +n_traj_per_env = 32 ## testing # n_traj_per_env = 1 ## adaptation n_steps_per_traj = int(10/0.5)+1 ## from coda @@ -146,7 +146,7 @@ def animate(i): plt.show() # Save t_eval and the solution to a npz file -np.savez('tmp/train_data.npz', t=t_eval, X=data) +np.savez('tmp/test_data.npz', t=t_eval, X=data) ## Save the movie to a small mp4 file # ani.save('tmp/lotka_volterra.mp4', fps=30, extra_args=['-vcodec', 'libx264']) diff --git a/examples/lotka-voltera/main.py b/examples/lotka-voltera/main.py index ae83e45..e54f15c 100644 --- a/examples/lotka-voltera/main.py +++ b/examples/lotka-voltera/main.py @@ -21,7 +21,7 @@ ## Hyperparams SEED = 3 context_size = 8000 -nb_epochs = 1000 +nb_epochs = 50000 @@ -188,15 +188,15 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, alpha, beta, ctx_, key): nb_train_steps = nb_epochs * 11 sched_node = optax.piecewise_constant_schedule(init_value=3e-3, boundaries_and_scales={int(nb_train_steps*0.25):0.2, - int(nb_train_steps*0.5):0.2, - int(nb_train_steps*0.75):0.2}) + int(nb_train_steps*0.5):0.1, + int(nb_train_steps*0.75):0.01}) # sched_node = 1e-3 # sched_node = optax.exponential_decay(3e-3, nb_epochs*2, 0.99) sched_ctx = optax.piecewise_constant_schedule(init_value=3e-2, boundaries_and_scales={int(nb_epochs*0.25):0.2, - int(nb_epochs*0.5):0.2, - int(nb_epochs*0.75):0.2}) + int(nb_epochs*0.5):0.1, + int(nb_epochs*0.75):0.01}) # sched_ctx = 1e-3 opt_node = optax.adabelief(sched_node)