Skip to content

Commit

Permalink
2D visualisation for gray-scott PDE
Browse files Browse the repository at this point in the history
  • Loading branch information
ddrous committed Feb 21, 2024
1 parent 4dfe9e1 commit 0b8a9b4
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 128 deletions.
27 changes: 16 additions & 11 deletions examples/gray-scott/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

flow_pool_count = 2 ## Number of neighboring contexts j to use for a flow in env e
context_size = 1024
nb_epochs = 100*1
nb_epochs_adapt = 2*1
nb_epochs = 180*360
nb_epochs_adapt = 180*360

print_error_every = 1000

train = False
train = True
save_trainer = True

finetune = False
Expand Down Expand Up @@ -54,7 +54,7 @@


else:
run_folder = "./runs/20022024-173401/" ## Needed for loading the model and finetuning TODO: opti
run_folder = "./runs/21022024-121527/" ## Needed for loading the model and finetuning TODO: opti
print("No training. Loading data and results from:", run_folder)

## Create a folder for the adaptation results
Expand All @@ -80,7 +80,7 @@
#%%

## Define dataloader for training and validation
train_dataloader = DataLoader(run_folder+"train_data.npz", batch_size=4, int_cutoff=1.0, shuffle=True, key=seed)
train_dataloader = DataLoader(run_folder+"train_data.npz", batch_size=1, int_cutoff=1.0, shuffle=True, key=seed)

nb_envs = train_dataloader.nb_envs
nb_trajs_per_env = train_dataloader.nb_trajs_per_env
Expand Down Expand Up @@ -140,6 +140,7 @@ def __init__(self, data_res, kernel_size, nb_int_channels, context_size, key=Non
lambda x: x.flatten()]

def __call__(self, t, y, ctx):
# return jnp.zeros_like(y)*ctx[0]

for layer in self.layers_context:
ctx = layer(ctx)
Expand Down Expand Up @@ -223,7 +224,7 @@ def __call__(self, t, x, ctxs):
return vf(ctx_) + gradvf(ctx_, ctx)


augmentation = Augmentation(data_res=32, kernel_size=3, nb_int_channels=4, context_size=context_size, key=seed)
augmentation = Augmentation(data_res=32, kernel_size=3, nb_int_channels=1, context_size=context_size, key=seed)

vectorfield = ContextFlowVectorField(augmentation, physics=None)
print("\n\nTotal number of parameters in the model:", sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(vectorfield,eqx.is_array)) if x is not None), "\n\n")
Expand Down Expand Up @@ -353,11 +354,13 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, all_ctx_s, key):


#%%
# ## Load the validation loss
# val_loss = np.load(run_folder+"val_losses.npy")
# val_loss


## Custom Gray-Scott trajectory visualiser
if finetune:
savefigdir = finetunedir+"results_2D_ind.png"
else:
savefigdir = run_folder+"results_2D_ind.png"
visualtester.visualize2D(test_dataloader, int_cutoff=1.0, res=32, save_path=savefigdir);


#%%
Expand Down Expand Up @@ -497,7 +500,7 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, all_ctx_s, key):
os.system(f'python dataset.py --split=test --savepath="{run_folder}" --seed="{seed*2}" --verbose=0')
os.system(f'python dataset.py --split=adapt --savepath="{adapt_folder}" --seed="{seed*3}" --verbose=0')

test_dataloader = DataLoader(run_folder+"test_data.npz", shuffle=False, batch_size=4, data_id="082026")
test_dataloader = DataLoader(run_folder+"test_data.npz", shuffle=False, batch_size=1, data_id="082026")
adapt_test_dataloader = DataLoader(adapt_folder+"adapt_data.npz", adaptation=True, batch_size=1, key=seed, data_id="082026")

ind_crit, _ = visualtester.test(test_dataloader, int_cutoff=1.0, verbose=False)
Expand Down Expand Up @@ -553,3 +556,5 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, all_ctx_s, key):

# ## Save the odd_crit_all in numpy
# np.save('mapes.npy', odd_crit_all)

# %%
184 changes: 85 additions & 99 deletions examples/gray-scott/nohup.log
Original file line number Diff line number Diff line change
@@ -1,166 +1,152 @@
Running this script in ipython (Jupyter) session ? False
=== Parsed arguments to generate data ===
Split: train
Savepath: ./runs/18022024-221505/
Savepath: ./runs/21022024-121527/
Seed: 2026

Running this script in ipython (Jupyter) session ? False
=== Parsed arguments to generate data ===
Split: test
Savepath: ./runs/18022024-221505/
Savepath: ./runs/21022024-121527/
Seed: 4052

Running this script in ipython (Jupyter) session ? False
=== Parsed arguments to generate data ===
Split: adapt
Savepath: ./runs/18022024-221505/adapt/
Savepath: ./runs/21022024-121527/adapt/
Seed: 6078


############# Inductive Bias Learning for Dynamical Systems #############

Jax version: 0.4.19
Jax version: 0.4.23
Available devices: [cuda(id=0)]
Data folder created successfuly: ./runs/18022024-221505/
Data folder created successfuly: ./runs/21022024-121527/
Completed copied scripts
WARNING: You did not provide a dataloader id. A new one has been generated: 221526
WARNING: You did not provide a dataloader id. A new one has been generated: 121535
WARNING: Note that this id used to distuinguish between adaptations to different environments.
WARNING: You did not provide a dataloader id. A new one has been generated: 221526
WARNING: You did not provide a dataloader id. A new one has been generated: 121536
WARNING: Note that this id used to distuinguish between adaptations to different environments.
WARNING: batch_size must be between 0 and nb_trajs_per_env. Setting batch_size to maximum.


Total number of parameters in the model: 319911
Total number of parameters in the model: 4200374


WARNING: No key provided for the context initialization. Initializing at 0.
WARNING: No key provided, using time as seed


=== Beginning training ... ===
Number of examples in a batch: 4
Number of train steps per epoch: 8
Number of training epochs: 24000
Total number of training steps: 192000
Number of examples in a batch: 1
Number of train steps per epoch: 1
Number of training epochs: 18000
Total number of training steps: 18000

Compiling function "train_step" for neural ode ...
Shapes of elements in a batch: (9, 4, 20, 7) (20,)
Shapes of elements in a batch: (4, 1, 10, 2048) (10,)

Compiling function "train_step" for context ...
Shapes of elements in a batch: (9, 4, 20, 7) (20,)
Epoch: 0 LossTrajs: 0.28848881 ContextsNorm: 0.00150763 ValIndCrit: 0.23863795
Epoch: 1 LossTrajs: 0.23736051 ContextsNorm: 0.00279875 ValIndCrit: 0.22864109
Epoch: 2 LossTrajs: 0.23170462 ContextsNorm: 0.00403358 ValIndCrit: 0.22836266
Epoch: 3 LossTrajs: 0.23057464 ContextsNorm: 0.00466054 ValIndCrit: 0.22816052
Epoch: 1000 LossTrajs: 0.00352303 ContextsNorm: 0.11520766 ValIndCrit: 0.00458277
Epoch: 2000 LossTrajs: 0.00210036 ContextsNorm: 0.11645799 ValIndCrit: 0.00377220
Epoch: 3000 LossTrajs: 0.00209809 ContextsNorm: 0.11668418 ValIndCrit: 0.00377159
Epoch: 4000 LossTrajs: 0.00201739 ContextsNorm: 0.11677950 ValIndCrit: 0.00373893
Epoch: 5000 LossTrajs: 0.00198951 ContextsNorm: 0.11691947 ValIndCrit: 0.00373602
Epoch: 6000 LossTrajs: 0.00196165 ContextsNorm: 0.11688989 ValIndCrit: 0.00373701
Epoch: 7000 LossTrajs: 0.00195028 ContextsNorm: 0.11689670 ValIndCrit: 0.00376597
Epoch: 8000 LossTrajs: 0.00191953 ContextsNorm: 0.11700038 ValIndCrit: 0.00374925
Epoch: 9000 LossTrajs: 0.00190826 ContextsNorm: 0.11683951 ValIndCrit: 0.00374242
Epoch: 10000 LossTrajs: 0.00187307 ContextsNorm: 0.11691865 ValIndCrit: 0.00374858
Epoch: 11000 LossTrajs: 0.00188009 ContextsNorm: 0.11678100 ValIndCrit: 0.00375942
Epoch: 12000 LossTrajs: 0.00183770 ContextsNorm: 0.11688248 ValIndCrit: 0.00376303
Epoch: 13000 LossTrajs: 0.00183869 ContextsNorm: 0.11692899 ValIndCrit: 0.00379309
Epoch: 14000 LossTrajs: 0.00179264 ContextsNorm: 0.11691536 ValIndCrit: 0.00378790
Epoch: 15000 LossTrajs: 0.00179362 ContextsNorm: 0.11679574 ValIndCrit: 0.00378610
Epoch: 16000 LossTrajs: 0.00178432 ContextsNorm: 0.11683217 ValIndCrit: 0.00377219
Epoch: 17000 LossTrajs: 0.00173260 ContextsNorm: 0.11698029 ValIndCrit: 0.00382742
Epoch: 18000 LossTrajs: 0.00175081 ContextsNorm: 0.11694510 ValIndCrit: 0.00379931
Epoch: 19000 LossTrajs: 0.00170353 ContextsNorm: 0.11696473 ValIndCrit: 0.00384929
Epoch: 20000 LossTrajs: 0.00170334 ContextsNorm: 0.11698332 ValIndCrit: 0.00384146
Epoch: 21000 LossTrajs: 0.00165402 ContextsNorm: 0.11691977 ValIndCrit: 0.00388831
Epoch: 22000 LossTrajs: 0.00165488 ContextsNorm: 0.11698642 ValIndCrit: 0.00388500
Epoch: 23000 LossTrajs: 0.00167246 ContextsNorm: 0.11704455 ValIndCrit: 0.00388701
Epoch: 23999 LossTrajs: 0.00163804 ContextsNorm: 0.11711274 ValIndCrit: 0.00386044

Total gradient descent training time: 17 hours 27 mins 46 secs
Environment weights at the end of the training: [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
0.11111111 0.11111111 0.11111111]
WARNING: You did not provide a dataloader id. A new one has been generated: 154323
Shapes of elements in a batch: (4, 1, 10, 2048) (10,)
Epoch: 0 LossTrajs: 8937.43457031 ContextsNorm: 0.00000000 ValIndCrit: 21615.28125000
Epoch: 1 LossTrajs: 21615.42968750 ContextsNorm: 0.00999988 ValIndCrit: 10067.46289062
Epoch: 2 LossTrajs: 10067.57812500 ContextsNorm: 0.01846115 ValIndCrit: 47097.55859375
Epoch: 3 LossTrajs: 47097.61328125 ContextsNorm: 0.01237317 ValIndCrit: 1540.37536621
Epoch: 1000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 2000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 3000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 4000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 5000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 6000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 7000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 8000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 9000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 10000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 11000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 12000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 13000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 14000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 15000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 16000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 17000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104
Epoch: 17999 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104

Total gradient descent training time: 0 hours 25 mins 43 secs
Environment weights at the end of the training: [0.25 0.25 0.25 0.25]
WARNING: You did not provide a dataloader id. A new one has been generated: 124122
WARNING: Note that this id used to distuinguish between adaptations to different environments.
WARNING: batch_size must be between 0 and nb_trajs_per_env. Setting batch_size to maximum.
WARNING: No key provided, using time as seed
== Begining in-domain testing ... ==
Number of training environments: 9
Final length of the training trajectories: 20
Length of the testing trajectories: 20
Test Score (In-Domain): 0.0038604357
Number of training environments: 4
Final length of the training trajectories: 10
Length of the testing trajectories: 10
Test Score (In-Domain): 0.05316104

== Begining in-domain visualisation ... ==
Environment id: 2
Trajectory id: 11
Final length of the training trajectories: 20
Length of the testing trajectories: 20
Testing finished. Figure saved in: ./runs/18022024-221505/results_in_domain.png
Trajectory id: 0
Final length of the training trajectories: 10
Length of the testing trajectories: 10
Testing finished. Figure saved in: ./runs/21022024-121527/results_in_domain.png
WARNING: batch_size must be between 0 and nb_trajs_per_env. Setting batch_size to maximum.
WARNING: No key provided for the context initialization. Initializing at 0.


=== Beginning adaptation ... ===
Number of examples in a batch: 1
Number of train steps per epoch: 1
Number of training epochs: 24000
Total number of training steps: 24000
Number of training epochs: 18000
Total number of training steps: 18000
WARNING: No key provided, using time as seed

Compiling function "train_step" for context ...
Shapes of elements in a batch: (4, 1, 20, 7) (20,)
Epoch: 0 LossContext: 0.03991570
Epoch: 1 LossContext: 0.03874778
Epoch: 2 LossContext: 0.03758416
Epoch: 3 LossContext: 0.03632322
Epoch: 1000 LossContext: 0.00190338
Epoch: 2000 LossContext: 0.00175669
Epoch: 3000 LossContext: 0.00165465
Epoch: 4000 LossContext: 0.00164393
Epoch: 5000 LossContext: 0.00162050
Epoch: 6000 LossContext: 0.00169902
Epoch: 7000 LossContext: 0.00161232
Epoch: 8000 LossContext: 0.00168804
Epoch: 9000 LossContext: 0.00167818
Epoch: 10000 LossContext: 0.00167914
Epoch: 11000 LossContext: 0.00155257
Epoch: 12000 LossContext: 0.00159554
Epoch: 13000 LossContext: 0.00167166
Epoch: 14000 LossContext: 0.00167052
Epoch: 15000 LossContext: 0.00170710
Epoch: 16000 LossContext: 0.00165121
Epoch: 17000 LossContext: 0.00159982
Epoch: 18000 LossContext: 0.00154785
Epoch: 19000 LossContext: 0.00154855
Epoch: 20000 LossContext: 0.00154777
Epoch: 21000 LossContext: 0.00159213
Epoch: 22000 LossContext: 0.00165402
Epoch: 23000 LossContext: 0.00166323
Epoch: 23999 LossContext: 0.00159121

Total gradient descent adaptation time: 0 hours 40 mins 24 secs
Shapes of elements in a batch: (4, 1, 10, 2048) (10,)
Epoch: 0 LossContext: 0.05053605
Epoch: 1 LossContext: 0.05032079
Epoch: 2 LossContext: 0.05032144
Epoch: 3 LossContext: 0.05032148
Epoch: 1000 LossContext: 0.05032149
Epoch: 2000 LossContext: 0.05032149
Epoch: 3000 LossContext: 0.05032149
Epoch: 4000 LossContext: 0.05032149
Epoch: 5000 LossContext: 0.05032149
Epoch: 6000 LossContext: 0.05032149
Epoch: 7000 LossContext: 0.05032149
Epoch: 8000 LossContext: 0.05032149
Epoch: 9000 LossContext: 0.05032149
Epoch: 10000 LossContext: 0.05032149
Epoch: 11000 LossContext: 0.05032149
Epoch: 12000 LossContext: 0.05032149
Epoch: 13000 LossContext: 0.05032149
Epoch: 14000 LossContext: 0.05032149
Epoch: 15000 LossContext: 0.05032149
Epoch: 16000 LossContext: 0.05032149
Epoch: 17000 LossContext: 0.05032149
Epoch: 17999 LossContext: 0.05032149

Total gradient descent adaptation time: 0 hours 6 mins 58 secs
Environment weights at the end of the adaptation: [0.25 0.25 0.25 0.25]

Saving adaptation parameters into ./runs/18022024-221505/adapt/ folder with id 170846 ...
Saving adaptation parameters into ./runs/21022024-121527/adapt/ folder with id 170846 ...

== Begining out-of-distribution testing ... ==
Number of training environments: 9
Number of training environments: 4
Number of adaptation environments: 4
Final length of the training trajectories: 20
Length of the testing trajectories: 20
Test Score (OOD): 0.0016007004
Final length of the training trajectories: 10
Length of the testing trajectories: 10
Test Score (OOD): 0.050321486

sh: 1: open: not found
== Begining out-of-distribution visualisation ... ==
Environment id: 1
Trajectory id: 0
Final length of the training trajectories: 20
Length of the testing trajectories: 20
Testing finished. Figure saved in: ./runs/18022024-221505/adapt/results_ood.png
Final length of the training trajectories: 10
Length of the testing trajectories: 10
Testing finished. Figure saved in: ./runs/21022024-121527/adapt/results_ood.png

Full evaluation of the model on 10 random seeds

seed ind_crit ood_crit
count 1.00e+01 1.00e+01 1.00e+01
mean 5.25e+03 4.02e-03 3.75e-03
std 3.39e+03 5.55e-04 2.80e-03
mean 5.25e+03 5.25e-02 5.04e-02
std 3.39e+03 2.50e-03 5.76e-04
Loading

0 comments on commit 0b8a9b4

Please sign in to comment.