Skip to content

Commit

Permalink
Validation of neural context flow with 150000 epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
rdes committed Jan 21, 2024
1 parent f764c97 commit 1e52a56
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,20 @@
# nodebias
Inductive bias learning for dynamical systems



NodeBias is built around 5 extensible modules:
- a DataLoader: to store the dataset
- a Learner: a model and loss function
- A Trainer: to train
- a VisualTester: to test and visualize the results
- a HPFactory: to find hyper-parameters for our models

Diagram showing the flow across the modules.


A few neural ODE implemented models:
- One-Per-Env
- One-For-All
- Context-Informed
- Neural Context Flow
2 changes: 1 addition & 1 deletion examples/lotka-voltera/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def step(state, t):
n_traj_per_env = 32 ## testing
# n_traj_per_env = 1 ## adaptation

n_steps_per_traj = int(10/0.5)+1 ## from coda
n_steps_per_traj = int(10/0.5)+1 ## TODO: from coda, it is 20 to be precise
# n_steps_per_traj = 201

data = np.zeros((len(environments), n_traj_per_env, n_steps_per_traj, 2))
Expand Down
8 changes: 4 additions & 4 deletions examples/lotka-voltera/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, alpha, beta, ctx_, key):

## Define optimiser and traine the model

nb_train_steps = nb_epochs * 11
sched_node = optax.piecewise_constant_schedule(init_value=3e-3,
nb_train_steps = nb_epochs * 3
sched_node = optax.piecewise_constant_schedule(init_value=3e-4,
boundaries_and_scales={int(nb_train_steps*0.25):0.2,
int(nb_train_steps*0.5):0.1,
int(nb_train_steps*0.75):0.01})
# sched_node = 1e-3
# sched_node = optax.exponential_decay(3e-3, nb_epochs*2, 0.99)

sched_ctx = optax.piecewise_constant_schedule(init_value=3e-2,
sched_ctx = optax.piecewise_constant_schedule(init_value=3e-3,
boundaries_and_scales={int(nb_epochs*0.25):0.2,
int(nb_epochs*0.5):0.1,
int(nb_epochs*0.75):0.01})
Expand All @@ -210,7 +210,7 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, alpha, beta, ctx_, key):
for i, prop in enumerate(np.linspace(0.25, 1.0, 2)):
trainer.dataloader.int_cutoff = int(prop*nb_steps_per_traj)
# nb_epochs = nb_epochs // 2 if nb_epochs > 1000 else 1000
trainer.train(nb_epochs=nb_epochs*(10**i), print_error_every=100*(10**i), update_context_every=1, save_path="tmp/", key=SEED)
trainer.train(nb_epochs=nb_epochs*(2**i), print_error_every=1000*(2**i), update_context_every=1, save_path="tmp/", key=SEED)

#%%

Expand Down

0 comments on commit 1e52a56

Please sign in to comment.