+* The pretrained ASR encoder can be either an open source pretrained model or trained from scratch.
+* Parts but not all of the ASR encoder tends to achieve a better performance.
+* Sharpness-Aware Minimizationis (SAM) training seems to effeciently alleviate overfitting.
+#### Results:
+#### Runtime is based on libtorch, go to `subtools/runtime` and evaluate your models' RTF.
\ No newline at end of file
diff --git a/doc/papers/giatrans.jpg b/doc/papers/giatrans.jpg
new file mode 100755
index 0000000..6eaa342
Binary files /dev/null and b/doc/papers/giatrans.jpg differ
diff --git a/doc/papers/trans.jpg b/doc/papers/trans.jpg
new file mode 100755
index 0000000..69fa627
Binary files /dev/null and b/doc/papers/trans.jpg differ
diff --git a/doc/runtime_deploy.png b/doc/runtime_deploy.png
new file mode 100755
index 0000000..f66d9e6
Binary files /dev/null and b/doc/runtime_deploy.png differ
diff --git a/pytorch/launcher/runEcapaXvector_online.py b/pytorch/launcher/runEcapaXvector_online.py
old mode 100644
new mode 100755
index 4277ccf..c50acc3
--- a/pytorch/launcher/runEcapaXvector_online.py
+++ b/pytorch/launcher/runEcapaXvector_online.py
@@ -20,6 +20,7 @@
import libs.support.utils as utils
import libs.support.kaldi_common as kaldi_common
import libs.training.trainer_online as trainer
+import libs.training.trainer_online_sam as trainer_sam
import libs.training.lr_scheduler_online as learn_rate_scheduler
import libs.training.optim as optim
import libs.egs.egs_online as egs
@@ -29,10 +30,10 @@
Python version is gived (rather than Shell) to have more freedom, such as decreasing limitation of parameters that transfering
them to python from shell.
-Note, this launcher does not contain dataset preparation, augmentation, extracting acoustic features and back-end scoring etc.
- 1.See subtools/recipe/voxceleb/runVoxceleb.sh to get complete stages.
- 2.See subtools/newCopyData.sh, subtools/makeFeatures.sh.sh, subtools/computeVad.sh, subtools/augmentDataByNoise.sh and
- subtools/scoreSets.sh and run these script separately before or after running this launcher.
+Note, this launcher does not contain dataset preparation, augmentation, and back-end scoring etc.
+ 1.See subtools/recipe/voxcelebSRC/runVoxceleb_online.sh to get complete stages.
+ 2.An on-the-fly feature extraction mod.
How to modify this launcher:
1.Prepare your kaldi format dataset and model.py (model blueprint);
@@ -95,8 +96,8 @@
-parser.add_argument("--stage", type=int, default=0,
- help="The stage to control the start of training epoch (default 4).\n"
+parser.add_argument("--stage", type=int, default=3,
+ help="The stage to control the start of training epoch (default 3).\n"
" stage 0: Generate raw wav kaldidir which contains utt2chunk and utt2dur. (preprocess_raw_wav_egs.sh).\n"
" stage 1: remove utts (preprocess_raw_wav_egs.sh).\n"
" stage 2.1: get chunk egs (preprocess_raw_wav_egs.sh).\n"
@@ -105,7 +106,7 @@
" stage 4: extract xvector.")
parser.add_argument("--endstage", type=int, default=4,
- help="The endstage to control the endstart of training epoch (default 5).")
+ help="The endstage to control the endstart of training epoch (default 4).")
parser.add_argument("--train-stage", type=int, default=-1,
help="The stage to control the start of training epoch (default -1).\n"
@@ -252,6 +253,8 @@
"use_step": False,
"step_params": {
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":1,"end_epoch":1,"offset_margin":-0.0,"init_lambda":1.0},
"T": None,
"m": True, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4,
"s": False, "s_tuple": (30, 12), "s_list": None,
@@ -269,9 +272,16 @@
"weight_decay": 1e-1,
"lookahead.k": 5,
# 0 means not using lookahead and if used, suggest to set it as 0.5.
- "lookahead.alpha": 0,
+ "lookahead.alpha": 0.,
"gc": False, # If true, use gradient centralization.
- "nesterov": False # for sgd
+ "nesterov": False, # for sgd
+ "sam": False,
+ "sam.rho": 2.0, # 2.0 for adaptive
+ "sam.adaptive": True,
+ # "custwd_dict":{
+ # "train_len":0,
+ # "bias":0
+ # }
lr_scheduler_params = {
@@ -356,7 +366,6 @@
if utils.is_main_training():
logger.info("Get model_blueprint from model directory.")
# Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path.
- print(model_blueprint)
model_blueprint = utils.create_model_dir(
model_dir, model_blueprint, stage=train_stage)
@@ -384,13 +393,28 @@
model = model_py.ECAPA_TDNN(
info["feat_dim"], info["num_targets"], **model_params)
+ epoch_iters = (info['epoch_iters']//accum_grad)
+ if hasattr(model,'margin_warm'):
+ model.margin_warm.update_step_range(epoch_iters)
# If multi-GPU used, then batchnorm will be converted to synchronized batchnorm, which is important
# to make peformance stable.
# It will change nothing for single-GPU training.
model = utils.convert_synchronized_batchnorm(model)
if utils.is_main_training():
+ print(model)
+ p1=sum(p.numel() for p in model.parameters())
+ script_model = copy.deepcopy(model)
+ script_model.loss=None
+ p2 = sum(p.numel() for p in script_model.parameters())
+ logger.info("model params w/o proj layer: {} / {} .".format(p1,p2))
+ script_model = torch.jit.script(script_model)
+ script_model.save(os.path.join(model_dir, 'init.zip'))
+ logger.info("The number of steps per epoch is about {}.".format(epoch_iters))
logger.info("Define optimizer and lr_scheduler.")
+ del script_model
optimizer = optim.get_optimizer(model, optimizer_params)
lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper(
optimizer, lr_scheduler_params)
@@ -399,7 +423,7 @@
if utils.is_main_training():
utils.write_list_to_file([egs_params, model_params, optimizer_params,
- lr_scheduler_params], model_dir+'/config/params.dict')
+ lr_scheduler_params], model_dir+'/config/params.dict',yml=True)
if utils.is_main_training():
logger.info("Init a simple trainer.")
@@ -409,14 +433,16 @@
"start_epoch": train_stage, "epochs": epochs, "use_gpu": use_gpu, "gpu_id": gpu_id, "use_amp": use_amp,
"skip_nan_batch": skip_nan_batch, "benchmark": benchmark, "suffix": suffix, "compute_batch_num_valid": compute_batch_num_valid,
"report_interval_iters": report_interval_iters, "record_file": "train.csv"})
- trainer = trainer.SimpleTrainer(package)
+ train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer
+ execuer = train_exec.SimpleTrainer(package)
if run_lr_finder:
- trainer.run_lr_finder("lr_finder.csv", init_lr=1e-8,
+ execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8,
final_lr=10., num_iters=2000, beta=0.98)
endstage = 3 # Do not start extractor.
- trainer.run()
+ execuer.run()
# Extract xvector
@@ -440,7 +466,7 @@
gpu_id = ""
sleep_time = 10
feat_config = "feat_conf.yaml"
+ max_chunk = 10000
# Run a batch extracting process.
for position in to_extracted_positions:
@@ -480,12 +506,12 @@
# with python's threads to extract xvectors directly, but the shell script is more convenient.
kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh "
" --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' "
- " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th}"
+ " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} "
" --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} "
"{model_dir} {datadir} {outdir}".format(model_file=model_file, nj=nj,
use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config,
feat_config=feat_config, data_type=data_type_emb, de_silence=str(de_silence).lower(), amp_th=amp_th,
- model_dir=model_dir, datadir=datadir, outdir=outdir))
+ max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir))
except BaseException as e:
if not isinstance(e, KeyboardInterrupt):
diff --git a/pytorch/launcher/runRepvggXvector.py b/pytorch/launcher/runRepvggXvector.py
old mode 100644
new mode 100755
index 1b62f71..3125b58
--- a/pytorch/launcher/runRepvggXvector.py
+++ b/pytorch/launcher/runRepvggXvector.py
@@ -20,6 +20,7 @@
import libs.training.optim as optim
import libs.training.lr_scheduler_online as learn_rate_scheduler
import libs.training.trainer_online as trainer
+import libs.training.trainer_online_sam as trainer_sam
import libs.support.kaldi_common as kaldi_common
import libs.support.utils as utils
from libs.support.logging_stdout import patch_logging_stream
@@ -32,6 +33,9 @@
Python version is gived (rather than Shell) to have more freedom, such as decreasing limitation of parameters that transfering
them to python from shell.
+Note, this launcher does not contain dataset preparation, augmentation, and back-end scoring etc.
+ 1.See subtools/recipe/voxcelebSRC/runVoxceleb_online.sh to get complete stages.
+ 2.An on-the-fly feature extraction mod.
How to modify this launcher:
1.Prepare your kaldi format dataset and model.py (model blueprint);
@@ -93,8 +97,8 @@
-parser.add_argument("--stage", type=int, default=0,
- help="The stage to control the start of training epoch (default 4).\n"
+parser.add_argument("--stage", type=int, default=3,
+ help="The stage to control the start of training epoch (default 3).\n"
" stage 0: Generate raw wav kaldidir which contains utt2chunk and utt2dur. (preprocess_raw_wav_egs.sh).\n"
" stage 1: remove utts (preprocess_raw_wav_egs.sh).\n"
" stage 2.1: get chunk egs (preprocess_raw_wav_egs.sh).\n"
@@ -103,7 +107,7 @@
" stage 4: extract xvector.")
parser.add_argument("--endstage", type=int, default=4,
- help="The endstage to control the endstart of training epoch (default 5).")
+ help="The endstage to control the endstart of training epoch (default 4).")
parser.add_argument("--train-stage", type=int, default=-1,
help="The stage to control the start of training epoch (default -1).\n"
@@ -264,6 +268,8 @@
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":1,"end_epoch":1,"offset_margin":-0.0,"init_lambda":1.0},
"m":True, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4,
"s":False, "s_tuple":(30, 12), "s_list":None,
@@ -371,7 +377,6 @@
if stage <= 3 <= endstage:
if utils.is_main_training():logger.info("Get model_blueprint from model directory.")
# Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path.
- print(model_blueprint)
model_blueprint = utils.create_model_dir(model_dir, model_blueprint, stage=train_stage)
if utils.is_main_training():logger.info("Load egs to bunch.")
@@ -387,7 +392,8 @@
with open(feat_config_path,'w') as fou:
- if utils.is_main_training():logger.info("Create model from model blueprint.")
+ if utils.is_main_training():
+ logger.info("Create model from model blueprint.")
# Another way: import the model.py in this python directly, but it is not friendly to the shell script of extracting and
# I don't want to change anything about extracting script when the model.py is changed.
model_py = utils.create_model_from_py(model_blueprint)
@@ -399,8 +405,22 @@
# It will change nothing for single-GPU training.
model = utils.convert_synchronized_batchnorm(model)
+ epoch_iters = (info['epoch_iters']//accum_grad)
+ if hasattr(model,'margin_warm'):
+ model.margin_warm.update_step_range(epoch_iters)
- if utils.is_main_training():logger.info("Define optimizer and lr_scheduler.")
+ if utils.is_main_training():
+ print(model)
+ p1=sum(p.numel() for p in model.parameters())
+ script_model = copy.deepcopy(model)
+ script_model.loss=None
+ p2 = sum(p.numel() for p in script_model.parameters())
+ logger.info("model params w/o proj layer: {} / {} .".format(p1,p2))
+ script_model = torch.jit.script(script_model)
+ script_model.save(os.path.join(model_dir, 'init.zip'))
+ logger.info("The number of steps per epoch is about {}.".format(epoch_iters))
+ logger.info("Define optimizer and lr_scheduler.")
+ del script_model
optimizer = optim.get_optimizer(model, optimizer_params)
lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper(optimizer, lr_scheduler_params)
@@ -408,8 +428,9 @@
# Record params to model_dir
- if utils.is_main_training():utils.write_list_to_file([egs_params, model_params, optimizer_params,
- lr_scheduler_params], model_dir+'/config/params.dict')
+ if utils.is_main_training():
+ utils.write_list_to_file([egs_params, model_params, optimizer_params,
+ lr_scheduler_params], model_dir+'/config/params.dict',yml=True)
if utils.is_main_training():logger.info("Init a simple trainer.")
@@ -420,13 +441,15 @@
"skip_nan_batch":skip_nan_batch, "benchmark":benchmark, "suffix":suffix, "compute_batch_num_valid":compute_batch_num_valid,
"report_interval_iters":report_interval_iters, "record_file":"train.csv"})
- trainer = trainer.SimpleTrainer(package)
+ train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer
+ execuer = train_exec.SimpleTrainer(package)
if run_lr_finder and utils.is_main_training():
- trainer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98)
+ execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98)
endstage = 3 # Do not start extractor.
- trainer.run()
+ execuer.run()
# Plan to use del to avoid memeory account after training done and continue to execute stage 4.
# But it dose not work and is still a problem.
@@ -452,7 +475,7 @@
gpu_id = ""
sleep_time = 10
+ max_chunk = 10000
# Run a batch extracting process.
for position in to_extracted_positions:
@@ -500,12 +523,12 @@
# with python's threads to extract xvectors directly, but the shell script is more convenient.
kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh "
" --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' "
- " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th}"
+ " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} "
" --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} "
"{model_dir} {datadir} {outdir}".format(model_file=depoly_model_file, nj=nj,
use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config,
- model_dir=model_dir, datadir=datadir, outdir=outdir))
+ max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir))
except BaseException as e:
if not isinstance(e, KeyboardInterrupt):
diff --git a/pytorch/launcher/runResnetXvector_online.py b/pytorch/launcher/runResnetXvector_online.py
old mode 100644
new mode 100755
index e7ca4ba..af74739
--- a/pytorch/launcher/runResnetXvector_online.py
+++ b/pytorch/launcher/runResnetXvector_online.py
@@ -20,6 +20,7 @@
import libs.training.optim as optim
import libs.training.lr_scheduler_online as learn_rate_scheduler
import libs.training.trainer_online as trainer
+import libs.training.trainer_online_sam as trainer_sam
import libs.support.kaldi_common as kaldi_common
import libs.support.utils as utils
from libs.support.logging_stdout import patch_logging_stream
@@ -33,6 +34,10 @@
them to python from shell.
+Note, this launcher does not contain dataset preparation, augmentation, and back-end scoring etc.
+ 1.See subtools/recipe/voxcelebSRC/runVoxceleb_online.sh to get complete stages.
+ 2.An on-the-fly feature extraction mod.
How to modify this launcher:
1.Prepare your kaldi format dataset and model.py (model blueprint);
2.Give the path of dataset, model blueprint, etc. in main parameters field;
@@ -93,8 +98,9 @@
-parser.add_argument("--stage", type=int, default=0,
- help="The stage to control the start of training epoch (default 4).\n"
+parser.add_argument("--stage", type=int, default=3,
+ help="The stage to control the start of training epoch (default 3).\n"
" stage 0: Generate raw wav kaldidir which contains utt2chunk and utt2dur. (preprocess_raw_wav_egs.sh).\n"
" stage 1: remove utts (preprocess_raw_wav_egs.sh).\n"
" stage 2.1: get chunk egs (preprocess_raw_wav_egs.sh).\n"
@@ -103,7 +109,7 @@
" stage 4: extract xvector.")
parser.add_argument("--endstage", type=int, default=4,
- help="The endstage to control the endstart of training epoch (default 5).")
+ help="The endstage to control the endstart of training epoch (default 4).")
parser.add_argument("--train-stage", type=int, default=-1,
help="The stage to control the start of training epoch (default -1).\n"
@@ -260,6 +266,8 @@
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":1,"end_epoch":1,"offset_margin":-0.0,"init_lambda":1.0},
"m":True, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4,
"s":False, "s_tuple":(30, 12), "s_list":None,
@@ -278,7 +286,14 @@
"lookahead.alpha":0., # 0 means not using lookahead and if used, suggest to set it as 0.5.
"gc":False, # If true, use gradient centralization.
- "nesterov": False # for sgd
+ "nesterov": False, # for sgd
+ "sam": False, # suggest true when apply ASR transferring.
+ "sam.rho": 2.0, # 2.0 for adaptive
+ "sam.adaptive": True,
+ # "custwd_dict":{
+ # "train_len":0,
+ # "bias":0
+ # }
lr_scheduler_params = {
@@ -301,7 +316,7 @@
epochs = 50 # Total epochs to train. It is important. Here 18 = 6 -> 12 -> 18 with warmR.T_mult=1 and warmR.T_max=6.
compute_batch_num_valid = 10
-report_interval_iters = 100 # About validation computation and loss reporting. If report_times_every_epoch is not None,
+report_interval_iters = 500 # About validation computation and loss reporting. If report_times_every_epoch is not None,
# then compute report_interval_iters by report_times_every_epoch.
stop_early = False
suffix = "params" # Used in saved model file.
@@ -395,8 +410,22 @@
# It will change nothing for single-GPU training.
model = utils.convert_synchronized_batchnorm(model)
- if utils.is_main_training():logger.info("Define optimizer and lr_scheduler.")
+ epoch_iters = (info['epoch_iters']//accum_grad)
+ if hasattr(model,'margin_warm'):
+ model.margin_warm.update_step_range(epoch_iters)
+ if utils.is_main_training():
+ print(model)
+ p1=sum(p.numel() for p in model.parameters())
+ script_model = copy.deepcopy(model)
+ script_model.loss=None
+ p2 = sum(p.numel() for p in script_model.parameters())
+ logger.info("model params w/o proj layer: {} / {} .".format(p1,p2))
+ script_model = torch.jit.script(script_model)
+ script_model.save(os.path.join(model_dir, 'init.zip'))
+ logger.info("The number of steps per epoch is about {}.".format(epoch_iters))
+ logger.info("Define optimizer and lr_scheduler.")
+ del script_model
optimizer = optim.get_optimizer(model, optimizer_params)
lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper(optimizer, lr_scheduler_params)
@@ -404,9 +433,10 @@
# Record params to model_dir
- if utils.is_main_training():utils.write_list_to_file([egs_params, model_params, optimizer_params,
- lr_scheduler_params], model_dir+'/config/params.dict')
+ if utils.is_main_training():
+ utils.write_list_to_file([egs_params, model_params, optimizer_params,
+ lr_scheduler_params], model_dir+'/config/params.dict',yml=True)
if utils.is_main_training():logger.info("Init a simple trainer.")
# Package(Elements:dict, Params:dict}. It is a key parameter's package to trainer and model_dir/config/.
@@ -416,13 +446,16 @@
"skip_nan_batch":skip_nan_batch, "benchmark":benchmark, "suffix":suffix, "compute_batch_num_valid":compute_batch_num_valid,
"report_interval_iters":report_interval_iters, "record_file":"train.csv"})
- trainer = trainer.SimpleTrainer(package)
+ train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer
+ execuer = train_exec.SimpleTrainer(package)
if run_lr_finder:
- trainer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98)
+ execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98)
endstage = 3 # Do not start extractor.
- trainer.run()
+ execuer.run()
# Plan to use del to avoid memeory account after training done and continue to execute stage 4.
# But it dose not work and is still a problem.
@@ -449,7 +482,7 @@
gpu_id = ""
sleep_time = 10
+ max_chunk = 10000
# Run a batch extracting process.
@@ -485,12 +518,12 @@
# with python's threads to extract xvectors directly, but the shell script is more convenient.
kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh "
" --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' "
- " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th}"
+ " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} "
" --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} "
"{model_dir} {datadir} {outdir}".format(model_file=model_file, nj=nj,
use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config,
- model_dir=model_dir, datadir=datadir, outdir=outdir))
+ max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir))
except BaseException as e:
if not isinstance(e, KeyboardInterrupt):
diff --git a/pytorch/launcher/runTransformerXvector.py b/pytorch/launcher/runTransformerXvector.py
new file mode 100755
index 0000000..decb599
--- /dev/null
+++ b/pytorch/launcher/runTransformerXvector.py
@@ -0,0 +1,558 @@
+# -*- coding:utf-8 -*-
+# Copyright xmuspeech (Author: Leo 2021-10-05)
+# Apache 2.0
+# Only support "nccl" backend multi-gpu training.
+import sys
+import os
+import logging
+import argparse
+import traceback
+import time
+import yaml
+import copy
+import math
+import numpy as np
+import torch
+sys.path.insert(0, 'subtools/pytorch')
+from libs.support.logging_stdout import patch_logging_stream
+import libs.support.utils as utils
+import libs.support.kaldi_common as kaldi_common
+import libs.training.trainer_online as trainer
+import libs.training.trainer_online_sam as trainer_sam
+import libs.training.lr_scheduler_online as learn_rate_scheduler
+import libs.training.optim as optim
+import libs.egs.egs_online as egs
+"""A launcher script with python version (Snowdar's launcher to do experiments w.r.t snowdar-xvector.py).
+Python version is gived (rather than Shell) to have more freedom, such as decreasing limitation of parameters that transfering
+them to python from shell.
+Note, this launcher does not contain dataset preparation, augmentation, and back-end scoring etc.
+ 1.See subtools/recipe/voxcelebSRC/runVoxceleb_online.sh to get complete stages.
+ 2.An on-the-fly feature extraction mod.
+How to modify this launcher:
+ 1.Prepare your kaldi format dataset and model.py (model blueprint);
+ 2.Give the path of dataset, model blueprint, etc. in main parameters field;
+ 3.Change the imported name of model in 'model = model_py.model_name(...)' w.r.t model.py by yourself;
+ 4.Modify any training parameters what you want to change (epochs, optimizer and lr_scheduler etc.);
+ 5.Modify parameters of extracting in stage 4 w.r.t your own training config;
+ 6.Run this launcher.
+Conclusion: preprare -> config -> run.
+How to run this launcher to train a model:
+ 1.For CPU-based training case. The key option is --use-gpu.
+ python3 launcher.py --use-gpu=false
+ 2.For single-GPU training case (Default).
+ python3 launcher.py
+ 3.For DDP-based multi-GPU training case. Note --nproc_per_node is equal to number of gpu id in --gpu-id.
+ python3 -m torch.distributed.launch --nproc_per_node=2 launcher.py --gpu-id=0,1
+ 4.For Horovod-based multi-GPU training case. Note --np is equal to number of gpu id in --gpu-id.
+ horovodrun -np 2 launcher.py --gpu-id=0,1
+ 5.For all of above, a runLauncher.sh script has been created to launch launcher.py conveniently.
+ The key option to use single or multiple GPU is --gpu-id.
+ The subtools/runPytorchLauncher.sh is a soft symbolic which is linked to subtools/pytorch/launcher/runLauncher.sh,
+ so just use it.
+ [ CPU ]
+ subtools/runPytorchLauncher.sh launcher.py --use-gpu=false
+ [ Single-GPU ]
+ (1) Auto-select GPU device
+ subtools/runPytorchLauncher.sh launcher.py
+ (2) Specify GPU device
+ subtools/runPytorchLauncher.sh launcher.py --gpu-id=2
+ [ Multi-GPU ]
+ (1) Use DDP solution (Default).
+ subtools/runPytorchLauncher.sh launcher.py --gpu-id=2,3 --multi-gpu-solution="ddp"
+ (2) Use Horovod solution.
+ subtools/runPytorchLauncher.sh launcher.py --gpu-id=2,3 --multi-gpu-solution="horovod"
+If you have any other requirements, you could modify the codes in anywhere.
+For more details of multi-GPU devolopment, see subtools/README.md.
+# Logger
+# Logger
+logger = logging.getLogger('libs')
+handler = logging.StreamHandler()
+formatter = logging.Formatter("%(asctime)s [ %(pathname)s:%(lineno)s - "
+ "%(funcName)s - %(levelname)s ]\n#### %(message)s")
+# Parser: add this parser to run launcher with some frequent options (really for conveninece).
+parser = argparse.ArgumentParser(
+ description="""Train xvector framework with pytorch.""",
+ formatter_class=argparse.RawTextHelpFormatter,
+ conflict_handler='resolve')
+parser.add_argument("--stage", type=int, default=3,
+ help="The stage to control the start of training epoch (default 3).\n"
+ " stage 0: Generate raw wav kaldidir which contains utt2chunk and utt2dur. (preprocess_raw_wav_egs.sh).\n"
+ " stage 1: remove utts (preprocess_raw_wav_egs.sh).\n"
+ " stage 2.1: get chunk egs (preprocess_raw_wav_egs.sh).\n"
+ " stage 2.2: Prepare speech augment csv files.\n"
+ " stage 3: Training.\n"
+ " stage 4: extract xvector.")
+parser.add_argument("--endstage", type=int, default=4,
+ help="The endstage to control the endstart of training epoch (default 4).")
+parser.add_argument("--train-stage", type=int, default=-1,
+ help="The stage to control the start of training epoch (default -1).\n"
+ " -1 -> creating model_dir.\n"
+ " 0 -> model initialization (e.g. transfer learning).\n"
+ " >0 -> recovering training.")
+parser.add_argument("--force-clear", type=str, action=kaldi_common.StrToBoolAction,
+ default=True, choices=["true", "false"],
+ help="Clear the dir generated by preprocess.")
+parser.add_argument("--pre-rirmusan", type=str, action=kaldi_common.StrToBoolAction,
+ default=True, choices=["true", "false"],
+ help="Prepare the openrir and musan dataset for adding reverb and noises.")
+parser.add_argument('--use-amp', type=str, action=kaldi_common.StrToBoolAction,
+ default=False, choices=["true", "false"],
+ help='Use automatic mixed precision training')
+parser.add_argument("--skip-nan-batch", type=str, action=kaldi_common.StrToBoolAction,
+ default=True, choices=["true", "false"],
+ help="Whether skip optimizer stepping when the gradient has nan/inf values.")
+parser.add_argument("--accum-grad", type=int, default=1,
+ help="Using accumulate grad.")
+parser.add_argument("--multi-gpu-solution", type=str, default="ddp",
+ choices=["ddp"],
+ help="if number of gpu_id > 1, this option will be valid to init a multi-gpu solution.")
+parser.add_argument("--use-gpu", type=str, action=kaldi_common.StrToBoolAction,
+ default=True, choices=["true", "false"],
+ help="Use GPU or not.")
+parser.add_argument("--gpu-id", type=str, default="",
+ help="If NULL, then it will be auto-specified.")
+parser.add_argument("--benchmark", type=str, action=kaldi_common.StrToBoolAction,
+ default=False, choices=["true", "false"],
+ help="If true, save training time but require a little more gpu-memory.")
+parser.add_argument("--run-lr-finder", type=str, action=kaldi_common.StrToBoolAction,
+ default=False, choices=["true", "false"],
+ help="If true, run lr finder rather than training.")
+parser.add_argument("--sleep", type=int, default=0,
+ help="The waiting time to launch a launcher.")
+parser.add_argument("--local_rank", type=int, default=0,
+ help="Do not delete it when using DDP-based multi-GPU training.\n"
+ "It is important for torch.distributed.launch.")
+args = parser.parse_args()
+######################################################### PARAMS ########################################################
+# Control options
+stage = max(0, args.stage)
+endstage = min(4, args.endstage)
+train_stage = max(-1, args.train_stage)
+# Preprocess options
+force_clear = args.force_clear
+preprocess_nj = 20
+whole_utt = True
+random_segment = False
+seg_dur = 3.015
+amp_th = 50
+de_silence = False
+vad_wav_savdir = "export/yourpath"
+min_len = 2.0
+max_len = 1000.0
+limit_utts = 8
+valid_split_type = "--total-spk" # --total-spk or --default
+valid_utts = 2048
+valid_chunk_num = 2
+valid_fix_chunk_num = False
+data_type = 'shard'
+num_utts_per_shard = 2000
+shard_dir = '/export/yourpath/voxceleb_dev_whole'
+# Prepare speech augmention csv files.
+pre_rirmusan = args.pre_rirmusan # whether skip this stage.
+openrir_folder = "export/path" # where contains RIRS_NOISES folder.
+musan_folder = "export/path" # where contains musan folder.
+csv_aug_folder = "exp/aug_csv3" # csv file location.
+savewav_folder = "/export/yourpath/speech_aug_3015" # save the noise seg into SSD.
+max_noise_len = seg_dur # The max dur of noise.
+# Training options
+use_amp = args.use_amp
+skip_nan_batch = args.skip_nan_batch
+accum_grad = args.accum_grad
+use_gpu = args.use_gpu # Default true.
+# If true, save much training time but require a little more gpu-memory.
+benchmark = args.benchmark
+gpu_id = args.gpu_id # If NULL, then it will be auto-specified.
+run_lr_finder = args.run_lr_finder
+# Define model_params by model_blueprint w.r.t your model's __init__(model_params).
+model_params = {
+ "wenet_transfer":True, # change wenet conformer keys when transferring
+ "training": True, "extracted_embedding": "near",
+ "embd_dim": 256, # xvector dim
+ "transformer_type": "conformer", # [conformer, transformer, re_conformer]
+ "transformer_params": {
+ "attention_dim": 256,
+ "attention_heads": 4,
+ "num_blocks": 6,
+ "combiner_type": "norm", # [norm, mfa, random_frame, random_layer]
+ "aux_layer_period": 2, # aux: select inner layer for combiner.
+ "aux_layer_start": 3,
+ "dropout_rate": 0.1,
+ "layer_dropout":0.,
+ "linear_units": 2048,
+ "positional_dropout_rate": 0.1,
+ "attention_dropout_rate": 0.1,
+ "attention_norm_args": {
+ "norm_method": "softmax_plus", # [softmax, relu_plus, softmax_plus]
+ "train_len": 300,
+ },
+ "input_layer": "conv2d", # [linear, conv2d2, conv2d, conv2d6, conv2d8]
+ "cnn_module_kernel": 15, # for conformer
+ "pos_enc_type": "rot_pos", # [abs_pos, no_pos, rot_pos, rel_pos]
+ "convfnn_blocks": 0 # use conv type ffn in head blocks.
+ },
+ "pooling": "ecpa-attentive",
+ "pooling_params": {
+ "hidden_size": 128,
+ "time_attention": False,
+ "stddev": True,
+ },
+ "fc1": False,
+ "fc1_params": {
+ "nonlinearity": 'relu', "nonlinearity_params": {"inplace": True},
+ "bn-relu": False,
+ "bn": True,
+ "bn_params": {"momentum": 0.5, "affine": False, "track_running_stats": True}},
+ "fc2_params": {
+ "nonlinearity": '', "nonlinearity_params": {"inplace": True},
+ "bn-relu": False,
+ "bn": True,
+ "ln_replace": True, # replace BN with LN
+ "bn_params": {"momentum": 0.5, "affine": True, "track_running_stats": True}},
+ "margin_loss": True,
+ "margin_loss_params": {
+ "method": "aam", "m": 0.2, "feature_normalize": True,
+ "s": 30, "mhe_loss": False, "mhe_w": 0.01},
+ # margin_warm is supported to smoothly increase margin of softmax.
+ "use_step": False,
+ "step_params": {
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":1,"end_epoch":1,"offset_margin":-0.0,"init_lambda":1.0},
+ "T": None,
+ "m": False, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4,
+ "s": False, "s_tuple": (30, 12), "s_list": None,
+ "t": False, "t_tuple": (0.5, 1.2),
+ "p": False, "p_tuple": (0.5, 0.1)}
+optimizer_params = {
+ "name": "adamW",
+ "learn_rate": 0.00025,
+ "beta1": 0.9,
+ "beta2": 0.999,
+ "beta3": 0.999,
+ # Should be large for decouped weight decay (adamW) and small for L2 regularization (sgd, adam).
+ "weight_decay": 5e-2,
+ "lookahead.k": 5,
+ # 0 means not using lookahead and if used, suggest to set it as 0.5.
+ "lookahead.alpha": 0,
+ "gc": False, # If true, use gradient centralization.
+ "nesterov": False, # for sgd
+ "sam": False, # suggest true when apply ASR transferring.
+ "sam.rho": 2.0, # 2.0 for adaptive
+ "sam.adaptive": True,
+ # "custwd_dict":{
+ # "train_len":0,
+ # "bias":0
+ # }
+lr_scheduler_params = {
+ "name": "1cycle",
+ "1cycle.learn_rate":0.001,
+ "1cycle.warmup_steps":5000,
+ "1cycle.epochs": 50,
+ "1cycle.steps_per_epoch": 2200,
+ "1cycle.div_factor":1000.0, # initial_lr = max_lr/div_factor
+ "1cycle.final_div_factor":1000.0, # min_lr = initial_lr/final_div_factor
+ "1cycle.anneal_strategy":'cos', # ["cos", "linear"]
+ "1cycle.cycle_momentum":False,
+ "noam.warmup_steps": 5000,
+ "noam.step_decay": True,
+ "noam.step_size": 8800, # suggest 4 epochs
+ "noam.step_rate": 0.5,
+ "cyclic.max_lr": 1e-3,
+ "cyclic.base_lr": 1e-8,
+ "cyclic.step_size_up": 22000,
+ "cyclic.mode": 'triangular2',
+epochs = 51 # Total epochs to train. It is important.
+compute_batch_num_valid = 10
+# About validation computation and loss reporting. If report_times_every_epoch is not None,
+report_interval_iters = 500
+# then compute report_interval_iters by report_times_every_epoch.
+suffix = "params" # Used in saved model file.
+# Other options
+exist_model = "" # Use it in transfer learning.
+# Main params
+traindata = "data/raw/voxceleb2_dev"
+traindata_for_egs = "data/raw/voxceleb2_dev"
+egs_dir = "exp/egs/voxceleb2_dev_whole"
+egs_conf = "subtools/conf/egs_pre_sre_transformer.yaml"
+model_blueprint = "subtools/pytorch/model/transformer_xvector.py"
+model_dir = "exp/conformer_6L256D4H_4sub"
+######################################################### START #########################################################
+# Set seed
+# Init environment
+# It is used for multi-gpu training if used (number of gpu-id > 1).
+# And it will do nothing for single-GPU training.
+utils.init_multi_gpu_training(args.gpu_id, args.multi_gpu_solution)
+# Set sleep time for a rest
+# Use it to run a launcher with a countdown function when there are no extra GPU memory
+# but you really want to go to bed and know when the GPU memory will be free.
+if args.sleep > 0:
+ time.sleep(args.sleep)
+# Auto-config params
+# If multi-GPU used, it will auto-scale learning rate by multiplying number of processes.
+optimizer_params["learn_rate"] = utils.auto_scale_lr(
+ optimizer_params["learn_rate"])
+# It is used for model.step() defined in model blueprint.
+if lr_scheduler_params["name"] == "warmR" and model_params["use_step"]:
+ model_params["step_params"]["T"]=(lr_scheduler_params["warmR.T_max"], lr_scheduler_params["warmR.T_mult"])
+# Preprocess
+if stage <= 2 and endstage >= 0 and utils.is_main_training():
+ # Here only give limited options because it is not convenient.
+ # Suggest to pre-execute this shell script to make it freedom and then continue to run this launcher.
+ kaldi_common.execute_command("bash subtools/pytorch/pipeline/preprocess_wav_egs.sh "
+ "--stage {stage} --endstage {endstage} --nj {nj} --whole-utt {whole_utt} --random-segment {random_segment} "
+ "--seg-dur {seg_dur} --amp-th {amp_th} --de-silence {de_silence} --vad-wav-savdir {vad_wav_savdir} "
+ "--min-len {min_len} --max-len {max_len} --limit-utts {limit_utts} "
+ "--valid-split-type {valid_split_type} --valid-num-utts {valid_utts} --valid-chunk-num {valid_chunk_num} "
+ "--valid-fix-chunk-num {valid_fix_chunk_num} --force-clear {force_clear} "
+ "--pre-rirmusan {pre_rirmusan} --openrir-folder {openrir_folder} --musan-folder {musan_folder} "
+ "--csv-aug-folder {csv_aug_folder} --savewav-folder {savewav_folder} --max-noise-len {max_noise_len} "
+ "--data-type {data_type} --shard-dir {shard_dir} --num-utts-per-shard {num_utts_per_shard} "
+ "{traindata} {traindata_for_egs} {egs_dir}".format(stage=stage, endstage=endstage, nj=preprocess_nj,
+ whole_utt=str(whole_utt).lower(), random_segment=str(random_segment).lower(), seg_dur=seg_dur,
+ amp_th=amp_th, de_silence=str(de_silence).lower(), vad_wav_savdir=vad_wav_savdir,
+ min_len=min_len, max_len=max_len, limit_utts=limit_utts, valid_split_type=valid_split_type,
+ valid_utts=valid_utts, valid_chunk_num=valid_chunk_num, valid_fix_chunk_num=str(valid_fix_chunk_num).lower(),
+ force_clear=str(force_clear).lower(), pre_rirmusan=str(pre_rirmusan).lower(), openrir_folder=openrir_folder,
+ musan_folder=musan_folder, csv_aug_folder=csv_aug_folder, savewav_folder=savewav_folder, max_noise_len=max_noise_len,
+ data_type=data_type, num_utts_per_shard=num_utts_per_shard, shard_dir=shard_dir,
+ traindata=traindata, traindata_for_egs=traindata_for_egs, egs_dir=egs_dir))
+# Train model
+if stage <= 3 <= endstage:
+ if utils.is_main_training():
+ logger.info("Get model_blueprint from model directory.")
+ # Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path.
+ model_blueprint = utils.create_model_dir(
+ model_dir, model_blueprint, stage=train_stage)
+ if utils.is_main_training():
+ logger.info("Load egs to bunch.")
+ # The dict [info] contains feat_dim and num_targets
+ with open(egs_conf, 'r') as fin:
+ egs_params = yaml.load(fin, Loader=yaml.FullLoader)
+ egs_params['dataset_conf']['csv_aug_folder'] = csv_aug_folder
+ bunch, info = egs.BaseBunch.get_bunch_from_egsdir(egs_dir, egs_params)
+ feat_extraction_config = copy.deepcopy(
+ egs_params['dataset_conf']['feature_extraction_conf'])
+ feat_extraction_config['kaldi_featset']['dither'] = 0.0
+ feat_config_path = os.path.join(model_dir, 'config', 'feat_conf.yaml')
+ if utils.is_main_training():
+ with open(feat_config_path, 'w') as fou:
+ yaml.dump(feat_extraction_config, fou)
+ if utils.is_main_training():
+ logger.info("Create model from model blueprint.")
+ # Another way: import the model.py in this python directly, but it is not friendly to the shell script of extracting and
+ # I don't want to change anything about extracting script when the model.py is changed.
+ model_py = utils.create_model_from_py(model_blueprint)
+ model = model_py.TransformerXvector(
+ info["feat_dim"], info["num_targets"], **model_params)
+ # If multi-GPU used, then batchnorm will be converted to synchronized batchnorm, which is important
+ # to make peformance stable.
+ # It will change nothing for single-GPU training.
+ model = utils.convert_synchronized_batchnorm(model)
+ # print(model)
+ epoch_iters = (info['epoch_iters']//accum_grad)
+ if hasattr(model,'margin_warm'):
+ model.margin_warm.update_step_range(epoch_iters)
+ # print(sum(p.numel() for p in model.parameters()))
+ # sys.exit()
+ if utils.is_main_training():
+ print(model)
+ p1=sum(p.numel() for p in model.parameters())
+ script_model = copy.deepcopy(model)
+ script_model.loss=None
+ p2 = sum(p.numel() for p in script_model.parameters())
+ logger.info("model params w/o proj layer: {} / {} .".format(p1,p2))
+ script_model = torch.jit.script(script_model)
+ script_model.save(os.path.join(model_dir, 'init.zip'))
+ logger.info("The number of steps per epoch is about {}.".format(epoch_iters))
+ logger.info("Define optimizer and lr_scheduler.")
+ del script_model
+ optimizer = optim.get_optimizer(model, optimizer_params)
+ lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper(
+ optimizer, lr_scheduler_params)
+ # Record params to model_dir
+ if utils.is_main_training():
+ utils.write_list_to_file([egs_params, model_params, optimizer_params,
+ lr_scheduler_params], model_dir+'/config/params.dict',yml=True)
+ if utils.is_main_training():
+ logger.info("Init a simple trainer.")
+ # Package(Elements:dict, Params:dict}. It is a key parameter's package to trainer and model_dir/config/.
+ package = ({"data": bunch, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler},
+ {"model_dir": model_dir, "model_blueprint": model_blueprint, "exist_model": exist_model, "accum_grad": accum_grad,
+ "start_epoch": train_stage, "epochs": epochs, "use_gpu": use_gpu, "gpu_id": gpu_id, "use_amp": use_amp,
+ "skip_nan_batch": skip_nan_batch, "benchmark": benchmark, "suffix": suffix, "compute_batch_num_valid": compute_batch_num_valid,
+ "report_interval_iters": report_interval_iters, "record_file": "train.csv"})
+ train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer
+ execuer = train_exec.SimpleTrainer(package)
+ if run_lr_finder:
+ execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8,
+ final_lr=10., num_iters=2000, beta=0.98)
+ endstage = 3 # Do not start extractor.
+ else:
+ execuer.run()
+# Extract xvector
+if stage <= 4 <= endstage and utils.is_main_training():
+ # There are some params for xvector extracting.
+ data_root = "data" # It contains all dataset just like Kaldi recipe.
+ prefix = "raw" # For to_extracted_data.
+ data_type_emb = "raw" # shard or raw or kaldi.
+ de_silence = False
+ amp_th = 50
+ to_extracted_positions = ["near"] # Define this w.r.t model_blueprint.
+ # All dataset should be in dataroot/prefix.
+ to_extracted_data = ["voxceleb1"]
+ # It is model's name, such as 10.params or final.params (suffix is w.r.t package).
+ to_extracted_epochs = ["50"]
+ nj = 8
+ force = True
+ use_gpu = True
+ gpu_id = ""
+ sleep_time = 10
+ feat_config = "feat_conf.yaml"
+ max_chunk = 300
+ # Run a batch extracting process.
+ try:
+ for position in to_extracted_positions:
+ # Generate the extracting config from nnet config where
+ # which position to extract depends on the 'extracted_embedding' parameter of model_creation (by my design).
+ model_blueprint, model_creation = utils.read_nnet_config(
+ "{0}/config/nnet.config".format(model_dir))
+ # To save memory without loading some independent components.
+ model_creation = model_creation.replace(
+ "training=True", "training=False")
+ model_creation = model_creation.replace(
+ model_params["extracted_embedding"], position)
+ extract_config = "{0}.extract.config".format(position)
+ utils.write_nnet_config(
+ model_blueprint, model_creation, "{0}/config/{1}".format(model_dir, extract_config))
+ for epoch in to_extracted_epochs:
+ model_file = "{0}.{1}".format(epoch, suffix)
+ point_name = "{0}_epoch_{1}".format(position, epoch)
+ # If run a trainer with background thread (do not be supported now) or run this launcher extrally with stage=4
+ # (it means another process), then this while-listen is useful to start extracting immediately (but require more gpu-memory).
+ model_path = "{0}/{1}".format(model_dir, model_file)
+ while True:
+ if os.path.exists(model_path):
+ break
+ else:
+ time.sleep(sleep_time)
+ for data in to_extracted_data:
+ datadir = "{0}/{1}/{2}".format(data_root, prefix, data)
+ outdir = "{0}/{1}/{2}".format(model_dir, point_name, data)
+ # Use a well-optimized shell script (with multi-processes) to extract xvectors.
+ # Another way: use subtools/splitDataByLength.sh and subtools/pytorch/pipeline/onestep/extract_embeddings.py
+ # with python's threads to extract xvectors directly, but the shell script is more convenient.
+ kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh "
+ " --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' "
+ " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} "
+ " --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} "
+ "{model_dir} {datadir} {outdir}".format(model_file=model_file, nj=nj,
+ use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config,
+ feat_config=feat_config, data_type=data_type_emb, de_silence=str(de_silence).lower(), amp_th=amp_th,
+ max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir))
+ except BaseException as e:
+ if not isinstance(e, KeyboardInterrupt):
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/pytorch/launcher/runTransformerXvector_LM.py b/pytorch/launcher/runTransformerXvector_LM.py
new file mode 100755
index 0000000..e767339
--- /dev/null
+++ b/pytorch/launcher/runTransformerXvector_LM.py
@@ -0,0 +1,508 @@
+# -*- coding:utf-8 -*-
+# Copyright xmuspeech (Author: Leo 2021-10-05)
+# Apache 2.0
+# Only support "nccl" backend multi-gpu training.
+import sys
+import os
+import logging
+import argparse
+import traceback
+import time
+import yaml
+import copy
+import math
+import numpy as np
+import torch
+sys.path.insert(0, 'subtools/pytorch')
+from pipeline.onestep.prepare_speechaug_csv import prepare_speech_aug
+from libs.support.logging_stdout import patch_logging_stream
+import libs.support.utils as utils
+import libs.support.kaldi_common as kaldi_common
+import libs.training.trainer_online as trainer
+import libs.training.trainer_online_sam as trainer_sam
+import libs.training.lr_scheduler_online as learn_rate_scheduler
+import libs.training.optim as optim
+import libs.egs.egs_online as egs
+"""A launcher script with python version (Snowdar's launcher to do experiments w.r.t snowdar-xvector.py).
+Python version is gived (rather than Shell) to have more freedom, such as decreasing limitation of parameters that transfering
+them to python from shell.
+Note, this script is for Large-Margin Finetuning, we assume that you have preprocessed the dataset and obtained a trained model.
+ `The IDLAB VoxCeleb Speaker Recognition Challenge 2020 System Description.`
+ https://arxiv.org/pdf/2010.12468.pdf
+How to modify this launcher:
+ 1.Prepare your kaldi format dataset and model.py (model blueprint);
+ 2.Give the path of dataset, model blueprint, etc. in main parameters field;
+ 3.Change the imported name of model in 'model = model_py.model_name(...)' w.r.t model.py by yourself;
+ 4.Modify any training parameters what you want to change (epochs, optimizer and lr_scheduler etc.);
+ 5.Modify parameters of extracting in stage 4 w.r.t your own training config;
+ 6.Run this launcher.
+Conclusion: preprare -> config -> run.
+How to run this launcher to train a model:
+ 1.For CPU-based training case. The key option is --use-gpu.
+ python3 launcher.py --use-gpu=false
+ 2.For single-GPU training case (Default).
+ python3 launcher.py
+ 3.For DDP-based multi-GPU training case. Note --nproc_per_node is equal to number of gpu id in --gpu-id.
+ python3 -m torch.distributed.launch --nproc_per_node=2 launcher.py --gpu-id=0,1
+ 4.For Horovod-based multi-GPU training case. Note --np is equal to number of gpu id in --gpu-id.
+ horovodrun -np 2 launcher.py --gpu-id=0,1
+ 5.For all of above, a runLauncher.sh script has been created to launch launcher.py conveniently.
+ The key option to use single or multiple GPU is --gpu-id.
+ The subtools/runPytorchLauncher.sh is a soft symbolic which is linked to subtools/pytorch/launcher/runLauncher.sh,
+ so just use it.
+ [ CPU ]
+ subtools/runPytorchLauncher.sh launcher.py --use-gpu=false
+ [ Single-GPU ]
+ (1) Auto-select GPU device
+ subtools/runPytorchLauncher.sh launcher.py
+ (2) Specify GPU device
+ subtools/runPytorchLauncher.sh launcher.py --gpu-id=2
+ [ Multi-GPU ]
+ (1) Use DDP solution (Default).
+ subtools/runPytorchLauncher.sh launcher.py --gpu-id=2,3 --multi-gpu-solution="ddp"
+ (2) Use Horovod solution.
+ subtools/runPytorchLauncher.sh launcher.py --gpu-id=2,3 --multi-gpu-solution="horovod"
+If you have any other requirements, you could modify the codes in anywhere.
+For more details of multi-GPU devolopment, see subtools/README.md.
+# Logger
+# Logger
+logger = logging.getLogger('libs')
+handler = logging.StreamHandler()
+formatter = logging.Formatter("%(asctime)s [ %(pathname)s:%(lineno)s - "
+ "%(funcName)s - %(levelname)s ]\n#### %(message)s")
+# Parser: add this parser to run launcher with some frequent options (really for conveninece).
+parser = argparse.ArgumentParser(
+ description="""Train xvector framework with pytorch.""",
+ formatter_class=argparse.RawTextHelpFormatter,
+ conflict_handler='resolve')
+parser.add_argument("--stage", type=int, default=3,
+ help="The stage to control the start of training epoch (default 3).\n"
+ " stage 3: Training.\n"
+ " stage 4: extract xvector.")
+parser.add_argument("--endstage", type=int, default=4,
+ help="The endstage to control the endstart of training epoch (default 4).")
+parser.add_argument("--train-stage", type=int, default=-1,
+ help="The stage to control the start of training epoch (default -1).\n"
+ " -1 -> creating model_dir.\n"
+ " 0 -> model initialization (e.g. transfer learning).\n"
+ " >0 -> recovering training.")
+parser.add_argument("--force-clear", type=str, action=kaldi_common.StrToBoolAction,
+ default=True, choices=["true", "false"],
+ help="Clear the dir generated by preprocess.")
+parser.add_argument("--pre-rirmusan", type=str, action=kaldi_common.StrToBoolAction,
+ default=True, choices=["true", "false"],
+ help="Prepare the openrir and musan dataset for adding reverb and noises.")
+parser.add_argument('--use-amp', type=str, action=kaldi_common.StrToBoolAction,
+ default=False, choices=["true", "false"],
+ help='Use automatic mixed precision training')
+parser.add_argument("--skip-nan-batch", type=str, action=kaldi_common.StrToBoolAction,
+ default=True, choices=["true", "false"],
+ help="Whether skip optimizer stepping when the gradient has nan/inf values.")
+parser.add_argument("--accum-grad", type=int, default=1,
+ help="Using accumulate grad.")
+parser.add_argument("--multi-gpu-solution", type=str, default="ddp",
+ choices=["ddp"],
+ help="if number of gpu_id > 1, this option will be valid to init a multi-gpu solution.")
+parser.add_argument("--use-gpu", type=str, action=kaldi_common.StrToBoolAction,
+ default=True, choices=["true", "false"],
+ help="Use GPU or not.")
+parser.add_argument("--gpu-id", type=str, default="",
+ help="If NULL, then it will be auto-specified.")
+parser.add_argument("--benchmark", type=str, action=kaldi_common.StrToBoolAction,
+ default=False, choices=["true", "false"],
+ help="If true, save training time but require a little more gpu-memory.")
+parser.add_argument("--run-lr-finder", type=str, action=kaldi_common.StrToBoolAction,
+ default=False, choices=["true", "false"],
+ help="If true, run lr finder rather than training.")
+parser.add_argument("--sleep", type=int, default=0,
+ help="The waiting time to launch a launcher.")
+parser.add_argument("--local_rank", type=int, default=0,
+ help="Do not delete it when using DDP-based multi-GPU training.\n"
+ "It is important for torch.distributed.launch.")
+args = parser.parse_args()
+######################################################### PARAMS ########################################################
+# Control options
+stage = max(3, args.stage)
+endstage = min(4, args.endstage)
+train_stage = max(-1, args.train_stage)
+# Prepare speech augmention csv files.
+pre_rirmusan = args.pre_rirmusan # whether skip this stage.
+openrir_folder = "export/path" # where contains RIRS_NOISES folder.
+musan_folder = "export/path" # where contains musan folder.
+csv_aug_folder = "exp/aug_csv6" # csv file location.
+savewav_folder = "/export/yourpath/speech_aug_6015" # save the noise seg into SSD.
+max_noise_len = 6.015 # The max dur of noise.
+# Training options
+use_amp = args.use_amp
+skip_nan_batch = args.skip_nan_batch
+accum_grad = args.accum_grad
+use_gpu = args.use_gpu # Default true.
+# If true, save much training time but require a little more gpu-memory.
+benchmark = args.benchmark
+gpu_id = args.gpu_id # If NULL, then it will be auto-specified.
+run_lr_finder = args.run_lr_finder
+# Define model_params by model_blueprint w.r.t your model's __init__(model_params).
+model_params = {
+ "wenet_transfer":True, # change wenet conformer keys when transferring
+ "training": True, "extracted_embedding": "near",
+ "embd_dim": 256, # xvector dim
+ "transformer_type": "conformer", # [conformer, transformer, re_conformer]
+ "transformer_params": {
+ "attention_dim": 256,
+ "attention_heads": 4,
+ "num_blocks": 6,
+ "combiner_type": "norm", # [norm, mfa, random_frame, random_layer]
+ "aux_layer_period": 2, # aux: select inner layer for combiner.
+ "aux_layer_start": 3,
+ "dropout_rate": 0.1,
+ "layer_dropout":0.,
+ "linear_units": 2048,
+ "positional_dropout_rate": 0.1,
+ "attention_dropout_rate": 0.1,
+ "attention_norm_args": {
+ "norm_method": "softmax_plus", # [softmax, relu_plus, softmax_plus]
+ "train_len": 300,
+ },
+ "input_layer": "conv2d", # [linear, conv2d2, conv2d, conv2d6, conv2d8]
+ "cnn_module_kernel": 15, # for conformer
+ "pos_enc_type": "rot_pos", # [abs_pos, no_pos, rot_pos, rel_pos]
+ "convfnn_blocks": 0 # use conv type ffn in head blocks.
+ },
+ "pooling": "ecpa-attentive",
+ "pooling_params": {
+ "hidden_size": 128,
+ "time_attention": False,
+ "stddev": True,
+ },
+ "fc1": False,
+ "fc1_params": {
+ "nonlinearity": 'relu', "nonlinearity_params": {"inplace": True},
+ "bn-relu": False,
+ "bn": True,
+ "bn_params": {"momentum": 0.5, "affine": False, "track_running_stats": True}},
+ "fc2_params": {
+ "nonlinearity": '', "nonlinearity_params": {"inplace": True},
+ "bn-relu": False,
+ "bn": True,
+ "ln_replace": True, # replace BN with LN
+ "bn_params": {"momentum": 0.5, "affine": True, "track_running_stats": True}},
+ "margin_loss": True,
+ "margin_loss_params": {
+ "method": "aam", "m": 0.5, "feature_normalize": True,
+ "s": 30, "mhe_loss": False, "mhe_w": 0.01},
+ "use_step": True,
+ "step_params": {
+ "margin_warm":True,
+ "margin_warm_conf":{"start_epoch":1,"end_epoch":2,"offset_margin":-0.3,"init_lambda":0.5},
+ "T": None,
+ "m": False, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4,
+ "s": False, "s_tuple": (30, 12), "s_list": None,
+ "t": False, "t_tuple": (0.5, 1.2),
+ "p": False, "p_tuple": (0.5, 0.1)}
+optimizer_params = {
+ "name": "sgd",
+ "learn_rate": 0.0001,
+ "beta1": 0.9,
+ "beta2": 0.999,
+ "beta3": 0.999,
+ # Should be large for decouped weight decay (adamW) and small for L2 regularization (sgd, adam).
+ "weight_decay": 1e-5,
+ "lookahead.k": 5,
+ # 0 means not using lookahead and if used, suggest to set it as 0.5.
+ "lookahead.alpha": 0,
+ "gc": False, # If true, use gradient centralization.
+ "nesterov": False, # for sgd
+ "sam": False,
+ "sam.rho": 2.0, # 2.0 for adaptive
+ "sam.adaptive": True,
+ "custwd_dict":{
+ "train_len":0,
+ "loss":1e-4,
+ # "bias":0
+ }
+lr_scheduler_params = {
+ "name": "1cycle",
+ "1cycle.learn_rate":4e-4,
+ "1cycle.warmup_steps":2000,
+ "1cycle.epochs": 4,
+ "1cycle.steps_per_epoch": 4300,
+ "1cycle.div_factor":10000.0, # initial_lr = max_lr/div_factor
+ "1cycle.final_div_factor":100.0, # min_lr = initial_lr/final_div_factor
+ "1cycle.anneal_strategy":'cos', # ["cos", "linear"]
+ "1cycle.cycle_momentum":False,
+ "noam.warmup_steps": 2000,
+ "noam.step_decay": True,
+ "noam.step_size": 4300, # suggest 4 epochs
+ "noam.step_rate": 0.5,
+ "cyclic.max_lr": 1e-3,
+ "cyclic.base_lr": 1e-8,
+ "cyclic.step_size_up": 22000,
+ "cyclic.mode": 'triangular2',
+epochs = 5 # Total epochs to train. It is important.
+compute_batch_num_valid = 10
+# About validation computation and loss reporting. If report_times_every_epoch is not None,
+report_interval_iters = 500
+# then compute report_interval_iters by report_times_every_epoch.
+suffix = "params" # Used in saved model file.
+# Other options
+exist_model = "exp/conformer_6L256D4H_4sub/50.params" # Use it in transfer learning.
+# Main paramsW
+egs_dir = "exp/egs/voxceleb2_dev_whole"
+egs_conf = "subtools/conf/egs_pre_sre_lm.yaml"
+model_blueprint = "subtools/pytorch/model/transformer_xvector.py"
+model_dir = "exp/conformer_6L256D4H_4sub_lm"
+######################################################### START #########################################################
+# Set seed
+# Init environment
+# It is used for multi-gpu training if used (number of gpu-id > 1).
+# And it will do nothing for single-GPU training.
+utils.init_multi_gpu_training(args.gpu_id, args.multi_gpu_solution)
+# Set sleep time for a rest
+# Use it to run a launcher with a countdown function when there are no extra GPU memory
+# but you really want to go to bed and know when the GPU memory will be free.
+if args.sleep > 0:
+ time.sleep(args.sleep)
+# Auto-config params
+# If multi-GPU used, it will auto-scale learning rate by multiplying number of processes.
+optimizer_params["learn_rate"] = utils.auto_scale_lr(
+ optimizer_params["learn_rate"])
+# It is used for model.step() defined in model blueprint.
+if lr_scheduler_params["name"] == "warmR" and model_params["use_step"]:
+ model_params["step_params"]["T"]=(lr_scheduler_params["warmR.T_max"], lr_scheduler_params["warmR.T_mult"])
+if pre_rirmusan:
+ prepare_speech_aug(openrir_folder, musan_folder, csv_folder=csv_aug_folder, savewav_folder=savewav_folder, max_noise_len=max_noise_len, force_clear=True)
+# Train model
+if stage <= 3 <= endstage:
+ if utils.is_main_training():
+ logger.info("Get model_blueprint from model directory.")
+ # Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path.
+ model_blueprint = utils.create_model_dir(
+ model_dir, model_blueprint, stage=train_stage)
+ if utils.is_main_training():
+ logger.info("Load egs to bunch.")
+ # The dict [info] contains feat_dim and num_targets
+ with open(egs_conf, 'r') as fin:
+ egs_params = yaml.load(fin, Loader=yaml.FullLoader)
+ egs_params['dataset_conf']['csv_aug_folder'] = csv_aug_folder
+ bunch, info = egs.BaseBunch.get_bunch_from_egsdir(egs_dir, egs_params)
+ feat_extraction_config = copy.deepcopy(
+ egs_params['dataset_conf']['feature_extraction_conf'])
+ feat_extraction_config['kaldi_featset']['dither'] = 0.0
+ feat_config_path = os.path.join(model_dir, 'config', 'feat_conf.yaml')
+ if utils.is_main_training():
+ with open(feat_config_path, 'w') as fou:
+ yaml.dump(feat_extraction_config, fou)
+ if utils.is_main_training():
+ logger.info("Create model from model blueprint.")
+ # Another way: import the model.py in this python directly, but it is not friendly to the shell script of extracting and
+ # I don't want to change anything about extracting script when the model.py is changed.
+ model_py = utils.create_model_from_py(model_blueprint)
+ model = model_py.TransformerXvector(
+ info["feat_dim"], info["num_targets"], **model_params)
+ # If multi-GPU used, then batchnorm will be converted to synchronized batchnorm, which is important
+ # to make peformance stable.
+ # It will change nothing for single-GPU training.
+ model = utils.convert_synchronized_batchnorm(model)
+ # print(model)
+ epoch_iters = (info['epoch_iters']//accum_grad)
+ if hasattr(model,'margin_warm'):
+ model.margin_warm.update_step_range(epoch_iters)
+ # print(sum(p.numel() for p in model.parameters()))
+ # sys.exit()
+ if utils.is_main_training():
+ print(model)
+ p1=sum(p.numel() for p in model.parameters())
+ script_model = copy.deepcopy(model)
+ script_model.loss=None
+ p2 = sum(p.numel() for p in script_model.parameters())
+ logger.info("model params w/o proj layer: {} / {} .".format(p1,p2))
+ script_model = torch.jit.script(script_model)
+ script_model.save(os.path.join(model_dir, 'init.zip'))
+ logger.info("The number of steps per epoch is about {}.".format(epoch_iters))
+ logger.info("Define optimizer and lr_scheduler.")
+ del script_model
+ optimizer = optim.get_optimizer(model, optimizer_params)
+ lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper(
+ optimizer, lr_scheduler_params)
+ # Record params to model_dir
+ if utils.is_main_training():
+ utils.write_list_to_file([egs_params, model_params, optimizer_params,
+ lr_scheduler_params], model_dir+'/config/params.dict',yml=True)
+ if utils.is_main_training():
+ logger.info("Init a simple trainer.")
+ # Package(Elements:dict, Params:dict}. It is a key parameter's package to trainer and model_dir/config/.
+ package = ({"data": bunch, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler},
+ {"model_dir": model_dir, "model_blueprint": model_blueprint, "exist_model": exist_model, "accum_grad": accum_grad,
+ "start_epoch": train_stage, "epochs": epochs, "use_gpu": use_gpu, "gpu_id": gpu_id, "use_amp": use_amp,
+ "skip_nan_batch": skip_nan_batch, "benchmark": benchmark, "suffix": suffix, "compute_batch_num_valid": compute_batch_num_valid,
+ "report_interval_iters": report_interval_iters, "record_file": "train.csv"})
+ train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer
+ execuer = train_exec.SimpleTrainer(package)
+ if run_lr_finder:
+ execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8,
+ final_lr=10., num_iters=2000, beta=0.98)
+ endstage = 3 # Do not start extractor.
+ else:
+ execuer.run()
+# Extract xvector
+if stage <= 4 <= endstage and utils.is_main_training():
+ # There are some params for xvector extracting.
+ data_root = "data" # It contains all dataset just like Kaldi recipe.
+ prefix = "raw" # For to_extracted_data.
+ data_type_emb = "raw" # shard or raw or kaldi.
+ de_silence = False
+ amp_th = 50
+ to_extracted_positions = ["near"] # Define this w.r.t model_blueprint.
+ # All dataset should be in dataroot/prefix.
+ to_extracted_data = ["voxceleb1", "voxceleb2_dev"]
+ # It is model's name, such as 10.params or final.params (suffix is w.r.t package).
+ to_extracted_epochs = ["4"]
+ nj = 8
+ force = True
+ use_gpu = True
+ gpu_id = ""
+ sleep_time = 10
+ feat_config = "feat_conf.yaml"
+ max_chunk = 4000
+ # Run a batch extracting process.
+ try:
+ for position in to_extracted_positions:
+ # Generate the extracting config from nnet config where
+ # which position to extract depends on the 'extracted_embedding' parameter of model_creation (by my design).
+ model_blueprint, model_creation = utils.read_nnet_config(
+ "{0}/config/nnet.config".format(model_dir))
+ # To save memory without loading some independent components.
+ model_creation = model_creation.replace(
+ "training=True", "training=False")
+ model_creation = model_creation.replace(
+ model_params["extracted_embedding"], position)
+ extract_config = "{0}.extract.config".format(position)
+ utils.write_nnet_config(
+ model_blueprint, model_creation, "{0}/config/{1}".format(model_dir, extract_config))
+ for epoch in to_extracted_epochs:
+ model_file = "{0}.{1}".format(epoch, suffix)
+ point_name = "{0}_epoch_{1}".format(position, epoch)
+ # If run a trainer with background thread (do not be supported now) or run this launcher extrally with stage=4
+ # (it means another process), then this while-listen is useful to start extracting immediately (but require more gpu-memory).
+ model_path = "{0}/{1}".format(model_dir, model_file)
+ while True:
+ if os.path.exists(model_path):
+ break
+ else:
+ time.sleep(sleep_time)
+ for data in to_extracted_data:
+ datadir = "{0}/{1}/{2}".format(data_root, prefix, data)
+ outdir = "{0}/{1}/{2}".format(model_dir, point_name, data)
+ # Use a well-optimized shell script (with multi-processes) to extract xvectors.
+ # Another way: use subtools/splitDataByLength.sh and subtools/pytorch/pipeline/onestep/extract_embeddings.py
+ # with python's threads to extract xvectors directly, but the shell script is more convenient.
+ kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh "
+ " --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' "
+ " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} "
+ " --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} "
+ "{model_dir} {datadir} {outdir}".format(model_file=model_file, nj=nj,
+ use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config,
+ feat_config=feat_config, data_type=data_type_emb, de_silence=str(de_silence).lower(), amp_th=amp_th,
+ max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir))
+ except BaseException as e:
+ if not isinstance(e, KeyboardInterrupt):
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/pytorch/libs/egs/egs_online.py b/pytorch/libs/egs/egs_online.py
old mode 100644
new mode 100755
index 311215f..c2ab949
--- a/pytorch/libs/egs/egs_online.py
+++ b/pytorch/libs/egs/egs_online.py
@@ -13,7 +13,6 @@
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader
import torch.distributed as dist
sys.path.insert(0, "subtools/pytorch")
import libs.support.utils as utils
import libs.egs.processor as processor
@@ -62,6 +61,8 @@ def get_data_dur(self):
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
+ def __len__(self):
+ return len(self.source)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
@@ -144,28 +145,54 @@ def get_data_dur(self):
logger.warning('do not support get duration')
num_sample = sum([int(self.lists[index]['eg-num']) for index in self.data_this_rank]) if "eg-num" in self.lists[0] else len(self.data_this_rank)
+ # tot_sample = sum([int(self.lists[index]['eg-num']) for index in self.lists]) if "eg-num" in self.lists[0] else len(self.lists)
return total_dur,num_sample
+ def __len__(self):
+ return sum([int(self.lists[index]['eg-num']) for index in range(len(self.lists))]) if "eg-num" in self.lists[0] else len(self.lists)
-def WavEgs(egs_csv,conf,data_type='raw',partition=True):
+def WavEgs(egs_csv,conf,data_type='raw',partition=True,num_targets=0):
assert data_type in ['raw', 'shard', 'kaldi']
lists = utils.csv_to_list(egs_csv)
shuffle = conf.get('shuffle', True)
dataset = DataList(lists, shuffle=shuffle, partition=partition)
- if data_type in ['raw','shard']:
+ if data_type in ['raw', 'shard']:
if data_type=='shard':
dataset = Processor(dataset, processor.url_opener)
dataset = Processor(dataset, processor.tar_file_and_group)
dataset = Processor(dataset, processor.parse_raw)
- random_chunk = conf.get('random_chunk',False)
- if random_chunk:
- random_chunk_size = conf.get('random_chunk_size',2.015)
- dataset = Processor(dataset, processor.random_chunk, random_chunk_size)
+ filt = conf.get('filter', False)
+ filter_conf = conf.get('filter_conf', {})
+ if filt:
+ dataset = Processor(dataset, processor.filter, **filter_conf)
resample = conf.get('resample', False)
if resample:
resample_conf = conf.get('resample_conf', {})
dataset = Processor(dataset, processor.resample, **resample_conf)
+ pre_speed_perturb = conf.get('pre_speed_perturb', False)
+ spkid_aug = 1
+ if pre_speed_perturb:
+ perturb_conf = conf.get('perturb_conf',{})
+ sp = processor.PreSpeedPerturb(spk_num=num_targets,**perturb_conf)
+ spkid_aug = sp._spkid_aug()
+ dataset = Processor(dataset,sp)
+ random_chunk = conf.get('random_chunk',False)
+ random_chunk_size = conf.get('random_chunk_size',2.015)
+ if random_chunk:
+ dataset = Processor(dataset, processor.random_chunk, random_chunk_size)
speech_aug = conf.get('speech_aug', False)
speech_aug_conf_file = conf.get('speech_aug_conf', '')
@@ -177,7 +204,11 @@ def WavEgs(egs_csv,conf,data_type='raw',partition=True):
csv_aug_folder = conf.get('csv_aug_folder','')
if csv_aug_folder:change_csv_folder(speech_aug_conf,csv_aug_folder)
- speechaug_pipline = processor.SpeechAugPipline(**speech_aug_conf)
+ speechaug_pipline = processor.SpeechAugPipline(spk_num=num_targets,**speech_aug_conf)
+ spkid_aug_lat = speechaug_pipline.get_spkid_aug()
+ if not (spkid_aug==1 or spkid_aug_lat==1):
+ raise ValueError("multi speaker id perturb setting, check your speech aug config")
+ spkid_aug = spkid_aug_lat*spkid_aug
dataset = Processor(dataset, speechaug_pipline)
feature_extraction_conf = conf.get('feature_extraction_conf',{})
@@ -185,6 +216,7 @@ def WavEgs(egs_csv,conf,data_type='raw',partition=True):
dataset = Processor(dataset, feature_extraction)
dataset = Processor(dataset,processor.offline_feat)
spec_aug = conf.get('spec_aug', False)
if spec_aug:
@@ -201,7 +233,8 @@ def WavEgs(egs_csv,conf,data_type='raw',partition=True):
batch_conf = conf.get('batch_conf', {})
dataset = Processor(dataset, processor.batch, **batch_conf)
dataset = Processor(dataset, processor.padding)
- return dataset
+ return dataset,spkid_aug
def WavEgsXvector(wav_scp,feat_conf={},data_type='kaldi',de_silence=False,de_sil_conf={},partition=False):
if data_type in ['raw', 'shard']:
@@ -242,19 +275,22 @@ def __init__(self, trainset, valid=None,prefetch_factor=2,num_workers=0, pin_mem
- def get_bunch_from_csv(self, trainset_csv: str, valid_csv: str = None, egs_params: dict = {}):
+ def get_bunch_from_csv(self, trainset_csv: str, valid_csv: str = None, egs_params: dict = {},num_targets=-1):
train_conf = egs_params['dataset_conf']
valid_conf = copy.deepcopy(train_conf)
valid_conf['speech_aug'] = False
+ valid_conf['pre_speed_perturb'] = False
valid_conf['spec_aug'] = False
valid_conf['shuffle'] = False
data_type = egs_params.get('data_type','raw')
- trainset = WavEgs(trainset_csv, train_conf, data_type=data_type,partition=True)
+ trainset, num_targets_t = WavEgs(trainset_csv, train_conf, data_type=data_type,partition=True,num_targets=num_targets)
if valid_csv != "" and valid_csv is not None:
- valid = WavEgs(valid_csv, valid_conf,data_type=data_type,partition=False)
+ valid,_ = WavEgs(valid_csv, valid_conf,data_type=data_type,partition=False)
valid = None
+ self.num_targets =num_targets*num_targets_t
return self(trainset, valid, **egs_params['data_loader_conf'])
@@ -275,9 +311,17 @@ def get_bunch_from_egsdir(self, egsdir: str, egs_params: dict={}):
egsdir, train_csv_name=train_csv_name, valid_csv_name=valid_csv_name)
assert 'feat_dim' in egs_params
feat_dim = int(egs_params['feat_dim'])
- info = {"feat_dim": feat_dim, "num_targets": num_targets}
bunch = self.get_bunch_from_csv(
- train_csv, valid_csv, egs_params)
+ train_csv, valid_csv, egs_params,num_targets)
+ num_targets = self.num_targets
+ tot_samples = len(bunch.train_loader.dataset)
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
+ if egs_params['dataset_conf']['batch_conf']['batch_type']=='static':
+ epoch_iters = (tot_samples//world_size)//egs_params['dataset_conf']['batch_conf']['batch_size']
+ else:
+ epoch_iters = None
+ info = {"feat_dim": feat_dim, "num_targets": num_targets, "epoch_iters": epoch_iters}
return bunch, info
@@ -301,4 +345,4 @@ def get_info_from_egsdir(egsdir, train_csv_name=None, valid_csv_name=None):
if __name__ == "__main__":
- pass
+ pass
\ No newline at end of file
diff --git a/pytorch/libs/egs/processor.py b/pytorch/libs/egs/processor.py
old mode 100644
new mode 100755
index e3fc4f9..39e1440
--- a/pytorch/libs/egs/processor.py
+++ b/pytorch/libs/egs/processor.py
@@ -15,7 +15,7 @@
import torchaudio.compliance.kaldi as kaldi
import libs.support.kaldi_io as kaldi_io
from libs.support.utils import batch_pad_right,get_torchaudio_backend
-from .speech_augment import SpeechAug
+from .speech_augment import SpeechAug,SpeedPerturb
from .signal_processing import de_silence
from libs.egs.augmentation import *
from libs.egs.kaldi_features import InputSequenceNormalization
@@ -87,7 +87,8 @@ def tar_file_and_group(data):
if postfix == 'txt':
label = file_obj.read().decode('utf8')
- example['label'] = int(label)
+ example['label'] = torch.tensor([int(label)],dtype=torch.long)
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = torchaudio.load(file_obj)
example['wav'] = waveform[:1,:]
@@ -134,7 +135,7 @@ def parse_raw(data):
waveform, sample_rate = torchaudio.load(wav_file)
waveform = waveform[:1,:]
- label = int(label)
+ label = torch.tensor([int(label)],dtype=torch.long)
lens = torch.ones(1)
example = dict(key=key,
@@ -173,6 +174,48 @@ def de_sil(data,win_len=0.1,min_eng=50,retry_times=1,force_output=True):
del waveform
yield sample
+class PreSpeedPerturb(object):
+ def __init__(self,
+ sample_rate=16000,
+ speeds=[90, 100, 110],
+ perturb_type='resample',
+ perturb_prob=1.0,
+ change_spk=True,
+ spk_num=0):
+ super().__init__()
+ self.change_spk = change_spk
+ self.speeder = SpeedPerturb(sample_rate, speeds=speeds, perturb_prob=perturb_prob, perturb_type=perturb_type, spk_num=spk_num,change_spk=change_spk,keep_shape=False)
+ def __call__(self, data):
+ """ speechaug.
+ Args:
+ data: Iterable[{key, wav, label, lens, sample_rate}]
+ Returns:
+ Iterable[{key, wav, label, lens, sample_rate}]
+ """
+ for sample in data:
+ assert 'wav' in sample
+ assert 'label' in sample
+ waveform = sample['wav']
+ spkid = sample['label']
+ try:
+ waveform,spkid = self.speeder(waveform, spkid)
+ sample['wav'] = waveform
+ if self.change_spk:
+ sample['label'] = spkid
+ yield sample
+ except Exception as ex:
+ logging.warning('Failed to speech aug {}'.format(sample['key']))
+ def _spkid_aug(self):
+ spkid_aug, _ = self.speeder.get_spkid_aug()
+ return spkid_aug
def random_chunk(data,chunk_len=2.015):
data: Iterable[{key, wav, label, lens, sample_rate}]
@@ -215,6 +258,7 @@ def offline_feat(data):
key = sample['eg-id']
ark_path = sample['ark-path']
label = int(sample['class-label'])
+ label=torch.tensor([label],dtype=torch.long)
if 'start-position' in sample:
@@ -257,17 +301,57 @@ def resample(data, resample_rate=16000):
+def filter(data,
+ max_length=15,
+ min_length=0.2,
+ max_cut=True):
+ """ Filter sample according to feature.
+ Args::
+ data: Iterable[{key, wav, label, sample_rate}]
+ max_length: drop utterance which is greater than max_length(s)
+ min_length: drop utterance which is less than min_length(s)
+ Returns:
+ Iterable[{key, wav, *, sample_rate}]
+ """
+ for sample in data:
+ assert 'sample_rate' in sample
+ assert 'wav' in sample
+ duration_sample=sample['wav'].size(1)
+ duration = duration_sample / sample['sample_rate']
+ if duration < min_length:
+ continue
+ if duration > max_length:
+ if max_cut:
+ duration_sample=sample['wav'].size(1)
+ snt_len_sample = int(max_length*sample['sample_rate'])
+ start = random.randint(0, duration_sample - snt_len_sample - 1)
+ stop = stop = start + snt_len_sample
+ sample['wav'] = sample['wav'][:,start:stop]
+ else:
+ continue
+ yield sample
class SpeechAugPipline(object):
- def __init__(self, speechaug={}, tail_speechaug={}):
+ def __init__(self,spk_num=0,speechaug={}, tail_speechaug={}):
- self.speechaug = SpeechAug(**speechaug)
- self.tail_speechaug = SpeechAug(**tail_speechaug)
+ self.speechaug = SpeechAug(spk_num=spk_num,**speechaug)
+ self.tail_speechaug = SpeechAug(spk_num=spk_num,**tail_speechaug)
+ speechaug_nid_aug=self.speechaug.get_spkid_aug()
+ tail_speechaug_nid_aug=self.tail_speechaug.get_spkid_aug()
# The concat number of speech augment, which is used to modify target.
self.concat_pip= (speechaug_n_concat,tail_speechaug_n_concat)
+ self.spk_nid_aug = tail_speechaug_nid_aug*speechaug_nid_aug
+ if speechaug_nid_aug>1 and tail_speechaug_nid_aug>1:
+ raise ValueError("multi speaker id perturb setting, check your speech aug config")
def __call__(self, data):
""" speechaug.
@@ -283,17 +367,20 @@ def __call__(self, data):
waveforms = sample['wav']
lens = sample['lens']
- try:
- waveforms, lens = self.speechaug(waveforms, lens)
+ spkid = sample['label']
+ # try:
- waveforms, lens = self.tail_speechaug(waveforms, lens)
- sample['wav'] = waveforms
- sample['lens'] = lens
- yield sample
- except Exception as ex:
- logging.warning('Failed to speech aug {}'.format(sample['key']))
+ waveforms, lens,spkid = self.speechaug(waveforms, lens,spkid)
+ waveforms, lens,spkid = self.tail_speechaug(waveforms, lens,spkid)
+ sample['wav'] = waveforms
+ sample['label'] = spkid
+ sample['lens'] = lens
+ yield sample
+ # except Exception as ex:
+ # logging.warning('Failed to speech aug {}'.format(sample['key']))
+ def get_spkid_aug(self):
+ return self.spk_nid_aug
@@ -311,7 +398,6 @@ def __init__(self,feature_type='mfcc',kaldi_featset={},mean_var_conf={}):
assert feature_type in ['mfcc','fbank']
if self.feat_type=='mfcc':
@@ -327,7 +413,7 @@ def __call__(self,data):
data: Iterable[{key, wav, label, lens, sample_rate}]
- Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
for sample in data:
assert 'wav' in sample
@@ -339,9 +425,10 @@ def __call__(self,data):
self.kaldi_featset['sample_frequency'] = sample['sample_rate']
lens = sample['lens']
waveforms = sample['wav']
+ label = sample['label']
waveforms = waveforms * (1 << 15)
feats = []
- label = sample['label']
+ labels=[]
utt = sample['key']
@@ -349,27 +436,32 @@ def __call__(self,data):
for i,wav in enumerate(waveforms):
if len(wav.shape)==1:
# add channel
+ else:
+ wav = wav.transpose(0, 1)
wav= wav[:,:lens[i].long()]
logging.warning('Failed to make featrue for {}, aug version:{}'.format(sample['key'],i))
- pass
+ continue
feat = feat.detach()
key = sample['key']+'#{}'.format(i) if i>0 else sample['key']
+ labels.append(label[i])
if len(feats)==0:
- pass
+ continue
max_len = max([feat.size(0) for feat in feats])
- yield dict(utt=utt,keys=keys,feats=feats,label=label,max_len=max_len)
+ yield dict(utt=utt,keys=keys,feats=feats,labels=labels,max_len=max_len)
except Exception as ex:
logging.warning('Failed to make featrue {}'.format(sample['key']))
@@ -381,9 +473,9 @@ def __init__(self,aug=None,aug_params={}):
def __call__(self,data):
""" make features.
- data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
- Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
for sample in data:
assert 'keys' in sample
@@ -396,32 +488,7 @@ def __call__(self,data):
yield sample
-# def speed_perturb(data, speeds=None):
-# """ Apply speed perturb to the data.
-# Inplace operation.
-# Args:
-# data: Iterable[{key, wav, label, sample_rate}]
-# speeds(List[float]): optional speed
-# Returns:
-# Iterable[{key, wav, label, sample_rate}]
-# """
-# if speeds is None:
-# speeds = [0.9, 1.0, 1.1]
-# for sample in data:
-# assert 'sample_rate' in sample
-# assert 'wav' in sample
-# sample_rate = sample['sample_rate']
-# waveform = sample['wav']
-# speed = random.choice(speeds)
-# if speed != 1.0:
-# wav, _ = torchaudio.sox_effects.apply_effects_tensor(
-# waveform, sample_rate,
-# [['speed', str(speed)], ['rate', str(sample_rate)]])
-# sample['wav'] = wav
-# yield sample
@@ -429,14 +496,16 @@ def shuffle(data, shuffle_size=10000):
""" Local shuffle the data
- data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
shuffle_size: buffer size for shuffle
- Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
buf = []
for sample in data:
if len(buf) >= shuffle_size:
@@ -456,11 +525,11 @@ def sort(data, sort_size=500):
be less than `shuffle_size`
- data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
sort_size: buffer size for sort
- data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
buf = []
@@ -476,14 +545,13 @@ def sort(data, sort_size=500):
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
- data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
batch_size: batch size
- Iterable[List[{utt:str, keys:list, label, feats:list, max_len:int}]]
+ Iterable[List[{utt:str, keys:list, labels:list, feats:list, max_len:int}]]
buf = []
for sample in data:
@@ -500,10 +568,10 @@ def dynamic_batch(data, max_frames_in_batch=12000):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
- data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}]
+ data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}]
max_frames_in_batch: max_frames in one batch
- Iterable[List[{utt:str, keys:list, label, feats:list, max_len:int}]]
+ Iterable[List[{utt:str, keys:list, labels:list, feats:list, max_len:int}]]
buf = []
longest_frames = 0
@@ -541,25 +609,27 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000):
def padding(data):
""" Padding the data into training data
- data: Iterable[List[{utt:str, keys:list, label, feats:list, max_len:int}]]
+ data: Iterable[List[{utt:str, keys:list, labels:list, feats:list, max_len:int}]]
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
for sample in data:
assert isinstance(sample, list)
for x in sample:
- labels.extend([x['label']]*len(x['feats']))
+ labels.extend(x['labels'])
+ labels = torch.tensor(labels)
- labels = torch.LongTensor(labels)
feats = [(x.T) for x in feats]
- padded_feats, lens = batch_pad_right(feats)
+ padded_feats, feats_lens = batch_pad_right(feats)
- yield (padded_feats, labels)
+ yield (padded_feats, labels, feats_lens)
diff --git a/pytorch/libs/egs/speech_augment.py b/pytorch/libs/egs/speech_augment.py
old mode 100644
new mode 100755
index 547fe61..3920f1e
--- a/pytorch/libs/egs/speech_augment.py
+++ b/pytorch/libs/egs/speech_augment.py
@@ -3,6 +3,7 @@
# Importing libraries
import math
+from typing import Optional,List
import numpy as np
import random
import sklearn
@@ -23,10 +24,10 @@
from libs.support.utils import batch_pad_right
+import torch.distributed as dist
class NoiseDataset(Dataset):
- def __init__(self, csv_file, sorting="original", max_len=None):
+ def __init__(self, csv_file, sorting="original", max_len=None, filt_min=None):
head = pd.read_csv(csv_file, sep=" ", nrows=0).columns
@@ -37,6 +38,11 @@ def __init__(self, csv_file, sorting="original", max_len=None):
data = pd.read_csv(csv_file, sep=" ",header=0)
+ if filt_min:
+ data =data[data['duration']>filt_min]
if sorting == "decending":
data = data.sort_values(by=['duration'], ascending=False)
elif sorting == "ascending":
@@ -47,11 +53,13 @@ def __init__(self, csv_file, sorting="original", max_len=None):
self.path = data['wav'].values.astype(np.string_)
if max_len:
assert max_len > 0.0
self.max_len = max_len
del data
@@ -74,6 +82,7 @@ def __len__(self):
return len(self.path)
class ReproducibleRandomSampler(RandomSampler):
"""A modification of RandomSampler which always returns the same values.
@@ -195,6 +204,7 @@ class AddNoise(torch.nn.Module):
def __init__(
+ add_filt_min=None,
@@ -207,6 +217,7 @@ def __init__(
self.csv_file = csv_file
+ self.add_filt_min=add_filt_min
self.sorting = sorting
self.num_workers = num_workers
self.snr_low = snr_low
@@ -287,7 +298,9 @@ def _load_noise(self, lengths, max_length):
# Create a data loader for the noise wavforms
if self.csv_file is not None:
dataset = NoiseDataset(
- self.csv_file, sorting=self.sorting)
+ self.csv_file,
+ filt_min = self.add_filt_min,
+ sorting=self.sorting)
shuffle = (self.sorting == "random")
if torch.distributed.is_initialized():
sampler = torch.utils.data.distributed.DistributedSampler(
@@ -325,7 +338,9 @@ def _load_noise(self, lengths, max_length):
# Ensure noise batch is long enough
elif noise_batch.size(1) < max_length:
- padding = (0, max_length - noise_batch.size(1))
+ pad = max_length - noise_batch.size(1)
+ left_padding = torch.randint(high = pad+1, size=(1,))[0]
+ padding = (left_padding,pad-left_padding)
noise_batch = torch.nn.functional.pad(noise_batch, padding)
# Select a random starting location in the waveform
@@ -429,12 +444,12 @@ def __init__(
self.csv_file = csv_file
self.sorting = sorting
self.reverb_prob = reverb_prob
self.rir_scale_factor = rir_scale_factor
# Create a data loader for the RIR waveforms
dataset = NoiseDataset(
- self.csv_file, sorting=self.sorting)
+ self.csv_file,
+ sorting=self.sorting)
shuffle = (self.sorting == "random")
if torch.distributed.is_initialized():
sampler = torch.utils.data.distributed.DistributedSampler(
@@ -443,7 +458,8 @@ def __init__(
sampler = None
self.data_loader = make_dataloader(
- dataset, shuffle=shuffle, sampler=sampler
+ dataset, shuffle=shuffle,
+ sampler=sampler
self.rir_data = iter(self.data_loader)
@@ -546,6 +562,7 @@ class AddBabble(torch.nn.Module):
def __init__(
+ add_filt_min=None,
@@ -559,6 +576,7 @@ def __init__(
self.csv_file = csv_file
+ self.add_filt_min = add_filt_min
self.sorting = sorting
self.num_workers = num_workers
self.snr_low = snr_low
@@ -651,7 +669,10 @@ def _load_noise(self, lengths, max_length):
# Create a data loader for the noise wavforms
if self.csv_file is not None:
dataset = NoiseDataset(
- self.csv_file, sorting=self.sorting, max_len=self.babble_noise_max_len)
+ self.csv_file,
+ filt_min=self.add_filt_min,
+ sorting=self.sorting,
+ max_len=self.babble_noise_max_len)
shuffle = (self.sorting == "random")
if torch.distributed.is_initialized():
sampler = torch.utils.data.distributed.DistributedSampler(
@@ -689,7 +710,9 @@ def _load_noise(self, lengths, max_length):
# Ensure noise batch is long enough
elif noise_batch.size(1) < max_length:
- padding = (0, max_length - noise_batch.size(1))
+ pad = max_length - noise_batch.size(1)
+ left_padding = torch.randint(high = pad+1, size=(1,))[0]
+ padding = (left_padding,pad-left_padding)
noise_batch = torch.nn.functional.pad(noise_batch, padding)
# Select a random starting location in the waveform
@@ -991,6 +1014,56 @@ def forward(self, waveforms, lengths):
return dropped_waveform
+class RandomChunk(torch.nn.Module):
+ """Get segment.
+ Arguments
+ ---------
+ chunk_len : float
+ Get segment of utts, in senconds (s).
+ sample_rate : int
+ the sampling frequency of the input signal.
+ """
+ def __init__(
+ self,
+ random_chunk=False,
+ chunk_len=2.015,
+ sample_rate=16000,
+ ):
+ super().__init__()
+ self.random_chunk=random_chunk
+ self.lens = int(chunk_len*sample_rate)
+ def forward(self, waveforms,lengths):
+ """
+ Arguments
+ ---------
+ waveforms : tensor
+ Shape should be `[batch, time]` or `[batch, time, channels]`.
+ lengths : tensor
+ Shape should be a single dimension, `[batch]`.
+ Returns
+ -------
+ Tensor of shape `[batch, time]` or `[batch, time, channels]`
+ """
+ if not self.random_chunk:
+ return waveforms,lengths
+ lengths=(lengths * waveforms.shape[1]) # [B]
+ shape = list(waveforms.shape)
+ shape[1] = self.lens
+ chunk_sig = torch.zeros(shape,device=lengths.device)
+ for i in range(shape[0]):
+ if lengths[i] > self.lens:
+ max_chop = (lengths[i] - self.lens).long()
+ start_index = torch.randint(
+ high=max_chop, size=(1,))
+ chunk_sig[i] = waveforms[i,start_index: start_index + self.lens]
+ else:
+ repeat_num = math.ceil(self.lens/lengths[i])
+ chunk_sig[i:i+1] = waveforms[i:i+1,: ].repeat(1,repeat_num)[:,:self.lens]
+ lengths = torch.ones(shape[0])
+ if chunk_sig.shape!=48240:
+ print(chunk_sig.shape)
+ return chunk_sig, lengths
class DoClip(torch.nn.Module):
"""This function mimics audio clipping by clamping the input tensor.
@@ -1041,6 +1114,48 @@ def forward(self, waveforms):
return clipped_waveform
+class SoxEffectTransform(torch.nn.Module):
+ effects: List[List[str]]
+ def __init__(self, effects: List[List[str]],sample_rate:int):
+ super().__init__()
+ self.effects = effects
+ self.sample_rate = sample_rate
+ def forward(self, waveforms: torch.Tensor):
+ """
+ Arguments
+ ---------
+ waveforms : tensor
+ Shape should be `[batch, time]` or `[batch, time, channels]`.
+ Returns
+ -------
+ Tensor of shape `[batch, time]` or `[batch, time, channels]`.
+ """
+ wavs = []
+ if self.effects == [[]]:
+ return waveforms
+ unsqueezed = False
+ if len(waveforms.shape)==2:
+ # add channel
+ waveforms=waveforms.unsqueeze(1)
+ unsqueezed = True
+ else:
+ waveforms = waveforms.transpose(1, 2)
+ for i,wav in enumerate(waveforms):
+ wav,_ = torchaudio.sox_effects.apply_effects_tensor(wav, self.sample_rate, self.effects)
+ wavs.append(wav.unsqueeze(0))
+ wavs = torch.cat(wavs,dim=0)
+ if unsqueezed:
+ wavs=wavs.squeeze(1)
+ else:
+ wavs=wavs.transpose(1,2)
+ return wavs
class SpeedPerturb(torch.nn.Module):
"""Slightly speed up or slow down an audio signal.
@@ -1054,8 +1169,7 @@ class SpeedPerturb(torch.nn.Module):
orig_freq : int
The frequency of the original signal.
speeds : list
- The speeds that the signal should be changed to, as a percentage of the
- original signal (i.e. `speeds` is divided by 100 to get a ratio).
+ A set of different speeds to use to perturb each batch. larger -> slower.
perturb_prob : float
The chance that the batch will be speed-
perturbed. By default, every batch is perturbed.
@@ -1064,27 +1178,52 @@ class SpeedPerturb(torch.nn.Module):
def __init__(
- self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0, keep_shape=True
+ self, orig_freq, speeds=[95, 100, 105], perturb_prob=1.0, keep_shape=True, perturb_type='resample', change_spk=False,spk_num=0
+ assert perturb_type in ['resample','sox_speed','sox_tempo']
self.orig_freq = orig_freq
self.speeds = speeds
self.perturb_prob = perturb_prob
self.keep_shape = keep_shape
+ self.change_spk = change_spk
- # Initialize index of perturbation
- self.samp_index = 0
+ if change_spk:
+ assert spk_num>0, "change_spk need total spk number."
+ self.aug_spks = self._speed_to_speaker(speeds)
+ self.aug_spks = [spk_num*aug_spk for aug_spk in self.aug_spks]
# Initialize resamplers
- self.resamplers = []
+ self.speeders = []
for speed in self.speeds:
- config = {
- "orig_freq": self.orig_freq,
- "new_freq": self.orig_freq * speed // 100,
- }
- self.resamplers.append(Resample(**config))
+ if perturb_type == 'resample':
- def forward(self, waveform):
+ config = {
+ "orig_freq": self.orig_freq,
+ "new_freq": self.orig_freq * speed // 100,
+ }
+ self.speeders.append(Resample(**config))
+ else:
+ if perturb_type == 'sox_speed':
+ if speed==100:
+ effects = [[]]
+ else:
+ speed = round(100/speed,2)
+ effects = [['speed',str(speed)],['rate',str(orig_freq)]]
+ elif perturb_type == 'sox_tempo':
+ if speed==100:
+ effects = [[]]
+ else:
+ speed = round(100/speed,2)
+ effects = [['tempo', str(speed)]]
+ else:
+ raise ValueError("unsupport perturb_type: {}".format(perturb_type))
+ self.speeders.append(SoxEffectTransform(effects,orig_freq))
+ def forward(self, waveform: torch.Tensor, spk_id: torch.Tensor=torch.ones((0), dtype=torch.long)):
@@ -1092,6 +1231,8 @@ def forward(self, waveform):
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
+ spk_id: tensor
+ Shape should be a single dimension, `[batch]`
@@ -1100,11 +1241,13 @@ def forward(self, waveform):
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
if torch.rand(1) > self.perturb_prob:
- return waveform.clone()
+ return waveform.clone(),spk_id
# Perform a random perturbation
- self.samp_index = torch.randint(len(self.speeds), (1,))[0]
- perturbed_waveform = self.resamplers[self.samp_index](waveform)
+ speed_index = torch.randint(len(self.speeds), (1,))[0]
+ perturbed_waveform = self.speeders[speed_index](waveform)
+ if self.change_spk:
+ spk_id = self.aug_spks[speed_index]+ spk_id
if self.keep_shape:
# Managing speed change
@@ -1116,7 +1259,27 @@ def forward(self, waveform):
zero_sig[:, 0: perturbed_waveform.shape[1]
] = perturbed_waveform
perturbed_waveform = zero_sig
- return perturbed_waveform
+ return perturbed_waveform,spk_id
+ def get_spkid_aug(self):
+ sp_aug,spkid_aug =1,1
+ if self.perturb_prob>0:
+ sp_aug = len(set(self.speeds))
+ if self.change_spk:
+ spkid_aug=len(set(self.speeds))
+ return spkid_aug,sp_aug
+ def _speed_to_speaker(self,speeds):
+ assert 100 in speeds, "speed perturb with speaker aug need origin speed."
+ t = {}
+ spk_cont = 0
+ for s in sorted(set(speeds),key = speeds.index):
+ if s ==100:
+ t[s]=0
+ else:
+ spk_cont+=1
+ t[s]=spk_cont
+ return [t[sp] for sp in speeds]
class Resample(torch.nn.Module):
@@ -1449,6 +1612,8 @@ class EnvCorrupt(torch.nn.Module):
A prepared csv file for loading noise data, if None, means white noise.
babble_csv : str
A prepared csv file for loading babble data, if None, means simulated babble noise.
+ add_filt_min : float
+ Filt the short noises in when loading noises and babble from csv.
noise_num_workers : int
Number of workers to use for loading noises.
babble_speaker_count : int
@@ -1482,14 +1647,16 @@ def __init__(
+ add_filt_min = None,
- babble_snr_low=0,
- babble_snr_high=0,
+ babble_snr_low=13,
+ babble_snr_high=20,
- noise_snr_high=0,
- rir_scale_factor=1.0,
+ noise_snr_high=15,
+ pad_noise = False,
+ rir_scale_factor=1.0,
@@ -1506,23 +1673,27 @@ def __init__(
self.add_babble = AddBabble(
+ add_filt_min = add_filt_min,
+ pad_noise=pad_noise,
if noise_prob > 0.0:
self.add_noise = AddNoise(
+ add_filt_min = add_filt_min,
+ pad_noise=pad_noise,
- def forward(self, waveforms, lengths):
+ def forward(self, waveforms, lengths,spk_id:torch.ones((0),dtype=torch.long)):
"""Returns the distorted waveforms.
@@ -1543,7 +1714,7 @@ def forward(self, waveforms, lengths):
if hasattr(self, "add_noise"):
waveforms = self.add_noise(waveforms, lengths)
- return waveforms
+ return waveforms,lengths, spk_id
class TimeDomainSpecAugment(torch.nn.Module):
@@ -1553,8 +1724,9 @@ class TimeDomainSpecAugment(torch.nn.Module):
the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise)
- 2. Drop frequency bands (with band-drop filters)
- 3. Speed peturbation (via resampling to slightly different rate)
+ 2. RandomChunk selection.
+ 3. Drop frequency bands (with band-drop filters)
+ 4. Speed peturbation (via resampling to slightly different rate)
@@ -1565,8 +1737,19 @@ class TimeDomainSpecAugment(torch.nn.Module):
drop_chunk_prob : float from 0 to 1
The probability that a batch will have chunks dropped.
speeds : list of ints
- A set of different speeds to use to perturb each batch.
- See ``speechbrain.processing.speech_augmentation.SpeedPerturb``
+ A set of different speeds to use to perturb each batch. larger -> slower.
+ spk_num: int
+ The total speker num, for aug spkid if needed.
+ perturb_type: str
+ ['resample','sox_speed','sox_tempo']
+ change_spk: bool
+ Whether aug spkid
+ keep_shape: bool
+ keep time length after speed perturb.
+ random_chunk: bool
+ random select chunks after speed perturb.
+ ramddom_chunsize:
+ random chunks length in seconds (s).
sample_rate : int
Sampling rate of the input waveforms.
drop_freq_count_low : int
@@ -1597,10 +1780,15 @@ class TimeDomainSpecAugment(torch.nn.Module):
def __init__(
- drop_freq_prob=1.0,
- drop_chunk_prob=1.0,
- speeds=[95, 100, 105],
+ drop_freq_prob=0.0,
+ drop_chunk_prob=0.0,
+ speeds=[95, 100, 110],
+ spk_num=0,
+ perturb_type='resample',
+ change_spk=False,
+ random_chunk=False,
+ ramddom_chunsize=2.015,
@@ -1613,7 +1801,18 @@ def __init__(
self.speed_perturb = SpeedPerturb(
- perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds, keep_shape=keep_shape
+ perturb_prob=perturb_prob,
+ orig_freq=sample_rate,
+ speeds=speeds,
+ spk_num=spk_num,
+ perturb_type=perturb_type,
+ change_spk=change_spk,
+ keep_shape=keep_shape
+ )
+ self.random_chunk = RandomChunk(
+ random_chunk = random_chunk,
+ chunk_len = ramddom_chunsize,
+ sample_rate = sample_rate
self.drop_freq = DropFreq(
@@ -1629,7 +1828,7 @@ def __init__(
- def forward(self, waveforms, lengths):
+ def forward(self, waveforms, lengths, spk_id:torch.ones((0),dtype=torch.long)):
"""Returns the distorted waveforms.
@@ -1639,11 +1838,18 @@ def forward(self, waveforms, lengths):
# Augmentation
with torch.no_grad():
- waveforms = self.speed_perturb(waveforms)
+ waveforms,spk_id = self.speed_perturb(waveforms,spk_id)
+ waveforms,lengths = self.random_chunk(waveforms,lengths)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
- return waveforms
+ return waveforms,lengths,spk_id
+ def get_spkid_aug(self):
+ return self.speed_perturb.get_spkid_aug()
class SpeechAug(torch.nn.Module):
@@ -1675,16 +1881,17 @@ class SpeechAug(torch.nn.Module):
signal, lens = speech_aug(signal,torch.ones(1))
- def __init__(self, aug_classes=[], mod="random"):
+ def __init__(self, spk_num=0,aug_classes=[], mod="random"):
assert mod in ["random", "concat", "chain"]
self.mod = mod
self.augment = []
self.augment_name = []
+ self.spk_num=spk_num
# define a weight of clean wav type
random_weights = [1] if self.mod == 'random' else []
+ self.spkid_aug,self.sp_aug = 1,1
+ spt_num=0
for aug_class in aug_classes:
assert 'aug_type' in aug_class
@@ -1709,59 +1916,76 @@ def __init__(self, aug_classes=[], mod="random"):
if aug_type == 'Env':
if aug_type == 'Time':
- self.augment.append(TimeDomainSpecAugment(**aug_class))
+ td_aug = TimeDomainSpecAugment(spk_num=spk_num,**aug_class)
+ spkid_aug,sp_aug = td_aug.get_spkid_aug()
+ spt_num+=(int(spkid_aug>1))
+ if spt_num > 1:
+ raise ValueError("multi speaker id perturb setting, check your speech aug config")
+ if spkid_aug>1:
+ self.spkid_aug = spkid_aug
+ self.augment.append(td_aug)
self.random_weight = torch.tensor(
random_weights, dtype=torch.float) if random_weights else None
- self.get_augment()
+ self.print_augment()
- def forward(self, waveforms, lengths):
+ def forward(self, waveforms, lengths, spkid: torch.Tensor=torch.ones((0),dtype=torch.long)):
if not self.augment:
- return waveforms, lengths
+ return waveforms, lengths, spkid
if self.mod == 'random':
- return self._random_forward(waveforms, lengths)
+ return self._random_forward(waveforms, lengths,spkid)
elif self.mod == 'chain':
- return self._chain_forward(waveforms, lengths)
+ return self._chain_forward(waveforms, lengths,spkid)
- return self._concat_forward(waveforms, lengths)
+ return self._concat_forward(waveforms, lengths,spkid)
- def _random_forward(self, waveforms, lengths):
+ def _random_forward(self, waveforms, lengths,spkid):
aug_idx = torch.multinomial(self.random_weight, 1)[0]
if aug_idx == 0:
- return waveforms, lengths
+ return waveforms, lengths,spkid
- waves=self.augment[aug_idx-1](waveforms, lengths)
+ waves,lengths,spkid=self.augment[aug_idx-1](waveforms, lengths,spkid)
raise ValueError('random_1:{},type:{},typename:{}'.format(waveforms,self.augment[aug_idx-1],self.augment_name[aug_idx-1]))
- return waves, lengths
+ return waves, lengths, spkid
- def _concat_forward(self, waveforms, lengths):
+ def _concat_forward(self, waveforms, lengths,spkid):
wavs_aug_tot = []
+ spkids=[]
+ lens = []
+ spkids.append(spkid)
+ lens.append(lengths)
for count, augment in enumerate(self.augment):
- wavs_aug = augment(waveforms, lengths)
+ wavs_aug,len,spkid_a = augment(waveforms, lengths,spkid)
raise ValueError('concat:{},type:{},typename:{}'.format(waveforms,self.augment[count],self.augment_name[count]))
+ spkids.append(spkid_a)
+ lens.append(len)
waveforms = torch.cat(wavs_aug_tot, dim=0)
- lengths = torch.cat([lengths] * len(wavs_aug_tot))
- return waveforms, lengths
+ lens = torch.cat(lens)
+ spkids = torch.cat(spkids, dim=0)
+ return waveforms, lens, spkids
- def _chain_forward(self, waveforms, lengths):
+ def _chain_forward(self, waveforms, lengths,spkid):
for count, augment in enumerate(self.augment):
- waveforms = augment(waveforms, lengths)
+ waveforms,lengths,spkid = augment(waveforms, lengths,spkid)
raise ValueError('chian:{},type:{},typename:{}'.format(waveforms,self.augment[count],self.augment_name[count]))
- return waveforms, lengths
+ return waveforms, lengths,spkid
def get_num_concat(self):
@@ -1770,7 +1994,13 @@ def get_num_concat(self):
return 1
- def get_augment(self):
+ def get_spkid_aug(self):
+ return self.spkid_aug
+ def print_augment(self):
if self.augment:
print('speech augment type is {}.'.format(self.mod))
diff --git a/pytorch/libs/nnet/activation.py b/pytorch/libs/nnet/activation.py
index 5c91ccb..95c2e8f 100644
--- a/pytorch/libs/nnet/activation.py
+++ b/pytorch/libs/nnet/activation.py
@@ -20,6 +20,51 @@ def __init__(self):
def forward(self, inputs):
return inputs * torch.tanh(F.softplus(inputs))
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Return Swish activation function."""
+ return x * torch.sigmoid(x)
+class DoubleSwishFunction(torch.autograd.Function):
+ """
+ double_swish(x) = x * torch.sigmoid(x-1)
+ This is a definition, originally motivated by its close numerical
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
+ Memory-efficient derivative computation:
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
+ Now, s'(x) = s(x) * (1-s(x)).
+ double_swish'(x) = x * s'(x) + s(x).
+ = x * s(x) * (1-s(x)) + s(x).
+ = double_swish(x) * (1-s(x)) + s(x)
+ ... so we just need to remember s(x) but not x itself.
+ """
+ @staticmethod
+ def forward(ctx, x: torch.Tensor) -> torch.Tensor:
+ x = x.detach()
+ s = torch.sigmoid(x - 1.0)
+ y = x * s
+ ctx.save_for_backward(s, y)
+ return y
+ @staticmethod
+ def backward(ctx, y_grad: torch.Tensor) -> torch.Tensor:
+ s, y = ctx.saved_tensors
+ return (y * (1 - s) + s) * y_grad
+class DoubleSwish(torch.nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
+ that we approximate closely with x * sigmoid(x-1).
+ """
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x * torch.sigmoid(x - 1.0)
+ else:
+ return DoubleSwishFunction.apply(x)
## Wrapper ✿
def Nonlinearity(nonlinearity="relu", inplace=True, negative_slope=0.01):
"""A wrapper for activation.
@@ -34,6 +79,13 @@ def Nonlinearity(nonlinearity="relu", inplace=True, negative_slope=0.01):
activation = Mish()
elif nonlinearity == 'tanh' :
activation = torch.nn.Tanh()
+ elif nonlinearity == 'swish' :
+ func = getattr(torch.nn, "SiLU", Swish)
+ activation = func()
+ elif nonlinearity == 'gelu' :
+ activation = torch.nn.GELU(inplace=inplace)
+ elif nonlinearity == 'double_swish' :
+ activation = DoubleSwish()
elif nonlinearity == "" or nonlinearity is None or nonlinearity == False:
activation = None
diff --git a/pytorch/libs/nnet/components.py b/pytorch/libs/nnet/components.py
index eab1cf0..696ba47 100755
--- a/pytorch/libs/nnet/components.py
+++ b/pytorch/libs/nnet/components.py
@@ -12,7 +12,7 @@
from libs.support.utils import to_device
import libs.support.utils as utils
+from libs.nnet.transformer.layer_norm import LayerNorm
### There are some basic custom components/layers. ###
@@ -350,6 +350,7 @@ def add_relu_bn(self, output_dim=None, options:dict={}):
"nonlinearity_params":{"inplace":True, "negative_slope":0.01},
+ "ln_replace": False,
"bn_params":{"momentum":0.1, "affine":True, "track_running_stats":True},
@@ -369,13 +370,19 @@ def add_relu_bn(self, output_dim=None, options:dict={}):
self.bn_relu = False
self.activation = Nonlinearity(default_params["nonlinearity"], **default_params["nonlinearity_params"])
if default_params["bn"]:
- self.batchnorm = torch.nn.BatchNorm1d(output_dim, **default_params["bn_params"])
+ if not default_params['ln_replace']:
+ self.batchnorm = torch.nn.BatchNorm1d(output_dim, **default_params["bn_params"])
+ else:
+ self.batchnorm = LayerNorm(output_dim,dim=1,eps=1e-5,learnabel_affine=default_params["bn_params"]["affine"])
# self.after_forward = self._bn_relu_forward
self.bn_relu = True
if default_params["bn"]:
- self.batchnorm = torch.nn.BatchNorm1d(output_dim, **default_params["bn_params"])
+ if not default_params['ln_replace']:
+ self.batchnorm = torch.nn.BatchNorm1d(output_dim, **default_params["bn_params"])
+ else:
+ self.batchnorm = LayerNorm(output_dim,dim=1,eps=1e-5)
self.activation = Nonlinearity(default_params["nonlinearity"], **default_params["nonlinearity_params"])
if default_params["special_init"] and self.affine is not None and not default_params["jit_compile"]:
@@ -463,7 +470,7 @@ class ReluBatchNormTdnnfLayer(_BaseActivationBatchNorm):
def __init__(self, input_dim, output_dim, inner_size, context_size = 0, **options):
super(ReluBatchNormTdnnfLayer, self).__init__()
- self.affine = FTdnnBlock(input_dim, output_dim, inner_size, context_size)
+ self.affine = TdnnfBlock(input_dim, output_dim, inner_size, context_size)
self.add_relu_bn(output_dim, options=options)
@@ -842,4 +849,17 @@ def to(self, device):
return self
+# Leo 2022-08-14
+class LabelSmoothing(torch.nn.Module):
+ """Label-smoothing for target tensor.
+ """
+ def __init__(
+ self,
+ mean_norm=True,
+ std_norm=True,
+ ):
+ super().__init__()
+ self.mean_norm = mean_norm
+ self.std_norm = std_norm
+ self.eps = 1e-10
\ No newline at end of file
diff --git a/pytorch/libs/nnet/loss.py b/pytorch/libs/nnet/loss.py
index 895dfde..9f50c75 100644
--- a/pytorch/libs/nnet/loss.py
+++ b/pytorch/libs/nnet/loss.py
@@ -3,7 +3,7 @@
# Copyright xmuspeech (Author: Snowdar 2019-05-29)
import numpy as np
+import math
import torch
import torch.nn.functional as F
@@ -88,11 +88,11 @@ def compute_accuracy(self, outputs, targets):
class SoftmaxLoss(TopVirtualLoss):
""" An usual log-softmax loss with affine component.
- def init(self, input_dim, num_targets, t=1, reduction='mean', special_init=False):
+ def init(self, input_dim, num_targets, t=1, reduction='mean', special_init=False,label_smoothing = 0.0):
self.affine = TdnnAffine(input_dim, num_targets)
self.t = t # temperature
# CrossEntropyLoss() has included the LogSoftmax, so do not add this function extra.
- self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction)
+ self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction,label_smoothing=label_smoothing)
# The special_init is not recommended in this loss component
if special_init :
@@ -120,12 +120,12 @@ class SoftmaxLoss_frame_phone_fix(TopVirtualLoss):
#Zheng Li 2021-06-08
""" An usual log-softmax loss with affine component.
- def init(self, input_dim, num_targets, t=1, reduction='mean', special_init=False):
+ def init(self, input_dim, num_targets, t=1, reduction='mean', special_init=False,label_smoothing = 0.0):
self.affine = TdnnAffine(input_dim, num_targets)
self.t = t # temperature
self.num_phones = num_targets - 1
# CrossEntropyLoss() has included the LogSoftmax, so do not add this function extra.
- self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction)
+ self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction,label_smoothing=label_smoothing)
# The special_init is not recommended in this loss component
if special_init :
@@ -187,6 +187,7 @@ def forward(self, inputs, targets):
return self.loss_function(outputs, targets)
class MarginSoftmaxLoss(TopVirtualLoss):
"""Margin softmax loss.
There are AM, AAM, Double-AM, SM1 (Snowdar Margin softmax loss), SM2 and SM3.
@@ -225,12 +226,13 @@ def init(self, input_dim, num_targets,
- reduction='mean', eps=1.0e-10, init=True):
+ reduction='mean', eps=1.0e-10, scale_init=False,label_smoothing = 0.0):
self.input_dim = input_dim
self.num_targets = num_targets
self.weight = torch.nn.Parameter(torch.randn(num_targets, input_dim, 1))
self.s = s # scale factor with feature normalization
+ self.init_m = m
self.m = m # margin
self.t = t # temperature
self.feature_normalize = feature_normalize
@@ -241,8 +243,7 @@ def init(self, input_dim, num_targets,
self.mhe_w = mhe_w
self.inter_loss = inter_loss
self.ring_loss = ring_loss
- self.lambda_factor = 0
+ self.lambda_m = 1
self.curricular = CurricularMarginComponent() if curricular else None
if self.ring_loss > 0:
@@ -260,112 +261,207 @@ def init(self, input_dim, num_targets,
There are some suggested s : {suggested_s} w.r.t p_target {p_target}.".format(
s=self.s, suggested_s=suggested_s, p_target=p_target))
- self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction)
+ self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction,label_smoothing = label_smoothing)
+ self.weight_scale =None
+ self.scale_init=scale_init
# Init weight.
- if init:
+ if scale_init:
+ initial_scale = torch.tensor(1.0).log()
+ self.weight_scale = torch.nn.Parameter(initial_scale.clone().detach())
+ std = 0.1
+ a = (3 ** 0.5) * std
+ torch.nn.init.uniform_(self.weight, -a, a)
+ fan_in = self.weight.shape[1] *self.weight[0][0].numel()
+ scale = fan_in ** -0.5 # 1/sqrt(fan_in)
+ with torch.no_grad():
+ self.weight_scale += torch.tensor(scale / std).log()
+ else:
# torch.nn.init.xavier_normal_(self.weight, gain=1.0)
torch.nn.init.normal_(self.weight, 0., 0.01) # It seems better.
+ def get_weight(self):
+ if self.scale_init:
+ return self.weight * self.weight_scale.exp()
+ else:
+ return self.weight
def forward(self, inputs, targets):
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
assert len(inputs.shape) == 3
assert inputs.shape[2] == 1
+ # with torch.cuda.amp.autocast(enabled=False):
## Normalize
- normalized_x = F.normalize(inputs.squeeze(dim=2), dim=1) # [512, 512] [batch_size, output_dim]
- normalized_weight = F.normalize(self.weight.squeeze(dim=2), dim=1) # [1000, 512] [target_num, output_dim]
- cosine_theta = F.linear(normalized_x, normalized_weight) # Y = W*X
- if not self.feature_normalize :
- self.s = inputs.norm(2, dim=1) # [batch-size, l2-norm]
- # The accuracy must be reported before margin penalty added
- self.posterior = (self.s.detach() * cosine_theta.detach()).unsqueeze(2)
- else:
- self.posterior = (self.s * cosine_theta.detach()).unsqueeze(2)
- if not self.training:
- # For valid set.
- outputs = self.s * cosine_theta
- return self.loss_function(outputs, targets)
- ## Margin Penalty
- # cosine_theta [batch_size, num_class]
- # targets.unsqueeze(1) [batch_size, 1]
- cosine_theta_target = cosine_theta.gather(1, targets.unsqueeze(1))
- if self.inter_loss > 0:
- inter_cosine_theta = torch.softmax(self.s * cosine_theta, dim=1)
- inter_cosine_theta_target = inter_cosine_theta.gather(1, targets.unsqueeze(1))
- inter_loss = torch.log((inter_cosine_theta.sum(dim=1) - inter_cosine_theta_target)/(self.num_targets - 1) + self.eps).mean()
+ normalized_x = F.normalize(inputs.squeeze(dim=2), dim=1) # [512, 512] [batch_size, output_dim]
+ normalized_weight = F.normalize(self.get_weight().squeeze(dim=2), dim=1) # [1000, 512] [target_num, output_dim]
+ with torch.cuda.amp.autocast(enabled=False):
+ cosine_theta = F.linear(normalized_x, normalized_weight) # Y = W*X
+ if not self.feature_normalize :
+ self.s = inputs.norm(2, dim=1) # [batch-size, l2-norm]
+ # The accuracy must be reported before margin penalty added
+ self.posterior = (self.s.detach() * cosine_theta.detach()).unsqueeze(2)
+ else:
+ self.posterior = (self.s * cosine_theta.detach()).unsqueeze(2)
+ if not self.training:
+ # For valid set.
+ outputs = self.s * cosine_theta
+ return self.loss_function(outputs, targets)
+ ## Margin Penalty
+ # cosine_theta [batch_size, num_class]
+ # targets.unsqueeze(1) [batch_size, 1]
+ cosine_theta_target = cosine_theta.gather(1, targets.unsqueeze(1))
+ if self.inter_loss > 0:
+ inter_cosine_theta = torch.softmax(self.s * cosine_theta, dim=1)
+ inter_cosine_theta_target = inter_cosine_theta.gather(1, targets.unsqueeze(1))
+ inter_loss = torch.log((inter_cosine_theta.sum(dim=1) - inter_cosine_theta_target)/(self.num_targets - 1) + self.eps).mean()
+ if self.method == "am":
+ penalty_cosine_theta = cosine_theta_target - self.m
+ if self.double:
+ double_cosine_theta = cosine_theta + self.m
+ elif self.method == "aam":
+ # Another implementation w.r.t cosine(theta+m) = cosine_theta * cos_m - sin_theta * sin_m
+ # penalty_cosine_theta = self.cos_m * cosine_theta_target - self.sin_m * torch.sqrt((1-cosine_theta_target**2).clamp(min=0.))
+ penalty_cosine_theta = torch.cos(torch.acos(cosine_theta_target) + self.m)
+ if self.double:
+ double_cosine_theta = torch.cos(torch.acos(cosine_theta).add(-self.m))
+ elif self.method == "sm1":
+ # penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target) * self.m
+ penalty_cosine_theta = (1 + self.m) * cosine_theta_target - self.m
+ elif self.method == "sm2":
+ penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target**2) * self.m
+ elif self.method == "sm3":
+ penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target)**2 * self.m
+ else:
+ raise ValueError("Do not support this {0} margin w.r.t [ am | aam | sm1 | sm2 | sm3 ]".format(self.method))
+ penalty_cosine_theta = self.lambda_m * penalty_cosine_theta + \
+ (1 - self.lambda_m) * cosine_theta_target
- if self.method == "am":
- penalty_cosine_theta = cosine_theta_target - self.m
- if self.double:
- double_cosine_theta = cosine_theta + self.m
- elif self.method == "aam":
- # Another implementation w.r.t cosine(theta+m) = cosine_theta * cos_m - sin_theta * sin_m
- # penalty_cosine_theta = self.cos_m * cosine_theta_target - self.sin_m * torch.sqrt((1-cosine_theta_target**2).clamp(min=0.))
- penalty_cosine_theta = torch.cos(torch.acos(cosine_theta_target) + self.m)
if self.double:
- double_cosine_theta = torch.cos(torch.acos(cosine_theta).add(-self.m))
- elif self.method == "sm1":
- # penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target) * self.m
- penalty_cosine_theta = (1 + self.m) * cosine_theta_target - self.m
- elif self.method == "sm2":
- penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target**2) * self.m
- elif self.method == "sm3":
- penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target)**2 * self.m
- else:
- raise ValueError("Do not support this {0} margin w.r.t [ am | aam | sm1 | sm2 | sm3 ]".format(self.method))
- penalty_cosine_theta = 1 / (1 + self.lambda_factor) * penalty_cosine_theta + \
- self.lambda_factor / (1 + self.lambda_factor) * cosine_theta_target
- if self.double:
- cosine_theta = 1/(1+self.lambda_factor) * double_cosine_theta + self.lambda_factor/(1+self.lambda_factor) * cosine_theta
- if self.curricular is not None:
- cosine_theta = self.curricular(cosine_theta, cosine_theta_target, penalty_cosine_theta)
- outputs = self.s * cosine_theta.scatter(1, targets.unsqueeze(1), penalty_cosine_theta)
+ cosine_theta = self.lambda_m * double_cosine_theta + (1 - self.lambda_m) * cosine_theta
+ if self.curricular is not None:
+ cosine_theta = self.curricular(cosine_theta, cosine_theta_target, penalty_cosine_theta)
+ # cosine_theta = cosine_theta.float()
+ # penalty_cosine_theta = penalty_cosine_theta.float()
+ outputs = self.s * cosine_theta.scatter(1, targets.unsqueeze(1), penalty_cosine_theta)
+ ## Other extra loss
+ # Final reported loss will be always higher than softmax loss for the absolute margin penalty and
+ # it is a lie about why we can not decrease the loss to a mininum value. We should not report the
+ # loss after margin penalty did but we really report this invalid loss to avoid computing the
+ # training loss twice.
+ if self.ring_loss > 0:
+ ring_loss = torch.mean((self.s - self.r)**2)/2
+ else:
+ ring_loss = 0.
+ if self.mhe_loss:
+ sub_weight = normalized_weight - torch.index_select(normalized_weight, 0, targets).unsqueeze(dim=1)
+ # [N, C]
+ normed_sub_weight = sub_weight.norm(2, dim=2)
+ mask = torch.full_like(normed_sub_weight, True, dtype=torch.bool).scatter_(1, targets.unsqueeze(dim=1), False)
+ # [N, C-1]
+ normed_sub_weight_clean = torch.masked_select(normed_sub_weight, mask).reshape(targets.size()[0], -1)
+ # torch.mean means 1/(N*(C-1))
+ the_mhe_loss = self.mhe_w * torch.mean((normed_sub_weight_clean**2).clamp(min=self.eps)**-1)
+ return self.loss_function(outputs/self.t, targets) + the_mhe_loss + self.ring_loss * ring_loss
+ elif self.inter_loss > 0:
+ return self.loss_function(outputs/self.t, targets) + self.inter_loss * inter_loss + self.ring_loss * ring_loss
+ else:
+ return self.loss_function(outputs/self.t, targets) + self.ring_loss * ring_loss
+ def step(self, lambda_m, add_margin=None):
+ self.lambda_m = lambda_m
+ if add_margin is not None:
+ self.m = max(0,self.init_m+add_margin)
- ## Other extra loss
- # Final reported loss will be always higher than softmax loss for the absolute margin penalty and
- # it is a lie about why we can not decrease the loss to a mininum value. We should not report the
- # loss after margin penalty did but we really report this invalid loss to avoid computing the
- # training loss twice.
- if self.ring_loss > 0:
- ring_loss = torch.mean((self.s - self.r)**2)/2
+ def extra_repr(self):
+ return '(~affine): (input_dim={input_dim}, num_targets={num_targets}, method={method}, double={double}, ' \
+ 'margin={m}, s={s}, t={t}, feature_normalize={feature_normalize}, mhe_loss={mhe_loss}, mhe_w={mhe_w}, ' \
+ 'eps={eps}, scale_init={scale_init})'.format(**self.__dict__)
+# Leo 2022-11-08
+class MarginWarm(torch.nn.Module):
+ def __init__(self,
+ start_epoch,
+ end_epoch,
+ offset_margin = 0.0,
+ init_lambda = 1.0,
+ epoch_iter=None):
+ '''
+ between start_epoch and end_epoch, the offset_margin is
+ exponentially increasing from offset_margin (usually negative) to 0.
+ And the lambda_t is linearly increasing from init_lambda to 1.
+ It is designed to control the margin_softmaxloss through `margin + offset_margin` and
+ `penalty_cosine_theta = lambda * penalty_cosine_theta + (1 - lambda) * cosine_theta_target`
+ '''
+ super().__init__()
+ if end_epoch < start_epoch:
+ raise ValueError("End_epoch should not smaller then start_epoch, but got end_epoch: {}, start_epoch:{}"
+ .format(end_epoch,start_epoch))
+ assert abs(init_lambda - 0.5)<=0.5,"init_lambda should be in [0, 1]"
+ self.start_epoch = start_epoch
+ self.end_epoch = end_epoch
+ self.offset_margin = offset_margin
+ self.init_lambda = init_lambda
+ self.epoch_iter = epoch_iter
+ if epoch_iter:
+ self.update_step_range(epoch_iter)
+ def update_step_range(self,epoch_iter,overwrite=False):
+ if not overwrite and self.epoch_iter:
+ raise ValueError("epoch_iter has been set as {}, but overwrite = {} now".format(self.epoch_iter,overwrite))
- ring_loss = 0.
- if self.mhe_loss:
- sub_weight = normalized_weight - torch.index_select(normalized_weight, 0, targets).unsqueeze(dim=1)
- # [N, C]
- normed_sub_weight = sub_weight.norm(2, dim=2)
- mask = torch.full_like(normed_sub_weight, True, dtype=torch.bool).scatter_(1, targets.unsqueeze(dim=1), False)
- # [N, C-1]
- normed_sub_weight_clean = torch.masked_select(normed_sub_weight, mask).reshape(targets.size()[0], -1)
- # torch.mean means 1/(N*(C-1))
- the_mhe_loss = self.mhe_w * torch.mean((normed_sub_weight_clean**2).clamp(min=self.eps)**-1)
- return self.loss_function(outputs/self.t, targets) + the_mhe_loss + self.ring_loss * ring_loss
- elif self.inter_loss > 0:
- return self.loss_function(outputs/self.t, targets) + self.inter_loss * inter_loss + self.ring_loss * ring_loss
+ self.epoch_iter = epoch_iter
+ self.increase_start_iter = (self.start_epoch - 1) * epoch_iter
+ self.fix_start_iter = (self.end_epoch - 1) * epoch_iter
+ self.step_range = self.fix_start_iter - self.increase_start_iter
+ def get_increase_margin(self, cur_step):
+ initial_val = 1.0
+ final_val = 1e-3
+ cur_pos = cur_step - self.increase_start_iter
+ ratio = math.exp(
+ (cur_pos / self.step_range) *
+ math.log(final_val / (initial_val + 1e-6))) * initial_val
+ offset_margin = self.offset_margin * ratio
+ lambda_t = self.init_lambda + (cur_pos / self.step_range) * (1-self.init_lambda)
+ return offset_margin, lambda_t
+ def step(self, cur_step):
+ if self.epoch_iter<0 or not isinstance(self.epoch_iter, int):
+ raise ValueError("MarginWarm expected positive integer epoch_iter, but got {}"
+ .format(self.epoch_iter))
+ if cur_step >= self.fix_start_iter:
+ return 0,1
+ elif cur_step > self.increase_start_iter:
+ return self.get_increase_margin(cur_step)
- return self.loss_function(outputs/self.t, targets) + self.ring_loss * ring_loss
- def step(self, lambda_factor):
- self.lambda_factor = lambda_factor
+ return self.offset_margin,self.init_lambda
def extra_repr(self):
- return '(~affine): (input_dim={input_dim}, num_targets={num_targets}, method={method}, double={double}, ' \
- 'margin={m}, s={s}, t={t}, feature_normalize={feature_normalize}, mhe_loss={mhe_loss}, mhe_w={mhe_w}, ' \
- 'eps={eps})'.format(**self.__dict__)
+ return 'start_epoch={start_epoch}, end_epoch={end_epoch}, epoch_iter={epoch_iter}, ' \
+ 'offset_margin={offset_margin} init_lambda={init_lambda})'.format(**self.__dict__)
class CurricularMarginComponent(torch.nn.Module):
diff --git a/pytorch/libs/nnet/pooling.py b/pytorch/libs/nnet/pooling.py
index 117bdc4..4fd119c 100644
--- a/pytorch/libs/nnet/pooling.py
+++ b/pytorch/libs/nnet/pooling.py
@@ -3,7 +3,7 @@
# Copyright xmuspeech (Author: Snowdar 2019-05-29 2020-06-10)
import numpy as np
+from typing import Optional
import torch
import torch.nn.functional as F
@@ -29,32 +29,45 @@ def __init__(self, input_dim, stddev=True, unbiased=False, eps=1.0e-10):
# Used for unbiased estimate of stddev
self.unbiased = unbiased
- def forward(self, inputs):
+ def forward(self, inputs, lengths:torch.Tensor = torch.ones((0),dtype=torch.long)):
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
assert len(inputs.shape) == 3
assert inputs.shape[1] == self.input_dim
- # Get the num of frames
- counts = inputs.shape[2]
- mean = inputs.sum(dim=2, keepdim=True) / counts
- if self.stddev :
- if self.unbiased and counts > 1:
- counts = counts - 1
- # The sqrt (as follows) is deprecated because it results in Nan problem.
- # std = torch.unsqueeze(torch.sqrt(torch.sum((inputs - mean)**2, dim=2) / counts), dim=2)
- # There is a eps to solve this problem.
- # Another method: Var is equal to std in "cat" way, actually. So, just use Var directly.
- var = torch.sum((inputs - mean)**2, dim=2, keepdim=True) / counts
- std = torch.sqrt(var.clamp(min=self.eps))
- return torch.cat((mean, std), dim=1)
+ if lengths.size(0) > 0:
+ mean = []
+ std = []
+ for i in range(inputs.shape[0]):
+ act_len = lengths[i]
+ act_counts = act_len
+ mean_i = torch.mean(inputs[i,:,:act_len],dim=1,keepdim=True)
+ mean.append(mean_i)
+ if self.stddev :
+ if self.unbiased and act_len > 1:
+ act_counts = act_len - 1
+ var = torch.sum((inputs[i,:,:act_len]-mean_i)**2, dim=1, keepdim=True)/act_counts
+ std.append(torch.sqrt(var.clamp(min=self.eps)))
+ mean = torch.stack(mean)
+ out = mean
+ if self.stddev :
+ std_o = torch.stack(std)
+ out = torch.cat((mean, std_o), dim=1)
- return mean
+ counts = inputs.shape[2]
+ mean = inputs.mean(dim=2, keepdim=True)
+ out = mean
+ if self.stddev:
+ if self.unbiased and counts > 1:
+ counts = counts - 1
+ var = torch.sum((inputs - mean)**2, dim=2, keepdim=True) / counts
+ std = torch.sqrt(var.clamp(min=self.eps))
+ out = torch.cat((mean, std), dim=1)
+ return out
def get_output_dim(self):
return self.output_dim
diff --git a/pytorch/libs/nnet/transformer/__init__.py b/pytorch/libs/nnet/transformer/__init__.py
index 79abb45..0dee3d0 100644
--- a/pytorch/libs/nnet/transformer/__init__.py
+++ b/pytorch/libs/nnet/transformer/__init__.py
@@ -3,9 +3,8 @@
# Copyright xmuspeech (Author: Snowdar 2020-02-21)
-__all__ = ["TransformerEncoder", "repeat", "MultiSequential", "MultiHeadedAttention", "LayerNorm"]
+__all__ = ["TransformerEncoder", "ConformerEncoder", "MultiHeadedAttention", "LayerNorm","ReConformerEncoder"]
-from .repeat import *
from .attention import *
from .layer_norm import *
-from .TransformerEncoder import *
\ No newline at end of file
+from .encoder import *
\ No newline at end of file
diff --git a/pytorch/libs/nnet/transformer/attention.py b/pytorch/libs/nnet/transformer/attention.py
index 6402b63..7e110c8 100644
--- a/pytorch/libs/nnet/transformer/attention.py
+++ b/pytorch/libs/nnet/transformer/attention.py
@@ -1,45 +1,75 @@
# -*- coding:utf-8 -*-
+# Copyright xmuspeech (Author: Leo 2022-07)
+# Reference: https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/attention.py
-# Reference: https://github.com/espnet/espnet.
-import math
+"""Multi-Head Attention layer definition."""
-import numpy
+from cmath import pi
+import math
+from typing import Optional, Tuple, Dict, Any
import torch
from torch import nn
+from libs.nnet.activation import Nonlinearity
+from .scaling import ScaledLinear,ScaledConv1d,ActivationBalancer
class MultiHeadedAttention(nn.Module):
- """Multi-Head Attention layer
+ """Multi-Head Attention layer.
- :param int n_head: the number of head s
- :param int n_feat: the number of features
- :param float dropout_rate: dropout rate
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix.
+ t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share a T5 rel position matrix in all layers.
- def __init__(self, n_head, n_feat, dropout_rate):
- super(MultiHeadedAttention, self).__init__()
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float,add_t5rel_bias: bool =False,conv_out: bool =False,attention_norm_args: Optional[Dict[str, Any]]=None,re_scale=False):
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
- self.linear_q = nn.Linear(n_feat, n_feat)
- self.linear_k = nn.Linear(n_feat, n_feat)
- self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.attn = None
+ self.conv_out = conv_out
+ self.att_norm = AttentionNormalize(self.d_k, att_type = "mlh", **attention_norm_args)
+ project = ScaledLinear if re_scale else nn.Linear
+ self.linear_q = project(n_feat, n_feat)
+ self.linear_k = project(n_feat, n_feat)
+ self.linear_v = project(n_feat, n_feat)
+ if self.conv_out:
+ conv = ScaledConv1d if re_scale else nn.Conv1d
+ self.linear_out = conv(n_feat, n_feat, 3,
+ stride=1, padding=1)
+ else:
+ self.linear_out = ScaledLinear(n_feat, n_feat,initial_scale=0.25) if re_scale else nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
- def forward(self, query, key, value, mask):
- """Compute 'Scaled Dot Product Attention'
+ self.add_t5rel_bias = add_t5rel_bias
+ self.t5rel_module = T5RelPositionBias(self.d_k**0.5) if self.add_t5rel_bias else None
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ Returns:
+ torch.Tensor: Transformed query tensor, size
+ (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor, size
+ (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor, size
+ (#batch, n_head, time2, d_k).
- :param torch.Tensor query: (batch, time1, size)
- :param torch.Tensor key: (batch, time2, size)
- :param torch.Tensor value: (batch, time2, size)
- :param torch.Tensor mask: (batch, time1, time2)
- :param torch.nn.Dropout dropout:
- :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
- weighted by the query dot key attention (batch, head, time1, time2)
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
@@ -49,16 +79,655 @@ def forward(self, query, key, value, mask):
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # (batch, head, time1, time2)
- if mask is not None:
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
+ return q, k, v
+ def forward_attention(self, value: torch.Tensor, scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)) -> torch.Tensor:
+ """Compute attention context vector.
+ Args:
+ value (torch.Tensor): Transformed value, size
+ (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score, size
+ (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+ """
+ n_batch = value.size(0)
+ attn = self.att_norm(scores,mask)
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
+ self.h * self.d_k)
+ ) # (batch, time1, d_model)
+ if self.conv_out:
+ x = self.linear_out(x.transpose(-1, 1)).transpose(-1, 1)
+ else:
+ x = self.linear_out(x)
+ return x # (batch, time1, d_model)
+ def forward(self, query: torch.Tensor, key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0)) -> torch.Tensor:
+ """Compute scaled dot product attention.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ 1.When applying cross attention between decoder and encoder,
+ the batch padding mask for input is in (#batch, 1, T) shape.
+ 2.When applying self attention of encoder,
+ the mask is in (#batch, T, T) shape.
+ 3.When applying self attention of decoder,
+ the mask is in (#batch, L, L) shape.
+ 4.If the different position in decoder see different block
+ of the encoder, such as Mocha, the passed in mask could be
+ in (#batch, L, T) shape. But there is no such case in current
+ Wenet.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = torch.matmul(q, k.transpose(-2, -1))
+ if self.add_t5rel_bias and self.t5rel_module is not None:
+ scores+=self.t5rel_module(scores)
+ return self.forward_attention(v, scores, mask)
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix.
+ t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share a T5 rel position matrix in all layers.
+ """
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float,add_t5rel_bias: bool =False,conv_out: bool =False,attention_norm_args: Optional[Dict[str, Any]]=None,re_scale=False):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate,add_t5rel_bias,conv_out,attention_norm_args,re_scale)
+ # linear transformation for positional encoding
+ project = ScaledLinear if re_scale else nn.Linear
+ self.linear_pos = project(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+ def rel_shift(self, x, zero_triu: bool = False):
+ """Compute relative positinal encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, size).
+ zero_triu (bool): If true, return the lower triangular part of
+ the matrix.
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
+ device=x.device,
+ dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+ x_padded = x_padded.view(x.size()[0],
+ x.size()[1],
+ x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)
+ if zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)))
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+ return x
+ def forward(self, query: torch.Tensor, key: torch.Tensor,
+ value: torch.Tensor, mask: torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0)):
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, time2, size).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ # Remove rel_shift since it is useless in speech recognition,
+ # and it requires special attention for streaming.
+ # matrix_bd = self.rel_shift(matrix_bd)
+ scores = matrix_ac + matrix_bd # (batch, head, time1, time2)
+ if self.add_t5rel_bias and self.t5rel_module is not None:
+ scores+=self.t5rel_module(scores)
+ return self.forward_attention(v, scores, mask)
+# RoPE (Leo 2022-07-25)
+# reference:
+# RoFormer: Enhanced Transformer with Rotary Position Embedding.
+class RoPESelfAttention(MultiHeadedAttention):
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float,add_t5rel_bias: bool =False,conv_out: bool =False, attention_norm_args: Optional[Dict[str, Any]]=None, rotary_value: bool = True,re_scale=False):
+ """Construct an RelPositionMultiHeadedAttention object.
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ rotary_value (bool): add rotary positon to value tensor.
+ add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix.
+ t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share one T5 rel position matrix in all layers.
+ """
+ super().__init__(n_head, n_feat, dropout_rate,add_t5rel_bias,conv_out,attention_norm_args,re_scale)
+ self.rotary_value = rotary_value
+ def forward(self, query: torch.Tensor, key: torch.Tensor,
+ value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb:torch.Tensor = torch.empty(0)):
+ """Compute 'Scaled Dot Product Attention' with rotary positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ pos_emb : Positional embedding tensor(#batch, time2, size//2) of query and key_value,
+ each is a tuple contains sin part tensor and cos part tensor.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, size).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = self.apply_rotary(q,pos_emb)
+ k = self.apply_rotary(k,pos_emb)
+ if self.rotary_value:
+ v = self.apply_rotary(v,pos_emb)
+ scores = torch.matmul(q, k.transpose(-2, -1))
+ if self.add_t5rel_bias and self.t5rel_module is not None:
+ scores+=self.t5rel_module(scores)
+ return self.forward_attention(v, scores, mask)
+ @staticmethod
+ def apply_rotary(x, sinusoidal_pos):
+ sin, cos = sinusoidal_pos.chunk(2,dim=-1) # (1, time1, d_model//2)
+ x1, x2 = x[..., 0::2], x[..., 1::2]
+ return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1)
+# T5RelPE (Leo 2022-07-25)
+# reference:
+# Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
+# https://arxiv.org/abs/1910.10683
+class T5RelPositionBias(torch.nn.Module):
+ """T5 rel_pos, which is trainalble bias matrix added to attention score.
+ Args:
+ scale (float): scale the bias, usually set to query_dim**0.5.
+ causal (bool): true, means omit the future.
+ num_buckets (int): relative to the length of sensitive area.
+ max_distance (int): when distance > max_distance,they share the same bias.
+ """
+ def __init__(
+ self,
+ scale,
+ causal = False,
+ num_buckets = 32,
+ max_distance = 128
+ ):
+ super().__init__()
+ self.scale = scale
+ self.causal = causal
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = torch.nn.Embedding(num_buckets, 1)
+ @staticmethod
+ def _relative_position_bucket(
+ relative_position,
+ causal = False,
+ num_buckets = 32,
+ max_distance = 128
+ ):
+ ret = 0
+ n = -relative_position
+ if not causal:
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+ else:
+ n = torch.max(n, torch.zeros_like(n))
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+ def forward(self, x):
+ """Get the rel bias from relative_position_bucket.
+ Args:
+ x (torch.Tensor): attention scores. (#batch, head, time1, time2).
+ Returns:
+ torch.Tensor: bias tensor. (#time1, time2).
+ """
+ i, j, device = *x.shape[-2:], x.device
+ # get q,k position
+ q_pos = torch.arange(i, dtype = torch.long, device = device)
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
+ # calculate the relative position distance matrix (i, j)
+ rel_pos = k_pos.unsqueeze(0) - q_pos.unsqueeze(1)
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ bias = values.squeeze(-1)
+ return bias * self.scale
+class OffsetScale(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.ones(1,dim))
+ self.beta = nn.Parameter(torch.zeros(1,dim))
+ nn.init.xavier_uniform_(self.gamma)
+ def forward(self, x):
+ out = x*self.gamma + self.beta
+ return out
+# (Leo 2022-08-10)
+class GAU(nn.Module):
+ """Gated Attention Unit. now it just support selfatt for the whole sequence.
+ Args:
+ hidden_dim (int): The size of hidden_dim, recommend 2*n_feat.
+ n_feat (int): The number of features.
+ d_k (int): Dim of query,key, default (128).
+ dropout_rate (float): Dropout rate.
+ activation_type (str): activation function.
+ add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix.
+ t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share a T5 rel position matrix in all layers.
+ """
+ def __init__(self, n_feat: int, hidden_dim:int ,d_k: int = 128 ,dropout_rate: float = 0.,
+ conv_out: bool =False,
+ attention_norm_args: Optional[Dict[str, Any]]=None,
+ re_scale: bool=False,
+ activation_type='swish',
+ add_t5rel_bias: bool =False,
+ ):
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+ self.d_k =d_k
+ self.conv_out = conv_out
+ project = ScaledLinear if re_scale else nn.Linear
+ banlancer = nn.Identity if re_scale else ActivationBalancer
+ self.to_gate = nn.Sequential(
+ project(n_feat, hidden_dim),
+ banlancer(),
+ Nonlinearity(activation_type)
+ )
+ self.to_v = nn.Sequential(
+ project(n_feat, hidden_dim),
+ banlancer(),
+ Nonlinearity(activation_type)
+ )
+ self.to_qk = nn.Sequential(
+ project(n_feat, d_k),
+ banlancer(),
+ Nonlinearity(activation_type)
+ )
+ if self.conv_out:
+ conv = ScaledConv1d if re_scale else nn.Conv1d
+ self.to_out = nn.Sequential(
+ conv(hidden_dim, n_feat, 3,
+ stride=1, padding=1)
+ )
+ else:
+ self.to_out = nn.Sequential(
+ project(hidden_dim, n_feat)
+ )
+ self.att_norm = AttentionNormalize(self.d_k, att_type = "gau", **attention_norm_args)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.scale_q = OffsetScale(d_k)
+ self.scale_k = OffsetScale(d_k)
+ self.add_t5rel_bias = add_t5rel_bias
+ self.t5rel_module = T5RelPositionBias(self.d_k**0.5) if self.add_t5rel_bias else None
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ now it assume q = k = v
+ Returns:
+ torch.Tensor: Transformed query tensor, size
+ (#batch, time1, d_k).
+ torch.Tensor: Transformed key tensor, size
+ (#batch, time2, d_k).
+ torch.Tensor: gate tensor, size
+ (#batch, time1, hidden_dim).
+ torch.Tensor: value tensor, size
+ (#batch, time2, hidden_dim).
+ """
+ u = self.to_gate(query) # (batch, time1, hidden_dim)
+ v = self.to_v(value) # (batch, time2, hidden_dim)
+ # here q is the whole sequence.
+ qk = self.to_qk(key) # (batch, time1, d_k)
+ q = self.scale_q(qk)
+ k = self.scale_k(qk)
+ return q, k, u, v
+ def forward_qkuv(self, query: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ now it assume q = k = v
+ Returns:
+ torch.Tensor: Transformed query tensor, size
+ (#batch, time1, d_k).
+ torch.Tensor: Transformed key tensor, size
+ (#batch, time2, d_k).
+ torch.Tensor: gate tensor, size
+ (#batch, time1, hidden_dim).
+ torch.Tensor: value tensor, size
+ (#batch, time2, hidden_dim).
+ """
+ u = self.to_gate(query) # (batch, time1, hidden_dim)
+ v = self.to_v(query) # (batch, time2, hidden_dim)
+ # here q is the whole sequence.
+ qk = self.to_qk(query) # (batch, time1, d_k)
+ q = self.scale_q(qk)
+ k = self.scale_k(qk)
+ return q, k, u, v
+ def forward_attention(self, u: torch.Tensor, value: torch.Tensor, scores: torch.Tensor,
+ mask:torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)) -> torch.Tensor:
+ """Compute attention context vector.
+ Args:
+ u (torch.Tensor): Transformed gate, size
+ (#batch, time1, hidden_dim).
+ value (torch.Tensor): Transformed value, size
+ (#batch, time2, hidden_dim).
+ scores (torch.Tensor): Attention score, size
+ (#batch, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+ """
+ attn = self.att_norm(scores,mask)
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, time1, time2) (batch, time2, hidden_dim) -> (batch, time1, hidden_dim)
+ x = u * x
+ if self.conv_out:
+ x = self.to_out(x.transpose(-1, 1)).transpose(-1, 1)
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ x = self.to_out(x)
+ return x # (batch, time1, n_feat)
+ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor],
+ value: Optional[torch.Tensor],
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0)) -> torch.Tensor:
+ """Compute Gated Attention.
+ now it just support selfatt for the whole sequence,
+ which means q = k = v.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ 1.When applying cross attention between decoder and encoder,
+ the batch padding mask for input is in (#batch, 1, T) shape.
+ 2.When applying self attention of encoder,
+ the mask is in (#batch, T, T) shape.
+ 3.When applying self attention of decoder,
+ the mask is in (#batch, L, L) shape.
+ 4.If the different position in decoder see different block
+ of the encoder, such as Mocha, the passed in mask could be
+ in (#batch, L, T) shape. But there is no such case in current
+ Wenet.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, n_feat).
+ """
+ q, k, u, v = self.forward_qkuv(query)
+ scores = torch.matmul(q, k.transpose(-2, -1))
+ if self.add_t5rel_bias and self.t5rel_module is not None:
+ scores+=self.t5rel_module(scores)
+ return self.forward_attention(u, v, scores, mask)
+class RoPEGAU(GAU):
+ def __init__(self, n_feat: int, hidden_dim:int ,d_k: int = 128 ,dropout_rate: float = 0.,
+ conv_out: bool =False,
+ attention_norm_args: Optional[Dict[str, Any]]=None,
+ re_scale: bool=False,
+ activation_type='swish',
+ add_t5rel_bias: bool =False,
+ ):
+ """Construct an RelPositionMultiHeadedAttention object.
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ rotary_value (bool): add rotary positon to value tensor.
+ add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix.
+ t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share one T5 rel position matrix in all layers.
+ """
+ super().__init__(n_feat, hidden_dim, d_k ,dropout_rate,
+ conv_out,
+ attention_norm_args,
+ re_scale,
+ activation_type,
+ add_t5rel_bias,
+ )
+ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor],
+ value: Optional[torch.Tensor], mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0)):
+ """Compute 'Scaled Dot Product Attention' with rotary positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ pos_emb : Positional embedding tensor(#batch, time2, size//2) of query and key_value,
+ each is a tuple contains sin part tensor and cos part tensor.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, size).
+ """
+ q, k, u, v = self.forward_qkuv(query)
+ q = self.apply_rotary(q,pos_emb)
+ k = self.apply_rotary(k,pos_emb)
+ scores = torch.matmul(q, k.transpose(-2, -1))
+ if self.add_t5rel_bias and self.t5rel_module is not None:
+ scores+=self.t5rel_module(scores)
+ return self.forward_attention(u,v, scores, mask)
+ @staticmethod
+ def apply_rotary(x, sinusoidal_pos):
+ sin, cos = sinusoidal_pos.chunk(2,dim=-1) # (1, time1, d_model//2)
+ x1, x2 = x[..., 0::2], x[..., 1::2]
+ return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1)
+# (Leo 2022-08-10)
+class AttentionNormalize(nn.Module):
+ def __init__(self,
+ d_k: int,
+ att_type = "mlh",
+ scale_adapt: bool=False,
+ norm_method: str='softmax',
+ diag_mask: bool=False,
+ g_sa: bool=False,
+ train_len : int = 512,
+ dim: int=-1):
+ super().__init__()
+ self.method = norm_method
+ self.dim = dim
+ self.scale_adapt = scale_adapt
+ if self.scale_adapt:
+ self.scale = nn.Parameter(torch.log(torch.tensor(d_k**-0.5)))
+ else:
+ self.scale = torch.tensor(math.sqrt(d_k))
+ self.diag_mask = diag_mask
+ self.att_type = att_type
+ self.g_sa = g_sa
+ self.omiga = torch.tensor(0.001)
+ self.bias = torch.zeros(1)-0.001
+ if self.g_sa:
+ if 'softmax' not in norm_method:
+ raise ValueError("g_sa just support softmax form calculate now")
+ # self.omiga = nn.Parameter(torch.abs(nn.init.trunc_normal_(torch.empty(1)))+0.001)
+ # self.bias = nn.Parameter(-torch.abs(nn.init.trunc_normal_(torch.empty(1))))
+ self.omiga = nn.Parameter(torch.tensor(0.001))
+ self.bias = nn.Parameter(torch.zeros(1)-0.001)
+ # self.train_len = torch.tensor(math.log(train_len))
+ self.train_len = nn.Parameter(torch.tensor(math.log(train_len))) if self.method =="softmax_plus" else torch.empty(0)
+ def forward(self, scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)):
+ if self.g_sa:
+ i, j=scores.shape[-2:]
+ device = scores.device
+ q_pos = torch.arange(j-i,j, dtype = torch.long, device = device)
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
+ # calculate the relative position distance matrix (i, j)
+ dis_matrix = (k_pos.unsqueeze(0) - q_pos.unsqueeze(1))**2
+ dis_matrix = -torch.abs(torch.abs(dis_matrix*self.omiga) - torch.abs(self.bias)) # (i, j)
+ scores = scores + dis_matrix
+ if self.scale_adapt:
+ scores = scores*(self.scale.exp())
+ else:
+ scores = scores/self.scale
+ if mask.size(2) > 0 : # time2 > 0
+ mask = mask[..., :scores.size(-1)]
+ if self.diag_mask:
+ mask = (~torch.eye(scores.shape[-2],scores.shape[-1],dtype=torch.bool,device=scores.device)) & mask
+ if self.att_type == "gau":
+ mask = mask.eq(0) # (batch, *, time2)
+ else:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ scores = scores.masked_fill(mask, -1e4)
+ attn = self.attention_normalize(scores, dim=self.dim,method=self.method).masked_fill(
+ mask, 0.0) # (batch, head, time1, time2) or # (batch, time1, time2)
+ else:
+ attn = self.attention_normalize(scores, dim=self.dim,method=self.method) # (batch, head, time1, time2) or (batch, time1, time2)
+ return attn
+ def attention_normalize(self,
+ a: torch.Tensor,
+ dim: int=-1,
+ method: str='softmax'):
+ """attention score normalization
+ softmax
+ relu_plus: https://arxiv.org/abs/2202.10447
+ softmax_plus: https://kexue.fm/archives/8823 。
+ """
+ assert method in ['softmax','relu_plus','softmax_plus']
+ if method == 'softmax':
+ return torch.softmax(a, dim=dim)
+ else:
+ mask = (a > -1e4).float()
+ l = torch.sum(mask ,dim=dim, keepdim=True).clamp_(1.)
+ if method == 'relu_plus':
+ return torch.relu(a)**2 / l
+ elif method == 'softmax_plus':
+ scale = torch.log(l) / self.train_len * mask + 1 - mask
+ return torch.softmax(a * scale, dim=dim)
+ else:
+ raise ValueError("check your attention norm ")
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
- x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
- return self.linear_out(x) # (batch, time1, d_model)
+ return a
\ No newline at end of file
diff --git a/pytorch/libs/nnet/transformer/convolution.py b/pytorch/libs/nnet/transformer/convolution.py
new file mode 100644
index 0000000..1b441c8
--- /dev/null
+++ b/pytorch/libs/nnet/transformer/convolution.py
@@ -0,0 +1,231 @@
+# -*- coding:utf-8 -*-
+# Reference: https://github.com/espnet/espnet.
+"""ConvolutionModule definition."""
+from typing import Optional, Tuple
+import torch
+from torch import nn
+from .scaling import ActivationBalancer,ScaledConv1d
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ causal (int): Whether use causal convolution or not
+ """
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int = 15,
+ activation:nn.Module = nn.ReLU(),
+ norm: str = 'batch_norm',
+ causal: bool=False,
+ activation_balancer: bool=False,
+ bias: bool=True
+ ):
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ # self.lorder is used to distinguish if it's a causal convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+ assert norm in ['batch_norm', 'layer_norm', "basic_norm"]
+ if norm == "batch_norm":
+ self.use_layer_norm = False
+ self.norm = nn.BatchNorm1d(channels)
+ else:
+ self.use_layer_norm = True
+ self.norm = nn.LayerNorm(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+ self.balancer1,self.balancer2 = None,None
+ self.activation_balancer = activation_balancer
+ if activation_balancer:
+ self.balancer1=ActivationBalancer(
+ channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
+ )
+ self.balancer2=ActivationBalancer(
+ channel_dim=1, min_positive=0.05, max_positive=1.0
+ )
+ def forward(self,
+ x: torch.Tensor,
+ mask_pad: Optional[torch.Tensor] = None,
+ )-> torch.Tensor:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time)
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2)
+ # mask batch padding
+ if mask_pad is not None:
+ x.masked_fill_(~mask_pad, 0.0)
+ if self.lorder>0:
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ if self.activation_balancer and self.balancer1 is not None:
+ x = self.balancer1(x)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ if self.activation_balancer and self.balancer2 is not None:
+ x = self.balancer2(x)
+ x = self.activation(x)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad is not None:
+ x.masked_fill_(~mask_pad, 0.0)
+ return x.transpose(1, 2)
+class ReConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ causal (int): Whether use causal convolution or not
+ """
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int = 15,
+ activation:nn.Module = nn.ReLU(),
+ causal: bool=False,
+ bias: bool=True
+ ):
+ """Construct an ConvolutionModule object."""
+ super(ReConvolutionModule, self).__init__()
+ self.pointwise_conv1 = ScaledConv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ # self.lorder is used to distinguish if it's a causal convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+ self.depthwise_conv = ScaledConv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+ self.pointwise_conv2 = ScaledConv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ initial_scale=0.25,
+ )
+ self.activation = activation
+ self.balancer1=ActivationBalancer(
+ channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
+ )
+ self.balancer2=ActivationBalancer(
+ channel_dim=1, min_positive=0.05, max_positive=1.0
+ )
+ def forward(self,
+ x: torch.Tensor,
+ mask_pad: Optional[torch.Tensor] = None,
+ )-> torch.Tensor:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time)
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2)
+ # mask batch padding
+ if mask_pad is not None:
+ x.masked_fill_(~mask_pad, 0.0)
+ if self.lorder>0:
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = self.balancer1(x)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ x = self.balancer2(x)
+ x = self.activation(x)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad is not None:
+ x.masked_fill_(~mask_pad, 0.0)
+ return x.transpose(1, 2)
\ No newline at end of file
diff --git a/pytorch/libs/nnet/transformer/embedding.py b/pytorch/libs/nnet/transformer/embedding.py
index b22ed7b..6346957 100644
--- a/pytorch/libs/nnet/transformer/embedding.py
+++ b/pytorch/libs/nnet/transformer/embedding.py
@@ -1,109 +1,198 @@
# -*- coding:utf-8 -*-
-# Reference: https://github.com/espnet/espnet.
+# Copyright xmuspeech (Author: Leo 2022-07)
+# Reference: https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/embedding.py.
"""Positonal Encoding Module."""
import math
-import torch
-def _pre_hook(state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- """Perform pre-hook in load_state_dict for backward compatibility.
+from typing import Tuple
- Note:
- We saved self.pe until v.0.5.2 but we have omitted it later.
- Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
- """
- k = prefix + "pe"
- if k in state_dict:
- state_dict.pop(k)
+import torch
+def get_abs_position(dim: int,
+ max_len: int) -> torch.Tensor:
+ abs_pe = torch.zeros(max_len, dim)
+ position = torch.arange(0, max_len,
+ dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, dim, 2, dtype=torch.float32) *
+ -(math.log(10000.0) / dim))
+ abs_pe[:, 0::2] = torch.sin(position * div_term)
+ abs_pe[:, 1::2] = torch.cos(position * div_term)
+ return abs_pe.unsqueeze(0)
class PositionalEncoding(torch.nn.Module):
- """Positional encoding."""
- def __init__(self, d_model, dropout_rate, max_len=5000):
- """Initialize class.
- :param int d_model: embedding dim
- :param float dropout_rate: dropout rate
- :param int max_len: maximum input length
- """
- super(PositionalEncoding, self).__init__()
+ """Positional encoding.
+ :param int d_model: embedding dim
+ :param float dropout_rate: dropout rate
+ :param int max_len: maximum input length
+ :param int att_h: invalid here,for compatibility to RoPositionalEncoding
+ :param bool rope_abs_plus: invalid here,for compatibility to RoPositionalEncoding
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
+ """
+ def __init__(self,
+ d_model: int,
+ dropout_rate: float,
+ att_h: int=4,
+ rope_abs_plus: bool=False,
+ max_len: int = 5000,
+ reverse: bool = False):
+ """Construct an PositionalEncoding object."""
+ super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
- self.pe = None
- self.extend_pe(torch.tensor(0.0).expand(1, max_len))
- self._register_load_state_dict_pre_hook(_pre_hook)
- def extend_pe(self, x):
- """Reset the positional encodings."""
- if self.pe is not None:
- if self.pe.size(1) >= x.size(1):
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
- return
- pe = torch.zeros(x.size(1), self.d_model)
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
- div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) *
- -(math.log(10000.0) / self.d_model))
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0)
- self.pe = pe.to(device=x.device, dtype=x.dtype)
- def forward(self, x: torch.Tensor):
+ self.max_len = max_len
+ self.pe = torch.zeros(self.max_len, self.d_model)
+ position = torch.arange(0, self.max_len,
+ dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
+ -(math.log(10000.0) / self.d_model))
+ self.pe[:, 0::2] = torch.sin(position * div_term)
+ self.pe[:, 1::2] = torch.cos(position * div_term)
+ self.pe = self.pe.unsqueeze(0)
+ def forward(self,
+ x: torch.Tensor,
+ offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
x (torch.Tensor): Input. Its shape is (batch, time, ...)
+ offset (int): position offset
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
+ torch.Tensor: for compatibility to RelPositionalEncoding
- self.extend_pe(x)
- x = x * self.xscale + self.pe[:, :x.size(1)]
- return self.dropout(x)
+ assert offset + x.size(1) < self.max_len
+ self.pe = self.pe.to(x.device)
+ pos_emb = self.pe[:, offset:offset + x.size(1)]
+ x = x * self.xscale + pos_emb
+ return self.dropout(x), self.dropout(pos_emb)
+ def position_encoding(self, offset: int, size: int) -> torch.Tensor:
+ """ For getting encoding in a streaming fashion
-class ScaledPositionalEncoding(PositionalEncoding):
- """Scaled positional encoding module.
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
- See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
+ Args:
+ offset (int): start offset
+ size (int): requried size of position encoding
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ assert offset + size < self.max_len
+ return self.dropout(self.pe[:, offset:offset + size])
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ att_h (int): invalid here,for compatibility to RoPositionalEncoding
+ rope_abs_plus (bool): invalid here,for compatibility to RoPositionalEncoding
+ max_len (int): Maximum input length.
- def __init__(self, d_model, dropout_rate, max_len=5000):
- """Initialize class.
- :param int d_model: embedding dim
- :param float dropout_rate: dropout rate
- :param int max_len: maximum input length
+ def __init__(self, d_model: int, dropout_rate: float, att_h: int=4, max_len: int = 5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len=max_len, reverse=True)
+ def forward(self,
+ x: torch.Tensor,
+ offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
- super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
- self.alpha = torch.nn.Parameter(torch.tensor(1.0))
+ assert offset + x.size(1) < self.max_len
+ self.pe = self.pe.to(x.device)
+ x = x * self.xscale
+ pos_emb = self.pe[:, offset:offset + x.size(1)]
+ return self.dropout(x), self.dropout(pos_emb)
- def reset_parameters(self):
- """Reset parameters."""
- self.alpha.data = torch.tensor(1.0)
- def forward(self, x):
- """Add positional encoding.
+class NoPositionalEncoding(torch.nn.Module):
+ """ No position encoding
+ """
+ def __init__(self, d_model: int, dropout_rate: float, att_h: int=4,rope_abs_plus: bool=False):
+ super().__init__()
+ self.d_model = d_model
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ def forward(self,
+ x: torch.Tensor,
+ offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Just return zero vector for interface compatibility
+ """
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
+ return self.dropout(x), pos_emb
+ def position_encoding(self, offset: int, size: int) -> torch.Tensor:
+ return torch.zeros(1, size, self.d_model)
+# RoPE (Leo 2022-07-25)
+# reference:
+# RoFormer: Enhanced Transformer with Rotary Position Embedding.
+class RoPositionalEncoding(PositionalEncoding):
+ """ Rotary positional encoding module. The cos features of sinusoidal are organized in
+ the 2nd half of the vector.
+ Args:
+ d_embed (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ att_h (int): attention head num.
+ rope_abs_plus (bool): rope plus abs.
+ max_len (int): Maximum input length.
+ """
+ def __init__(self, d_embed: int, dropout_rate: float ,att_h: int=4 ,rope_abs_plus: bool=False, max_len: int = 5000 , d_roembed: int=-1):
+ """Initialize class."""
+ super().__init__(d_embed, dropout_rate, max_len=max_len)
+ if d_roembed < 1:
+ assert (d_embed % att_h) % 2 == 0
+ d_roembed = d_embed //att_h
+ else:
+ d_roembed = d_roembed
+ abs_rope = get_abs_position(d_roembed, self.max_len)
+ freq = torch.zeros_like(abs_rope)
+ sentinel = d_roembed // 2
+ freq[...,0:sentinel] = abs_rope[...,0::2]
+ freq[...,sentinel:] = abs_rope[...,1::2]
+ self.abs_pe = self.pe
+ self.pe = freq
+ self.rope_abs_plus = rope_abs_plus
+ def forward(self,
+ x: torch.Tensor,
+ offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute positional encoding.
- x (torch.Tensor): Input. Its shape is (batch, time, ...)
+ x (torch.Tensor): Input tensor (batch, time, `*`).
- torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
- self.extend_pe(x)
- x = x + self.alpha * self.pe[:, :x.size(1)]
- return self.dropout(x)
+ assert offset + x.size(1) < self.max_len
+ self.pe = self.pe.to(x.device)
+ x = x * self.xscale
+ pos_emb = self.pe[:, offset:offset + x.size(1)]
+ if self.rope_abs_plus:
+ self.abs_pe = self.abs_pe.to(x.device)
+ abs_pe = self.abs_pe[:, offset:offset + x.size(1)]
+ x += abs_pe
+ return self.dropout(x), pos_emb
diff --git a/pytorch/libs/nnet/transformer/encoder.py b/pytorch/libs/nnet/transformer/encoder.py
new file mode 100644
index 0000000..9e4989d
--- /dev/null
+++ b/pytorch/libs/nnet/transformer/encoder.py
@@ -0,0 +1,1104 @@
+# -*- coding:utf-8 -*-
+# Copyright xmuspeech (Author: Leo 2022-07)
+# Reference: https://github.com/wenet-e2e/wenet
+import torch
+import math
+from typing import Tuple, List, Optional
+from .attention import (
+ MultiHeadedAttention,
+ RelPositionMultiHeadedAttention,
+ RoPESelfAttention,
+from .embedding import (
+ PositionalEncoding,
+ RelPositionalEncoding,
+ NoPositionalEncoding,
+ RoPositionalEncoding
+ )
+from .encoder_layer import (
+ TransformerEncoderLayer,
+ ConformerEncoderLayer,
+ ReConformerEncoderLayer
+from .layer_norm import BasicNorm, LayerNorm,Trans_Bat
+from .multi_layer_conv import (
+ Conv1dLinear,
+ MultiLayeredConv1d,
+from .convolution import ConvolutionModule, ReConvolutionModule
+from .positionwise_feed_forward import PositionwiseFeedForward
+from .subsampling import (
+ LinearNoSubsampling,
+ Conv2dSubsampling2,
+ Conv2dSubsampling4,
+ ReConv2dSubsampling4,
+ SVConv2dSubsampling4,
+ Conv2dSubsampling6,
+ Conv2dSubsampling8,
+ SVConv2dSubsampling2,
+ FrameMerger
+import torch.nn as nn
+from .mask import add_optional_chunk_mask
+from .mask import make_pad_mask
+from libs.nnet.activation import Nonlinearity
+class BaseEncoder(torch.nn.Module):
+ """
+ :param int idim: input dim
+ :param int attention_dim: dimention of attention
+ :param int attention_heads: the number of heads of multi head attention
+ :param int linear_units: the number of units of position-wise feed forward
+ :param int num_blocks: the number of encoder blocks
+ :param int aux_layer_period: the period for randomcombiner/mfa.
+ :param int aux_layer_start: the start position for randomcombiner/mfa.
+ default :1. e.g. 3 means from 1/3 of the total block nums suggest for randomcombiner, mfa suggest set to 1+num_blocks.
+ :param float dropout_rate: dropout rate
+ :param float attention_dropout_rate: dropout rate in attention
+ :param float positional_dropout_rate: dropout rate after adding positional encoding
+ :param str input_layer: input layer type
+ :param str pos_enc_type: Encoder positional encoding layer type.
+ opitonal [abs_pos, rel_pos, no_pos, rot_pos]
+ :param bool rotary_value: whether apply rot_pos on value vector when use rot_pos, which contains abs position information.
+ :param bool rope_abs_plus: whether apply abs_pos when use rot_pos.
+ :param bool add_t5rel_bias: whether apply t5_rel position in attention score.
+ :param bool normalize_before: whether to use layer_norm before the first block
+ :param bool concat_after: whether to concat attention layer's input and output
+ if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ :param str positionwise_layer_type: positionwise feedforward type
+ :param int positionwise_conv_kernel_size: Kernel size of positionwise conv1d layer.
+ :param int static_chunk_size: chunk size for static chunk training and decoding.
+ :param bool use_dynamic_chunk: whether use dynamic chunk size for
+ training or not, You can only use fixed chunk(chunk_size > 0)
+ or dynamic chunk size(use_dynamic_chunk = True)
+ :param bool use_dynamic_left_chunk (bool): whether use dynamic left chunk in
+ dynamic chunk training
+ :param str comnbiner_type: combine the output of encoder with its sublayers.
+ opitonal [norm, mfa, random_frame, random_layer]
+ """
+ aux_layers: List[int]
+ def __init__(self, idim,
+ attention_dim=256,
+ attention_heads=4,
+ linear_units=2048,
+ mlp_head=False,
+ num_blocks=6,
+ aux_layer_period = 3,
+ aux_layer_start = 1,
+ dropout_rate=0.1,
+ positional_dropout_rate=0.1,
+ attention_dropout_rate=0.0,
+ attention_conv_out = False,
+ attention_norm_args = {},
+ input_layer="conv2d",
+ pos_enc_type="abs_pos",
+ rotary_value=True,
+ rope_abs_plus = False,
+ add_t5rel_bias = False,
+ att_type: str= 'multi',
+ gau_units: int = 512,
+ gau_key: int = 64,
+ normalize_before=True,
+ norm_type = "layer_norm",
+ concat_after=False,
+ positionwise_layer_type="linear",
+ positionwise_conv_kernel_size=1,
+ activation_type='relu',
+ activation_balancer=False,
+ static_chunk_size: int = 0,
+ left_chunk_size: int = -1,
+ use_dynamic_chunk: bool = False,
+ use_dynamic_left_chunk: bool = False,
+ combiner_type: str="norm",
+ re_scale: bool=False):
+ super().__init__()
+ self.att_type = att_type
+ self.mlp_head = mlp_head
+ if att_type != 'gau' and positionwise_layer_type == 'gau':
+ assert gau_key == attention_dim // attention_heads
+ pos_enc_dict = {}
+ if pos_enc_type == "abs_pos":
+ pos_enc_class = PositionalEncoding
+ elif pos_enc_type == "rel_pos":
+ pos_enc_class = RelPositionalEncoding
+ elif pos_enc_type == "rot_pos":
+ pos_enc_class = RoPositionalEncoding
+ pos_enc_dict['rope_abs_plus']= rope_abs_plus
+ if self.att_type == 'gau':
+ pos_enc_dict['d_roembed']= gau_key
+ else:
+ pos_enc_dict['att_h']= attention_heads
+ elif pos_enc_type == "no_pos":
+ pos_enc_class = NoPositionalEncoding
+ else:
+ raise ValueError("unknown pos_enc_layer: " + pos_enc_type)
+ if input_layer == "linear":
+ subsampling_class = LinearNoSubsampling
+ elif input_layer == "conv2d2":
+ subsampling_class = SVConv2dSubsampling2
+ elif input_layer == "conv2d":
+ subsampling_class = Conv2dSubsampling4
+ elif input_layer == "re_conv2d":
+ subsampling_class = ReConv2dSubsampling4
+ elif input_layer == "conv2d6":
+ subsampling_class = Conv2dSubsampling6
+ elif input_layer == "conv2d8":
+ subsampling_class = Conv2dSubsampling8
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.embed = subsampling_class(idim, attention_dim, dropout_rate,mlp_head,pos_enc_class(attention_dim,positional_dropout_rate, **pos_enc_dict))
+ self.pos_enc_type = pos_enc_type
+ self.positionwise_layer, self.positionwise_layer_args = self.get_positionwise_layer(
+ positionwise_layer_type,
+ attention_dim,
+ linear_units,
+ dropout_rate,
+ positionwise_conv_kernel_size,
+ activation_type,
+ activation_balancer,
+ re_scale,
+ attention_norm_args=attention_norm_args
+ )
+ if self.att_type == 'gau':
+ self.selfattn_layer, self.selfattn_layer_args = self.get_gau_layer(
+ pos_enc_type,
+ attention_dim,
+ gau_units,
+ gau_key,
+ attention_dropout_rate,
+ attention_conv_out,
+ attention_norm_args=attention_norm_args,
+ re_scale=re_scale
+ )
+ else:
+ self.selfattn_layer, self.selfattn_layer_args = self.get_selfattn_layer(
+ pos_enc_type,
+ attention_heads,
+ attention_dim,
+ attention_dropout_rate,
+ add_t5rel_bias,
+ attention_conv_out,
+ rotary_value,
+ attention_norm_args=attention_norm_args,
+ re_scale=re_scale
+ )
+ self.aux_layers,self.combiner = self.get_combiner(
+ num_blocks,
+ aux_layer_period,
+ aux_layer_start,
+ combiner_type
+ )
+ self.normalize_before = normalize_before
+ self._output_size = attention_dim*len(self.aux_layers) if combiner_type=="mfa" else attention_dim
+ self.after_norm = None
+ if self.normalize_before or combiner_type=="mfa":
+ if norm_type == "batch_norm":
+ self.use_layer_norm=False
+ self.after_norm = Trans_Bat(self._output_size)
+ else:
+ self.use_layer_norm=True
+ if norm_type == "layer_norm":
+ self.after_norm = LayerNorm(self._output_size,eps=1e-5)
+ else:
+ self.after_norm = BasicNorm(self._output_size)
+ self.static_chunk_size = static_chunk_size
+ self.use_dynamic_chunk = use_dynamic_chunk
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
+ self.left_chunk_size = left_chunk_size
+ def get_positionwise_layer(
+ self,
+ positionwise_layer_type="linear",
+ attention_dim=256,
+ linear_units=2048,
+ dropout_rate=0.1,
+ positionwise_conv_kernel_size=1,
+ activation_type='relu',
+ activation_balancer=False,
+ re_scale = False,
+ gau_key = 64,
+ attention_norm_args = {}
+ ):
+ """Define positionwise layer."""
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (attention_dim, linear_units, dropout_rate, activation_type, activation_balancer,re_scale)
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ attention_dim,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ activation_type,
+ activation_balancer,
+ re_scale
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ attention_dim,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ activation_type,
+ activation_balancer,
+ re_scale
+ )
+ elif positionwise_layer_type == "gau":
+ positionwise_layer, positionwise_layer_args = self.get_gau_layer(
+ self.pos_enc_type,
+ attention_dim,
+ linear_units,
+ gau_key,
+ dropout_rate,
+ attention_norm_args = attention_norm_args,
+ re_scale=re_scale
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+ return positionwise_layer, positionwise_layer_args
+ def get_selfattn_layer(
+ self,
+ pos_enc_type="abs_pos",
+ attention_heads=4,
+ attention_dim=256,
+ attention_dropout_rate=0.0,
+ add_t5rel_bias=False,
+ conv_out=False,
+ rotary_value = False,
+ attention_norm_args={},
+ re_scale=False
+ ):
+ """Define selfattn layer."""
+ selfattn_layer_args = (attention_heads,attention_dim,attention_dropout_rate,add_t5rel_bias,conv_out,attention_norm_args,re_scale)
+ if pos_enc_type == "rel_pos":
+ selfattn_layer = RelPositionMultiHeadedAttention
+ elif pos_enc_type == "rot_pos":
+ selfattn_layer = RoPESelfAttention
+ selfattn_layer_args=(attention_heads,attention_dim,attention_dropout_rate,add_t5rel_bias,conv_out,attention_norm_args,rotary_value,re_scale)
+ else:
+ selfattn_layer = MultiHeadedAttention
+ return selfattn_layer, selfattn_layer_args
+ def get_gau_layer(
+ self,
+ pos_enc_type="abs_pos",
+ attention_dim=256,
+ hidden_dim = 512,
+ d_qk = 64,
+ attention_dropout_rate=0.0,
+ conv_out=False,
+ attention_norm_args={},
+ re_scale=False
+ ):
+ """Define gau layer."""
+ if pos_enc_type == "abs_pos":
+ selfattn_layer = GAU
+ else:
+ selfattn_layer = RoPEGAU
+ selfattn_layer_args=(attention_dim,hidden_dim,d_qk,attention_dropout_rate,conv_out,attention_norm_args,re_scale)
+ return selfattn_layer, selfattn_layer_args
+ def get_combiner(self,
+ num_blocks: int,
+ aux_layer_period: int = 3,
+ aux_layer_start: int = 1,
+ combiner_type="norm"
+ ) -> Tuple[List[int],torch.nn.Module]:
+ """Define combiner layer."""
+ assert combiner_type in ["norm", "mfa", "random_frame", "random_layer"], "unknown combiner_type {}".format(combiner_type)
+ assert aux_layer_period,aux_layer_start > 0
+ assert num_blocks > 0
+ aux_layers=list(
+ range(
+ num_blocks // aux_layer_start,
+ num_blocks - 1,
+ aux_layer_period,
+ )
+ )
+ assert len(set(aux_layers)) == len(aux_layers)
+ assert num_blocks - 1 not in aux_layers
+ aux_layers = aux_layers + [num_blocks - 1]
+ combiner = RandomCombine(
+ aux_layers=aux_layers,
+ combiner_type=combiner_type,
+ final_weight=0.5,
+ pure_prob=0.333,
+ stddev=2.0,
+ )
+ return aux_layers,combiner
+ def output_size(self) -> int:
+ return self._output_size
+ def forward(
+ self,
+ xs: torch.Tensor,
+ xs_lens: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ decoding_left_chunk_size: int = -2,
+ warmup: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Embed positions in tensor
+ :param torch.Tensor xs: input tensor (B, T, D)
+ :param torch.Tensor xs_lens: input length (B)
+ :param int decoding_chunk_size: decoding chunk size for dynamic chunk
+ 0: use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ :param int decoding_left_chunk_size:
+ <-1: default , the number of left chunks is self.left_chunk_size.
+ -1: use full left chunk
+ 0: no left chunk
+ >0: the number of left chunks
+ : param float warmup:
+ Model level warmup, a floating point value that gradually increases from 0 throughout
+ training; when it is >= 1.0 we are "fully warmed up". It is used
+ to turn modules on sequentially.
+ :return: position embedded tensor and mask
+ :rtype Tuple[torch.Tensor, torch.Tensor]:
+ """
+ if decoding_left_chunk_size <= -2:
+ left_chunk_size = self.left_chunk_size
+ else:
+ left_chunk_size = decoding_left_chunk_size
+ T = xs.size(1)
+ if self.mlp_head:
+ T+=1
+ xs_lens+=1
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
+ xs, pos_emb, masks = self.embed(xs, masks)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(xs, masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size,
+ left_chunk_size)
+ out= []
+ for i,layer in enumerate(self.encoders):
+ xs, chunk_masks = layer(xs, chunk_masks, pos_emb, mask_pad, warmup = warmup)
+ if i in self.aux_layers:
+ out.append(xs)
+ if len(out)>0:
+ xs = self.combiner(out)
+ if self.after_norm is not None:
+ if not self.use_layer_norm:
+ xs = xs.transpose(1, 2)
+ xs = self.after_norm(xs)
+ if not self.use_layer_norm:
+ xs = xs.transpose(1, 2)
+ return xs, masks
+class TransformerEncoder(BaseEncoder):
+ """Transformer encoder module."""
+ def __init__(
+ self,
+ idim: int,
+ attention_dim: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ mlp_head: bool=False,
+ num_blocks: int = 6,
+ aux_layer_period: int = 3,
+ aux_layer_start:int = 1,
+ dropout_rate: float = 0.1,
+ layer_dropout: float = 0.,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ attention_conv_out : bool = False,
+ attention_norm_args : dict = {},
+ input_layer: str = "conv2d",
+ pos_enc_type: str = "abs_pos",
+ rotary_value: bool = True,
+ rope_abs_plus: bool = False,
+ add_t5rel_bias: bool = False,
+ att_type: str= 'multi',
+ gau_units: int = 512,
+ gau_key: int = 64,
+ normalize_before: bool = True,
+ norm_type: str = "layer_norm",
+ concat_after: bool = False,
+ positionwise_layer_type="linear",
+ positionwise_conv_kernel_size=1,
+ activation_type='relu',
+ activation_balancer: bool=False,
+ static_chunk_size: int = 0,
+ left_chunk_size: int = -1,
+ use_dynamic_chunk: bool = False,
+ use_dynamic_left_chunk: bool = False,
+ combiner_type: str="norm",
+ re_scale: bool=False,
+ convfnn_blocks: int=0,
+ **args
+ ):
+ """ Construct TransformerEncoder
+ See Encoder for the meaning of each parameter.
+ """
+ super().__init__(idim, attention_dim, attention_heads, linear_units,
+ mlp_head, num_blocks, aux_layer_period, aux_layer_start, dropout_rate,
+ positional_dropout_rate, attention_dropout_rate, attention_conv_out,
+ attention_norm_args, input_layer, pos_enc_type, rotary_value, rope_abs_plus,
+ add_t5rel_bias, att_type, gau_units, gau_key, normalize_before, norm_type,
+ concat_after, positionwise_layer_type, positionwise_conv_kernel_size,
+ activation_type, activation_balancer, static_chunk_size, left_chunk_size,
+ use_dynamic_chunk, use_dynamic_left_chunk,combiner_type, re_scale)
+ pre_selfattn_layer, pre_selfattn_layer_args = self.selfattn_layer,self.selfattn_layer_args
+ if positionwise_layer_type == "gau":
+ pre_positionwise_layer,pre_positionwise_layer_args = self.get_gau_layer(
+ self.pos_enc_type,
+ attention_dim,
+ linear_units,
+ gau_key,
+ dropout_rate,
+ conv_out =True,
+ attention_norm_args = attention_norm_args,
+ re_scale=re_scale
+ )
+ pre_selfattn_layer, pre_selfattn_layer_args = self.get_gau_layer(
+ pos_enc_type,
+ attention_dim,
+ gau_units,
+ gau_key,
+ attention_dropout_rate,
+ conv_out =True,
+ attention_norm_args=attention_norm_args,
+ re_scale=re_scale
+ )
+ else:
+ pre_positionwise_layer = MultiLayeredConv1d
+ pre_positionwise_layer_args = (
+ attention_dim,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ activation_type,
+ activation_balancer,
+ re_scale
+ )
+ encoders = []
+ for _ in range(convfnn_blocks):
+ encoders.append(TransformerEncoderLayer(
+ attention_dim,
+ pre_selfattn_layer(*pre_selfattn_layer_args),
+ pre_positionwise_layer(*pre_positionwise_layer_args),
+ dropout_rate,
+ layer_dropout,
+ normalize_before,
+ norm_type,
+ positionwise_layer_type,
+ concat_after,
+ ))
+ for _ in range(num_blocks-convfnn_blocks):
+ encoders.append(TransformerEncoderLayer(
+ attention_dim,
+ self.selfattn_layer(*self.selfattn_layer_args),
+ self.positionwise_layer(*self.positionwise_layer_args),dropout_rate,
+ layer_dropout,normalize_before, norm_type,positionwise_layer_type,concat_after))
+ self.encoders = torch.nn.ModuleList(encoders)
+class ConformerEncoder(BaseEncoder):
+ """Conformer encoder module."""
+ def __init__(
+ self,
+ idim: int,
+ attention_dim: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ mlp_head: bool=False,
+ num_blocks: int = 6,
+ aux_layer_period: int = 3,
+ aux_layer_start: int = 1,
+ dropout_rate: float = 0.1,
+ layer_dropout: float = 0.,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ attention_conv_out : bool = False,
+ attention_norm_args : dict = {},
+ input_layer: str = "conv2d",
+ pos_enc_type: str = "rel_pos",
+ rotary_value: bool = True,
+ rope_abs_plus: bool = False,
+ add_t5rel_bias: bool = False,
+ att_type: str= 'multi',
+ gau_units: int = 512,
+ gau_key: int = 64,
+ normalize_before: bool = True,
+ norm_type: str = "layer_norm",
+ concat_after: bool = False,
+ positionwise_layer_type="linear",
+ positionwise_conv_kernel_size=3,
+ activation_type: str = "swish",
+ activation_balancer: bool=False,
+ static_chunk_size: int = 0,
+ left_chunk_size: int = -1,
+ use_dynamic_chunk: bool = False,
+ use_dynamic_left_chunk: bool = False,
+ macaron_style: bool = True,
+ use_cnn_module: bool = True,
+ cnn_module_kernel: int = 15,
+ causal: bool = False,
+ cnn_module_norm: str = "batch_norm",
+ combiner_type: str="norm",
+ re_scale: bool=False,
+ convfnn_blocks: int=0,
+ **args
+ ):
+ """Construct ConformerEncoder
+ Args:
+ input_size to use_dynamic_chunk, see in BaseEncoder
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
+ conv1d layer.
+ macaron_style (bool): Whether to use macaron style for
+ positionwise layer.
+ activation_type (str): Encoder activation function type.
+ use_cnn_module (bool): Whether to use convolution module.
+ cnn_module_kernel (int): Kernel size of convolution module.
+ causal (bool): whether to use causal convolution or not.
+ """
+ super().__init__(idim, attention_dim, attention_heads, linear_units,
+ mlp_head, num_blocks, aux_layer_period, aux_layer_start, dropout_rate,
+ positional_dropout_rate, attention_dropout_rate, attention_conv_out,
+ attention_norm_args,input_layer, pos_enc_type, rotary_value, rope_abs_plus,
+ add_t5rel_bias, att_type, gau_units, gau_key, normalize_before, norm_type,
+ concat_after, positionwise_layer_type, positionwise_conv_kernel_size,
+ activation_type, activation_balancer,static_chunk_size, left_chunk_size,
+ use_dynamic_chunk, use_dynamic_left_chunk,combiner_type,re_scale)
+ activation = Nonlinearity(activation_type)
+ assert activation is not None
+ # convolution module definition
+ convolution_layer = ConvolutionModule
+ convolution_layer_args = (attention_dim, cnn_module_kernel, activation,
+ cnn_module_norm, causal, activation_balancer)
+ pre_selfattn_layer, pre_selfattn_layer_args = self.selfattn_layer,self.selfattn_layer_args
+ if positionwise_layer_type == "gau":
+ pre_positionwise_layer,pre_positionwise_layer_args = self.get_gau_layer(
+ self.pos_enc_type,
+ attention_dim,
+ linear_units,
+ gau_key,
+ dropout_rate,
+ conv_out =True,
+ attention_norm_args = attention_norm_args,
+ re_scale=re_scale
+ )
+ pre_selfattn_layer, pre_selfattn_layer_args = self.get_gau_layer(
+ pos_enc_type,
+ attention_dim,
+ gau_units,
+ gau_key,
+ attention_dropout_rate,
+ conv_out =True,
+ attention_norm_args=attention_norm_args,
+ re_scale=re_scale
+ )
+ else:
+ pre_positionwise_layer = MultiLayeredConv1d
+ pre_positionwise_layer_args = (
+ attention_dim,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ activation_type,
+ activation_balancer,
+ re_scale
+ )
+ encoders = []
+ for _ in range(convfnn_blocks):
+ encoders.append(ConformerEncoderLayer(
+ attention_dim,
+ pre_selfattn_layer(*pre_selfattn_layer_args),
+ pre_positionwise_layer(*pre_positionwise_layer_args),
+ pre_positionwise_layer(
+ *pre_positionwise_layer_args) if macaron_style else None,
+ convolution_layer(
+ *convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ layer_dropout,
+ normalize_before,
+ norm_type,
+ positionwise_layer_type,
+ concat_after,
+ ))
+ for _ in range(num_blocks-convfnn_blocks):
+ encoders.append(ConformerEncoderLayer(
+ attention_dim,
+ self.selfattn_layer(*self.selfattn_layer_args),
+ self.positionwise_layer(*self.positionwise_layer_args),
+ self.positionwise_layer(
+ *self.positionwise_layer_args) if macaron_style else None,
+ convolution_layer(
+ *convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ layer_dropout,
+ normalize_before,
+ norm_type,
+ positionwise_layer_type,
+ concat_after,
+ ))
+ self.encoders = torch.nn.ModuleList(encoders)
+class ReConformerEncoder(BaseEncoder):
+ """Conformer encoder module."""
+ def __init__(
+ self,
+ idim: int,
+ attention_dim: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ mlp_head: bool=False,
+ num_blocks: int = 6,
+ aux_layer_period: int = 3,
+ aux_layer_start: int = 1,
+ dropout_rate: float = 0.1,
+ layer_dropout: float = 0.,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ attention_conv_out : bool = False,
+ attention_norm_args : dict = {},
+ input_layer: str = "re_conv2d",
+ pos_enc_type: str = "rel_pos",
+ rotary_value: bool = True,
+ rope_abs_plus: bool = False,
+ add_t5rel_bias: bool = False,
+ att_type: str= 'multi',
+ gau_units: int = 512,
+ gau_key: int = 64,
+ normalize_before: bool = False,
+ norm_type: str = "basic_norm",
+ concat_after: bool = False,
+ positionwise_layer_type="linear",
+ positionwise_conv_kernel_size=3,
+ activation_type: str = "double_swish",
+ activation_balancer: bool=True,
+ static_chunk_size: int = 0,
+ left_chunk_size: int = -1,
+ use_dynamic_chunk: bool = False,
+ use_dynamic_left_chunk: bool = False,
+ macaron_style: bool = True,
+ use_cnn_module: bool = True,
+ cnn_module_kernel: int = 15,
+ causal: bool = False,
+ combiner_type: str="norm",
+ re_scale: bool=True,
+ convfnn_blocks: int=0,
+ **args
+ ):
+ """Construct ConformerEncoder
+ Args:
+ input_size to use_dynamic_chunk, see in BaseEncoder
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
+ conv1d layer.
+ macaron_style (bool): Whether to use macaron style for
+ positionwise layer.
+ activation_type (str): Encoder activation function type.
+ use_cnn_module (bool): Whether to use convolution module.
+ cnn_module_kernel (int): Kernel size of convolution module.
+ causal (bool): whether to use causal convolution or not.
+ """
+ assert normalize_before == False,"set normalize_before=False."
+ assert norm_type == "basic_norm","set norm_type=basic_norm."
+ assert re_scale == True, "reconformer set res_scale=True."
+ assert activation_balancer == True
+ super().__init__(idim, attention_dim, attention_heads, linear_units,
+ mlp_head, num_blocks, aux_layer_period, aux_layer_start, dropout_rate,
+ positional_dropout_rate, attention_dropout_rate, attention_conv_out,
+ attention_norm_args,input_layer, pos_enc_type, rotary_value, rope_abs_plus,
+ add_t5rel_bias, att_type, gau_units, gau_key, normalize_before, norm_type,
+ concat_after, positionwise_layer_type, positionwise_conv_kernel_size,
+ activation_type, activation_balancer,static_chunk_size, left_chunk_size,
+ use_dynamic_chunk, use_dynamic_left_chunk,combiner_type,re_scale)
+ activation = Nonlinearity(activation_type)
+ assert activation is not None
+ # convolution module definition
+ convolution_layer = ReConvolutionModule
+ convolution_layer_args = (attention_dim, cnn_module_kernel, activation,
+ causal)
+ pre_selfattn_layer, pre_selfattn_layer_args = self.selfattn_layer,self.selfattn_layer_args
+ if positionwise_layer_type == "gau":
+ pre_positionwise_layer,pre_positionwise_layer_args = self.get_gau_layer(
+ self.pos_enc_type,
+ attention_dim,
+ linear_units,
+ gau_key,
+ dropout_rate,
+ conv_out =True,
+ attention_norm_args = attention_norm_args,
+ re_scale=re_scale
+ )
+ pre_selfattn_layer, pre_selfattn_layer_args = self.get_gau_layer(
+ pos_enc_type,
+ attention_dim,
+ gau_units,
+ gau_key,
+ attention_dropout_rate,
+ conv_out =True,
+ attention_norm_args=attention_norm_args,
+ re_scale=re_scale
+ )
+ else:
+ pre_positionwise_layer = MultiLayeredConv1d
+ pre_positionwise_layer_args = (
+ attention_dim,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ activation_type,
+ activation_balancer,
+ re_scale
+ )
+ encoders = []
+ for _ in range(convfnn_blocks):
+ encoders.append(ReConformerEncoderLayer(
+ attention_dim,
+ pre_selfattn_layer(*pre_selfattn_layer_args),
+ pre_positionwise_layer(*pre_positionwise_layer_args),
+ pre_positionwise_layer(
+ *pre_positionwise_layer_args) if macaron_style else None,
+ convolution_layer(
+ *convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ layer_dropout,
+ positionwise_layer_type,
+ concat_after,
+ ))
+ for _ in range(num_blocks-convfnn_blocks):
+ encoders.append(ReConformerEncoderLayer(
+ attention_dim,
+ self.selfattn_layer(*self.selfattn_layer_args),
+ self.positionwise_layer(*self.positionwise_layer_args),
+ self.positionwise_layer(
+ *self.positionwise_layer_args) if macaron_style else None,
+ convolution_layer(
+ *convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ layer_dropout,
+ positionwise_layer_type,
+ concat_after,
+ ))
+ self.encoders = torch.nn.ModuleList(encoders)
+# RandomCombine conformer in k2, for deeper training.
+# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py
+class RandomCombine(nn.Module):
+ """
+ This module combines a list of Tensors, all with the same shape, to
+ produce a single output of that same shape which, in training time,
+ is a random combination of all the inputs; but which in test time
+ will be just the last input.
+ The idea is that the list of Tensors will be a list of outputs of multiple
+ conformer layers. This has a similar effect as iterated loss. (See:
+ """
+ def __init__(
+ self,
+ aux_layers: list,
+ combiner_type: str = "norm",
+ final_weight: float = 0.5,
+ pure_prob: float = 0.5,
+ stddev: float = 2.0,
+ ) -> None:
+ """
+ Args:
+ num_inputs:
+ The number of tensor inputs, which equals the number of layers'
+ outputs that are fed into this module. E.g. in an 18-layer neural
+ net if we output layers 16, 12, 18, num_inputs would be 3.
+ final_weight:
+ The amount of weight or probability we assign to the
+ final layer when randomly choosing layers or when choosing
+ continuous layer weights.
+ pure_prob:
+ The probability, on each frame, with which we choose
+ only a single layer to output (rather than an interpolation)
+ stddev:
+ A standard deviation that we add to log-probs for computing
+ randomized weights.
+ The method of choosing which layers, or combinations of layers, to use,
+ is conceptually as follows::
+ With probability `pure_prob`::
+ With probability `final_weight`: choose final layer,
+ Else: choose random non-final layer.
+ Else::
+ Choose initial log-weights that correspond to assigning
+ weight `final_weight` to the final layer and equal
+ weights to other layers; then add Gaussian noise
+ with variance `stddev` to these log-weights, and normalize
+ to weights (note: the average weight assigned to the
+ final layer here will not be `final_weight` if stddev>0).
+ """
+ super().__init__()
+ self.num_inputs = len(aux_layers)
+ self.aux_layers = aux_layers
+ assert self.num_inputs >= 1
+ if combiner_type in ["random_frame", "random_layer"]:
+ assert 0 <= pure_prob <= 1, pure_prob
+ assert 0 < final_weight < 1, final_weight
+ self.final_weight = final_weight
+ self.pure_prob = pure_prob
+ self.stddev = stddev
+ self.final_log_weight = (
+ torch.tensor(
+ (final_weight / (1 - final_weight)) * (self.num_inputs - 1)
+ )
+ .log()
+ .item()
+ )
+ self.combiner_type = combiner_type
+ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
+ if self.combiner_type == "mfa":
+ return self.forward_mfa(inputs)
+ elif self.combiner_type == "random_frame":
+ return self.forward_rand_frame(inputs)
+ elif self.combiner_type == "random_layer":
+ return self.forward_rand_layer(inputs)
+ return self.forward_norm(inputs)
+ def forward_mfa(self, inputs: List[torch.Tensor]) -> torch.Tensor:
+ return torch.cat(inputs,dim=-1)
+ def forward_norm(self, inputs: List[torch.Tensor]) -> torch.Tensor:
+ return inputs[-1]
+ def forward_rand_frame(self, inputs: List[torch.Tensor]) -> torch.Tensor:
+ """Forward function.
+ Args:
+ inputs:
+ A list of Tensor, e.g. from various layers of a transformer.
+ All must be the same shape, of (*, num_channels)
+ Returns:
+ A Tensor of shape (*, num_channels). In test mode
+ this is just the final input.
+ """
+ num_inputs = self.num_inputs
+ assert len(inputs) == num_inputs
+ if not self.training or torch.jit.is_scripting() or self.num_inputs==1:
+ return inputs[-1]
+ # Shape of weights: (*, num_inputs)
+ num_channels = inputs[0].shape[-1]
+ num_frames = inputs[0].numel() // num_channels
+ ndim = inputs[0].ndim
+ # stacked_inputs: (num_frames, num_channels, num_inputs)
+ stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
+ (num_frames, num_channels, num_inputs)
+ )
+ # weights: (num_frames, num_inputs)
+ weights = self._get_random_weights(
+ inputs[0].dtype, inputs[0].device, num_frames
+ )
+ weights = weights.reshape(num_frames, num_inputs, 1)
+ # ans: (num_frames, num_channels, 1)
+ ans = torch.matmul(stacked_inputs, weights)
+ # ans: (*, num_channels)
+ ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
+ # The following if causes errors for torch script in torch 1.6.0
+ # if __name__ == "__main__":
+ # # for testing only...
+ # print("Weights = ", weights.reshape(num_frames, num_inputs))
+ return ans
+ def forward_rand_layer(self, inputs: List[torch.Tensor]) -> torch.Tensor:
+ """Forward function.
+ Args:
+ inputs:
+ A list of Tensor, e.g. from various layers of a transformer.
+ All must be the same shape, of (B, T, C)
+ Returns:
+ A Tensor of shape (B, T, C). In test mode
+ this is just the final input.
+ """
+ num_inputs = self.num_inputs
+ assert len(inputs) == num_inputs
+ if not self.training or torch.jit.is_scripting() or self.num_inputs==1:
+ return inputs[-1]
+ num_channels = inputs[0].shape[-1]
+ num_b = inputs[0].shape[0]
+ ndim = inputs[0].ndim
+ # stacked_inputs: (B, T, C, num_inputs)
+ stacked_inputs = torch.stack(inputs, dim=ndim)
+ # weights: (B, num_inputs)
+ weights = self._get_random_weights(
+ inputs[0].dtype, inputs[0].device, num_b
+ )
+ weights = weights.reshape(num_b,1, num_inputs, 1)
+ # ans: (B, T, C, 1)
+ ans = torch.matmul(stacked_inputs, weights)
+ # ans: (B, T, C)
+ ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
+ # The following if causes errors for torch script in torch 1.6.0
+ # if __name__ == "__main__":
+ # # for testing only...
+ # print("Weights = ", weights.reshape(num_frames, num_inputs))
+ return ans
+ def _get_random_weights(
+ self, dtype: torch.dtype, device: torch.device, num_frames: int
+ ) -> torch.Tensor:
+ """Return a tensor of random weights, of shape
+ `(num_frames, self.num_inputs)`,
+ Args:
+ dtype:
+ The data-type desired for the answer, e.g. float, double.
+ device:
+ The device needed for the answer.
+ num_frames:
+ The number of sets of weights desired
+ Returns:
+ A tensor of shape (num_frames, self.num_inputs), such that
+ `ans.sum(dim=1)` is all ones.
+ """
+ pure_prob = self.pure_prob
+ if pure_prob == 0.0:
+ return self._get_random_mixed_weights(dtype, device, num_frames)
+ elif pure_prob == 1.0:
+ return self._get_random_pure_weights(dtype, device, num_frames)
+ else:
+ p = self._get_random_pure_weights(dtype, device, num_frames)
+ m = self._get_random_mixed_weights(dtype, device, num_frames)
+ return torch.where(
+ torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
+ )
+ def _get_random_pure_weights(
+ self, dtype: torch.dtype, device: torch.device, num_frames: int
+ ):
+ """Return a tensor of random one-hot weights, of shape
+ `(num_frames, self.num_inputs)`,
+ Args:
+ dtype:
+ The data-type desired for the answer, e.g. float, double.
+ device:
+ The device needed for the answer.
+ num_frames:
+ The number of sets of weights desired.
+ Returns:
+ A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
+ exactly one weight equal to 1.0 on each frame.
+ """
+ final_prob = self.final_weight
+ # final contains self.num_inputs - 1 in all elements
+ final = torch.full((num_frames,), self.num_inputs - 1, device=device)
+ # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
+ nonfinal = torch.randint(
+ self.num_inputs - 1, (num_frames,), device=device
+ )
+ indexes = torch.where(
+ torch.rand(num_frames, device=device) < final_prob, final, nonfinal
+ )
+ ans = torch.nn.functional.one_hot(
+ indexes, num_classes=self.num_inputs
+ ).to(dtype=dtype)
+ return ans
+ def _get_random_mixed_weights(
+ self, dtype: torch.dtype, device: torch.device, num_frames: int
+ ):
+ """Return a tensor of random one-hot weights, of shape
+ `(num_frames, self.num_inputs)`,
+ Args:
+ dtype:
+ The data-type desired for the answer, e.g. float, double.
+ device:
+ The device needed for the answer.
+ num_frames:
+ The number of sets of weights desired.
+ Returns:
+ A tensor of shape (num_frames, self.num_inputs), which elements
+ in [0..1] that sum to one over the second axis, i.e.
+ `ans.sum(dim=1)` is all ones.
+ """
+ logprobs = (
+ torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
+ * self.stddev
+ )
+ logprobs[:, -1] += self.final_log_weight
+ return logprobs.softmax(dim=1)
+ def extra_repr(self):
+ s = ('{combiner_type}, num_inputs_layer={num_inputs}, aux_layers={aux_layers}')
+ if "random" in self.combiner_type:
+ s += ', final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}, final_log_weight={final_log_weight}'
+ return s.format(**self.__dict__)
+# def scale_args(layer_args,num=2,scale=2):
+ # new_args=[]
+ # for i,arg in enumerate(layer_args):
+ # if i x + att(x)
- def __init__(self, size, self_attn, feed_forward, dropout_rate,
- normalize_before=True, concat_after=False):
- super(EncoderLayer, self).__init__()
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ dropout_rate: float=0.1,
+ layer_dropout: float=0.,
+ normalize_before: bool = True,
+ norm_type: str = "layer_norm",
+ positionwise_layer_type: str="linear",
+ concat_after: bool = False
+ ):
+ super(TransformerEncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
+ if norm_type == "batch_norm":
+ self.use_layer_norm = False
+ norm_tp = Trans_Bat
+ else:
+ self.use_layer_norm = True
+ if norm_type == "layer_norm":
+ norm_tp = LayerNorm
+ else:
+ norm_tp = BasicNorm
+ self.norm1 = norm_tp(size)
+ self.norm2 = norm_tp(size)
self.dropout = nn.Dropout(dropout_rate)
+ self.layer_dropout = layer_dropout
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
+ self.concat_linear = None
+ self.positionwise_layer_type = positionwise_layer_type
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
- def forward(self, x, mask):
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: Optional[torch.Tensor] = None,
+ warmup: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoded features
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
:param torch.Tensor mask: mask for x (batch, max_time_in)
+ :param torch.Tensor pos_emb:
+ :param torch.Tensor mask_pad: does not used in transformer layer, just for unified api with conformer.
:rtype: Tuple[torch.Tensor, torch.Tensor]
+ warmup_scale = min(0.1 + warmup, 1.0)
+ if self.training:
+ alpha = (
+ warmup_scale
+ if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+ else 0.1
+ )
+ else:
+ alpha = 1.0
+ x_orig = x
residual = x
if self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
x = self.norm1(x)
- if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, x, x, mask)), dim=-1)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x_att = self.self_attn(x, x, x, mask, pos_emb)
+ if self.concat_linear is not None:
+ x_concat = torch.cat((x,x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
- x = residual + self.dropout(self.self_attn(x, x, x, mask))
+ x = residual + self.dropout(x_att)
if not self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
x = self.norm1(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
residual = x
if self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
x = self.norm2(x)
- x = residual + self.dropout(self.feed_forward(x))
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ if self.positionwise_layer_type == 'gau':
+ x = self.feed_forward(x, x, x, mask, pos_emb)
+ else:
+ x = self.feed_forward(x)
+ x = residual + self.dropout(x)
if not self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
x = self.norm2(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ if alpha != 1.0:
+ x = alpha * x + (1 - alpha) * x_orig
+ return x, mask
+class ConformerEncoderLayer(nn.Module):
+ """Encoder layer module
+ :param int size: input dim
+ :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module
+ :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward:
+ feed forward module
+ :param feed_forward_macaron (torch.nn.Module): Additional feed-forward module
+ instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ :param conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ :param float dropout_rate: dropout rate
+ :param bool normalize_before: whether to use layer_norm before the first block
+ :param bool concat_after: whether to concat attention layer's input and output
+ if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ """
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ feed_forward_macaron: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ dropout_rate: float=0.1,
+ layer_dropout: float=0.,
+ normalize_before: bool = True,
+ norm_type: str = "layer_norm",
+ positionwise_layer_type: str="linear",
+ concat_after: bool = False
+ ):
+ super(ConformerEncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ if norm_type == "batch_norm":
+ self.use_layer_norm = False
+ norm_tp = Trans_Bat
+ else:
+ self.use_layer_norm = True
+ if norm_type == "layer_norm":
+ norm_tp = LayerNorm
+ else:
+ norm_tp = BasicNorm
+ self.norm_ff = norm_tp(size)
+ self.norm_mha = norm_tp(size)
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = norm_tp(size)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = norm_tp(size) # for the CNN module
+ self.norm_final = norm_tp(size) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.layer_dropout = layer_dropout
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ self.positionwise_layer_type = positionwise_layer_type
+ self.concat_linear = None
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: Optional[torch.Tensor] = None,
+ warmup: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute encoded features
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
+ :param torch.Tensor mask: mask for x (batch, max_time_in, max_time_in)
+ :param torch.Tensor pos_emb:
+ :param torch.Tensor mask_pad: batch padding mask used for conv module (batch, 1, time)
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
+ """
+ warmup_scale = min(0.1 + warmup, 1.0)
+ if self.training:
+ alpha = (
+ warmup_scale
+ if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+ else 0.1
+ )
+ else:
+ alpha = 1.0
+ x_orig = x
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_ff_macaron(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ if self.positionwise_layer_type == 'gau':
+ x = self.feed_forward_macaron(x, x, x, mask, pos_emb)
+ else:
+ x = self.feed_forward_macaron(x)
+ x = residual + self.ff_scale * self.dropout(x)
+ if not self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_ff_macaron(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ # MHA
+ residual = x
+ if self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_mha(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x_att = self.self_attn(x, x, x, mask, pos_emb)
+ if self.concat_linear is not None:
+ x_concat = torch.cat((x, x_att), dim=-1)
+ x = residual + self.concat_linear(x_concat)
+ else:
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_mha(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ # convolution module
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_conv(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.conv_module(x, mask_pad)
+ x = residual + self.dropout(x)
+ if not self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_conv(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ # FFN
+ residual = x
+ if self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_ff(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ if self.positionwise_layer_type == 'gau':
+ x = self.feed_forward(x, x, x, mask, pos_emb)
+ else:
+ x = self.feed_forward(x)
+ x = residual + self.ff_scale * self.dropout(x)
+ if not self.normalize_before:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_ff(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ if self.conv_module is not None:
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.norm_final(x)
+ if not self.use_layer_norm:
+ x = x.transpose(1, 2)
+ if alpha != 1.0:
+ x = alpha * x + (1 - alpha) * x_orig
+ return x, mask
+class ReConformerEncoderLayer(nn.Module):
+ """Encoder layer module
+ :param int size: input dim
+ :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module
+ :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward:
+ feed forward module
+ :param feed_forward_macaron (torch.nn.Module): Additional feed-forward module
+ instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ :param conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ :param float dropout_rate: dropout rate
+ :param bool concat_after: whether to concat attention layer's input and output
+ if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ """
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ feed_forward_macaron: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ dropout_rate: float=0.1,
+ layer_dropout: float=0.075,
+ positionwise_layer_type: str="linear",
+ concat_after: bool = False
+ ):
+ super(ReConformerEncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.dropout = nn.Dropout(dropout_rate)
+ self.layer_dropout = layer_dropout
+ self.size = size
+ self.concat_after = concat_after
+ self.positionwise_layer_type = positionwise_layer_type
+ self.concat_linear = None
+ if self.concat_after:
+ self.concat_linear = ScaledLinear(size + size, size)
+ self.norm_final = BasicNorm(size)
+ # try to ensure the output is close to zero-mean (or at least, zero-median).
+ self.balancer = ActivationBalancer(
+ channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+ )
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: Optional[torch.Tensor] = None,
+ warmup: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute encoded features
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
+ :param torch.Tensor mask: mask for x (batch, max_time_in, max_time_in)
+ :param torch.Tensor pos_emb:
+ :param torch.Tensor mask_pad: batch padding mask used for conv module (batch, 1, time)
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
+ """
+ warmup_scale = min(0.1 + warmup, 1.0)
+ if self.training:
+ alpha = (
+ warmup_scale
+ if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+ else 0.1
+ )
+ else:
+ alpha = 1.0
+ x_orig = x
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.positionwise_layer_type == 'gau':
+ x = self.feed_forward_macaron(x, x, x, mask, pos_emb)
+ else:
+ x = self.feed_forward_macaron(x)
+ x = residual + self.dropout(x)
+ # MHA
+ x_att = self.self_attn(x, x, x, mask, pos_emb)
+ if self.concat_linear is not None:
+ x_concat = torch.cat((x, x_att), dim=-1)
+ x = x + self.concat_linear(x_concat)
+ else:
+ x = x + self.dropout(x_att)
+ # convolution module
+ if self.conv_module is not None:
+ residual = x
+ x = self.conv_module(x, mask_pad)
+ x = residual + self.dropout(x)
+ # FFN
+ residual = x
+ if self.positionwise_layer_type == 'gau':
+ x = self.feed_forward(x, x, x, mask, pos_emb)
+ else:
+ x = self.feed_forward(x)
+ x = residual + self.dropout(x)
+ x = self.norm_final(self.balancer(x))
+ if alpha != 1.0:
+ x = alpha * x + (1 - alpha) * x_orig
return x, mask
diff --git a/pytorch/libs/nnet/transformer/layer_norm.py b/pytorch/libs/nnet/transformer/layer_norm.py
index 66e81fd..0c28a05 100644
--- a/pytorch/libs/nnet/transformer/layer_norm.py
+++ b/pytorch/libs/nnet/transformer/layer_norm.py
@@ -3,8 +3,57 @@
# Reference: https://github.com/espnet/espnet.
import torch
+import torch.nn.functional as F
+class Trans_Bat(torch.nn.BatchNorm1d):
+ """BatchNorm1d module
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ :param bool transpose: transpose T and F.
+ :param float eps: invalid, compatible with LayerNorm.
+ """
+ def __init__(self, nout, transpose=False,eps=1e-12,learnabel_affine: bool = True):
+ super(Trans_Bat,self).__init__(nout)
+ self.norm = torch.nn.BatchNorm1d(nout,affine=learnabel_affine)
+ self.transpose = transpose
+ def forward(self, x):
+ """Apply BatchNorm1d normalization
+ :param torch.Tensor x: input tensor
+ :return: batch normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.transpose:
+ return self.norm(x.transpose(1, -1)).transpose(1, -1)
+ else:
+ return self.norm(x)
+# class LayerNorm(torch.nn.Module):
+# """Layer normalization module
+# :param int nout: output dim size
+# :param int dim: dimension to be normalized
+# """
+# def __init__(
+# self,
+# nout: int,
+# dim: int = -1, # CAUTION: see documentation.
+# eps: float = 1e-5,
+# learnabel_affine: bool = True,
+# ) -> None:
+# super(LayerNorm, self).__init__()
+# self.dim = dim
+# self.norm = torch.nn.LayerNorm(nout,eps=eps,elementwise_affine=learnabel_affine)
+# def forward(self,x):
+# if self.dim == -1:
+# return self.norm(x)
+# return self.norm(x.transpose(1, -1)).transpose(1, -1)
class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module
@@ -12,8 +61,8 @@ class LayerNorm(torch.nn.LayerNorm):
:param int dim: dimension to be normalized
- def __init__(self, nout, dim=-1):
- super(LayerNorm, self).__init__(nout, eps=1e-12)
+ def __init__(self, nout, dim=-1, eps=1e-5,learnabel_affine: bool = True):
+ super(LayerNorm, self).__init__(nout, eps=eps,elementwise_affine=learnabel_affine)
self.dim = dim
def forward(self, x):
@@ -24,5 +73,56 @@ def forward(self, x):
:rtype torch.Tensor
if self.dim == -1:
- return super(LayerNorm, self).forward(x)
- return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+ return F.layer_norm(
+ x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return F.layer_norm(
+ x.transpose(1, -1), self.normalized_shape, self.weight, self.bias, self.eps).transpose(1, -1)
+ # return super().forward(x)
+ # return super().forward(x.transpose(1, -1)).transpose(1, -1)
+class BasicNorm(torch.nn.Module):
+ """
+ This is intended to be a simpler, and hopefully cheaper, replacement for
+ LayerNorm. The observation this is based on, is that Transformer-type
+ networks, especially with pre-norm, sometimes seem to set one of the
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
+ the LayerNorm because the output magnitude is then not strongly dependent
+ on the other (useful) features. Presumably the weight and bias of the
+ LayerNorm are required to allow it to do this.
+ So the idea is to introduce this large constant value as an explicit
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
+ doesn't have to do this trick. We make the "eps" learnable.
+ Args:
+ num_channels: the number of channels, e.g. 512.
+ channel_dim: the axis/dimension corresponding to the channel,
+ interprted as an offset from the input's ndim if negative.
+ shis is NOT the num_channels; it should typically be one of
+ {-2, -1, 0, 1, 2, 3}.
+ eps: the initial "epsilon" that we add as ballast in:
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
+ Note: our epsilon is actually large, but we keep the name
+ to indicate the connection with conventional LayerNorm.
+ learn_eps: if true, we learn epsilon; if false, we keep it
+ at the initial value.
+ """
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int = -1, # CAUTION: see documentation.
+ eps: float = 0.25,
+ learn_eps: bool = True,
+ ) -> None:
+ super(BasicNorm, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ if learn_eps:
+ self.eps = torch.nn.Parameter(torch.tensor(eps).log().detach())
+ else:
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ scales = (
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ + self.eps.exp()
+ ) ** -0.5
+ return x * scales
\ No newline at end of file
diff --git a/pytorch/libs/nnet/transformer/mask.py b/pytorch/libs/nnet/transformer/mask.py
new file mode 100755
index 0000000..3c34268
--- /dev/null
+++ b/pytorch/libs/nnet/transformer/mask.py
@@ -0,0 +1,136 @@
+# -*- coding:utf-8 -*-
+# Reference: https://github.com/wenet-e2e/wenet
+import torch
+def subsequent_chunk_mask(
+ size: int,
+ chunk_size: int,
+ num_left_chunks: int = -1,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size) with chunk size
+ Args:
+ size (int): size of mask
+ chunk_size (int): size of chunk
+ num_left_chunks (int): number of left chunks
+ <0: use full chunk
+ >=0: use num_left_chunks
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+ Returns:
+ torch.Tensor: mask
+ Examples:
+ >>> subsequent_chunk_mask(4, 2)
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 1],
+ [1, 1, 1, 1]]
+ """
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+ for i in range(size):
+ if num_left_chunks < 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
+ ending = min((i // chunk_size + 1) * chunk_size, size)
+ ret[i, start:ending] = True
+ return ret
+def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
+ use_dynamic_chunk: bool,
+ use_dynamic_left_chunk: bool,
+ decoding_chunk_size: int, static_chunk_size: int,
+ left_chunk_size: int):
+ """ Apply optional mask for encoder.
+ Args:
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
+ mask (torch.Tensor): mask for xs, (B, 1, L)
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
+ training.
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
+ 0: use random dynamic chunk.
+ <0: use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ static_chunk_size (int): chunk size for static chunk training/decoding
+ if it's greater than 0, if use_dynamic_chunk is true,
+ this parameter will be ignored
+ left_chunk_size: number of left chunks.
+ >=0: use left_chunk_size
+ <0: use all left chunks
+ Returns:
+ torch.Tensor: chunk mask of the input xs.
+ """
+ # Whether to use chunk mask or not
+ if use_dynamic_chunk:
+ max_len = xs.size(1)
+ if decoding_chunk_size < 0:
+ chunk_size = max_len
+ num_left_chunks = -1
+ elif decoding_chunk_size > 0:
+ chunk_size = decoding_chunk_size
+ num_left_chunks = left_chunk_size
+ else:
+ # chunk size is either [1, 25] or full context(max_len).
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
+ # delay, the maximum frame is 100 / 4 = 25.
+ chunk_size = torch.randint(1, max_len, (1, )).item()
+ num_left_chunks = -1
+ if chunk_size > max_len // 2:
+ chunk_size = max_len
+ else:
+ chunk_size = chunk_size % 25 + 1
+ if use_dynamic_left_chunk:
+ max_left_chunks = (max_len - 1) // chunk_size
+ num_left_chunks = torch.randint(0, max_left_chunks,
+ (1, )).item()
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ elif static_chunk_size > 0:
+ num_left_chunks = left_chunk_size
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ else:
+ chunk_masks = masks
+ return chunk_masks
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+ """Make mask tensor containing indices of padded part.
+ See description of make_non_pad_mask.
+ Args:
+ lengths (torch.Tensor): Batch of lengths (B,).
+ Returns:
+ torch.Tensor: Mask tensor containing indices of padded part.
+ Examples:
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+ """
+ batch_size = lengths.size(0)
+ max_len = max_len if max_len > 0 else lengths.max().item()
+ seq_range = torch.arange(0,
+ max_len,
+ dtype=torch.int64,
+ device=lengths.device)
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
+ seq_length_expand = lengths.unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+ return mask
\ No newline at end of file
diff --git a/pytorch/libs/nnet/transformer/multi_layer_conv.py b/pytorch/libs/nnet/transformer/multi_layer_conv.py
index 3c21042..5128a0e 100644
--- a/pytorch/libs/nnet/transformer/multi_layer_conv.py
+++ b/pytorch/libs/nnet/transformer/multi_layer_conv.py
@@ -7,7 +7,8 @@
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
import torch
+from libs.nnet.activation import Nonlinearity
+from .scaling import ActivationBalancer,ScaledConv1d,ScaledLinear
class MultiLayeredConv1d(torch.nn.Module):
"""Multi-layered conv1d for Transformer block.
@@ -20,7 +21,7 @@ class MultiLayeredConv1d(torch.nn.Module):
- def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate,activation_type='relu',activation_balancer=False,re_scale=False):
"""Initialize MultiLayeredConv1d module.
@@ -31,13 +32,18 @@ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
super(MultiLayeredConv1d, self).__init__()
- self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size,
+ conv = ScaledConv1d if re_scale else torch.nn.Conv1d
+ self.w_1 = conv(in_chans, hidden_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
- self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size,
+ self.w_2 = conv(hidden_chans, in_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
self.dropout = torch.nn.Dropout(dropout_rate)
+ self.activation = Nonlinearity(activation_type)
+ self.balancer = None
+ if activation_balancer:
+ self.balancer = ActivationBalancer(channel_dim=1)
- def forward(self, x):
+ def forward(self, x,x1:torch.Tensor = torch.empty(0),x2:torch.Tensor = torch.empty(0),mask:torch.Tensor = torch.empty(0),pos_embed:torch.Tensor = torch.empty(0)):
"""Calculate forward propagation.
@@ -47,7 +53,10 @@ def forward(self, x):
Tensor: Batch of output tensors (B, *, hidden_chans)
- x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
+ x = self.w_1(x.transpose(-1, 1))
+ if self.balancer is not None:
+ x=self.balancer(x)
+ x = self.activation(x).transpose(-1, 1)
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
@@ -58,7 +67,7 @@ class Conv1dLinear(torch.nn.Module):
- def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate,activation_type='relu',activation_balancer=False,re_scale=False):
"""Initialize Conv1dLinear module.
@@ -69,12 +78,17 @@ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
super(Conv1dLinear, self).__init__()
- self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size,
+ conv = ScaledConv1d if re_scale else torch.nn.Conv1d
+ self.w_1 = conv(in_chans, hidden_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
- self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
+ self.w_2 = ScaledLinear(hidden_chans, in_chans,initial_scale=0.25) if re_scale else torch.nn.Linear(hidden_chans, in_chans)
self.dropout = torch.nn.Dropout(dropout_rate)
- def forward(self, x):
+ self.activation = Nonlinearity(activation_type)
+ self.balancer = None
+ if activation_balancer:
+ self.balancer = ActivationBalancer(channel_dim=1)
+ def forward(self, x,x1:torch.Tensor = torch.empty(0),x2:torch.Tensor = torch.empty(0),mask:torch.Tensor = torch.empty(0),pos_embed:torch.Tensor = torch.empty(0)):
"""Calculate forward propagation.
@@ -84,5 +98,9 @@ def forward(self, x):
Tensor: Batch of output tensors (B, *, hidden_chans)
- x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
+ x = self.w_1(x.transpose(-1, 1))
+ if self.balancer is not None:
+ x=self.balancer(x)
+ x = self.activation(x).transpose(-1, 1)
return self.w_2(self.dropout(x))
diff --git a/pytorch/libs/nnet/transformer/positionwise_feed_forward.py b/pytorch/libs/nnet/transformer/positionwise_feed_forward.py
index cf70c87..7f2ff01 100644
--- a/pytorch/libs/nnet/transformer/positionwise_feed_forward.py
+++ b/pytorch/libs/nnet/transformer/positionwise_feed_forward.py
@@ -2,7 +2,11 @@
# Reference: https://github.com/espnet/espnet.
+from tkinter import N
+from tkinter.messagebox import NO
import torch
+from libs.nnet.activation import Nonlinearity
+from .scaling import ActivationBalancer,ScaledLinear
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward
@@ -12,11 +16,20 @@ class PositionwiseFeedForward(torch.nn.Module):
:param float dropout_rate: dropout rate
- def __init__(self, idim, hidden_units, dropout_rate):
+ def __init__(self, idim, hidden_units, dropout_rate,activation_type='relu',activation_balancer=False,re_scale=False):
super(PositionwiseFeedForward, self).__init__()
- self.w_1 = torch.nn.Linear(idim, hidden_units)
- self.w_2 = torch.nn.Linear(hidden_units, idim)
+ self.w_1 = ScaledLinear(idim, hidden_units) if re_scale else torch.nn.Linear(idim, hidden_units)
+ self.w_2 = ScaledLinear(hidden_units, idim,initial_scale=0.25) if re_scale else torch.nn.Linear(hidden_units, idim)
self.dropout = torch.nn.Dropout(dropout_rate)
- def forward(self, x):
- return self.w_2(self.dropout(torch.relu(self.w_1(x))))
+ self.activation = Nonlinearity(activation_type)
+ self.balancer = None
+ if activation_balancer:
+ self.balancer = ActivationBalancer(channel_dim=-1)
+ assert self.activation is not None
+ def forward(self, x,x1:torch.Tensor = torch.empty(0),x2:torch.Tensor = torch.empty(0),mask:torch.Tensor = torch.empty(0),pos_embed:torch.Tensor = torch.empty(0)):
+ x = self.w_1(x)
+ if self.balancer is not None:
+ x=self.balancer(x)
+ return self.w_2(self.dropout(self.activation(x)))
+ # return self.w_2(self.dropout(self.activation(self.w_1(x))))
diff --git a/pytorch/libs/nnet/transformer/scaling.py b/pytorch/libs/nnet/transformer/scaling.py
new file mode 100755
index 0000000..6d7debf
--- /dev/null
+++ b/pytorch/libs/nnet/transformer/scaling.py
@@ -0,0 +1,499 @@
+# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
+# For reworked conformer
+# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
+import collections
+from itertools import repeat
+from typing import Optional, Tuple
+import torch
+import torch.backends.cudnn.rnn as rnn
+import torch.nn as nn
+from torch import _VF, Tensor
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+_single = _ntuple(1)
+_pair = _ntuple(2)
+class ActivationBalancerFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ channel_dim: int,
+ min_positive: float, # e.g. 0.05
+ max_positive: float, # e.g. 0.95
+ max_factor: float, # e.g. 0.01
+ min_abs: float, # e.g. 0.2
+ max_abs: float, # e.g. 100.0
+ ) -> Tensor:
+ if x.requires_grad:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ # sum_dims = [d for d in range(x.ndim) if d != channel_dim]
+ # The above line is not torch scriptable for torch 1.6.0
+ # torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
+ sum_dims = []
+ for d in range(x.ndim):
+ if d != channel_dim:
+ sum_dims.append(d)
+ xgt0 = x > 0
+ proportion_positive = torch.mean(
+ xgt0.to(x.dtype), dim=sum_dims, keepdim=True
+ )
+ factor1 = (
+ (min_positive - proportion_positive).relu()
+ * (max_factor / min_positive)
+ if min_positive != 0.0
+ else 0.0
+ )
+ factor2 = (
+ (proportion_positive - max_positive).relu()
+ * (max_factor / (max_positive - 1.0))
+ if max_positive != 1.0
+ else 0.0
+ )
+ factor = factor1 + factor2
+ if isinstance(factor, float):
+ factor = torch.zeros_like(proportion_positive)
+ mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
+ below_threshold = mean_abs < min_abs
+ above_threshold = mean_abs > max_abs
+ ctx.save_for_backward(
+ factor, xgt0, below_threshold, above_threshold
+ )
+ ctx.max_factor = max_factor
+ ctx.sum_dims = sum_dims
+ return x
+ @staticmethod
+ def backward(
+ ctx, x_grad: Tensor
+ ) -> Tuple[Tensor, None, None, None, None, None, None]:
+ factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
+ dtype = x_grad.dtype
+ scale_factor = (
+ (below_threshold.to(dtype) - above_threshold.to(dtype))
+ * (xgt0.to(dtype) - 0.5)
+ * (ctx.max_factor * 2.0)
+ )
+ neg_delta_grad = x_grad.abs() * (factor + scale_factor)
+ return x_grad - neg_delta_grad, None, None, None, None, None, None
+class ScaledLinear(nn.Linear):
+ """
+ A modified version of nn.Linear where the parameters are scaled before
+ use, via:
+ weight = self.weight * self.weight_scale.exp()
+ bias = self.bias * self.bias_scale.exp()
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ initial_speed: this affects how fast the parameter will
+ learn near the start of training; you can set it to a
+ value less than one if you suspect that a module
+ is contributing to instability near the start of training.
+ Nnote: regardless of the use of this option, it's best to
+ use schedulers like Noam that have a warm-up period.
+ Alternatively you can set it to more than 1 if you want it to
+ initially train faster. Must be greater than 0.
+ """
+ def __init__(
+ self,
+ *args,
+ initial_scale: float = 1.0,
+ initial_speed: float = 1.0,
+ **kwargs
+ ):
+ super(ScaledLinear, self).__init__(*args, **kwargs)
+ initial_scale = torch.tensor(initial_scale).log()
+ self.weight_scale = nn.Parameter(initial_scale.clone().detach())
+ if self.bias is not None:
+ self.bias_scale = nn.Parameter(initial_scale.clone().detach())
+ else:
+ self.register_parameter("bias_scale", None)
+ self._reset_parameters(
+ initial_speed
+ ) # Overrides the reset_parameters in nn.Linear
+ def _reset_parameters(self, initial_speed: float):
+ std = 0.1 / initial_speed
+ a = (3 ** 0.5) * std
+ nn.init.uniform_(self.weight, -a, a)
+ if self.bias is not None:
+ nn.init.constant_(self.bias, 0.0)
+ fan_in = self.weight.shape[1] * self.weight[0][0].numel()
+ scale = fan_in ** -0.5 # 1/sqrt(fan_in)
+ with torch.no_grad():
+ self.weight_scale += torch.tensor(scale / std).log()
+ def get_weight(self):
+ return self.weight * self.weight_scale.exp()
+ def get_bias(self):
+ if self.bias is None or self.bias_scale is None:
+ return None
+ else:
+ return self.bias * self.bias_scale.exp()
+ def forward(self, input: Tensor) -> Tensor:
+ return torch.nn.functional.linear(
+ input, self.get_weight(), self.get_bias()
+ )
+class ScaledConv1d(nn.Conv1d):
+ # See docs for ScaledLinear
+ def __init__(
+ self,
+ *args,
+ initial_scale: float = 1.0,
+ initial_speed: float = 1.0,
+ **kwargs
+ ):
+ super(ScaledConv1d, self).__init__(*args, **kwargs)
+ initial_scale = torch.tensor(initial_scale).log()
+ self.bias_scale: Optional[nn.Parameter] # for torchscript
+ self.weight_scale = nn.Parameter(initial_scale.clone().detach())
+ if self.bias is not None:
+ self.bias_scale = nn.Parameter(initial_scale.clone().detach())
+ else:
+ self.register_parameter("bias_scale", None)
+ self._reset_parameters(
+ initial_speed
+ ) # Overrides the reset_parameters in base class
+ def _reset_parameters(self, initial_speed: float):
+ std = 0.1 / initial_speed
+ a = (3 ** 0.5) * std
+ nn.init.uniform_(self.weight, -a, a)
+ if self.bias is not None:
+ nn.init.constant_(self.bias, 0.0)
+ fan_in = self.weight.shape[1] * self.weight[0][0].numel()
+ scale = fan_in ** -0.5 # 1/sqrt(fan_in)
+ with torch.no_grad():
+ self.weight_scale += torch.tensor(scale / std).log()
+ def get_weight(self):
+ return self.weight * self.weight_scale.exp()
+ def get_bias(self):
+ bias = self.bias
+ bias_scale = self.bias_scale
+ if bias is None or bias_scale is None:
+ return None
+ else:
+ return bias * bias_scale.exp()
+ def forward(self, input: Tensor) -> Tensor:
+ F = torch.nn.functional
+ if self.padding_mode != "zeros":
+ return F.conv1d(
+ F.pad(
+ input,
+ self._reversed_padding_repeated_twice,
+ mode=self.padding_mode,
+ ),
+ self.get_weight(),
+ self.get_bias(),
+ self.stride,
+ (0,),
+ self.dilation,
+ self.groups,
+ )
+ return F.conv1d(
+ input,
+ self.get_weight(),
+ self.get_bias(),
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ )
+class ScaledConv2d(nn.Conv2d):
+ # See docs for ScaledLinear
+ def __init__(
+ self,
+ *args,
+ initial_scale: float = 1.0,
+ initial_speed: float = 1.0,
+ **kwargs
+ ):
+ super(ScaledConv2d, self).__init__(*args, **kwargs)
+ initial_scale = torch.tensor(initial_scale).log()
+ self.weight_scale = nn.Parameter(initial_scale.clone().detach())
+ if self.bias is not None:
+ self.bias_scale = nn.Parameter(initial_scale.clone().detach())
+ else:
+ self.register_parameter("bias_scale", None)
+ self._reset_parameters(
+ initial_speed
+ ) # Overrides the reset_parameters in base class
+ def _reset_parameters(self, initial_speed: float):
+ std = 0.1 / initial_speed
+ a = (3 ** 0.5) * std
+ nn.init.uniform_(self.weight, -a, a)
+ if self.bias is not None:
+ nn.init.constant_(self.bias, 0.0)
+ fan_in = self.weight.shape[1] * self.weight[0][0].numel()
+ scale = fan_in ** -0.5 # 1/sqrt(fan_in)
+ with torch.no_grad():
+ self.weight_scale += torch.tensor(scale / std).log()
+ def get_weight(self):
+ return self.weight * self.weight_scale.exp()
+ def get_bias(self):
+ # see https://github.com/pytorch/pytorch/issues/24135
+ bias = self.bias
+ bias_scale = self.bias_scale
+ if bias is None or bias_scale is None:
+ return None
+ else:
+ return bias * bias_scale.exp()
+ def _conv_forward(self, input, weight):
+ F = torch.nn.functional
+ if self.padding_mode != "zeros":
+ return F.conv2d(
+ F.pad(
+ input,
+ self._reversed_padding_repeated_twice,
+ mode=self.padding_mode,
+ ),
+ weight,
+ self.get_bias(),
+ self.stride,
+ (0, 0),
+ self.dilation,
+ self.groups,
+ )
+ return F.conv2d(
+ input,
+ weight,
+ self.get_bias(),
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ )
+ def forward(self, input: Tensor) -> Tensor:
+ return self._conv_forward(input, self.get_weight())
+class ActivationBalancer(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to encourage, for
+ each channel, that it is positive at least a proportion `threshold` of the
+ time. It does this by multiplying negative derivative values by up to
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
+ interpolated from 1 at the threshold to those extremal values when none
+ of the inputs are positive.
+ Args:
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ min_positive: the minimum, per channel, of the proportion of the time
+ that (x > 0), below which we start to modify the derivatives.
+ max_positive: the maximum, per channel, of the proportion of the time
+ that (x > 0), above which we start to modify the derivatives.
+ max_factor: the maximum factor by which we modify the derivatives for
+ either the sign constraint or the magnitude constraint;
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
+ values in the range [0.98..1.02].
+ min_abs: the minimum average-absolute-value per channel, which
+ we allow, before we start to modify the derivatives to prevent
+ this.
+ max_abs: the maximum average-absolute-value per channel, which
+ we allow, before we start to modify the derivatives to prevent
+ this.
+ """
+ def __init__(
+ self,
+ channel_dim: int,
+ min_positive: float = 0.05,
+ max_positive: float = 0.95,
+ max_factor: float = 0.01,
+ min_abs: float = 0.2,
+ max_abs: float = 100.0,
+ ):
+ super(ActivationBalancer, self).__init__()
+ self.channel_dim = channel_dim
+ self.min_positive = min_positive
+ self.max_positive = max_positive
+ self.max_factor = max_factor
+ self.min_abs = min_abs
+ self.max_abs = max_abs
+ def forward(self, x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return x
+ else:
+ return ActivationBalancerFunction.apply(
+ x,
+ self.channel_dim,
+ self.min_positive,
+ self.max_positive,
+ self.max_factor,
+ self.min_abs,
+ self.max_abs,
+ )
+class DoubleSwishFunction(torch.autograd.Function):
+ """
+ double_swish(x) = x * torch.sigmoid(x-1)
+ This is a definition, originally motivated by its close numerical
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
+ Memory-efficient derivative computation:
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
+ Now, s'(x) = s(x) * (1-s(x)).
+ double_swish'(x) = x * s'(x) + s(x).
+ = x * s(x) * (1-s(x)) + s(x).
+ = double_swish(x) * (1-s(x)) + s(x)
+ ... so we just need to remember s(x) but not x itself.
+ """
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ x = x.detach()
+ s = torch.sigmoid(x - 1.0)
+ y = x * s
+ ctx.save_for_backward(s, y)
+ return y
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ s, y = ctx.saved_tensors
+ return (y * (1 - s) + s) * y_grad
+class DoubleSwish(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
+ that we approximate closely with x * sigmoid(x-1).
+ """
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x * torch.sigmoid(x - 1.0)
+ else:
+ return DoubleSwishFunction.apply(x)
+def _test_activation_balancer_sign():
+ probs = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
+ x = x.detach()
+ x.requires_grad = True
+ m = ActivationBalancer(
+ channel_dim=0,
+ min_positive=0.05,
+ max_positive=0.95,
+ max_factor=0.2,
+ min_abs=0.0,
+ )
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_activation_balancer_sign: x = ", x)
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
+def _test_activation_balancer_magnitude():
+ magnitudes = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
+ -1
+ )
+ x = x.detach()
+ x.requires_grad = True
+ m = ActivationBalancer(
+ channel_dim=0,
+ min_positive=0.0,
+ max_positive=1.0,
+ max_factor=0.2,
+ min_abs=0.2,
+ max_abs=0.8,
+ )
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_activation_balancer_magnitude: x = ", x)
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
+def _test_basic_norm():
+ num_channels = 128
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
+ x = torch.randn(500, num_channels)
+ y = m(x)
+ assert y.shape == x.shape
+ x_rms = (x ** 2).mean().sqrt()
+ y_rms = (y ** 2).mean().sqrt()
+ print("x rms = ", x_rms)
+ print("y rms = ", y_rms)
+ assert y_rms < x_rms
+ assert y_rms > 0.5 * x_rms
+def _test_double_swish_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 0.5
+ x.requires_grad = True
+ m = DoubleSwish()
+ torch.autograd.gradcheck(m, x)
+if __name__ == "__main__":
+ _test_activation_balancer_sign()
+ # _test_activation_balancer_magnitude()
+ # _test_basic_norm()
+ # _test_double_swish_deriv()
diff --git a/pytorch/libs/nnet/transformer/subsampling.py b/pytorch/libs/nnet/transformer/subsampling.py
index 085f50c..c4d3df4 100644
--- a/pytorch/libs/nnet/transformer/subsampling.py
+++ b/pytorch/libs/nnet/transformer/subsampling.py
@@ -2,44 +2,568 @@
# Reference: https://github.com/espnet/espnet.
+from turtle import xcor
import torch
+from typing import Tuple
+import torch.nn as nn
+from .scaling import ScaledLinear,ScaledConv2d,ActivationBalancer
+from libs.nnet.activation import DoubleSwish
+from .layer_norm import BasicNorm
+class TooShortUttError(Exception):
+ """Raised when the utt is too short for subsampling.
+ Args:
+ message (str): Message for error catch
+ actual_size (int): the short size that cannot pass the subsampling
+ limit (int): the limit size for subsampling
+ """
+ def __init__(self, message, actual_size, limit):
+ """Construct a TooShortUttError for error handler."""
+ super().__init__(message)
+ self.actual_size = actual_size
+ self.limit = limit
-from .embedding import PositionalEncoding
+def check_short_utt(ins, size):
+ """Check if the utterance is too short for subsampling."""
+ if isinstance(ins, Conv2dSubsampling2) and size < 7:
+ return True, 7
+ if isinstance(ins, Conv2dSubsampling4) and size < 7:
+ return True, 7
+ if isinstance(ins, Conv2dSubsampling6) and size < 11:
+ return True, 11
+ if isinstance(ins, Conv2dSubsampling8) and size < 15:
+ return True, 15
+ return False, -1
-class Conv2dSubsampling(torch.nn.Module):
- """Convolutional 2D subsampling (to 1/4 length)
+class LinearNoSubsampling(torch.nn.Module):
+ """Linear transform the input without subsampling
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc_class (torch.nn.Module): Custom position encoding layer.
- :param int idim: input dim
- :param int odim: output dim
- :param flaot dropout_rate: dropout rate
+ def __init__(self, idim: int, odim: int, dropout_rate: float,mlp_head: bool,
+ pos_enc_class: torch.nn.Module):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(idim, odim),
+ # torch.nn.LayerNorm(odim, eps=1e-5),
+ # torch.nn.Dropout(dropout_rate),
+ )
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 1
+ self.mlp_head = mlp_head
+ self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0)
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
- def __init__(self, idim, odim, dropout_rate):
- super(Conv2dSubsampling, self).__init__()
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+ """
+ x = self.out(x)
+ if self.mlp_head:
+ b,_,_ = x.shape
+ cls_tokens = self.cls_token.repeat(b,1,1)
+ x = torch.cat((cls_tokens,x),dim=1)
+ x, pos_emb = self.pos_enc(x)
+ return x, pos_emb, x_mask
+class Conv2dSubsampling4(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/4 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc_class (torch.nn.Module): Custom position encoding layer.
+ """
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ mlp_head: bool, pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling object."""
+ super(Conv2dSubsampling4, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.Conv2d(odim, odim, 3, 2),
- torch.nn.ReLU()
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
+ )
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 4
+ self.mlp_head = mlp_head
+ self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0)
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 4.
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if self.mlp_head:
+ b,_,_ = x.shape
+ cls_tokens = self.cls_token.repeat(b,1,1)
+ x = torch.cat((cls_tokens,x),dim=1)
+ x, pos_emb = self.pos_enc(x)
+ return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
+ def __getitem__(self, key):
+ """Get item.
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
+ return the positioning encoding.
+ """
+ if key != -1:
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
+ return self.out[key]
+class ReConv2dSubsampling4(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/4 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc_class (torch.nn.Module): Custom position encoding layer.
+ """
+ def __init__(self,
+ idim: int,
+ odim: int,
+ dropout_rate: float,
+ mlp_head: bool,
+ pos_enc_class: torch.nn.Module,
+ layer1_channels: int = 8,
+ layer2_channels: int = 32,
+ layer3_channels: int = 128,
+ ):
+ """Construct an Conv2dSubsampling object."""
+ assert idim >= 7
+ super(ReConv2dSubsampling4, self).__init__()
+ self.conv = nn.Sequential(
+ ScaledConv2d(
+ in_channels=1,
+ out_channels=layer1_channels,
+ kernel_size=3,
+ padding=1,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ ScaledConv2d(
+ in_channels=layer1_channels,
+ out_channels=layer2_channels,
+ kernel_size=3,
+ stride=2,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ ScaledConv2d(
+ in_channels=layer2_channels,
+ out_channels=layer3_channels,
+ kernel_size=3,
+ stride=2,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ )
+ self.out = ScaledLinear(
+ layer3_channels * (((idim - 1) // 2 - 1) // 2), odim
+ )
+ # set learn_eps=False because out_norm is preceded by `out`, and `out`
+ # itself has learned scale, so the extra degree of freedom is not
+ # needed.
+ self.out_norm = BasicNorm(odim, learn_eps=False)
+ # constrain median of output to be close to zero.
+ self.out_balancer = ActivationBalancer(
+ channel_dim=-1, min_positive=0.45, max_positive=0.55
+ )
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 4
+ self.mlp_head = mlp_head
+ self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0)
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 4.
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x = self.out_norm(x)
+ x = self.out_balancer(x)
+ if self.mlp_head:
+ b,_,_ = x.shape
+ cls_tokens = self.cls_token.repeat(b,1,1)
+ x = torch.cat((cls_tokens,x),dim=1)
+ x, pos_emb = self.pos_enc(x)
+ return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
+class SVConv2dSubsampling4(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/4 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc_class (torch.nn.Module): Custom position encoding layer.
+ """
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ mlp_head: bool, pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling object."""
+ super(SVConv2dSubsampling4, self).__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, (2,1)),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, (2,1)),
+ torch.nn.ReLU(),
self.out = torch.nn.Sequential(
- torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
- PositionalEncoding(odim, dropout_rate)
+ torch.nn.Linear(odim * (idim-4) , odim)
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 4
+ self.mlp_head = mlp_head
+ self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0)
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 4.
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if self.mlp_head:
+ b,_,_ = x.shape
+ cls_tokens = self.cls_token.repeat(b,1,1)
+ x = torch.cat((cls_tokens,x),dim=1)
+ x, pos_emb = self.pos_enc(x)
+ return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
- def forward(self, x, x_mask):
- """Subsample x
+ def __getitem__(self, key):
+ """Get item.
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
+ return the positioning encoding.
+ """
+ if key != -1:
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
+ return self.out[key]
+class Conv2dSubsampling2(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/2 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc_class (torch.nn.Module): Custom position encoding layer.
+ """
- :param torch.Tensor x: input tensor
- :param torch.Tensor x_mask: input mask
- :return: subsampled x and mask
- :rtype Tuple[torch.Tensor, torch.Tensor]
+ def __init__(self, idim: int, odim: int, dropout_rate: float, mlp_head: bool, pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling2 object."""
+ super(Conv2dSubsampling2, self).__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 1),
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim))
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 2
+ self.mlp_head = mlp_head
+ self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0)
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 2.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 2.
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- if x_mask is None:
- return x, None
- return x, x_mask[:, :, :-2:2][:, :, :-2:2]
+ if self.mlp_head:
+ b,_,_ = x.shape
+ cls_tokens = self.cls_token.repeat(b,1,1)
+ x = torch.cat((cls_tokens,x),dim=1)
+ x, pos_emb = self.pos_enc(x)
+ return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:1]
+ def __getitem__(self, key):
+ """Get item.
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
+ return the positioning encoding.
+ """
+ if key != -1:
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
+ return self.out[key]
+class SVConv2dSubsampling2(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/2 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc_class (torch.nn.Module): Custom position encoding layer.
+ """
+ def __init__(self, idim: int, odim: int, dropout_rate: float, mlp_head: bool, pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling2 object."""
+ super(SVConv2dSubsampling2, self).__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, (2,1)),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 1),
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * (idim-4), odim))
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 2
+ self.mlp_head = mlp_head
+ self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0)
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 2.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 2.
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if self.mlp_head:
+ b,_,_ = x.shape
+ cls_tokens = self.cls_token.repeat(b,1,1)
+ x = torch.cat((cls_tokens,x),dim=1)
+ x, pos_emb = self.pos_enc(x)
+ return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:1]
+ # return x, pos_emb, x_mask[:, :, :-2:2]
+ def __getitem__(self, key):
+ """Get item.
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
+ return the positioning encoding.
+ """
+ if key != -1:
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
+ return self.out[key]
+class Conv2dSubsampling6(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/6 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+ """
+ def __init__(self, idim: int, odim: int, dropout_rate: float, mlp_head: bool, pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling6 object."""
+ super(Conv2dSubsampling6, self).__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 5, 3),
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
+ )
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 6
+ self.mlp_head = mlp_head
+ self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0)
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 6.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 6.
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if self.mlp_head:
+ b,_,_ = x.shape
+ cls_tokens = self.cls_token.repeat(b,1,1)
+ x = torch.cat((cls_tokens,x),dim=1)
+ x, pos_emb = self.pos_enc(x)
+ return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
+class Conv2dSubsampling8(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/8 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+ """
+ def __init__(self, idim: int, odim: int, dropout_rate: float, mlp_head: bool, pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling8 object."""
+ super(Conv2dSubsampling8, self).__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim))
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 8
+ self.mlp_head = mlp_head
+ self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0)
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 8.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 8.
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if self.mlp_head:
+ b,_,_ = x.shape
+ cls_tokens = self.cls_token.repeat(b,1,1)
+ x = torch.cat((cls_tokens,x),dim=1)
+ x, pos_emb = self.pos_enc(x)
+ return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
+class FrameMerger(nn.Module):
+ def __init__(self, dim: int,
+ normalize_before: bool = True,
+ norm_type: str = "layer_norm",):
+ super().__init__()
+ # self.conv = torch.nn.Sequential(
+ # torch.nn.Conv1d(dim, dim*2, 3,2),
+ # torch.nn.ReLU(),
+ # torch.nn.Conv1d(dim*2, dim*2, 3, 2),
+ # torch.nn.ReLU(),
+ # )
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv1d(dim, 2*dim, 3,2),
+ torch.nn.ReLU(),
+ torch.nn.Conv1d(2*dim, 2*dim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ # self.conv = torch.nn.Sequential(
+ # torch.nn.Conv2d(1, dim, 3, (2,1)),
+ # torch.nn.ReLU(),
+ # torch.nn.Conv2d(dim, dim, 3, (2,1)),
+ # torch.nn.ReLU(),
+ # )
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.norm = nn.LayerNorm(dim)
+ else:
+ self.norm = nn.LayerNorm(dim*2)
+ def forward(self, x,pos_emb:torch.Tensor,mask:torch.Tensor,offset: int = 0):
+ if self.normalize_before:
+ x = self.norm(x)
+ x = x.transpose(1, 2)
+ x = self.conv(x)
+ x = x.transpose(1, 2)
+ if not self.normalize_before:
+ x = self.norm(x)
+ return x,pos_emb[:, offset:offset + x.size(1)],mask[:, :, :-2:2][:, :, :-2:2]
\ No newline at end of file
diff --git a/pytorch/libs/support/utils.py b/pytorch/libs/support/utils.py
index a9800a9..5f74073 100755
--- a/pytorch/libs/support/utils.py
+++ b/pytorch/libs/support/utils.py
@@ -4,6 +4,7 @@
import sys, os
import math
+import yaml
import random
import logging
import shutil
@@ -181,6 +182,7 @@ def create_model_from_py(model_blueprint, model_creation=""):
return model_module
model = eval("model_module.{0}".format(model_creation))
return model
@@ -214,7 +216,7 @@ def create_model_dir(model_dir:str, model_blueprint:str, stage=-1):
os.makedirs("{0}/config".format(model_dir), exist_ok=True)
os.makedirs("{0}/checkpoint_info".format(model_dir), exist_ok=True)
if is_main_training():
- if stage < 0 and model_blueprint != config_model_blueprint:
+ if stage <= 0 and model_blueprint != config_model_blueprint:
shutil.copy(model_blueprint, config_model_blueprint)
@@ -246,7 +248,7 @@ def read_file_to_list(file_path, every_bytes=10000000):
return list
-def write_list_to_file(this_list, file_path, mod='w'):
+def write_list_to_file(this_list, file_path, mod='w', yml=False):
@mod: could be 'w' or 'a'
@@ -254,8 +256,14 @@ def write_list_to_file(this_list, file_path, mod='w'):
this_list = [this_list]
with open(file_path, mod) as writer :
- writer.write('\n'.join(str(x) for x in this_list))
- writer.write('\n')
+ if yml:
+ for x in this_list:
+ yaml.dump(x,writer)
+ else:
+ writer.write('\n'.join(str(x) for x in this_list))
+ writer.write('\n')
def save_checkpoint(checkpoint_path, **kwargs):
@@ -309,9 +317,11 @@ def key_to_value(adict, key, return_none=True):
def assign_params_dict(default_params:dict, params:dict, force_check=False, support_unknow=False):
default_params = copy.deepcopy(default_params)
- default_keys = set(default_params.keys())
+ default_keys = set(default_params.keys())
# Should keep force_check=False to use support_unknow
if force_check:
for key in param.keys():
@@ -484,7 +494,7 @@ def get_free_port(ip=""):
return s.getsockname()[1]
# https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/utils/data_utils.py
-def batch_pad_right(tensors: list, mode="constant", value=0):
+def batch_pad_right(tensors: list, mode="constant", value=0,val_index=-1):
"""Given a list of torch tensors it batches them together by padding to the right
on each dimension in order to get same length for all.
@@ -543,10 +553,10 @@ def batch_pad_right(tensors: list, mode="constant", value=0):
t, max_shape, mode=mode, value=value
- valid.append(valid_percent[0])
+ valid.append(valid_percent[val_index])
batched = torch.stack(batched)
return batched, torch.tensor(valid)
@@ -589,7 +599,6 @@ def pad_right_to(
valid_vals.append(tensor.shape[j] / target_shape[j])
i -= 1
j += 1
tensor = torch.nn.functional.pad(tensor, pads, mode=mode, value=value)
return tensor, valid_vals
diff --git a/pytorch/libs/training/lr_scheduler_online.py b/pytorch/libs/training/lr_scheduler_online.py
old mode 100644
new mode 100755
index dbd3082..7377c43
--- a/pytorch/libs/training/lr_scheduler_online.py
+++ b/pytorch/libs/training/lr_scheduler_online.py
@@ -7,6 +7,7 @@
import math
import numpy as np
import torch
+from typing import Union
from torch.optim.lr_scheduler import _LRScheduler
from .optim import *
@@ -39,7 +40,8 @@ def __init__(self, optimizer, params:dict={}):
- "1cycle.anneal_strategy":'linear',
+ "1cycle.warmup_steps":None,
+ "1cycle.anneal_strategy":'cos', # ["cos", "linear"]
@@ -53,6 +55,11 @@ def __init__(self, optimizer, params:dict={}):
+ "noam.warmup_steps": 2000,
+ "noam.step_decay": False,
+ "noam.step_size": 34000,
+ "noam.step_rate": 0.5,
@@ -79,13 +86,23 @@ def __init__(self, optimizer, params:dict={}):
self.step_total=int(step_up + step_down)
self.lr_scheduler = torch.optim.lr_scheduler.CyclicLR(base_optimizer, base_lr, max_lr, **split_params["cyclic"])
elif self.name == "1cycle":
+ warmup_steps = split_params["1cycle"].pop("warmup_steps")
+ pct_start = split_params["1cycle"].pop("pct_start")
max_lr = split_params["1cycle"].pop("learn_rate")
- self.lr_scheduler = optim.lr_scheduler.OneCycleLR(base_optimizer, max_lr, **split_params["1cycle"])
+ self.lr_scheduler = optim.lr_scheduler.OneCycleLR(base_optimizer, max_lr, pct_start=pct_start,**split_params["1cycle"])
+ self.step_total = self.lr_scheduler.total_steps
+ if warmup_steps is not None:
+ pct_start = warmup_steps/self.step_total
+ self.lr_scheduler = optim.lr_scheduler.OneCycleLR(base_optimizer, max_lr, pct_start=pct_start,**split_params["1cycle"])
elif self.name == "warmR":
self.T_0 = split_params["warmR"].pop("T_max")
self.T_mult = split_params["warmR"]["T_mult"]
self.lr_decay_step = split_params["warmR"].pop("lr_decay_step")
self.lr_scheduler = CosineAnnealingWarmRestarts(base_optimizer, self.T_0, **split_params["warmR"])
+ elif self.name == "noam":
+ warmup_steps = split_params["noam"].pop("warmup_steps")
+ self.lr_scheduler = WarmupLR(base_optimizer, warmup_steps, **split_params["noam"])
+ pass
elif self.name == "reduceP":
self.check_interval = split_params["reduceP"].pop("check_interval")
self.metric = split_params["reduceP"].pop("metric")
@@ -120,7 +137,7 @@ def is_cycle_point(self, training_point):
if math.log(max(0.05, (epoch / self.T_0 * (self.T_mult - 1) + 1)), self.T_mult)%1==0 and epoch>0:
return False if (self.lr_decay_step == 0 and training_point[1]>1) else True
- if self.name=="cyclic":
+ if self.name in ["cyclic", "1cycle"] :
return True if training_point[2]%self.step_total==0 and training_point[2]>0 else False
@@ -137,12 +154,16 @@ def step(self, training_point=None, valid_metric=None):
elif self.name == "cyclic":
elif self.name == "1cycle":
- self.lr_scheduler.step(training_point[2])
+ self.lr_scheduler.step()
elif self.name == "reduceP":
# Sample a point in which the metrics of valid are computed and adjust learning rate at this point.
if self.is_reduce_point(training_point):
metric = valid_metric[0] if self.metric == "valid_loss" else valid_metric[1]
+ elif self.name == "noam":
+ self.lr_scheduler.step()
+ else:
+ raise ValueError("unknown lr_scheduler: " + self.name)
## Learn rate scheduler ✿
class CosineAnnealingWarmRestarts(_LRScheduler):
@@ -174,7 +195,7 @@ class CosineAnnealingWarmRestarts(_LRScheduler):
Base lr decay has been added. [Snowdar 2019-08-29]
- def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, factor=1.0, log_decay=False, last_epoch=-1):
+ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, factor=1.0, log_decay=False, last_epoch=-1, warmup_steps=5000):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult <=0: # or not isinstance(T_mult, int):
@@ -251,3 +272,67 @@ def step(self, epoch=None):
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
+# Reference: https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/scheduler.py
+class WarmupLR(_LRScheduler):
+ """The WarmupLR scheduler
+ This scheduler is almost same as NoamLR Scheduler except for following
+ difference:
+ NoamLR:
+ lr = optimizer.lr * model_size ** -0.5
+ * min(step ** -0.5, step * warmup_step ** -1.5)
+ WarmupLR:
+ lr = optimizer.lr * warmup_step ** 0.5
+ * min(step ** -0.5, step * warmup_step ** -1.5)
+ Note that the maximum lr equals to optimizer.lr in this scheduler.
+ """
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_steps: Union[int, float] = 25000,
+ step_decay: bool=False,
+ step_size: int=80000,
+ step_rate: float = 0.5,
+ last_epoch: int = -1,
+ ):
+ self.warmup_steps = warmup_steps
+ self.step_decay = step_decay
+ self.step_size = step_size
+ self.step_rate = step_rate
+ # __init__() must be invoked before setting field
+ # because step() is also invoked in __init__()
+ super().__init__(optimizer, last_epoch)
+ def __repr__(self):
+ return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
+ def get_lr(self):
+ step_num = self.last_epoch + 1
+ if step_num>> from optimizer import RAdam
- >>> optimizer = RAdam(model.parameters(), lr=0.001)
- Note, here the weight decay is not L2 regularization.
- '''
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), N_sma_threshhold=4, eps=1e-8, weight_decay=0):
- defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
- self.N_sma_threshhold = N_sma_threshhold
- self.buffer = [[None, None, None] for ind in range(10)]
- super(RAdam, self).__init__(params, defaults)
- def __setstate__(self, state):
- super(RAdam, self).__setstate__(state)
- def step(self, closure=None):
- loss = None
- if closure is not None:
- loss = closure()
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is None:
- continue
- grad = p.grad.data.float()
- if grad.is_sparse:
- raise RuntimeError('RAdam does not support sparse gradients')
- p_data_fp32 = p.data.float()
- state = self.state[p]
- if len(state) == 0:
- state['step'] = 0
- state['exp_avg'] = torch.zeros_like(p_data_fp32)
- state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
- else:
- state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
- state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
- beta1, beta2 = group['betas']
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
- exp_avg.mul_(beta1).add_((1 - beta1) * grad)
- state['step'] += 1
- buffered = self.buffer[int(state['step'] % 10)]
- if state['step'] == buffered[0]:
- N_sma, step_size = buffered[1], buffered[2]
- else:
- buffered[0] = state['step']
- beta2_t = beta2 ** state['step']
- N_sma_max = 2 / (1 - beta2) - 1
- N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
- buffered[1] = N_sma
- if N_sma > self.N_sma_threshhold:
- step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
- else:
- step_size = group['lr'] / (1 - beta1 ** state['step'])
- buffered[2] = step_size
- if group['weight_decay'] != 0:
- p_data_fp32.add_(-group['weight_decay'] * group['lr'] * p_data_fp32)
- if N_sma > self.N_sma_threshhold:
- denom = exp_avg_sq.sqrt().add_(group['eps'])
- p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
- else:
- p_data_fp32.add_(-step_size * exp_avg)
- p.data.copy_(p_data_fp32)
- return loss
class Ralamb(Optimizer):
@@ -819,3 +760,223 @@ def step(self, closure=None):
p.data.add_(-group['lr'] * exp_avg)
return loss
+# Sharpness-Aware Minimization for Efficiently Improving Generalization.
+# https://openreview.net/pdf?id=6Tm1mposlrM
+# https://github.com/davda54/sam
+class SAM(Optimizer):
+ def __init__(self, optimizer, rho=0.05, adaptive=False) -> None:
+ assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
+ defaults = dict(rho=rho, adaptive=adaptive)
+ super(SAM, self).__init__(optimizer.param_groups, defaults)
+ self.base_optimizer = optimizer
+ self.defaults.update(self.base_optimizer.defaults)
+ @torch.no_grad()
+ def first_step(self, zero_grad=False):
+ grad_norm = self._grad_norm()
+ for group in self.param_groups:
+ scale = group["rho"] / (grad_norm + 1e-12)
+ for p in group["params"]:
+ if p.grad is None: continue
+ self.state[p]["old_p"] = p.data.clone()
+ self.state[p]["cache"] = True
+ e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
+ p.add_(e_w) # climb to the local maximum "w + e(w)"
+ if zero_grad: self.zero_grad()
+ @torch.no_grad()
+ def second_step(self, zero_grad=False):
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None: continue
+ p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
+ self.state[p]["cache"] = False
+ self.base_optimizer.step() # do the actual "sharpness-aware" update
+ if zero_grad: self.zero_grad()
+ @torch.no_grad()
+ def back_w(self):
+ for group in self.param_groups:
+ for p in group["params"]:
+ p_cache = self.state[p].get('cache',False)
+ if p_cache:
+ p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
+ self.state[p]["cache"] = False
+ @torch.no_grad()
+ def step(self,stage=1):
+ if stage==1:
+ self.first_step(zero_grad=True)
+ return True
+ else:
+ self.second_step(zero_grad=True)
+ return True
+ def _grad_norm(self):
+ shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
+ norm = torch.norm(
+ torch.stack([
+ ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
+ for group in self.param_groups for p in group["params"]
+ if p.grad is not None
+ ]),
+ p=2
+ )
+ return norm
+ def load_state_dict(self, state_dict):
+ super().load_state_dict(state_dict)
+ self.base_optimizer.param_groups = self.param_groups
+class Eve(Optimizer):
+ r"""
+ Implements Eve algorithm. This is a modified version of AdamW with a special
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
+ will be close to invariant to the absolute scale on the parameter matrix.
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
+ Eve is unpublished so far.
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
+ this value means that the weight would decay significantly after
+ about 3k minibatches. Is not multiplied by learning rate, but
+ is conditional on RMS-value of parameter being > target_rms.
+ target_rms (float, optional): target root-mean-square value of
+ parameters, if they fall below this we will stop applying weight decay.
+ .. _Adam\: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.98),
+ eps=1e-8,
+ weight_decay=1e-3,
+ target_rms=0.1,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError(
+ "Invalid beta parameter at index 0: {}".format(betas[0])
+ )
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError(
+ "Invalid beta parameter at index 1: {}".format(betas[1])
+ )
+ if not 0 <= weight_decay <= 0.1:
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
+ if not 0 < target_rms <= 10.0:
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ target_rms=target_rms,
+ )
+ super(Eve, self).__init__(params, defaults)
+ def __setstate__(self, state):
+ super(Eve, self).__setstate__(state)
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ # Perform optimization step
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "AdamW does not support sparse gradients"
+ )
+ state = self.state[p]
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+ beta1, beta2 = group["betas"]
+ state["step"] += 1
+ bias_correction1 = 1 - beta1 ** state["step"]
+ bias_correction2 = 1 - beta2 ** state["step"]
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
+ group["eps"]
+ )
+ step_size = group["lr"] / bias_correction1
+ target_rms = group["target_rms"]
+ weight_decay = group["weight_decay"]
+ if p.numel() > 1:
+ # avoid applying this weight-decay on "scaling factors"
+ # (which are scalar).
+ is_above_target_rms = p.norm() > (
+ target_rms * (p.numel() ** 0.5)
+ )
+ p.mul_(1 - (weight_decay * is_above_target_rms))
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+ # Constrain the range of scalar weights
+ if p.numel() == 1:
+ p.clamp_(min=-10, max=2)
+ return loss
\ No newline at end of file
diff --git a/pytorch/libs/training/trainer_online.py b/pytorch/libs/training/trainer_online.py
old mode 100644
new mode 100755
index 9876603..60a0b95
--- a/pytorch/libs/training/trainer_online.py
+++ b/pytorch/libs/training/trainer_online.py
@@ -15,6 +15,7 @@
import numpy as np
import yaml
import torch
+import torch.distributed as dist
from torch.utils.data import DataLoader
from contextlib import nullcontext
from .reporter import Reporter_new as Reporter
@@ -60,8 +61,8 @@ class _BaseTrainer():
def __init__(self, package, stop_early=False):
default_elements = {"data": None, "model": None,
"optimizer": None, "lr_scheduler": None}
- default_params = {"model_dir": "", "model_blueprint": "", "exist_model": "", "start_epoch": 0, "epochs": 10,
- "use_gpu": True, "gpu_id": "", "benchmark": True, "max_change": 10.0, "use_amp": False, "accum_grad": 1,
+ default_params = {"model_dir": "", "model_blueprint": "", "exist_model": "", "start_epoch": 0, "epochs": 10, "warmup_steps":0,
+ "use_gpu": True, "gpu_id": "", "benchmark": True, "max_change": 5.0, "use_amp": False, "accum_grad": 1,
"compute_accuracy": True, "compute_valid_accuracy": True, "compute_batch_num_valid": 1,
"suffix": "params", "nan_debug": False, "skip_nan_batch": True, "use_tensorboard": True}
@@ -90,6 +91,7 @@ def __init__(self, package, stop_early=False):
self.training_point = copy.deepcopy([self.params["start_epoch"], 0, 0])
self.cycle_point = 0 # for cycle training.
+ self.warmup_steps = self.params["warmup_steps"] # model level warm up. just for transformer.
def select_device(self):
return utils.select_model_device(self.elements["model"], self.params["use_gpu"],
gpu_id=self.params["gpu_id"], benchmark=self.params["benchmark"])
@@ -140,7 +142,11 @@ def init_training(self):
pass # Now, it means use the raw initial model
# Select device
+ # for k,v in model.named_parameters():
+ # if "train_len" in k:
+ # print(v)
+ # assert 1==0
model = self.select_device()
# Original model is built in libs.nnet.framework.TopVirtualNnet, and it is not available after
@@ -215,7 +221,11 @@ def train_one_batch(self, batch, step_lr=True):
model = self.elements["model"]
model_forward = self.elements["model_forward"]
optimizer = self.elements["optimizer"]
+ cur_step = self.training_point[2]
+ # model level warm up. just for transformer.
+ warmup = cur_step / self.warmup_steps if self.warmup_steps >0 else 1.0
+ warmup = torch.FloatTensor([warmup])
if not model.training:
@@ -229,8 +239,13 @@ def train_one_batch(self, batch, step_lr=True):
- inputs, targets = batch
+ inputs, targets, feats_lens = batch
+ feats_lens = (feats_lens*inputs.shape[2]).long()
+ input_list = [inputs,feats_lens]
+ # model level warm up. just for transformer.
+ if self.warmup_steps >0:
+ input_list.append(warmup)
context = None
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
@@ -246,13 +261,17 @@ def train_one_batch(self, batch, step_lr=True):
# Managing automatic mixed precision (Leo 2021-11-08)
with torch.cuda.amp.autocast(self.scaler is not None):
- loss = model.get_loss(model_forward(inputs), targets)/self.params["accum_grad"]
+ loss = model.get_loss(model_forward(*input_list), targets)/self.params["accum_grad"]
# loss = model_forward(inputs, targets)/self.params["accum_grad"]
if self.params["use_amp"]:
+ # for name, param in model.named_parameters():
+ # if param.grad is None:
+ # print(name)
+ # sys.exit()
loss.detach() # For safe.
loss = loss.item()*self.params["accum_grad"]
accuracy = None
@@ -320,13 +339,15 @@ def compute_validation(self, data_loader):
num_samples = 0
with torch.no_grad():
for idx,this_data in enumerate(data_loader):
- inputs, targets = this_data
+ inputs, targets, feats_lens = this_data
+ feats_lens = (feats_lens*inputs.shape[2]).long()
num_utts = targets.size(0)
+ input_list = [inputs,feats_lens]
if num_utts == 0:
# in valid stage, DO NOT call ddp model, for ddp model is in JOIN context wrapper.
# Leo 2022-02-03
- loss += model.get_loss(model(inputs),
+ loss += model.get_loss(model(*input_list),
targets).item() * len(targets)
# loss += model_forward(inputs,targets).item() * len(targets)
@@ -434,20 +455,32 @@ def run(self):
model_context = model_forward.join(throw_on_early_termination=True)
model_context = nullcontext()
+ device = utils.get_device(self.elements["model"])
+ stop_training = torch.zeros(1).to(device)
with model_context:
- for this_epoch in range(start_epoch, epochs):
- self.training_point[0]+=1
+ for this_epoch in range(start_epoch, epochs+5):
+ # In case the uneven data in different ranks when ddp training,
+ # here we design a more epoch for sub-ranks to ensure the main rank can broadcasting.
+ if utils.is_main_training:
+ if this_epoch == epochs: # skip the last+1 epoch
+ stop_training = torch.ones(1).to(device)
+ self.training_point[0]+=1
# with model_context:
+ if utils.is_main_training() and self.training_point[1]==0: self.origin_total_dur,self.num_sample=data.train_loader.dataset.get_data_dur()
for _, batch in enumerate(data.train_loader, 0):
# It is important for reporter.
- if utils.is_main_training() and self.training_point[1]==0: self.origin_total_dur,self.num_sample=data.train_loader.dataset.get_data_dur()
+ dist.barrier()
+ if utils.use_ddp():dist.all_reduce(stop_training,op=dist.ReduceOp.SUM)
+ if stop_training:
+ break
self.training_point[1] +=1
num_utts = batch[0].size(0)
if num_utts == 0:
@@ -459,9 +492,12 @@ def run(self):
loss, acc = self.train_one_batch(batch)
+ if stop_training:
+ break
if utils.is_main_training():self.save_model(train_lr=self.train_lr)
self.training_point[1] =0
+ # dist.barrier()
if utils.is_main_training():self.reporter.finish()
if utils.is_main_training():
final_model_name = "{}_cycle".format(self.cycle_point) if self.cycle_point else epochs
@@ -475,6 +511,9 @@ def run(self):
if not isinstance(e, KeyboardInterrupt):traceback.print_exc()
def lr_finder_compute(self, train_batch):
model = self.elements["model"]
diff --git a/pytorch/libs/training/trainer_online_sam.py b/pytorch/libs/training/trainer_online_sam.py
new file mode 100755
index 0000000..244041d
--- /dev/null
+++ b/pytorch/libs/training/trainer_online_sam.py
@@ -0,0 +1,639 @@
+# -*- coding:utf-8 -*-
+# Copyright xmuspeech (Author: Snowdar 2019-11-25
+# Leo 2022-01-15)
+import os
+import sys
+import re
+import logging
+import copy
+import math
+import time
+import traceback
+import progressbar
+import pandas as pd
+import numpy as np
+import yaml
+import torch
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.utils.data import DataLoader
+import torch.distributed as dist
+from contextlib import nullcontext
+from .reporter import Reporter_new as Reporter
+from .lr_scheduler_online import LRSchedulerWrapper
+from .lr_finder import for_lr_finder_new
+import libs.support.utils as utils
+# Wrap stderr before logger init.
+# Logger
+logger = logging.getLogger(__name__)
+This is the structure of Package:
+ Elements{
+ data:Bunch
+ model:TopVirtualNnet
+ optimizer:Optimizer
+ lr_scheduler:LR_Scheduler
+ },
+ Params{
+ model_dir:str
+ exist_model:str
+ start_epoch:int
+ epochs:int
+ ...
+ }
+ )
+training_point (this epoch, iter in epoch, global step)
+# Trainer ✿
+class _BaseTrainer():
+ def __init__(self, package, stop_early=False):
+ default_elements = {"data": None, "model": None,
+ "optimizer": None, "lr_scheduler": None}
+ default_params = {"model_dir": "", "model_blueprint": "", "exist_model": "", "start_epoch": 0, "epochs": 10,"warmup_steps":0,
+ "use_gpu": True, "gpu_id": "", "benchmark": True, "max_change": 5.0, "use_amp": False, "accum_grad": 1,
+ "compute_accuracy": True, "compute_valid_accuracy": True, "compute_batch_num_valid": 1,
+ "suffix": "params", "nan_debug": False, "skip_nan_batch": True, "use_tensorboard": True}
+ elements, params = package
+ self.elements = utils.assign_params_dict(default_elements, elements)
+ self.params = utils.assign_params_dict(
+ default_params, params, support_unknow=True)
+ assert self.elements["data"] is not None
+ assert self.elements["model"] is not None
+ assert self.elements["optimizer"] is not None
+ assert self.params["model_dir"] != ""
+ assert self.params["model_blueprint"] != ""
+ assert self.params["accum_grad"] == 1 ,"Sharpness Aware Minimization do not support accum_grad now."
+ self.elements["model_forward"] = self.elements["model"]
+ self.params["start_epoch"] = max(0, self.params["start_epoch"])
+ self.params["accum_grad"] = max(1,self.params["accum_grad"])
+ self.use_ddp = utils.use_ddp()
+ self.stop_early = stop_early # To do.
+ # Automatic mixed precision init.(Leo 2021-11-08)
+ self.scaler = torch.cuda.amp.GradScaler() if self.params["use_amp"] else None
+ # (epoch, iter in epoch, global step)
+ self.training_point = copy.deepcopy([self.params["start_epoch"], 0, 0])
+ self.cycle_point = 0 # for cycle training.
+ self.warmup_steps = self.params["warmup_steps"] # model level warm up. just for transformer.
+ def select_device(self):
+ return utils.select_model_device(self.elements["model"], self.params["use_gpu"],
+ gpu_id=self.params["gpu_id"], benchmark=self.params["benchmark"])
+ def init_training(self):
+ model = self.elements["model"]
+ start_epoch = self.params["start_epoch"]
+ exist_model = self.params["exist_model"]
+ model_dir = self.params["model_dir"]
+ model_blueprint = self.params["model_blueprint"]
+ suffix = self.params["suffix"]
+ if start_epoch <= 0 and utils.is_main_training():
+ model_creation = model.get_model_creation()
+ utils.write_nnet_config(
+ model_blueprint, model_creation, "{0}/config/nnet.config".format(model_dir))
+ # Recover checkpoint | Tansform learning | Initialize parametes
+ if start_epoch > 0:
+ # This train_stage is equal to number of completed epoch
+ if utils.is_main_training():
+ logger.info(
+ "Recover training from {0} epoch.".format(start_epoch))
+ model.load_state_dict(torch.load('{0}/{1}.{2}'.format(model_dir, start_epoch, suffix),
+ map_location="cpu"))
+ # info_path = '{0}/{1}/{2}.{3}'.format(
+ # model_dir, "checkpoint_info", start_epoch, "info")
+ info_log_path = '{0}/{1}/{2}.{3}'.format(
+ model_dir, "checkpoint_info", start_epoch, "yaml")
+ if os.path.exists(info_log_path):
+ # info = torch.load(info_path)
+ # self.elements["optimizer"].load_state_dict(info['optimizer'])
+ # for state in self.elements["optimizer"].values():
+ # for k, v in state.items():
+ # if isinstance(v, torch.Tensor):
+ # state[k] = v.cuda()
+ with open(info_log_path, 'r') as fin:
+ info = yaml.load(fin, Loader=yaml.FullLoader)
+ self.training_point[2] = info['step']
+ elif os.path.exists(exist_model):
+ if utils.is_main_training():
+ logger.info(
+ "Use {0} as the initial model to start transform-training.".format(exist_model))
+ model.load_transform_state_dict(
+ torch.load(exist_model, map_location="cpu"))
+ else:
+ # Just use the raw initial model or initialize it again by some initial functions here
+ pass # Now, it means use the raw initial model
+ # Select device
+ model = self.select_device()
+ # Original model is built in libs.nnet.framework.TopVirtualNnet, and it is not available after
+ # wrapped by DistributedDataParallel. So, to call functions of TopVirtualNnet conveniently, the
+ # self.elements["model_forward"] is set here to name DistributedDataParallel.
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ self.elements["model"] = model.module
+ self.elements["model_forward"] = model
+ def save_model(self, mod="epoch",train_lr=None,valid_loss=None):
+ assert mod in ["epoch", "iter", "cycle"]
+ if mod == "epoch":
+ model_name = self.training_point[0]
+ elif mod == "iter":
+ model_name = "{}.{}".format(
+ self.training_point[0], self.training_point[1])
+ else:
+ model_name = "{}_cycle".format(self.cycle_point)
+ model_path = '{0}/{1}.{2}'.format(
+ self.params["model_dir"], model_name, self.params["suffix"])
+ # info = {
+ # 'optimizer': self.elements["optimizer"].state_dict(),
+ # 'step': self.training_point[2],
+ # }
+ info_log = {
+ 'train_lr': train_lr if train_lr else "see train.csv",
+ "next_lr": self.elements["optimizer"].state_dict()['param_groups'][0]['lr'],
+ 'epoch': self.training_point[0],
+ 'iter in epoch': self.training_point[1],
+ 'step': self.training_point[2],
+ 'valid_loss':valid_loss if valid_loss else "see train.csv"
+ }
+ # info_path = '{0}/{1}/{2}.{3}'.format(
+ # self.params["model_dir"], "checkpoint_info", model_name, "info")
+ # info_log_path = re.sub('.info$', '.yaml', info_path)
+ info_log_path = '{0}/{1}/{2}.{3}'.format(
+ self.params["model_dir"], "checkpoint_info", model_name, "yaml")
+ logger.info("Save model to {0}. \n epoch/iter: {1}/{2}. cur_step: {3}".format(model_path, self.training_point[0],
+ self.training_point[1], self.training_point[2]))
+ torch.save(self.elements["model"].state_dict(), model_path)
+ # torch.save(info, info_path)
+ with open(info_log_path, 'w') as fout:
+ data = yaml.dump(info_log)
+ fout.write(data)
+ def run(self):
+ raise NotImplementedError
+ @for_lr_finder_new
+ def lr_finder_compute(self, train_batch):
+ # Only train_batch parameter for it's always main metric.
+ raise NotImplementedError
+ def run_lr_finder(self):
+ # Implement this function w.r.t self.lr_finder_compute().
+ raise NotImplementedError
+class SimpleTrainer(_BaseTrainer):
+ """One input and one output.
+ """
+ def __init__(self, *args, **kwargs):
+ super(SimpleTrainer, self).__init__(*args, **kwargs)
+ self.num_batch=0
+ def train_one_batch(self, batch, step_lr=True):
+ """A normal training core without fetching data from iterator.
+ """
+ model = self.elements["model"]
+ model_forward = self.elements["model_forward"]
+ optimizer = self.elements["optimizer"]
+ device = utils.get_device(self.elements["model"])
+ # model level warm up. just for transformer.
+ cur_step = self.training_point[2]
+ warmup = cur_step / self.warmup_steps if self.warmup_steps >0 else 1.0
+ warmup = torch.FloatTensor([warmup])
+ if not model.training:
+ model.train()
+ if self.params["nan_debug"]:
+ device = utils.get_device(self.elements["model"])
+ inputs = torch.load(
+ "{0}/nan.batch".format(self.params["model_dir"])).to(device)
+ targets = torch.load(
+ "{0}/nan.targets".format(self.params["model_dir"])).to(device)
+ self.elements["model"].load_state_dict(torch.load("{0}/nan.params".format(self.params["model_dir"]),
+ map_location="cpu"))
+ self.elements["model"].to(device)
+ else:
+ inputs, targets, feats_lens = batch
+ feats_lens = (feats_lens*inputs.shape[2]).long()
+ input_list = [inputs,feats_lens]
+ # model level warm up. just for transformer.
+ if self.warmup_steps >0:
+ input_list.append(warmup)
+ context = None
+ # Disable gradient synchronizations across DDP processes in first background.
+ if self.use_ddp :
+ context = model_forward.no_sync
+ # Used for single gpu training and DDP gradient synchronization
+ # processes.
+ else:
+ context = nullcontext
+ # first forward-backward pass
+ with context():
+ # Managing automatic mixed precision (Leo 2021-11-08)
+ with torch.cuda.amp.autocast(self.scaler is not None):
+ enable_running_stats(model_forward)
+ loss = model.get_loss(model_forward(*input_list), targets)/self.params["accum_grad"]
+ # loss = model_forward(inputs, targets)/self.params["accum_grad"]
+ if self.params["use_amp"]:
+ self.scaler.scale(loss).backward()
+ else:
+ loss.backward()
+ loss.detach() # For safe.
+ loss = loss.item()*self.params["accum_grad"]
+ accuracy = model.get_accuracy(targets) if self.params["compute_accuracy"] else None
+ # for name, param in model.named_parameters():
+ # if param.grad is None:
+ # print(name)
+ # sys.exit()
+ # Use mixed precision training (Leo 2021-11-08)
+ if self.params["use_amp"]:
+ self.scaler.unscale_(optimizer)
+ if not self.modify_grad() and not self.params["skip_nan_batch"]:
+ torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"]))
+ torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"]))
+ torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"]))
+ raise RuntimeError('There is Nan problem in epoch/iter: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[0],
+ self.training_point[1], "{0}/nan.*".format(self.params["model_dir"])))
+ else:
+ flag_step1 = self.scaler.step(optimizer,stage=1)
+ self.scaler.update()
+ else:
+ if not self.modify_grad():
+ if not self.params["skip_nan_batch"]:
+ torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"]))
+ torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"]))
+ torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"]))
+ raise RuntimeError('There is Nan problem in epoch/iter: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[0],
+ self.training_point[1], "{0}/nan.*".format(self.params["model_dir"])))
+ else:
+ flag_step1=optimizer.step(stage=1)
+ if flag_step1:
+ stop_flag = torch.zeros(1).to(device)
+ else:
+ stop_flag = torch.ones(1).to(device)
+ if utils.use_ddp():
+ dist.all_reduce(stop_flag,op=dist.ReduceOp.SUM)
+ dist.barrier()
+ stop_flag = bool(stop_flag)
+ flag_step2 = False
+ if not stop_flag:
+ # Managing automatic mixed precision (Leo 2021-11-08)
+ with torch.cuda.amp.autocast(self.scaler is not None):
+ disable_running_stats(model_forward)
+ loss1 = model.get_loss(model_forward(*input_list), targets)/self.params["accum_grad"]
+ # loss = model_forward(inputs, targets)/self.params["accum_grad"]
+ if self.params["use_amp"]:
+ self.scaler.scale(loss1).backward()
+ else:
+ loss1.backward()
+ loss1.detach()
+ # Use mixed precision training (Leo 2021-11-08)
+ if self.params["use_amp"]:
+ self.scaler.unscale_(optimizer)
+ if not self.modify_grad() and not self.params["skip_nan_batch"]:
+ torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"]))
+ torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"]))
+ torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"]))
+ raise RuntimeError('There is Nan problem in epoch/iter: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[0],
+ self.training_point[1], "{0}/nan.*".format(self.params["model_dir"])))
+ else:
+ flag_step2 = self.scaler.step(optimizer,stage=2)
+ self.scaler.update()
+ else:
+ if not self.modify_grad():
+ if not self.params["skip_nan_batch"]:
+ torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"]))
+ torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"]))
+ torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"]))
+ raise RuntimeError('There is Nan problem in epoch/iter: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[0],
+ self.training_point[1], "{0}/nan.*".format(self.params["model_dir"])))
+ else:
+ flag_step2 = optimizer.step(stage=2)
+ self.training_point[2] += 1 # update step
+ if step_lr:
+ self.step_lr(loss,accuracy,optimizer,self.elements["lr_scheduler"])
+ if flag_step1 and flag_step2 is not True:
+ optimizer.back_w()
+ optimizer.zero_grad()
+ self.num_batch += 1
+ return loss, accuracy
+ # clip grad
+ def modify_grad(self):
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.elements["model"].parameters(), self.params["max_change"])
+ if not torch.isfinite(grad_norm):
+ logger.warning("Grad:{0} is not finite in epoch/iter: {1}/{2}".format(grad_norm, self.training_point[0],self.training_point[1]))
+ if self.params["nan_debug"]:
+ raise RuntimeError(
+ "[NOT OK] Nan is still found in this debug.")
+ return False
+ else:
+ if self.params["nan_debug"]:
+ raise RuntimeError(
+ "[OK] There is no nan found for this debug.")
+ return True
+ def compute_validation(self, data_loader):
+ """A normal evaluation core.
+ """
+ model = self.elements["model"]
+ model_forward = self.elements["model_forward"]
+ train_status = model.training # Record status.
+ model.eval()
+ loss = 0.
+ accuracy = 0. if self.params["compute_valid_accuracy"] else None
+ num_samples = 0
+ with torch.no_grad():
+ for idx,this_data in enumerate(data_loader):
+ inputs, targets, feats_lens = this_data
+ feats_lens = (feats_lens*inputs.shape[2]).long()
+ num_utts = targets.size(0)
+ input_list = [inputs,feats_lens]
+ if num_utts == 0:
+ continue
+ # in valid stage, DO NOT call ddp model, for ddp model is in JOIN context wrapper.
+ # Leo 2022-02-03
+ loss += model.get_loss(model(*input_list),
+ targets).item() * len(targets)
+ # loss += model_forward(inputs,targets).item() * len(targets)
+ num_samples += len(targets)
+ if self.params["compute_valid_accuracy"]:
+ # This will occupy extra GPU memory.
+ accuracy += model.get_accuracy(targets) * len(targets)
+ if idx > self.params["compute_batch_num_valid"]-1:
+ break
+ avg_loss = loss/num_samples
+ avg_accuracy = accuracy / num_samples if self.params["compute_valid_accuracy"] else None
+ if train_status:
+ model.train()
+ return avg_loss, avg_accuracy
+ def step_lr(self,train_loss,train_acc,base_optimizer,lr_scheduler):
+ # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
+ # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
+ # and some simple schedulers whose step() parameter is 'epoch' only are supported.
+ valid_dataloader=self.elements["data"].valid_loader
+ lr_scheduler_params = {
+ "training_point": self.training_point}
+ valid_loss = None
+ valid_computed = False
+ if lr_scheduler.name == "reduceP" and lr_scheduler.is_reduce_point(self.training_point):
+ assert valid_dataloader is not None
+ valid_loss, valid_acc = self.compute_validation(valid_dataloader)
+ lr_scheduler_params["valid_metric"] = (valid_loss, valid_acc)
+ valid_computed = True
+ if utils.is_main_training():
+ if valid_computed or (valid_dataloader is not None and self.reporter.is_report(self.training_point)):
+ if not valid_computed:
+ valid_loss, valid_acc = self.compute_validation(valid_dataloader)
+ valid_computed = False
+ # real_snapshot is set for tensorboard to avoid workspace problem
+ real_snapshot = {"train_loss": train_loss, "valid_loss": valid_loss,
+ "train_acc": train_acc*100, "valid_acc": valid_acc*100}
+ snapshot = {"train_loss": "{0:.6f}".format(train_loss), "valid_loss": "{0:.6f}".format(valid_loss),
+ "train_acc": "{0:.2f}".format(train_acc*100), "valid_acc": "{0:.2f}".format(valid_acc*100),
+ "total_dur":self.origin_total_dur,"num_sample":self.num_sample,"real": real_snapshot}
+ # For ReduceLROnPlateau.
+ lr_scheduler_params["valid_metric"] = (valid_loss, valid_acc)
+ else:
+ real_snapshot = {
+ "train_loss": train_loss, "train_acc": train_acc*100}
+ snapshot = {"train_loss": "{0:.6f}".format(train_loss), "valid_loss": "",
+ "train_acc": "{0:.2f}".format(train_acc*100), "valid_acc": "",
+ "total_dur":self.origin_total_dur,"num_sample":self.num_sample,"real": real_snapshot}
+ training_point = (self.training_point[0],self.training_point[1],self.training_point[2])
+ self.train_lr = base_optimizer.state_dict()['param_groups'][0]['lr']
+ self.reporter.update(snapshot,training_point,self.train_lr)
+ if lr_scheduler is not None:
+ # It is not convenient to wrap lr_scheduler (doing).
+ if isinstance(lr_scheduler, LRSchedulerWrapper):
+ lr_scheduler.step(**lr_scheduler_params)
+ if utils.is_main_training():
+ current_lr = base_optimizer.state_dict()['param_groups'][0]['lr']
+ if lr_scheduler.name == "reduceP":
+ if current_lr < self.last_lr:
+ self.last_lr = current_lr
+ self.save_model(mod="iter",train_lr=self.train_lr,valid_loss=valid_loss)
+ elif current_lr <= lr_scheduler.min_lr and lr_scheduler.is_reduce_point(self.training_point):
+ self.save_model(mod="iter",train_lr=self.train_lr,valid_loss=valid_loss)
+ if lr_scheduler.is_cycle_point(self.training_point):
+ self.cycle_point+=1
+ self.save_model(mod="cycle",train_lr=self.train_lr,valid_loss=valid_loss)
+ else:
+ # For some pytorch lr_schedulers, but it is not available for all.
+ lr_scheduler.step(self.training_point[0])
+ def run(self):
+ """Main function to start a training process.
+ """
+ try:
+ self.init_training()
+ if utils.is_main_training():
+ self.reporter = Reporter(self)
+ start_epoch = self.params["start_epoch"]
+ epochs = self.params["epochs"]
+ data = self.elements["data"]
+ model = self.elements["model"]
+ # See init_training.
+ model_forward = self.elements["model_forward"]
+ self.train_lr = self.elements["optimizer"].state_dict()['param_groups'][0]['lr']
+ self.last_lr = self.elements["optimizer"].state_dict()['param_groups'][0]['lr']
+ if utils.is_main_training():
+ logger.info("Training will run for {0} epochs.".format(epochs))
+ if utils.is_main_training() and self.params["accum_grad"] > 1:
+ logger.info("using accumulate grad,accum_num: {}".format(
+ self.params["accum_grad"]))
+ if isinstance(model_forward, torch.nn.parallel.DistributedDataParallel):
+ model_context = model_forward.join(throw_on_early_termination=True)
+ else:
+ model_context = nullcontext()
+ device = utils.get_device(self.elements["model"])
+ stop_training = torch.zeros(1).to(device)
+ with model_context:
+ for this_epoch in range(start_epoch, epochs+5):
+ # In case the uneven data in different ranks when ddp training,
+ # here we design a more epoch for sub-ranks to ensure the main rank can broadcasting.
+ if utils.is_main_training:
+ if this_epoch == epochs: # skip the last+1 epoch
+ stop_training = torch.ones(1).to(device)
+ self.training_point[0]+=1
+ data.train_loader.dataset.set_epoch(this_epoch)
+ if utils.is_main_training() and self.training_point[1]==0: self.origin_total_dur,self.num_sample=data.train_loader.dataset.get_data_dur()
+ # with model_context:
+ for _, batch in enumerate(data.train_loader, 0):
+ # It is important for reporter.
+ dist.barrier()
+ if utils.use_ddp():dist.all_reduce(stop_training,op=dist.ReduceOp.SUM)
+ if stop_training:
+ break
+ self.training_point[1] +=1
+ num_utts = batch[0].size(0)
+ if num_utts == 0:
+ continue
+ if model.use_step:
+ step_point = (self.training_point[0],self.training_point[2])
+ model.step_iter(*step_point)
+ loss, acc = self.train_one_batch(batch)
+ model.backward_step(*self.training_point)
+ if stop_training:
+ break
+ if utils.is_main_training():self.save_model(train_lr=self.train_lr)
+ self.training_point[1] =0
+ if utils.is_main_training():self.reporter.finish()
+ if utils.is_main_training():
+ final_model_name = "{}_cycle".format(self.cycle_point) if self.cycle_point else epochs
+ final_model_path = os.path.join(self.params["model_dir"],'final.params')
+ if os.path.exists(final_model_path) or os.path.islink(final_model_path):
+ os.remove(final_model_path)
+ os.symlink('{0}/{1}.{2}'.format(self.params["model_dir"], final_model_name, self.params["suffix"]), final_model_path)
+ except BaseException as e:
+ if utils.use_ddp():utils.cleanup_ddp()
+ if not isinstance(e, KeyboardInterrupt):traceback.print_exc()
+ sys.exit(1)
+ @for_lr_finder_new
+ def lr_finder_compute(self, train_batch):
+ model = self.elements["model"]
+ if model.use_step:
+ step_point = (self.training_point[0],self.training_point[2])
+ model.step_iter(*step_point)
+ loss, acc = self.train_one_batch(train_batch,step_lr=False)
+ model.backward_step(*self.training_point)
+ valid_loss, valid_acc=0,0
+ if utils.is_main_training():
+ valid_loss, valid_acc = self.compute_validation(
+ self.elements["data"].valid_loader)
+ return ["train_loss", "train_acc", "valid_loss", "valid_acc"], [loss, acc, valid_loss, valid_acc]
+ def run_lr_finder(self, save_file: str, comment=None, init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98):
+ self.init_training()
+ log_dir = self.params["model_dir"] + "/log/" # For tensorboardX
+ if comment is not None:
+ save_file = comment + "-" + save_file
+ save_file = log_dir + save_file
+ log_lrs, values_matrix = self.lr_finder_compute(self.elements["data"].train_loader, self.elements["optimizer"],
+ init_lr=init_lr, final_lr=final_lr, num_iters=num_iters, beta=beta,
+ log_dir=log_dir, comment=comment)
+ if utils.is_main_training():
+ df = pd.DataFrame(np.vstack([log_lrs, values_matrix]).T,
+ columns=["log_lr", "train_loss", "train_acc", "valid_loss", "valid_acc"])
+ logger.info("Save lr finder values to {}.".format(save_file))
+ df.to_csv(save_file)
+ if utils.use_ddp():utils.cleanup_ddp()
+ sys.exit(1)
+class MultitaskTrainer(_BaseTrainer):
+ """One input and multi-output corresponding to different tasks, such as one
+ is speaker classfication and another is phones classfication.
+ """
+ pass
+class GANTrainer(_BaseTrainer):
+ """This is for GAN.
+ """
+ pass
+# Function ✿
+def add_gaussian_noise_to_grad(model, t, n=0.1, gamma=0.55):
+ """
+ var = n/(1+t)**gamma
+ for param in model.params():
+ param.grad += to_device(model, torch.normal(0, var, param.size()))
+def disable_running_stats(model):
+ def _disable(module):
+ if isinstance(module, _BatchNorm):
+ module.backup_momentum = module.momentum
+ module.momentum = 0
+ model.apply(_disable)
+def enable_running_stats(model):
+ def _enable(module):
+ if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
+ module.momentum = module.backup_momentum
+ model.apply(_enable)
\ No newline at end of file
diff --git a/pytorch/model/ecapa_tdnn_xvector.py b/pytorch/model/ecapa_tdnn_xvector.py
old mode 100644
new mode 100755
index 5f79f1e..bf7022d
--- a/pytorch/model/ecapa_tdnn_xvector.py
+++ b/pytorch/model/ecapa_tdnn_xvector.py
@@ -1,3 +1,9 @@
+# Copyright xmuspeech (Author: Leo 2022-05-27)
+# refs:
+# 1. ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification
+# https://arxiv.org/abs/2005.07143
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -6,10 +12,6 @@
sys.path.insert(0, 'subtools/pytorch')
import libs.support.utils as utils
from libs.nnet import *
-# Copyright xmuspeech (Author: Leo 2022-05-27)
-# refs:
-# 1. ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification
-# https://arxiv.org/abs/2005.07143
class Res2NetBlock(torch.nn.Module):
@@ -34,7 +36,7 @@ class Res2NetBlock(torch.nn.Module):
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
>>> out_tensor.shape
- torch.Size([8, 120, 64])
+ torch.Size([8, 120, 64])
def __init__(
@@ -231,6 +233,8 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
"curricular": False}
default_step_params = {
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0},
"T": None,
"m": False, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4,
"s": False, "s_tuple": (30, 12), "s_list": None,
@@ -332,6 +336,8 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
if margin_loss:
self.loss = MarginSoftmaxLoss(
embd_dim, num_targets, **margin_loss_params)
+ if self.use_step and self.step_params["margin_warm"]:
+ self.margin_warm = MarginWarm(**step_params["margin_warm_conf"])
self.loss = SoftmaxLoss(embd_dim, num_targets)
# self.loss = AngleLoss(embd_dim,num_targets)
@@ -339,7 +345,7 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
self.loss, self.mixup) if mixup else None
# An example to using transform-learning without initializing loss.affine parameters
self.transform_keys = ["layer2", "layer3",
- "layer4", "conv", "stats", "fc1", "fc2"]
+ "layer4", "stats", "mfa", "bn_stats", "fc1", "fc2", "loss"]
if margin_loss and transfer_from == "softmax_loss":
# For softmax_loss to am_softmax_loss
@@ -348,7 +354,7 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
- def forward(self, x):
+ def forward(self, x, x_len: torch.Tensor=torch.empty(0)):
x = self.layer1(x)
x1 = self.layer2(x)
x2 = self.layer3(x+x1)
@@ -442,27 +448,28 @@ def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torc
return xvector
- def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 10000, isMatrix: bool = True):
- if isMatrix:
- input = torch.unsqueeze(input, dim=0)
- input = input.transpose(1, 2)
- num_frames = input.shape[2]
- num_split = (num_frames + maxChunk - 1) // maxChunk
- split_size = num_frames // num_split
- offset = 0
- embedding_stats = torch.zeros(1, self.embd_dim, 1)
- for _ in range(0, num_split-1):
- this_embedding = self.extract_embedding_jit(
- input[:, :, offset:offset+split_size], position)
- offset += split_size
- embedding_stats += split_size*this_embedding
- last_embedding = self.extract_embedding_jit(
- input[:, :, offset:], position)
- embedding = (embedding_stats + (num_frames-offset)
- * last_embedding) / num_frames
- return torch.squeeze(embedding.transpose(1, 2)).cpu()
+ def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True):
+ with torch.no_grad():
+ if isMatrix:
+ input = torch.unsqueeze(input, dim=0)
+ input = input.transpose(1, 2)
+ num_frames = input.shape[2]
+ num_split = (num_frames + maxChunk - 1) // maxChunk
+ split_size = num_frames // num_split
+ offset = 0
+ embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device)
+ for _ in range(0, num_split-1):
+ this_embedding = self.extract_embedding_jit(
+ input[:, :, offset:offset+split_size], position)
+ offset += split_size
+ embedding_stats += split_size*this_embedding
+ last_embedding = self.extract_embedding_jit(
+ input[:, :, offset:], position)
+ embedding = (embedding_stats + (num_frames-offset)
+ * last_embedding) / num_frames
+ return torch.squeeze(embedding.transpose(1, 2)).cpu()
def embedding_dim(self) -> int:
@@ -488,7 +495,8 @@ def step(self, epoch, this_iter, epoch_batchs):
current_postion = epoch*epoch_batchs + this_iter
lambda_factor = max(self.step_params["lambda_0"],
- self.loss.step(lambda_factor)
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch)
@@ -509,10 +517,15 @@ def step(self, epoch, this_iter, epoch_batchs):
def step_iter(self, epoch, cur_step):
# For iterabledataset
if self.use_step:
+ if self.step_params["margin_warm"]:
+ offset_margin, lambda_m = self.margin_warm.step(cur_step)
+ lambda_m = max(1e-3,lambda_m)
+ self.loss.step(lambda_m,offset_margin)
if self.step_params["m"]:
lambda_factor = max(self.step_params["lambda_0"],
- self.loss.step(lambda_factor)
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step)
diff --git a/pytorch/model/extended_xvector.py b/pytorch/model/extended_xvector.py
old mode 100644
new mode 100755
index 85d5d31..c539144
--- a/pytorch/model/extended_xvector.py
+++ b/pytorch/model/extended_xvector.py
@@ -49,8 +49,9 @@ def init(self, inputs_dim, num_targets, extend=True, nonlinearity="relu",
self.transform_keys = ["tdnn1","tdnn2","tdnn3","tdnn4","tdnn5","stats","tdnn6","tdnn7",
+ @torch.jit.unused
- def forward(self, inputs):
+ def forward(self, inputs, x_len: torch.Tensor=torch.empty(0)):
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
diff --git a/pytorch/model/factored_xvector.py b/pytorch/model/factored_xvector.py
old mode 100644
new mode 100755
index 9699bc8..0ed7053
--- a/pytorch/model/factored_xvector.py
+++ b/pytorch/model/factored_xvector.py
@@ -61,7 +61,7 @@ def init(self, inputs_dim, num_targets, nonlinearity="relu", semi_orth=True,embd
- def forward(self, inputs):
+ def forward(self, inputs, x_len: torch.Tensor=torch.empty(0)):
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
@@ -158,24 +158,28 @@ def extract_embedding_jit(self, inputs: torch.Tensor, position: str = 'far') ->
return xvector
- def extract_embedding_whole(self,input:torch.Tensor,position:str='far',maxChunk:int=10000,isMatrix:bool=True):
- if isMatrix:
- input=torch.unsqueeze(input,dim=0)
- input=input.transpose(1,2)
- num_frames = input.shape[2]
- num_split = (num_frames + maxChunk - 1) // maxChunk
- split_size = num_frames // num_split
- offset=0
- embedding_stats = torch.zeros(1,self.embd_dim,1)
- for _ in range(0, num_split-1):
- this_embedding = self.extract_embedding_jit(input[:, :, offset:offset+split_size],position)
- offset += split_size
- embedding_stats += split_size*this_embedding
- last_embedding = self.extract_embedding_jit(input[:, :, offset:],position)
- embedding = (embedding_stats + (num_frames-offset) * last_embedding) / num_frames
- return torch.squeeze(embedding.transpose(1,2)).cpu()
+ def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True):
+ with torch.no_grad():
+ if isMatrix:
+ input = torch.unsqueeze(input, dim=0)
+ input = input.transpose(1, 2)
+ num_frames = input.shape[2]
+ num_split = (num_frames + maxChunk - 1) // maxChunk
+ split_size = num_frames // num_split
+ offset = 0
+ embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device)
+ for _ in range(0, num_split-1):
+ this_embedding = self.extract_embedding_jit(
+ input[:, :, offset:offset+split_size], position)
+ offset += split_size
+ embedding_stats += split_size*this_embedding
+ last_embedding = self.extract_embedding_jit(
+ input[:, :, offset:], position)
+ embedding = (embedding_stats + (num_frames-offset)
+ * last_embedding) / num_frames
+ return torch.squeeze(embedding.transpose(1, 2)).cpu()
def embedding_dim(self) -> int:
diff --git a/pytorch/model/multi_task_xvector_fix.py b/pytorch/model/multi_task_xvector_fix.py
old mode 100644
new mode 100755
index efc681d..c5f653d
--- a/pytorch/model/multi_task_xvector_fix.py
+++ b/pytorch/model/multi_task_xvector_fix.py
@@ -323,7 +323,8 @@ def step(self, epoch, this_iter, epoch_batchs):
current_postion = epoch*epoch_batchs + this_iter
lambda_factor = max(self.step_params["lambda_0"],
- self.loss_spk.step(lambda_factor)
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss_spk.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch)
@@ -346,7 +347,8 @@ def step_iter(self, epoch, cur_step):
if self.step_params["m"]:
lambda_factor = max(self.step_params["lambda_0"],
- self.loss.step(lambda_factor)
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss_spk.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step)
diff --git a/pytorch/model/repvgg_xvector.py b/pytorch/model/repvgg_xvector.py
old mode 100644
new mode 100755
index fe0e200..4106bce
--- a/pytorch/model/repvgg_xvector.py
+++ b/pytorch/model/repvgg_xvector.py
@@ -12,7 +12,7 @@
from libs.nnet import *
class RepVggXvector(TopVirtualNnet):
- """ A repvgg vector framework """
+ """ A repvgg vector framework """
def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropout=0.,training=True, extracted_embedding="near",
deploy=False,repvgg_config={}, pooling="statistics", pooling_params={}, fc1=False, fc1_params={}, fc2_params={}, margin_loss=False, margin_loss_params={},
@@ -21,13 +21,13 @@ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropou
## Params.
default_repvgg_config = {
"auto_model" : False,
- "auto_model_name" : "RepVGG_B0",
- "block": "RepVGG",
+ "auto_model_name" : "RepVGG_A1",
+ "block": "RepSPK",
- "num_blocks": [4, 6, 16, 1],
+ "num_blocks": [2, 4, 14, 1],
"base_width": 32,
- "width_multiplier":[1, 1, 1, 2.5],
+ "width_multiplier": [1, 1, 1, 2.5],
"norm_layer_params":{"momentum":0.5, "affine":True},
"override_groups_map": None,
"use_se": False,
@@ -57,6 +57,8 @@ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropou
default_step_params = {
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0},
"m":False, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4,
"s":False, "s_tuple":(30, 12), "s_list":None,
@@ -125,13 +127,15 @@ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropou
if training :
if margin_loss:
self.loss = MarginSoftmaxLoss(embd_dim, num_targets, **margin_loss_params)
+ if self.use_step and self.step_params["margin_warm"]:
+ self.margin_warm = MarginWarm(**step_params["margin_warm_conf"])
elif adacos:
self.loss = AdaCos(embd_dim,num_targets)
self.loss = SoftmaxLoss(embd_dim, num_targets)
# An example to using transform-learning without initializing loss.affine parameters
- self.transform_keys = ["repvgg", "stats", "fc1", "fc2"]
+ self.transform_keys = ["repvgg", "stats", "fc1", "fc2","loss"]
# self.transform_keys = ["resnet"]
if margin_loss and transfer_from == "softmax_loss":
@@ -140,7 +144,7 @@ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropou
- def forward(self, x):
+ def forward(self, x, x_len: torch.Tensor=torch.empty(0)):
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
@@ -233,24 +237,28 @@ def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torc
return xvector
- def extract_embedding_whole(self,input:torch.Tensor,position:str='near',maxChunk:int=10000,isMatrix:bool=True):
- if isMatrix:
- input=torch.unsqueeze(input,dim=0)
- input=input.transpose(1,2)
- num_frames = input.shape[2]
- num_split = (num_frames + maxChunk - 1) // maxChunk
- split_size = num_frames // num_split
- offset=0
- embedding_stats = torch.zeros(1,self.embd_dim,1)
- for _ in range(0, num_split-1):
- this_embedding = self.extract_embedding_jit(input[:, :, offset:offset+split_size],position)
- offset += split_size
- embedding_stats+=split_size*this_embedding
- last_embedding = self.extract_embedding_jit(input[:, :, offset:],position)
- embedding = (embedding_stats + (num_frames-offset) * last_embedding) / num_frames
- return torch.squeeze(embedding.transpose(1,2)).cpu()
+ def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True):
+ with torch.no_grad():
+ if isMatrix:
+ input = torch.unsqueeze(input, dim=0)
+ input = input.transpose(1, 2)
+ num_frames = input.shape[2]
+ num_split = (num_frames + maxChunk - 1) // maxChunk
+ split_size = num_frames // num_split
+ offset = 0
+ embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device)
+ for _ in range(0, num_split-1):
+ this_embedding = self.extract_embedding_jit(
+ input[:, :, offset:offset+split_size], position)
+ offset += split_size
+ embedding_stats += split_size*this_embedding
+ last_embedding = self.extract_embedding_jit(
+ input[:, :, offset:], position)
+ embedding = (embedding_stats + (num_frames-offset)
+ * last_embedding) / num_frames
+ return torch.squeeze(embedding.transpose(1, 2)).cpu()
def embedding_dim(self) -> int:
@@ -277,9 +285,10 @@ def step(self, epoch, this_iter, epoch_batchs):
if self.use_step:
if self.step_params["m"]:
current_postion = epoch*epoch_batchs + this_iter
- lambda_factor = max(self.step_params["lambda_0"],
- self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"]))
- self.loss.step(lambda_factor)
+ lambda_factor = max(self.step_params["lambda_0"],
+ self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"]))
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch)
@@ -287,10 +296,12 @@ def step(self, epoch, this_iter, epoch_batchs):
T_i = T_i * epoch_batchs
if self.step_params["t"]:
- self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i)
+ self.loss.t = self.compute_decay_value(
+ *self.step_params["t_tuple"], T_cur, T_i)
if self.step_params["p"]:
- self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i)
+ self.aug_dropout.p = self.compute_decay_value(
+ *self.step_params["p_tuple"], T_cur, T_i)
if self.step_params["s"]:
self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]]
@@ -298,20 +309,26 @@ def step(self, epoch, this_iter, epoch_batchs):
def step_iter(self, epoch, cur_step):
# For iterabledataset
if self.use_step:
+ if self.step_params["margin_warm"]:
+ offset_margin, lambda_m = self.margin_warm.step(cur_step)
+ lambda_m = max(1e-3,lambda_m)
+ self.loss.step(lambda_m,offset_margin)
if self.step_params["m"]:
lambda_factor = max(self.step_params["lambda_0"],
- self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"]))
- self.loss.step(lambda_factor)
+ self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"]))
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step)
if self.step_params["t"]:
- self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i)
+ self.loss.t = self.compute_decay_value(
+ *self.step_params["t_tuple"], T_cur, T_i)
if self.step_params["p"]:
- self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i)
+ self.aug_dropout.p = self.compute_decay_value(
+ *self.step_params["p_tuple"], T_cur, T_i)
if self.step_params["s"]:
self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]]
diff --git a/pytorch/model/resnet_xvector.py b/pytorch/model/resnet_xvector.py
old mode 100644
new mode 100755
index 952eed5..d8f41ff
--- a/pytorch/model/resnet_xvector.py
+++ b/pytorch/model/resnet_xvector.py
@@ -29,8 +29,8 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
"head_conv":True, "head_conv_params":{"kernel_size":3, "stride":1, "padding":1},
"head_maxpool":False, "head_maxpool_params":{"kernel_size":3, "stride":1, "padding":1},
- "layers":[3, 4, 6, 3],
- "planes":[32, 64, 128, 256], # a.k.a channels.
+ "layers": [3, 4, 6, 3],
+ "planes": [32, 64, 128, 256], # a.k.a channels.
"use_se": False,
"se_ratio": 4,
@@ -63,6 +63,8 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
default_step_params = {
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0},
"m":False, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4,
"s":False, "s_tuple":(30, 12), "s_list":None,
@@ -85,7 +87,7 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
## Nnet.
self.aug_dropout = torch.nn.Dropout2d(p=aug_dropout) if aug_dropout > 0 else None
- self.cmvn_=InputSequenceNormalization(**cmvn_params) if cmvn else None
+ self.cmvn_=InputSequenceNormalization(**cmvn_params) if cmvn else torch.nn.Identity()
# [batch, 1, feats-dim, frames] for 2d and [batch, feats-dim, frames] for 1d.
# Should keep the channel/plane is always in 1-dim of tensor (index-0 based).
@@ -126,10 +128,14 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
if training :
if margin_loss:
self.loss = MarginSoftmaxLoss(resnet_params["planes"][3], num_targets, **margin_loss_params)
+ if self.use_step and self.step_params["margin_warm"]:
+ self.margin_warm = MarginWarm(**step_params["margin_warm_conf"])
self.loss = SoftmaxLoss(resnet_params["planes"][3], num_targets)
# An example to using transform-learning without initializing loss.affine parameters
+ # self.transform_keys = ["resnet", "stats", "fc1", "fc2"]
self.transform_keys = ["resnet", "stats", "fc1", "fc2","loss.weight"]
if margin_loss and transfer_from == "softmax_loss":
@@ -138,12 +144,12 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
- def forward(self, x):
+ def forward(self, x, x_len: torch.Tensor=torch.empty(0)):
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
- x = self.auto(self.cmvn_,x)
+ x = self.self.cmvn_(x)
x = self.auto(self.aug_dropout, x) # This auto function is equal to "x = layer(x) if layer is not None else x" for convenience.
# [samples-index, frames-dim-index, frames-index] -> [samples-index, 1, frames-dim-index, frames-index]
@@ -181,7 +187,7 @@ def extract_embedding(self, x):
return: an 1-dimensional vector after processed by decorator
# Tensor shape is not modified in libs.nnet.resnet.py for calling free, such as using this framework in cv.
- x = self.auto(self.cmvn_,x)
+ x = self.cmvn_(x)
x = x.unsqueeze(1) if self.convXd == 2 else x
x = self.resnet(x)
x = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) if self.convXd == 2 else x
@@ -204,10 +210,10 @@ def extract_embedding(self, x):
def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torch.Tensor:
x: a 3-dimensional tensor with batch-dim = 1 or normal features matrix
- return: an 1-dimensional vector after processed by decorator
+ return: an 1-dimensional vector after processed by decorator
- x = self.auto(self.cmvn_,x)
+ x = self.cmvn_(x)
# Tensor shape is not modified in libs.nnet.resnet.py for calling free, such as using this framework in cv.
x = x.unsqueeze(1) if self.convXd == 2 else x
x = self.resnet(x)
@@ -230,24 +236,28 @@ def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torc
return xvector
- def extract_embedding_whole(self,input:torch.Tensor,position:str='near',maxChunk:int=10000,isMatrix:bool=True):
- if isMatrix:
- input=torch.unsqueeze(input,dim=0)
- input=input.transpose(1,2)
- num_frames = input.shape[2]
- num_split = (num_frames + maxChunk - 1) // maxChunk
- split_size = num_frames // num_split
- offset=0
- embedding_stats = torch.zeros(1,self.embd_dim,1)
- for _ in range(0, num_split-1):
- this_embedding = self.extract_embedding_jit(input[:, :, offset:offset+split_size],position)
- offset += split_size
- embedding_stats += split_size*this_embedding
- last_embedding = self.extract_embedding_jit(input[:, :, offset:],position)
- embedding = (embedding_stats + (num_frames-offset) * last_embedding) / num_frames
- return torch.squeeze(embedding.transpose(1,2)).cpu()
+ def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True):
+ with torch.no_grad():
+ if isMatrix:
+ input = torch.unsqueeze(input, dim=0)
+ input = input.transpose(1, 2)
+ num_frames = input.shape[2]
+ num_split = (num_frames + maxChunk - 1) // maxChunk
+ split_size = num_frames // num_split
+ offset = 0
+ embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device)
+ for _ in range(0, num_split-1):
+ this_embedding = self.extract_embedding_jit(
+ input[:, :, offset:offset+split_size], position)
+ offset += split_size
+ embedding_stats += split_size*this_embedding
+ last_embedding = self.extract_embedding_jit(
+ input[:, :, offset:], position)
+ embedding = (embedding_stats + (num_frames-offset)
+ * last_embedding) / num_frames
+ return torch.squeeze(embedding.transpose(1, 2)).cpu()
def embedding_dim(self) -> int:
@@ -274,9 +284,10 @@ def step(self, epoch, this_iter, epoch_batchs):
if self.use_step:
if self.step_params["m"]:
current_postion = epoch*epoch_batchs + this_iter
- lambda_factor = max(self.step_params["lambda_0"],
- self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"]))
- self.loss.step(lambda_factor)
+ lambda_factor = max(self.step_params["lambda_0"],
+ self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"]))
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch)
@@ -284,10 +295,12 @@ def step(self, epoch, this_iter, epoch_batchs):
T_i = T_i * epoch_batchs
if self.step_params["t"]:
- self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i)
+ self.loss.t = self.compute_decay_value(
+ *self.step_params["t_tuple"], T_cur, T_i)
if self.step_params["p"]:
- self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i)
+ self.aug_dropout.p = self.compute_decay_value(
+ *self.step_params["p_tuple"], T_cur, T_i)
if self.step_params["s"]:
self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]]
@@ -295,20 +308,26 @@ def step(self, epoch, this_iter, epoch_batchs):
def step_iter(self, epoch, cur_step):
# For iterabledataset
if self.use_step:
+ if self.step_params["margin_warm"]:
+ offset_margin, lambda_m = self.margin_warm.step(cur_step)
+ lambda_m = max(1e-3,lambda_m)
+ self.loss.step(lambda_m,offset_margin)
if self.step_params["m"]:
lambda_factor = max(self.step_params["lambda_0"],
- self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"]))
- self.loss.step(lambda_factor)
+ self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"]))
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step)
if self.step_params["t"]:
- self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i)
+ self.loss.t = self.compute_decay_value(
+ *self.step_params["t_tuple"], T_cur, T_i)
if self.step_params["p"]:
- self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i)
+ self.aug_dropout.p = self.compute_decay_value(
+ *self.step_params["p_tuple"], T_cur, T_i)
if self.step_params["s"]:
self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]]
diff --git a/pytorch/model/snowdar_xvector.py b/pytorch/model/snowdar_xvector.py
old mode 100644
new mode 100755
index 2a21f42..1d6b8d2
--- a/pytorch/model/snowdar_xvector.py
+++ b/pytorch/model/snowdar_xvector.py
@@ -47,7 +47,7 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False,
- "context":[0],
+ "context": [0],
@@ -65,6 +65,8 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False,
default_step_params = {
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0},
"m":False, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4,
"s":False, "s_tuple":(30, 12), "s_list":None,
@@ -153,6 +155,8 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False,
if training :
if margin_loss:
self.loss = MarginSoftmaxLoss(512, num_targets, **margin_loss_params)
+ if self.use_step and self.step_params["margin_warm"]:
+ self.margin_warm = MarginWarm(**step_params["margin_warm_conf"])
self.loss = SoftmaxLoss(512, num_targets)
@@ -161,14 +165,14 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False,
# An example to using transform-learning without initializing loss.affine parameters
self.transform_keys = ["tdnn1","tdnn2","tdnn3","tdnn4","tdnn5","stats","tdnn6","tdnn7",
- "se1","se2","se3","se4","loss"]
+ "se1","se2","se3","se4", "loss"]
if margin_loss and transfer_from == "softmax_loss":
# For softmax_loss to am_softmax_loss
self.rename_transform_keys = {"loss.affine.weight":"loss.weight"}
+ @torch.jit.unused
- def forward(self, inputs):
+ def forward(self, inputs, x_len: torch.Tensor=torch.empty(0)):
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
@@ -272,6 +276,67 @@ def extract_embedding(self, inputs):
return xvector
+ def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torch.Tensor:
+ """
+ inputs: a 3-dimensional tensor with batch-dim = 1 or normal features matrix
+ return: an 1-dimensional vector after processed by decorator
+ """
+ # Tensor shape is not modified in libs.nnet.resnet.py for calling free, such as using this framework in cv.
+ x = x.unsqueeze(1)
+ x = self.repvgg(x)
+ x = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3])
+ x = self.stats(x)
+ if position == "far" and self.fc1 is not None:
+ xvector = self.fc1.affine(x)
+ elif position == "near_affine":
+ if self.fc1 is not None:
+ x=self.fc1(x)
+ xvector = self.fc2.affine(x)
+ elif position == "near":
+ if self.fc1 is not None:
+ x=self.fc1(x)
+ xvector = self.fc2(x)
+ # xvector = F.normalize(xvector)
+ else:
+ raise TypeError("Expected far or near position, but got {}".format(position))
+ return xvector
+ @torch.jit.export
+ def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True):
+ with torch.no_grad():
+ if isMatrix:
+ input = torch.unsqueeze(input, dim=0)
+ input = input.transpose(1, 2)
+ num_frames = input.shape[2]
+ num_split = (num_frames + maxChunk - 1) // maxChunk
+ split_size = num_frames // num_split
+ offset = 0
+ embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device)
+ for _ in range(0, num_split-1):
+ this_embedding = self.extract_embedding_jit(
+ input[:, :, offset:offset+split_size], position)
+ offset += split_size
+ embedding_stats += split_size*this_embedding
+ last_embedding = self.extract_embedding_jit(
+ input[:, :, offset:], position)
+ embedding = (embedding_stats + (num_frames-offset)
+ * last_embedding) / num_frames
+ return torch.squeeze(embedding.transpose(1, 2)).cpu()
+ @torch.jit.export
+ def embedding_dim(self) -> int:
+ """ Export interface for c++ call, return embedding dim of the model
+ """
+ return self.embd_dim
def get_warmR_T(self,T_0, T_mult, epoch):
n = int(math.log(max(0.05, (epoch / T_0 * (T_mult - 1) + 1)), T_mult))
T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1)
@@ -288,9 +353,10 @@ def step(self, epoch, this_iter, epoch_batchs):
if self.use_step:
if self.step_params["m"]:
current_postion = epoch*epoch_batchs + this_iter
- lambda_factor = max(self.step_params["lambda_0"],
- self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"]))
- self.loss.step(lambda_factor)
+ lambda_factor = max(self.step_params["lambda_0"],
+ self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"]))
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch)
@@ -298,32 +364,39 @@ def step(self, epoch, this_iter, epoch_batchs):
T_i = T_i * epoch_batchs
if self.step_params["t"]:
- self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i)
+ self.loss.t = self.compute_decay_value(
+ *self.step_params["t_tuple"], T_cur, T_i)
if self.step_params["p"]:
- self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i)
+ self.aug_dropout.p = self.compute_decay_value(
+ *self.step_params["p_tuple"], T_cur, T_i)
if self.step_params["s"]:
self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]]
def step_iter(self, epoch, cur_step):
# For iterabledataset
if self.use_step:
+ if self.step_params["margin_warm"]:
+ offset_margin, lambda_m = self.margin_warm.step(cur_step)
+ lambda_m = max(1e-3,lambda_m)
+ self.loss.step(lambda_m,offset_margin)
if self.step_params["m"]:
lambda_factor = max(self.step_params["lambda_0"],
- self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"]))
- self.loss.step(lambda_factor)
+ self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"]))
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step)
if self.step_params["t"]:
- self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i)
+ self.loss.t = self.compute_decay_value(
+ *self.step_params["t_tuple"], T_cur, T_i)
if self.step_params["p"]:
- self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i)
+ self.aug_dropout.p = self.compute_decay_value(
+ *self.step_params["p_tuple"], T_cur, T_i)
if self.step_params["s"]:
self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]]
diff --git a/pytorch/model/transformer_xvector.py b/pytorch/model/transformer_xvector.py
new file mode 100755
index 0000000..4bf2371
--- /dev/null
+++ b/pytorch/model/transformer_xvector.py
@@ -0,0 +1,485 @@
+# Copyright xmuspeech (Author: Leo 2022-07-18)
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+import sys
+sys.path.insert(0, 'subtools/pytorch')
+import libs.support.utils as utils
+from libs.nnet import *
+def compute_statistics(x, m, dim: int=2, stddev: bool=True,eps: float=1e-5):
+ mean = (m * x).sum(dim)
+ if stddev:
+ # std = torch.sqrt(
+ # (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
+ # )
+ std = torch.sqrt(
+ (torch.sum(m * (x ** 2), dim=dim) - mean ** 2).clamp(eps)
+ )
+ else:
+ std = torch.empty(0)
+ return mean, std
+class AttentiveStatsPool(nn.Module):
+ def __init__(self, in_dim, hidden_size=128, time_attention=False, stddev=True):
+ super().__init__()
+ self.stddev = stddev
+ self.output_dim = in_dim*2 if self.stddev else in_dim
+ self.time_attention = time_attention
+ accept_dim = in_dim
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
+ if time_attention:
+ accept_dim = in_dim*3 if self.stddev else in_dim*2
+ norm = LayerNorm(hidden_size,dim=1,eps=1e-5)
+ self.attention = nn.Sequential(
+ nn.Conv1d(accept_dim, hidden_size, kernel_size=1),
+ nn.ReLU(),
+ norm,
+ nn.Tanh(),
+ nn.Conv1d(hidden_size, in_dim, kernel_size=1)
+ )
+ # gn_num = 2 if self.stddev else 1
+ # self.norm_stats = nn.GroupNorm(gn_num,self.output_dim )
+ self.norm_stats = LayerNorm(self.output_dim,dim=1)
+ def forward(self, x, mask: torch.Tensor = torch.ones((0, 0, 0))):
+ B, C ,T = x.shape
+ if mask.size(2) == 0 :
+ mask = torch.ones((B, 1, T)).to(x.device)
+ if self.time_attention:
+ total = mask.sum(dim=2, keepdim=True).float()
+ mean, std = compute_statistics(x, mask / total,stddev = self.stddev)
+ mean = mean.unsqueeze(2).repeat(1, 1, T)
+ if self.stddev:
+ std = std.unsqueeze(2).repeat(1, 1, T)
+ x_in = [x,mean,std]
+ else:
+ x_in = [x,mean]
+ x_in = torch.cat(x_in, dim=1)
+ else:
+ x_in = x
+ alpha = self.attention(x_in)
+ alpha = alpha.masked_fill(mask == 0, float("-inf"))
+ alpha = F.softmax(alpha, dim=2)
+ mean, std = compute_statistics(x, alpha,stddev = self.stddev)
+ if self.stddev:
+ out = torch.cat([mean, std], dim=1).unsqueeze(2)
+ else:
+ out = mean.unsqueeze(2)
+ return self.norm_stats(out)
+ def get_output_dim(self):
+ return self.output_dim
+class TransformerXvector(TopVirtualNnet):
+ def init(self, inputs_dim, num_targets, embd_dim=256,training=True,
+ extracted_embedding="near", mixup=False, mixup_alpha=1.0, pooling="ecpa-attentive", pooling_params={},
+ transformer_type="conformer", transformer_params={},tansformer_out={}, fc1=False, fc1_params={}, fc2_params={},
+ margin_loss=True, margin_loss_params={}, lsm_weight=0.0,use_step=False, step_params={}, transfer_from="softmax_loss",wenet_transfer=False):
+ default_transformer_params = {
+ "attention_dim": 256,
+ "att_type": 'multi', # [multi, gau] gau attention don't suppport rel_pos.
+ "attention_heads": 4,
+ "gau_key": 64, # gau key dim.
+ "gau_units": 512,
+ "num_blocks": 6,
+ "dropout_rate": 0.1,
+ "layer_dropout":0.,
+ "positionwise_layer_type": 'linear', # [linear, conv1d, conv1d-linear, gau, re_conv1d]
+ "positional_dropout_rate": 0.1,
+ "linear_units": 2048,
+ "positionwise_conv_kernel_size": 3,
+ "attention_dropout_rate": 0.0,
+ "attention_norm_args": {
+ "norm_method": "softmax", # [softmax, relu_plus, softmax_plus]
+ "train_len":300., # for softmax_plus.
+ },
+ "input_layer": "conv2d", # [linear, conv2d2, conv2d, re_conv2d, conv2d6, conv2d8]
+ "pos_enc_type": "abs_pos", # [abs_pos, no_pos, rot_pos, rel_pos]
+ "cnn_module_kernel": 15, # for conformer
+ "use_cnn_module": True, # for conformer
+ "cnn_module_norm": 'layer_norm', # for conformer ['batch_norm', 'layer_norm']
+ "static_chunk_size": 0,
+ "left_chunk_size": -1,
+ "use_dynamic_chunk": False,
+ "use_dynamic_left_chunk": False,
+ "combiner_type": "norm", # [norm, mfa, random_frame, random_layer]
+ "convfnn_blocks": 0
+ }
+ default_tansformer_out = {
+ "out_dim": 1536,
+ "nonlinearity": 'swish', "nonlinearity_params": {"inplace": True},
+ "bn-relu": False,
+ "bn": True,
+ "ln_replace": True, # replace BN with LN
+ "bn_params": {"momentum": 0.5, "affine": True, "track_running_stats": True}
+ }
+ default_pooling_params = {
+ "hidden_size": 128,
+ "time_attention": False,
+ "stddev": True,
+ }
+ default_fc_params = {
+ "nonlinearity": 'relu', "nonlinearity_params": {"inplace": True},
+ "bn-relu": False,
+ "bn": True,
+ "ln_replace": True, # replace BN with LN
+ "bn_params": {"momentum": 0.5, "affine": True, "track_running_stats": True}
+ }
+ default_margin_loss_params = {
+ "method": "am", "m": 0.2,
+ "feature_normalize": True, "s": 30,
+ "double": False,
+ "mhe_loss": False, "mhe_w": 0.01,
+ "inter_loss": 0.,
+ "ring_loss": 0.,
+ "curricular": False
+ }
+ default_step_params = {
+ "margin_warm":False,
+ "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0},
+ "T": None,
+ "m": True, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4,
+ "s": False, "s_tuple": (30, 12), "s_list": None,
+ "t": False, "t_tuple": (0.5, 1.2),
+ "p": False, "p_tuple": (0.5, 0.1)
+ }
+ self.use_step = use_step
+ self.step_params = step_params
+ self.extracted_embedding = extracted_embedding
+ transformer_params = utils.assign_params_dict(
+ default_transformer_params, transformer_params,support_unknow=True)
+ tansformer_out = utils.assign_params_dict(default_tansformer_out, tansformer_out)
+ pooling_params = utils.assign_params_dict(
+ default_pooling_params, pooling_params)
+ fc1_params = utils.assign_params_dict(default_fc_params, fc1_params)
+ fc2_params = utils.assign_params_dict(default_fc_params, fc2_params)
+ margin_loss_params = utils.assign_params_dict(
+ default_margin_loss_params, margin_loss_params)
+ step_params = utils.assign_params_dict(
+ default_step_params, step_params)
+ self.embd_dim = embd_dim
+ self.mixup = Mixup(alpha=mixup_alpha) if mixup else None
+ if transformer_type == "transformer":
+ transformer_backbone = TransformerEncoder
+ elif transformer_type == "conformer":
+ transformer_backbone = ConformerEncoder
+ elif transformer_type == "re_conformer":
+ transformer_backbone = ReConformerEncoder
+ else:
+ raise ValueError("unknown transformer_type: " + transformer_type)
+ self.transformer = transformer_backbone(inputs_dim,**transformer_params)
+ self.transform_out = ReluBatchNormTdnnLayer(self.transformer.output_size(),tansformer_out["out_dim"],**tansformer_out)
+ # Pooling
+ stddev = pooling_params.pop("stddev")
+ if pooling == "ecpa-attentive":
+ self.stats = AttentiveStatsPool(
+ tansformer_out["out_dim"], stddev=stddev,**pooling_params)
+ self.fc1 = ReluBatchNormTdnnLayer(
+ self.stats.get_output_dim(), embd_dim, **fc1_params) if fc1 else None
+ else:
+ raise ValueError("Only supoort asp for conformer now.")
+ if fc1:
+ fc2_in_dim = embd_dim
+ else:
+ fc2_in_dim = self.stats.get_output_dim()
+ self.fc2 = ReluBatchNormTdnnLayer(fc2_in_dim, embd_dim, **fc2_params)
+ # print("num_targets---------------",num_targets)
+ # Loss
+ # Do not need when extracting embedding.
+ if training:
+ if margin_loss:
+ self.loss = MarginSoftmaxLoss(
+ embd_dim, num_targets, label_smoothing=lsm_weight,**margin_loss_params)
+ if self.use_step and self.step_params["margin_warm"]:
+ self.margin_warm = MarginWarm(**step_params["margin_warm_conf"])
+ else:
+ self.loss = SoftmaxLoss(embd_dim, num_targets,label_smoothing=lsm_weight)
+ # self.loss = AngleLoss(embd_dim,num_targets)
+ self.wrapper_loss = MixupLoss(
+ self.loss, self.mixup) if mixup else None
+ # An example to using transform-learning without initializing loss.affine parameters
+ self.transform_keys = ["transformer", "transform_out", "stats", "fc1", "fc2", "loss"]
+ if margin_loss and transfer_from == "softmax_loss":
+ # For softmax_loss to am_softmax_loss
+ self.rename_transform_keys = {
+ "loss.affine.weight": "loss.weight"}
+ self.wenet_transfer = wenet_transfer
+ def load_transform_state_dict(self, state_dict):
+ """It is used in transform-learning.
+ """
+ assert isinstance(self.transform_keys, list)
+ assert isinstance(self.rename_transform_keys, dict)
+ remaining = {}
+ for k,v in state_dict.items():
+ # if "train_len" in k:
+ # print(k,v)
+ if self.wenet_transfer:
+ k = k.replace("encoder.","transformer.")
+ # k = k.replace("embed.","noembed.")
+ if k.split('.')[0] in self.transform_keys or k in self.transform_keys:
+ k = utils.key_to_value(self.rename_transform_keys, k, False)
+ remaining[k] = v
+ # for k in remaining.keys():
+ # print(k)
+ # assert 1==0
+ self.load_state_dict(remaining, strict=False)
+ return self
+ @torch.jit.unused
+ @utils.for_device_free
+ def forward(self, x, x_len,warmup: torch.Tensor=torch.FloatTensor([1.0])):
+ # [samples-index, frames-dim-index, frames-index] -> [samples-index, frames-index, frames-dim-index]
+ x = x.transpose(1,2)
+ x, masks = self.transformer(x,x_len,warmup=float(warmup))
+ x = x.transpose(1,2)
+ x = self.transform_out(x)
+ x = self.stats(x,masks)
+ if len(x.shape) != 3:
+ x = x.unsqueeze(dim=2)
+ with torch.cuda.amp.autocast(enabled=False):
+ x = self.auto(self.fc1, x)
+ x = self.fc2(x)
+ return x
+ @utils.for_device_free
+ def get_loss(self, inputs, targets):
+ """Should call get_loss() after forward() with using Xvector model function.
+ e.g.:
+ m=Xvector(20,10)
+ loss=m.get_loss(m(inputs),targets)
+ model.get_loss [custom] -> loss.forward [custom]
+ |
+ v
+ model.get_accuracy [custom] -> loss.get_accuracy [custom] -> loss.compute_accuracy [static] -> loss.predict [static]
+ """
+ if self.wrapper_loss is not None:
+ return self.wrapper_loss(inputs, targets)
+ else:
+ return self.loss(inputs, targets)
+ @utils.for_device_free
+ def get_accuracy(self, targets):
+ """Should call get_accuracy() after get_loss().
+ @return: return accuracy
+ """
+ if self.wrapper_loss is not None:
+ return self.wrapper_loss.get_accuracy(targets)
+ else:
+ return self.loss.get_accuracy(targets)
+ @for_extract_embedding(maxChunk=300, isMatrix=True)
+ def extract_embedding(self, x):
+ x_lens = torch.LongTensor([x.shape[2]]).to(x.device)
+ x = x.transpose(1,2)
+ x, _ = self.transformer(x,x_lens)
+ x = x.transpose(1,2)
+ x = self.transform_out(x)
+ x = self.stats(x)
+ if len(x.shape) != 3:
+ x = x.unsqueeze(dim=2)
+ if self.extracted_embedding == "far":
+ assert self.fc1 is not None
+ xvector = self.fc1.affine(x)
+ elif self.extracted_embedding == "near_affine":
+ x = self.auto(self.fc1, x)
+ xvector = self.fc2.affine(x)
+ elif self.extracted_embedding == "near":
+ x = self.auto(self.fc1, x)
+ xvector = self.fc2(x)
+ else:
+ raise TypeError("Expected far or near position, but got {}".format(
+ self.extracted_embedding))
+ return xvector
+ def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torch.Tensor:
+ x_lens = torch.tensor([x.shape[2]]).to(x.device)
+ x = x.transpose(1,2)
+ x, _ = self.transformer(x,x_lens)
+ x = x.transpose(1,2)
+ x = self.transform_out(x)
+ x = self.stats(x)
+ if len(x.shape) != 3:
+ x = x.unsqueeze(dim=2)
+ if position == "far" and self.fc1 is not None:
+ xvector = self.fc1.affine(x)
+ elif position == "near_affine":
+ if self.fc1 is not None:
+ x = self.fc1(x)
+ xvector = self.fc2.affine(x)
+ elif position == "near":
+ if self.fc1 is not None:
+ x = self.fc1(x)
+ xvector = self.fc2(x)
+ else:
+ raise TypeError("Expected far or near position, but got {}".format(
+ self.extracted_embedding))
+ return xvector
+ @torch.jit.export
+ def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True):
+ with torch.no_grad():
+ if isMatrix:
+ input = torch.unsqueeze(input, dim=0)
+ input = input.transpose(1, 2)
+ num_frames = input.shape[2]
+ num_split = (num_frames + maxChunk - 1) // maxChunk
+ split_size = num_frames // num_split
+ offset = 0
+ embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device)
+ for _ in range(0, num_split-1):
+ this_embedding = self.extract_embedding_jit(
+ input[:, :, offset:offset+split_size], position)
+ offset += split_size
+ embedding_stats += split_size*this_embedding
+ last_embedding = self.extract_embedding_jit(
+ input[:, :, offset:], position)
+ embedding = (embedding_stats + (num_frames-offset)
+ * last_embedding) / num_frames
+ return torch.squeeze(embedding.transpose(1, 2)).cpu()
+ @torch.jit.export
+ def embedding_dim(self) -> int:
+ """ Export interface for c++ call, return embedding dim of the model
+ """
+ return self.embd_dim
+ def get_warmR_T(self, T_0, T_mult, epoch):
+ n = int(math.log(max(0.05, (epoch / T_0 * (T_mult - 1) + 1)), T_mult))
+ T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1)
+ T_i = T_0 * T_mult ** (n)
+ return T_cur, T_i
+ def compute_decay_value(self, start, end, T_cur, T_i):
+ # Linear decay in every cycle time.
+ return start - (start - end)/(T_i-1) * (T_cur % T_i)
+ def step(self, epoch, this_iter, epoch_batchs):
+ # Heated up for t and s.
+ # Decay for margin and dropout p.
+ if self.use_step:
+ if self.step_params["m"]:
+ current_postion = epoch*epoch_batchs + this_iter
+ lambda_factor = max(self.step_params["lambda_0"],
+ self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"]))
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
+ if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
+ T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch)
+ T_cur = T_cur*epoch_batchs + this_iter
+ T_i = T_i * epoch_batchs
+ if self.step_params["t"]:
+ self.loss.t = self.compute_decay_value(
+ *self.step_params["t_tuple"], T_cur, T_i)
+ if self.step_params["p"]:
+ self.aug_dropout.p = self.compute_decay_value(
+ *self.step_params["p_tuple"], T_cur, T_i)
+ if self.step_params["s"]:
+ self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]]
+ def step_iter(self, epoch, cur_step):
+ # For iterabledataset
+ if self.use_step:
+ if self.step_params["margin_warm"]:
+ offset_margin, lambda_m = self.margin_warm.step(cur_step)
+ lambda_m = max(1e-3,lambda_m)
+ self.loss.step(lambda_m,offset_margin)
+ if self.step_params["m"]:
+ lambda_factor = max(self.step_params["lambda_0"],
+ self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"]))
+ lambda_m = 1/(1 + lambda_factor)
+ self.loss.step(lambda_m)
+ if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]):
+ T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step)
+ if self.step_params["t"]:
+ self.loss.t = self.compute_decay_value(
+ *self.step_params["t_tuple"], T_cur, T_i)
+ if self.step_params["p"]:
+ self.aug_dropout.p = self.compute_decay_value(
+ *self.step_params["p_tuple"], T_cur, T_i)
+ if self.step_params["s"]:
+ self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]]
+if __name__ == '__main__':
+ # Input size: batch_size * feat_dim * seq_len *
+ timer = utils.Timer()
+ x = torch.zeros(1000, 80)
+ model = TransformerXvector(inputs_dim=80, num_targets=1211,training=False)
+ out = model.extract_embedding(x)
+ total = sum(p.numel() for p in model.parameters())
+ print(model)
+ print(total)
+ print(out.shape)
diff --git a/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh b/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh
old mode 100644
new mode 100755
index 7dc1979..9fe7ac2
--- a/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh
+++ b/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh
@@ -14,6 +14,7 @@ de_silence=false
@@ -87,8 +88,8 @@ if [ $stage -le 1 ]; then
for g in $(seq $nj); do
$cmd --gpu 1 ${dir}/log/extract.$g.log \
- python3 subtools/pytorch/pipeline/onestep/extract_embeddings_new.py --use-gpu=$use_gpu --gpu-id="$gpu_id" \
- --data-type=$data_type --de-silence=$de_silence --amp-th=$amp_th \
+ python3 subtools/pytorch/pipeline/onestep/extract_embeddings_online.py --use-gpu=$use_gpu --gpu-id="$gpu_id" \
+ --data-type=$data_type --de-silence=$de_silence --amp-th=$amp_th --max-chunk=$max_chunk \
--feat-config=$srcdir/$feat_config --nnet-config=$srcdir/$nnet_config \
"$srcdir/$model" "`echo $wavs | sed s/JOB/$g/g`" "`echo $output | sed s/JOB/$g/g`" || exit 1 &
sleep $sleep_time
@@ -98,8 +99,8 @@ if [ $stage -le 1 ]; then
$cmd JOB=1:$nj ${dir}/log/extract.JOB.log \
- python3 subtools/pytorch/pipeline/onestep/extract_embeddings_new.py --use-gpu="false" \
- --data-type=$data_type --de-silence=$de_silence --amp-th=$amp_th \
+ python3 subtools/pytorch/pipeline/onestep/extract_embeddings_online.py --use-gpu="false" \
+ --data-type=$data_type --de-silence=$de_silence --amp-th=$amp_th --max-chunk=$max_chunk \
--feat-config=$srcdir/$feat_config --nnet-config=$srcdir/$nnet_config \
"$srcdir/$model" "$wavs" "$output" || exit 1;
diff --git a/pytorch/pipeline/onestep/extract_embeddings_new.py b/pytorch/pipeline/onestep/extract_embeddings_online.py
old mode 100644
new mode 100755
similarity index 88%
rename from pytorch/pipeline/onestep/extract_embeddings_new.py
rename to pytorch/pipeline/onestep/extract_embeddings_online.py
index 4284d5b..17894b3
--- a/pytorch/pipeline/onestep/extract_embeddings_new.py
+++ b/pytorch/pipeline/onestep/extract_embeddings_online.py
@@ -43,7 +43,11 @@
parser.add_argument("--de-silence", type=str, action=kaldi_common.StrToBoolAction, default=False, choices=["true", "false"],
help="Vad or not")
parser.add_argument("--amp-th", type=int, default=50,
- help="De_silence threshold (16bit)")
+ help="De_silence threshold (16bit)")
+parser.add_argument("--max-chunk", type=int, default=10000,
+ help="Select chun_size of features when extracting xvector")
parser.add_argument("--feat-config",type=str,default="",help="The config yaml of feat extraction")
parser.add_argument("--use-gpu", type=str, default='true',
@@ -80,12 +84,14 @@
raise ValueError("Expected nnet_config or (model_blueprint, model_creation) to exist.")
model = utils.create_model_from_py(model_blueprint, model_creation)
+ position = model.extracted_embedding
model.load_state_dict(torch.load(args.model_path, map_location='cpu'), strict=False)
# Select device
model = utils.select_model_device(model, args.use_gpu, gpu_id=args.gpu_id)
+ devc = utils.get_device(model)
+ model.eval()
feature_extraction_conf = None
if args.data_type != "kaldi":
@@ -102,7 +108,7 @@
- data_loader = DataLoader(dataset, batch_size=None,num_workers=2, prefetch_factor=500)
+ data_loader = DataLoader(dataset, batch_size=None,num_workers=2, prefetch_factor=100)
timer = utils.Timer()
with kaldi_io.open_or_fd(args.vectors_wspecifier, 'wb') as w:
@@ -111,11 +117,16 @@
pbar=tqdm(total=tot_len, position=0,ascii=True,miniters=tot_len/100,dynamic_ncols=True)
for idx,sample in enumerate(data_loader):
key = sample['keys'][0]
- feats = sample['feats'][0]
+ feats = sample['feats'][0].to(devc)
total_dur += feats.size(0)*0.01
- embedding = model.extract_embedding(feats)
+ embedding = model.extract_embedding_whole(feats,position=position,maxChunk=args.max_chunk)
+ # embedding1 = model.forward_1(feats)
+ # print(embedding)
+ # print(embedding1)
+ # assert 1==0
if cnt%500==0:
diff --git a/pytorch/pipeline/onestep/prepare_speechaug_csv.py b/pytorch/pipeline/onestep/prepare_speechaug_csv.py
old mode 100644
new mode 100755
index 898d131..470847e
--- a/pytorch/pipeline/onestep/prepare_speechaug_csv.py
+++ b/pytorch/pipeline/onestep/prepare_speechaug_csv.py
@@ -225,9 +225,14 @@ def prepare_aug_csv(items, csv_file, max_length=None):
for i in range(int(duration / max_length)):
start = int(max_length * i * rate)
- stop = int(
- min(max_length * (i + 1), duration) * rate
- )
+ if i == int(duration / max_length) -1:
+ stop = int(
+ duration * rate
+ )
+ else:
+ stop = int(
+ max_length * (i + 1) * rate
+ )
new_filename = (
filename[: -len(f".{ext}")] + f"_{i}.{ext}"
@@ -264,17 +269,17 @@ def concat_csv(out_file,*csv_files):
# Options
- parser.add_argument("--openrir-folder", type=str, default='/tsdata/ASR',
+ parser.add_argument("--openrir-folder", type=str, default='/data',
help="where has openslr rir.")
- parser.add_argument("--musan-folder", type=str, default='/tsdata/ASR',
+ parser.add_argument("--musan-folder", type=str, default='/data',
help="where has openslr musan.")
- parser.add_argument("--savewav-folder", type=str, default='/work1/ldx/speech_aug_2_new',
+ parser.add_argument("--savewav-folder", type=str, default='/export/yourpath/speech_aug_6',
help="noise clips for online speechaug, set it in SSD.")
parser.add_argument("--force-clear", type=str, action=kaldi_common.StrToBoolAction,
default=True, choices=["true", "false"],
help="force clear")
- parser.add_argument("--max-noise-len", type=float, default=2.015,
+ parser.add_argument("--max-noise-len", type=float, default=6.015,
help="the maximum noise length in seconds. Noises longer than this will be cut into pieces")
parser.add_argument("csv_aug_folder", type=str, help="csv file folder.")
diff --git a/recipe/voxcelebSRC/README.md b/recipe/voxcelebSRC/README.md
new file mode 100755
index 0000000..e9db8d4
--- /dev/null
+++ b/recipe/voxcelebSRC/README.md
@@ -0,0 +1,57 @@
+## Reports
+### Results of ResNet34
+* Egs = Voxceleb2_dev(online random aug) + sequential sampling(2s)
+* Optimization = [SGD (lr = 0.04) + ReduceLROnPlateau] x 4 GPUs (total batch-size=512)
+* ResNet34 (channels = 32, 64, 128, 256) + Stats-Pooling + FC-BN + AM-Softmax (margin = 0.2) + AMP training
+* Back-end = near + Cosine
+| EER% | vox1-O | vox1-O-clean | vox1-E | vox1-E-clean | vox1-H | vox1-H-clean |
+| Submean | 1.071 | 0.920 | 1.257 | 1.135 | 2.205 | 2.072 |
+| AS-Norm | 0.970 | 0.819 | - | - | - | - |
+### Results of ECAPA-TDNN
+* Egs = Voxceleb2_dev(online random aug) + random chunk(2s)
+* Optimization = [adamW (lr = 1e-8 - 1e-3) + cyclic for 3 cycle with triangular2 strategy] x 4 GPUs (total batch-size=512)
+* ECAPA-TDNN (channels = 1024) + FC-BN + AAM-Softmax (margin = 0.2)
+* Back-end = near + Cosine
+| EER% | vox1-O | vox1-O-clean | vox1-E | vox1-E-clean | vox1-H | vox1-H-clean |
+| Submean | 1.045 | 0.904 | 1.330 | 1.211 | 2.430 | 2.303 |
+| AS-Norm | 0.991 | 0.856 | - | - | - | - |
+### Results of Conformer
+* Egs = Voxceleb2_dev(online random aug) + random chunk(3s)
+* Optimization = [adamW (lr = 1e-6 - 1e-3) + 1cycle] x 4 GPUs (total batch-size=512)
+* Conformer + FC-Swish-LN + ASP + FC-LN + AAM-Softmax (margin = 0.2))
+* Back-end = near + Cosine
+* LM: Large-Margin Fine-tune (margin: 0.2 --> 0.5, chunk: 6s)
+| Config | | vox1-O | vox1-O-clean | vox1-E | vox1-E-clean | vox1-H | vox1-H-clean |
+|:---------------------------- |:------:|:------:|:------------:|:------:|:------------:|:------:|:------------:|
+| 6L-256D-4H-4Sub (50 epochs) | Cosine | 1.204 | 1.074 | 1.386 | 1.267 | 2.416 | 2.294 |
+| | AS-Norm | 1.092 | 0.952 | - | - | - | - |
+| $\quad+$ SAM training | Cosine | 1.103 | 0.984 | 1.350 | 1.234 | 2.380 | 2.257 |
+| | LM | 1.034 | 0.899 | 1.181 | 1.060 | 2.079 | 1.953 |
+| | AS-Norm | 0.943 | 0.792 | - | - | - | - |
+| 6L-256D-4H-2Sub (30 epochs) | Cosine | 1.066 | 0.915 | 1.298 | 1.177 | 2.167 | 2.034 |
+| | LM | 1.029 | 0.888 | 1.160 | 1.043 | 1.923 | 1.792 |
+| | AS-Norm | 0.949 | 0.792 | - | - | - | - |
+### Results of RTF
+* RTF is evaluated on LibTorch-based runtime, see `subtools/runtime`
+* One thread is used for CPU threading and TorchScript inference.
+* CPU: Intel(R) Xeon(R) Gold 5218R CPU @ 2.10GHz.
+| Model | Config | Params | RTF |
+|:-----|:------ |:------:|:---:|
+| ResNet34 | base32 | 6.80M | 0.090 |
+| ECAPA | C1024 | 16.0M | 0.071 |
+| | C512 | 6.53M | 0.030 |
+| Conformer | 6L-256D-4H-4Sub | 18.8M | 0.025 |
+| | 6L-256D-4H-2Sub | 22.5M | 0.070 |
\ No newline at end of file
diff --git a/recipe/voxcelebSRC/runVoxcelebSRC_online.sh b/recipe/voxcelebSRC/runVoxcelebSRC_online.sh
old mode 100644
new mode 100755
index a2b8c6d..14b0675
--- a/recipe/voxcelebSRC/runVoxcelebSRC_online.sh
+++ b/recipe/voxcelebSRC/runVoxcelebSRC_online.sh
@@ -68,50 +68,50 @@ subtools/newCopyData.sh $prefix "voxceleb2_dev voxceleb1"
# [5] Sample egs. It will do cmn and vad firstly and then remove invalid utts. Finally,
# it samples egs to fixed chunk-size with instance sampling.
-subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=0 --endstage=2
+# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=0 --endstage=2
# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runRepvggXvector.py --stage=0 --endstage=2
# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runEcapaXvector_online.py --stage=0 --endstage=2
+subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector.py --stage=0 --endstage=2
# [6] Train a thin Resnet34 model with AM-Softmax loss and 8 GPUs will be used to accelerate training
-subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=3 --endstage=3 --gpu-id=0,1,2,3,4
-# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runRepvggXvector.py --stage=3 --endstage=3 --gpu-id=0,1,2,3,4
-# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runEcapaXvector_online.py --stage=3 --endstage=3 --gpu-id=0,1,2,3,4
+# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=3 --endstage=3 --gpu-id=0,1,2,3
+# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runRepvggXvector.py --stage=3 --endstage=3 --gpu-id=0,1,2,3
+# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runEcapaXvector_online.py --stage=3 --endstage=3 --gpu-id=0,1,2,3
+subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector.py --stage=3 --endstage=3 --gpu-id=0,1,2,3
# [7] Extract near xvectors for voxceleb1 and voxceleb2_dev
-subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=4
+# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=4
# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runRepvggXvector.py --stage=4
# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runEcapaXvector_online.py --stage=4
+subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector.py --stage=4
+# [8] Large-Margin Fine-tune
+subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector_LM.py --stage=3 --endstage=3 --gpu-id=0,1,2,3
+# [9] Extract xvectors
+subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector.py --stage=4
### Back-end scoring
-# [14] Score with submean + Cosine + AS-Norm processes
+# [14] Score with submean + Cosine + AS-Norm processes.
+# tasks="vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean"
+# for task in $tasks;do
+# score_norm=false
+# [ "$task" == "vox1-O" ] && score_norm=true
+# [ "$task" == "vox1-O-clean" ] && score_norm=true
+# subtools/recipe/voxcelebSRC/gather_results_from_epochs.sh --prefix $prefix --score cosine --submean true \
+# --vectordir "exp/resnet34_fbank80_online" --task $task --epochs "40" --positions "near" --trainset voxceleb2_dev_vad \
+# --score-norm $score_norm --score-norm-method "asnorm" --top-n 100 --cohort-set voxceleb2_dev_vad
+# done
tasks="vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean"
for task in $tasks;do
[ "$task" == "vox1-O" ] && score_norm=true
[ "$task" == "vox1-O-clean" ] && score_norm=true
- subtools/recipe/voxcelebSRC/gather_results_from_epochs.sh --prefix $prefix --score cosine --submean true \
- --vectordir "exp/resnet34_fbank80_online" --task $task --epochs "40" --positions "near" --trainset voxceleb2_dev_vad \
- --score-norm $score_norm --score-norm-method "asnorm" --top-n 100 --cohort-set voxceleb2_dev_vad
+ subtools/recipe/voxcelebSRC/gather_results_from_epochs.sh --prefix $prefix --score cosine --submean false \
+ --vectordir "exp/conformer_6L256D4H_4sub_lm" --task $task --epochs "4" --positions "near" --trainset voxceleb2_dev \
+ --score-norm $score_norm --score-norm-method "asnorm" --top-n 300 --cohort-set voxceleb2_dev
-#### Report ####
-# Egs = Voxceleb2_dev(online random aug) + sequential sampling
-# Optimization = [SGD (lr = 0.01) + ReduceLROnPlateau] x 4 GPUs (total batch-size=512)
-# Resnet34 (channels = 32, 64, 128, 256) + Stats-Pooling + FC-ReLU-BN-FC-BN + AM-Softmax (margin = 0.2) + AMP training
-# Back-end = near + Cosine
-# EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean
-# Submean 1.071 0.920 1.257 1.135 2.205 2.072
-# AS-Norm 0.970 0.819 - - - -
-# Egs = Voxceleb2_dev(online random aug) + random chunk
-# ECAPA-TDNN (channels = 1024) + FC-ReLU-BN-FC-BN + AAM-Softmax (margin = 0.2)
-# Optimization = [adamW (lr = 1e-8 - 1e-3) + cyclic for 3 cycle with triangular2 strategy] x 4 GPUs (total batch-size=512)
-# Back-end = near + Cosine
-# EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean
-# Submean 1.045 0.904 1.330 1.211 2.430 2.303
-# AS-Norm 0,991 0.856 - - - -