From f2b6d520b81cdff086ce7336aa7a4ba60642aa33 Mon Sep 17 00:00:00 2001 From: Payam Jome Yazdian Date: Tue, 6 Feb 2024 23:18:31 -0800 Subject: [PATCH] Add files via upload --- config/DAE.yml | 150 +++++++++++++++++------------------ config/DAE_GENEA.yml | 144 +++++++++++++++++----------------- config/VQ-VAE.yml | 144 +++++++++++++++++----------------- config/VQ-VAE_GENEA.yml | 154 ++++++++++++++++++------------------ config/parse_args.py | 96 +++++++++++++++++++++++ config/seq2seq.yml | 138 ++++++++++++++++----------------- config/seq2seqtxt.yml | 168 ++++++++++++++++++++-------------------- 7 files changed, 545 insertions(+), 449 deletions(-) create mode 100644 config/parse_args.py diff --git a/config/DAE.yml b/config/DAE.yml index 0e1e116..f549a1b 100644 --- a/config/DAE.yml +++ b/config/DAE.yml @@ -1,75 +1,75 @@ -name: Frame_Level - -train_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb/lmdb_train -sentence_level: True - -#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test -val_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb//lmdb_test - -wordembed_dim: 300 -wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext - -model_save_path: ../output/2023/pose2vec_1_Vanilla -random_seed: 0 - -# model params -model: DAE_complex -hidden_size: 200 -n_layers: 2 -dropout_prob: 0.2 - -input_motion_dim: 135 -data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] -data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] - - -#Autoencoder parameters: -autoencoder_denoising: True -autoencoder_att: False -autoencoder_fixed_weight: True -autoencoder_conditioned: False -autoencoder_vae: False -autoencoder_vq: False -autoencoder_vq_components: 80 -autoencoder_vq_commitment_cost: 0.25 -use_derivative: False -#autoenoder_train_decoder: True -autoencoder_freeze_encoder: False - - -#Text 2 Gesture -text2_embedding_discrete: True - -use_similarity: False -similarity_labels: data_loader/gesture_labels.txt -data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin -loss_label_weight: 5.5 - - -# train params -epochs: 20 -batch_size: 128 -learning_rate: 0.0005 -loss_l1_weight: 5 -loss_cont_weight: 0.1 -loss_var_weight: 0.5 - -# dataset params -motion_resampling_framerate: 20 -n_poses: 30 -n_pre_poses: 1 -subdivision_stride: 5 -subdivision_stride_sentence: 20 -sentence_frame_length: 120 -loader_workers: 4 - -#reoresentation learning -rep_learning_checkpoint: ../output/DAE_old/train_DAE_H41/rep_learning_DAE_H41_checkpoint_020.bin -rep_learning_dim: 41 -autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.bin -#GAN -noise_dim: 400 - -Modality_Audio: False -Modality_Text: False -Modality_Gesture: True +name: Frame_Level + +train_data_path: ../../data/Training_data/lmdb/lmdb_train +sentence_level: True + + +val_data_path: ../../data/Training_data/lmdb//lmdb_test + +wordembed_dim: 300 +wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext + +model_save_path: ../output/2024/pose2vec_1_Vanilla +random_seed: 0 + +# model params +model: DAE_complex +hidden_size: 40 +n_layers: 2 +dropout_prob: 0.2 + +input_motion_dim: 135 +data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] +data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] + + +#Autoencoder parameters: +autoencoder_denoising: True +autoencoder_att: False +autoencoder_fixed_weight: True +autoencoder_conditioned: False +autoencoder_vae: False +autoencoder_vq: False +autoencoder_vq_components: 80 +autoencoder_vq_commitment_cost: 0.25 +use_derivative: False +#autoenoder_train_decoder: True +autoencoder_freeze_encoder: False + + +#Text 2 Gesture +text2_embedding_discrete: True + +use_similarity: False +similarity_labels: data_loader/gesture_labels.txt +data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin +loss_label_weight: 5.5 + + +# train params +epochs: 20 +batch_size: 128 +learning_rate: 0.0005 +loss_l1_weight: 5 +loss_cont_weight: 0.1 +loss_var_weight: 0.5 + +# dataset params +motion_resampling_framerate: 20 +n_poses: 20 +n_pre_poses: 1 +subdivision_stride: 5 +subdivision_stride_sentence: 20 +sentence_frame_length: 120 +loader_workers: 4 + +#reoresentation learning +rep_learning_checkpoint: None +rep_learning_dim: 0 +autoencoder_checkpoint: None +#GAN +noise_dim: 400 + +Modality_Audio: False +Modality_Text: False +Modality_Gesture: True diff --git a/config/DAE_GENEA.yml b/config/DAE_GENEA.yml index 3d0af9e..aa8d650 100644 --- a/config/DAE_GENEA.yml +++ b/config/DAE_GENEA.yml @@ -1,72 +1,72 @@ -name: DAE - -train_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_train -sentence_level: True - -#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test -val_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_test - -wordembed_dim: 300 -wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext - -model_save_path: ../output/GENEA/DAE -random_seed: 0 - -input_motion_dim: 135 -data_mean: [1.00000, -0.00034, -0.00062, 0.00035, 1.00000, 0.00030, 0.00062, -0.00031, 1.00000, 0.99998, 0.00419, 0.00081, -0.00418, 0.99998, -0.00062, -0.00083, 0.00061, 0.99999, 0.99991, 0.00472, 0.00917, -0.00475, 1.00000, 0.00325, -0.00916, -0.00333, 0.99999, 0.99997, -0.00763, -0.00009, 0.00763, 0.99997, 0.00002, 0.00009, -0.00003, 1.00000, 0.99962, -0.01398, -0.01889, 0.01383, 1.00000, -0.00656, 0.01898, 0.00630, 0.99982, 1.00000, 0.00037, 0.00049, -0.00037, 1.00000, -0.00057, -0.00049, 0.00057, 1.00000, 1.00000, -0.00301, 0.00043, 0.00301, 1.00000, -0.00056, -0.00043, 0.00057, 0.99999, 0.99989, 0.01280, -0.00062, -0.01276, 0.99989, 0.00084, 0.00063, -0.00084, 0.99997, 0.99988, 0.00165, -0.01069, -0.00165, 1.00000, -0.00064, 0.01069, 0.00068, 0.99988, 0.99998, -0.00350, -0.00003, 0.00350, 0.99998, 0.00001, 0.00003, -0.00001, 1.00000, 0.99955, -0.01916, 0.01381, 0.01899, 0.99965, 0.00691, -0.01391, -0.00668, 0.99984, 1.00000, 0.00124, -0.00088, -0.00123, 1.00000, 0.00148, 0.00088, -0.00147, 1.00000, 1.00000, 0.00017, 0.00002, -0.00017, 1.00000, 0.00017, -0.00002, -0.00017, 1.00000, 1.00000, -0.00079, 0.00032, 0.00078, 1.00000, 0.00019, -0.00032, -0.00018, 1.00000, 1.00000, 0.00003, 0.00005, -0.00003, 1.00000, -0.00006, -0.00005, 0.00006, 1.00000, 1.00000, 0.00061, 0.00007, -0.00061, 1.00000, -0.00020, -0.00007, 0.00020, 1.00000, 1.00000, -0.00245, 0.00005, 0.00245, 1.00000, -0.00005, -0.00005, 0.00005, 1.00000, 1.00000, -0.00062, 0.00006, 0.00062, 1.00000, 0.00000, -0.00006, -0.00000, 1.00000] -data_std: [0.00004, 0.00181, 0.00640, 0.00184, 0.00000, 0.00338, 0.00639, 0.00340, 0.00005, 0.00011, 0.00982, 0.00313, 0.00980, 0.00011, 0.00347, 0.00318, 0.00342, 0.00008, 0.00019, 0.00592, 0.00635, 0.00594, 0.00000, 0.00162, 0.00634, 0.00160, 0.00006, 0.00011, 0.00817, 0.00093, 0.00817, 0.00011, 0.00095, 0.00094, 0.00093, 0.00002, 0.00020, 0.00222, 0.00488, 0.00232, 0.00002, 0.00488, 0.00493, 0.00485, 0.00024, 0.00000, 0.00104, 0.00184, 0.00104, 0.00000, 0.00163, 0.00185, 0.00163, 0.00000, 0.00004, 0.00127, 0.00646, 0.00128, 0.00000, 0.00270, 0.00646, 0.00270, 0.00006, 0.00022, 0.00987, 0.00607, 0.00986, 0.00022, 0.00639, 0.00617, 0.00629, 0.00017, 0.00021, 0.00240, 0.01065, 0.00239, 0.00000, 0.00094, 0.01065, 0.00089, 0.00021, 0.00009, 0.00852, 0.00051, 0.00852, 0.00009, 0.00054, 0.00052, 0.00053, 0.00001, 0.00015, 0.00474, 0.00900, 0.00495, 0.00022, 0.00898, 0.00910, 0.00890, 0.00028, 0.00000, 0.00131, 0.00257, 0.00130, 0.00000, 0.00214, 0.00257, 0.00213, 0.00000, 0.00000, 0.00322, 0.00078, 0.00322, 0.00000, 0.00322, 0.00079, 0.00322, 0.00000, 0.00000, 0.00320, 0.00243, 0.00319, 0.00000, 0.00155, 0.00243, 0.00154, 0.00000, 0.00000, 0.00003, 0.00077, 0.00003, 0.00000, 0.00027, 0.00077, 0.00027, 0.00000, 0.00000, 0.00051, 0.00125, 0.00051, 0.00000, 0.00107, 0.00125, 0.00107, 0.00000, 0.00000, 0.00173, 0.00080, 0.00173, 0.00000, 0.00038, 0.00080, 0.00038, 0.00000, 0.00000, 0.00043, 0.00076, 0.00043, 0.00000, 0.00010, 0.00076, 0.00010, 0.00000] - - -# model params -model: seq2seq -hidden_size: 200 -n_layers: 2 -dropout_prob: 0.0 - -#Atuoencoder parameters: -autoencoder_denoising: True -autoencoder_att: False -autoencoder_fixed_weight: True -autoencoder_conditioned: False -autoencoder_vae: False -autoencoder_vq: False -autoencoder_vq_components: 100 -autoencoder_vq_commitment_cost: 0.25 -use_derivative: False -#autoenoder_train_decoder: True -autoencoder_freeze_encoder: False - - -#Text 2 Gesture -text2_embedding_discrete: True - -use_similarity: False -similarity_labels: data_loader/gesture_labels.txt -data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin -loss_label_weight: 5.5 - - -# train params -epochs: 30 -batch_size: 1024 -learning_rate: 0.00001 -loss_l1_weight: 5 -loss_cont_weight: 0.1 -loss_var_weight: 0.5 - -# dataset params -motion_resampling_framerate: 10 -n_poses: 30 -n_pre_poses: 1 -subdivision_stride: 5 -subdivision_stride_sentence: 20 -sentence_frame_length: 120 -loader_workers: 4 - -#reoresentation learning -rep_learning_checkpoint: ../output/DAE_old/train_DAE_H41/rep_learning_DAE_H41_checkpoint_020.binXYZ -rep_learning_dim: 41 -autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.binXYZ -#GAN -noise_dim: 400 - +name: DAE + +train_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_train +sentence_level: True + +#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test +val_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_test + +wordembed_dim: 300 +wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext + +model_save_path: ../output/GENEA/DAE +random_seed: 0 + +input_motion_dim: 135 +data_mean: [1.00000, -0.00034, -0.00062, 0.00035, 1.00000, 0.00030, 0.00062, -0.00031, 1.00000, 0.99998, 0.00419, 0.00081, -0.00418, 0.99998, -0.00062, -0.00083, 0.00061, 0.99999, 0.99991, 0.00472, 0.00917, -0.00475, 1.00000, 0.00325, -0.00916, -0.00333, 0.99999, 0.99997, -0.00763, -0.00009, 0.00763, 0.99997, 0.00002, 0.00009, -0.00003, 1.00000, 0.99962, -0.01398, -0.01889, 0.01383, 1.00000, -0.00656, 0.01898, 0.00630, 0.99982, 1.00000, 0.00037, 0.00049, -0.00037, 1.00000, -0.00057, -0.00049, 0.00057, 1.00000, 1.00000, -0.00301, 0.00043, 0.00301, 1.00000, -0.00056, -0.00043, 0.00057, 0.99999, 0.99989, 0.01280, -0.00062, -0.01276, 0.99989, 0.00084, 0.00063, -0.00084, 0.99997, 0.99988, 0.00165, -0.01069, -0.00165, 1.00000, -0.00064, 0.01069, 0.00068, 0.99988, 0.99998, -0.00350, -0.00003, 0.00350, 0.99998, 0.00001, 0.00003, -0.00001, 1.00000, 0.99955, -0.01916, 0.01381, 0.01899, 0.99965, 0.00691, -0.01391, -0.00668, 0.99984, 1.00000, 0.00124, -0.00088, -0.00123, 1.00000, 0.00148, 0.00088, -0.00147, 1.00000, 1.00000, 0.00017, 0.00002, -0.00017, 1.00000, 0.00017, -0.00002, -0.00017, 1.00000, 1.00000, -0.00079, 0.00032, 0.00078, 1.00000, 0.00019, -0.00032, -0.00018, 1.00000, 1.00000, 0.00003, 0.00005, -0.00003, 1.00000, -0.00006, -0.00005, 0.00006, 1.00000, 1.00000, 0.00061, 0.00007, -0.00061, 1.00000, -0.00020, -0.00007, 0.00020, 1.00000, 1.00000, -0.00245, 0.00005, 0.00245, 1.00000, -0.00005, -0.00005, 0.00005, 1.00000, 1.00000, -0.00062, 0.00006, 0.00062, 1.00000, 0.00000, -0.00006, -0.00000, 1.00000] +data_std: [0.00004, 0.00181, 0.00640, 0.00184, 0.00000, 0.00338, 0.00639, 0.00340, 0.00005, 0.00011, 0.00982, 0.00313, 0.00980, 0.00011, 0.00347, 0.00318, 0.00342, 0.00008, 0.00019, 0.00592, 0.00635, 0.00594, 0.00000, 0.00162, 0.00634, 0.00160, 0.00006, 0.00011, 0.00817, 0.00093, 0.00817, 0.00011, 0.00095, 0.00094, 0.00093, 0.00002, 0.00020, 0.00222, 0.00488, 0.00232, 0.00002, 0.00488, 0.00493, 0.00485, 0.00024, 0.00000, 0.00104, 0.00184, 0.00104, 0.00000, 0.00163, 0.00185, 0.00163, 0.00000, 0.00004, 0.00127, 0.00646, 0.00128, 0.00000, 0.00270, 0.00646, 0.00270, 0.00006, 0.00022, 0.00987, 0.00607, 0.00986, 0.00022, 0.00639, 0.00617, 0.00629, 0.00017, 0.00021, 0.00240, 0.01065, 0.00239, 0.00000, 0.00094, 0.01065, 0.00089, 0.00021, 0.00009, 0.00852, 0.00051, 0.00852, 0.00009, 0.00054, 0.00052, 0.00053, 0.00001, 0.00015, 0.00474, 0.00900, 0.00495, 0.00022, 0.00898, 0.00910, 0.00890, 0.00028, 0.00000, 0.00131, 0.00257, 0.00130, 0.00000, 0.00214, 0.00257, 0.00213, 0.00000, 0.00000, 0.00322, 0.00078, 0.00322, 0.00000, 0.00322, 0.00079, 0.00322, 0.00000, 0.00000, 0.00320, 0.00243, 0.00319, 0.00000, 0.00155, 0.00243, 0.00154, 0.00000, 0.00000, 0.00003, 0.00077, 0.00003, 0.00000, 0.00027, 0.00077, 0.00027, 0.00000, 0.00000, 0.00051, 0.00125, 0.00051, 0.00000, 0.00107, 0.00125, 0.00107, 0.00000, 0.00000, 0.00173, 0.00080, 0.00173, 0.00000, 0.00038, 0.00080, 0.00038, 0.00000, 0.00000, 0.00043, 0.00076, 0.00043, 0.00000, 0.00010, 0.00076, 0.00010, 0.00000] + + +# model params +model: seq2seq +hidden_size: 200 +n_layers: 2 +dropout_prob: 0.0 + +#Atuoencoder parameters: +autoencoder_denoising: True +autoencoder_att: False +autoencoder_fixed_weight: True +autoencoder_conditioned: False +autoencoder_vae: False +autoencoder_vq: False +autoencoder_vq_components: 100 +autoencoder_vq_commitment_cost: 0.25 +use_derivative: False +#autoenoder_train_decoder: True +autoencoder_freeze_encoder: False + + +#Text 2 Gesture +text2_embedding_discrete: True + +use_similarity: False +similarity_labels: data_loader/gesture_labels.txt +data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin +loss_label_weight: 5.5 + + +# train params +epochs: 30 +batch_size: 1024 +learning_rate: 0.00001 +loss_l1_weight: 5 +loss_cont_weight: 0.1 +loss_var_weight: 0.5 + +# dataset params +motion_resampling_framerate: 10 +n_poses: 30 +n_pre_poses: 1 +subdivision_stride: 5 +subdivision_stride_sentence: 20 +sentence_frame_length: 120 +loader_workers: 4 + +#reoresentation learning +rep_learning_checkpoint: ../output/DAE_old/train_DAE_H41/rep_learning_DAE_H41_checkpoint_020.binXYZ +rep_learning_dim: 41 +autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.binXYZ +#GAN +noise_dim: 400 + diff --git a/config/VQ-VAE.yml b/config/VQ-VAE.yml index 029663b..6466650 100644 --- a/config/VQ-VAE.yml +++ b/config/VQ-VAE.yml @@ -1,73 +1,73 @@ -name: VQVAE - -train_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb/lmdb_train -sentence_level: True - -#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test -val_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb//lmdb_test - -wordembed_dim: 300 -wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext - -model_save_path: ../output/2023/gesture2vec_VQ-VAE_16X -random_seed: 0 - -input_motion_dim: 135 -data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] -data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] - -# model params -model: seq2seq -hidden_size: 200 -n_layers: 2 -dropout_prob: 0.0 - -#Atuoencoder parameters: -autoencoder_denoising: True -autoencoder_att: False -autoencoder_fixed_weight: False -autoencoder_conditioned: True -autoencoder_vae: False -autoencoder_vq: True -autoencoder_vq_components: 512 -autoencoder_vq_commitment_cost: 0.25 -use_derivative: False -#autoenoder_train_decoder: True -autoencoder_freeze_encoder: False - -#Text 2 Gesture -text2_embedding_discrete: True - -use_similarity: False -similarity_labels: data_loader/gesture_labels.txt -data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin -loss_label_weight: 5.5 - - -# train params -epochs: 200 -batch_size: 128 -learning_rate: 0.0005 -loss_l1_weight: 5 -loss_cont_weight: 0.1 -loss_var_weight: 0.5 - -# dataset params -motion_resampling_framerate: 20 -n_poses: 20 -n_pre_poses: 1 -subdivision_stride: 5 -subdivision_stride_sentence: 20 -sentence_frame_length: 120 -loader_workers: 4 - -#reoresentation learning -rep_learning_checkpoint: ../output/2023/pose2vec_1_Vanilla/train_DAE_H40/Frame_Level_H40_checkpoint_020.bin -rep_learning_dim: 40 -autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.bin -#GAN -noise_dim: 400 - -Modality_Audio: False -Modality_Text: False +name: VQVAE + +train_data_path: ../../data/Training_data/lmdb/lmdb_train +sentence_level: True + + +val_data_path: ../../data/Training_data/lmdb//lmdb_test + +wordembed_dim: 300 +wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext + +model_save_path: ../output/2024/Gesture2Vec_VQ-VAE +random_seed: 0 + +input_motion_dim: 135 +data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] +data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] + +# model params +model: seq2seq +hidden_size: 200 +n_layers: 2 +dropout_prob: 0.2 + +#Atuoencoder parameters: +autoencoder_denoising: True +autoencoder_att: False +autoencoder_fixed_weight: False +autoencoder_conditioned: True +autoencoder_vae: False +autoencoder_vq: True +autoencoder_vq_components: 512 +autoencoder_vq_commitment_cost: 0.25 +use_derivative: False +#autoenoder_train_decoder: True +autoencoder_freeze_encoder: False + +#Text 2 Gesture +text2_embedding_discrete: True + +use_similarity: False +similarity_labels: data_loader/gesture_labels.txt +data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin +loss_label_weight: 5.5 + + +# train params +epochs: 10 +batch_size: 128 +learning_rate: 0.0005 +loss_l1_weight: 5 +loss_cont_weight: 0.1 +loss_var_weight: 0.5 + +# dataset params +motion_resampling_framerate: 20 +n_poses: 20 +n_pre_poses: 1 +subdivision_stride: 5 +subdivision_stride_sentence: 20 +sentence_frame_length: 120 +loader_workers: 4 + +#reoresentation learning +rep_learning_checkpoint: ../output/2024/pose2vec_1_Vanilla/train_DAE_H40/Frame_Level_H40_checkpoint_020.bin +rep_learning_dim: 40 +autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.bin +#GAN +noise_dim: 400 + +Modality_Audio: False +Modality_Text: False Modality_Gesture: True \ No newline at end of file diff --git a/config/VQ-VAE_GENEA.yml b/config/VQ-VAE_GENEA.yml index 7687e5e..0e7859f 100644 --- a/config/VQ-VAE_GENEA.yml +++ b/config/VQ-VAE_GENEA.yml @@ -1,77 +1,77 @@ -name: VQVAE - - -train_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_train -sentence_level: True - - -#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test -val_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_test - - -wordembed_dim: 300 -wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext - -model_save_path: ../output/GENEA/VQ-VAE -random_seed: 0 - -# input_motion_dim: 135 -#data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] -#data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] - -input_motion_dim: 162 -data_mean: [1.00000, -0.00034, -0.00062, 0.00035, 1.00000, 0.00030, 0.00062, -0.00031, 1.00000, 0.99998, 0.00419, 0.00081, -0.00418, 0.99998, -0.00062, -0.00083, 0.00061, 0.99999, 0.99991, 0.00472, 0.00917, -0.00475, 1.00000, 0.00325, -0.00916, -0.00333, 0.99999, 0.99997, -0.00763, -0.00009, 0.00763, 0.99997, 0.00002, 0.00009, -0.00003, 1.00000, 0.99962, -0.01398, -0.01889, 0.01383, 1.00000, -0.00656, 0.01898, 0.00630, 0.99982, 1.00000, 0.00037, 0.00049, -0.00037, 1.00000, -0.00057, -0.00049, 0.00057, 1.00000, 1.00000, -0.00301, 0.00043, 0.00301, 1.00000, -0.00056, -0.00043, 0.00057, 0.99999, 0.99989, 0.01280, -0.00062, -0.01276, 0.99989, 0.00084, 0.00063, -0.00084, 0.99997, 0.99988, 0.00165, -0.01069, -0.00165, 1.00000, -0.00064, 0.01069, 0.00068, 0.99988, 0.99998, -0.00350, -0.00003, 0.00350, 0.99998, 0.00001, 0.00003, -0.00001, 1.00000, 0.99955, -0.01916, 0.01381, 0.01899, 0.99965, 0.00691, -0.01391, -0.00668, 0.99984, 1.00000, 0.00124, -0.00088, -0.00123, 1.00000, 0.00148, 0.00088, -0.00147, 1.00000, 1.00000, 0.00017, 0.00002, -0.00017, 1.00000, 0.00017, -0.00002, -0.00017, 1.00000, 1.00000, -0.00079, 0.00032, 0.00078, 1.00000, 0.00019, -0.00032, -0.00018, 1.00000, 1.00000, 0.00003, 0.00005, -0.00003, 1.00000, -0.00006, -0.00005, 0.00006, 1.00000, 1.00000, 0.00061, 0.00007, -0.00061, 1.00000, -0.00020, -0.00007, 0.00020, 1.00000, 1.00000, -0.00245, 0.00005, 0.00245, 1.00000, -0.00005, -0.00005, 0.00005, 1.00000, 1.00000, -0.00062, 0.00006, 0.00062, 1.00000, 0.00000, -0.00006, -0.00000, 1.00000] -data_std: [0.00004, 0.00181, 0.00640, 0.00184, 0.00000, 0.00338, 0.00639, 0.00340, 0.00005, 0.00011, 0.00982, 0.00313, 0.00980, 0.00011, 0.00347, 0.00318, 0.00342, 0.00008, 0.00019, 0.00592, 0.00635, 0.00594, 0.00000, 0.00162, 0.00634, 0.00160, 0.00006, 0.00011, 0.00817, 0.00093, 0.00817, 0.00011, 0.00095, 0.00094, 0.00093, 0.00002, 0.00020, 0.00222, 0.00488, 0.00232, 0.00002, 0.00488, 0.00493, 0.00485, 0.00024, 0.00000, 0.00104, 0.00184, 0.00104, 0.00000, 0.00163, 0.00185, 0.00163, 0.00000, 0.00004, 0.00127, 0.00646, 0.00128, 0.00000, 0.00270, 0.00646, 0.00270, 0.00006, 0.00022, 0.00987, 0.00607, 0.00986, 0.00022, 0.00639, 0.00617, 0.00629, 0.00017, 0.00021, 0.00240, 0.01065, 0.00239, 0.00000, 0.00094, 0.01065, 0.00089, 0.00021, 0.00009, 0.00852, 0.00051, 0.00852, 0.00009, 0.00054, 0.00052, 0.00053, 0.00001, 0.00015, 0.00474, 0.00900, 0.00495, 0.00022, 0.00898, 0.00910, 0.00890, 0.00028, 0.00000, 0.00131, 0.00257, 0.00130, 0.00000, 0.00214, 0.00257, 0.00213, 0.00000, 0.00000, 0.00322, 0.00078, 0.00322, 0.00000, 0.00322, 0.00079, 0.00322, 0.00000, 0.00000, 0.00320, 0.00243, 0.00319, 0.00000, 0.00155, 0.00243, 0.00154, 0.00000, 0.00000, 0.00003, 0.00077, 0.00003, 0.00000, 0.00027, 0.00077, 0.00027, 0.00000, 0.00000, 0.00051, 0.00125, 0.00051, 0.00000, 0.00107, 0.00125, 0.00107, 0.00000, 0.00000, 0.00173, 0.00080, 0.00173, 0.00000, 0.00038, 0.00080, 0.00038, 0.00000, 0.00000, 0.00043, 0.00076, 0.00043, 0.00000, 0.00010, 0.00076, 0.00010, 0.00000] - - -# model params -model: seq2seq -hidden_size: 200 -n_layers: 2 -dropout_prob: 0.0 - -#Atuoencoder parameters: -autoencoder_denoising: True -autoencoder_att: False -autoencoder_fixed_weight: False -autoencoder_conditioned: True -autoencoder_vae: False -autoencoder_vq: True -autoencoder_vq_components: 400 -autoencoder_vq_commitment_cost: 0.25 -use_derivative: False -#autoenoder_train_decoder: True -autoencoder_freeze_encoder: False - -#Text 2 Gesture -text2_embedding_discrete: True - -use_similarity: False -similarity_labels: data_loader/gesture_labels.txt -data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin -loss_label_weight: 5.5 - - -# train params -epochs: 3000 -batch_size: 128 -learning_rate: 0.0001 -loss_l1_weight: 5 -loss_cont_weight: 0.1 -loss_var_weight: 0.5 - -# dataset params -motion_resampling_framerate: 10 -n_poses: 10 -n_pre_poses: 1 -subdivision_stride: 20 -subdivision_stride_sentence: 20 -sentence_frame_length: 120 -loader_workers: 4 - -#reoresentation learning -rep_learning_checkpoint: ../output/GENEA/DAE/train_DAE_H45/DAE_H45_checkpoint_030.bin -rep_learning_dim: 45 -autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.bin -#GAN -noise_dim: 400 +name: VQVAE + + +train_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_train +sentence_level: True + + +#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test +val_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_test + + +wordembed_dim: 300 +wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext + +model_save_path: ../output/GENEA/VQ-VAE +random_seed: 0 + +# input_motion_dim: 135 +#data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] +#data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] + +input_motion_dim: 162 +data_mean: [1.00000, -0.00034, -0.00062, 0.00035, 1.00000, 0.00030, 0.00062, -0.00031, 1.00000, 0.99998, 0.00419, 0.00081, -0.00418, 0.99998, -0.00062, -0.00083, 0.00061, 0.99999, 0.99991, 0.00472, 0.00917, -0.00475, 1.00000, 0.00325, -0.00916, -0.00333, 0.99999, 0.99997, -0.00763, -0.00009, 0.00763, 0.99997, 0.00002, 0.00009, -0.00003, 1.00000, 0.99962, -0.01398, -0.01889, 0.01383, 1.00000, -0.00656, 0.01898, 0.00630, 0.99982, 1.00000, 0.00037, 0.00049, -0.00037, 1.00000, -0.00057, -0.00049, 0.00057, 1.00000, 1.00000, -0.00301, 0.00043, 0.00301, 1.00000, -0.00056, -0.00043, 0.00057, 0.99999, 0.99989, 0.01280, -0.00062, -0.01276, 0.99989, 0.00084, 0.00063, -0.00084, 0.99997, 0.99988, 0.00165, -0.01069, -0.00165, 1.00000, -0.00064, 0.01069, 0.00068, 0.99988, 0.99998, -0.00350, -0.00003, 0.00350, 0.99998, 0.00001, 0.00003, -0.00001, 1.00000, 0.99955, -0.01916, 0.01381, 0.01899, 0.99965, 0.00691, -0.01391, -0.00668, 0.99984, 1.00000, 0.00124, -0.00088, -0.00123, 1.00000, 0.00148, 0.00088, -0.00147, 1.00000, 1.00000, 0.00017, 0.00002, -0.00017, 1.00000, 0.00017, -0.00002, -0.00017, 1.00000, 1.00000, -0.00079, 0.00032, 0.00078, 1.00000, 0.00019, -0.00032, -0.00018, 1.00000, 1.00000, 0.00003, 0.00005, -0.00003, 1.00000, -0.00006, -0.00005, 0.00006, 1.00000, 1.00000, 0.00061, 0.00007, -0.00061, 1.00000, -0.00020, -0.00007, 0.00020, 1.00000, 1.00000, -0.00245, 0.00005, 0.00245, 1.00000, -0.00005, -0.00005, 0.00005, 1.00000, 1.00000, -0.00062, 0.00006, 0.00062, 1.00000, 0.00000, -0.00006, -0.00000, 1.00000] +data_std: [0.00004, 0.00181, 0.00640, 0.00184, 0.00000, 0.00338, 0.00639, 0.00340, 0.00005, 0.00011, 0.00982, 0.00313, 0.00980, 0.00011, 0.00347, 0.00318, 0.00342, 0.00008, 0.00019, 0.00592, 0.00635, 0.00594, 0.00000, 0.00162, 0.00634, 0.00160, 0.00006, 0.00011, 0.00817, 0.00093, 0.00817, 0.00011, 0.00095, 0.00094, 0.00093, 0.00002, 0.00020, 0.00222, 0.00488, 0.00232, 0.00002, 0.00488, 0.00493, 0.00485, 0.00024, 0.00000, 0.00104, 0.00184, 0.00104, 0.00000, 0.00163, 0.00185, 0.00163, 0.00000, 0.00004, 0.00127, 0.00646, 0.00128, 0.00000, 0.00270, 0.00646, 0.00270, 0.00006, 0.00022, 0.00987, 0.00607, 0.00986, 0.00022, 0.00639, 0.00617, 0.00629, 0.00017, 0.00021, 0.00240, 0.01065, 0.00239, 0.00000, 0.00094, 0.01065, 0.00089, 0.00021, 0.00009, 0.00852, 0.00051, 0.00852, 0.00009, 0.00054, 0.00052, 0.00053, 0.00001, 0.00015, 0.00474, 0.00900, 0.00495, 0.00022, 0.00898, 0.00910, 0.00890, 0.00028, 0.00000, 0.00131, 0.00257, 0.00130, 0.00000, 0.00214, 0.00257, 0.00213, 0.00000, 0.00000, 0.00322, 0.00078, 0.00322, 0.00000, 0.00322, 0.00079, 0.00322, 0.00000, 0.00000, 0.00320, 0.00243, 0.00319, 0.00000, 0.00155, 0.00243, 0.00154, 0.00000, 0.00000, 0.00003, 0.00077, 0.00003, 0.00000, 0.00027, 0.00077, 0.00027, 0.00000, 0.00000, 0.00051, 0.00125, 0.00051, 0.00000, 0.00107, 0.00125, 0.00107, 0.00000, 0.00000, 0.00173, 0.00080, 0.00173, 0.00000, 0.00038, 0.00080, 0.00038, 0.00000, 0.00000, 0.00043, 0.00076, 0.00043, 0.00000, 0.00010, 0.00076, 0.00010, 0.00000] + + +# model params +model: seq2seq +hidden_size: 200 +n_layers: 2 +dropout_prob: 0.0 + +#Atuoencoder parameters: +autoencoder_denoising: True +autoencoder_att: False +autoencoder_fixed_weight: False +autoencoder_conditioned: True +autoencoder_vae: False +autoencoder_vq: True +autoencoder_vq_components: 400 +autoencoder_vq_commitment_cost: 0.25 +use_derivative: False +#autoenoder_train_decoder: True +autoencoder_freeze_encoder: False + +#Text 2 Gesture +text2_embedding_discrete: True + +use_similarity: False +similarity_labels: data_loader/gesture_labels.txt +data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin +loss_label_weight: 5.5 + + +# train params +epochs: 3000 +batch_size: 128 +learning_rate: 0.0001 +loss_l1_weight: 5 +loss_cont_weight: 0.1 +loss_var_weight: 0.5 + +# dataset params +motion_resampling_framerate: 10 +n_poses: 10 +n_pre_poses: 1 +subdivision_stride: 20 +subdivision_stride_sentence: 20 +sentence_frame_length: 120 +loader_workers: 4 + +#reoresentation learning +rep_learning_checkpoint: ../output/GENEA/DAE/train_DAE_H45/DAE_H45_checkpoint_030.bin +rep_learning_dim: 45 +autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.bin +#GAN +noise_dim: 400 diff --git a/config/parse_args.py b/config/parse_args.py new file mode 100644 index 0000000..2928594 --- /dev/null +++ b/config/parse_args.py @@ -0,0 +1,96 @@ +import configargparse + + +def str2bool(v): + """ from https://stackoverflow.com/a/43357954/1361529 """ + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise configargparse.ArgumentTypeError('Boolean value expected.') + + +def parse_args(): + parser = configargparse.ArgParser() + parser.add('-c', '--config', required=True, is_config_file=True, help='Config file path') + parser.add("--name", type=str, default="main") + parser.add("--train_data_path", action="append", required=True) + parser.add("--val_data_path", action="append", required=True) + parser.add("--test_data_path", action="append", required=False) + parser.add("--model_save_path", required=True) + parser.add("--random_seed", type=int, default=-1) + + + # word embedding + parser.add("--wordembed_path", type=str, default=None) + parser.add("--wordembed_dim", type=int, default=200) + parser.add("--sentence_level", type=str, required=True) + parser.add("--sentence_frame_length", type=int, default=120) + + + # model + parser.add("--model", type=str, required=True) + parser.add("--epochs", type=int, default=10) + parser.add("--batch_size", type=int, default=50) + parser.add("--dropout_prob", type=float, default=0.3) + parser.add("--n_layers", type=int, default=2) + parser.add("--hidden_size", type=int, default=200) + + # Autoencoder + parser.add("--autoencoder_denoising", type=str, required=True) + parser.add("--autoencoder_att", type=str, required=True) + parser.add("--autoencoder_fixed_weight", type=str, required=True) + parser.add("--autoencoder_conditioned", type=str, required=True) + parser.add("--use_derivative", type=str, required=True) + parser.add("--autoencoder_checkpoint", type=str, required=True) + parser.add("--autoencoder_vae", type=str, required=True) + # parser.add("--autoenoder_train_decoder", type=str, required=True) # seems duplication + parser.add("--autoencoder_freeze_encoder", type=str, required=True) + parser.add("--autoencoder_vq", type=str, required=True) + parser.add("--autoencoder_vq_components", type=str, required=True) + parser.add("--autoencoder_vq_commitment_cost", type=str, required=True) + + # text2embedding + parser.add("--text2_embedding_discrete", type=str, required=True) + + + # similarity + parser.add("--use_similarity", type=str, required=True) + parser.add("--similarity_labels", type=str, required=True) + parser.add("--data_for_sim", type=str, required=True) + parser.add("--loss_label_weight", type=float) + # dataset + parser.add("--data_mean", action="append", type=float, nargs='*') + parser.add("--data_std", action="append", type=float, nargs='*') + parser.add("--motion_resampling_framerate", type=int, default=24) + parser.add("--n_poses", type=int, default=50) + parser.add("--n_pre_poses", type=int, default=5) + parser.add("--subdivision_stride", type=int, default=5) + parser.add("--subdivision_stride_sentence", type=int, default=30) + parser.add("--loader_workers", type=int, default=4) + parser.add("--input_motion_dim", type=int, default=135) + + # Modalities + parser.add("--Modality_Audio", type=str, required=True) + parser.add("--Modality_Text", type=str, required=True) + parser.add("--Modality_Gesture", type=str, required=True) + + # training + parser.add("--learning_rate", type=float, default=0.001) + parser.add("--loss_l1_weight", type=float, default=50) + parser.add("--loss_cont_weight", type=float, default=0.1) + parser.add("--loss_var_weight", type=float, default=0.01) + + # Representation learning: + parser.add("--rep_learning_checkpoint", type=str, default='') + parser.add("--rep_learning_dim", type=int, default=-1) + + + # GAN + parser.add("--noise_dim", type=int, default=200) + + args = parser.parse_args() + return args \ No newline at end of file diff --git a/config/seq2seq.yml b/config/seq2seq.yml index 1787ac8..ade59cf 100644 --- a/config/seq2seq.yml +++ b/config/seq2seq.yml @@ -1,69 +1,69 @@ -name: VQVAE - -train_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb/lmdb_train -sentence_level: True - -#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test -val_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb//lmdb_test - -wordembed_dim: 300 -wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext - -model_save_path: ../output/IROS_2/AI2_11_HQ -random_seed: 0 - -input_motion_dim: 135 -data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] -data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] - -# model params -model: seq2seq -hidden_size: 200 -n_layers: 2 -dropout_prob: 0.0 - -#Atuoencoder parameters: -autoencoder_denoising: True -autoencoder_att: False -autoencoder_fixed_weight: False -autoencoder_conditioned: True -autoencoder_vae: False -autoencoder_vq: False -autoencoder_vq_components: 512 -autoencoder_vq_commitment_cost: 0.25 -use_derivative: False -#autoenoder_train_decoder: True -autoencoder_freeze_encoder: False - -#Text 2 Gesture -text2_embedding_discrete: True - -use_similarity: False -similarity_labels: data_loader/gesture_labels.txt -data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin -loss_label_weight: 5.5 - - -# train params -epochs: 20 -batch_size: 128 -learning_rate: 0.0005 -loss_l1_weight: 5 -loss_cont_weight: 0.1 -loss_var_weight: 0.5 - -# dataset params -motion_resampling_framerate: 20 -n_poses: 20 -n_pre_poses: 1 -subdivision_stride: 20 -subdivision_stride_sentence: 20 -sentence_frame_length: 120 -loader_workers: 4 - -#reoresentation learning -rep_learning_checkpoint: ../output/IROS_2/DAE_p2/DAE_H40_checkpoint_020.bin -rep_learning_dim: 40 -autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.bin -#GAN -noise_dim: 400 +name: VQVAE + +train_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb/lmdb_train +sentence_level: True + +#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test +val_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb//lmdb_test + +wordembed_dim: 300 +wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext + +model_save_path: ../output/IROS_2/AI2_11_HQ +random_seed: 0 + +input_motion_dim: 135 +data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] +data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] + +# model params +model: seq2seq +hidden_size: 200 +n_layers: 2 +dropout_prob: 0.0 + +#Atuoencoder parameters: +autoencoder_denoising: True +autoencoder_att: False +autoencoder_fixed_weight: False +autoencoder_conditioned: True +autoencoder_vae: False +autoencoder_vq: False +autoencoder_vq_components: 512 +autoencoder_vq_commitment_cost: 0.25 +use_derivative: False +#autoenoder_train_decoder: True +autoencoder_freeze_encoder: False + +#Text 2 Gesture +text2_embedding_discrete: True + +use_similarity: False +similarity_labels: data_loader/gesture_labels.txt +data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin +loss_label_weight: 5.5 + + +# train params +epochs: 20 +batch_size: 128 +learning_rate: 0.0005 +loss_l1_weight: 5 +loss_cont_weight: 0.1 +loss_var_weight: 0.5 + +# dataset params +motion_resampling_framerate: 20 +n_poses: 20 +n_pre_poses: 1 +subdivision_stride: 20 +subdivision_stride_sentence: 20 +sentence_frame_length: 120 +loader_workers: 4 + +#reoresentation learning +rep_learning_checkpoint: ../output/IROS_2/DAE_p2/DAE_H40_checkpoint_020.bin +rep_learning_dim: 40 +autoencoder_checkpoint: ../output/autoencoder/toturial/4th/VQ-DVAE_ablation1_checkpoint_015.bin +#GAN +noise_dim: 400 diff --git a/config/seq2seqtxt.yml b/config/seq2seqtxt.yml index 354d0d9..fed414e 100644 --- a/config/seq2seqtxt.yml +++ b/config/seq2seqtxt.yml @@ -1,84 +1,84 @@ -name: VQ-DVAE_ablation1 - -#train_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb/lmdb_train -train_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_train - -sentence_level: True - -#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test -#val_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb/lmdb_test -val_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_test - -wordembed_dim: 300 -wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext - -#model_save_path: ../output/autoencoder/toturial/ICLR_text2embedding/ -model_save_path: ../output/autoencoder/ICML/text2embedding/1 -random_seed: 0 - -input_motion_dim: 135 -#data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] -#data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] - - -#Taras_tor 10fps -data_mean: [1.00000, -0.00034, -0.00062, 0.00035, 1.00000, 0.00030, 0.00062, -0.00031, 1.00000, 0.99998, 0.00419, 0.00081, -0.00418, 0.99998, -0.00062, -0.00083, 0.00061, 0.99999, 0.99991, 0.00472, 0.00917, -0.00475, 1.00000, 0.00325, -0.00916, -0.00333, 0.99999, 0.99997, -0.00763, -0.00009, 0.00763, 0.99997, 0.00002, 0.00009, -0.00003, 1.00000, 0.99962, -0.01398, -0.01889, 0.01383, 1.00000, -0.00656, 0.01898, 0.00630, 0.99982, 1.00000, 0.00037, 0.00049, -0.00037, 1.00000, -0.00057, -0.00049, 0.00057, 1.00000, 1.00000, -0.00301, 0.00043, 0.00301, 1.00000, -0.00056, -0.00043, 0.00057, 0.99999, 0.99989, 0.01280, -0.00062, -0.01276, 0.99989, 0.00084, 0.00063, -0.00084, 0.99997, 0.99988, 0.00165, -0.01069, -0.00165, 1.00000, -0.00064, 0.01069, 0.00068, 0.99988, 0.99998, -0.00350, -0.00003, 0.00350, 0.99998, 0.00001, 0.00003, -0.00001, 1.00000, 0.99955, -0.01916, 0.01381, 0.01899, 0.99965, 0.00691, -0.01391, -0.00668, 0.99984, 1.00000, 0.00124, -0.00088, -0.00123, 1.00000, 0.00148, 0.00088, -0.00147, 1.00000, 1.00000, 0.00017, 0.00002, -0.00017, 1.00000, 0.00017, -0.00002, -0.00017, 1.00000, 1.00000, -0.00079, 0.00032, 0.00078, 1.00000, 0.00019, -0.00032, -0.00018, 1.00000, 1.00000, 0.00003, 0.00005, -0.00003, 1.00000, -0.00006, -0.00005, 0.00006, 1.00000, 1.00000, 0.00061, 0.00007, -0.00061, 1.00000, -0.00020, -0.00007, 0.00020, 1.00000, 1.00000, -0.00245, 0.00005, 0.00245, 1.00000, -0.00005, -0.00005, 0.00005, 1.00000, 1.00000, -0.00062, 0.00006, 0.00062, 1.00000, 0.00000, -0.00006, -0.00000, 1.00000] -data_std: [0.00004, 0.00181, 0.00640, 0.00184, 0.00000, 0.00338, 0.00639, 0.00340, 0.00005, 0.00011, 0.00982, 0.00313, 0.00980, 0.00011, 0.00347, 0.00318, 0.00342, 0.00008, 0.00019, 0.00592, 0.00635, 0.00594, 0.00000, 0.00162, 0.00634, 0.00160, 0.00006, 0.00011, 0.00817, 0.00093, 0.00817, 0.00011, 0.00095, 0.00094, 0.00093, 0.00002, 0.00020, 0.00222, 0.00488, 0.00232, 0.00002, 0.00488, 0.00493, 0.00485, 0.00024, 0.00000, 0.00104, 0.00184, 0.00104, 0.00000, 0.00163, 0.00185, 0.00163, 0.00000, 0.00004, 0.00127, 0.00646, 0.00128, 0.00000, 0.00270, 0.00646, 0.00270, 0.00006, 0.00022, 0.00987, 0.00607, 0.00986, 0.00022, 0.00639, 0.00617, 0.00629, 0.00017, 0.00021, 0.00240, 0.01065, 0.00239, 0.00000, 0.00094, 0.01065, 0.00089, 0.00021, 0.00009, 0.00852, 0.00051, 0.00852, 0.00009, 0.00054, 0.00052, 0.00053, 0.00001, 0.00015, 0.00474, 0.00900, 0.00495, 0.00022, 0.00898, 0.00910, 0.00890, 0.00028, 0.00000, 0.00131, 0.00257, 0.00130, 0.00000, 0.00214, 0.00257, 0.00213, 0.00000, 0.00000, 0.00322, 0.00078, 0.00322, 0.00000, 0.00322, 0.00079, 0.00322, 0.00000, 0.00000, 0.00320, 0.00243, 0.00319, 0.00000, 0.00155, 0.00243, 0.00154, 0.00000, 0.00000, 0.00003, 0.00077, 0.00003, 0.00000, 0.00027, 0.00077, 0.00027, 0.00000, 0.00000, 0.00051, 0.00125, 0.00051, 0.00000, 0.00107, 0.00125, 0.00107, 0.00000, 0.00000, 0.00173, 0.00080, 0.00173, 0.00000, 0.00038, 0.00080, 0.00038, 0.00000, 0.00000, 0.00043, 0.00076, 0.00043, 0.00000, 0.00010, 0.00076, 0.00010, 0.00000] - - -# model params -model: seq2seq -hidden_size: 200 -n_layers: 2 -dropout_prob: 0.2 - -#Atuoencoder parameters: -autoencoder_denoising: False -autoencoder_att: True -autoencoder_fixed_weight: False -autoencoder_conditioned: True -autoencoder_vae: False -autoencoder_vq: True -autoencoder_vq_components: 512 -autoencoder_vq_commitment_cost: 0.01 - -use_derivative: True -#autoenoder_train_decoder: True -autoencoder_freeze_encoder: False - -#Text 2 Gesture -text2_embedding_discrete: True - -use_similarity: False -similarity_labels: data_loader/gesture_labels.txt -data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin -loss_label_weight: 5.5 - - -# train params -epochs: 250 -batch_size: 128 -learning_rate: 1e-5 -loss_l1_weight: 5 -loss_cont_weight: 0.1 -loss_var_weight: 0.5 - -# dataset params -motion_resampling_framerate: 10 -n_poses: 10 -n_pre_poses: 1 -subdivision_stride: 10 -subdivision_stride_sentence: 30 -sentence_frame_length: 120 -loader_workers: 128 - -#reoresentation learning -#rep_learning_checkpoint: ../output/DAE_old/train_DAE_H41/rep_learning_DAE_H41_checkpoint_020.bin -rep_learning_checkpoint: ../output/GENEA/DAE/train_DAE_H45/DAE_H45_checkpoint_030.bin - -rep_learning_dim: 82 -#autoencoder_checkpoint: ../output/autoencoder/toturial/ablation-study/22_Vanilla_VQ_Ideal/VQ-DVAE_ablation1_checkpoint_020.bin -autoencoder_checkpoint: ../output/GENEA/VQ-VAE/VQVAE_checkpoint_3000.bin - -#GAN -noise_dim: 400 +name: VQ-DVAE_ablation1 + +#train_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb/lmdb_train +train_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_train + +sentence_level: True + +#val_data_path: /mnt/work2/Trinity_Gesture_DB/lmdb/lmdb_test +#val_data_path: /local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Training_data/lmdb/lmdb_test +val_data_path: /local-scratch/pjomeyaz/rosie_gesture_benchmark/cloned/Clustering/must/GENEA/Co-Speech_Gesture_Generation/dataset/dataset_v1/trn/lmdb/lmdb_test + +wordembed_dim: 300 +wordembed_path: ../resource/crawl-300d-2M-subword.bin # fasttext + +#model_save_path: ../output/autoencoder/toturial/ICLR_text2embedding/ +model_save_path: ../output/autoencoder/ICML/text2embedding/1 +random_seed: 0 + +input_motion_dim: 135 +#data_mean: [0.99414, 0.05276, -0.01830, -0.05399, 0.98730, -0.07916, 0.01263, 0.07880, 0.99316, 1.00000, 0.01767, 0.00634, -0.01717, 0.99854, -0.07086, -0.00766, 0.07074, 0.99854, 0.99951, 0.01764, -0.00896, -0.01765, 1.00000, 0.00019, 0.00896, -0.00005, 0.99951, 1.00000, -0.00739, -0.01461, 0.00706, 1.00000, -0.01138, 0.01510, 0.01107, 1.00000, 0.99805, 0.00925, 0.00102, -0.00964, 0.98096, -0.17590, -0.00240, 0.17456, 0.97949, 0.98926, 0.00434, -0.03925, -0.00249, 0.99951, 0.02534, 0.03946, -0.02550, 0.98828, 1.00000, 0.00189, -0.00242, -0.00105, 0.95752, 0.27734, 0.00261, -0.27759, 0.95801, 0.95410, -0.02214, -0.23987, 0.02150, 0.98584, 0.00000, 0.23657, 0.00037, 0.96777, 0.36108, -0.81299, 0.27588, 0.80957, 0.37817, 0.16406, -0.24243, 0.15540, 0.83447, 0.05392, -0.22961, 0.67334, 0.00002, 0.81055, 0.23083, -0.84521, 0.07886, 0.07214, 0.92676, -0.23047, -0.09564, 0.22668, 0.93848, -0.03152, 0.10413, 0.00001, 0.98828, 0.96338, -0.00713, 0.21875, 0.00795, 0.99023, -0.00000, -0.21655, 0.00461, 0.97363, 0.30835, 0.78467, -0.42993, -0.84326, 0.37915, 0.11230, 0.26318, 0.34448, 0.80078, 0.09741, 0.21948, -0.67285, -0.00000, 0.80664, 0.21936, 0.85010, 0.04178, 0.10510, 0.93457, 0.19019, -0.00217, -0.18896, 0.94238, -0.00727, 0.00069, 0.00001, 0.99219] +#data_std: [0.08801, 0.03111, 0.01127, 0.03111, 0.08801, 0.03806, 0.01147, 0.03815, 0.08801, 0.08801, 0.00571, 0.00195, 0.00560, 0.08801, 0.01682, 0.00254, 0.01680, 0.08801, 0.08801, 0.00809, 0.01999, 0.00826, 0.08801, 0.00847, 0.02000, 0.00852, 0.08801, 0.08801, 0.01180, 0.00590, 0.01175, 0.08801, 0.00827, 0.00601, 0.00827, 0.08801, 0.08801, 0.01738, 0.02643, 0.01482, 0.08801, 0.04764, 0.02838, 0.04758, 0.08801, 0.08801, 0.01122, 0.06299, 0.01103, 0.08801, 0.01845, 0.06305, 0.01826, 0.08801, 0.08801, 0.01365, 0.00780, 0.01556, 0.08801, 0.04874, 0.00441, 0.04880, 0.08801, 0.08801, 0.04425, 0.04453, 0.04507, 0.08801, 0.00000, 0.04440, 0.01525, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08813, 0.08801, 0.08801, 0.08844, 0.08801, 0.00055, 0.08801, 0.08923, 0.08801, 0.08533, 0.08801, 0.08801, 0.08508, 0.04401, 0.07373, 0.08801, 0.02698, 0.04498, 0.00000, 0.08801, 0.08801, 0.03989, 0.04425, 0.04178, 0.08801, 0.00000, 0.04410, 0.01121, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.08807, 0.08801, 0.00042, 0.08801, 0.09247, 0.08801, 0.08801, 0.08801, 0.08801, 0.08801, 0.04617, 0.08801, 0.08801, 0.02463, 0.04758, 0.00024, 0.08801] + + +#Taras_tor 10fps +data_mean: [1.00000, -0.00034, -0.00062, 0.00035, 1.00000, 0.00030, 0.00062, -0.00031, 1.00000, 0.99998, 0.00419, 0.00081, -0.00418, 0.99998, -0.00062, -0.00083, 0.00061, 0.99999, 0.99991, 0.00472, 0.00917, -0.00475, 1.00000, 0.00325, -0.00916, -0.00333, 0.99999, 0.99997, -0.00763, -0.00009, 0.00763, 0.99997, 0.00002, 0.00009, -0.00003, 1.00000, 0.99962, -0.01398, -0.01889, 0.01383, 1.00000, -0.00656, 0.01898, 0.00630, 0.99982, 1.00000, 0.00037, 0.00049, -0.00037, 1.00000, -0.00057, -0.00049, 0.00057, 1.00000, 1.00000, -0.00301, 0.00043, 0.00301, 1.00000, -0.00056, -0.00043, 0.00057, 0.99999, 0.99989, 0.01280, -0.00062, -0.01276, 0.99989, 0.00084, 0.00063, -0.00084, 0.99997, 0.99988, 0.00165, -0.01069, -0.00165, 1.00000, -0.00064, 0.01069, 0.00068, 0.99988, 0.99998, -0.00350, -0.00003, 0.00350, 0.99998, 0.00001, 0.00003, -0.00001, 1.00000, 0.99955, -0.01916, 0.01381, 0.01899, 0.99965, 0.00691, -0.01391, -0.00668, 0.99984, 1.00000, 0.00124, -0.00088, -0.00123, 1.00000, 0.00148, 0.00088, -0.00147, 1.00000, 1.00000, 0.00017, 0.00002, -0.00017, 1.00000, 0.00017, -0.00002, -0.00017, 1.00000, 1.00000, -0.00079, 0.00032, 0.00078, 1.00000, 0.00019, -0.00032, -0.00018, 1.00000, 1.00000, 0.00003, 0.00005, -0.00003, 1.00000, -0.00006, -0.00005, 0.00006, 1.00000, 1.00000, 0.00061, 0.00007, -0.00061, 1.00000, -0.00020, -0.00007, 0.00020, 1.00000, 1.00000, -0.00245, 0.00005, 0.00245, 1.00000, -0.00005, -0.00005, 0.00005, 1.00000, 1.00000, -0.00062, 0.00006, 0.00062, 1.00000, 0.00000, -0.00006, -0.00000, 1.00000] +data_std: [0.00004, 0.00181, 0.00640, 0.00184, 0.00000, 0.00338, 0.00639, 0.00340, 0.00005, 0.00011, 0.00982, 0.00313, 0.00980, 0.00011, 0.00347, 0.00318, 0.00342, 0.00008, 0.00019, 0.00592, 0.00635, 0.00594, 0.00000, 0.00162, 0.00634, 0.00160, 0.00006, 0.00011, 0.00817, 0.00093, 0.00817, 0.00011, 0.00095, 0.00094, 0.00093, 0.00002, 0.00020, 0.00222, 0.00488, 0.00232, 0.00002, 0.00488, 0.00493, 0.00485, 0.00024, 0.00000, 0.00104, 0.00184, 0.00104, 0.00000, 0.00163, 0.00185, 0.00163, 0.00000, 0.00004, 0.00127, 0.00646, 0.00128, 0.00000, 0.00270, 0.00646, 0.00270, 0.00006, 0.00022, 0.00987, 0.00607, 0.00986, 0.00022, 0.00639, 0.00617, 0.00629, 0.00017, 0.00021, 0.00240, 0.01065, 0.00239, 0.00000, 0.00094, 0.01065, 0.00089, 0.00021, 0.00009, 0.00852, 0.00051, 0.00852, 0.00009, 0.00054, 0.00052, 0.00053, 0.00001, 0.00015, 0.00474, 0.00900, 0.00495, 0.00022, 0.00898, 0.00910, 0.00890, 0.00028, 0.00000, 0.00131, 0.00257, 0.00130, 0.00000, 0.00214, 0.00257, 0.00213, 0.00000, 0.00000, 0.00322, 0.00078, 0.00322, 0.00000, 0.00322, 0.00079, 0.00322, 0.00000, 0.00000, 0.00320, 0.00243, 0.00319, 0.00000, 0.00155, 0.00243, 0.00154, 0.00000, 0.00000, 0.00003, 0.00077, 0.00003, 0.00000, 0.00027, 0.00077, 0.00027, 0.00000, 0.00000, 0.00051, 0.00125, 0.00051, 0.00000, 0.00107, 0.00125, 0.00107, 0.00000, 0.00000, 0.00173, 0.00080, 0.00173, 0.00000, 0.00038, 0.00080, 0.00038, 0.00000, 0.00000, 0.00043, 0.00076, 0.00043, 0.00000, 0.00010, 0.00076, 0.00010, 0.00000] + + +# model params +model: seq2seq +hidden_size: 200 +n_layers: 2 +dropout_prob: 0.2 + +#Atuoencoder parameters: +autoencoder_denoising: False +autoencoder_att: True +autoencoder_fixed_weight: False +autoencoder_conditioned: True +autoencoder_vae: False +autoencoder_vq: True +autoencoder_vq_components: 512 +autoencoder_vq_commitment_cost: 0.01 + +use_derivative: True +#autoenoder_train_decoder: True +autoencoder_freeze_encoder: False + +#Text 2 Gesture +text2_embedding_discrete: True + +use_similarity: False +similarity_labels: data_loader/gesture_labels.txt +data_for_sim: ../output/clustering_results/org_latent_clustering_data.bin +loss_label_weight: 5.5 + + +# train params +epochs: 250 +batch_size: 128 +learning_rate: 1e-5 +loss_l1_weight: 5 +loss_cont_weight: 0.1 +loss_var_weight: 0.5 + +# dataset params +motion_resampling_framerate: 10 +n_poses: 10 +n_pre_poses: 1 +subdivision_stride: 10 +subdivision_stride_sentence: 30 +sentence_frame_length: 120 +loader_workers: 128 + +#reoresentation learning +#rep_learning_checkpoint: ../output/DAE_old/train_DAE_H41/rep_learning_DAE_H41_checkpoint_020.bin +rep_learning_checkpoint: ../output/GENEA/DAE/train_DAE_H45/DAE_H45_checkpoint_030.bin + +rep_learning_dim: 82 +#autoencoder_checkpoint: ../output/autoencoder/toturial/ablation-study/22_Vanilla_VQ_Ideal/VQ-DVAE_ablation1_checkpoint_020.bin +autoencoder_checkpoint: ../output/GENEA/VQ-VAE/VQVAE_checkpoint_3000.bin + +#GAN +noise_dim: 400