From 0b8a9b48d2fbafed532de9e5a7bd0631ef031af4 Mon Sep 17 00:00:00 2001 From: gb21553 Date: Wed, 21 Feb 2024 14:05:59 +0000 Subject: [PATCH] 2D visualisation for gray-scott PDE --- examples/gray-scott/main.py | 27 +++-- examples/gray-scott/nohup.log | 184 ++++++++++++++++------------------ nodax/visualtester.py | 120 ++++++++++++++++++---- 3 files changed, 203 insertions(+), 128 deletions(-) diff --git a/examples/gray-scott/main.py b/examples/gray-scott/main.py index 3c9dd3d..22a8dba 100644 --- a/examples/gray-scott/main.py +++ b/examples/gray-scott/main.py @@ -15,12 +15,12 @@ flow_pool_count = 2 ## Number of neighboring contexts j to use for a flow in env e context_size = 1024 -nb_epochs = 100*1 -nb_epochs_adapt = 2*1 +nb_epochs = 180*360 +nb_epochs_adapt = 180*360 print_error_every = 1000 -train = False +train = True save_trainer = True finetune = False @@ -54,7 +54,7 @@ else: - run_folder = "./runs/20022024-173401/" ## Needed for loading the model and finetuning TODO: opti + run_folder = "./runs/21022024-121527/" ## Needed for loading the model and finetuning TODO: opti print("No training. Loading data and results from:", run_folder) ## Create a folder for the adaptation results @@ -80,7 +80,7 @@ #%% ## Define dataloader for training and validation -train_dataloader = DataLoader(run_folder+"train_data.npz", batch_size=4, int_cutoff=1.0, shuffle=True, key=seed) +train_dataloader = DataLoader(run_folder+"train_data.npz", batch_size=1, int_cutoff=1.0, shuffle=True, key=seed) nb_envs = train_dataloader.nb_envs nb_trajs_per_env = train_dataloader.nb_trajs_per_env @@ -140,6 +140,7 @@ def __init__(self, data_res, kernel_size, nb_int_channels, context_size, key=Non lambda x: x.flatten()] def __call__(self, t, y, ctx): + # return jnp.zeros_like(y)*ctx[0] for layer in self.layers_context: ctx = layer(ctx) @@ -223,7 +224,7 @@ def __call__(self, t, x, ctxs): return vf(ctx_) + gradvf(ctx_, ctx) -augmentation = Augmentation(data_res=32, kernel_size=3, nb_int_channels=4, context_size=context_size, key=seed) +augmentation = Augmentation(data_res=32, kernel_size=3, nb_int_channels=1, context_size=context_size, key=seed) vectorfield = ContextFlowVectorField(augmentation, physics=None) print("\n\nTotal number of parameters in the model:", sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(vectorfield,eqx.is_array)) if x is not None), "\n\n") @@ -353,11 +354,13 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, all_ctx_s, key): #%% -# ## Load the validation loss -# val_loss = np.load(run_folder+"val_losses.npy") -# val_loss - +## Custom Gray-Scott trajectory visualiser +if finetune: + savefigdir = finetunedir+"results_2D_ind.png" +else: + savefigdir = run_folder+"results_2D_ind.png" +visualtester.visualize2D(test_dataloader, int_cutoff=1.0, res=32, save_path=savefigdir); #%% @@ -497,7 +500,7 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, all_ctx_s, key): os.system(f'python dataset.py --split=test --savepath="{run_folder}" --seed="{seed*2}" --verbose=0') os.system(f'python dataset.py --split=adapt --savepath="{adapt_folder}" --seed="{seed*3}" --verbose=0') - test_dataloader = DataLoader(run_folder+"test_data.npz", shuffle=False, batch_size=4, data_id="082026") + test_dataloader = DataLoader(run_folder+"test_data.npz", shuffle=False, batch_size=1, data_id="082026") adapt_test_dataloader = DataLoader(adapt_folder+"adapt_data.npz", adaptation=True, batch_size=1, key=seed, data_id="082026") ind_crit, _ = visualtester.test(test_dataloader, int_cutoff=1.0, verbose=False) @@ -553,3 +556,5 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, all_ctx_s, key): # ## Save the odd_crit_all in numpy # np.save('mapes.npy', odd_crit_all) + +# %% diff --git a/examples/gray-scott/nohup.log b/examples/gray-scott/nohup.log index 6400fde..7380627 100644 --- a/examples/gray-scott/nohup.log +++ b/examples/gray-scott/nohup.log @@ -1,36 +1,36 @@ Running this script in ipython (Jupyter) session ? False === Parsed arguments to generate data === Split: train - Savepath: ./runs/18022024-221505/ + Savepath: ./runs/21022024-121527/ Seed: 2026 Running this script in ipython (Jupyter) session ? False === Parsed arguments to generate data === Split: test - Savepath: ./runs/18022024-221505/ + Savepath: ./runs/21022024-121527/ Seed: 4052 Running this script in ipython (Jupyter) session ? False === Parsed arguments to generate data === Split: adapt - Savepath: ./runs/18022024-221505/adapt/ + Savepath: ./runs/21022024-121527/adapt/ Seed: 6078 ############# Inductive Bias Learning for Dynamical Systems ############# -Jax version: 0.4.19 +Jax version: 0.4.23 Available devices: [cuda(id=0)] -Data folder created successfuly: ./runs/18022024-221505/ +Data folder created successfuly: ./runs/21022024-121527/ Completed copied scripts -WARNING: You did not provide a dataloader id. A new one has been generated: 221526 +WARNING: You did not provide a dataloader id. A new one has been generated: 121535 WARNING: Note that this id used to distuinguish between adaptations to different environments. -WARNING: You did not provide a dataloader id. A new one has been generated: 221526 +WARNING: You did not provide a dataloader id. A new one has been generated: 121536 WARNING: Note that this id used to distuinguish between adaptations to different environments. WARNING: batch_size must be between 0 and nb_trajs_per_env. Setting batch_size to maximum. -Total number of parameters in the model: 319911 +Total number of parameters in the model: 4200374 WARNING: No key provided for the context initialization. Initializing at 0. @@ -38,64 +38,57 @@ WARNING: No key provided, using time as seed === Beginning training ... === - Number of examples in a batch: 4 - Number of train steps per epoch: 8 - Number of training epochs: 24000 - Total number of training steps: 192000 + Number of examples in a batch: 1 + Number of train steps per epoch: 1 + Number of training epochs: 18000 + Total number of training steps: 18000 Compiling function "train_step" for neural ode ... -Shapes of elements in a batch: (9, 4, 20, 7) (20,) +Shapes of elements in a batch: (4, 1, 10, 2048) (10,) Compiling function "train_step" for context ... -Shapes of elements in a batch: (9, 4, 20, 7) (20,) - Epoch: 0 LossTrajs: 0.28848881 ContextsNorm: 0.00150763 ValIndCrit: 0.23863795 - Epoch: 1 LossTrajs: 0.23736051 ContextsNorm: 0.00279875 ValIndCrit: 0.22864109 - Epoch: 2 LossTrajs: 0.23170462 ContextsNorm: 0.00403358 ValIndCrit: 0.22836266 - Epoch: 3 LossTrajs: 0.23057464 ContextsNorm: 0.00466054 ValIndCrit: 0.22816052 - Epoch: 1000 LossTrajs: 0.00352303 ContextsNorm: 0.11520766 ValIndCrit: 0.00458277 - Epoch: 2000 LossTrajs: 0.00210036 ContextsNorm: 0.11645799 ValIndCrit: 0.00377220 - Epoch: 3000 LossTrajs: 0.00209809 ContextsNorm: 0.11668418 ValIndCrit: 0.00377159 - Epoch: 4000 LossTrajs: 0.00201739 ContextsNorm: 0.11677950 ValIndCrit: 0.00373893 - Epoch: 5000 LossTrajs: 0.00198951 ContextsNorm: 0.11691947 ValIndCrit: 0.00373602 - Epoch: 6000 LossTrajs: 0.00196165 ContextsNorm: 0.11688989 ValIndCrit: 0.00373701 - Epoch: 7000 LossTrajs: 0.00195028 ContextsNorm: 0.11689670 ValIndCrit: 0.00376597 - Epoch: 8000 LossTrajs: 0.00191953 ContextsNorm: 0.11700038 ValIndCrit: 0.00374925 - Epoch: 9000 LossTrajs: 0.00190826 ContextsNorm: 0.11683951 ValIndCrit: 0.00374242 - Epoch: 10000 LossTrajs: 0.00187307 ContextsNorm: 0.11691865 ValIndCrit: 0.00374858 - Epoch: 11000 LossTrajs: 0.00188009 ContextsNorm: 0.11678100 ValIndCrit: 0.00375942 - Epoch: 12000 LossTrajs: 0.00183770 ContextsNorm: 0.11688248 ValIndCrit: 0.00376303 - Epoch: 13000 LossTrajs: 0.00183869 ContextsNorm: 0.11692899 ValIndCrit: 0.00379309 - Epoch: 14000 LossTrajs: 0.00179264 ContextsNorm: 0.11691536 ValIndCrit: 0.00378790 - Epoch: 15000 LossTrajs: 0.00179362 ContextsNorm: 0.11679574 ValIndCrit: 0.00378610 - Epoch: 16000 LossTrajs: 0.00178432 ContextsNorm: 0.11683217 ValIndCrit: 0.00377219 - Epoch: 17000 LossTrajs: 0.00173260 ContextsNorm: 0.11698029 ValIndCrit: 0.00382742 - Epoch: 18000 LossTrajs: 0.00175081 ContextsNorm: 0.11694510 ValIndCrit: 0.00379931 - Epoch: 19000 LossTrajs: 0.00170353 ContextsNorm: 0.11696473 ValIndCrit: 0.00384929 - Epoch: 20000 LossTrajs: 0.00170334 ContextsNorm: 0.11698332 ValIndCrit: 0.00384146 - Epoch: 21000 LossTrajs: 0.00165402 ContextsNorm: 0.11691977 ValIndCrit: 0.00388831 - Epoch: 22000 LossTrajs: 0.00165488 ContextsNorm: 0.11698642 ValIndCrit: 0.00388500 - Epoch: 23000 LossTrajs: 0.00167246 ContextsNorm: 0.11704455 ValIndCrit: 0.00388701 - Epoch: 23999 LossTrajs: 0.00163804 ContextsNorm: 0.11711274 ValIndCrit: 0.00386044 - -Total gradient descent training time: 17 hours 27 mins 46 secs -Environment weights at the end of the training: [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 - 0.11111111 0.11111111 0.11111111] -WARNING: You did not provide a dataloader id. A new one has been generated: 154323 +Shapes of elements in a batch: (4, 1, 10, 2048) (10,) + Epoch: 0 LossTrajs: 8937.43457031 ContextsNorm: 0.00000000 ValIndCrit: 21615.28125000 + Epoch: 1 LossTrajs: 21615.42968750 ContextsNorm: 0.00999988 ValIndCrit: 10067.46289062 + Epoch: 2 LossTrajs: 10067.57812500 ContextsNorm: 0.01846115 ValIndCrit: 47097.55859375 + Epoch: 3 LossTrajs: 47097.61328125 ContextsNorm: 0.01237317 ValIndCrit: 1540.37536621 + Epoch: 1000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 2000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 3000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 4000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 5000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 6000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 7000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 8000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 9000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 10000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 11000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 12000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 13000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 14000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 15000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 16000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 17000 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + Epoch: 17999 LossTrajs: 0.05446807 ContextsNorm: 0.12620294 ValIndCrit: 0.05316104 + +Total gradient descent training time: 0 hours 25 mins 43 secs +Environment weights at the end of the training: [0.25 0.25 0.25 0.25] +WARNING: You did not provide a dataloader id. A new one has been generated: 124122 WARNING: Note that this id used to distuinguish between adaptations to different environments. WARNING: batch_size must be between 0 and nb_trajs_per_env. Setting batch_size to maximum. WARNING: No key provided, using time as seed == Begining in-domain testing ... == - Number of training environments: 9 - Final length of the training trajectories: 20 - Length of the testing trajectories: 20 -Test Score (In-Domain): 0.0038604357 + Number of training environments: 4 + Final length of the training trajectories: 10 + Length of the testing trajectories: 10 +Test Score (In-Domain): 0.05316104 == Begining in-domain visualisation ... == Environment id: 2 - Trajectory id: 11 - Final length of the training trajectories: 20 - Length of the testing trajectories: 20 -Testing finished. Figure saved in: ./runs/18022024-221505/results_in_domain.png + Trajectory id: 0 + Final length of the training trajectories: 10 + Length of the testing trajectories: 10 +Testing finished. Figure saved in: ./runs/21022024-121527/results_in_domain.png WARNING: batch_size must be between 0 and nb_trajs_per_env. Setting batch_size to maximum. WARNING: No key provided for the context initialization. Initializing at 0. @@ -103,64 +96,57 @@ WARNING: No key provided for the context initialization. Initializing at 0. === Beginning adaptation ... === Number of examples in a batch: 1 Number of train steps per epoch: 1 - Number of training epochs: 24000 - Total number of training steps: 24000 + Number of training epochs: 18000 + Total number of training steps: 18000 WARNING: No key provided, using time as seed Compiling function "train_step" for context ... -Shapes of elements in a batch: (4, 1, 20, 7) (20,) - Epoch: 0 LossContext: 0.03991570 - Epoch: 1 LossContext: 0.03874778 - Epoch: 2 LossContext: 0.03758416 - Epoch: 3 LossContext: 0.03632322 - Epoch: 1000 LossContext: 0.00190338 - Epoch: 2000 LossContext: 0.00175669 - Epoch: 3000 LossContext: 0.00165465 - Epoch: 4000 LossContext: 0.00164393 - Epoch: 5000 LossContext: 0.00162050 - Epoch: 6000 LossContext: 0.00169902 - Epoch: 7000 LossContext: 0.00161232 - Epoch: 8000 LossContext: 0.00168804 - Epoch: 9000 LossContext: 0.00167818 - Epoch: 10000 LossContext: 0.00167914 - Epoch: 11000 LossContext: 0.00155257 - Epoch: 12000 LossContext: 0.00159554 - Epoch: 13000 LossContext: 0.00167166 - Epoch: 14000 LossContext: 0.00167052 - Epoch: 15000 LossContext: 0.00170710 - Epoch: 16000 LossContext: 0.00165121 - Epoch: 17000 LossContext: 0.00159982 - Epoch: 18000 LossContext: 0.00154785 - Epoch: 19000 LossContext: 0.00154855 - Epoch: 20000 LossContext: 0.00154777 - Epoch: 21000 LossContext: 0.00159213 - Epoch: 22000 LossContext: 0.00165402 - Epoch: 23000 LossContext: 0.00166323 - Epoch: 23999 LossContext: 0.00159121 - -Total gradient descent adaptation time: 0 hours 40 mins 24 secs +Shapes of elements in a batch: (4, 1, 10, 2048) (10,) + Epoch: 0 LossContext: 0.05053605 + Epoch: 1 LossContext: 0.05032079 + Epoch: 2 LossContext: 0.05032144 + Epoch: 3 LossContext: 0.05032148 + Epoch: 1000 LossContext: 0.05032149 + Epoch: 2000 LossContext: 0.05032149 + Epoch: 3000 LossContext: 0.05032149 + Epoch: 4000 LossContext: 0.05032149 + Epoch: 5000 LossContext: 0.05032149 + Epoch: 6000 LossContext: 0.05032149 + Epoch: 7000 LossContext: 0.05032149 + Epoch: 8000 LossContext: 0.05032149 + Epoch: 9000 LossContext: 0.05032149 + Epoch: 10000 LossContext: 0.05032149 + Epoch: 11000 LossContext: 0.05032149 + Epoch: 12000 LossContext: 0.05032149 + Epoch: 13000 LossContext: 0.05032149 + Epoch: 14000 LossContext: 0.05032149 + Epoch: 15000 LossContext: 0.05032149 + Epoch: 16000 LossContext: 0.05032149 + Epoch: 17000 LossContext: 0.05032149 + Epoch: 17999 LossContext: 0.05032149 + +Total gradient descent adaptation time: 0 hours 6 mins 58 secs Environment weights at the end of the adaptation: [0.25 0.25 0.25 0.25] -Saving adaptation parameters into ./runs/18022024-221505/adapt/ folder with id 170846 ... +Saving adaptation parameters into ./runs/21022024-121527/adapt/ folder with id 170846 ... == Begining out-of-distribution testing ... == - Number of training environments: 9 + Number of training environments: 4 Number of adaptation environments: 4 - Final length of the training trajectories: 20 - Length of the testing trajectories: 20 -Test Score (OOD): 0.0016007004 + Final length of the training trajectories: 10 + Length of the testing trajectories: 10 +Test Score (OOD): 0.050321486 -sh: 1: open: not found == Begining out-of-distribution visualisation ... == Environment id: 1 Trajectory id: 0 - Final length of the training trajectories: 20 - Length of the testing trajectories: 20 -Testing finished. Figure saved in: ./runs/18022024-221505/adapt/results_ood.png + Final length of the training trajectories: 10 + Length of the testing trajectories: 10 +Testing finished. Figure saved in: ./runs/21022024-121527/adapt/results_ood.png Full evaluation of the model on 10 random seeds seed ind_crit ood_crit count 1.00e+01 1.00e+01 1.00e+01 -mean 5.25e+03 4.02e-03 3.75e-03 -std 3.39e+03 5.55e-04 2.80e-03 +mean 5.25e+03 5.25e-02 5.04e-02 +std 3.39e+03 2.50e-03 5.76e-04 diff --git a/nodax/visualtester.py b/nodax/visualtester.py index a7b45d6..94932af 100644 --- a/nodax/visualtester.py +++ b/nodax/visualtester.py @@ -215,7 +215,7 @@ def test(self, data_loader, criterion=None, int_cutoff=1.0, verbose=True): - def visualize(self, data_loader, e=None, traj=None, int_cutoff=1.0, save_path=False, key=None): + def visualize(self, data_loader, e=None, traj=None, dims=(0,1), context_dims=(0,1), int_cutoff=1.0, save_path=False, key=None): # assert data_loader.nb_envs == self.trainer.dataloader.nb_envs, "The number of environments in the test dataloader must be the same as the number of environments in the trainer." @@ -234,6 +234,7 @@ def visualize(self, data_loader, e=None, traj=None, int_cutoff=1.0, save_path=Fa print("== Begining out-of-distribution visualisation ... ==") print(" Environment id:", e) print(" Trajectory id:", traj) + print(" Visualized dimensions:", dims) print(" Final length of the training trajectories:", self.trainer.dataloader.int_cutoff) print(" Length of the testing trajectories:", test_length) @@ -255,12 +256,13 @@ def visualize(self, data_loader, e=None, traj=None, int_cutoff=1.0, save_path=Fa fig, ax = plt.subplot_mosaic('AB;CC;DD;EF', figsize=(6*2, 3.5*4)) mks = 2 + dim0, dim1 = dims - ax['A'].plot(t_test, X[:, 0], c="deepskyblue", label=r"$x_1$ (GT)") - ax['A'].plot(t_test, X_hat[:, 0], "o", c="royalblue", label=r"$\hat{x}_1$ (NCF)", markersize=mks) + ax['A'].plot(t_test, X[:, 0], c="deepskyblue", label=f"$x_{{{dim0}}}$ (GT)") + ax['A'].plot(t_test, X_hat[:, 0], "o", c="royalblue", label=f"$\\hat{{x}}_{{{dim0}}}$ (NCF)", markersize=mks) - ax['A'].plot(t_test, X[:, 1], c="violet", label=r"$x_1$ (GT)") - ax['A'].plot(t_test, X_hat[:, 1], "x", c="purple", label=r"$\hat{x}_2$ (NCF)", markersize=mks) + ax['A'].plot(t_test, X[:, 1], c="violet", label=f"$x_{{{dim1}}}$ (GT)") + ax['A'].plot(t_test, X_hat[:, 1], "x", c="purple", label=f"$\\hat{{x}}_{{{dim1}}}$ (NCF)", markersize=mks) ax['A'].set_xlabel("Time") ax['A'].set_ylabel("State") @@ -269,8 +271,8 @@ def visualize(self, data_loader, e=None, traj=None, int_cutoff=1.0, save_path=Fa ax['B'].plot(X[:, 0], X[:, 1], c="turquoise", label="GT") ax['B'].plot(X_hat[:, 0], X_hat[:, 1], ".", c="teal", label="NCF") - ax['B'].set_xlabel(r"$x_1$") - ax['B'].set_ylabel(r"$x_2$") + ax['B'].set_xlabel(f"$x_{{{dim0}}}$") + ax['B'].set_ylabel(f"$x_{{{dim1}}}$") ax['B'].set_title("Phase space") ax['B'].legend() @@ -314,20 +316,21 @@ def visualize(self, data_loader, e=None, traj=None, int_cutoff=1.0, save_path=Fa eps = 0.1 colors = ['dodgerblue', 'r', 'b', 'g', 'm', 'c', 'y', 'orange', 'purple', 'brown'] colors = colors*(nb_envs) + cdim0, cdim1 = context_dims - ax['F'].scatter(xis[:,0], xis[:,1], s=50, c=colors[:nb_envs], marker='o') - for i, (x, y) in enumerate(xis[:, :2]): - ax['F'].annotate(str(i), (x, y), fontsize=8) - ax['F'].set_title(r'Final Contexts ($\xi^e$)') - - ax['E'].scatter(init_xis[:,0], init_xis[:,1], s=30, c=colors[:nb_envs], marker='X') - ax['F'].scatter(xis[:,0], xis[:,1], s=50, c=colors[:nb_envs], marker='o') - for i, (x, y) in enumerate(init_xis[:, :2]): + ax['E'].scatter(init_xis[:,cdim0], init_xis[:,cdim1], s=30, c=colors[:nb_envs], marker='X') + ax['F'].scatter(xis[:,cdim0], xis[:,cdim1], s=50, c=colors[:nb_envs], marker='o') + for i, (x, y) in enumerate(init_xis[:, context_dims]): ax['E'].annotate(str(i), (x, y), fontsize=8) - for i, (x, y) in enumerate(xis[:, :2]): + for i, (x, y) in enumerate(xis[:, context_dims]): ax['F'].annotate(str(i), (x, y), fontsize=8) - ax['E'].set_title(r'Initial Contexts (first 2 dims)') - ax['F'].set_title(r'Final Contexts (first 2 dims)') + ax['E'].set_title(r'Initial Contexts') + ax['E'].set_xlabel(f'dim {cdim0}') + ax['E'].set_ylabel(f'dim {cdim1}') + + ax['F'].set_title(r'Final Contexts') + ax['F'].set_xlabel(f'dim {cdim0}') + ax['F'].set_ylabel(f'dim {cdim1}') plt.suptitle(f"Results for env={e}, traj={traj}", fontsize=14) @@ -338,3 +341,84 @@ def visualize(self, data_loader, e=None, traj=None, int_cutoff=1.0, save_path=Fa if save_path: plt.savefig(save_path, dpi=100, bbox_inches='tight') print("Testing finished. Figure saved in:", save_path); + + + + def visualize2D(self, data_loader, e=None, traj=None, res=(32,32), int_cutoff=1.0, nb_plot_timesteps=10, save_path=False, key=None): + + """ + The visualize2D function is used to visualize the results of a trained neural ODE model. + + :param self: Access the trainer object + :param data_loader: Get the data from the dataset + :param e: Select the environment to visualize + :param traj: Specify which trajectory to visualize + :param res: Specify the resolution of the gif + :param 32): Set the resolution of the gif + :param int_cutoff: Specify the length of the trajectory to be visualized + :param nb_plot_timesteps: Specify the number of timesteps to be visualized + :param save_path: Specify the path where to save the figure + :param key: Generate a random key for the jax + :return: A figure with two subplots + :doc-author: Trelent + """ + e_key, traj_key = get_new_key(time.time_ns(), num=2) + e = e if e else jax.random.randint(e_key, (1,), 0, data_loader.nb_envs)[0] + traj = traj if traj else jax.random.randint(traj_key, (1,), 0, data_loader.nb_trajs_per_env)[0] + + t_eval = data_loader.t_eval + test_length = int(data_loader.nb_steps_per_traj*int_cutoff) + X = data_loader.dataset[e, traj:traj+1, :test_length, :] + t_test = t_eval[:test_length] + + if data_loader.adaptation == False: + print("== Begining in-domain 2D visualisation ... ==") + else: + print("== Begining out-of-distribution 2D visualisation ... ==") + print(" Environment id:", e) + print(" Trajectory id:", traj) + print(" Length of the testing trajectories:", test_length) + + if data_loader.adaptation == False: + contexts = self.trainer.learner.contexts.params + else: + contexts = self.trainer.learner.contexts_adapt.params + X_hat, _ = self.trainer.learner.neuralode(X[:, 0, :], + t_test, + contexts[e], + contexts[e]) + + X_hat = X_hat.squeeze() + X = X.squeeze() + + # if isinstance(res, int): + # res = (res, res) + nb_mats = X_hat.shape[1] // (res*res) + assert nb_mats > 0, f"Not enough dimensions to form a {res}x{res} matrix" + # mats = vec_to_mats(X_hat, res, nb_mats) + + fig, ax = plt.subplots(nrows=nb_mats*2, ncols=nb_plot_timesteps, figsize=(2*nb_plot_timesteps, 2*nb_mats*2)) + for j in range(0, test_length, test_length//nb_plot_timesteps): + gt_j = vec_to_mats(X[j], res, nb_mats) + ncf_j = vec_to_mats(X_hat[j], res, nb_mats) + for i in range(nb_mats): + ax[2*i, j].imshow(gt_j[i], cmap='gist_ncar', interpolation='bilinear', origin='lower') + ax[2*i+1, j].imshow(ncf_j[i], cmap='gist_ncar', interpolation='bilinear', origin='lower') + + ## Remove the ticks and labels + for a in ax.flatten(): + a.set_xticks([]) + a.set_yticks([]) + a.set_xticklabels([]) + a.set_yticklabels([]) + + plt.suptitle(f"2D visualisation results for env={e}, traj={traj}", fontsize=20) + + plt.tight_layout() + plt.draw(); + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + print("Testing finished. Figure saved in:", save_path); + + ## Save the gifs as well \ No newline at end of file