From 777d4b84a2e6b163ce65ff3108e65f87a7ffa164 Mon Sep 17 00:00:00 2001 From: Eric L Date: Wed, 3 May 2023 06:13:31 -0700 Subject: [PATCH] Add c2g autoencoder trainig funs docstrings --- scripts/train_eval/train_seq2seq.py | 377 +++++++++++++++++----------- 1 file changed, 230 insertions(+), 147 deletions(-) diff --git a/scripts/train_eval/train_seq2seq.py b/scripts/train_eval/train_seq2seq.py index edc2649..a7b2dd6 100644 --- a/scripts/train_eval/train_seq2seq.py +++ b/scripts/train_eval/train_seq2seq.py @@ -1,4 +1,25 @@ -""" +"""This module provides single iteration training functions for all parts. + +The following parameters must be included in the config file: + loss_l1_weight: A float weight for l1 loss when summing total loss. + loss_cont_weight: A float weight for cont loss when summing total loss. + loss_var_weight: A float weight for var loss when summing total loss. + autoencoder_vq: A string boolean to train a VQVAE model. + autoencoder_vae: A string boolean to train a basic VAE model. + autoencoder_freeze_encoder: A string boolean if encoder state is frozen. + text2_embedding_discrete: A string boolean to use word vector representation. + +The following functions are currently exported: + train_iter_DAE + train_iter_Autoencoder_seq2seq + train_iter_Autoencoder_ssl_seq2seq + train_iter_Autoencoder_VQ_seq2seq + train_iter_text2embedding + +Typical usage example: + model = DAE_Network(135, 200) + optim = torch.optim.Adam(model.parameters) + loss = train_iter_DAE(args, 1, noise_tensor, in_tensor, model, optim) """ @@ -15,7 +36,10 @@ loss_i = 0 -def custom_loss(output: torch.Tensor, target: torch.Tensor, args: argparse.Namespace) -> torch.Tensor: + +def custom_loss( + output: torch.Tensor, target: torch.Tensor, args: argparse.Namespace +) -> torch.Tensor: """Calculate a weighted l1, cont and var loss value. The 'args' argument must have the following keys: @@ -38,7 +62,9 @@ def custom_loss(output: torch.Tensor, target: torch.Tensor, args: argparse.Names l1_loss: torch.Tensor = l1_loss * args.loss_l1_weight # continuous motion - diff = [abs(output[:, n, :] - output[:, n-1, :]) for n in range(1, output.shape[1])] + diff = [ + abs(output[:, n, :] - output[:, n - 1, :]) for n in range(1, output.shape[1]) + ] cont_loss = torch.sum(torch.stack(diff)) / n_element cont_loss: torch.Tensor = cont_loss * args.loss_cont_weight @@ -52,14 +78,25 @@ def custom_loss(output: torch.Tensor, target: torch.Tensor, args: argparse.Names # inspect loss terms global loss_i if loss_i == 100: - logging.debug(' (loss terms) l1 %.5f, cont %.5f, var %.5f' % (l1_loss.item(), cont_loss.item(), var_loss.item())) + logging.debug( + " (loss terms) l1 %.5f, cont %.5f, var %.5f" + % (l1_loss.item(), cont_loss.item(), var_loss.item()) + ) loss_i = 0 loss_i += 1 return loss -def train_iter_seq2seq(args: argparse.Namespace, epoch: int, in_text: torch.Tensor, in_lengths: torch.Tensor, target_poses: torch.Tensor, net: torch.nn.Module, optim: torch.optim.Optimizer) -> dict[str, float]: +def train_iter_seq2seq( + args: argparse.Namespace, + epoch: int, + in_text: torch.Tensor, + in_lengths: torch.Tensor, + target_poses: torch.Tensor, + net: torch.nn.Module, + optim: torch.optim.Optimizer, +) -> dict[str, float]: """Perform one iteration of model training. The 'args' argument must have the following keys: @@ -93,7 +130,7 @@ def train_iter_seq2seq(args: argparse.Namespace, epoch: int, in_text: torch.Tens torch.nn.utils.clip_grad_norm_(net.parameters(), 5) optim.step() - return {'loss': loss.item()} + return {"loss": loss.item()} class RMSLELoss(nn.Module): @@ -102,6 +139,7 @@ class RMSLELoss(nn.Module): Attributes: mse: A PyTorch MSELoss object. """ + def __init__(self): """Default initialization.""" super().__init__() @@ -119,7 +157,15 @@ def forward(self, pred: torch.Tensor, actual: torch.Tensor) -> torch.Tensor: """ return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1))) -def train_iter_DAE(args: argparse.Namespace, epoch: int, noisy_poses: torch.Tensor, target_poses: torch.Tensor, net: torch.nn.Module, optim: torch.optim.Optimizer) -> Tuple[dict, torch.Tensor] | dict: + +def train_iter_DAE( + args: argparse.Namespace, + epoch: int, + noisy_poses: torch.Tensor, + target_poses: torch.Tensor, + net: torch.nn.Module, + optim: torch.optim.Optimizer, +) -> Tuple[dict, torch.Tensor] | dict: """Train one iteration of a Part a model. The 'args' argument must have the following keys: @@ -148,13 +194,12 @@ def train_iter_DAE(args: argparse.Namespace, epoch: int, noisy_poses: torch.Tens # zero gradients optim.zero_grad() - # generation - if args.autoencoder_vq=='True' and args.autoencoder_vae == 'False': + if args.autoencoder_vq == "True" and args.autoencoder_vae == "False": outputs, vq_loss, perplexity_vq = net(noisy_poses) - elif args.autoencoder_vq == 'True' and args.autoencoder_vae == 'True': + elif args.autoencoder_vq == "True" and args.autoencoder_vae == "True": outputs, vq_loss, perplexity_vq, logvar, meu = net(noisy_poses) - elif args.autoencoder_vq == 'False' and args.autoencoder_vae == 'True': + elif args.autoencoder_vq == "False" and args.autoencoder_vae == "True": outputs, logvar, meu = net(noisy_poses) else: outputs = net(noisy_poses) @@ -164,24 +209,25 @@ def train_iter_DAE(args: argparse.Namespace, epoch: int, noisy_poses: torch.Tens rec_loss: torch.Tensor = loss_fn(outputs, target_poses) - if args.autoencoder_vq == 'True': + if args.autoencoder_vq == "True": GSOFT = False if GSOFT: rec_loss = outputs.log_prob(target_poses).sum(dim=1).mean() - - loss = vq_loss - rec_loss/100 + loss = vq_loss - rec_loss / 100 # print("LOSSSSSSSS!", vq_loss, rec_loss) else: loss = rec_loss + vq_loss else: loss = rec_loss - if args.autoencoder_vae == 'True': - loss_KLD = -2.5 * torch.mean(torch.mean(1 + logvar - logvar.exp() - meu.pow(2), 1)) + if args.autoencoder_vae == "True": + loss_KLD = -2.5 * torch.mean( + torch.mean(1 + logvar - logvar.exp() - meu.pow(2), 1) + ) # loss_KLD = -0.5 * torch.mean(torch.mean(1 + logvar - logvar.exp(), 1)) print("Kista", loss_KLD) - loss += 5 * loss_KLD #0.11 + loss += 5 * loss_KLD # 0.11 loss.backward() @@ -189,14 +235,20 @@ def train_iter_DAE(args: argparse.Namespace, epoch: int, noisy_poses: torch.Tens torch.nn.utils.clip_grad_norm_(net.parameters(), 5) optim.step() - if args.autoencoder_vq == 'True': - return {'loss': rec_loss.item()}, perplexity_vq + if args.autoencoder_vq == "True": + return {"loss": rec_loss.item()}, perplexity_vq else: - return {'loss': loss.item()} - + return {"loss": loss.item()} -def train_iter_Autoencoder_seq2seq(args: argparse.Namespace, epoch: int, input_poses: torch.Tensor, target_poses: torch.Tensor, net: torch.nn.Module, optim: torch.optim.Optimizer) -> dict[str, float]: +def train_iter_Autoencoder_seq2seq( + args: argparse.Namespace, + epoch: int, + input_poses: torch.Tensor, + target_poses: torch.Tensor, + net: torch.nn.Module, + optim: torch.optim.Optimizer, +) -> dict[str, float]: """Train one iteration of a Part b model. The 'args' argument must contain the following keys: @@ -221,20 +273,21 @@ def train_iter_Autoencoder_seq2seq(args: argparse.Namespace, epoch: int, input_p optim.zero_grad() # generation - if args.autoencoder_vae == 'True': + if args.autoencoder_vae == "True": outputs, _, meu, logvar = net(input_poses, target_poses) else: outputs, _ = net(input_poses, target_poses) # loss - #Todo: important: I removed custom loss and replaced it by ll to test + # Todo: important: I removed custom loss and replaced it by ll to test loss = custom_loss(outputs, target_poses, args) # loss = F.mse_loss(outputs, target_poses) - if args.autoencoder_vae == 'True': - + if args.autoencoder_vae == "True": # loss_KLD = 0.5 * torch.mean(logvar.exp()-logvar-1 + meu.pow(2)) - loss_KLD = -0.5 * torch.mean(torch.mean( 1 + logvar - logvar.exp() - meu.pow(2), 1)) + loss_KLD = -0.5 * torch.mean( + torch.mean(1 + logvar - logvar.exp() - meu.pow(2), 1) + ) # if epoch%10==0: # print("____________________") # print("loss", loss) @@ -243,8 +296,8 @@ def train_iter_Autoencoder_seq2seq(args: argparse.Namespace, epoch: int, input_p # print("____________________") kl_start_epoch = 5 - if epoch>kl_start_epoch and args.autoencoder_freeze_encoder == 'False': - loss += loss_KLD * 0.01 * (epoch-kl_start_epoch)/args.epochs + if epoch > kl_start_epoch and args.autoencoder_freeze_encoder == "False": + loss += loss_KLD * 0.01 * (epoch - kl_start_epoch) / args.epochs # print("!!!!!!!!!!!!!!!!!!!!!!!!!!") loss.backward() @@ -252,35 +305,55 @@ def train_iter_Autoencoder_seq2seq(args: argparse.Namespace, epoch: int, input_p torch.nn.utils.clip_grad_norm_(net.parameters(), 5) optim.step() - return {'loss': loss.item()} + return {"loss": loss.item()} -def train_iter_Autoencoder_ssl_seq2seq(args: argparse.Namespace, epoch: int, input_poses: torch.Tensor, target_poses: torch.Tensor, net: torch.nn.Module, optim: torch.optim.Optimizer, - stack_pairs1: torch.Tensor, stackpairs2: torch.Tensor, stack_label: torch.Tensor) -> dict[str, float]: - """ +def train_iter_Autoencoder_ssl_seq2seq( + args: argparse.Namespace, + epoch: int, + input_poses: torch.Tensor, + target_poses: torch.Tensor, + net: torch.nn.Module, + optim: torch.optim.Optimizer, + stack_pairs1: torch.Tensor, + stack_pairs2: torch.Tensor, + stack_label: torch.Tensor, +) -> dict[str, float]: + """Train one iteration of a Part b model. + + #TODO + + The 'args' argument must contain the following keys: + autoencoder_vae: A string boolean if a VAE model was trained. + autoencoder_freeze_encoder: A string boolean if encoder state is frozen. + loss_l1_weight: A float weight for l1 loss when summing total loss. + loss_cont_weight: A float weight for cont loss when summing total loss. + loss_var_weight: A float weight for var loss when summing total loss. Args: - args: - epoch: - input_poses: - target_poses: - net: - optim: - stack_pairs1: - stack_pairs2: - stack_label: + args: A configargparser object with specified parameters (See above). + epoch: An integer number of iterations (unused). + input_poses: A Tensor of input data. + target_poses: A Tensor of ground truth data. + net: A PyTorch neural net (Autoencoder) model (from Part b). + optim: A PyTorch optimization algorithm object. + stack_pairs1: #TODO + stack_pairs2: #TODO + stack_label: #TODO Returns: - + A dict with a string key 'loss' and a float loss score. """ # zero gradients optim.zero_grad() # generation # Unlabeled - if args.autoencoder_vae == 'True': + if args.autoencoder_vae == "True": outputs, _, meu, logvar = net(input_poses, target_poses) - loss_KLD = -0.5 * torch.mean(torch.mean(1 + logvar - logvar.exp() - meu.pow(2), 1)) + loss_KLD = -0.5 * torch.mean( + torch.mean(1 + logvar - logvar.exp() - meu.pow(2), 1) + ) else: outputs, _ = net(input_poses, target_poses) @@ -289,12 +362,12 @@ def train_iter_Autoencoder_ssl_seq2seq(args: argparse.Namespace, epoch: int, inp if debug: print("stack_pairs1", stack_pairs1.shape) print("net.decoder.n_layers", net.decoder.n_layers) - if args.autoencoder_vae == 'True': + if args.autoencoder_vae == "True": outputs_p1, latents_p1, mu_1, logvar_1 = net(stack_pairs1, stack_pairs1) - outputs_p2, latents_p2, mu_2, logvar_2 = net(stackpairs2, stackpairs2) + outputs_p2, latents_p2, mu_2, logvar_2 = net(stack_pairs2, stack_pairs2) else: outputs_p1, latents_p1 = net(stack_pairs1, stack_pairs1) - outputs_p2, latents_p2 = net(stackpairs2, stackpairs2) + outputs_p2, latents_p2 = net(stack_pairs2, stack_pairs2) if debug: print("1. latentp1.shape:", latents_p1.shape) @@ -304,7 +377,7 @@ def train_iter_Autoencoder_ssl_seq2seq(args: argparse.Namespace, epoch: int, inp latents_p2 = torch.hstack((latents_p2[0], latents_p2[1])) # Normal loss - #Todo: important: I removed custom loss and replaced it by ll to test + # Todo: important: I removed custom loss and replaced it by ll to test # loss = custom_loss(outputs, target_poses, args) # loss_unlabeled = F.mse_loss(outputs, target_poses) loss_unlabeled = custom_loss(outputs, target_poses, args) @@ -316,18 +389,17 @@ def train_iter_Autoencoder_ssl_seq2seq(args: argparse.Namespace, epoch: int, inp print("latentp1.shape:", latents_p1.shape) print("cosine_similarity.shape:", cos_dist.shape) print("stack_label", stack_label.shape) - mask = (stack_label == 1) - cos_dist[mask] = cos_dist[mask]*-1 + mask = stack_label == 1 + cos_dist[mask] = cos_dist[mask] * -1 loss_labeled = torch.sum(cos_dist) loss: torch.Tensor = args.loss_label_weight + loss_unlabeled - if args.autoencoder_vae == 'True': + if args.autoencoder_vae == "True": kl_start_epoch = 10 if epoch > kl_start_epoch: loss += loss_KLD * 0.1 * (epoch - kl_start_epoch) / args.epochs - loss.backward() # optimize @@ -337,15 +409,18 @@ def train_iter_Autoencoder_ssl_seq2seq(args: argparse.Namespace, epoch: int, inp print("loss_unlabeled", loss_unlabeled) print("Loss_labeled", loss_labeled) print("loss_label_weight:", args.loss_label_weight) - return {'loss': loss.item()} - - - + return {"loss": loss.item()} - -def train_iter_c2g_seq2seq(args: argparse.Namespace, epoch: int, input_cluster: torch.Tensor, target_poses: torch.Tensor, net: torch.nn.Module, optim: torch.optim.Optimizer) -> dict[str, float]: - """ +def train_iter_c2g_seq2seq( + args: argparse.Namespace, + epoch: int, + input_cluster: torch.Tensor, + target_poses: torch.Tensor, + net: torch.nn.Module, + optim: torch.optim.Optimizer, +) -> dict[str, float]: + """Train a single iteration of a Part d model (cluster to gesture). The 'args' argument must contain the following keys: autoencoder_vae: A string boolean if a VAE model was trained. @@ -372,7 +447,7 @@ def train_iter_c2g_seq2seq(args: argparse.Namespace, epoch: int, input_cluster: outputs = net(input_cluster, target_poses) # loss - #Todo: important: I removed custom loss and replaced it by ll to test + # Todo: important: I removed custom loss and replaced it by ll to test loss = custom_loss(outputs, target_poses, args) # loss = F.mse_loss(outputs, target_poses) loss.backward() @@ -381,42 +456,57 @@ def train_iter_c2g_seq2seq(args: argparse.Namespace, epoch: int, input_cluster: torch.nn.utils.clip_grad_norm_(net.parameters(), 5) optim.step() - return {'loss': loss.item()} + return {"loss": loss.item()} -def train_iter_text2embedding(args: argparse.Namespace, epoch: int, in_text: torch.Tensor, in_lengths: torch.Tensor, in_audio: torch.Tensor, target_poses: torch.Tensor, cluster_targets: torch.Tensor, - GPT3_Embedding: torch.Tensor, net: torch.nn.Module, optim: torch.optim.Optimizer) -> dict[str, float]: - """ +def train_iter_text2embedding( + args: argparse.Namespace, + epoch: int, + in_text: torch.Tensor, + in_lengths: torch.Tensor, + in_audio: torch.Tensor, + target_poses: torch.Tensor, + cluster_targets: torch.Tensor, + GPT3_Embedding: torch.Tensor, + net: torch.nn.Module, + optim: torch.optim.Optimizer, +) -> dict[str, float]: + """Train one iteration of a Part d model (text2embedding_model). The 'args' argument must have the following keys: - text2_embedding_discrete: + text2_embedding_discrete: A string boolean to use word vector representation. Args: - args: - epoch: - in_text: - in_lengths: - in_audio: - target_poses: - cluster_targets: - GPT3_Embedding: - net: - optim: + args: A configargparser object with specific keys (See above). + epoch: An integer number of epochs (unused). + in_text: A Tensor of input data (text). + in_lengths: A Tensor of dimensions of 'in_text'. + in_audio: A Tensor of input data (audio). + target_poses: A Tensor of data (gesture) as a starting point in output. + cluster_targets: A Tensor of input data (gesture). + GPT3_Embedding: A Tensor of word vectors data from GPT3. + net: A custom Part d PyTorch 'text2embedding_model'. + optim: A PyTorch optimizer object. Returns: - + A dict with the following keys: + loss: A float loss score of the iteration. """ # zero gradients optim.zero_grad() # generation - if args.text2_embedding_discrete == 'False': - outputs, _ = net(in_text, in_lengths, in_audio, target_poses, GPT3_Embedding, None) + if args.text2_embedding_discrete == "False": + outputs, _ = net( + in_text, in_lengths, in_audio, target_poses, GPT3_Embedding, None + ) else: - outputs, _ = net(in_text, in_lengths, in_audio, cluster_targets, GPT3_Embedding, None) + outputs, _ = net( + in_text, in_lengths, in_audio, cluster_targets, GPT3_Embedding, None + ) # loss # print(outputs.shape) - if args.text2_embedding_discrete=='False': + if args.text2_embedding_discrete == "False": loss = F.mse_loss(outputs[:, 1:, :], target_poses[:, 1:, :]) else: os = cluster_targets.shape @@ -427,15 +517,9 @@ def train_iter_text2embedding(args: argparse.Namespace, epoch: int, in_text: tor print("cc", q.shape) w = F.one_hot(q.to(torch.int64), 300) print("----", w.shape, w) - - - # cluster_targets_one_hot = F.one_hot(cluster_targets.reshape(-1).to(torch.int64), 300) - # cluster_targets_one_hot = cluster_targets_one_hot.reshape(os[0], os[1], -1) outputs = outputs[:, 1:, :] outputs = outputs.reshape(-1, outputs.shape[2]) - # cluster_targets.reshape(-1) - - cluster_targets = cluster_targets[:,1:] + cluster_targets = cluster_targets[:, 1:] cluster_targets = cluster_targets.reshape(-1) if debug: @@ -444,16 +528,6 @@ def train_iter_text2embedding(args: argparse.Namespace, epoch: int, in_text: tor a = outputs.cpu().detach().numpy() b = cluster_targets.cpu().detach().numpy() loss = torch.nn.CrossEntropyLoss()(outputs.float(), cluster_targets.long()) - # loss = torch.nn.NLLLoss()(outputs.float(), cluster_targets.long()) - - # cluster_targets_one_hot = F.one_hot(cluster_targets.reshape(-1).to(torch.int64), 514) - # loss = F.l1_loss(outputs, cluster_targets_one_hot) - - # loss2 = torch.mean(torch.mean(loss, dim=2)/cluster_portion) - - - - loss.backward() @@ -461,12 +535,23 @@ def train_iter_text2embedding(args: argparse.Namespace, epoch: int, in_text: tor torch.nn.utils.clip_grad_norm_(net.parameters(), 5) optim.step() - return {'loss': loss.item()} + return {"loss": loss.item()} -def train_iter_text2embedding_GAN(args: argparse.Namespace, epoch: int, in_text: torch.Tensor, in_lengths: torch.Tensor, target_poses: torch.Tensor, cluster_portion: torch.Tensor, - g_net: torch.nn.Module, d_net: torch.nn.Module, g_optim: torch.optim.Optimizer, d_optim: torch.optim.Optimizer) -> Tuple[np.ndarray, np.ndarray]: - """ +def train_iter_text2embedding_GAN( + args: argparse.Namespace, + epoch: int, + in_text: torch.Tensor, + in_lengths: torch.Tensor, + target_poses: torch.Tensor, + cluster_portion: torch.Tensor, + g_net: torch.nn.Module, + d_net: torch.nn.Module, + g_optim: torch.optim.Optimizer, + d_optim: torch.optim.Optimizer, +) -> Tuple[np.ndarray, np.ndarray]: + """Experimental. Provided as-is. + Args: args: epoch: @@ -489,7 +574,7 @@ def train_iter_text2embedding_GAN(args: argparse.Namespace, epoch: int, in_text: # 1. * Generate fake data with torch.no_grad(): - fake_y = g_net(in_text, in_lengths, target_poses, None) + fake_y = g_net(in_text, in_lengths, target_poses, None) # 2. * Train Discriminator @@ -498,12 +583,12 @@ def train_iter_text2embedding_GAN(args: argparse.Namespace, epoch: int, in_text: d_real_error = 0 d_fake_error = 0 - real_logit = d_net(in_text, in_lengths, target_poses, None ) + real_logit = d_net(in_text, in_lengths, target_poses, None) real_label = torch.ones_like(real_logit) real_error = bce_loss(real_logit, real_label) d_real_error = torch.mean(real_error) - fake_logit = d_net(in_text, in_lengths, fake_y, None ) + fake_logit = d_net(in_text, in_lengths, fake_y, None) fake_label = torch.zeros_like(fake_logit) if debug: print("fake_label", fake_label.shape) @@ -511,13 +596,12 @@ def train_iter_text2embedding_GAN(args: argparse.Namespace, epoch: int, in_text: fake_error = bce_loss(fake_logit, fake_label) d_fake_error = torch.mean(fake_error) - d_loss = d_real_error + d_fake_error d_loss.backward() d_optim.step() - d_real_loss = (d_real_error.cpu().detach().numpy()) - d_fake_loss = (d_fake_error.cpu().detach().numpy()) + d_real_loss = d_real_error.cpu().detach().numpy() + d_fake_loss = d_fake_error.cpu().detach().numpy() # 2. * Unrolling step unroll_steps = 10 @@ -530,8 +614,7 @@ def train_iter_text2embedding_GAN(args: argparse.Namespace, epoch: int, in_text: d_real_error = 0 d_fake_error = 0 - - real_logit = d_net(in_text, in_lengths, target_poses, None ) + real_logit = d_net(in_text, in_lengths, target_poses, None) real_label = torch.ones_like(real_logit) real_error = bce_loss(real_logit, real_label) d_real_error = torch.mean(real_error) @@ -548,14 +631,13 @@ def train_iter_text2embedding_GAN(args: argparse.Namespace, epoch: int, in_text: # * Train G g_optim.zero_grad() - gen_y = g_net(in_text, in_lengths, target_poses, None ) + gen_y = g_net(in_text, in_lengths, target_poses, None) gen_logit = d_net(in_text, in_lengths, gen_y, None) gen_lable = torch.ones_like(gen_logit) gen_error = bce_loss(gen_logit, gen_lable) g_error = torch.mean(gen_error) - g_error.backward() g_optim.step() @@ -579,9 +661,14 @@ def train_iter_text2embedding_GAN(args: argparse.Namespace, epoch: int, in_text: return d_real_loss, d_fake_loss - - -def train_iter_Autoencoder_VQ_seq2seq(args: argparse.Namespace, epoch: int, input_poses: torch.Tensor, target_poses: torch.Tensor, net: torch.nn.Module, optim: torch.optim.Optimizer) -> Tuple[dict[str, float], torch.Tensor] | dict[str,float]: +def train_iter_Autoencoder_VQ_seq2seq( + args: argparse.Namespace, + epoch: int, + input_poses: torch.Tensor, + target_poses: torch.Tensor, + net: torch.nn.Module, + optim: torch.optim.Optimizer, +) -> Tuple[dict[str, float], torch.Tensor] | dict[str, float]: """ Args: @@ -599,53 +686,56 @@ def train_iter_Autoencoder_VQ_seq2seq(args: argparse.Namespace, epoch: int, inpu optim.zero_grad() vq_start_epoch = 0 - # net.vq_layer.embedding_grad(epoch % 3 == 0) # generation - if args.autoencoder_vq == 'True' and args.autoencoder_vae == 'True': - outputs, _, meu, logvar, loss_vq, perplexity_vq = net(input_poses, target_poses, epoch > vq_start_epoch) - if args.autoencoder_vq == 'True' and args.autoencoder_vae == 'False': - outputs, _, loss_vq, perplexity_vq = net(input_poses, target_poses, epoch > vq_start_epoch) - if args.autoencoder_vq == 'False' and args.autoencoder_vae == 'True': - outputs, _, meu, logvar = net(input_poses, target_poses) - if args.autoencoder_vq == 'False' and args.autoencoder_vae == 'False': - outputs , _ = net(input_poses, target_poses) + if args.autoencoder_vq == "True" and args.autoencoder_vae == "True": + outputs, _, meu, logvar, loss_vq, perplexity_vq = net( + input_poses, target_poses, epoch > vq_start_epoch + ) + if args.autoencoder_vq == "True" and args.autoencoder_vae == "False": + outputs, _, loss_vq, perplexity_vq = net( + input_poses, target_poses, epoch > vq_start_epoch + ) + if args.autoencoder_vq == "False" and args.autoencoder_vae == "True": + outputs, _, meu, logvar = net(input_poses, target_poses) + if args.autoencoder_vq == "False" and args.autoencoder_vae == "False": + outputs, _ = net(input_poses, target_poses) # loss - #Todo: important: I removed custom loss and replaced it by ll to test + # Todo: important: I removed custom loss and replaced it by ll to test loss = custom_loss(outputs, target_poses, args) # loss = F.mse_loss(outputs, target_poses) - if args.autoencoder_vae == 'True': + if args.autoencoder_vae == "True": # loss_KLD = 0.5 * torch.mean(logvar.exp()-logvar-1 + meu.pow(2)) - loss_KLD = -0.5 * torch.mean(torch.mean( 1 + logvar - logvar.exp() - meu.pow(2), 1)) + loss_KLD = -0.5 * torch.mean( + torch.mean(1 + logvar - logvar.exp() - meu.pow(2), 1) + ) loss_KLD = 0.5 * torch.mean(logvar.exp() - logvar - 1 + meu.pow(2)) - if epoch%10==0: + if epoch % 10 == 0: print("____________________") print("loss", loss) - print("loss_KLD", loss_KLD ) - print("Epoch ratio",epoch, args.epochs, epoch/args.epochs) + print("loss_KLD", loss_KLD) + print("Epoch ratio", epoch, args.epochs, epoch / args.epochs) print("____________________") if debug: - print('rec_loss', loss) + print("rec_loss", loss) # print("loss_vqn", loss_vq) print("loss_KLD", loss_KLD) kl_start_epoch = 0 - if epoch>kl_start_epoch: - loss += loss_KLD * 0.1 * (epoch-kl_start_epoch)/args.epochs + if epoch > kl_start_epoch: + loss += loss_KLD * 0.1 * (epoch - kl_start_epoch) / args.epochs if epoch > vq_start_epoch: - if args.autoencoder_vq == 'True': + if args.autoencoder_vq == "True": emb_w = net.vq_layer._embedding.weight.detach() # mycdist = torch.cdist(emb_w, emb_w).mean() # print("________________________________________________\n loss:", # loss.data, " vq_los", loss_vq.data, "mycdist:", mycdist.data) - loss = (loss) + 1 * loss_vq/400 #+ -1*torch.log(mycdist) - - + loss = (loss) + 1 * loss_vq / 400 # + -1*torch.log(mycdist) loss.backward() @@ -660,16 +750,9 @@ def train_iter_Autoencoder_VQ_seq2seq(args: argparse.Namespace, epoch: int, inpu # print("Loss:", loss, "-------Loss_VQ", loss_vq) # print("____________________") - - if args.autoencoder_vq == 'True': + if args.autoencoder_vq == "True": # return {'loss': loss.item()}, loss_vq # Todo: it should be perplexity_vq not loss_vq - return {'loss': loss.item()}, perplexity_vq.detach() + return {"loss": loss.item()}, perplexity_vq.detach() else: - return {'loss': loss.item()} - - - - - - + return {"loss": loss.item()}