diff --git a/conf/egs_pre_sre_lm.yaml b/conf/egs_pre_sre_lm.yaml new file mode 100755 index 0000000..5f5725d --- /dev/null +++ b/conf/egs_pre_sre_lm.yaml @@ -0,0 +1,85 @@ +# Copyright xmuspeech (Author: Leo 2022-01-23) +feat_dim: 80 # the num_mel_bins of fbank and the num_ceps of mfcc +data_type: 'shard' # shard or raw +# feature extraction +dataset_conf: + # asv_target: true + filter: false + filter_conf: + max_length: 15.0 + min_length: 0.2 + max_cut: true + # resample + resample: false + resample_conf: + resample_rate: 16000 + + # pre speed_perturb + pre_speed_perturb: false + perturb_conf: + speeds: [90, 100, 110] # larger->slower + sample_rate: 16000 + # random_chunk + random_chunk: true + random_chunk_size: 6.015 + + # waveform true config + speech_aug: true + speech_aug_conf: subtools/conf/speech_aug_lm.yaml + csv_aug_folder: '' + + # It seems exit some bug, DO NOT set dither and use_energy together. + feature_extraction_conf: + # feature_type: 'mfcc' + # kaldi_featset: + # num_ceps: 23 + # num_mel_bins: 23 + # frame_shift: 10 + # frame_length: 25 + # low_freq: 40.0 + # high_freq: -200 + # energy_floor: 0.0 + # dither: 0.0 # conflicted with use_energy=true. + # use_energy: true # if you want use energy-based vad, set it true. + + feature_type: 'fbank' + kaldi_featset: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + low_freq: 40 + high_freq: -200 + energy_floor: 0.0 + use_energy: false + + mean_var_conf: + mean_norm: true + std_norm: false + + + + # spec level config + spec_aug: false + spec_aug_conf: + aug: specaugment # None or specaugment + aug_params: + frequency: 0.2 + frame: 0.2 + rows: 4 + cols: 4 + random_rows: true + random_cols: true + + + shuffle: true + shuffle_conf: + shuffle_size: 3000 + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 64 + +# attention: Do not specify batch size in dataloader. +data_loader_conf: + num_workers: 8 + pin_memory: false + prefetch_factor: 20 # pf(400) * bs(16) is about 2 shards which has 3000 samples each. diff --git a/conf/egs_pre_sre_transformer.yaml b/conf/egs_pre_sre_transformer.yaml new file mode 100755 index 0000000..5271893 --- /dev/null +++ b/conf/egs_pre_sre_transformer.yaml @@ -0,0 +1,85 @@ +# Copyright xmuspeech (Author: Leo 2022-01-23) +feat_dim: 80 # the num_mel_bins of fbank and the num_ceps of mfcc +data_type: 'shard' # shard or raw +# feature extraction +dataset_conf: + # asv_target: true + filter: false + filter_conf: + max_length: 15.0 + min_length: 0.2 + max_cut: true + # resample + resample: false + resample_conf: + resample_rate: 16000 + + # pre speed_perturb + pre_speed_perturb: false + perturb_conf: + speeds: [90, 100, 110] # larger->slower + sample_rate: 16000 + # random_chunk + random_chunk: true + random_chunk_size: 3.015 + + # waveform true config + speech_aug: true + speech_aug_conf: subtools/conf/speech_aug_transformer.yaml + csv_aug_folder: '' + + # It seems exit some bug, DO NOT set dither and use_energy together. + feature_extraction_conf: + # feature_type: 'mfcc' + # kaldi_featset: + # num_ceps: 23 + # num_mel_bins: 23 + # frame_shift: 10 + # frame_length: 25 + # low_freq: 40.0 + # high_freq: -200 + # energy_floor: 0.0 + # dither: 0.0 # conflicted with use_energy=true. + # use_energy: true # if you want use energy-based vad, set it true. + + feature_type: 'fbank' + kaldi_featset: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + low_freq: 40 + high_freq: -200 + energy_floor: 0.0 + use_energy: false + + mean_var_conf: + mean_norm: true + std_norm: false + + + + # spec level config + spec_aug: false + spec_aug_conf: + aug: specaugment # None or specaugment + aug_params: + frequency: 0.2 + frame: 0.2 + rows: 4 + cols: 4 + random_rows: true + random_cols: true + + + shuffle: true + shuffle_conf: + shuffle_size: 3000 + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 128 + +# attention: Do not specify batch size in dataloader. +data_loader_conf: + num_workers: 8 + pin_memory: false + prefetch_factor: 50 # pf(400) * bs(16) is about 2 shards which has 3000 samples each. diff --git a/conf/speech_aug_lm.yaml b/conf/speech_aug_lm.yaml new file mode 100755 index 0000000..556dc53 --- /dev/null +++ b/conf/speech_aug_lm.yaml @@ -0,0 +1,73 @@ +speechaug: + mod: random # chain,concat,random + aug_classes: + - + aug_name: add_noise # Define the speeech augment name + aug_type: Env # Env or Time + random_mod_weight: 0.5 + reverb_prob: 0.0 + noise_prob: 1.0 + noise_snr_low: 5 + noise_snr_high: 15 + noise_csv: exp/aug_csv/combine_music_noise.csv + add_filt_min: 0.5 + pad_noise: true + noise_num_workers: 0 + + - + aug_name: add_babble_noise + aug_type: Env + random_mod_weight: 0.2 + reverb_prob: 0.0 + noise_prob: 0.0 + babble_prob: 1.0 + babble_speaker_count: 4 + babble_snr_low: 13 + babble_snr_high: 20 + babble_csv: exp/aug_csv/musan_speech.csv + babble_noise_max_len: 8.0 + add_filt_min: 0.5 + pad_noise: true + noise_num_workers: 0 + + - + aug_name: add_rev + aug_type: Env + random_mod_weight: 0.3 + reverb_prob: 1.0 + noise_prob: 0.0 + babble_prob: 0.0 + reverb_csv: exp/aug_csv/combine_sim_small_medium_rev.csv + rir_scale_factor: 1.0 + + + - + aug_name: add_rev_noise + aug_type: Env + random_mod_weight: 0.2 + reverb_prob: 1.0 + noise_prob: 0.5 + noise_snr_low: 0 + noise_snr_high: 15 + noise_csv: exp/aug_csv/pointsrc_noise.csv + reverb_csv: exp/aug_csv/real_reverb.csv + add_filt_min: 0.5 + pad_noise: true + noise_num_workers: 0 + rir_scale_factor: 1.0 + +# You can define here for more augment strategy. +# tail_speechaug: +# mod: chain +# aug_classes: +# - +# aug_name: augment_speed +# aug_type: Time +# perturb_type: resample # ['resample','sox_speed','sox_tempo'] +# perturb_prob: 0.0 +# drop_freq_prob: 0.0 +# drop_chunk_prob: 0.0 +# sample_rate: 16000 +# speeds: [95, 100, 105] +# keep_shape: true +# change_spk: false \ No newline at end of file diff --git a/conf/speech_aug_transformer.yaml b/conf/speech_aug_transformer.yaml new file mode 100755 index 0000000..b50fa63 --- /dev/null +++ b/conf/speech_aug_transformer.yaml @@ -0,0 +1,126 @@ +speechaug: + mod: random # chain,concat,random + aug_classes: + - + aug_name: add_noise # Define the speeech augment name + aug_type: Env # Env or Time + random_mod_weight: 0.5 + reverb_prob: 0.0 + noise_prob: 1.0 + noise_snr_low: 5 + noise_snr_high: 15 + noise_csv: exp/aug_csv/combine_music_noise.csv + noise_num_workers: 0 + + + + # - + # aug_name: add_white_noise + # aug_type: Env + # random_mod_weight: 1 + # reverb_prob: 0.0 + # noise_prob: 1.0 + # noise_snr_low: 0 + # noise_snr_high: 15 + # noise_csv: ~ + # noise_num_workers: 0 + + - + aug_name: add_babble_noise + aug_type: Env + random_mod_weight: 0.2 + reverb_prob: 0.0 + noise_prob: 0.0 + babble_prob: 1.0 + babble_speaker_count: 4 + babble_snr_low: 13 + babble_snr_high: 20 + babble_csv: exp/aug_csv/musan_speech.csv + babble_noise_max_len: 3.015 + noise_num_workers: 0 + + - + aug_name: add_rev + aug_type: Env + random_mod_weight: 0.3 + reverb_prob: 1.0 + noise_prob: 0.0 + babble_prob: 0.0 + reverb_csv: exp/aug_csv/combine_sim_small_medium_rev.csv + rir_scale_factor: 1.0 + + + - + aug_name: add_rev_noise + aug_type: Env + random_mod_weight: 0.2 + reverb_prob: 1.0 + noise_prob: 0.5 + noise_snr_low: 0 + noise_snr_high: 15 + noise_csv: exp/aug_csv/pointsrc_noise.csv + reverb_csv: exp/aug_csv/real_reverb.csv + noise_num_workers: 0 + rir_scale_factor: 1.0 + + # - + # aug_name: augment_wavedrop + # aug_type: Time + # random_mod_weight: 1 + # perturb_prob: 0.0 + # drop_freq_prob: 1.0 + # drop_chunk_prob: 1.0 + # drop_freq_count_low: 0 + # drop_freq_count_high: 3 + # drop_chunk_count_low: 0 + # drop_chunk_count_high: 4 + # drop_chunk_length_low: 1000 + # drop_chunk_length_high: 2000 + # sample_rate: 16000 + # speeds: [100] + + # - + # aug_name: augment_speed + # aug_type: Time + # random_mod_weight: 1 + # perturb_prob: 1.0 + # drop_freq_prob: 1.0 + # drop_chunk_prob: 1.0 + # drop_freq_count_low: 0 + # drop_freq_count_high: 3 + # drop_chunk_count_low: 0 + # drop_chunk_count_high: 4 + # drop_chunk_length_low: 1000 + # drop_chunk_length_high: 2000 + # sample_rate: 16000 + # speeds: [95, 100, 105] + # keep_shape: true + +# You can define here for more augment strategy. +tail_speechaug: + mod: chain + aug_classes: + # - + # aug_name: augment_wavedrop + # aug_type: Time + # random_mod_weight: 1 + # perturb_prob: 0.0 + # drop_freq_prob: 1.0 + # drop_chunk_prob: 1.0 + # drop_freq_count_low: 0 + # drop_freq_count_high: 3 + # drop_chunk_count_low: 0 + # drop_chunk_count_high: 4 + # drop_chunk_length_low: 1000 + # drop_chunk_length_high: 2000 + # sample_rate: 16000 + # speeds: [100] + - + aug_name: augment_speed + aug_type: Time + perturb_prob: 1.0 + drop_freq_prob: 0.0 + drop_chunk_prob: 0.0 + sample_rate: 16000 + speeds: [95, 100, 105] + keep_shape: true \ No newline at end of file diff --git a/doc/papers/cntrans.jpg b/doc/papers/cntrans.jpg new file mode 100755 index 0000000..8c3f4d9 Binary files /dev/null and b/doc/papers/cntrans.jpg differ diff --git a/doc/papers/conformer.md b/doc/papers/conformer.md new file mode 100755 index 0000000..047d1b2 --- /dev/null +++ b/doc/papers/conformer.md @@ -0,0 +1,26 @@ +# ASR transferring for ASV conformer + +#### Baseline ASV conformers are conducted on VoxCeleb and CNCeleb. +* VoxCeleb: `subtools/recipe/voxcelebSRC/runVoxcelebSRC_online.sh` +* CNCeleb: To be released. + +#### ASR transferring is conducted on CNCeleb +
+
+ +
+
+* The pretrained ASR encoder can be either an open source pretrained model or trained from scratch. +* Parts but not all of the ASR encoder tends to achieve a better performance. +* Sharpness-Aware Minimizationis (SAM) training seems to effeciently alleviate overfitting. + +#### Results: +
+
+
+ + +
+
+ +#### Runtime is based on libtorch, go to `subtools/runtime` and evaluate your models' RTF. \ No newline at end of file diff --git a/doc/papers/giatrans.jpg b/doc/papers/giatrans.jpg new file mode 100755 index 0000000..6eaa342 Binary files /dev/null and b/doc/papers/giatrans.jpg differ diff --git a/doc/papers/trans.jpg b/doc/papers/trans.jpg new file mode 100755 index 0000000..69fa627 Binary files /dev/null and b/doc/papers/trans.jpg differ diff --git a/doc/runtime_deploy.png b/doc/runtime_deploy.png new file mode 100755 index 0000000..f66d9e6 Binary files /dev/null and b/doc/runtime_deploy.png differ diff --git a/pytorch/launcher/runEcapaXvector_online.py b/pytorch/launcher/runEcapaXvector_online.py old mode 100644 new mode 100755 index 4277ccf..c50acc3 --- a/pytorch/launcher/runEcapaXvector_online.py +++ b/pytorch/launcher/runEcapaXvector_online.py @@ -20,6 +20,7 @@ import libs.support.utils as utils import libs.support.kaldi_common as kaldi_common import libs.training.trainer_online as trainer +import libs.training.trainer_online_sam as trainer_sam import libs.training.lr_scheduler_online as learn_rate_scheduler import libs.training.optim as optim import libs.egs.egs_online as egs @@ -29,10 +30,10 @@ Python version is gived (rather than Shell) to have more freedom, such as decreasing limitation of parameters that transfering them to python from shell. -Note, this launcher does not contain dataset preparation, augmentation, extracting acoustic features and back-end scoring etc. - 1.See subtools/recipe/voxceleb/runVoxceleb.sh to get complete stages. - 2.See subtools/newCopyData.sh, subtools/makeFeatures.sh.sh, subtools/computeVad.sh, subtools/augmentDataByNoise.sh and - subtools/scoreSets.sh and run these script separately before or after running this launcher. +Note, this launcher does not contain dataset preparation, augmentation, and back-end scoring etc. + 1.See subtools/recipe/voxcelebSRC/runVoxceleb_online.sh to get complete stages. + 2.An on-the-fly feature extraction mod. + How to modify this launcher: 1.Prepare your kaldi format dataset and model.py (model blueprint); @@ -95,8 +96,8 @@ formatter_class=argparse.RawTextHelpFormatter, conflict_handler='resolve') -parser.add_argument("--stage", type=int, default=0, - help="The stage to control the start of training epoch (default 4).\n" +parser.add_argument("--stage", type=int, default=3, + help="The stage to control the start of training epoch (default 3).\n" " stage 0: Generate raw wav kaldidir which contains utt2chunk and utt2dur. (preprocess_raw_wav_egs.sh).\n" " stage 1: remove utts (preprocess_raw_wav_egs.sh).\n" " stage 2.1: get chunk egs (preprocess_raw_wav_egs.sh).\n" @@ -105,7 +106,7 @@ " stage 4: extract xvector.") parser.add_argument("--endstage", type=int, default=4, - help="The endstage to control the endstart of training epoch (default 5).") + help="The endstage to control the endstart of training epoch (default 4).") parser.add_argument("--train-stage", type=int, default=-1, help="The stage to control the start of training epoch (default -1).\n" @@ -252,6 +253,8 @@ "use_step": False, "step_params": { + "margin_warm":False, + "margin_warm_conf":{"start_epoch":1,"end_epoch":1,"offset_margin":-0.0,"init_lambda":1.0}, "T": None, "m": True, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4, "s": False, "s_tuple": (30, 12), "s_list": None, @@ -269,9 +272,16 @@ "weight_decay": 1e-1, "lookahead.k": 5, # 0 means not using lookahead and if used, suggest to set it as 0.5. - "lookahead.alpha": 0, + "lookahead.alpha": 0., "gc": False, # If true, use gradient centralization. - "nesterov": False # for sgd + "nesterov": False, # for sgd + "sam": False, + "sam.rho": 2.0, # 2.0 for adaptive + "sam.adaptive": True, + # "custwd_dict":{ + # "train_len":0, + # "bias":0 + # } } lr_scheduler_params = { @@ -356,7 +366,6 @@ if utils.is_main_training(): logger.info("Get model_blueprint from model directory.") # Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path. - print(model_blueprint) model_blueprint = utils.create_model_dir( model_dir, model_blueprint, stage=train_stage) @@ -384,13 +393,28 @@ model = model_py.ECAPA_TDNN( info["feat_dim"], info["num_targets"], **model_params) + epoch_iters = (info['epoch_iters']//accum_grad) + if hasattr(model,'margin_warm'): + model.margin_warm.update_step_range(epoch_iters) + # If multi-GPU used, then batchnorm will be converted to synchronized batchnorm, which is important # to make peformance stable. # It will change nothing for single-GPU training. model = utils.convert_synchronized_batchnorm(model) if utils.is_main_training(): + print(model) + p1=sum(p.numel() for p in model.parameters()) + script_model = copy.deepcopy(model) + script_model.loss=None + p2 = sum(p.numel() for p in script_model.parameters()) + logger.info("model params w/o proj layer: {} / {} .".format(p1,p2)) + script_model = torch.jit.script(script_model) + script_model.save(os.path.join(model_dir, 'init.zip')) + logger.info("The number of steps per epoch is about {}.".format(epoch_iters)) logger.info("Define optimizer and lr_scheduler.") + del script_model + optimizer = optim.get_optimizer(model, optimizer_params) lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper( optimizer, lr_scheduler_params) @@ -399,7 +423,7 @@ if utils.is_main_training(): utils.write_list_to_file([egs_params, model_params, optimizer_params, - lr_scheduler_params], model_dir+'/config/params.dict') + lr_scheduler_params], model_dir+'/config/params.dict',yml=True) if utils.is_main_training(): logger.info("Init a simple trainer.") @@ -409,14 +433,16 @@ "start_epoch": train_stage, "epochs": epochs, "use_gpu": use_gpu, "gpu_id": gpu_id, "use_amp": use_amp, "skip_nan_batch": skip_nan_batch, "benchmark": benchmark, "suffix": suffix, "compute_batch_num_valid": compute_batch_num_valid, "report_interval_iters": report_interval_iters, "record_file": "train.csv"}) - trainer = trainer.SimpleTrainer(package) + train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer + + execuer = train_exec.SimpleTrainer(package) if run_lr_finder: - trainer.run_lr_finder("lr_finder.csv", init_lr=1e-8, + execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98) endstage = 3 # Do not start extractor. else: - trainer.run() + execuer.run() # Extract xvector @@ -440,7 +466,7 @@ gpu_id = "" sleep_time = 10 feat_config = "feat_conf.yaml" - + max_chunk = 10000 # Run a batch extracting process. try: for position in to_extracted_positions: @@ -480,12 +506,12 @@ # with python's threads to extract xvectors directly, but the shell script is more convenient. kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh " " --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' " - " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th}" + " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} " " --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} " "{model_dir} {datadir} {outdir}".format(model_file=model_file, nj=nj, use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config, feat_config=feat_config, data_type=data_type_emb, de_silence=str(de_silence).lower(), amp_th=amp_th, - model_dir=model_dir, datadir=datadir, outdir=outdir)) + max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir)) except BaseException as e: if not isinstance(e, KeyboardInterrupt): traceback.print_exc() diff --git a/pytorch/launcher/runRepvggXvector.py b/pytorch/launcher/runRepvggXvector.py old mode 100644 new mode 100755 index 1b62f71..3125b58 --- a/pytorch/launcher/runRepvggXvector.py +++ b/pytorch/launcher/runRepvggXvector.py @@ -20,6 +20,7 @@ import libs.training.optim as optim import libs.training.lr_scheduler_online as learn_rate_scheduler import libs.training.trainer_online as trainer +import libs.training.trainer_online_sam as trainer_sam import libs.support.kaldi_common as kaldi_common import libs.support.utils as utils from libs.support.logging_stdout import patch_logging_stream @@ -32,6 +33,9 @@ Python version is gived (rather than Shell) to have more freedom, such as decreasing limitation of parameters that transfering them to python from shell. +Note, this launcher does not contain dataset preparation, augmentation, and back-end scoring etc. + 1.See subtools/recipe/voxcelebSRC/runVoxceleb_online.sh to get complete stages. + 2.An on-the-fly feature extraction mod. How to modify this launcher: 1.Prepare your kaldi format dataset and model.py (model blueprint); @@ -93,8 +97,8 @@ formatter_class=argparse.RawTextHelpFormatter, conflict_handler='resolve') -parser.add_argument("--stage", type=int, default=0, - help="The stage to control the start of training epoch (default 4).\n" +parser.add_argument("--stage", type=int, default=3, + help="The stage to control the start of training epoch (default 3).\n" " stage 0: Generate raw wav kaldidir which contains utt2chunk and utt2dur. (preprocess_raw_wav_egs.sh).\n" " stage 1: remove utts (preprocess_raw_wav_egs.sh).\n" " stage 2.1: get chunk egs (preprocess_raw_wav_egs.sh).\n" @@ -103,7 +107,7 @@ " stage 4: extract xvector.") parser.add_argument("--endstage", type=int, default=4, - help="The endstage to control the endstart of training epoch (default 5).") + help="The endstage to control the endstart of training epoch (default 4).") parser.add_argument("--train-stage", type=int, default=-1, help="The stage to control the start of training epoch (default -1).\n" @@ -264,6 +268,8 @@ "use_step":True, "step_params":{ + "margin_warm":False, + "margin_warm_conf":{"start_epoch":1,"end_epoch":1,"offset_margin":-0.0,"init_lambda":1.0}, "T":None, "m":True, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4, "s":False, "s_tuple":(30, 12), "s_list":None, @@ -371,7 +377,6 @@ if stage <= 3 <= endstage: if utils.is_main_training():logger.info("Get model_blueprint from model directory.") # Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path. - print(model_blueprint) model_blueprint = utils.create_model_dir(model_dir, model_blueprint, stage=train_stage) if utils.is_main_training():logger.info("Load egs to bunch.") @@ -387,7 +392,8 @@ with open(feat_config_path,'w') as fou: yaml.dump(feat_extraction_config,fou) - if utils.is_main_training():logger.info("Create model from model blueprint.") + if utils.is_main_training(): + logger.info("Create model from model blueprint.") # Another way: import the model.py in this python directly, but it is not friendly to the shell script of extracting and # I don't want to change anything about extracting script when the model.py is changed. model_py = utils.create_model_from_py(model_blueprint) @@ -399,8 +405,22 @@ # It will change nothing for single-GPU training. model = utils.convert_synchronized_batchnorm(model) + epoch_iters = (info['epoch_iters']//accum_grad) + if hasattr(model,'margin_warm'): + model.margin_warm.update_step_range(epoch_iters) - if utils.is_main_training():logger.info("Define optimizer and lr_scheduler.") + if utils.is_main_training(): + print(model) + p1=sum(p.numel() for p in model.parameters()) + script_model = copy.deepcopy(model) + script_model.loss=None + p2 = sum(p.numel() for p in script_model.parameters()) + logger.info("model params w/o proj layer: {} / {} .".format(p1,p2)) + script_model = torch.jit.script(script_model) + script_model.save(os.path.join(model_dir, 'init.zip')) + logger.info("The number of steps per epoch is about {}.".format(epoch_iters)) + logger.info("Define optimizer and lr_scheduler.") + del script_model optimizer = optim.get_optimizer(model, optimizer_params) lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper(optimizer, lr_scheduler_params) @@ -408,8 +428,9 @@ # Record params to model_dir - if utils.is_main_training():utils.write_list_to_file([egs_params, model_params, optimizer_params, - lr_scheduler_params], model_dir+'/config/params.dict') + if utils.is_main_training(): + utils.write_list_to_file([egs_params, model_params, optimizer_params, + lr_scheduler_params], model_dir+'/config/params.dict',yml=True) if utils.is_main_training():logger.info("Init a simple trainer.") @@ -420,13 +441,15 @@ "skip_nan_batch":skip_nan_batch, "benchmark":benchmark, "suffix":suffix, "compute_batch_num_valid":compute_batch_num_valid, "report_interval_iters":report_interval_iters, "record_file":"train.csv"}) - trainer = trainer.SimpleTrainer(package) + train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer + + execuer = train_exec.SimpleTrainer(package) if run_lr_finder and utils.is_main_training(): - trainer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98) + execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98) endstage = 3 # Do not start extractor. else: - trainer.run() + execuer.run() # Plan to use del to avoid memeory account after training done and continue to execute stage 4. # But it dose not work and is still a problem. @@ -452,7 +475,7 @@ gpu_id = "" sleep_time = 10 feat_config="feat_conf.yaml" - + max_chunk = 10000 # Run a batch extracting process. try: for position in to_extracted_positions: @@ -500,12 +523,12 @@ # with python's threads to extract xvectors directly, but the shell script is more convenient. kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh " " --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' " - " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th}" + " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} " " --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} " "{model_dir} {datadir} {outdir}".format(model_file=depoly_model_file, nj=nj, use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config, feat_config=feat_config,data_type=data_type_emb,de_silence=str(de_silence).lower(),amp_th=amp_th, - model_dir=model_dir, datadir=datadir, outdir=outdir)) + max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir)) except BaseException as e: if not isinstance(e, KeyboardInterrupt): traceback.print_exc() diff --git a/pytorch/launcher/runResnetXvector_online.py b/pytorch/launcher/runResnetXvector_online.py old mode 100644 new mode 100755 index e7ca4ba..af74739 --- a/pytorch/launcher/runResnetXvector_online.py +++ b/pytorch/launcher/runResnetXvector_online.py @@ -20,6 +20,7 @@ import libs.training.optim as optim import libs.training.lr_scheduler_online as learn_rate_scheduler import libs.training.trainer_online as trainer +import libs.training.trainer_online_sam as trainer_sam import libs.support.kaldi_common as kaldi_common import libs.support.utils as utils from libs.support.logging_stdout import patch_logging_stream @@ -33,6 +34,10 @@ them to python from shell. +Note, this launcher does not contain dataset preparation, augmentation, and back-end scoring etc. + 1.See subtools/recipe/voxcelebSRC/runVoxceleb_online.sh to get complete stages. + 2.An on-the-fly feature extraction mod. + How to modify this launcher: 1.Prepare your kaldi format dataset and model.py (model blueprint); 2.Give the path of dataset, model blueprint, etc. in main parameters field; @@ -93,8 +98,9 @@ formatter_class=argparse.RawTextHelpFormatter, conflict_handler='resolve') -parser.add_argument("--stage", type=int, default=0, - help="The stage to control the start of training epoch (default 4).\n" + +parser.add_argument("--stage", type=int, default=3, + help="The stage to control the start of training epoch (default 3).\n" " stage 0: Generate raw wav kaldidir which contains utt2chunk and utt2dur. (preprocess_raw_wav_egs.sh).\n" " stage 1: remove utts (preprocess_raw_wav_egs.sh).\n" " stage 2.1: get chunk egs (preprocess_raw_wav_egs.sh).\n" @@ -103,7 +109,7 @@ " stage 4: extract xvector.") parser.add_argument("--endstage", type=int, default=4, - help="The endstage to control the endstart of training epoch (default 5).") + help="The endstage to control the endstart of training epoch (default 4).") parser.add_argument("--train-stage", type=int, default=-1, help="The stage to control the start of training epoch (default -1).\n" @@ -260,6 +266,8 @@ "use_step":True, "step_params":{ + "margin_warm":False, + "margin_warm_conf":{"start_epoch":1,"end_epoch":1,"offset_margin":-0.0,"init_lambda":1.0}, "T":None, "m":True, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4, "s":False, "s_tuple":(30, 12), "s_list":None, @@ -278,7 +286,14 @@ "lookahead.k":5, "lookahead.alpha":0., # 0 means not using lookahead and if used, suggest to set it as 0.5. "gc":False, # If true, use gradient centralization. - "nesterov": False # for sgd + "nesterov": False, # for sgd + "sam": False, # suggest true when apply ASR transferring. + "sam.rho": 2.0, # 2.0 for adaptive + "sam.adaptive": True, + # "custwd_dict":{ + # "train_len":0, + # "bias":0 + # } } lr_scheduler_params = { @@ -301,7 +316,7 @@ epochs = 50 # Total epochs to train. It is important. Here 18 = 6 -> 12 -> 18 with warmR.T_mult=1 and warmR.T_max=6. compute_batch_num_valid = 10 -report_interval_iters = 100 # About validation computation and loss reporting. If report_times_every_epoch is not None, +report_interval_iters = 500 # About validation computation and loss reporting. If report_times_every_epoch is not None, # then compute report_interval_iters by report_times_every_epoch. stop_early = False suffix = "params" # Used in saved model file. @@ -395,8 +410,22 @@ # It will change nothing for single-GPU training. model = utils.convert_synchronized_batchnorm(model) - - if utils.is_main_training():logger.info("Define optimizer and lr_scheduler.") + epoch_iters = (info['epoch_iters']//accum_grad) + if hasattr(model,'margin_warm'): + model.margin_warm.update_step_range(epoch_iters) + + if utils.is_main_training(): + print(model) + p1=sum(p.numel() for p in model.parameters()) + script_model = copy.deepcopy(model) + script_model.loss=None + p2 = sum(p.numel() for p in script_model.parameters()) + logger.info("model params w/o proj layer: {} / {} .".format(p1,p2)) + script_model = torch.jit.script(script_model) + script_model.save(os.path.join(model_dir, 'init.zip')) + logger.info("The number of steps per epoch is about {}.".format(epoch_iters)) + logger.info("Define optimizer and lr_scheduler.") + del script_model optimizer = optim.get_optimizer(model, optimizer_params) lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper(optimizer, lr_scheduler_params) @@ -404,9 +433,10 @@ # Record params to model_dir - if utils.is_main_training():utils.write_list_to_file([egs_params, model_params, optimizer_params, - lr_scheduler_params], model_dir+'/config/params.dict') + if utils.is_main_training(): + utils.write_list_to_file([egs_params, model_params, optimizer_params, + lr_scheduler_params], model_dir+'/config/params.dict',yml=True) if utils.is_main_training():logger.info("Init a simple trainer.") # Package(Elements:dict, Params:dict}. It is a key parameter's package to trainer and model_dir/config/. @@ -416,13 +446,16 @@ "skip_nan_batch":skip_nan_batch, "benchmark":benchmark, "suffix":suffix, "compute_batch_num_valid":compute_batch_num_valid, "report_interval_iters":report_interval_iters, "record_file":"train.csv"}) - trainer = trainer.SimpleTrainer(package) + train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer + + execuer = train_exec.SimpleTrainer(package) if run_lr_finder: - trainer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98) + execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98) endstage = 3 # Do not start extractor. else: - trainer.run() + execuer.run() + # Plan to use del to avoid memeory account after training done and continue to execute stage 4. # But it dose not work and is still a problem. @@ -449,7 +482,7 @@ gpu_id = "" sleep_time = 10 feat_config="feat_conf.yaml" - + max_chunk = 10000 # Run a batch extracting process. try: @@ -485,12 +518,12 @@ # with python's threads to extract xvectors directly, but the shell script is more convenient. kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh " " --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' " - " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th}" + " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} " " --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} " "{model_dir} {datadir} {outdir}".format(model_file=model_file, nj=nj, use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config, feat_config=feat_config,data_type=data_type_emb,de_silence=str(de_silence).lower(),amp_th=amp_th, - model_dir=model_dir, datadir=datadir, outdir=outdir)) + max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir)) except BaseException as e: if not isinstance(e, KeyboardInterrupt): traceback.print_exc() diff --git a/pytorch/launcher/runTransformerXvector.py b/pytorch/launcher/runTransformerXvector.py new file mode 100755 index 0000000..decb599 --- /dev/null +++ b/pytorch/launcher/runTransformerXvector.py @@ -0,0 +1,558 @@ +# -*- coding:utf-8 -*- + +# Copyright xmuspeech (Author: Leo 2021-10-05) +# Apache 2.0 +# Only support "nccl" backend multi-gpu training. +import sys +import os +import logging +import argparse +import traceback +import time +import yaml +import copy +import math +import numpy as np + +import torch +sys.path.insert(0, 'subtools/pytorch') +from libs.support.logging_stdout import patch_logging_stream +import libs.support.utils as utils +import libs.support.kaldi_common as kaldi_common +import libs.training.trainer_online as trainer +import libs.training.trainer_online_sam as trainer_sam +import libs.training.lr_scheduler_online as learn_rate_scheduler +import libs.training.optim as optim +import libs.egs.egs_online as egs +torch.multiprocessing.set_sharing_strategy('file_system') +"""A launcher script with python version (Snowdar's launcher to do experiments w.r.t snowdar-xvector.py). + +Python version is gived (rather than Shell) to have more freedom, such as decreasing limitation of parameters that transfering +them to python from shell. + +Note, this launcher does not contain dataset preparation, augmentation, and back-end scoring etc. + 1.See subtools/recipe/voxcelebSRC/runVoxceleb_online.sh to get complete stages. + 2.An on-the-fly feature extraction mod. + +How to modify this launcher: + 1.Prepare your kaldi format dataset and model.py (model blueprint); + 2.Give the path of dataset, model blueprint, etc. in main parameters field; + 3.Change the imported name of model in 'model = model_py.model_name(...)' w.r.t model.py by yourself; + 4.Modify any training parameters what you want to change (epochs, optimizer and lr_scheduler etc.); + 5.Modify parameters of extracting in stage 4 w.r.t your own training config; + 6.Run this launcher. + +Conclusion: preprare -> config -> run. + +How to run this launcher to train a model: + 1.For CPU-based training case. The key option is --use-gpu. + python3 launcher.py --use-gpu=false + 2.For single-GPU training case (Default). + python3 launcher.py + 3.For DDP-based multi-GPU training case. Note --nproc_per_node is equal to number of gpu id in --gpu-id. + python3 -m torch.distributed.launch --nproc_per_node=2 launcher.py --gpu-id=0,1 + 4.For Horovod-based multi-GPU training case. Note --np is equal to number of gpu id in --gpu-id. + horovodrun -np 2 launcher.py --gpu-id=0,1 + 5.For all of above, a runLauncher.sh script has been created to launch launcher.py conveniently. + The key option to use single or multiple GPU is --gpu-id. + The subtools/runPytorchLauncher.sh is a soft symbolic which is linked to subtools/pytorch/launcher/runLauncher.sh, + so just use it. + + [ CPU ] + subtools/runPytorchLauncher.sh launcher.py --use-gpu=false + + [ Single-GPU ] + (1) Auto-select GPU device + subtools/runPytorchLauncher.sh launcher.py + (2) Specify GPU device + subtools/runPytorchLauncher.sh launcher.py --gpu-id=2 + + [ Multi-GPU ] + (1) Use DDP solution (Default). + subtools/runPytorchLauncher.sh launcher.py --gpu-id=2,3 --multi-gpu-solution="ddp" + (2) Use Horovod solution. + subtools/runPytorchLauncher.sh launcher.py --gpu-id=2,3 --multi-gpu-solution="horovod" + +If you have any other requirements, you could modify the codes in anywhere. +For more details of multi-GPU devolopment, see subtools/README.md. +""" + +# Logger +# Logger +patch_logging_stream(logging.INFO) +logger = logging.getLogger('libs') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [ %(pathname)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ]\n#### %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + +# Parser: add this parser to run launcher with some frequent options (really for conveninece). +parser = argparse.ArgumentParser( + description="""Train xvector framework with pytorch.""", + formatter_class=argparse.RawTextHelpFormatter, + conflict_handler='resolve') + +parser.add_argument("--stage", type=int, default=3, + help="The stage to control the start of training epoch (default 3).\n" + " stage 0: Generate raw wav kaldidir which contains utt2chunk and utt2dur. (preprocess_raw_wav_egs.sh).\n" + " stage 1: remove utts (preprocess_raw_wav_egs.sh).\n" + " stage 2.1: get chunk egs (preprocess_raw_wav_egs.sh).\n" + " stage 2.2: Prepare speech augment csv files.\n" + " stage 3: Training.\n" + " stage 4: extract xvector.") + +parser.add_argument("--endstage", type=int, default=4, + help="The endstage to control the endstart of training epoch (default 4).") + +parser.add_argument("--train-stage", type=int, default=-1, + help="The stage to control the start of training epoch (default -1).\n" + " -1 -> creating model_dir.\n" + " 0 -> model initialization (e.g. transfer learning).\n" + " >0 -> recovering training.") + +parser.add_argument("--force-clear", type=str, action=kaldi_common.StrToBoolAction, + default=True, choices=["true", "false"], + help="Clear the dir generated by preprocess.") + +parser.add_argument("--pre-rirmusan", type=str, action=kaldi_common.StrToBoolAction, + default=True, choices=["true", "false"], + help="Prepare the openrir and musan dataset for adding reverb and noises.") + +parser.add_argument('--use-amp', type=str, action=kaldi_common.StrToBoolAction, + default=False, choices=["true", "false"], + help='Use automatic mixed precision training') + +parser.add_argument("--skip-nan-batch", type=str, action=kaldi_common.StrToBoolAction, + default=True, choices=["true", "false"], + help="Whether skip optimizer stepping when the gradient has nan/inf values.") + +parser.add_argument("--accum-grad", type=int, default=1, + help="Using accumulate grad.") + +parser.add_argument("--multi-gpu-solution", type=str, default="ddp", + choices=["ddp"], + help="if number of gpu_id > 1, this option will be valid to init a multi-gpu solution.") + +parser.add_argument("--use-gpu", type=str, action=kaldi_common.StrToBoolAction, + default=True, choices=["true", "false"], + help="Use GPU or not.") + +parser.add_argument("--gpu-id", type=str, default="", + help="If NULL, then it will be auto-specified.") + +parser.add_argument("--benchmark", type=str, action=kaldi_common.StrToBoolAction, + default=False, choices=["true", "false"], + help="If true, save training time but require a little more gpu-memory.") + +parser.add_argument("--run-lr-finder", type=str, action=kaldi_common.StrToBoolAction, + default=False, choices=["true", "false"], + help="If true, run lr finder rather than training.") + +parser.add_argument("--sleep", type=int, default=0, + help="The waiting time to launch a launcher.") + +parser.add_argument("--local_rank", type=int, default=0, + help="Do not delete it when using DDP-based multi-GPU training.\n" + "It is important for torch.distributed.launch.") + +args = parser.parse_args() +## +######################################################### PARAMS ######################################################## +## +##--------------------------------------------------## +# Control options +stage = max(0, args.stage) +endstage = min(4, args.endstage) +train_stage = max(-1, args.train_stage) +##--------------------------------------------------## +# Preprocess options +force_clear = args.force_clear +preprocess_nj = 20 + + +whole_utt = True +random_segment = False +seg_dur = 3.015 +amp_th = 50 +de_silence = False +vad_wav_savdir = "export/yourpath" + + + +min_len = 2.0 +max_len = 1000.0 +limit_utts = 8 + +valid_split_type = "--total-spk" # --total-spk or --default +valid_utts = 2048 +valid_chunk_num = 2 +valid_fix_chunk_num = False + +data_type = 'shard' +num_utts_per_shard = 2000 +shard_dir = '/export/yourpath/voxceleb_dev_whole' +##--------------------------------------------------## +# Prepare speech augmention csv files. +pre_rirmusan = args.pre_rirmusan # whether skip this stage. +openrir_folder = "export/path" # where contains RIRS_NOISES folder. +musan_folder = "export/path" # where contains musan folder. +csv_aug_folder = "exp/aug_csv3" # csv file location. +savewav_folder = "/export/yourpath/speech_aug_3015" # save the noise seg into SSD. +max_noise_len = seg_dur # The max dur of noise. +##--------------------------------------------------## +# Training options +use_amp = args.use_amp +skip_nan_batch = args.skip_nan_batch +accum_grad = args.accum_grad + +use_gpu = args.use_gpu # Default true. +# If true, save much training time but require a little more gpu-memory. +benchmark = args.benchmark +gpu_id = args.gpu_id # If NULL, then it will be auto-specified. +run_lr_finder = args.run_lr_finder + +##--------------------------------------------------## +# Define model_params by model_blueprint w.r.t your model's __init__(model_params). + +model_params = { + "wenet_transfer":True, # change wenet conformer keys when transferring + "training": True, "extracted_embedding": "near", + "embd_dim": 256, # xvector dim + "transformer_type": "conformer", # [conformer, transformer, re_conformer] + "transformer_params": { + "attention_dim": 256, + "attention_heads": 4, + "num_blocks": 6, + "combiner_type": "norm", # [norm, mfa, random_frame, random_layer] + "aux_layer_period": 2, # aux: select inner layer for combiner. + "aux_layer_start": 3, + "dropout_rate": 0.1, + "layer_dropout":0., + "linear_units": 2048, + "positional_dropout_rate": 0.1, + "attention_dropout_rate": 0.1, + "attention_norm_args": { + "norm_method": "softmax_plus", # [softmax, relu_plus, softmax_plus] + "train_len": 300, + }, + "input_layer": "conv2d", # [linear, conv2d2, conv2d, conv2d6, conv2d8] + "cnn_module_kernel": 15, # for conformer + "pos_enc_type": "rot_pos", # [abs_pos, no_pos, rot_pos, rel_pos] + "convfnn_blocks": 0 # use conv type ffn in head blocks. + }, + + + "pooling": "ecpa-attentive", + "pooling_params": { + "hidden_size": 128, + "time_attention": False, + "stddev": True, + + }, + "fc1": False, + "fc1_params": { + "nonlinearity": 'relu', "nonlinearity_params": {"inplace": True}, + "bn-relu": False, + "bn": True, + "bn_params": {"momentum": 0.5, "affine": False, "track_running_stats": True}}, + + "fc2_params": { + "nonlinearity": '', "nonlinearity_params": {"inplace": True}, + "bn-relu": False, + "bn": True, + "ln_replace": True, # replace BN with LN + "bn_params": {"momentum": 0.5, "affine": True, "track_running_stats": True}}, + + + "margin_loss": True, + "margin_loss_params": { + "method": "aam", "m": 0.2, "feature_normalize": True, + "s": 30, "mhe_loss": False, "mhe_w": 0.01}, + + + # margin_warm is supported to smoothly increase margin of softmax. + "use_step": False, + "step_params": { + "margin_warm":False, + "margin_warm_conf":{"start_epoch":1,"end_epoch":1,"offset_margin":-0.0,"init_lambda":1.0}, + "T": None, + "m": False, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4, + "s": False, "s_tuple": (30, 12), "s_list": None, + "t": False, "t_tuple": (0.5, 1.2), + "p": False, "p_tuple": (0.5, 0.1)} +} + +optimizer_params = { + "name": "adamW", + "learn_rate": 0.00025, + "beta1": 0.9, + "beta2": 0.999, + "beta3": 0.999, + # Should be large for decouped weight decay (adamW) and small for L2 regularization (sgd, adam). + "weight_decay": 5e-2, + "lookahead.k": 5, + # 0 means not using lookahead and if used, suggest to set it as 0.5. + "lookahead.alpha": 0, + "gc": False, # If true, use gradient centralization. + "nesterov": False, # for sgd + "sam": False, # suggest true when apply ASR transferring. + "sam.rho": 2.0, # 2.0 for adaptive + "sam.adaptive": True, + # "custwd_dict":{ + # "train_len":0, + # "bias":0 + # } +} + +lr_scheduler_params = { + "name": "1cycle", + "1cycle.learn_rate":0.001, + "1cycle.warmup_steps":5000, + "1cycle.epochs": 50, + "1cycle.steps_per_epoch": 2200, + "1cycle.div_factor":1000.0, # initial_lr = max_lr/div_factor + "1cycle.final_div_factor":1000.0, # min_lr = initial_lr/final_div_factor + "1cycle.anneal_strategy":'cos', # ["cos", "linear"] + "1cycle.cycle_momentum":False, + + + "noam.warmup_steps": 5000, + "noam.step_decay": True, + "noam.step_size": 8800, # suggest 4 epochs + "noam.step_rate": 0.5, + + "cyclic.max_lr": 1e-3, + "cyclic.base_lr": 1e-8, + "cyclic.step_size_up": 22000, + "cyclic.mode": 'triangular2', + +} + +epochs = 51 # Total epochs to train. It is important. + +compute_batch_num_valid = 10 +# About validation computation and loss reporting. If report_times_every_epoch is not None, +report_interval_iters = 500 +# then compute report_interval_iters by report_times_every_epoch. + +suffix = "params" # Used in saved model file. +# Other options +exist_model = "" # Use it in transfer learning. +##--------------------------------------------------## +# Main params +traindata = "data/raw/voxceleb2_dev" +traindata_for_egs = "data/raw/voxceleb2_dev" +egs_dir = "exp/egs/voxceleb2_dev_whole" +egs_conf = "subtools/conf/egs_pre_sre_transformer.yaml" +model_blueprint = "subtools/pytorch/model/transformer_xvector.py" +model_dir = "exp/conformer_6L256D4H_4sub" +##--------------------------------------------------## +## +######################################################### START ######################################################### +## +# Set seed +utils.set_all_seed(1024) +## +# Init environment +# It is used for multi-gpu training if used (number of gpu-id > 1). +# And it will do nothing for single-GPU training. +utils.init_multi_gpu_training(args.gpu_id, args.multi_gpu_solution) +## +# Set sleep time for a rest +# Use it to run a launcher with a countdown function when there are no extra GPU memory +# but you really want to go to bed and know when the GPU memory will be free. +if args.sleep > 0: + time.sleep(args.sleep) + +## +# Auto-config params +# If multi-GPU used, it will auto-scale learning rate by multiplying number of processes. +optimizer_params["learn_rate"] = utils.auto_scale_lr( + optimizer_params["learn_rate"]) +# It is used for model.step() defined in model blueprint. +if lr_scheduler_params["name"] == "warmR" and model_params["use_step"]: + model_params["step_params"]["T"]=(lr_scheduler_params["warmR.T_max"], lr_scheduler_params["warmR.T_mult"]) + +# Preprocess +if stage <= 2 and endstage >= 0 and utils.is_main_training(): + # Here only give limited options because it is not convenient. + # Suggest to pre-execute this shell script to make it freedom and then continue to run this launcher. + kaldi_common.execute_command("bash subtools/pytorch/pipeline/preprocess_wav_egs.sh " + "--stage {stage} --endstage {endstage} --nj {nj} --whole-utt {whole_utt} --random-segment {random_segment} " + "--seg-dur {seg_dur} --amp-th {amp_th} --de-silence {de_silence} --vad-wav-savdir {vad_wav_savdir} " + "--min-len {min_len} --max-len {max_len} --limit-utts {limit_utts} " + "--valid-split-type {valid_split_type} --valid-num-utts {valid_utts} --valid-chunk-num {valid_chunk_num} " + "--valid-fix-chunk-num {valid_fix_chunk_num} --force-clear {force_clear} " + "--pre-rirmusan {pre_rirmusan} --openrir-folder {openrir_folder} --musan-folder {musan_folder} " + "--csv-aug-folder {csv_aug_folder} --savewav-folder {savewav_folder} --max-noise-len {max_noise_len} " + "--data-type {data_type} --shard-dir {shard_dir} --num-utts-per-shard {num_utts_per_shard} " + "{traindata} {traindata_for_egs} {egs_dir}".format(stage=stage, endstage=endstage, nj=preprocess_nj, + whole_utt=str(whole_utt).lower(), random_segment=str(random_segment).lower(), seg_dur=seg_dur, + amp_th=amp_th, de_silence=str(de_silence).lower(), vad_wav_savdir=vad_wav_savdir, + min_len=min_len, max_len=max_len, limit_utts=limit_utts, valid_split_type=valid_split_type, + valid_utts=valid_utts, valid_chunk_num=valid_chunk_num, valid_fix_chunk_num=str(valid_fix_chunk_num).lower(), + force_clear=str(force_clear).lower(), pre_rirmusan=str(pre_rirmusan).lower(), openrir_folder=openrir_folder, + musan_folder=musan_folder, csv_aug_folder=csv_aug_folder, savewav_folder=savewav_folder, max_noise_len=max_noise_len, + data_type=data_type, num_utts_per_shard=num_utts_per_shard, shard_dir=shard_dir, + traindata=traindata, traindata_for_egs=traindata_for_egs, egs_dir=egs_dir)) + +# Train model +if stage <= 3 <= endstage: + if utils.is_main_training(): + logger.info("Get model_blueprint from model directory.") + # Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path. + model_blueprint = utils.create_model_dir( + model_dir, model_blueprint, stage=train_stage) + + if utils.is_main_training(): + logger.info("Load egs to bunch.") + # The dict [info] contains feat_dim and num_targets + with open(egs_conf, 'r') as fin: + egs_params = yaml.load(fin, Loader=yaml.FullLoader) + egs_params['dataset_conf']['csv_aug_folder'] = csv_aug_folder + bunch, info = egs.BaseBunch.get_bunch_from_egsdir(egs_dir, egs_params) + feat_extraction_config = copy.deepcopy( + egs_params['dataset_conf']['feature_extraction_conf']) + feat_extraction_config['kaldi_featset']['dither'] = 0.0 + feat_config_path = os.path.join(model_dir, 'config', 'feat_conf.yaml') + if utils.is_main_training(): + with open(feat_config_path, 'w') as fou: + yaml.dump(feat_extraction_config, fou) + + if utils.is_main_training(): + logger.info("Create model from model blueprint.") + # Another way: import the model.py in this python directly, but it is not friendly to the shell script of extracting and + # I don't want to change anything about extracting script when the model.py is changed. + model_py = utils.create_model_from_py(model_blueprint) + + model = model_py.TransformerXvector( + info["feat_dim"], info["num_targets"], **model_params) + + + # If multi-GPU used, then batchnorm will be converted to synchronized batchnorm, which is important + # to make peformance stable. + # It will change nothing for single-GPU training. + model = utils.convert_synchronized_batchnorm(model) + # print(model) + epoch_iters = (info['epoch_iters']//accum_grad) + if hasattr(model,'margin_warm'): + model.margin_warm.update_step_range(epoch_iters) + # print(sum(p.numel() for p in model.parameters())) + # sys.exit() + if utils.is_main_training(): + print(model) + p1=sum(p.numel() for p in model.parameters()) + script_model = copy.deepcopy(model) + script_model.loss=None + p2 = sum(p.numel() for p in script_model.parameters()) + logger.info("model params w/o proj layer: {} / {} .".format(p1,p2)) + script_model = torch.jit.script(script_model) + script_model.save(os.path.join(model_dir, 'init.zip')) + logger.info("The number of steps per epoch is about {}.".format(epoch_iters)) + logger.info("Define optimizer and lr_scheduler.") + del script_model + + optimizer = optim.get_optimizer(model, optimizer_params) + lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper( + optimizer, lr_scheduler_params) + + # Record params to model_dir + + if utils.is_main_training(): + utils.write_list_to_file([egs_params, model_params, optimizer_params, + lr_scheduler_params], model_dir+'/config/params.dict',yml=True) + + if utils.is_main_training(): + logger.info("Init a simple trainer.") + # Package(Elements:dict, Params:dict}. It is a key parameter's package to trainer and model_dir/config/. + package = ({"data": bunch, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}, + {"model_dir": model_dir, "model_blueprint": model_blueprint, "exist_model": exist_model, "accum_grad": accum_grad, + "start_epoch": train_stage, "epochs": epochs, "use_gpu": use_gpu, "gpu_id": gpu_id, "use_amp": use_amp, + "skip_nan_batch": skip_nan_batch, "benchmark": benchmark, "suffix": suffix, "compute_batch_num_valid": compute_batch_num_valid, + "report_interval_iters": report_interval_iters, "record_file": "train.csv"}) + train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer + + execuer = train_exec.SimpleTrainer(package) + + if run_lr_finder: + execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8, + final_lr=10., num_iters=2000, beta=0.98) + endstage = 3 # Do not start extractor. + else: + execuer.run() + + +# Extract xvector +if stage <= 4 <= endstage and utils.is_main_training(): + # There are some params for xvector extracting. + data_root = "data" # It contains all dataset just like Kaldi recipe. + prefix = "raw" # For to_extracted_data. + data_type_emb = "raw" # shard or raw or kaldi. + de_silence = False + amp_th = 50 + + to_extracted_positions = ["near"] # Define this w.r.t model_blueprint. + # All dataset should be in dataroot/prefix. + to_extracted_data = ["voxceleb1"] + # It is model's name, such as 10.params or final.params (suffix is w.r.t package). + to_extracted_epochs = ["50"] + + nj = 8 + force = True + use_gpu = True + gpu_id = "" + sleep_time = 10 + feat_config = "feat_conf.yaml" + max_chunk = 300 + # Run a batch extracting process. + try: + for position in to_extracted_positions: + # Generate the extracting config from nnet config where + # which position to extract depends on the 'extracted_embedding' parameter of model_creation (by my design). + model_blueprint, model_creation = utils.read_nnet_config( + "{0}/config/nnet.config".format(model_dir)) + # To save memory without loading some independent components. + model_creation = model_creation.replace( + "training=True", "training=False") + model_creation = model_creation.replace( + model_params["extracted_embedding"], position) + extract_config = "{0}.extract.config".format(position) + + utils.write_nnet_config( + model_blueprint, model_creation, "{0}/config/{1}".format(model_dir, extract_config)) + + for epoch in to_extracted_epochs: + model_file = "{0}.{1}".format(epoch, suffix) + point_name = "{0}_epoch_{1}".format(position, epoch) + + # If run a trainer with background thread (do not be supported now) or run this launcher extrally with stage=4 + # (it means another process), then this while-listen is useful to start extracting immediately (but require more gpu-memory). + model_path = "{0}/{1}".format(model_dir, model_file) + + while True: + if os.path.exists(model_path): + break + else: + time.sleep(sleep_time) + + for data in to_extracted_data: + datadir = "{0}/{1}/{2}".format(data_root, prefix, data) + outdir = "{0}/{1}/{2}".format(model_dir, point_name, data) + # Use a well-optimized shell script (with multi-processes) to extract xvectors. + # Another way: use subtools/splitDataByLength.sh and subtools/pytorch/pipeline/onestep/extract_embeddings.py + # with python's threads to extract xvectors directly, but the shell script is more convenient. + kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh " + " --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' " + " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} " + " --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} " + "{model_dir} {datadir} {outdir}".format(model_file=model_file, nj=nj, + use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config, + feat_config=feat_config, data_type=data_type_emb, de_silence=str(de_silence).lower(), amp_th=amp_th, + max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir)) + except BaseException as e: + if not isinstance(e, KeyboardInterrupt): + traceback.print_exc() + sys.exit(1) diff --git a/pytorch/launcher/runTransformerXvector_LM.py b/pytorch/launcher/runTransformerXvector_LM.py new file mode 100755 index 0000000..e767339 --- /dev/null +++ b/pytorch/launcher/runTransformerXvector_LM.py @@ -0,0 +1,508 @@ +# -*- coding:utf-8 -*- + +# Copyright xmuspeech (Author: Leo 2021-10-05) +# Apache 2.0 +# Only support "nccl" backend multi-gpu training. +import sys +import os +import logging +import argparse +import traceback +import time +import yaml +import copy +import math +import numpy as np + +import torch +sys.path.insert(0, 'subtools/pytorch') +from pipeline.onestep.prepare_speechaug_csv import prepare_speech_aug +from libs.support.logging_stdout import patch_logging_stream +import libs.support.utils as utils +import libs.support.kaldi_common as kaldi_common +import libs.training.trainer_online as trainer +import libs.training.trainer_online_sam as trainer_sam +import libs.training.lr_scheduler_online as learn_rate_scheduler +import libs.training.optim as optim +import libs.egs.egs_online as egs + +torch.multiprocessing.set_sharing_strategy('file_system') +"""A launcher script with python version (Snowdar's launcher to do experiments w.r.t snowdar-xvector.py). + +Python version is gived (rather than Shell) to have more freedom, such as decreasing limitation of parameters that transfering +them to python from shell. + +Note, this script is for Large-Margin Finetuning, we assume that you have preprocessed the dataset and obtained a trained model. + `The IDLAB VoxCeleb Speaker Recognition Challenge 2020 System Description.` + https://arxiv.org/pdf/2010.12468.pdf + +How to modify this launcher: + 1.Prepare your kaldi format dataset and model.py (model blueprint); + 2.Give the path of dataset, model blueprint, etc. in main parameters field; + 3.Change the imported name of model in 'model = model_py.model_name(...)' w.r.t model.py by yourself; + 4.Modify any training parameters what you want to change (epochs, optimizer and lr_scheduler etc.); + 5.Modify parameters of extracting in stage 4 w.r.t your own training config; + 6.Run this launcher. + +Conclusion: preprare -> config -> run. + +How to run this launcher to train a model: + 1.For CPU-based training case. The key option is --use-gpu. + python3 launcher.py --use-gpu=false + 2.For single-GPU training case (Default). + python3 launcher.py + 3.For DDP-based multi-GPU training case. Note --nproc_per_node is equal to number of gpu id in --gpu-id. + python3 -m torch.distributed.launch --nproc_per_node=2 launcher.py --gpu-id=0,1 + 4.For Horovod-based multi-GPU training case. Note --np is equal to number of gpu id in --gpu-id. + horovodrun -np 2 launcher.py --gpu-id=0,1 + 5.For all of above, a runLauncher.sh script has been created to launch launcher.py conveniently. + The key option to use single or multiple GPU is --gpu-id. + The subtools/runPytorchLauncher.sh is a soft symbolic which is linked to subtools/pytorch/launcher/runLauncher.sh, + so just use it. + + [ CPU ] + subtools/runPytorchLauncher.sh launcher.py --use-gpu=false + + [ Single-GPU ] + (1) Auto-select GPU device + subtools/runPytorchLauncher.sh launcher.py + (2) Specify GPU device + subtools/runPytorchLauncher.sh launcher.py --gpu-id=2 + + [ Multi-GPU ] + (1) Use DDP solution (Default). + subtools/runPytorchLauncher.sh launcher.py --gpu-id=2,3 --multi-gpu-solution="ddp" + (2) Use Horovod solution. + subtools/runPytorchLauncher.sh launcher.py --gpu-id=2,3 --multi-gpu-solution="horovod" + +If you have any other requirements, you could modify the codes in anywhere. +For more details of multi-GPU devolopment, see subtools/README.md. +""" + +# Logger +# Logger +patch_logging_stream(logging.INFO) +logger = logging.getLogger('libs') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [ %(pathname)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ]\n#### %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + +# Parser: add this parser to run launcher with some frequent options (really for conveninece). +parser = argparse.ArgumentParser( + description="""Train xvector framework with pytorch.""", + formatter_class=argparse.RawTextHelpFormatter, + conflict_handler='resolve') + +parser.add_argument("--stage", type=int, default=3, + help="The stage to control the start of training epoch (default 3).\n" + " stage 3: Training.\n" + " stage 4: extract xvector.") + +parser.add_argument("--endstage", type=int, default=4, + help="The endstage to control the endstart of training epoch (default 4).") + +parser.add_argument("--train-stage", type=int, default=-1, + help="The stage to control the start of training epoch (default -1).\n" + " -1 -> creating model_dir.\n" + " 0 -> model initialization (e.g. transfer learning).\n" + " >0 -> recovering training.") + +parser.add_argument("--force-clear", type=str, action=kaldi_common.StrToBoolAction, + default=True, choices=["true", "false"], + help="Clear the dir generated by preprocess.") + +parser.add_argument("--pre-rirmusan", type=str, action=kaldi_common.StrToBoolAction, + default=True, choices=["true", "false"], + help="Prepare the openrir and musan dataset for adding reverb and noises.") + +parser.add_argument('--use-amp', type=str, action=kaldi_common.StrToBoolAction, + default=False, choices=["true", "false"], + help='Use automatic mixed precision training') + +parser.add_argument("--skip-nan-batch", type=str, action=kaldi_common.StrToBoolAction, + default=True, choices=["true", "false"], + help="Whether skip optimizer stepping when the gradient has nan/inf values.") + +parser.add_argument("--accum-grad", type=int, default=1, + help="Using accumulate grad.") + +parser.add_argument("--multi-gpu-solution", type=str, default="ddp", + choices=["ddp"], + help="if number of gpu_id > 1, this option will be valid to init a multi-gpu solution.") + +parser.add_argument("--use-gpu", type=str, action=kaldi_common.StrToBoolAction, + default=True, choices=["true", "false"], + help="Use GPU or not.") + +parser.add_argument("--gpu-id", type=str, default="", + help="If NULL, then it will be auto-specified.") + +parser.add_argument("--benchmark", type=str, action=kaldi_common.StrToBoolAction, + default=False, choices=["true", "false"], + help="If true, save training time but require a little more gpu-memory.") + +parser.add_argument("--run-lr-finder", type=str, action=kaldi_common.StrToBoolAction, + default=False, choices=["true", "false"], + help="If true, run lr finder rather than training.") + +parser.add_argument("--sleep", type=int, default=0, + help="The waiting time to launch a launcher.") + +parser.add_argument("--local_rank", type=int, default=0, + help="Do not delete it when using DDP-based multi-GPU training.\n" + "It is important for torch.distributed.launch.") + +args = parser.parse_args() +## +######################################################### PARAMS ######################################################## +## +##--------------------------------------------------## +# Control options +stage = max(3, args.stage) +endstage = min(4, args.endstage) +train_stage = max(-1, args.train_stage) + +##--------------------------------------------------## +# Prepare speech augmention csv files. +pre_rirmusan = args.pre_rirmusan # whether skip this stage. +openrir_folder = "export/path" # where contains RIRS_NOISES folder. +musan_folder = "export/path" # where contains musan folder. +csv_aug_folder = "exp/aug_csv6" # csv file location. +savewav_folder = "/export/yourpath/speech_aug_6015" # save the noise seg into SSD. +max_noise_len = 6.015 # The max dur of noise. +##--------------------------------------------------## +# Training options +use_amp = args.use_amp +skip_nan_batch = args.skip_nan_batch +accum_grad = args.accum_grad + +use_gpu = args.use_gpu # Default true. +# If true, save much training time but require a little more gpu-memory. +benchmark = args.benchmark +gpu_id = args.gpu_id # If NULL, then it will be auto-specified. +run_lr_finder = args.run_lr_finder + +##--------------------------------------------------## +# Define model_params by model_blueprint w.r.t your model's __init__(model_params). + +model_params = { + "wenet_transfer":True, # change wenet conformer keys when transferring + "training": True, "extracted_embedding": "near", + "embd_dim": 256, # xvector dim + "transformer_type": "conformer", # [conformer, transformer, re_conformer] + "transformer_params": { + "attention_dim": 256, + "attention_heads": 4, + "num_blocks": 6, + "combiner_type": "norm", # [norm, mfa, random_frame, random_layer] + "aux_layer_period": 2, # aux: select inner layer for combiner. + "aux_layer_start": 3, + "dropout_rate": 0.1, + "layer_dropout":0., + "linear_units": 2048, + "positional_dropout_rate": 0.1, + "attention_dropout_rate": 0.1, + "attention_norm_args": { + "norm_method": "softmax_plus", # [softmax, relu_plus, softmax_plus] + "train_len": 300, + }, + "input_layer": "conv2d", # [linear, conv2d2, conv2d, conv2d6, conv2d8] + "cnn_module_kernel": 15, # for conformer + "pos_enc_type": "rot_pos", # [abs_pos, no_pos, rot_pos, rel_pos] + "convfnn_blocks": 0 # use conv type ffn in head blocks. + }, + + + "pooling": "ecpa-attentive", + "pooling_params": { + "hidden_size": 128, + "time_attention": False, + "stddev": True, + + }, + "fc1": False, + "fc1_params": { + "nonlinearity": 'relu', "nonlinearity_params": {"inplace": True}, + "bn-relu": False, + "bn": True, + "bn_params": {"momentum": 0.5, "affine": False, "track_running_stats": True}}, + + "fc2_params": { + "nonlinearity": '', "nonlinearity_params": {"inplace": True}, + "bn-relu": False, + "bn": True, + "ln_replace": True, # replace BN with LN + "bn_params": {"momentum": 0.5, "affine": True, "track_running_stats": True}}, + + + "margin_loss": True, + "margin_loss_params": { + "method": "aam", "m": 0.5, "feature_normalize": True, + "s": 30, "mhe_loss": False, "mhe_w": 0.01}, + + "use_step": True, + "step_params": { + "margin_warm":True, + "margin_warm_conf":{"start_epoch":1,"end_epoch":2,"offset_margin":-0.3,"init_lambda":0.5}, + "T": None, + "m": False, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4, + "s": False, "s_tuple": (30, 12), "s_list": None, + "t": False, "t_tuple": (0.5, 1.2), + "p": False, "p_tuple": (0.5, 0.1)} +} + +optimizer_params = { + "name": "sgd", + "learn_rate": 0.0001, + "beta1": 0.9, + "beta2": 0.999, + "beta3": 0.999, + # Should be large for decouped weight decay (adamW) and small for L2 regularization (sgd, adam). + "weight_decay": 1e-5, + "lookahead.k": 5, + # 0 means not using lookahead and if used, suggest to set it as 0.5. + "lookahead.alpha": 0, + "gc": False, # If true, use gradient centralization. + "nesterov": False, # for sgd + "sam": False, + "sam.rho": 2.0, # 2.0 for adaptive + "sam.adaptive": True, + "custwd_dict":{ + "train_len":0, + "loss":1e-4, + # "bias":0 + } +} + +lr_scheduler_params = { + "name": "1cycle", + "1cycle.learn_rate":4e-4, + "1cycle.warmup_steps":2000, + "1cycle.epochs": 4, + "1cycle.steps_per_epoch": 4300, + "1cycle.div_factor":10000.0, # initial_lr = max_lr/div_factor + "1cycle.final_div_factor":100.0, # min_lr = initial_lr/final_div_factor + "1cycle.anneal_strategy":'cos', # ["cos", "linear"] + "1cycle.cycle_momentum":False, + + + "noam.warmup_steps": 2000, + "noam.step_decay": True, + "noam.step_size": 4300, # suggest 4 epochs + "noam.step_rate": 0.5, + + "cyclic.max_lr": 1e-3, + "cyclic.base_lr": 1e-8, + "cyclic.step_size_up": 22000, + "cyclic.mode": 'triangular2', + +} + +epochs = 5 # Total epochs to train. It is important. + +compute_batch_num_valid = 10 +# About validation computation and loss reporting. If report_times_every_epoch is not None, +report_interval_iters = 500 +# then compute report_interval_iters by report_times_every_epoch. + +suffix = "params" # Used in saved model file. +# Other options +exist_model = "exp/conformer_6L256D4H_4sub/50.params" # Use it in transfer learning. +##--------------------------------------------------## +# Main paramsW +egs_dir = "exp/egs/voxceleb2_dev_whole" +egs_conf = "subtools/conf/egs_pre_sre_lm.yaml" +model_blueprint = "subtools/pytorch/model/transformer_xvector.py" +model_dir = "exp/conformer_6L256D4H_4sub_lm" +##--------------------------------------------------## +## +######################################################### START ######################################################### +## +# Set seed +utils.set_all_seed(1024) +## +# Init environment +# It is used for multi-gpu training if used (number of gpu-id > 1). +# And it will do nothing for single-GPU training. +utils.init_multi_gpu_training(args.gpu_id, args.multi_gpu_solution) +## +# Set sleep time for a rest +# Use it to run a launcher with a countdown function when there are no extra GPU memory +# but you really want to go to bed and know when the GPU memory will be free. +if args.sleep > 0: + time.sleep(args.sleep) + +## +# Auto-config params +# If multi-GPU used, it will auto-scale learning rate by multiplying number of processes. +optimizer_params["learn_rate"] = utils.auto_scale_lr( + optimizer_params["learn_rate"]) +# It is used for model.step() defined in model blueprint. +if lr_scheduler_params["name"] == "warmR" and model_params["use_step"]: + model_params["step_params"]["T"]=(lr_scheduler_params["warmR.T_max"], lr_scheduler_params["warmR.T_mult"]) + +if pre_rirmusan: + prepare_speech_aug(openrir_folder, musan_folder, csv_folder=csv_aug_folder, savewav_folder=savewav_folder, max_noise_len=max_noise_len, force_clear=True) + + +# Train model +if stage <= 3 <= endstage: + if utils.is_main_training(): + logger.info("Get model_blueprint from model directory.") + # Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path. + model_blueprint = utils.create_model_dir( + model_dir, model_blueprint, stage=train_stage) + + if utils.is_main_training(): + logger.info("Load egs to bunch.") + # The dict [info] contains feat_dim and num_targets + with open(egs_conf, 'r') as fin: + egs_params = yaml.load(fin, Loader=yaml.FullLoader) + egs_params['dataset_conf']['csv_aug_folder'] = csv_aug_folder + bunch, info = egs.BaseBunch.get_bunch_from_egsdir(egs_dir, egs_params) + feat_extraction_config = copy.deepcopy( + egs_params['dataset_conf']['feature_extraction_conf']) + feat_extraction_config['kaldi_featset']['dither'] = 0.0 + feat_config_path = os.path.join(model_dir, 'config', 'feat_conf.yaml') + if utils.is_main_training(): + with open(feat_config_path, 'w') as fou: + yaml.dump(feat_extraction_config, fou) + + if utils.is_main_training(): + logger.info("Create model from model blueprint.") + # Another way: import the model.py in this python directly, but it is not friendly to the shell script of extracting and + # I don't want to change anything about extracting script when the model.py is changed. + model_py = utils.create_model_from_py(model_blueprint) + + model = model_py.TransformerXvector( + info["feat_dim"], info["num_targets"], **model_params) + + + # If multi-GPU used, then batchnorm will be converted to synchronized batchnorm, which is important + # to make peformance stable. + # It will change nothing for single-GPU training. + model = utils.convert_synchronized_batchnorm(model) + # print(model) + epoch_iters = (info['epoch_iters']//accum_grad) + if hasattr(model,'margin_warm'): + model.margin_warm.update_step_range(epoch_iters) + # print(sum(p.numel() for p in model.parameters())) + # sys.exit() + if utils.is_main_training(): + print(model) + p1=sum(p.numel() for p in model.parameters()) + script_model = copy.deepcopy(model) + script_model.loss=None + p2 = sum(p.numel() for p in script_model.parameters()) + logger.info("model params w/o proj layer: {} / {} .".format(p1,p2)) + script_model = torch.jit.script(script_model) + script_model.save(os.path.join(model_dir, 'init.zip')) + logger.info("The number of steps per epoch is about {}.".format(epoch_iters)) + logger.info("Define optimizer and lr_scheduler.") + del script_model + + optimizer = optim.get_optimizer(model, optimizer_params) + lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper( + optimizer, lr_scheduler_params) + + # Record params to model_dir + + if utils.is_main_training(): + utils.write_list_to_file([egs_params, model_params, optimizer_params, + lr_scheduler_params], model_dir+'/config/params.dict',yml=True) + + if utils.is_main_training(): + logger.info("Init a simple trainer.") + # Package(Elements:dict, Params:dict}. It is a key parameter's package to trainer and model_dir/config/. + package = ({"data": bunch, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}, + {"model_dir": model_dir, "model_blueprint": model_blueprint, "exist_model": exist_model, "accum_grad": accum_grad, + "start_epoch": train_stage, "epochs": epochs, "use_gpu": use_gpu, "gpu_id": gpu_id, "use_amp": use_amp, + "skip_nan_batch": skip_nan_batch, "benchmark": benchmark, "suffix": suffix, "compute_batch_num_valid": compute_batch_num_valid, + "report_interval_iters": report_interval_iters, "record_file": "train.csv"}) + train_exec = trainer_sam if isinstance(optimizer,optim.SAM) else trainer + + execuer = train_exec.SimpleTrainer(package) + + if run_lr_finder: + execuer.run_lr_finder("lr_finder.csv", init_lr=1e-8, + final_lr=10., num_iters=2000, beta=0.98) + endstage = 3 # Do not start extractor. + else: + execuer.run() + + +# Extract xvector +if stage <= 4 <= endstage and utils.is_main_training(): + # There are some params for xvector extracting. + data_root = "data" # It contains all dataset just like Kaldi recipe. + prefix = "raw" # For to_extracted_data. + data_type_emb = "raw" # shard or raw or kaldi. + de_silence = False + amp_th = 50 + + to_extracted_positions = ["near"] # Define this w.r.t model_blueprint. + # All dataset should be in dataroot/prefix. + to_extracted_data = ["voxceleb1", "voxceleb2_dev"] + # It is model's name, such as 10.params or final.params (suffix is w.r.t package). + to_extracted_epochs = ["4"] + + nj = 8 + force = True + use_gpu = True + gpu_id = "" + sleep_time = 10 + feat_config = "feat_conf.yaml" + max_chunk = 4000 + # Run a batch extracting process. + try: + for position in to_extracted_positions: + # Generate the extracting config from nnet config where + # which position to extract depends on the 'extracted_embedding' parameter of model_creation (by my design). + model_blueprint, model_creation = utils.read_nnet_config( + "{0}/config/nnet.config".format(model_dir)) + # To save memory without loading some independent components. + model_creation = model_creation.replace( + "training=True", "training=False") + model_creation = model_creation.replace( + model_params["extracted_embedding"], position) + extract_config = "{0}.extract.config".format(position) + + utils.write_nnet_config( + model_blueprint, model_creation, "{0}/config/{1}".format(model_dir, extract_config)) + + for epoch in to_extracted_epochs: + model_file = "{0}.{1}".format(epoch, suffix) + point_name = "{0}_epoch_{1}".format(position, epoch) + + # If run a trainer with background thread (do not be supported now) or run this launcher extrally with stage=4 + # (it means another process), then this while-listen is useful to start extracting immediately (but require more gpu-memory). + model_path = "{0}/{1}".format(model_dir, model_file) + + while True: + if os.path.exists(model_path): + break + else: + time.sleep(sleep_time) + + for data in to_extracted_data: + datadir = "{0}/{1}/{2}".format(data_root, prefix, data) + outdir = "{0}/{1}/{2}".format(model_dir, point_name, data) + # Use a well-optimized shell script (with multi-processes) to extract xvectors. + # Another way: use subtools/splitDataByLength.sh and subtools/pytorch/pipeline/onestep/extract_embeddings.py + # with python's threads to extract xvectors directly, but the shell script is more convenient. + kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh " + " --model {model_file} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' " + " --data-type '{data_type}' --de-silence {de_silence} --amp-th {amp_th} --max-chunk {max_chunk} " + " --force {force} --nnet-config config/{extract_config} --feat-config config/{feat_config} " + "{model_dir} {datadir} {outdir}".format(model_file=model_file, nj=nj, + use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config, + feat_config=feat_config, data_type=data_type_emb, de_silence=str(de_silence).lower(), amp_th=amp_th, + max_chunk=max_chunk, model_dir=model_dir, datadir=datadir, outdir=outdir)) + except BaseException as e: + if not isinstance(e, KeyboardInterrupt): + traceback.print_exc() + sys.exit(1) diff --git a/pytorch/libs/egs/egs_online.py b/pytorch/libs/egs/egs_online.py old mode 100644 new mode 100755 index 311215f..c2ab949 --- a/pytorch/libs/egs/egs_online.py +++ b/pytorch/libs/egs/egs_online.py @@ -13,7 +13,6 @@ from torch.utils.data import IterableDataset from torch.utils.data import DataLoader import torch.distributed as dist - sys.path.insert(0, "subtools/pytorch") import libs.support.utils as utils import libs.egs.processor as processor @@ -62,6 +61,8 @@ def get_data_dur(self): def apply(self, f): assert callable(f) return Processor(self, f, *self.args, **self.kw) + def __len__(self): + return len(self.source) class DistributedSampler: def __init__(self, shuffle=True, partition=True): @@ -144,28 +145,54 @@ def get_data_dur(self): logger.warning('do not support get duration') total_dur=None num_sample = sum([int(self.lists[index]['eg-num']) for index in self.data_this_rank]) if "eg-num" in self.lists[0] else len(self.data_this_rank) + # tot_sample = sum([int(self.lists[index]['eg-num']) for index in self.lists]) if "eg-num" in self.lists[0] else len(self.lists) return total_dur,num_sample + def __len__(self): + return sum([int(self.lists[index]['eg-num']) for index in range(len(self.lists))]) if "eg-num" in self.lists[0] else len(self.lists) -def WavEgs(egs_csv,conf,data_type='raw',partition=True): +def WavEgs(egs_csv,conf,data_type='raw',partition=True,num_targets=0): assert data_type in ['raw', 'shard', 'kaldi'] lists = utils.csv_to_list(egs_csv) + + shuffle = conf.get('shuffle', True) + + dataset = DataList(lists, shuffle=shuffle, partition=partition) - if data_type in ['raw','shard']: + + if data_type in ['raw', 'shard']: if data_type=='shard': dataset = Processor(dataset, processor.url_opener) dataset = Processor(dataset, processor.tar_file_and_group) else: dataset = Processor(dataset, processor.parse_raw) - - random_chunk = conf.get('random_chunk',False) - if random_chunk: - random_chunk_size = conf.get('random_chunk_size',2.015) - dataset = Processor(dataset, processor.random_chunk, random_chunk_size) + filt = conf.get('filter', False) + filter_conf = conf.get('filter_conf', {}) + if filt: + dataset = Processor(dataset, processor.filter, **filter_conf) + resample = conf.get('resample', False) if resample: resample_conf = conf.get('resample_conf', {}) dataset = Processor(dataset, processor.resample, **resample_conf) + + + pre_speed_perturb = conf.get('pre_speed_perturb', False) + spkid_aug = 1 + if pre_speed_perturb: + perturb_conf = conf.get('perturb_conf',{}) + sp = processor.PreSpeedPerturb(spk_num=num_targets,**perturb_conf) + spkid_aug = sp._spkid_aug() + dataset = Processor(dataset,sp) + + random_chunk = conf.get('random_chunk',False) + random_chunk_size = conf.get('random_chunk_size',2.015) + + + if random_chunk: + dataset = Processor(dataset, processor.random_chunk, random_chunk_size) + + speech_aug = conf.get('speech_aug', False) speech_aug_conf_file = conf.get('speech_aug_conf', '') @@ -177,7 +204,11 @@ def WavEgs(egs_csv,conf,data_type='raw',partition=True): csv_aug_folder = conf.get('csv_aug_folder','') if csv_aug_folder:change_csv_folder(speech_aug_conf,csv_aug_folder) - speechaug_pipline = processor.SpeechAugPipline(**speech_aug_conf) + speechaug_pipline = processor.SpeechAugPipline(spk_num=num_targets,**speech_aug_conf) + spkid_aug_lat = speechaug_pipline.get_spkid_aug() + if not (spkid_aug==1 or spkid_aug_lat==1): + raise ValueError("multi speaker id perturb setting, check your speech aug config") + spkid_aug = spkid_aug_lat*spkid_aug dataset = Processor(dataset, speechaug_pipline) feature_extraction_conf = conf.get('feature_extraction_conf',{}) @@ -185,6 +216,7 @@ def WavEgs(egs_csv,conf,data_type='raw',partition=True): dataset = Processor(dataset, feature_extraction) else: dataset = Processor(dataset,processor.offline_feat) + spec_aug = conf.get('spec_aug', False) if spec_aug: spec_aug_conf=conf.get('spec_aug_conf',{}) @@ -201,7 +233,8 @@ def WavEgs(egs_csv,conf,data_type='raw',partition=True): batch_conf = conf.get('batch_conf', {}) dataset = Processor(dataset, processor.batch, **batch_conf) dataset = Processor(dataset, processor.padding) - return dataset + + return dataset,spkid_aug def WavEgsXvector(wav_scp,feat_conf={},data_type='kaldi',de_silence=False,de_sil_conf={},partition=False): if data_type in ['raw', 'shard']: @@ -242,19 +275,22 @@ def __init__(self, trainset, valid=None,prefetch_factor=2,num_workers=0, pin_mem @classmethod - def get_bunch_from_csv(self, trainset_csv: str, valid_csv: str = None, egs_params: dict = {}): + def get_bunch_from_csv(self, trainset_csv: str, valid_csv: str = None, egs_params: dict = {},num_targets=-1): train_conf = egs_params['dataset_conf'] valid_conf = copy.deepcopy(train_conf) valid_conf['speech_aug'] = False + valid_conf['pre_speed_perturb'] = False valid_conf['spec_aug'] = False valid_conf['shuffle'] = False data_type = egs_params.get('data_type','raw') - trainset = WavEgs(trainset_csv, train_conf, data_type=data_type,partition=True) + trainset, num_targets_t = WavEgs(trainset_csv, train_conf, data_type=data_type,partition=True,num_targets=num_targets) + if valid_csv != "" and valid_csv is not None: - valid = WavEgs(valid_csv, valid_conf,data_type=data_type,partition=False) + valid,_ = WavEgs(valid_csv, valid_conf,data_type=data_type,partition=False) else: valid = None + self.num_targets =num_targets*num_targets_t return self(trainset, valid, **egs_params['data_loader_conf']) @@ -275,9 +311,17 @@ def get_bunch_from_egsdir(self, egsdir: str, egs_params: dict={}): egsdir, train_csv_name=train_csv_name, valid_csv_name=valid_csv_name) assert 'feat_dim' in egs_params feat_dim = int(egs_params['feat_dim']) - info = {"feat_dim": feat_dim, "num_targets": num_targets} + bunch = self.get_bunch_from_csv( - train_csv, valid_csv, egs_params) + train_csv, valid_csv, egs_params,num_targets) + num_targets = self.num_targets + tot_samples = len(bunch.train_loader.dataset) + world_size = dist.get_world_size() if dist.is_initialized() else 1 + if egs_params['dataset_conf']['batch_conf']['batch_type']=='static': + epoch_iters = (tot_samples//world_size)//egs_params['dataset_conf']['batch_conf']['batch_size'] + else: + epoch_iters = None + info = {"feat_dim": feat_dim, "num_targets": num_targets, "epoch_iters": epoch_iters} return bunch, info @@ -301,4 +345,4 @@ def get_info_from_egsdir(egsdir, train_csv_name=None, valid_csv_name=None): if __name__ == "__main__": - pass + pass \ No newline at end of file diff --git a/pytorch/libs/egs/processor.py b/pytorch/libs/egs/processor.py old mode 100644 new mode 100755 index e3fc4f9..39e1440 --- a/pytorch/libs/egs/processor.py +++ b/pytorch/libs/egs/processor.py @@ -15,7 +15,7 @@ import torchaudio.compliance.kaldi as kaldi import libs.support.kaldi_io as kaldi_io from libs.support.utils import batch_pad_right,get_torchaudio_backend -from .speech_augment import SpeechAug +from .speech_augment import SpeechAug,SpeedPerturb from .signal_processing import de_silence from libs.egs.augmentation import * from libs.egs.kaldi_features import InputSequenceNormalization @@ -87,7 +87,8 @@ def tar_file_and_group(data): try: if postfix == 'txt': label = file_obj.read().decode('utf8') - example['label'] = int(label) + + example['label'] = torch.tensor([int(label)],dtype=torch.long) elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) example['wav'] = waveform[:1,:] @@ -134,7 +135,7 @@ def parse_raw(data): else: waveform, sample_rate = torchaudio.load(wav_file) waveform = waveform[:1,:] - label = int(label) + label = torch.tensor([int(label)],dtype=torch.long) lens = torch.ones(1) example = dict(key=key, label=label, @@ -173,6 +174,48 @@ def de_sil(data,win_len=0.1,min_eng=50,retry_times=1,force_output=True): del waveform yield sample +class PreSpeedPerturb(object): + def __init__(self, + sample_rate=16000, + speeds=[90, 100, 110], + perturb_type='resample', + perturb_prob=1.0, + change_spk=True, + spk_num=0): + super().__init__() + self.change_spk = change_spk + self.speeder = SpeedPerturb(sample_rate, speeds=speeds, perturb_prob=perturb_prob, perturb_type=perturb_type, spk_num=spk_num,change_spk=change_spk,keep_shape=False) + + def __call__(self, data): + """ speechaug. + Args: + data: Iterable[{key, wav, label, lens, sample_rate}] + Returns: + Iterable[{key, wav, label, lens, sample_rate}] + """ + + for sample in data: + assert 'wav' in sample + assert 'label' in sample + waveform = sample['wav'] + spkid = sample['label'] + + try: + + waveform,spkid = self.speeder(waveform, spkid) + + sample['wav'] = waveform + if self.change_spk: + sample['label'] = spkid + yield sample + except Exception as ex: + logging.warning('Failed to speech aug {}'.format(sample['key'])) + + def _spkid_aug(self): + spkid_aug, _ = self.speeder.get_spkid_aug() + return spkid_aug + + def random_chunk(data,chunk_len=2.015): """ data: Iterable[{key, wav, label, lens, sample_rate}] @@ -215,6 +258,7 @@ def offline_feat(data): key = sample['eg-id'] ark_path = sample['ark-path'] label = int(sample['class-label']) + label=torch.tensor([label],dtype=torch.long) try: if 'start-position' in sample: @@ -257,17 +301,57 @@ def resample(data, resample_rate=16000): +def filter(data, + max_length=15, + min_length=0.2, + max_cut=True): + """ Filter sample according to feature. + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(s) + min_length: drop utterance which is less than min_length(s) + + Returns: + Iterable[{key, wav, *, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + duration_sample=sample['wav'].size(1) + + duration = duration_sample / sample['sample_rate'] + if duration < min_length: + continue + if duration > max_length: + if max_cut: + duration_sample=sample['wav'].size(1) + + snt_len_sample = int(max_length*sample['sample_rate']) + start = random.randint(0, duration_sample - snt_len_sample - 1) + stop = stop = start + snt_len_sample + sample['wav'] = sample['wav'][:,start:stop] + else: + continue + + yield sample + + + class SpeechAugPipline(object): - def __init__(self, speechaug={}, tail_speechaug={}): + def __init__(self,spk_num=0,speechaug={}, tail_speechaug={}): super().__init__() - self.speechaug = SpeechAug(**speechaug) - self.tail_speechaug = SpeechAug(**tail_speechaug) + self.speechaug = SpeechAug(spk_num=spk_num,**speechaug) + self.tail_speechaug = SpeechAug(spk_num=spk_num,**tail_speechaug) + speechaug_nid_aug=self.speechaug.get_spkid_aug() + tail_speechaug_nid_aug=self.tail_speechaug.get_spkid_aug() speechaug_n_concat=self.speechaug.get_num_concat() tail_speechaug_n_concat=self.tail_speechaug.get_num_concat() # The concat number of speech augment, which is used to modify target. self.concat_pip= (speechaug_n_concat,tail_speechaug_n_concat) - + self.spk_nid_aug = tail_speechaug_nid_aug*speechaug_nid_aug + if speechaug_nid_aug>1 and tail_speechaug_nid_aug>1: + raise ValueError("multi speaker id perturb setting, check your speech aug config") def __call__(self, data): """ speechaug. Args: @@ -283,17 +367,20 @@ def __call__(self, data): waveforms = sample['wav'] lens = sample['lens'] - try: - waveforms, lens = self.speechaug(waveforms, lens) + spkid = sample['label'] + # try: - waveforms, lens = self.tail_speechaug(waveforms, lens) - sample['wav'] = waveforms - - sample['lens'] = lens - yield sample - except Exception as ex: - logging.warning('Failed to speech aug {}'.format(sample['key'])) + waveforms, lens,spkid = self.speechaug(waveforms, lens,spkid) + waveforms, lens,spkid = self.tail_speechaug(waveforms, lens,spkid) + sample['wav'] = waveforms + sample['label'] = spkid + sample['lens'] = lens + yield sample + # except Exception as ex: + # logging.warning('Failed to speech aug {}'.format(sample['key'])) + def get_spkid_aug(self): + return self.spk_nid_aug @@ -311,7 +398,6 @@ def __init__(self,feature_type='mfcc',kaldi_featset={},mean_var_conf={}): super().__init__() assert feature_type in ['mfcc','fbank'] self.feat_type=feature_type - self.kaldi_featset=kaldi_featset if self.feat_type=='mfcc': self.extract=kaldi.mfcc @@ -327,7 +413,7 @@ def __call__(self,data): Args: data: Iterable[{key, wav, label, lens, sample_rate}] Returns: - Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] """ for sample in data: assert 'wav' in sample @@ -339,9 +425,10 @@ def __call__(self,data): self.kaldi_featset['sample_frequency'] = sample['sample_rate'] lens = sample['lens'] waveforms = sample['wav'] + label = sample['label'] waveforms = waveforms * (1 << 15) feats = [] - label = sample['label'] + labels=[] keys=[] utt = sample['key'] try: @@ -349,27 +436,32 @@ def __call__(self,data): lens=lens*waveforms.shape[1] for i,wav in enumerate(waveforms): - + if len(wav.shape)==1: # add channel wav=wav.unsqueeze(0) + else: + wav = wav.transpose(0, 1) + wav= wav[:,:lens[i].long()] + feat=self.extract(wav,**self.kaldi_featset) if(torch.any((torch.isnan(feat)))): logging.warning('Failed to make featrue for {}, aug version:{}'.format(sample['key'],i)) - pass + continue feat = feat.detach() feat=self.mean_var(feat) key = sample['key']+'#{}'.format(i) if i>0 else sample['key'] feats.append(feat) - + labels.append(label[i]) keys.append(key) if len(feats)==0: - pass + continue max_len = max([feat.size(0) for feat in feats]) - yield dict(utt=utt,keys=keys,feats=feats,label=label,max_len=max_len) + + yield dict(utt=utt,keys=keys,feats=feats,labels=labels,max_len=max_len) except Exception as ex: logging.warning('Failed to make featrue {}'.format(sample['key'])) @@ -381,9 +473,9 @@ def __init__(self,aug=None,aug_params={}): def __call__(self,data): """ make features. Args: - data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] Returns: - Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] """ for sample in data: assert 'keys' in sample @@ -396,32 +488,7 @@ def __call__(self,data): yield sample -# def speed_perturb(data, speeds=None): -# """ Apply speed perturb to the data. -# Inplace operation. - -# Args: -# data: Iterable[{key, wav, label, sample_rate}] -# speeds(List[float]): optional speed -# Returns: -# Iterable[{key, wav, label, sample_rate}] -# """ -# if speeds is None: -# speeds = [0.9, 1.0, 1.1] -# for sample in data: -# assert 'sample_rate' in sample -# assert 'wav' in sample -# sample_rate = sample['sample_rate'] -# waveform = sample['wav'] -# speed = random.choice(speeds) -# if speed != 1.0: -# wav, _ = torchaudio.sox_effects.apply_effects_tensor( -# waveform, sample_rate, -# [['speed', str(speed)], ['rate', str(sample_rate)]]) -# sample['wav'] = wav - -# yield sample @@ -429,14 +496,16 @@ def shuffle(data, shuffle_size=10000): """ Local shuffle the data Args: - data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] shuffle_size: buffer size for shuffle Returns: - Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] """ + buf = [] for sample in data: + buf.append(sample) if len(buf) >= shuffle_size: random.shuffle(buf) @@ -456,11 +525,11 @@ def sort(data, sort_size=500): be less than `shuffle_size` Args: - data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] sort_size: buffer size for sort Returns: - data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] """ buf = [] @@ -476,14 +545,13 @@ def sort(data, sort_size=500): for x in buf: yield x - def static_batch(data, batch_size=16): """ Static batch the data by `batch_size` Args: - data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] batch_size: batch size Returns: - Iterable[List[{utt:str, keys:list, label, feats:list, max_len:int}]] + Iterable[List[{utt:str, keys:list, labels:list, feats:list, max_len:int}]] """ buf = [] for sample in data: @@ -500,10 +568,10 @@ def dynamic_batch(data, max_frames_in_batch=12000): """ Dynamic batch the data until the total frames in batch reach `max_frames_in_batch` Args: - data: Iterable[{utt:str, keys:list, label, feats:list, max_len:int}] + data: Iterable[{utt:str, keys:list, labels:list, feats:list, max_len:int}] max_frames_in_batch: max_frames in one batch Returns: - Iterable[List[{utt:str, keys:list, label, feats:list, max_len:int}]] + Iterable[List[{utt:str, keys:list, labels:list, feats:list, max_len:int}]] """ buf = [] longest_frames = 0 @@ -541,25 +609,27 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): def padding(data): """ Padding the data into training data Args: - data: Iterable[List[{utt:str, keys:list, label, feats:list, max_len:int}]] + data: Iterable[List[{utt:str, keys:list, labels:list, feats:list, max_len:int}]] Returns: Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] """ for sample in data: assert isinstance(sample, list) + feats=[] labels=[] keys=[] for x in sample: feats.extend(x['feats']) - labels.extend([x['label']]*len(x['feats'])) + labels.extend(x['labels']) keys.extend(x['keys']) + + labels = torch.tensor(labels) - labels = torch.LongTensor(labels) feats = [(x.T) for x in feats] - padded_feats, lens = batch_pad_right(feats) + padded_feats, feats_lens = batch_pad_right(feats) - yield (padded_feats, labels) + yield (padded_feats, labels, feats_lens) diff --git a/pytorch/libs/egs/speech_augment.py b/pytorch/libs/egs/speech_augment.py old mode 100644 new mode 100755 index 547fe61..3920f1e --- a/pytorch/libs/egs/speech_augment.py +++ b/pytorch/libs/egs/speech_augment.py @@ -3,6 +3,7 @@ # Importing libraries import math +from typing import Optional,List import numpy as np import random import sklearn @@ -23,10 +24,10 @@ reverberate, ) from libs.support.utils import batch_pad_right - +import torch.distributed as dist class NoiseDataset(Dataset): - def __init__(self, csv_file, sorting="original", max_len=None): + def __init__(self, csv_file, sorting="original", max_len=None, filt_min=None): head = pd.read_csv(csv_file, sep=" ", nrows=0).columns @@ -37,6 +38,11 @@ def __init__(self, csv_file, sorting="original", max_len=None): data = pd.read_csv(csv_file, sep=" ",header=0) + + if filt_min: + data =data[data['duration']>filt_min] + + if sorting == "decending": data = data.sort_values(by=['duration'], ascending=False) elif sorting == "ascending": @@ -47,11 +53,13 @@ def __init__(self, csv_file, sorting="original", max_len=None): pass self.path = data['wav'].values.astype(np.string_) + + if max_len: assert max_len > 0.0 self.lens=data['tot_frame'].values self.sr=data['sr'].values - + self.max_len = max_len del data @@ -74,6 +82,7 @@ def __len__(self): return len(self.path) + class ReproducibleRandomSampler(RandomSampler): """A modification of RandomSampler which always returns the same values. @@ -195,6 +204,7 @@ class AddNoise(torch.nn.Module): def __init__( self, csv_file=None, + add_filt_min=None, sorting="random", num_workers=0, snr_low=0, @@ -207,6 +217,7 @@ def __init__( super().__init__() self.csv_file = csv_file + self.add_filt_min=add_filt_min self.sorting = sorting self.num_workers = num_workers self.snr_low = snr_low @@ -287,7 +298,9 @@ def _load_noise(self, lengths, max_length): # Create a data loader for the noise wavforms if self.csv_file is not None: dataset = NoiseDataset( - self.csv_file, sorting=self.sorting) + self.csv_file, + filt_min = self.add_filt_min, + sorting=self.sorting) shuffle = (self.sorting == "random") if torch.distributed.is_initialized(): sampler = torch.utils.data.distributed.DistributedSampler( @@ -325,7 +338,9 @@ def _load_noise(self, lengths, max_length): # Ensure noise batch is long enough elif noise_batch.size(1) < max_length: - padding = (0, max_length - noise_batch.size(1)) + pad = max_length - noise_batch.size(1) + left_padding = torch.randint(high = pad+1, size=(1,))[0] + padding = (left_padding,pad-left_padding) noise_batch = torch.nn.functional.pad(noise_batch, padding) # Select a random starting location in the waveform @@ -429,12 +444,12 @@ def __init__( self.csv_file = csv_file self.sorting = sorting self.reverb_prob = reverb_prob - self.rir_scale_factor = rir_scale_factor # Create a data loader for the RIR waveforms dataset = NoiseDataset( - self.csv_file, sorting=self.sorting) + self.csv_file, + sorting=self.sorting) shuffle = (self.sorting == "random") if torch.distributed.is_initialized(): sampler = torch.utils.data.distributed.DistributedSampler( @@ -443,7 +458,8 @@ def __init__( else: sampler = None self.data_loader = make_dataloader( - dataset, shuffle=shuffle, sampler=sampler + dataset, shuffle=shuffle, + sampler=sampler ) self.rir_data = iter(self.data_loader) @@ -546,6 +562,7 @@ class AddBabble(torch.nn.Module): def __init__( self, csv_file=None, + add_filt_min=None, sorting="random", num_workers=0, snr_low=0, @@ -559,6 +576,7 @@ def __init__( super().__init__() self.csv_file = csv_file + self.add_filt_min = add_filt_min self.sorting = sorting self.num_workers = num_workers self.snr_low = snr_low @@ -651,7 +669,10 @@ def _load_noise(self, lengths, max_length): # Create a data loader for the noise wavforms if self.csv_file is not None: dataset = NoiseDataset( - self.csv_file, sorting=self.sorting, max_len=self.babble_noise_max_len) + self.csv_file, + filt_min=self.add_filt_min, + sorting=self.sorting, + max_len=self.babble_noise_max_len) shuffle = (self.sorting == "random") if torch.distributed.is_initialized(): sampler = torch.utils.data.distributed.DistributedSampler( @@ -689,7 +710,9 @@ def _load_noise(self, lengths, max_length): # Ensure noise batch is long enough elif noise_batch.size(1) < max_length: - padding = (0, max_length - noise_batch.size(1)) + pad = max_length - noise_batch.size(1) + left_padding = torch.randint(high = pad+1, size=(1,))[0] + padding = (left_padding,pad-left_padding) noise_batch = torch.nn.functional.pad(noise_batch, padding) # Select a random starting location in the waveform @@ -991,6 +1014,56 @@ def forward(self, waveforms, lengths): return dropped_waveform +class RandomChunk(torch.nn.Module): + """Get segment. + Arguments + --------- + chunk_len : float + Get segment of utts, in senconds (s). + sample_rate : int + the sampling frequency of the input signal. + """ + def __init__( + self, + random_chunk=False, + chunk_len=2.015, + sample_rate=16000, + ): + super().__init__() + self.random_chunk=random_chunk + self.lens = int(chunk_len*sample_rate) + def forward(self, waveforms,lengths): + """ + Arguments + --------- + waveforms : tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + lengths : tensor + Shape should be a single dimension, `[batch]`. + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]` + """ + if not self.random_chunk: + return waveforms,lengths + lengths=(lengths * waveforms.shape[1]) # [B] + shape = list(waveforms.shape) + shape[1] = self.lens + chunk_sig = torch.zeros(shape,device=lengths.device) + for i in range(shape[0]): + if lengths[i] > self.lens: + max_chop = (lengths[i] - self.lens).long() + start_index = torch.randint( + high=max_chop, size=(1,)) + chunk_sig[i] = waveforms[i,start_index: start_index + self.lens] + else: + repeat_num = math.ceil(self.lens/lengths[i]) + chunk_sig[i:i+1] = waveforms[i:i+1,: ].repeat(1,repeat_num)[:,:self.lens] + lengths = torch.ones(shape[0]) + if chunk_sig.shape!=48240: + print(chunk_sig.shape) + return chunk_sig, lengths + class DoClip(torch.nn.Module): """This function mimics audio clipping by clamping the input tensor. @@ -1041,6 +1114,48 @@ def forward(self, waveforms): return clipped_waveform +class SoxEffectTransform(torch.nn.Module): + effects: List[List[str]] + + def __init__(self, effects: List[List[str]],sample_rate:int): + + super().__init__() + self.effects = effects + self.sample_rate = sample_rate + def forward(self, waveforms: torch.Tensor): + """ + Arguments + --------- + waveforms : tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]`. + """ + + wavs = [] + if self.effects == [[]]: + return waveforms + unsqueezed = False + if len(waveforms.shape)==2: + # add channel + waveforms=waveforms.unsqueeze(1) + unsqueezed = True + else: + waveforms = waveforms.transpose(1, 2) + for i,wav in enumerate(waveforms): + + wav,_ = torchaudio.sox_effects.apply_effects_tensor(wav, self.sample_rate, self.effects) + + wavs.append(wav.unsqueeze(0)) + wavs = torch.cat(wavs,dim=0) + + + if unsqueezed: + wavs=wavs.squeeze(1) + else: + wavs=wavs.transpose(1,2) + return wavs class SpeedPerturb(torch.nn.Module): """Slightly speed up or slow down an audio signal. @@ -1054,8 +1169,7 @@ class SpeedPerturb(torch.nn.Module): orig_freq : int The frequency of the original signal. speeds : list - The speeds that the signal should be changed to, as a percentage of the - original signal (i.e. `speeds` is divided by 100 to get a ratio). + A set of different speeds to use to perturb each batch. larger -> slower. perturb_prob : float The chance that the batch will be speed- perturbed. By default, every batch is perturbed. @@ -1064,27 +1178,52 @@ class SpeedPerturb(torch.nn.Module): """ def __init__( - self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0, keep_shape=True + self, orig_freq, speeds=[95, 100, 105], perturb_prob=1.0, keep_shape=True, perturb_type='resample', change_spk=False,spk_num=0 ): super().__init__() + assert perturb_type in ['resample','sox_speed','sox_tempo'] + self.orig_freq = orig_freq self.speeds = speeds self.perturb_prob = perturb_prob self.keep_shape = keep_shape + self.change_spk = change_spk - # Initialize index of perturbation - self.samp_index = 0 + if change_spk: + assert spk_num>0, "change_spk need total spk number." + + self.aug_spks = self._speed_to_speaker(speeds) + self.aug_spks = [spk_num*aug_spk for aug_spk in self.aug_spks] # Initialize resamplers - self.resamplers = [] + self.speeders = [] for speed in self.speeds: - config = { - "orig_freq": self.orig_freq, - "new_freq": self.orig_freq * speed // 100, - } - self.resamplers.append(Resample(**config)) + if perturb_type == 'resample': - def forward(self, waveform): + config = { + "orig_freq": self.orig_freq, + "new_freq": self.orig_freq * speed // 100, + } + self.speeders.append(Resample(**config)) + else: + + if perturb_type == 'sox_speed': + if speed==100: + effects = [[]] + else: + speed = round(100/speed,2) + effects = [['speed',str(speed)],['rate',str(orig_freq)]] + + elif perturb_type == 'sox_tempo': + if speed==100: + effects = [[]] + else: + speed = round(100/speed,2) + effects = [['tempo', str(speed)]] + else: + raise ValueError("unsupport perturb_type: {}".format(perturb_type)) + self.speeders.append(SoxEffectTransform(effects,orig_freq)) + def forward(self, waveform: torch.Tensor, spk_id: torch.Tensor=torch.ones((0), dtype=torch.long)): """ Arguments --------- @@ -1092,6 +1231,8 @@ def forward(self, waveform): Shape should be `[batch, time]` or `[batch, time, channels]`. lengths : tensor Shape should be a single dimension, `[batch]`. + spk_id: tensor + Shape should be a single dimension, `[batch]` Returns ------- @@ -1100,11 +1241,13 @@ def forward(self, waveform): # Don't perturb (return early) 1-`perturb_prob` portion of the batches if torch.rand(1) > self.perturb_prob: - return waveform.clone() + return waveform.clone(),spk_id # Perform a random perturbation - self.samp_index = torch.randint(len(self.speeds), (1,))[0] - perturbed_waveform = self.resamplers[self.samp_index](waveform) + speed_index = torch.randint(len(self.speeds), (1,))[0] + perturbed_waveform = self.speeders[speed_index](waveform) + if self.change_spk: + spk_id = self.aug_spks[speed_index]+ spk_id if self.keep_shape: # Managing speed change @@ -1116,7 +1259,27 @@ def forward(self, waveform): zero_sig[:, 0: perturbed_waveform.shape[1] ] = perturbed_waveform perturbed_waveform = zero_sig - return perturbed_waveform + return perturbed_waveform,spk_id + + def get_spkid_aug(self): + sp_aug,spkid_aug =1,1 + if self.perturb_prob>0: + sp_aug = len(set(self.speeds)) + if self.change_spk: + spkid_aug=len(set(self.speeds)) + return spkid_aug,sp_aug + + def _speed_to_speaker(self,speeds): + assert 100 in speeds, "speed perturb with speaker aug need origin speed." + t = {} + spk_cont = 0 + for s in sorted(set(speeds),key = speeds.index): + if s ==100: + t[s]=0 + else: + spk_cont+=1 + t[s]=spk_cont + return [t[sp] for sp in speeds] class Resample(torch.nn.Module): @@ -1449,6 +1612,8 @@ class EnvCorrupt(torch.nn.Module): A prepared csv file for loading noise data, if None, means white noise. babble_csv : str A prepared csv file for loading babble data, if None, means simulated babble noise. + add_filt_min : float + Filt the short noises in when loading noises and babble from csv. noise_num_workers : int Number of workers to use for loading noises. babble_speaker_count : int @@ -1482,14 +1647,16 @@ def __init__( reverb_csv=None, noise_csv=None, babble_csv=None, + add_filt_min = None, babble_noise_max_len=2.0, noise_num_workers=0, babble_speaker_count=0, - babble_snr_low=0, - babble_snr_high=0, + babble_snr_low=13, + babble_snr_high=20, noise_snr_low=0, - noise_snr_high=0, - rir_scale_factor=1.0, + noise_snr_high=15, + pad_noise = False, + rir_scale_factor=1.0, **ops ): super().__init__() @@ -1506,23 +1673,27 @@ def __init__( self.add_babble = AddBabble( mix_prob=babble_prob, csv_file=babble_csv, + add_filt_min = add_filt_min, num_workers=noise_num_workers, speaker_count=babble_speaker_count, snr_low=babble_snr_low, snr_high=babble_snr_high, babble_noise_max_len=babble_noise_max_len, + pad_noise=pad_noise, ) if noise_prob > 0.0: self.add_noise = AddNoise( mix_prob=noise_prob, csv_file=noise_csv, + add_filt_min = add_filt_min, num_workers=noise_num_workers, snr_low=noise_snr_low, snr_high=noise_snr_high, + pad_noise=pad_noise, ) - def forward(self, waveforms, lengths): + def forward(self, waveforms, lengths,spk_id:torch.ones((0),dtype=torch.long)): """Returns the distorted waveforms. Arguments @@ -1543,7 +1714,7 @@ def forward(self, waveforms, lengths): if hasattr(self, "add_noise"): waveforms = self.add_noise(waveforms, lengths) - return waveforms + return waveforms,lengths, spk_id class TimeDomainSpecAugment(torch.nn.Module): @@ -1553,8 +1724,9 @@ class TimeDomainSpecAugment(torch.nn.Module): the time-domain. 1. Drop chunks of the audio (zero amplitude or white noise) - 2. Drop frequency bands (with band-drop filters) - 3. Speed peturbation (via resampling to slightly different rate) + 2. RandomChunk selection. + 3. Drop frequency bands (with band-drop filters) + 4. Speed peturbation (via resampling to slightly different rate) Arguments --------- @@ -1565,8 +1737,19 @@ class TimeDomainSpecAugment(torch.nn.Module): drop_chunk_prob : float from 0 to 1 The probability that a batch will have chunks dropped. speeds : list of ints - A set of different speeds to use to perturb each batch. - See ``speechbrain.processing.speech_augmentation.SpeedPerturb`` + A set of different speeds to use to perturb each batch. larger -> slower. + spk_num: int + The total speker num, for aug spkid if needed. + perturb_type: str + ['resample','sox_speed','sox_tempo'] + change_spk: bool + Whether aug spkid + keep_shape: bool + keep time length after speed perturb. + random_chunk: bool + random select chunks after speed perturb. + ramddom_chunsize: + random chunks length in seconds (s). sample_rate : int Sampling rate of the input waveforms. drop_freq_count_low : int @@ -1597,10 +1780,15 @@ class TimeDomainSpecAugment(torch.nn.Module): def __init__( self, perturb_prob=1.0, - drop_freq_prob=1.0, - drop_chunk_prob=1.0, - speeds=[95, 100, 105], + drop_freq_prob=0.0, + drop_chunk_prob=0.0, + speeds=[95, 100, 110], + spk_num=0, + perturb_type='resample', + change_spk=False, keep_shape=True, + random_chunk=False, + ramddom_chunsize=2.015, sample_rate=16000, drop_freq_count_low=0, drop_freq_count_high=3, @@ -1613,7 +1801,18 @@ def __init__( ): super().__init__() self.speed_perturb = SpeedPerturb( - perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds, keep_shape=keep_shape + perturb_prob=perturb_prob, + orig_freq=sample_rate, + speeds=speeds, + spk_num=spk_num, + perturb_type=perturb_type, + change_spk=change_spk, + keep_shape=keep_shape + ) + self.random_chunk = RandomChunk( + random_chunk = random_chunk, + chunk_len = ramddom_chunsize, + sample_rate = sample_rate ) self.drop_freq = DropFreq( drop_prob=drop_freq_prob, @@ -1629,7 +1828,7 @@ def __init__( noise_factor=drop_chunk_noise_factor, ) - def forward(self, waveforms, lengths): + def forward(self, waveforms, lengths, spk_id:torch.ones((0),dtype=torch.long)): """Returns the distorted waveforms. Arguments @@ -1639,11 +1838,18 @@ def forward(self, waveforms, lengths): """ # Augmentation with torch.no_grad(): - waveforms = self.speed_perturb(waveforms) + waveforms,spk_id = self.speed_perturb(waveforms,spk_id) + waveforms,lengths = self.random_chunk(waveforms,lengths) waveforms = self.drop_freq(waveforms) waveforms = self.drop_chunk(waveforms, lengths) - return waveforms + return waveforms,lengths,spk_id + + def get_spkid_aug(self): + + return self.speed_perturb.get_spkid_aug() + + class SpeechAug(torch.nn.Module): @@ -1675,16 +1881,17 @@ class SpeechAug(torch.nn.Module): signal, lens = speech_aug(signal,torch.ones(1)) """ - def __init__(self, aug_classes=[], mod="random"): + def __init__(self, spk_num=0,aug_classes=[], mod="random"): super().__init__() assert mod in ["random", "concat", "chain"] self.mod = mod self.augment = [] self.augment_name = [] - + self.spk_num=spk_num # define a weight of clean wav type random_weights = [1] if self.mod == 'random' else [] - + self.spkid_aug,self.sp_aug = 1,1 + spt_num=0 for aug_class in aug_classes: assert 'aug_type' in aug_class @@ -1709,59 +1916,76 @@ def __init__(self, aug_classes=[], mod="random"): if aug_type == 'Env': self.augment.append(EnvCorrupt(**aug_class)) if aug_type == 'Time': - self.augment.append(TimeDomainSpecAugment(**aug_class)) + td_aug = TimeDomainSpecAugment(spk_num=spk_num,**aug_class) + spkid_aug,sp_aug = td_aug.get_spkid_aug() + spt_num+=(int(spkid_aug>1)) + + if spt_num > 1: + raise ValueError("multi speaker id perturb setting, check your speech aug config") + if spkid_aug>1: + self.spkid_aug = spkid_aug + self.augment.append(td_aug) + + self.random_weight = torch.tensor( random_weights, dtype=torch.float) if random_weights else None - self.get_augment() + self.print_augment() - def forward(self, waveforms, lengths): + def forward(self, waveforms, lengths, spkid: torch.Tensor=torch.ones((0),dtype=torch.long)): if not self.augment: - return waveforms, lengths + return waveforms, lengths, spkid if self.mod == 'random': - return self._random_forward(waveforms, lengths) + return self._random_forward(waveforms, lengths,spkid) elif self.mod == 'chain': - return self._chain_forward(waveforms, lengths) + return self._chain_forward(waveforms, lengths,spkid) else: - return self._concat_forward(waveforms, lengths) + return self._concat_forward(waveforms, lengths,spkid) - def _random_forward(self, waveforms, lengths): + def _random_forward(self, waveforms, lengths,spkid): aug_idx = torch.multinomial(self.random_weight, 1)[0] if aug_idx == 0: - return waveforms, lengths + return waveforms, lengths,spkid else: - waves=self.augment[aug_idx-1](waveforms, lengths) + waves,lengths,spkid=self.augment[aug_idx-1](waveforms, lengths,spkid) if(torch.any((torch.isnan(waves)))): raise ValueError('random_1:{},type:{},typename:{}'.format(waveforms,self.augment[aug_idx-1],self.augment_name[aug_idx-1])) - return waves, lengths + return waves, lengths, spkid - def _concat_forward(self, waveforms, lengths): + def _concat_forward(self, waveforms, lengths,spkid): wavs_aug_tot = [] + spkids=[] + lens = [] wavs_aug_tot.append(waveforms.clone()) + spkids.append(spkid) + lens.append(lengths) for count, augment in enumerate(self.augment): - wavs_aug = augment(waveforms, lengths) + wavs_aug,len,spkid_a = augment(waveforms, lengths,spkid) if(torch.any((torch.isnan(wavs_aug)))): raise ValueError('concat:{},type:{},typename:{}'.format(waveforms,self.augment[count],self.augment_name[count])) wavs_aug_tot.append(wavs_aug) + spkids.append(spkid_a) + lens.append(len) waveforms = torch.cat(wavs_aug_tot, dim=0) - lengths = torch.cat([lengths] * len(wavs_aug_tot)) - return waveforms, lengths + lens = torch.cat(lens) + spkids = torch.cat(spkids, dim=0) + return waveforms, lens, spkids - def _chain_forward(self, waveforms, lengths): + def _chain_forward(self, waveforms, lengths,spkid): for count, augment in enumerate(self.augment): - waveforms = augment(waveforms, lengths) + waveforms,lengths,spkid = augment(waveforms, lengths,spkid) if(torch.any((torch.isnan(waveforms)))): raise ValueError('chian:{},type:{},typename:{}'.format(waveforms,self.augment[count],self.augment_name[count])) - return waveforms, lengths + return waveforms, lengths,spkid def get_num_concat(self): @@ -1770,7 +1994,13 @@ def get_num_concat(self): else: return 1 - def get_augment(self): + def get_spkid_aug(self): + + return self.spkid_aug + + + + def print_augment(self): if self.augment: print('speech augment type is {}.'.format(self.mod)) aug_dict=dict(zip(self.augment_name,self.augment)) diff --git a/pytorch/libs/nnet/activation.py b/pytorch/libs/nnet/activation.py index 5c91ccb..95c2e8f 100644 --- a/pytorch/libs/nnet/activation.py +++ b/pytorch/libs/nnet/activation.py @@ -20,6 +20,51 @@ def __init__(self): def forward(self, inputs): return inputs * torch.tanh(F.softplus(inputs)) +class Swish(torch.nn.Module): + """Construct an Swish object.""" + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swish activation function.""" + return x * torch.sigmoid(x) + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + x = x.detach() + s = torch.sigmoid(x - 1.0) + y = x * s + ctx.save_for_backward(s, y) + return y + + @staticmethod + def backward(ctx, y_grad: torch.Tensor) -> torch.Tensor: + s, y = ctx.saved_tensors + return (y * (1 - s) + s) * y_grad + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + else: + return DoubleSwishFunction.apply(x) + ## Wrapper ✿ def Nonlinearity(nonlinearity="relu", inplace=True, negative_slope=0.01): """A wrapper for activation. @@ -34,6 +79,13 @@ def Nonlinearity(nonlinearity="relu", inplace=True, negative_slope=0.01): activation = Mish() elif nonlinearity == 'tanh' : activation = torch.nn.Tanh() + elif nonlinearity == 'swish' : + func = getattr(torch.nn, "SiLU", Swish) + activation = func() + elif nonlinearity == 'gelu' : + activation = torch.nn.GELU(inplace=inplace) + elif nonlinearity == 'double_swish' : + activation = DoubleSwish() elif nonlinearity == "" or nonlinearity is None or nonlinearity == False: activation = None else: diff --git a/pytorch/libs/nnet/components.py b/pytorch/libs/nnet/components.py index eab1cf0..696ba47 100755 --- a/pytorch/libs/nnet/components.py +++ b/pytorch/libs/nnet/components.py @@ -12,7 +12,7 @@ from libs.support.utils import to_device import libs.support.utils as utils - +from libs.nnet.transformer.layer_norm import LayerNorm ### There are some basic custom components/layers. ### @@ -350,6 +350,7 @@ def add_relu_bn(self, output_dim=None, options:dict={}): "nonlinearity":'relu', "nonlinearity_params":{"inplace":True, "negative_slope":0.01}, "bn":True, + "ln_replace": False, "bn_params":{"momentum":0.1, "affine":True, "track_running_stats":True}, "special_init":True, "mode":'fan_out', @@ -369,13 +370,19 @@ def add_relu_bn(self, output_dim=None, options:dict={}): self.bn_relu = False self.activation = Nonlinearity(default_params["nonlinearity"], **default_params["nonlinearity_params"]) if default_params["bn"]: - self.batchnorm = torch.nn.BatchNorm1d(output_dim, **default_params["bn_params"]) + if not default_params['ln_replace']: + self.batchnorm = torch.nn.BatchNorm1d(output_dim, **default_params["bn_params"]) + else: + self.batchnorm = LayerNorm(output_dim,dim=1,eps=1e-5,learnabel_affine=default_params["bn_params"]["affine"]) else: # BN-ReLU # self.after_forward = self._bn_relu_forward self.bn_relu = True if default_params["bn"]: - self.batchnorm = torch.nn.BatchNorm1d(output_dim, **default_params["bn_params"]) + if not default_params['ln_replace']: + self.batchnorm = torch.nn.BatchNorm1d(output_dim, **default_params["bn_params"]) + else: + self.batchnorm = LayerNorm(output_dim,dim=1,eps=1e-5) self.activation = Nonlinearity(default_params["nonlinearity"], **default_params["nonlinearity_params"]) if default_params["special_init"] and self.affine is not None and not default_params["jit_compile"]: @@ -463,7 +470,7 @@ class ReluBatchNormTdnnfLayer(_BaseActivationBatchNorm): def __init__(self, input_dim, output_dim, inner_size, context_size = 0, **options): super(ReluBatchNormTdnnfLayer, self).__init__() - self.affine = FTdnnBlock(input_dim, output_dim, inner_size, context_size) + self.affine = TdnnfBlock(input_dim, output_dim, inner_size, context_size) self.add_relu_bn(output_dim, options=options) @@ -842,4 +849,17 @@ def to(self, device): return self - +# Leo 2022-08-14 +class LabelSmoothing(torch.nn.Module): + """Label-smoothing for target tensor. + + """ + def __init__( + self, + mean_norm=True, + std_norm=True, + ): + super().__init__() + self.mean_norm = mean_norm + self.std_norm = std_norm + self.eps = 1e-10 \ No newline at end of file diff --git a/pytorch/libs/nnet/loss.py b/pytorch/libs/nnet/loss.py index 895dfde..9f50c75 100644 --- a/pytorch/libs/nnet/loss.py +++ b/pytorch/libs/nnet/loss.py @@ -3,7 +3,7 @@ # Copyright xmuspeech (Author: Snowdar 2019-05-29) import numpy as np - +import math import torch import torch.nn.functional as F @@ -88,11 +88,11 @@ def compute_accuracy(self, outputs, targets): class SoftmaxLoss(TopVirtualLoss): """ An usual log-softmax loss with affine component. """ - def init(self, input_dim, num_targets, t=1, reduction='mean', special_init=False): + def init(self, input_dim, num_targets, t=1, reduction='mean', special_init=False,label_smoothing = 0.0): self.affine = TdnnAffine(input_dim, num_targets) self.t = t # temperature # CrossEntropyLoss() has included the LogSoftmax, so do not add this function extra. - self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction) + self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction,label_smoothing=label_smoothing) # The special_init is not recommended in this loss component if special_init : @@ -120,12 +120,12 @@ class SoftmaxLoss_frame_phone_fix(TopVirtualLoss): #Zheng Li 2021-06-08 """ An usual log-softmax loss with affine component. """ - def init(self, input_dim, num_targets, t=1, reduction='mean', special_init=False): + def init(self, input_dim, num_targets, t=1, reduction='mean', special_init=False,label_smoothing = 0.0): self.affine = TdnnAffine(input_dim, num_targets) self.t = t # temperature self.num_phones = num_targets - 1 # CrossEntropyLoss() has included the LogSoftmax, so do not add this function extra. - self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction) + self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction,label_smoothing=label_smoothing) # The special_init is not recommended in this loss component if special_init : @@ -187,6 +187,7 @@ def forward(self, inputs, targets): return self.loss_function(outputs, targets) + class MarginSoftmaxLoss(TopVirtualLoss): """Margin softmax loss. There are AM, AAM, Double-AM, SM1 (Snowdar Margin softmax loss), SM2 and SM3. @@ -225,12 +226,13 @@ def init(self, input_dim, num_targets, inter_loss=0., ring_loss=0., curricular=False, - reduction='mean', eps=1.0e-10, init=True): + reduction='mean', eps=1.0e-10, scale_init=False,label_smoothing = 0.0): self.input_dim = input_dim self.num_targets = num_targets self.weight = torch.nn.Parameter(torch.randn(num_targets, input_dim, 1)) self.s = s # scale factor with feature normalization + self.init_m = m self.m = m # margin self.t = t # temperature self.feature_normalize = feature_normalize @@ -241,8 +243,7 @@ def init(self, input_dim, num_targets, self.mhe_w = mhe_w self.inter_loss = inter_loss self.ring_loss = ring_loss - self.lambda_factor = 0 - + self.lambda_m = 1 self.curricular = CurricularMarginComponent() if curricular else None if self.ring_loss > 0: @@ -260,112 +261,207 @@ def init(self, input_dim, num_targets, There are some suggested s : {suggested_s} w.r.t p_target {p_target}.".format( s=self.s, suggested_s=suggested_s, p_target=p_target)) - self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction) - + self.loss_function = torch.nn.CrossEntropyLoss(reduction=reduction,label_smoothing = label_smoothing) + self.weight_scale =None + self.scale_init=scale_init # Init weight. - if init: + if scale_init: + initial_scale = torch.tensor(1.0).log() + self.weight_scale = torch.nn.Parameter(initial_scale.clone().detach()) + std = 0.1 + a = (3 ** 0.5) * std + torch.nn.init.uniform_(self.weight, -a, a) + fan_in = self.weight.shape[1] *self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + else: # torch.nn.init.xavier_normal_(self.weight, gain=1.0) torch.nn.init.normal_(self.weight, 0., 0.01) # It seems better. + def get_weight(self): + if self.scale_init: + return self.weight * self.weight_scale.exp() + else: + return self.weight def forward(self, inputs, targets): """ @inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index] """ assert len(inputs.shape) == 3 assert inputs.shape[2] == 1 - + # with torch.cuda.amp.autocast(enabled=False): ## Normalize - normalized_x = F.normalize(inputs.squeeze(dim=2), dim=1) # [512, 512] [batch_size, output_dim] - normalized_weight = F.normalize(self.weight.squeeze(dim=2), dim=1) # [1000, 512] [target_num, output_dim] - cosine_theta = F.linear(normalized_x, normalized_weight) # Y = W*X - - if not self.feature_normalize : - self.s = inputs.norm(2, dim=1) # [batch-size, l2-norm] - # The accuracy must be reported before margin penalty added - self.posterior = (self.s.detach() * cosine_theta.detach()).unsqueeze(2) - else: - self.posterior = (self.s * cosine_theta.detach()).unsqueeze(2) - - if not self.training: - # For valid set. - outputs = self.s * cosine_theta - return self.loss_function(outputs, targets) - ## Margin Penalty - # cosine_theta [batch_size, num_class] - # targets.unsqueeze(1) [batch_size, 1] - cosine_theta_target = cosine_theta.gather(1, targets.unsqueeze(1)) - - if self.inter_loss > 0: - inter_cosine_theta = torch.softmax(self.s * cosine_theta, dim=1) - inter_cosine_theta_target = inter_cosine_theta.gather(1, targets.unsqueeze(1)) - inter_loss = torch.log((inter_cosine_theta.sum(dim=1) - inter_cosine_theta_target)/(self.num_targets - 1) + self.eps).mean() + normalized_x = F.normalize(inputs.squeeze(dim=2), dim=1) # [512, 512] [batch_size, output_dim] + normalized_weight = F.normalize(self.get_weight().squeeze(dim=2), dim=1) # [1000, 512] [target_num, output_dim] + + with torch.cuda.amp.autocast(enabled=False): + cosine_theta = F.linear(normalized_x, normalized_weight) # Y = W*X + + if not self.feature_normalize : + self.s = inputs.norm(2, dim=1) # [batch-size, l2-norm] + # The accuracy must be reported before margin penalty added + self.posterior = (self.s.detach() * cosine_theta.detach()).unsqueeze(2) + else: + self.posterior = (self.s * cosine_theta.detach()).unsqueeze(2) + + if not self.training: + # For valid set. + outputs = self.s * cosine_theta + return self.loss_function(outputs, targets) + + ## Margin Penalty + # cosine_theta [batch_size, num_class] + # targets.unsqueeze(1) [batch_size, 1] + cosine_theta_target = cosine_theta.gather(1, targets.unsqueeze(1)) + + if self.inter_loss > 0: + inter_cosine_theta = torch.softmax(self.s * cosine_theta, dim=1) + inter_cosine_theta_target = inter_cosine_theta.gather(1, targets.unsqueeze(1)) + inter_loss = torch.log((inter_cosine_theta.sum(dim=1) - inter_cosine_theta_target)/(self.num_targets - 1) + self.eps).mean() + + if self.method == "am": + penalty_cosine_theta = cosine_theta_target - self.m + if self.double: + double_cosine_theta = cosine_theta + self.m + elif self.method == "aam": + # Another implementation w.r.t cosine(theta+m) = cosine_theta * cos_m - sin_theta * sin_m + # penalty_cosine_theta = self.cos_m * cosine_theta_target - self.sin_m * torch.sqrt((1-cosine_theta_target**2).clamp(min=0.)) + penalty_cosine_theta = torch.cos(torch.acos(cosine_theta_target) + self.m) + if self.double: + double_cosine_theta = torch.cos(torch.acos(cosine_theta).add(-self.m)) + elif self.method == "sm1": + # penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target) * self.m + penalty_cosine_theta = (1 + self.m) * cosine_theta_target - self.m + elif self.method == "sm2": + penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target**2) * self.m + elif self.method == "sm3": + penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target)**2 * self.m + else: + raise ValueError("Do not support this {0} margin w.r.t [ am | aam | sm1 | sm2 | sm3 ]".format(self.method)) + + penalty_cosine_theta = self.lambda_m * penalty_cosine_theta + \ + (1 - self.lambda_m) * cosine_theta_target - if self.method == "am": - penalty_cosine_theta = cosine_theta_target - self.m - if self.double: - double_cosine_theta = cosine_theta + self.m - elif self.method == "aam": - # Another implementation w.r.t cosine(theta+m) = cosine_theta * cos_m - sin_theta * sin_m - # penalty_cosine_theta = self.cos_m * cosine_theta_target - self.sin_m * torch.sqrt((1-cosine_theta_target**2).clamp(min=0.)) - penalty_cosine_theta = torch.cos(torch.acos(cosine_theta_target) + self.m) if self.double: - double_cosine_theta = torch.cos(torch.acos(cosine_theta).add(-self.m)) - elif self.method == "sm1": - # penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target) * self.m - penalty_cosine_theta = (1 + self.m) * cosine_theta_target - self.m - elif self.method == "sm2": - penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target**2) * self.m - elif self.method == "sm3": - penalty_cosine_theta = cosine_theta_target - (1 - cosine_theta_target)**2 * self.m - else: - raise ValueError("Do not support this {0} margin w.r.t [ am | aam | sm1 | sm2 | sm3 ]".format(self.method)) - - penalty_cosine_theta = 1 / (1 + self.lambda_factor) * penalty_cosine_theta + \ - self.lambda_factor / (1 + self.lambda_factor) * cosine_theta_target - - if self.double: - cosine_theta = 1/(1+self.lambda_factor) * double_cosine_theta + self.lambda_factor/(1+self.lambda_factor) * cosine_theta - - if self.curricular is not None: - cosine_theta = self.curricular(cosine_theta, cosine_theta_target, penalty_cosine_theta) - - outputs = self.s * cosine_theta.scatter(1, targets.unsqueeze(1), penalty_cosine_theta) + cosine_theta = self.lambda_m * double_cosine_theta + (1 - self.lambda_m) * cosine_theta + + if self.curricular is not None: + cosine_theta = self.curricular(cosine_theta, cosine_theta_target, penalty_cosine_theta) + + # cosine_theta = cosine_theta.float() + # penalty_cosine_theta = penalty_cosine_theta.float() + + outputs = self.s * cosine_theta.scatter(1, targets.unsqueeze(1), penalty_cosine_theta) + + ## Other extra loss + # Final reported loss will be always higher than softmax loss for the absolute margin penalty and + # it is a lie about why we can not decrease the loss to a mininum value. We should not report the + # loss after margin penalty did but we really report this invalid loss to avoid computing the + # training loss twice. + + if self.ring_loss > 0: + ring_loss = torch.mean((self.s - self.r)**2)/2 + else: + ring_loss = 0. + + if self.mhe_loss: + sub_weight = normalized_weight - torch.index_select(normalized_weight, 0, targets).unsqueeze(dim=1) + # [N, C] + normed_sub_weight = sub_weight.norm(2, dim=2) + mask = torch.full_like(normed_sub_weight, True, dtype=torch.bool).scatter_(1, targets.unsqueeze(dim=1), False) + # [N, C-1] + normed_sub_weight_clean = torch.masked_select(normed_sub_weight, mask).reshape(targets.size()[0], -1) + # torch.mean means 1/(N*(C-1)) + the_mhe_loss = self.mhe_w * torch.mean((normed_sub_weight_clean**2).clamp(min=self.eps)**-1) + + return self.loss_function(outputs/self.t, targets) + the_mhe_loss + self.ring_loss * ring_loss + elif self.inter_loss > 0: + return self.loss_function(outputs/self.t, targets) + self.inter_loss * inter_loss + self.ring_loss * ring_loss + else: + return self.loss_function(outputs/self.t, targets) + self.ring_loss * ring_loss + + def step(self, lambda_m, add_margin=None): + self.lambda_m = lambda_m + if add_margin is not None: + self.m = max(0,self.init_m+add_margin) - ## Other extra loss - # Final reported loss will be always higher than softmax loss for the absolute margin penalty and - # it is a lie about why we can not decrease the loss to a mininum value. We should not report the - # loss after margin penalty did but we really report this invalid loss to avoid computing the - # training loss twice. - if self.ring_loss > 0: - ring_loss = torch.mean((self.s - self.r)**2)/2 + def extra_repr(self): + return '(~affine): (input_dim={input_dim}, num_targets={num_targets}, method={method}, double={double}, ' \ + 'margin={m}, s={s}, t={t}, feature_normalize={feature_normalize}, mhe_loss={mhe_loss}, mhe_w={mhe_w}, ' \ + 'eps={eps}, scale_init={scale_init})'.format(**self.__dict__) + + +# Leo 2022-11-08 +class MarginWarm(torch.nn.Module): + def __init__(self, + start_epoch, + end_epoch, + offset_margin = 0.0, + init_lambda = 1.0, + epoch_iter=None): + ''' + between start_epoch and end_epoch, the offset_margin is + exponentially increasing from offset_margin (usually negative) to 0. + And the lambda_t is linearly increasing from init_lambda to 1. + It is designed to control the margin_softmaxloss through `margin + offset_margin` and + `penalty_cosine_theta = lambda * penalty_cosine_theta + (1 - lambda) * cosine_theta_target` + ''' + super().__init__() + if end_epoch < start_epoch: + raise ValueError("End_epoch should not smaller then start_epoch, but got end_epoch: {}, start_epoch:{}" + .format(end_epoch,start_epoch)) + assert abs(init_lambda - 0.5)<=0.5,"init_lambda should be in [0, 1]" + self.start_epoch = start_epoch + self.end_epoch = end_epoch + self.offset_margin = offset_margin + self.init_lambda = init_lambda + self.epoch_iter = epoch_iter + if epoch_iter: + self.update_step_range(epoch_iter) + + + def update_step_range(self,epoch_iter,overwrite=False): + if not overwrite and self.epoch_iter: + raise ValueError("epoch_iter has been set as {}, but overwrite = {} now".format(self.epoch_iter,overwrite)) else: - ring_loss = 0. - - if self.mhe_loss: - sub_weight = normalized_weight - torch.index_select(normalized_weight, 0, targets).unsqueeze(dim=1) - # [N, C] - normed_sub_weight = sub_weight.norm(2, dim=2) - mask = torch.full_like(normed_sub_weight, True, dtype=torch.bool).scatter_(1, targets.unsqueeze(dim=1), False) - # [N, C-1] - normed_sub_weight_clean = torch.masked_select(normed_sub_weight, mask).reshape(targets.size()[0], -1) - # torch.mean means 1/(N*(C-1)) - the_mhe_loss = self.mhe_w * torch.mean((normed_sub_weight_clean**2).clamp(min=self.eps)**-1) - - return self.loss_function(outputs/self.t, targets) + the_mhe_loss + self.ring_loss * ring_loss - elif self.inter_loss > 0: - return self.loss_function(outputs/self.t, targets) + self.inter_loss * inter_loss + self.ring_loss * ring_loss + self.epoch_iter = epoch_iter + self.increase_start_iter = (self.start_epoch - 1) * epoch_iter + self.fix_start_iter = (self.end_epoch - 1) * epoch_iter + self.step_range = self.fix_start_iter - self.increase_start_iter + + def get_increase_margin(self, cur_step): + initial_val = 1.0 + final_val = 1e-3 + + cur_pos = cur_step - self.increase_start_iter + + ratio = math.exp( + (cur_pos / self.step_range) * + math.log(final_val / (initial_val + 1e-6))) * initial_val + offset_margin = self.offset_margin * ratio + lambda_t = self.init_lambda + (cur_pos / self.step_range) * (1-self.init_lambda) + + return offset_margin, lambda_t + + def step(self, cur_step): + if self.epoch_iter<0 or not isinstance(self.epoch_iter, int): + raise ValueError("MarginWarm expected positive integer epoch_iter, but got {}" + .format(self.epoch_iter)) + if cur_step >= self.fix_start_iter: + return 0,1 + elif cur_step > self.increase_start_iter: + return self.get_increase_margin(cur_step) else: - return self.loss_function(outputs/self.t, targets) + self.ring_loss * ring_loss - - def step(self, lambda_factor): - self.lambda_factor = lambda_factor + return self.offset_margin,self.init_lambda def extra_repr(self): - return '(~affine): (input_dim={input_dim}, num_targets={num_targets}, method={method}, double={double}, ' \ - 'margin={m}, s={s}, t={t}, feature_normalize={feature_normalize}, mhe_loss={mhe_loss}, mhe_w={mhe_w}, ' \ - 'eps={eps})'.format(**self.__dict__) + return 'start_epoch={start_epoch}, end_epoch={end_epoch}, epoch_iter={epoch_iter}, ' \ + 'offset_margin={offset_margin} init_lambda={init_lambda})'.format(**self.__dict__) + class CurricularMarginComponent(torch.nn.Module): diff --git a/pytorch/libs/nnet/pooling.py b/pytorch/libs/nnet/pooling.py index 117bdc4..4fd119c 100644 --- a/pytorch/libs/nnet/pooling.py +++ b/pytorch/libs/nnet/pooling.py @@ -3,7 +3,7 @@ # Copyright xmuspeech (Author: Snowdar 2019-05-29 2020-06-10) import numpy as np - +from typing import Optional import torch import torch.nn.functional as F @@ -29,32 +29,45 @@ def __init__(self, input_dim, stddev=True, unbiased=False, eps=1.0e-10): # Used for unbiased estimate of stddev self.unbiased = unbiased - def forward(self, inputs): + def forward(self, inputs, lengths:torch.Tensor = torch.ones((0),dtype=torch.long)): """ @inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index] """ assert len(inputs.shape) == 3 assert inputs.shape[1] == self.input_dim - # Get the num of frames - counts = inputs.shape[2] - - mean = inputs.sum(dim=2, keepdim=True) / counts - - if self.stddev : - if self.unbiased and counts > 1: - counts = counts - 1 - - # The sqrt (as follows) is deprecated because it results in Nan problem. - # std = torch.unsqueeze(torch.sqrt(torch.sum((inputs - mean)**2, dim=2) / counts), dim=2) - # There is a eps to solve this problem. - # Another method: Var is equal to std in "cat" way, actually. So, just use Var directly. - - var = torch.sum((inputs - mean)**2, dim=2, keepdim=True) / counts - std = torch.sqrt(var.clamp(min=self.eps)) - return torch.cat((mean, std), dim=1) + if lengths.size(0) > 0: + mean = [] + std = [] + for i in range(inputs.shape[0]): + act_len = lengths[i] + act_counts = act_len + mean_i = torch.mean(inputs[i,:,:act_len],dim=1,keepdim=True) + mean.append(mean_i) + if self.stddev : + if self.unbiased and act_len > 1: + act_counts = act_len - 1 + var = torch.sum((inputs[i,:,:act_len]-mean_i)**2, dim=1, keepdim=True)/act_counts + std.append(torch.sqrt(var.clamp(min=self.eps))) + + mean = torch.stack(mean) + out = mean + if self.stddev : + std_o = torch.stack(std) + out = torch.cat((mean, std_o), dim=1) else: - return mean + counts = inputs.shape[2] + mean = inputs.mean(dim=2, keepdim=True) + out = mean + if self.stddev: + if self.unbiased and counts > 1: + counts = counts - 1 + var = torch.sum((inputs - mean)**2, dim=2, keepdim=True) / counts + std = torch.sqrt(var.clamp(min=self.eps)) + out = torch.cat((mean, std), dim=1) + + return out + def get_output_dim(self): return self.output_dim diff --git a/pytorch/libs/nnet/transformer/__init__.py b/pytorch/libs/nnet/transformer/__init__.py index 79abb45..0dee3d0 100644 --- a/pytorch/libs/nnet/transformer/__init__.py +++ b/pytorch/libs/nnet/transformer/__init__.py @@ -3,9 +3,8 @@ # Copyright xmuspeech (Author: Snowdar 2020-02-21) -__all__ = ["TransformerEncoder", "repeat", "MultiSequential", "MultiHeadedAttention", "LayerNorm"] +__all__ = ["TransformerEncoder", "ConformerEncoder", "MultiHeadedAttention", "LayerNorm","ReConformerEncoder"] -from .repeat import * from .attention import * from .layer_norm import * -from .TransformerEncoder import * \ No newline at end of file +from .encoder import * \ No newline at end of file diff --git a/pytorch/libs/nnet/transformer/attention.py b/pytorch/libs/nnet/transformer/attention.py index 6402b63..7e110c8 100644 --- a/pytorch/libs/nnet/transformer/attention.py +++ b/pytorch/libs/nnet/transformer/attention.py @@ -1,45 +1,75 @@ # -*- coding:utf-8 -*- +# Copyright xmuspeech (Author: Leo 2022-07) +# Reference: https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/attention.py -# Reference: https://github.com/espnet/espnet. -import math +"""Multi-Head Attention layer definition.""" -import numpy +from cmath import pi +import math +from typing import Optional, Tuple, Dict, Any import torch from torch import nn - +from libs.nnet.activation import Nonlinearity +from .scaling import ScaledLinear,ScaledConv1d,ActivationBalancer class MultiHeadedAttention(nn.Module): - """Multi-Head Attention layer + """Multi-Head Attention layer. - :param int n_head: the number of head s - :param int n_feat: the number of features - :param float dropout_rate: dropout rate + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix. + t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share a T5 rel position matrix in all layers. """ - - def __init__(self, n_head, n_feat, dropout_rate): - super(MultiHeadedAttention, self).__init__() + def __init__(self, n_head: int, n_feat: int, dropout_rate: float,add_t5rel_bias: bool =False,conv_out: bool =False,attention_norm_args: Optional[Dict[str, Any]]=None,re_scale=False): + """Construct an MultiHeadedAttention object.""" + super().__init__() assert n_feat % n_head == 0 # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.attn = None + + self.conv_out = conv_out + + self.att_norm = AttentionNormalize(self.d_k, att_type = "mlh", **attention_norm_args) + + project = ScaledLinear if re_scale else nn.Linear + + self.linear_q = project(n_feat, n_feat) + self.linear_k = project(n_feat, n_feat) + self.linear_v = project(n_feat, n_feat) + if self.conv_out: + conv = ScaledConv1d if re_scale else nn.Conv1d + self.linear_out = conv(n_feat, n_feat, 3, + stride=1, padding=1) + else: + self.linear_out = ScaledLinear(n_feat, n_feat,initial_scale=0.25) if re_scale else nn.Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) - def forward(self, query, key, value, mask): - """Compute 'Scaled Dot Product Attention' + self.add_t5rel_bias = add_t5rel_bias + + self.t5rel_module = T5RelPositionBias(self.d_k**0.5) if self.add_t5rel_bias else None + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). - :param torch.Tensor query: (batch, time1, size) - :param torch.Tensor key: (batch, time2, size) - :param torch.Tensor value: (batch, time2, size) - :param torch.Tensor mask: (batch, time1, time2) - :param torch.nn.Dropout dropout: - :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model) - weighted by the query dot key attention (batch, head, time1, time2) """ n_batch = query.size(0) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) @@ -49,16 +79,655 @@ def forward(self, query, key, value, mask): k = k.transpose(1, 2) # (batch, head, time2, d_k) v = v.transpose(1, 2) # (batch, head, time2, d_k) - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # (batch, head, time1, time2) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) - min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + return q, k, v + + + + + + def forward_attention(self, value: torch.Tensor, scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + attn = self.att_norm(scores,mask) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + if self.conv_out: + x = self.linear_out(x.transpose(-1, 1)).transpose(-1, 1) + else: + x = self.linear_out(x) + + return x # (batch, time1, d_model) + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0)) -> torch.Tensor: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + Wenet. + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + + scores = torch.matmul(q, k.transpose(-2, -1)) + if self.add_t5rel_bias and self.t5rel_module is not None: + scores+=self.t5rel_module(scores) + + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix. + t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share a T5 rel position matrix in all layers. + """ + def __init__(self, n_head: int, n_feat: int, dropout_rate: float,add_t5rel_bias: bool =False,conv_out: bool =False,attention_norm_args: Optional[Dict[str, Any]]=None,re_scale=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate,add_t5rel_bias,conv_out,attention_norm_args,re_scale) + # linear transformation for positional encoding + project = ScaledLinear if re_scale else nn.Linear + self.linear_pos = project(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x, zero_triu: bool = False): + """Compute relative positinal encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, size). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + torch.Tensor: Output tensor. + """ + + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, mask: torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0)): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + # matrix_bd = self.rel_shift(matrix_bd) + + scores = matrix_ac + matrix_bd # (batch, head, time1, time2) + if self.add_t5rel_bias and self.t5rel_module is not None: + scores+=self.t5rel_module(scores) + + return self.forward_attention(v, scores, mask) + +# RoPE (Leo 2022-07-25) +# reference: +# RoFormer: Enhanced Transformer with Rotary Position Embedding. +class RoPESelfAttention(MultiHeadedAttention): + def __init__(self, n_head: int, n_feat: int, dropout_rate: float,add_t5rel_bias: bool =False,conv_out: bool =False, attention_norm_args: Optional[Dict[str, Any]]=None, rotary_value: bool = True,re_scale=False): + """Construct an RelPositionMultiHeadedAttention object. + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + rotary_value (bool): add rotary positon to value tensor. + add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix. + t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share one T5 rel position matrix in all layers. + """ + super().__init__(n_head, n_feat, dropout_rate,add_t5rel_bias,conv_out,attention_norm_args,re_scale) + self.rotary_value = rotary_value + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb:torch.Tensor = torch.empty(0)): + """Compute 'Scaled Dot Product Attention' with rotary positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + pos_emb : Positional embedding tensor(#batch, time2, size//2) of query and key_value, + each is a tuple contains sin part tensor and cos part tensor. + Returns: + torch.Tensor: Output tensor (#batch, time1, size). + """ + q, k, v = self.forward_qkv(query, key, value) + + + q = self.apply_rotary(q,pos_emb) + k = self.apply_rotary(k,pos_emb) + if self.rotary_value: + v = self.apply_rotary(v,pos_emb) + scores = torch.matmul(q, k.transpose(-2, -1)) + if self.add_t5rel_bias and self.t5rel_module is not None: + scores+=self.t5rel_module(scores) + + + return self.forward_attention(v, scores, mask) + + @staticmethod + def apply_rotary(x, sinusoidal_pos): + sin, cos = sinusoidal_pos.chunk(2,dim=-1) # (1, time1, d_model//2) + + x1, x2 = x[..., 0::2], x[..., 1::2] + + return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1) + +# T5RelPE (Leo 2022-07-25) +# reference: +# Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer +# https://arxiv.org/abs/1910.10683 +class T5RelPositionBias(torch.nn.Module): + """T5 rel_pos, which is trainalble bias matrix added to attention score. + Args: + scale (float): scale the bias, usually set to query_dim**0.5. + causal (bool): true, means omit the future. + num_buckets (int): relative to the length of sensitive area. + max_distance (int): when distance > max_distance,they share the same bias. + """ + def __init__( + self, + scale, + causal = False, + num_buckets = 32, + max_distance = 128 + ): + super().__init__() + self.scale = scale + self.causal = causal + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = torch.nn.Embedding(num_buckets, 1) + + @staticmethod + def _relative_position_bucket( + relative_position, + causal = False, + num_buckets = 32, + max_distance = 128 + ): + ret = 0 + n = -relative_position + if not causal: + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, x): + """Get the rel bias from relative_position_bucket. + Args: + x (torch.Tensor): attention scores. (#batch, head, time1, time2). + Returns: + torch.Tensor: bias tensor. (#time1, time2). + + """ + i, j, device = *x.shape[-2:], x.device + # get q,k position + q_pos = torch.arange(i, dtype = torch.long, device = device) + k_pos = torch.arange(j, dtype = torch.long, device = device) + # calculate the relative position distance matrix (i, j) + rel_pos = k_pos.unsqueeze(0) - q_pos.unsqueeze(1) + rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + bias = values.squeeze(-1) + return bias * self.scale + + + + + +class OffsetScale(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(1,dim)) + self.beta = nn.Parameter(torch.zeros(1,dim)) + nn.init.xavier_uniform_(self.gamma) + + def forward(self, x): + out = x*self.gamma + self.beta + return out + +# (Leo 2022-08-10) +class GAU(nn.Module): + """Gated Attention Unit. now it just support selfatt for the whole sequence. + + Args: + hidden_dim (int): The size of hidden_dim, recommend 2*n_feat. + n_feat (int): The number of features. + d_k (int): Dim of query,key, default (128). + dropout_rate (float): Dropout rate. + activation_type (str): activation function. + add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix. + t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share a T5 rel position matrix in all layers. + """ + def __init__(self, n_feat: int, hidden_dim:int ,d_k: int = 128 ,dropout_rate: float = 0., + conv_out: bool =False, + attention_norm_args: Optional[Dict[str, Any]]=None, + re_scale: bool=False, + activation_type='swish', + add_t5rel_bias: bool =False, + + ): + """Construct an MultiHeadedAttention object.""" + super().__init__() + self.d_k =d_k + self.conv_out = conv_out + project = ScaledLinear if re_scale else nn.Linear + banlancer = nn.Identity if re_scale else ActivationBalancer + self.to_gate = nn.Sequential( + project(n_feat, hidden_dim), + banlancer(), + Nonlinearity(activation_type) + ) + self.to_v = nn.Sequential( + project(n_feat, hidden_dim), + banlancer(), + Nonlinearity(activation_type) + ) + self.to_qk = nn.Sequential( + project(n_feat, d_k), + banlancer(), + Nonlinearity(activation_type) + ) + if self.conv_out: + conv = ScaledConv1d if re_scale else nn.Conv1d + self.to_out = nn.Sequential( + conv(hidden_dim, n_feat, 3, + stride=1, padding=1) + ) + else: + self.to_out = nn.Sequential( + project(hidden_dim, n_feat) + ) + self.att_norm = AttentionNormalize(self.d_k, att_type = "gau", **attention_norm_args) + self.dropout = nn.Dropout(dropout_rate) + self.scale_q = OffsetScale(d_k) + self.scale_k = OffsetScale(d_k) + + self.add_t5rel_bias = add_t5rel_bias + self.t5rel_module = T5RelPositionBias(self.d_k**0.5) if self.add_t5rel_bias else None + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + now it assume q = k = v + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, time2, d_k). + torch.Tensor: gate tensor, size + (#batch, time1, hidden_dim). + torch.Tensor: value tensor, size + (#batch, time2, hidden_dim). + """ + + u = self.to_gate(query) # (batch, time1, hidden_dim) + v = self.to_v(value) # (batch, time2, hidden_dim) + # here q is the whole sequence. + qk = self.to_qk(key) # (batch, time1, d_k) + q = self.scale_q(qk) + k = self.scale_k(qk) + + return q, k, u, v + + def forward_qkuv(self, query: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + now it assume q = k = v + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, time2, d_k). + torch.Tensor: gate tensor, size + (#batch, time1, hidden_dim). + torch.Tensor: value tensor, size + (#batch, time2, hidden_dim). + """ + + u = self.to_gate(query) # (batch, time1, hidden_dim) + v = self.to_v(query) # (batch, time2, hidden_dim) + # here q is the whole sequence. + qk = self.to_qk(query) # (batch, time1, d_k) + q = self.scale_q(qk) + k = self.scale_k(qk) + + return q, k, u, v + + def forward_attention(self, u: torch.Tensor, value: torch.Tensor, scores: torch.Tensor, + mask:torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)) -> torch.Tensor: + """Compute attention context vector. + + Args: + u (torch.Tensor): Transformed gate, size + (#batch, time1, hidden_dim). + value (torch.Tensor): Transformed value, size + (#batch, time2, hidden_dim). + scores (torch.Tensor): Attention score, size + (#batch, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + attn = self.att_norm(scores,mask) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, time1, time2) (batch, time2, hidden_dim) -> (batch, time1, hidden_dim) + x = u * x + if self.conv_out: + x = self.to_out(x.transpose(-1, 1)).transpose(-1, 1) else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + x = self.to_out(x) + return x # (batch, time1, n_feat) + + def forward(self, query: torch.Tensor, key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0)) -> torch.Tensor: + """Compute Gated Attention. + now it just support selfatt for the whole sequence, + which means q = k = v. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + Wenet. + + + Returns: + torch.Tensor: Output tensor (#batch, time1, n_feat). + + """ + q, k, u, v = self.forward_qkuv(query) + + scores = torch.matmul(q, k.transpose(-2, -1)) + if self.add_t5rel_bias and self.t5rel_module is not None: + scores+=self.t5rel_module(scores) + + return self.forward_attention(u, v, scores, mask) + + +class RoPEGAU(GAU): + def __init__(self, n_feat: int, hidden_dim:int ,d_k: int = 128 ,dropout_rate: float = 0., + conv_out: bool =False, + attention_norm_args: Optional[Dict[str, Any]]=None, + re_scale: bool=False, + activation_type='swish', + add_t5rel_bias: bool =False, + ): + """Construct an RelPositionMultiHeadedAttention object. + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + rotary_value (bool): add rotary positon to value tensor. + add_t5rel_bias (bool): whether apply T5 rel_pos on attention score matrix. + t5rel_module (torch.nn.Module): T5Rel module instance, if not None, means share one T5 rel position matrix in all layers. + """ + super().__init__(n_feat, hidden_dim, d_k ,dropout_rate, + conv_out, + attention_norm_args, + re_scale, + activation_type, + add_t5rel_bias, + ) + + + def forward(self, query: torch.Tensor, key: Optional[torch.Tensor], + value: Optional[torch.Tensor], mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0)): + """Compute 'Scaled Dot Product Attention' with rotary positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + pos_emb : Positional embedding tensor(#batch, time2, size//2) of query and key_value, + each is a tuple contains sin part tensor and cos part tensor. + Returns: + torch.Tensor: Output tensor (#batch, time1, size). + """ + q, k, u, v = self.forward_qkuv(query) + + q = self.apply_rotary(q,pos_emb) + k = self.apply_rotary(k,pos_emb) + + scores = torch.matmul(q, k.transpose(-2, -1)) + if self.add_t5rel_bias and self.t5rel_module is not None: + scores+=self.t5rel_module(scores) + return self.forward_attention(u,v, scores, mask) + + @staticmethod + def apply_rotary(x, sinusoidal_pos): + sin, cos = sinusoidal_pos.chunk(2,dim=-1) # (1, time1, d_model//2) + x1, x2 = x[..., 0::2], x[..., 1::2] + + return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1) + + +# (Leo 2022-08-10) +class AttentionNormalize(nn.Module): + def __init__(self, + d_k: int, + att_type = "mlh", + scale_adapt: bool=False, + norm_method: str='softmax', + diag_mask: bool=False, + g_sa: bool=False, + train_len : int = 512, + dim: int=-1): + super().__init__() + self.method = norm_method + self.dim = dim + self.scale_adapt = scale_adapt + if self.scale_adapt: + self.scale = nn.Parameter(torch.log(torch.tensor(d_k**-0.5))) + else: + self.scale = torch.tensor(math.sqrt(d_k)) + self.diag_mask = diag_mask + self.att_type = att_type + self.g_sa = g_sa + self.omiga = torch.tensor(0.001) + self.bias = torch.zeros(1)-0.001 + if self.g_sa: + if 'softmax' not in norm_method: + raise ValueError("g_sa just support softmax form calculate now") + # self.omiga = nn.Parameter(torch.abs(nn.init.trunc_normal_(torch.empty(1)))+0.001) + # self.bias = nn.Parameter(-torch.abs(nn.init.trunc_normal_(torch.empty(1)))) + self.omiga = nn.Parameter(torch.tensor(0.001)) + self.bias = nn.Parameter(torch.zeros(1)-0.001) + # self.train_len = torch.tensor(math.log(train_len)) + + self.train_len = nn.Parameter(torch.tensor(math.log(train_len))) if self.method =="softmax_plus" else torch.empty(0) + def forward(self, scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)): + if self.g_sa: + i, j=scores.shape[-2:] + device = scores.device + q_pos = torch.arange(j-i,j, dtype = torch.long, device = device) + k_pos = torch.arange(j, dtype = torch.long, device = device) + # calculate the relative position distance matrix (i, j) + dis_matrix = (k_pos.unsqueeze(0) - q_pos.unsqueeze(1))**2 + dis_matrix = -torch.abs(torch.abs(dis_matrix*self.omiga) - torch.abs(self.bias)) # (i, j) + scores = scores + dis_matrix + + if self.scale_adapt: + scores = scores*(self.scale.exp()) + else: + scores = scores/self.scale + if mask.size(2) > 0 : # time2 > 0 + mask = mask[..., :scores.size(-1)] + if self.diag_mask: + mask = (~torch.eye(scores.shape[-2],scores.shape[-1],dtype=torch.bool,device=scores.device)) & mask + if self.att_type == "gau": + mask = mask.eq(0) # (batch, *, time2) + else: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + + scores = scores.masked_fill(mask, -1e4) + attn = self.attention_normalize(scores, dim=self.dim,method=self.method).masked_fill( + mask, 0.0) # (batch, head, time1, time2) or # (batch, time1, time2) + else: + attn = self.attention_normalize(scores, dim=self.dim,method=self.method) # (batch, head, time1, time2) or (batch, time1, time2) + return attn + + def attention_normalize(self, + a: torch.Tensor, + dim: int=-1, + method: str='softmax'): + """attention score normalization + softmax + relu_plus: https://arxiv.org/abs/2202.10447 + softmax_plus: https://kexue.fm/archives/8823 。 + """ + assert method in ['softmax','relu_plus','softmax_plus'] + if method == 'softmax': + return torch.softmax(a, dim=dim) + + else: + mask = (a > -1e4).float() + l = torch.sum(mask ,dim=dim, keepdim=True).clamp_(1.) + + if method == 'relu_plus': + return torch.relu(a)**2 / l + elif method == 'softmax_plus': + + scale = torch.log(l) / self.train_len * mask + 1 - mask + + return torch.softmax(a * scale, dim=dim) + + else: + raise ValueError("check your attention norm ") - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, v) # (batch, head, time1, d_k) - x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) - return self.linear_out(x) # (batch, time1, d_model) + return a \ No newline at end of file diff --git a/pytorch/libs/nnet/transformer/convolution.py b/pytorch/libs/nnet/transformer/convolution.py new file mode 100644 index 0000000..1b441c8 --- /dev/null +++ b/pytorch/libs/nnet/transformer/convolution.py @@ -0,0 +1,231 @@ +# -*- coding:utf-8 -*- + +# Reference: https://github.com/espnet/espnet. + +"""ConvolutionModule definition.""" + + +from typing import Optional, Tuple +import torch +from torch import nn +from .scaling import ActivationBalancer,ScaledConv1d + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + causal (int): Whether use causal convolution or not + """ + + def __init__( + self, + channels: int, + kernel_size: int = 15, + activation:nn.Module = nn.ReLU(), + norm: str = 'batch_norm', + causal: bool=False, + activation_balancer: bool=False, + bias: bool=True + ): + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + assert norm in ['batch_norm', 'layer_norm', "basic_norm"] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + self.balancer1,self.balancer2 = None,None + self.activation_balancer = activation_balancer + if activation_balancer: + self.balancer1=ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + self.balancer2=ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + def forward(self, + x: torch.Tensor, + mask_pad: Optional[torch.Tensor] = None, + )-> torch.Tensor: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time) + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # mask batch padding + if mask_pad is not None: + x.masked_fill_(~mask_pad, 0.0) + if self.lorder>0: + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + if self.activation_balancer and self.balancer1 is not None: + x = self.balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + + if self.activation_balancer and self.balancer2 is not None: + x = self.balancer2(x) + + x = self.activation(x) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad is not None: + x.masked_fill_(~mask_pad, 0.0) + return x.transpose(1, 2) + +class ReConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + causal (int): Whether use causal convolution or not + """ + + def __init__( + self, + channels: int, + kernel_size: int = 15, + activation:nn.Module = nn.ReLU(), + causal: bool=False, + bias: bool=True + ): + """Construct an ConvolutionModule object.""" + super(ReConvolutionModule, self).__init__() + + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + self.activation = activation + + self.balancer1=ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + self.balancer2=ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + def forward(self, + x: torch.Tensor, + mask_pad: Optional[torch.Tensor] = None, + )-> torch.Tensor: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time) + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + # mask batch padding + if mask_pad is not None: + x.masked_fill_(~mask_pad, 0.0) + if self.lorder>0: + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + + x = self.balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + x = self.balancer2(x) + + x = self.activation(x) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad is not None: + x.masked_fill_(~mask_pad, 0.0) + return x.transpose(1, 2) \ No newline at end of file diff --git a/pytorch/libs/nnet/transformer/embedding.py b/pytorch/libs/nnet/transformer/embedding.py index b22ed7b..6346957 100644 --- a/pytorch/libs/nnet/transformer/embedding.py +++ b/pytorch/libs/nnet/transformer/embedding.py @@ -1,109 +1,198 @@ # -*- coding:utf-8 -*- - -# Reference: https://github.com/espnet/espnet. +# Copyright xmuspeech (Author: Leo 2022-07) +# Reference: https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/embedding.py. """Positonal Encoding Module.""" import math -import torch - - -def _pre_hook(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - """Perform pre-hook in load_state_dict for backward compatibility. +from typing import Tuple - Note: - We saved self.pe until v.0.5.2 but we have omitted it later. - Therefore, we remove the item "pe" from `state_dict` for backward compatibility. - - """ - k = prefix + "pe" - if k in state_dict: - state_dict.pop(k) +import torch +def get_abs_position(dim: int, + max_len: int) -> torch.Tensor: + abs_pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, dim, 2, dtype=torch.float32) * + -(math.log(10000.0) / dim)) + abs_pe[:, 0::2] = torch.sin(position * div_term) + abs_pe[:, 1::2] = torch.cos(position * div_term) + return abs_pe.unsqueeze(0) class PositionalEncoding(torch.nn.Module): - """Positional encoding.""" - - def __init__(self, d_model, dropout_rate, max_len=5000): - """Initialize class. - - :param int d_model: embedding dim - :param float dropout_rate: dropout rate - :param int max_len: maximum input length - - """ - super(PositionalEncoding, self).__init__() + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + :param int att_h: invalid here,for compatibility to RoPositionalEncoding + :param bool rope_abs_plus: invalid here,for compatibility to RoPositionalEncoding + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + def __init__(self, + d_model: int, + dropout_rate: float, + att_h: int=4, + rope_abs_plus: bool=False, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - self._register_load_state_dict_pre_hook(_pre_hook) - - def extend_pe(self, x): - """Reset the positional encodings.""" - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * - -(math.log(10000.0) / self.d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor): + self.max_len = max_len + + self.pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + self.pe[:, 0::2] = torch.sin(position * div_term) + self.pe[:, 1::2] = torch.cos(position * div_term) + self.pe = self.pe.unsqueeze(0) + + def forward(self, + x: torch.Tensor, + offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int): position offset Returns: torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) - + torch.Tensor: for compatibility to RelPositionalEncoding """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, :x.size(1)] - return self.dropout(x) + assert offset + x.size(1) < self.max_len + self.pe = self.pe.to(x.device) + pos_emb = self.pe[:, offset:offset + x.size(1)] + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + def position_encoding(self, offset: int, size: int) -> torch.Tensor: + """ For getting encoding in a streaming fashion -class ScaledPositionalEncoding(PositionalEncoding): - """Scaled positional encoding module. + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. - See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf + Args: + offset (int): start offset + size (int): requried size of position encoding + Returns: + torch.Tensor: Corresponding encoding + """ + assert offset + size < self.max_len + return self.dropout(self.pe[:, offset:offset + size]) + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + att_h (int): invalid here,for compatibility to RoPositionalEncoding + rope_abs_plus (bool): invalid here,for compatibility to RoPositionalEncoding + max_len (int): Maximum input length. """ - - def __init__(self, d_model, dropout_rate, max_len=5000): - """Initialize class. - - :param int d_model: embedding dim - :param float dropout_rate: dropout rate - :param int max_len: maximum input length - + def __init__(self, d_model: int, dropout_rate: float, att_h: int=4, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len=max_len, reverse=True) + + def forward(self, + x: torch.Tensor, + offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). """ - super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) - self.alpha = torch.nn.Parameter(torch.tensor(1.0)) + assert offset + x.size(1) < self.max_len + self.pe = self.pe.to(x.device) + x = x * self.xscale + pos_emb = self.pe[:, offset:offset + x.size(1)] + return self.dropout(x), self.dropout(pos_emb) - def reset_parameters(self): - """Reset parameters.""" - self.alpha.data = torch.tensor(1.0) - def forward(self, x): - """Add positional encoding. +class NoPositionalEncoding(torch.nn.Module): + """ No position encoding + """ + def __init__(self, d_model: int, dropout_rate: float, att_h: int=4,rope_abs_plus: bool=False): + super().__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + def forward(self, + x: torch.Tensor, + offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + """ Just return zero vector for interface compatibility + """ + pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) + return self.dropout(x), pos_emb + + def position_encoding(self, offset: int, size: int) -> torch.Tensor: + return torch.zeros(1, size, self.d_model) + + +# RoPE (Leo 2022-07-25) +# reference: +# RoFormer: Enhanced Transformer with Rotary Position Embedding. +class RoPositionalEncoding(PositionalEncoding): + """ Rotary positional encoding module. The cos features of sinusoidal are organized in + the 2nd half of the vector. + Args: + d_embed (int): Embedding dimension. + dropout_rate (float): Dropout rate. + att_h (int): attention head num. + rope_abs_plus (bool): rope plus abs. + max_len (int): Maximum input length. + """ + def __init__(self, d_embed: int, dropout_rate: float ,att_h: int=4 ,rope_abs_plus: bool=False, max_len: int = 5000 , d_roembed: int=-1): + """Initialize class.""" + super().__init__(d_embed, dropout_rate, max_len=max_len) + if d_roembed < 1: + assert (d_embed % att_h) % 2 == 0 + d_roembed = d_embed //att_h + else: + d_roembed = d_roembed + abs_rope = get_abs_position(d_roembed, self.max_len) + freq = torch.zeros_like(abs_rope) + sentinel = d_roembed // 2 + freq[...,0:sentinel] = abs_rope[...,0::2] + freq[...,sentinel:] = abs_rope[...,1::2] + self.abs_pe = self.pe + self.pe = freq + self.rope_abs_plus = rope_abs_plus + + def forward(self, + x: torch.Tensor, + offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. Args: - x (torch.Tensor): Input. Its shape is (batch, time, ...) - + x (torch.Tensor): Input tensor (batch, time, `*`). Returns: - torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) - + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). """ - self.extend_pe(x) - x = x + self.alpha * self.pe[:, :x.size(1)] - return self.dropout(x) + assert offset + x.size(1) < self.max_len + self.pe = self.pe.to(x.device) + x = x * self.xscale + pos_emb = self.pe[:, offset:offset + x.size(1)] + if self.rope_abs_plus: + self.abs_pe = self.abs_pe.to(x.device) + abs_pe = self.abs_pe[:, offset:offset + x.size(1)] + x += abs_pe + return self.dropout(x), pos_emb + diff --git a/pytorch/libs/nnet/transformer/encoder.py b/pytorch/libs/nnet/transformer/encoder.py new file mode 100644 index 0000000..9e4989d --- /dev/null +++ b/pytorch/libs/nnet/transformer/encoder.py @@ -0,0 +1,1104 @@ +# -*- coding:utf-8 -*- +# Copyright xmuspeech (Author: Leo 2022-07) +# Reference: https://github.com/wenet-e2e/wenet +import torch +import math +from typing import Tuple, List, Optional +from .attention import ( + MultiHeadedAttention, + RelPositionMultiHeadedAttention, + RoPESelfAttention, + RoPEGAU, + GAU +) +from .embedding import ( + PositionalEncoding, + RelPositionalEncoding, + NoPositionalEncoding, + RoPositionalEncoding + ) +from .encoder_layer import ( + TransformerEncoderLayer, + ConformerEncoderLayer, + ReConformerEncoderLayer +) +from .layer_norm import BasicNorm, LayerNorm,Trans_Bat +from .multi_layer_conv import ( + Conv1dLinear, + MultiLayeredConv1d, +) +from .convolution import ConvolutionModule, ReConvolutionModule +from .positionwise_feed_forward import PositionwiseFeedForward +from .subsampling import ( + LinearNoSubsampling, + Conv2dSubsampling2, + Conv2dSubsampling4, + ReConv2dSubsampling4, + SVConv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, + SVConv2dSubsampling2, + FrameMerger +) +import torch.nn as nn +from .mask import add_optional_chunk_mask +from .mask import make_pad_mask +from libs.nnet.activation import Nonlinearity + +class BaseEncoder(torch.nn.Module): + """ + :param int idim: input dim + :param int attention_dim: dimention of attention + :param int attention_heads: the number of heads of multi head attention + :param int linear_units: the number of units of position-wise feed forward + :param int num_blocks: the number of encoder blocks + :param int aux_layer_period: the period for randomcombiner/mfa. + :param int aux_layer_start: the start position for randomcombiner/mfa. + default :1. e.g. 3 means from 1/3 of the total block nums suggest for randomcombiner, mfa suggest set to 1+num_blocks. + :param float dropout_rate: dropout rate + :param float attention_dropout_rate: dropout rate in attention + :param float positional_dropout_rate: dropout rate after adding positional encoding + :param str input_layer: input layer type + :param str pos_enc_type: Encoder positional encoding layer type. + opitonal [abs_pos, rel_pos, no_pos, rot_pos] + :param bool rotary_value: whether apply rot_pos on value vector when use rot_pos, which contains abs position information. + :param bool rope_abs_plus: whether apply abs_pos when use rot_pos. + :param bool add_t5rel_bias: whether apply t5_rel position in attention score. + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + :param str positionwise_layer_type: positionwise feedforward type + :param int positionwise_conv_kernel_size: Kernel size of positionwise conv1d layer. + :param int static_chunk_size: chunk size for static chunk training and decoding. + :param bool use_dynamic_chunk: whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dynamic chunk size(use_dynamic_chunk = True) + :param bool use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + :param str comnbiner_type: combine the output of encoder with its sublayers. + opitonal [norm, mfa, random_frame, random_layer] + """ + aux_layers: List[int] + def __init__(self, idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + mlp_head=False, + num_blocks=6, + aux_layer_period = 3, + aux_layer_start = 1, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + attention_conv_out = False, + attention_norm_args = {}, + input_layer="conv2d", + pos_enc_type="abs_pos", + rotary_value=True, + rope_abs_plus = False, + add_t5rel_bias = False, + att_type: str= 'multi', + gau_units: int = 512, + gau_key: int = 64, + normalize_before=True, + norm_type = "layer_norm", + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + activation_type='relu', + activation_balancer=False, + static_chunk_size: int = 0, + left_chunk_size: int = -1, + use_dynamic_chunk: bool = False, + use_dynamic_left_chunk: bool = False, + combiner_type: str="norm", + re_scale: bool=False): + super().__init__() + + + self.att_type = att_type + self.mlp_head = mlp_head + + + if att_type != 'gau' and positionwise_layer_type == 'gau': + assert gau_key == attention_dim // attention_heads + pos_enc_dict = {} + if pos_enc_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_type == "rel_pos": + pos_enc_class = RelPositionalEncoding + elif pos_enc_type == "rot_pos": + pos_enc_class = RoPositionalEncoding + pos_enc_dict['rope_abs_plus']= rope_abs_plus + if self.att_type == 'gau': + pos_enc_dict['d_roembed']= gau_key + else: + pos_enc_dict['att_h']= attention_heads + elif pos_enc_type == "no_pos": + pos_enc_class = NoPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_type) + + if input_layer == "linear": + subsampling_class = LinearNoSubsampling + elif input_layer == "conv2d2": + subsampling_class = SVConv2dSubsampling2 + elif input_layer == "conv2d": + subsampling_class = Conv2dSubsampling4 + elif input_layer == "re_conv2d": + subsampling_class = ReConv2dSubsampling4 + elif input_layer == "conv2d6": + subsampling_class = Conv2dSubsampling6 + elif input_layer == "conv2d8": + subsampling_class = Conv2dSubsampling8 + else: + raise ValueError("unknown input_layer: " + input_layer) + self.embed = subsampling_class(idim, attention_dim, dropout_rate,mlp_head,pos_enc_class(attention_dim,positional_dropout_rate, **pos_enc_dict)) + self.pos_enc_type = pos_enc_type + self.positionwise_layer, self.positionwise_layer_args = self.get_positionwise_layer( + positionwise_layer_type, + attention_dim, + linear_units, + dropout_rate, + positionwise_conv_kernel_size, + activation_type, + activation_balancer, + re_scale, + attention_norm_args=attention_norm_args + ) + + if self.att_type == 'gau': + self.selfattn_layer, self.selfattn_layer_args = self.get_gau_layer( + pos_enc_type, + attention_dim, + gau_units, + gau_key, + attention_dropout_rate, + attention_conv_out, + attention_norm_args=attention_norm_args, + re_scale=re_scale + ) + else: + self.selfattn_layer, self.selfattn_layer_args = self.get_selfattn_layer( + pos_enc_type, + attention_heads, + attention_dim, + attention_dropout_rate, + add_t5rel_bias, + attention_conv_out, + rotary_value, + attention_norm_args=attention_norm_args, + re_scale=re_scale + ) + + self.aux_layers,self.combiner = self.get_combiner( + num_blocks, + aux_layer_period, + aux_layer_start, + combiner_type + ) + + self.normalize_before = normalize_before + self._output_size = attention_dim*len(self.aux_layers) if combiner_type=="mfa" else attention_dim + self.after_norm = None + if self.normalize_before or combiner_type=="mfa": + if norm_type == "batch_norm": + self.use_layer_norm=False + self.after_norm = Trans_Bat(self._output_size) + else: + self.use_layer_norm=True + if norm_type == "layer_norm": + self.after_norm = LayerNorm(self._output_size,eps=1e-5) + else: + self.after_norm = BasicNorm(self._output_size) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.left_chunk_size = left_chunk_size + + def get_positionwise_layer( + self, + positionwise_layer_type="linear", + attention_dim=256, + linear_units=2048, + dropout_rate=0.1, + positionwise_conv_kernel_size=1, + activation_type='relu', + activation_balancer=False, + re_scale = False, + gau_key = 64, + attention_norm_args = {} + ): + """Define positionwise layer.""" + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = (attention_dim, linear_units, dropout_rate, activation_type, activation_balancer,re_scale) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + activation_type, + activation_balancer, + re_scale + ) + + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + activation_type, + activation_balancer, + re_scale + ) + + elif positionwise_layer_type == "gau": + positionwise_layer, positionwise_layer_args = self.get_gau_layer( + self.pos_enc_type, + attention_dim, + linear_units, + gau_key, + dropout_rate, + attention_norm_args = attention_norm_args, + re_scale=re_scale + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + return positionwise_layer, positionwise_layer_args + + + def get_selfattn_layer( + self, + pos_enc_type="abs_pos", + attention_heads=4, + attention_dim=256, + attention_dropout_rate=0.0, + add_t5rel_bias=False, + conv_out=False, + rotary_value = False, + attention_norm_args={}, + re_scale=False + ): + """Define selfattn layer.""" + selfattn_layer_args = (attention_heads,attention_dim,attention_dropout_rate,add_t5rel_bias,conv_out,attention_norm_args,re_scale) + + if pos_enc_type == "rel_pos": + selfattn_layer = RelPositionMultiHeadedAttention + elif pos_enc_type == "rot_pos": + selfattn_layer = RoPESelfAttention + selfattn_layer_args=(attention_heads,attention_dim,attention_dropout_rate,add_t5rel_bias,conv_out,attention_norm_args,rotary_value,re_scale) + else: + selfattn_layer = MultiHeadedAttention + + return selfattn_layer, selfattn_layer_args + + def get_gau_layer( + self, + pos_enc_type="abs_pos", + attention_dim=256, + hidden_dim = 512, + d_qk = 64, + attention_dropout_rate=0.0, + conv_out=False, + attention_norm_args={}, + re_scale=False + ): + """Define gau layer.""" + + if pos_enc_type == "abs_pos": + selfattn_layer = GAU + else: + selfattn_layer = RoPEGAU + selfattn_layer_args=(attention_dim,hidden_dim,d_qk,attention_dropout_rate,conv_out,attention_norm_args,re_scale) + + return selfattn_layer, selfattn_layer_args + + def get_combiner(self, + num_blocks: int, + aux_layer_period: int = 3, + aux_layer_start: int = 1, + combiner_type="norm" + ) -> Tuple[List[int],torch.nn.Module]: + """Define combiner layer.""" + assert combiner_type in ["norm", "mfa", "random_frame", "random_layer"], "unknown combiner_type {}".format(combiner_type) + assert aux_layer_period,aux_layer_start > 0 + assert num_blocks > 0 + aux_layers=list( + range( + num_blocks // aux_layer_start, + num_blocks - 1, + aux_layer_period, + ) + ) + assert len(set(aux_layers)) == len(aux_layers) + assert num_blocks - 1 not in aux_layers + aux_layers = aux_layers + [num_blocks - 1] + combiner = RandomCombine( + aux_layers=aux_layers, + combiner_type=combiner_type, + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + return aux_layers,combiner + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = -1, + decoding_left_chunk_size: int = -2, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor + + :param torch.Tensor xs: input tensor (B, T, D) + :param torch.Tensor xs_lens: input length (B) + :param int decoding_chunk_size: decoding chunk size for dynamic chunk + 0: use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + :param int decoding_left_chunk_size: + <-1: default , the number of left chunks is self.left_chunk_size. + -1: use full left chunk + 0: no left chunk + >0: the number of left chunks + : param float warmup: + Model level warmup, a floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + if decoding_left_chunk_size <= -2: + left_chunk_size = self.left_chunk_size + else: + left_chunk_size = decoding_left_chunk_size + T = xs.size(1) + if self.mlp_head: + T+=1 + + xs_lens+=1 + + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + + xs, pos_emb, masks = self.embed(xs, masks) + + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + left_chunk_size) + + out= [] + + for i,layer in enumerate(self.encoders): + xs, chunk_masks = layer(xs, chunk_masks, pos_emb, mask_pad, warmup = warmup) + + if i in self.aux_layers: + out.append(xs) + if len(out)>0: + xs = self.combiner(out) + + if self.after_norm is not None: + if not self.use_layer_norm: + xs = xs.transpose(1, 2) + xs = self.after_norm(xs) + if not self.use_layer_norm: + xs = xs.transpose(1, 2) + return xs, masks + + +class TransformerEncoder(BaseEncoder): + """Transformer encoder module.""" + def __init__( + self, + idim: int, + attention_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + mlp_head: bool=False, + num_blocks: int = 6, + aux_layer_period: int = 3, + aux_layer_start:int = 1, + dropout_rate: float = 0.1, + layer_dropout: float = 0., + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + attention_conv_out : bool = False, + attention_norm_args : dict = {}, + input_layer: str = "conv2d", + pos_enc_type: str = "abs_pos", + rotary_value: bool = True, + rope_abs_plus: bool = False, + add_t5rel_bias: bool = False, + att_type: str= 'multi', + gau_units: int = 512, + gau_key: int = 64, + normalize_before: bool = True, + norm_type: str = "layer_norm", + concat_after: bool = False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + activation_type='relu', + activation_balancer: bool=False, + static_chunk_size: int = 0, + left_chunk_size: int = -1, + use_dynamic_chunk: bool = False, + use_dynamic_left_chunk: bool = False, + combiner_type: str="norm", + re_scale: bool=False, + convfnn_blocks: int=0, + **args + ): + """ Construct TransformerEncoder + + See Encoder for the meaning of each parameter. + """ + + super().__init__(idim, attention_dim, attention_heads, linear_units, + mlp_head, num_blocks, aux_layer_period, aux_layer_start, dropout_rate, + positional_dropout_rate, attention_dropout_rate, attention_conv_out, + attention_norm_args, input_layer, pos_enc_type, rotary_value, rope_abs_plus, + add_t5rel_bias, att_type, gau_units, gau_key, normalize_before, norm_type, + concat_after, positionwise_layer_type, positionwise_conv_kernel_size, + activation_type, activation_balancer, static_chunk_size, left_chunk_size, + use_dynamic_chunk, use_dynamic_left_chunk,combiner_type, re_scale) + pre_selfattn_layer, pre_selfattn_layer_args = self.selfattn_layer,self.selfattn_layer_args + if positionwise_layer_type == "gau": + pre_positionwise_layer,pre_positionwise_layer_args = self.get_gau_layer( + self.pos_enc_type, + attention_dim, + linear_units, + gau_key, + dropout_rate, + conv_out =True, + attention_norm_args = attention_norm_args, + re_scale=re_scale + ) + + pre_selfattn_layer, pre_selfattn_layer_args = self.get_gau_layer( + pos_enc_type, + attention_dim, + gau_units, + gau_key, + attention_dropout_rate, + conv_out =True, + attention_norm_args=attention_norm_args, + re_scale=re_scale + ) + else: + pre_positionwise_layer = MultiLayeredConv1d + pre_positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + activation_type, + activation_balancer, + re_scale + ) + + encoders = [] + for _ in range(convfnn_blocks): + encoders.append(TransformerEncoderLayer( + attention_dim, + pre_selfattn_layer(*pre_selfattn_layer_args), + pre_positionwise_layer(*pre_positionwise_layer_args), + dropout_rate, + layer_dropout, + normalize_before, + norm_type, + positionwise_layer_type, + concat_after, + )) + for _ in range(num_blocks-convfnn_blocks): + encoders.append(TransformerEncoderLayer( + attention_dim, + self.selfattn_layer(*self.selfattn_layer_args), + self.positionwise_layer(*self.positionwise_layer_args),dropout_rate, + layer_dropout,normalize_before, norm_type,positionwise_layer_type,concat_after)) + + self.encoders = torch.nn.ModuleList(encoders) + + +class ConformerEncoder(BaseEncoder): + """Conformer encoder module.""" + def __init__( + self, + idim: int, + attention_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + mlp_head: bool=False, + num_blocks: int = 6, + aux_layer_period: int = 3, + aux_layer_start: int = 1, + dropout_rate: float = 0.1, + layer_dropout: float = 0., + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + attention_conv_out : bool = False, + attention_norm_args : dict = {}, + input_layer: str = "conv2d", + pos_enc_type: str = "rel_pos", + rotary_value: bool = True, + rope_abs_plus: bool = False, + add_t5rel_bias: bool = False, + att_type: str= 'multi', + gau_units: int = 512, + gau_key: int = 64, + normalize_before: bool = True, + norm_type: str = "layer_norm", + concat_after: bool = False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=3, + activation_type: str = "swish", + activation_balancer: bool=False, + static_chunk_size: int = 0, + left_chunk_size: int = -1, + use_dynamic_chunk: bool = False, + use_dynamic_left_chunk: bool = False, + macaron_style: bool = True, + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + combiner_type: str="norm", + re_scale: bool=False, + convfnn_blocks: int=0, + **args + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + """ + + super().__init__(idim, attention_dim, attention_heads, linear_units, + mlp_head, num_blocks, aux_layer_period, aux_layer_start, dropout_rate, + positional_dropout_rate, attention_dropout_rate, attention_conv_out, + attention_norm_args,input_layer, pos_enc_type, rotary_value, rope_abs_plus, + add_t5rel_bias, att_type, gau_units, gau_key, normalize_before, norm_type, + concat_after, positionwise_layer_type, positionwise_conv_kernel_size, + activation_type, activation_balancer,static_chunk_size, left_chunk_size, + use_dynamic_chunk, use_dynamic_left_chunk,combiner_type,re_scale) + + activation = Nonlinearity(activation_type) + assert activation is not None + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation, + cnn_module_norm, causal, activation_balancer) + pre_selfattn_layer, pre_selfattn_layer_args = self.selfattn_layer,self.selfattn_layer_args + if positionwise_layer_type == "gau": + pre_positionwise_layer,pre_positionwise_layer_args = self.get_gau_layer( + self.pos_enc_type, + attention_dim, + linear_units, + gau_key, + dropout_rate, + conv_out =True, + attention_norm_args = attention_norm_args, + re_scale=re_scale + ) + + pre_selfattn_layer, pre_selfattn_layer_args = self.get_gau_layer( + pos_enc_type, + attention_dim, + gau_units, + gau_key, + attention_dropout_rate, + conv_out =True, + attention_norm_args=attention_norm_args, + re_scale=re_scale + ) + else: + pre_positionwise_layer = MultiLayeredConv1d + pre_positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + activation_type, + activation_balancer, + re_scale + ) + + encoders = [] + for _ in range(convfnn_blocks): + encoders.append(ConformerEncoderLayer( + attention_dim, + pre_selfattn_layer(*pre_selfattn_layer_args), + pre_positionwise_layer(*pre_positionwise_layer_args), + pre_positionwise_layer( + *pre_positionwise_layer_args) if macaron_style else None, + convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + layer_dropout, + normalize_before, + norm_type, + positionwise_layer_type, + concat_after, + )) + + for _ in range(num_blocks-convfnn_blocks): + encoders.append(ConformerEncoderLayer( + attention_dim, + self.selfattn_layer(*self.selfattn_layer_args), + self.positionwise_layer(*self.positionwise_layer_args), + self.positionwise_layer( + *self.positionwise_layer_args) if macaron_style else None, + convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + layer_dropout, + normalize_before, + norm_type, + positionwise_layer_type, + concat_after, + )) + + self.encoders = torch.nn.ModuleList(encoders) + +class ReConformerEncoder(BaseEncoder): + """Conformer encoder module.""" + def __init__( + self, + idim: int, + attention_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + mlp_head: bool=False, + num_blocks: int = 6, + aux_layer_period: int = 3, + aux_layer_start: int = 1, + dropout_rate: float = 0.1, + layer_dropout: float = 0., + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + attention_conv_out : bool = False, + attention_norm_args : dict = {}, + input_layer: str = "re_conv2d", + pos_enc_type: str = "rel_pos", + rotary_value: bool = True, + rope_abs_plus: bool = False, + add_t5rel_bias: bool = False, + att_type: str= 'multi', + gau_units: int = 512, + gau_key: int = 64, + normalize_before: bool = False, + norm_type: str = "basic_norm", + concat_after: bool = False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=3, + activation_type: str = "double_swish", + activation_balancer: bool=True, + static_chunk_size: int = 0, + left_chunk_size: int = -1, + use_dynamic_chunk: bool = False, + use_dynamic_left_chunk: bool = False, + macaron_style: bool = True, + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + combiner_type: str="norm", + re_scale: bool=True, + convfnn_blocks: int=0, + **args + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + """ + assert normalize_before == False,"set normalize_before=False." + assert norm_type == "basic_norm","set norm_type=basic_norm." + assert re_scale == True, "reconformer set res_scale=True." + assert activation_balancer == True + super().__init__(idim, attention_dim, attention_heads, linear_units, + mlp_head, num_blocks, aux_layer_period, aux_layer_start, dropout_rate, + positional_dropout_rate, attention_dropout_rate, attention_conv_out, + attention_norm_args,input_layer, pos_enc_type, rotary_value, rope_abs_plus, + add_t5rel_bias, att_type, gau_units, gau_key, normalize_before, norm_type, + concat_after, positionwise_layer_type, positionwise_conv_kernel_size, + activation_type, activation_balancer,static_chunk_size, left_chunk_size, + use_dynamic_chunk, use_dynamic_left_chunk,combiner_type,re_scale) + + activation = Nonlinearity(activation_type) + assert activation is not None + # convolution module definition + convolution_layer = ReConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation, + causal) + pre_selfattn_layer, pre_selfattn_layer_args = self.selfattn_layer,self.selfattn_layer_args + if positionwise_layer_type == "gau": + pre_positionwise_layer,pre_positionwise_layer_args = self.get_gau_layer( + self.pos_enc_type, + attention_dim, + linear_units, + gau_key, + dropout_rate, + conv_out =True, + attention_norm_args = attention_norm_args, + re_scale=re_scale + ) + + pre_selfattn_layer, pre_selfattn_layer_args = self.get_gau_layer( + pos_enc_type, + attention_dim, + gau_units, + gau_key, + attention_dropout_rate, + conv_out =True, + attention_norm_args=attention_norm_args, + re_scale=re_scale + ) + else: + pre_positionwise_layer = MultiLayeredConv1d + pre_positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + activation_type, + activation_balancer, + re_scale + ) + + encoders = [] + for _ in range(convfnn_blocks): + encoders.append(ReConformerEncoderLayer( + attention_dim, + pre_selfattn_layer(*pre_selfattn_layer_args), + pre_positionwise_layer(*pre_positionwise_layer_args), + pre_positionwise_layer( + *pre_positionwise_layer_args) if macaron_style else None, + convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + layer_dropout, + positionwise_layer_type, + concat_after, + )) + + for _ in range(num_blocks-convfnn_blocks): + encoders.append(ReConformerEncoderLayer( + attention_dim, + self.selfattn_layer(*self.selfattn_layer_args), + self.positionwise_layer(*self.positionwise_layer_args), + self.positionwise_layer( + *self.positionwise_layer_args) if macaron_style else None, + convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + layer_dropout, + positionwise_layer_type, + concat_after, + )) + + self.encoders = torch.nn.ModuleList(encoders) + + + +# RandomCombine conformer in k2, for deeper training. +# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +class RandomCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + aux_layers: list, + combiner_type: str = "norm", + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + self.num_inputs = len(aux_layers) + self.aux_layers = aux_layers + + assert self.num_inputs >= 1 + if combiner_type in ["random_frame", "random_layer"]: + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) + .log() + .item() + ) + self.combiner_type = combiner_type + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + + if self.combiner_type == "mfa": + return self.forward_mfa(inputs) + elif self.combiner_type == "random_frame": + return self.forward_rand_frame(inputs) + elif self.combiner_type == "random_layer": + return self.forward_rand_layer(inputs) + return self.forward_norm(inputs) + + def forward_mfa(self, inputs: List[torch.Tensor]) -> torch.Tensor: + return torch.cat(inputs,dim=-1) + + def forward_norm(self, inputs: List[torch.Tensor]) -> torch.Tensor: + return inputs[-1] + + def forward_rand_frame(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training or torch.jit.is_scripting() or self.num_inputs==1: + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def forward_rand_layer(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (B, T, C) + Returns: + A Tensor of shape (B, T, C). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training or torch.jit.is_scripting() or self.num_inputs==1: + return inputs[-1] + + + + num_channels = inputs[0].shape[-1] + num_b = inputs[0].shape[0] + + ndim = inputs[0].ndim + # stacked_inputs: (B, T, C, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim) + + # weights: (B, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_b + ) + + weights = weights.reshape(num_b,1, num_inputs, 1) + + # ans: (B, T, C, 1) + ans = torch.matmul(stacked_inputs, weights) + + # ans: (B, T, C) + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + def extra_repr(self): + s = ('{combiner_type}, num_inputs_layer={num_inputs}, aux_layers={aux_layers}') + if "random" in self.combiner_type: + s += ', final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}, final_log_weight={final_log_weight}' + return s.format(**self.__dict__) + +# def scale_args(layer_args,num=2,scale=2): + # new_args=[] + # for i,arg in enumerate(layer_args): + # if i x + att(x) """ - def __init__(self, size, self_attn, feed_forward, dropout_rate, - normalize_before=True, concat_after=False): - super(EncoderLayer, self).__init__() + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: torch.nn.Module, + dropout_rate: float=0.1, + layer_dropout: float=0., + normalize_before: bool = True, + norm_type: str = "layer_norm", + positionwise_layer_type: str="linear", + concat_after: bool = False + ): + super(TransformerEncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward - self.norm1 = LayerNorm(size) - self.norm2 = LayerNorm(size) + if norm_type == "batch_norm": + self.use_layer_norm = False + + norm_tp = Trans_Bat + else: + self.use_layer_norm = True + if norm_type == "layer_norm": + norm_tp = LayerNorm + else: + norm_tp = BasicNorm + self.norm1 = norm_tp(size) + self.norm2 = norm_tp(size) self.dropout = nn.Dropout(dropout_rate) + self.layer_dropout = layer_dropout self.size = size self.normalize_before = normalize_before self.concat_after = concat_after + self.concat_linear = None + self.positionwise_layer_type = positionwise_layer_type if self.concat_after: self.concat_linear = nn.Linear(size + size, size) - def forward(self, x, mask): + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute encoded features :param torch.Tensor x: encoded source features (batch, max_time_in, size) :param torch.Tensor mask: mask for x (batch, max_time_in) + :param torch.Tensor pos_emb: + :param torch.Tensor mask_pad: does not used in transformer layer, just for unified api with conformer. :rtype: Tuple[torch.Tensor, torch.Tensor] """ + warmup_scale = min(0.1 + warmup, 1.0) + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + x_orig = x residual = x + if self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) x = self.norm1(x) - if self.concat_after: - x_concat = torch.cat((x, self.self_attn(x, x, x, mask)), dim=-1) + if not self.use_layer_norm: + x = x.transpose(1, 2) + + x_att = self.self_attn(x, x, x, mask, pos_emb) + + if self.concat_linear is not None: + x_concat = torch.cat((x,x_att), dim=-1) x = residual + self.concat_linear(x_concat) else: - x = residual + self.dropout(self.self_attn(x, x, x, mask)) + x = residual + self.dropout(x_att) + if not self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) x = self.norm1(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) residual = x + if self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) x = self.norm2(x) - x = residual + self.dropout(self.feed_forward(x)) + if not self.use_layer_norm: + x = x.transpose(1, 2) + if self.positionwise_layer_type == 'gau': + x = self.feed_forward(x, x, x, mask, pos_emb) + else: + x = self.feed_forward(x) + x = residual + self.dropout(x) + if not self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) x = self.norm2(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + + if alpha != 1.0: + x = alpha * x + (1 - alpha) * x_orig + + return x, mask + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module + + :param int size: input dim + :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module + :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward: + feed forward module + :param feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + :param conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + :param float dropout_rate: dropout rate + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: torch.nn.Module, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float=0.1, + layer_dropout: float=0., + normalize_before: bool = True, + norm_type: str = "layer_norm", + positionwise_layer_type: str="linear", + concat_after: bool = False + ): + super(ConformerEncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + if norm_type == "batch_norm": + self.use_layer_norm = False + + norm_tp = Trans_Bat + else: + self.use_layer_norm = True + if norm_type == "layer_norm": + norm_tp = LayerNorm + else: + norm_tp = BasicNorm + self.norm_ff = norm_tp(size) + self.norm_mha = norm_tp(size) + if feed_forward_macaron is not None: + self.norm_ff_macaron = norm_tp(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = norm_tp(size) # for the CNN module + self.norm_final = norm_tp(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.layer_dropout = layer_dropout + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + self.positionwise_layer_type = positionwise_layer_type + self.concat_linear = None + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoded features + + :param torch.Tensor x: encoded source features (batch, max_time_in, size) + :param torch.Tensor mask: mask for x (batch, max_time_in, max_time_in) + :param torch.Tensor pos_emb: + :param torch.Tensor mask_pad: batch padding mask used for conv module (batch, 1, time) + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + warmup_scale = min(0.1 + warmup, 1.0) + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + x_orig = x + if self.feed_forward_macaron is not None: + residual = x + + if self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_ff_macaron(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + if self.positionwise_layer_type == 'gau': + x = self.feed_forward_macaron(x, x, x, mask, pos_emb) + else: + x = self.feed_forward_macaron(x) + x = residual + self.ff_scale * self.dropout(x) + + if not self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_ff_macaron(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + # MHA + residual = x + + if self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_mha(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + + x_att = self.self_attn(x, x, x, mask, pos_emb) + + if self.concat_linear is not None: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + + if not self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_mha(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + + # convolution module + if self.conv_module is not None: + residual = x + + if self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_conv(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + + x = self.conv_module(x, mask_pad) + x = residual + self.dropout(x) + + if not self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_conv(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + # FFN + residual = x + + if self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_ff(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + + if self.positionwise_layer_type == 'gau': + x = self.feed_forward(x, x, x, mask, pos_emb) + else: + x = self.feed_forward(x) + x = residual + self.ff_scale * self.dropout(x) + + + if not self.normalize_before: + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_ff(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + + if self.conv_module is not None: + + if not self.use_layer_norm: + x = x.transpose(1, 2) + x = self.norm_final(x) + if not self.use_layer_norm: + x = x.transpose(1, 2) + + if alpha != 1.0: + x = alpha * x + (1 - alpha) * x_orig + return x, mask + + +class ReConformerEncoderLayer(nn.Module): + """Encoder layer module + + :param int size: input dim + :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module + :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward: + feed forward module + :param feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + :param conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + :param float dropout_rate: dropout rate + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: torch.nn.Module, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float=0.1, + layer_dropout: float=0.075, + positionwise_layer_type: str="linear", + concat_after: bool = False + ): + super(ReConformerEncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.dropout = nn.Dropout(dropout_rate) + self.layer_dropout = layer_dropout + self.size = size + self.concat_after = concat_after + self.positionwise_layer_type = positionwise_layer_type + self.concat_linear = None + if self.concat_after: + self.concat_linear = ScaledLinear(size + size, size) + self.norm_final = BasicNorm(size) + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoded features + + :param torch.Tensor x: encoded source features (batch, max_time_in, size) + :param torch.Tensor mask: mask for x (batch, max_time_in, max_time_in) + :param torch.Tensor pos_emb: + :param torch.Tensor mask_pad: batch padding mask used for conv module (batch, 1, time) + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + warmup_scale = min(0.1 + warmup, 1.0) + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + x_orig = x + if self.feed_forward_macaron is not None: + residual = x + + if self.positionwise_layer_type == 'gau': + x = self.feed_forward_macaron(x, x, x, mask, pos_emb) + else: + x = self.feed_forward_macaron(x) + x = residual + self.dropout(x) + + + # MHA + + x_att = self.self_attn(x, x, x, mask, pos_emb) + + if self.concat_linear is not None: + x_concat = torch.cat((x, x_att), dim=-1) + x = x + self.concat_linear(x_concat) + else: + x = x + self.dropout(x_att) + + + # convolution module + if self.conv_module is not None: + residual = x + + + x = self.conv_module(x, mask_pad) + x = residual + self.dropout(x) + + # FFN + residual = x + + + if self.positionwise_layer_type == 'gau': + x = self.feed_forward(x, x, x, mask, pos_emb) + else: + x = self.feed_forward(x) + x = residual + self.dropout(x) + + + x = self.norm_final(self.balancer(x)) + + if alpha != 1.0: + x = alpha * x + (1 - alpha) * x_orig return x, mask diff --git a/pytorch/libs/nnet/transformer/layer_norm.py b/pytorch/libs/nnet/transformer/layer_norm.py index 66e81fd..0c28a05 100644 --- a/pytorch/libs/nnet/transformer/layer_norm.py +++ b/pytorch/libs/nnet/transformer/layer_norm.py @@ -3,8 +3,57 @@ # Reference: https://github.com/espnet/espnet. import torch +import torch.nn.functional as F +class Trans_Bat(torch.nn.BatchNorm1d): + """BatchNorm1d module + :param int nout: output dim size + :param int dim: dimension to be normalized + :param bool transpose: transpose T and F. + :param float eps: invalid, compatible with LayerNorm. + """ + + def __init__(self, nout, transpose=False,eps=1e-12,learnabel_affine: bool = True): + super(Trans_Bat,self).__init__(nout) + self.norm = torch.nn.BatchNorm1d(nout,affine=learnabel_affine) + self.transpose = transpose + def forward(self, x): + """Apply BatchNorm1d normalization + + :param torch.Tensor x: input tensor + :return: batch normalized tensor + :rtype torch.Tensor + """ + + if self.transpose: + return self.norm(x.transpose(1, -1)).transpose(1, -1) + else: + return self.norm(x) + +# class LayerNorm(torch.nn.Module): +# """Layer normalization module + +# :param int nout: output dim size +# :param int dim: dimension to be normalized +# """ +# def __init__( +# self, +# nout: int, +# dim: int = -1, # CAUTION: see documentation. +# eps: float = 1e-5, +# learnabel_affine: bool = True, +# ) -> None: +# super(LayerNorm, self).__init__() +# self.dim = dim + +# self.norm = torch.nn.LayerNorm(nout,eps=eps,elementwise_affine=learnabel_affine) + +# def forward(self,x): +# if self.dim == -1: +# return self.norm(x) + +# return self.norm(x.transpose(1, -1)).transpose(1, -1) class LayerNorm(torch.nn.LayerNorm): """Layer normalization module @@ -12,8 +61,8 @@ class LayerNorm(torch.nn.LayerNorm): :param int dim: dimension to be normalized """ - def __init__(self, nout, dim=-1): - super(LayerNorm, self).__init__(nout, eps=1e-12) + def __init__(self, nout, dim=-1, eps=1e-5,learnabel_affine: bool = True): + super(LayerNorm, self).__init__(nout, eps=eps,elementwise_affine=learnabel_affine) self.dim = dim def forward(self, x): @@ -24,5 +73,56 @@ def forward(self, x): :rtype torch.Tensor """ if self.dim == -1: - return super(LayerNorm, self).forward(x) - return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps) + return F.layer_norm( + x.transpose(1, -1), self.normalized_shape, self.weight, self.bias, self.eps).transpose(1, -1) + # return super().forward(x) + # return super().forward(x.transpose(1, -1)).transpose(1, -1) +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = torch.nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + scales = ( + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() + ) ** -0.5 + return x * scales \ No newline at end of file diff --git a/pytorch/libs/nnet/transformer/mask.py b/pytorch/libs/nnet/transformer/mask.py new file mode 100755 index 0000000..3c34268 --- /dev/null +++ b/pytorch/libs/nnet/transformer/mask.py @@ -0,0 +1,136 @@ +# -*- coding:utf-8 -*- + +# Reference: https://github.com/wenet-e2e/wenet + +import torch + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, static_chunk_size: int, + left_chunk_size: int): + """ Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: use random dynamic chunk. + <0: use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + left_chunk_size: number of left chunks. + >=0: use left_chunk_size + <0: use all left chunks + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = left_chunk_size + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = torch.randint(1, max_len, (1, )).item() + num_left_chunks = -1 + if chunk_size > max_len // 2: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = torch.randint(0, max_left_chunks, + (1, )).item() + chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = left_chunk_size + chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + return chunk_masks + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask \ No newline at end of file diff --git a/pytorch/libs/nnet/transformer/multi_layer_conv.py b/pytorch/libs/nnet/transformer/multi_layer_conv.py index 3c21042..5128a0e 100644 --- a/pytorch/libs/nnet/transformer/multi_layer_conv.py +++ b/pytorch/libs/nnet/transformer/multi_layer_conv.py @@ -7,7 +7,8 @@ """Layer modules for FFT block in FastSpeech (Feed-forward Transformer).""" import torch - +from libs.nnet.activation import Nonlinearity +from .scaling import ActivationBalancer,ScaledConv1d,ScaledLinear class MultiLayeredConv1d(torch.nn.Module): """Multi-layered conv1d for Transformer block. @@ -20,7 +21,7 @@ class MultiLayeredConv1d(torch.nn.Module): """ - def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate,activation_type='relu',activation_balancer=False,re_scale=False): """Initialize MultiLayeredConv1d module. Args: @@ -31,13 +32,18 @@ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): """ super(MultiLayeredConv1d, self).__init__() - self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, + conv = ScaledConv1d if re_scale else torch.nn.Conv1d + self.w_1 = conv(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2) - self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size, + self.w_2 = conv(hidden_chans, in_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2) self.dropout = torch.nn.Dropout(dropout_rate) + self.activation = Nonlinearity(activation_type) + self.balancer = None + if activation_balancer: + self.balancer = ActivationBalancer(channel_dim=1) - def forward(self, x): + def forward(self, x,x1:torch.Tensor = torch.empty(0),x2:torch.Tensor = torch.empty(0),mask:torch.Tensor = torch.empty(0),pos_embed:torch.Tensor = torch.empty(0)): """Calculate forward propagation. Args: @@ -47,7 +53,10 @@ def forward(self, x): Tensor: Batch of output tensors (B, *, hidden_chans) """ - x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + x = self.w_1(x.transpose(-1, 1)) + if self.balancer is not None: + x=self.balancer(x) + x = self.activation(x).transpose(-1, 1) return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) @@ -58,7 +67,7 @@ class Conv1dLinear(torch.nn.Module): """ - def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate,activation_type='relu',activation_balancer=False,re_scale=False): """Initialize Conv1dLinear module. Args: @@ -69,12 +78,17 @@ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): """ super(Conv1dLinear, self).__init__() - self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, + conv = ScaledConv1d if re_scale else torch.nn.Conv1d + + self.w_1 = conv(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2) - self.w_2 = torch.nn.Linear(hidden_chans, in_chans) + self.w_2 = ScaledLinear(hidden_chans, in_chans,initial_scale=0.25) if re_scale else torch.nn.Linear(hidden_chans, in_chans) self.dropout = torch.nn.Dropout(dropout_rate) - - def forward(self, x): + self.activation = Nonlinearity(activation_type) + self.balancer = None + if activation_balancer: + self.balancer = ActivationBalancer(channel_dim=1) + def forward(self, x,x1:torch.Tensor = torch.empty(0),x2:torch.Tensor = torch.empty(0),mask:torch.Tensor = torch.empty(0),pos_embed:torch.Tensor = torch.empty(0)): """Calculate forward propagation. Args: @@ -84,5 +98,9 @@ def forward(self, x): Tensor: Batch of output tensors (B, *, hidden_chans) """ - x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + x = self.w_1(x.transpose(-1, 1)) + if self.balancer is not None: + x=self.balancer(x) + x = self.activation(x).transpose(-1, 1) return self.w_2(self.dropout(x)) + diff --git a/pytorch/libs/nnet/transformer/positionwise_feed_forward.py b/pytorch/libs/nnet/transformer/positionwise_feed_forward.py index cf70c87..7f2ff01 100644 --- a/pytorch/libs/nnet/transformer/positionwise_feed_forward.py +++ b/pytorch/libs/nnet/transformer/positionwise_feed_forward.py @@ -2,7 +2,11 @@ # Reference: https://github.com/espnet/espnet. +from tkinter import N +from tkinter.messagebox import NO import torch +from libs.nnet.activation import Nonlinearity +from .scaling import ActivationBalancer,ScaledLinear class PositionwiseFeedForward(torch.nn.Module): """Positionwise feed forward @@ -12,11 +16,20 @@ class PositionwiseFeedForward(torch.nn.Module): :param float dropout_rate: dropout rate """ - def __init__(self, idim, hidden_units, dropout_rate): + def __init__(self, idim, hidden_units, dropout_rate,activation_type='relu',activation_balancer=False,re_scale=False): super(PositionwiseFeedForward, self).__init__() - self.w_1 = torch.nn.Linear(idim, hidden_units) - self.w_2 = torch.nn.Linear(hidden_units, idim) + self.w_1 = ScaledLinear(idim, hidden_units) if re_scale else torch.nn.Linear(idim, hidden_units) + self.w_2 = ScaledLinear(hidden_units, idim,initial_scale=0.25) if re_scale else torch.nn.Linear(hidden_units, idim) self.dropout = torch.nn.Dropout(dropout_rate) - - def forward(self, x): - return self.w_2(self.dropout(torch.relu(self.w_1(x)))) + + self.activation = Nonlinearity(activation_type) + self.balancer = None + if activation_balancer: + self.balancer = ActivationBalancer(channel_dim=-1) + assert self.activation is not None + def forward(self, x,x1:torch.Tensor = torch.empty(0),x2:torch.Tensor = torch.empty(0),mask:torch.Tensor = torch.empty(0),pos_embed:torch.Tensor = torch.empty(0)): + x = self.w_1(x) + if self.balancer is not None: + x=self.balancer(x) + return self.w_2(self.dropout(self.activation(x))) + # return self.w_2(self.dropout(self.activation(self.w_1(x)))) diff --git a/pytorch/libs/nnet/transformer/scaling.py b/pytorch/libs/nnet/transformer/scaling.py new file mode 100755 index 0000000..6d7debf --- /dev/null +++ b/pytorch/libs/nnet/transformer/scaling.py @@ -0,0 +1,499 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao) +# For reworked conformer +# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py + + +import collections +from itertools import repeat +from typing import Optional, Tuple + +import torch +import torch.backends.cudnn.rnn as rnn +import torch.nn as nn +from torch import _VF, Tensor + + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + ) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + + # sum_dims = [d for d in range(x.ndim) if d != channel_dim] + # The above line is not torch scriptable for torch 1.6.0 + # torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa + sum_dims = [] + for d in range(x.ndim): + if d != channel_dim: + sum_dims.append(d) + + xgt0 = x > 0 + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) + + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs + + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) + ctx.max_factor = max_factor + ctx.sum_dims = sum_dims + return x + + @staticmethod + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors + dtype = x_grad.dtype + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) + + neg_delta_grad = x_grad.abs() * (factor + scale_factor) + return x_grad - neg_delta_grad, None, None, None, None, None, None + + + + + +class ScaledLinear(nn.Linear): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * self.weight_scale.exp() + bias = self.bias * self.bias_scale.exp() + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + initial_speed: this affects how fast the parameter will + learn near the start of training; you can set it to a + value less than one if you suspect that a module + is contributing to instability near the start of training. + Nnote: regardless of the use of this option, it's best to + use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + """ + + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): + super(ScaledLinear, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + if self.bias is None or self.bias_scale is None: + return None + else: + return self.bias * self.bias_scale.exp() + + def forward(self, input: Tensor) -> Tensor: + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) + + +class ScaledConv1d(nn.Conv1d): + # See docs for ScaledLinear + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): + super(ScaledConv1d, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + + self.bias_scale: Optional[nn.Parameter] # for torchscript + + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + bias = self.bias + bias_scale = self.bias_scale + if bias is None or bias_scale is None: + return None + else: + return bias * bias_scale.exp() + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + (0,), + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +class ScaledConv2d(nn.Conv2d): + # See docs for ScaledLinear + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): + super(ScaledConv2d, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + # see https://github.com/pytorch/pytorch/issues/24135 + bias = self.bias + bias_scale = self.bias_scale + if bias is None or bias_scale is None: + return None + else: + return bias * bias_scale.exp() + + def _conv_forward(self, input, weight): + F = torch.nn.functional + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + (0, 0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.get_weight()) + + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + Args: + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + """ + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): + super(ActivationBalancer, self).__init__() + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return x + else: + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + x = x.detach() + s = torch.sigmoid(x - 1.0) + y = x * s + ctx.save_for_backward(s, y) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + s, y = ctx.saved_tensors + return (y * (1 - s) + s) * y_grad + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + else: + return DoubleSwishFunction.apply(x) + + + +def _test_activation_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + + +def _test_activation_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 0.5 + x.requires_grad = True + m = DoubleSwish() + torch.autograd.gradcheck(m, x) + + + + + +if __name__ == "__main__": + _test_activation_balancer_sign() + # _test_activation_balancer_magnitude() + # _test_basic_norm() + # _test_double_swish_deriv() diff --git a/pytorch/libs/nnet/transformer/subsampling.py b/pytorch/libs/nnet/transformer/subsampling.py index 085f50c..c4d3df4 100644 --- a/pytorch/libs/nnet/transformer/subsampling.py +++ b/pytorch/libs/nnet/transformer/subsampling.py @@ -2,44 +2,568 @@ # Reference: https://github.com/espnet/espnet. +from turtle import xcor import torch +from typing import Tuple +import torch.nn as nn +from .scaling import ScaledLinear,ScaledConv2d,ActivationBalancer +from libs.nnet.activation import DoubleSwish +from .layer_norm import BasicNorm +class TooShortUttError(Exception): + """Raised when the utt is too short for subsampling. + Args: + message (str): Message for error catch + actual_size (int): the short size that cannot pass the subsampling + limit (int): the limit size for subsampling + """ + + def __init__(self, message, actual_size, limit): + """Construct a TooShortUttError for error handler.""" + super().__init__(message) + self.actual_size = actual_size + self.limit = limit -from .embedding import PositionalEncoding +def check_short_utt(ins, size): + """Check if the utterance is too short for subsampling.""" + if isinstance(ins, Conv2dSubsampling2) and size < 7: + return True, 7 + if isinstance(ins, Conv2dSubsampling4) and size < 7: + return True, 7 + if isinstance(ins, Conv2dSubsampling6) and size < 11: + return True, 11 + if isinstance(ins, Conv2dSubsampling8) and size < 15: + return True, 15 + return False, -1 -class Conv2dSubsampling(torch.nn.Module): - """Convolutional 2D subsampling (to 1/4 length) +class LinearNoSubsampling(torch.nn.Module): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc_class (torch.nn.Module): Custom position encoding layer. - :param int idim: input dim - :param int odim: output dim - :param flaot dropout_rate: dropout rate """ + def __init__(self, idim: int, odim: int, dropout_rate: float,mlp_head: bool, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + # torch.nn.LayerNorm(odim, eps=1e-5), + # torch.nn.Dropout(dropout_rate), + ) + self.pos_enc = pos_enc_class + self.subsampling_rate = 1 + self.mlp_head = mlp_head + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0) + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). - def __init__(self, idim, odim, dropout_rate): - super(Conv2dSubsampling, self).__init__() + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + if self.mlp_head: + b,_,_ = x.shape + cls_tokens = self.cls_token.repeat(b,1,1) + + x = torch.cat((cls_tokens,x),dim=1) + + x, pos_emb = self.pos_enc(x) + return x, pos_emb, x_mask + + +class Conv2dSubsampling4(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc_class (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + mlp_head: bool, pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling4, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), - torch.nn.ReLU() + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + ) + self.pos_enc = pos_enc_class + self.subsampling_rate = 4 + self.mlp_head = mlp_head + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0) + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if self.mlp_head: + b,_,_ = x.shape + cls_tokens = self.cls_token.repeat(b,1,1) + x = torch.cat((cls_tokens,x),dim=1) + x, pos_emb = self.pos_enc(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + + def __getitem__(self, key): + """Get item. + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + +class ReConv2dSubsampling4(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc_class (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, + idim: int, + odim: int, + dropout_rate: float, + mlp_head: bool, + pos_enc_class: torch.nn.Module, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ): + """Construct an Conv2dSubsampling object.""" + assert idim >= 7 + super(ReConv2dSubsampling4, self).__init__() + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((idim - 1) // 2 - 1) // 2), odim + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(odim, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + self.pos_enc = pos_enc_class + self.subsampling_rate = 4 + self.mlp_head = mlp_head + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0) + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x = self.out_norm(x) + x = self.out_balancer(x) + if self.mlp_head: + b,_,_ = x.shape + cls_tokens = self.cls_token.repeat(b,1,1) + x = torch.cat((cls_tokens,x),dim=1) + + x, pos_emb = self.pos_enc(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + +class SVConv2dSubsampling4(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc_class (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + mlp_head: bool, pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling object.""" + super(SVConv2dSubsampling4, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, (2,1)), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, (2,1)), + torch.nn.ReLU(), ) self.out = torch.nn.Sequential( - torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), - PositionalEncoding(odim, dropout_rate) + torch.nn.Linear(odim * (idim-4) , odim) ) + self.pos_enc = pos_enc_class + self.subsampling_rate = 4 + self.mlp_head = mlp_head + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0) + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if self.mlp_head: + b,_,_ = x.shape + cls_tokens = self.cls_token.repeat(b,1,1) + x = torch.cat((cls_tokens,x),dim=1) + x, pos_emb = self.pos_enc(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] - def forward(self, x, x_mask): - """Subsample x + def __getitem__(self, key): + """Get item. + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dSubsampling2(torch.nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc_class (torch.nn.Module): Custom position encoding layer. + """ - :param torch.Tensor x: input tensor - :param torch.Tensor x_mask: input mask - :return: subsampled x and mask - :rtype Tuple[torch.Tensor, torch.Tensor] + def __init__(self, idim: int, odim: int, dropout_rate: float, mlp_head: bool, pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling2 object.""" + super(Conv2dSubsampling2, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim)) + self.pos_enc = pos_enc_class + self.subsampling_rate = 2 + self.mlp_head = mlp_head + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0) + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - if x_mask is None: - return x, None - return x, x_mask[:, :, :-2:2][:, :, :-2:2] + if self.mlp_head: + b,_,_ = x.shape + cls_tokens = self.cls_token.repeat(b,1,1) + x = torch.cat((cls_tokens,x),dim=1) + x, pos_emb = self.pos_enc(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:1] + + def __getitem__(self, key): + """Get item. + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + +class SVConv2dSubsampling2(torch.nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc_class (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, mlp_head: bool, pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling2 object.""" + super(SVConv2dSubsampling2, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, (2,1)), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (idim-4), odim)) + self.pos_enc = pos_enc_class + self.subsampling_rate = 2 + self.mlp_head = mlp_head + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0) + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if self.mlp_head: + b,_,_ = x.shape + cls_tokens = self.cls_token.repeat(b,1,1) + x = torch.cat((cls_tokens,x),dim=1) + x, pos_emb = self.pos_enc(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:1] + # return x, pos_emb, x_mask[:, :, :-2:2] + def __getitem__(self, key): + """Get item. + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + +class Conv2dSubsampling6(torch.nn.Module): + """Convolutional 2D subsampling (to 1/6 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, mlp_head: bool, pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling6 object.""" + super(Conv2dSubsampling6, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim), + ) + self.pos_enc = pos_enc_class + self.subsampling_rate = 6 + self.mlp_head = mlp_head + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0) + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if self.mlp_head: + b,_,_ = x.shape + cls_tokens = self.cls_token.repeat(b,1,1) + x = torch.cat((cls_tokens,x),dim=1) + + x, pos_emb = self.pos_enc(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] + + +class Conv2dSubsampling8(torch.nn.Module): + """Convolutional 2D subsampling (to 1/8 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, mlp_head: bool, pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling8 object.""" + super(Conv2dSubsampling8, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)) + self.pos_enc = pos_enc_class + self.subsampling_rate = 8 + self.mlp_head = mlp_head + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, odim)) if self.mlp_head else torch.empty(0) + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if self.mlp_head: + b,_,_ = x.shape + cls_tokens = self.cls_token.repeat(b,1,1) + x = torch.cat((cls_tokens,x),dim=1) + x, pos_emb = self.pos_enc(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] + + +class FrameMerger(nn.Module): + def __init__(self, dim: int, + normalize_before: bool = True, + norm_type: str = "layer_norm",): + super().__init__() + # self.conv = torch.nn.Sequential( + # torch.nn.Conv1d(dim, dim*2, 3,2), + # torch.nn.ReLU(), + # torch.nn.Conv1d(dim*2, dim*2, 3, 2), + # torch.nn.ReLU(), + # ) + self.conv = torch.nn.Sequential( + torch.nn.Conv1d(dim, 2*dim, 3,2), + torch.nn.ReLU(), + torch.nn.Conv1d(2*dim, 2*dim, 3, 2), + torch.nn.ReLU(), + ) + # self.conv = torch.nn.Sequential( + # torch.nn.Conv2d(1, dim, 3, (2,1)), + # torch.nn.ReLU(), + # torch.nn.Conv2d(dim, dim, 3, (2,1)), + # torch.nn.ReLU(), + # ) + + self.normalize_before = normalize_before + if self.normalize_before: + self.norm = nn.LayerNorm(dim) + else: + self.norm = nn.LayerNorm(dim*2) + + def forward(self, x,pos_emb:torch.Tensor,mask:torch.Tensor,offset: int = 0): + if self.normalize_before: + x = self.norm(x) + x = x.transpose(1, 2) + x = self.conv(x) + x = x.transpose(1, 2) + if not self.normalize_before: + x = self.norm(x) + + return x,pos_emb[:, offset:offset + x.size(1)],mask[:, :, :-2:2][:, :, :-2:2] \ No newline at end of file diff --git a/pytorch/libs/support/utils.py b/pytorch/libs/support/utils.py index a9800a9..5f74073 100755 --- a/pytorch/libs/support/utils.py +++ b/pytorch/libs/support/utils.py @@ -4,6 +4,7 @@ import sys, os import math +import yaml import random import logging import shutil @@ -181,6 +182,7 @@ def create_model_from_py(model_blueprint, model_creation=""): return model_module else: model = eval("model_module.{0}".format(model_creation)) + return model @@ -214,7 +216,7 @@ def create_model_dir(model_dir:str, model_blueprint:str, stage=-1): os.makedirs("{0}/config".format(model_dir), exist_ok=True) os.makedirs("{0}/checkpoint_info".format(model_dir), exist_ok=True) if is_main_training(): - if stage < 0 and model_blueprint != config_model_blueprint: + if stage <= 0 and model_blueprint != config_model_blueprint: shutil.copy(model_blueprint, config_model_blueprint) else: while(True): @@ -246,7 +248,7 @@ def read_file_to_list(file_path, every_bytes=10000000): return list -def write_list_to_file(this_list, file_path, mod='w'): +def write_list_to_file(this_list, file_path, mod='w', yml=False): """ @mod: could be 'w' or 'a' """ @@ -254,8 +256,14 @@ def write_list_to_file(this_list, file_path, mod='w'): this_list = [this_list] with open(file_path, mod) as writer : - writer.write('\n'.join(str(x) for x in this_list)) - writer.write('\n') + if yml: + for x in this_list: + yaml.dump(x,writer) + + else: + + writer.write('\n'.join(str(x) for x in this_list)) + writer.write('\n') def save_checkpoint(checkpoint_path, **kwargs): @@ -309,9 +317,11 @@ def key_to_value(adict, key, return_none=True): def assign_params_dict(default_params:dict, params:dict, force_check=False, support_unknow=False): + default_params = copy.deepcopy(default_params) - default_keys = set(default_params.keys()) + default_keys = set(default_params.keys()) + # Should keep force_check=False to use support_unknow if force_check: for key in param.keys(): @@ -484,7 +494,7 @@ def get_free_port(ip="127.0.0.1"): return s.getsockname()[1] # https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/utils/data_utils.py -def batch_pad_right(tensors: list, mode="constant", value=0): +def batch_pad_right(tensors: list, mode="constant", value=0,val_index=-1): """Given a list of torch tensors it batches them together by padding to the right on each dimension in order to get same length for all. @@ -543,10 +553,10 @@ def batch_pad_right(tensors: list, mode="constant", value=0): t, max_shape, mode=mode, value=value ) batched.append(padded) - valid.append(valid_percent[0]) + valid.append(valid_percent[val_index]) batched = torch.stack(batched) - + return batched, torch.tensor(valid) @@ -589,7 +599,6 @@ def pad_right_to( valid_vals.append(tensor.shape[j] / target_shape[j]) i -= 1 j += 1 - tensor = torch.nn.functional.pad(tensor, pads, mode=mode, value=value) return tensor, valid_vals diff --git a/pytorch/libs/training/lr_scheduler_online.py b/pytorch/libs/training/lr_scheduler_online.py old mode 100644 new mode 100755 index dbd3082..7377c43 --- a/pytorch/libs/training/lr_scheduler_online.py +++ b/pytorch/libs/training/lr_scheduler_online.py @@ -7,6 +7,7 @@ import math import numpy as np import torch +from typing import Union from torch.optim.lr_scheduler import _LRScheduler from .optim import * @@ -39,7 +40,8 @@ def __init__(self, optimizer, params:dict={}): "1cycle.epochs":None, "1cycle.steps_per_epoch":None, "1cycle.pct_start":0.3, - "1cycle.anneal_strategy":'linear', + "1cycle.warmup_steps":None, + "1cycle.anneal_strategy":'cos', # ["cos", "linear"] "1cycle.cycle_momentum":False, "1cycle.base_momentum":0.85, "1cycle.max_momentum":0.95, @@ -53,6 +55,11 @@ def __init__(self, optimizer, params:dict={}): "warmR.log_decay":False, "warmR.lr_decay_step":1, + "noam.warmup_steps": 2000, + "noam.step_decay": False, + "noam.step_size": 34000, + "noam.step_rate": 0.5, + "reduceP.metric":'valid_acc', "reduceP.check_interval":0, "reduceP.factor":0.5, @@ -79,13 +86,23 @@ def __init__(self, optimizer, params:dict={}): self.step_total=int(step_up + step_down) self.lr_scheduler = torch.optim.lr_scheduler.CyclicLR(base_optimizer, base_lr, max_lr, **split_params["cyclic"]) elif self.name == "1cycle": + warmup_steps = split_params["1cycle"].pop("warmup_steps") + pct_start = split_params["1cycle"].pop("pct_start") max_lr = split_params["1cycle"].pop("learn_rate") - self.lr_scheduler = optim.lr_scheduler.OneCycleLR(base_optimizer, max_lr, **split_params["1cycle"]) + self.lr_scheduler = optim.lr_scheduler.OneCycleLR(base_optimizer, max_lr, pct_start=pct_start,**split_params["1cycle"]) + self.step_total = self.lr_scheduler.total_steps + if warmup_steps is not None: + pct_start = warmup_steps/self.step_total + self.lr_scheduler = optim.lr_scheduler.OneCycleLR(base_optimizer, max_lr, pct_start=pct_start,**split_params["1cycle"]) elif self.name == "warmR": self.T_0 = split_params["warmR"].pop("T_max") self.T_mult = split_params["warmR"]["T_mult"] self.lr_decay_step = split_params["warmR"].pop("lr_decay_step") self.lr_scheduler = CosineAnnealingWarmRestarts(base_optimizer, self.T_0, **split_params["warmR"]) + elif self.name == "noam": + warmup_steps = split_params["noam"].pop("warmup_steps") + self.lr_scheduler = WarmupLR(base_optimizer, warmup_steps, **split_params["noam"]) + pass elif self.name == "reduceP": self.check_interval = split_params["reduceP"].pop("check_interval") self.metric = split_params["reduceP"].pop("metric") @@ -120,7 +137,7 @@ def is_cycle_point(self, training_point): else: if math.log(max(0.05, (epoch / self.T_0 * (self.T_mult - 1) + 1)), self.T_mult)%1==0 and epoch>0: return False if (self.lr_decay_step == 0 and training_point[1]>1) else True - if self.name=="cyclic": + if self.name in ["cyclic", "1cycle"] : return True if training_point[2]%self.step_total==0 and training_point[2]>0 else False @@ -137,12 +154,16 @@ def step(self, training_point=None, valid_metric=None): elif self.name == "cyclic": self.lr_scheduler.step() elif self.name == "1cycle": - self.lr_scheduler.step(training_point[2]) + self.lr_scheduler.step() elif self.name == "reduceP": # Sample a point in which the metrics of valid are computed and adjust learning rate at this point. if self.is_reduce_point(training_point): metric = valid_metric[0] if self.metric == "valid_loss" else valid_metric[1] self.lr_scheduler.step(metric) + elif self.name == "noam": + self.lr_scheduler.step() + else: + raise ValueError("unknown lr_scheduler: " + self.name) ## Learn rate scheduler ✿ class CosineAnnealingWarmRestarts(_LRScheduler): @@ -174,7 +195,7 @@ class CosineAnnealingWarmRestarts(_LRScheduler): Base lr decay has been added. [Snowdar 2019-08-29] """ - def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, factor=1.0, log_decay=False, last_epoch=-1): + def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, factor=1.0, log_decay=False, last_epoch=-1, warmup_steps=5000): if T_0 <= 0 or not isinstance(T_0, int): raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) if T_mult <=0: # or not isinstance(T_mult, int): @@ -251,3 +272,67 @@ def step(self, epoch=None): self.last_epoch = math.floor(epoch) for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr + + +# Reference: https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/scheduler.py +class WarmupLR(_LRScheduler): + """The WarmupLR scheduler + + This scheduler is almost same as NoamLR Scheduler except for following + difference: + + NoamLR: + lr = optimizer.lr * model_size ** -0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + WarmupLR: + lr = optimizer.lr * warmup_step ** 0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + + Note that the maximum lr equals to optimizer.lr in this scheduler. + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + step_decay: bool=False, + step_size: int=80000, + step_rate: float = 0.5, + last_epoch: int = -1, + ): + self.warmup_steps = warmup_steps + self.step_decay = step_decay + self.step_size = step_size + self.step_rate = step_rate + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer, last_epoch) + + def __repr__(self): + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + + def get_lr(self): + step_num = self.last_epoch + 1 + if step_num>> from optimizer import RAdam - >>> optimizer = RAdam(model.parameters(), lr=0.001) - Note, here the weight decay is not L2 regularization. - ''' - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), N_sma_threshhold=4, eps=1e-8, weight_decay=0): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - self.N_sma_threshhold = N_sma_threshhold - self.buffer = [[None, None, None] for ind in range(10)] - super(RAdam, self).__init__(params, defaults) - - def __setstate__(self, state): - super(RAdam, self).__setstate__(state) - - def step(self, closure=None): - - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data.float() - if grad.is_sparse: - raise RuntimeError('RAdam does not support sparse gradients') - - p_data_fp32 = p.data.float() - - state = self.state[p] - - if len(state) == 0: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) - else: - state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] - - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - exp_avg.mul_(beta1).add_((1 - beta1) * grad) - - state['step'] += 1 - buffered = self.buffer[int(state['step'] % 10)] - if state['step'] == buffered[0]: - N_sma, step_size = buffered[1], buffered[2] - else: - buffered[0] = state['step'] - beta2_t = beta2 ** state['step'] - N_sma_max = 2 / (1 - beta2) - 1 - N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) - buffered[1] = N_sma - if N_sma > self.N_sma_threshhold: - step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) - else: - step_size = group['lr'] / (1 - beta1 ** state['step']) - buffered[2] = step_size - - if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * group['lr'] * p_data_fp32) - - if N_sma > self.N_sma_threshhold: - denom = exp_avg_sq.sqrt().add_(group['eps']) - p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) - else: - p_data_fp32.add_(-step_size * exp_avg) - - p.data.copy_(p_data_fp32) - - return loss - class Ralamb(Optimizer): '''https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py @@ -819,3 +760,223 @@ def step(self, closure=None): p.data.add_(-group['lr'] * exp_avg) return loss + +# Sharpness-Aware Minimization for Efficiently Improving Generalization. +# https://openreview.net/pdf?id=6Tm1mposlrM +# https://github.com/davda54/sam + +class SAM(Optimizer): + def __init__(self, optimizer, rho=0.05, adaptive=False) -> None: + assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" + defaults = dict(rho=rho, adaptive=adaptive) + + super(SAM, self).__init__(optimizer.param_groups, defaults) + self.base_optimizer = optimizer + + self.defaults.update(self.base_optimizer.defaults) + + @torch.no_grad() + def first_step(self, zero_grad=False): + grad_norm = self._grad_norm() + for group in self.param_groups: + scale = group["rho"] / (grad_norm + 1e-12) + + for p in group["params"]: + if p.grad is None: continue + self.state[p]["old_p"] = p.data.clone() + self.state[p]["cache"] = True + e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) + p.add_(e_w) # climb to the local maximum "w + e(w)" + + if zero_grad: self.zero_grad() + + @torch.no_grad() + def second_step(self, zero_grad=False): + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: continue + p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" + self.state[p]["cache"] = False + + self.base_optimizer.step() # do the actual "sharpness-aware" update + + if zero_grad: self.zero_grad() + + @torch.no_grad() + def back_w(self): + for group in self.param_groups: + for p in group["params"]: + p_cache = self.state[p].get('cache',False) + if p_cache: + p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" + self.state[p]["cache"] = False + + @torch.no_grad() + def step(self,stage=1): + if stage==1: + self.first_step(zero_grad=True) + return True + else: + self.second_step(zero_grad=True) + return True + + def _grad_norm(self): + shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism + norm = torch.norm( + torch.stack([ + ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) + for group in self.param_groups for p in group["params"] + if p.grad is not None + ]), + p=2 + ) + return norm + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.base_optimizer.param_groups = self.param_groups + + +class Eve(Optimizer): + r""" + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if not 0 <= weight_decay <= 0.1: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "AdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) + p.mul_(1 - (weight_decay * is_above_target_rms)) + p.addcdiv_(exp_avg, denom, value=-step_size) + + # Constrain the range of scalar weights + if p.numel() == 1: + p.clamp_(min=-10, max=2) + + return loss \ No newline at end of file diff --git a/pytorch/libs/training/trainer_online.py b/pytorch/libs/training/trainer_online.py old mode 100644 new mode 100755 index 9876603..60a0b95 --- a/pytorch/libs/training/trainer_online.py +++ b/pytorch/libs/training/trainer_online.py @@ -15,6 +15,7 @@ import numpy as np import yaml import torch +import torch.distributed as dist from torch.utils.data import DataLoader from contextlib import nullcontext from .reporter import Reporter_new as Reporter @@ -60,8 +61,8 @@ class _BaseTrainer(): def __init__(self, package, stop_early=False): default_elements = {"data": None, "model": None, "optimizer": None, "lr_scheduler": None} - default_params = {"model_dir": "", "model_blueprint": "", "exist_model": "", "start_epoch": 0, "epochs": 10, - "use_gpu": True, "gpu_id": "", "benchmark": True, "max_change": 10.0, "use_amp": False, "accum_grad": 1, + default_params = {"model_dir": "", "model_blueprint": "", "exist_model": "", "start_epoch": 0, "epochs": 10, "warmup_steps":0, + "use_gpu": True, "gpu_id": "", "benchmark": True, "max_change": 5.0, "use_amp": False, "accum_grad": 1, "compute_accuracy": True, "compute_valid_accuracy": True, "compute_batch_num_valid": 1, "suffix": "params", "nan_debug": False, "skip_nan_batch": True, "use_tensorboard": True} @@ -90,6 +91,7 @@ def __init__(self, package, stop_early=False): self.training_point = copy.deepcopy([self.params["start_epoch"], 0, 0]) self.cycle_point = 0 # for cycle training. + self.warmup_steps = self.params["warmup_steps"] # model level warm up. just for transformer. def select_device(self): return utils.select_model_device(self.elements["model"], self.params["use_gpu"], gpu_id=self.params["gpu_id"], benchmark=self.params["benchmark"]) @@ -140,7 +142,11 @@ def init_training(self): pass # Now, it means use the raw initial model # Select device - + + # for k,v in model.named_parameters(): + # if "train_len" in k: + # print(v) + # assert 1==0 model = self.select_device() # Original model is built in libs.nnet.framework.TopVirtualNnet, and it is not available after @@ -215,7 +221,11 @@ def train_one_batch(self, batch, step_lr=True): model = self.elements["model"] model_forward = self.elements["model_forward"] optimizer = self.elements["optimizer"] + cur_step = self.training_point[2] + # model level warm up. just for transformer. + warmup = cur_step / self.warmup_steps if self.warmup_steps >0 else 1.0 + warmup = torch.FloatTensor([warmup]) if not model.training: model.train() @@ -229,8 +239,13 @@ def train_one_batch(self, batch, step_lr=True): map_location="cpu")) self.elements["model"].to(device) else: - inputs, targets = batch + inputs, targets, feats_lens = batch + feats_lens = (feats_lens*inputs.shape[2]).long() + input_list = [inputs,feats_lens] + # model level warm up. just for transformer. + if self.warmup_steps >0: + input_list.append(warmup) context = None # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module @@ -246,13 +261,17 @@ def train_one_batch(self, batch, step_lr=True): # Managing automatic mixed precision (Leo 2021-11-08) with torch.cuda.amp.autocast(self.scaler is not None): - loss = model.get_loss(model_forward(inputs), targets)/self.params["accum_grad"] + loss = model.get_loss(model_forward(*input_list), targets)/self.params["accum_grad"] # loss = model_forward(inputs, targets)/self.params["accum_grad"] if self.params["use_amp"]: self.scaler.scale(loss).backward() else: loss.backward() + # for name, param in model.named_parameters(): + # if param.grad is None: + # print(name) + # sys.exit() loss.detach() # For safe. loss = loss.item()*self.params["accum_grad"] accuracy = None @@ -320,13 +339,15 @@ def compute_validation(self, data_loader): num_samples = 0 with torch.no_grad(): for idx,this_data in enumerate(data_loader): - inputs, targets = this_data + inputs, targets, feats_lens = this_data + feats_lens = (feats_lens*inputs.shape[2]).long() num_utts = targets.size(0) + input_list = [inputs,feats_lens] if num_utts == 0: continue # in valid stage, DO NOT call ddp model, for ddp model is in JOIN context wrapper. # Leo 2022-02-03 - loss += model.get_loss(model(inputs), + loss += model.get_loss(model(*input_list), targets).item() * len(targets) # loss += model_forward(inputs,targets).item() * len(targets) @@ -434,20 +455,32 @@ def run(self): model_context = model_forward.join(throw_on_early_termination=True) else: model_context = nullcontext() + + device = utils.get_device(self.elements["model"]) + stop_training = torch.zeros(1).to(device) + with model_context: - for this_epoch in range(start_epoch, epochs): - self.training_point[0]+=1 + for this_epoch in range(start_epoch, epochs+5): + # In case the uneven data in different ranks when ddp training, + # here we design a more epoch for sub-ranks to ensure the main rank can broadcasting. + if utils.is_main_training: + if this_epoch == epochs: # skip the last+1 epoch + stop_training = torch.ones(1).to(device) + + self.training_point[0]+=1 data.train_loader.dataset.set_epoch(this_epoch) # with model_context: + if utils.is_main_training() and self.training_point[1]==0: self.origin_total_dur,self.num_sample=data.train_loader.dataset.get_data_dur() for _, batch in enumerate(data.train_loader, 0): # It is important for reporter. - if utils.is_main_training() and self.training_point[1]==0: self.origin_total_dur,self.num_sample=data.train_loader.dataset.get_data_dur() - + dist.barrier() + if utils.use_ddp():dist.all_reduce(stop_training,op=dist.ReduceOp.SUM) + if stop_training: + break self.training_point[1] +=1 - num_utts = batch[0].size(0) if num_utts == 0: @@ -459,9 +492,12 @@ def run(self): loss, acc = self.train_one_batch(batch) model.backward_step(*self.training_point) - + if stop_training: + break if utils.is_main_training():self.save_model(train_lr=self.train_lr) self.training_point[1] =0 + + # dist.barrier() if utils.is_main_training():self.reporter.finish() if utils.is_main_training(): final_model_name = "{}_cycle".format(self.cycle_point) if self.cycle_point else epochs @@ -475,6 +511,9 @@ def run(self): if not isinstance(e, KeyboardInterrupt):traceback.print_exc() sys.exit(1) + + + @for_lr_finder_new def lr_finder_compute(self, train_batch): model = self.elements["model"] diff --git a/pytorch/libs/training/trainer_online_sam.py b/pytorch/libs/training/trainer_online_sam.py new file mode 100755 index 0000000..244041d --- /dev/null +++ b/pytorch/libs/training/trainer_online_sam.py @@ -0,0 +1,639 @@ +# -*- coding:utf-8 -*- + +# Copyright xmuspeech (Author: Snowdar 2019-11-25 +# Leo 2022-01-15) +import os +import sys +import re +import logging +import copy +import math +import time +import traceback +import progressbar +import pandas as pd +import numpy as np +import yaml +import torch +from torch.nn.modules.batchnorm import _BatchNorm +from torch.utils.data import DataLoader +import torch.distributed as dist +from contextlib import nullcontext +from .reporter import Reporter_new as Reporter +from .lr_scheduler_online import LRSchedulerWrapper +from .lr_finder import for_lr_finder_new + +import libs.support.utils as utils + +# Wrap stderr before logger init. +progressbar.streams.wrap_stderr() + +# Logger +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +""" +This is the structure of Package: + +Package( + Elements{ + data:Bunch + model:TopVirtualNnet + optimizer:Optimizer + lr_scheduler:LR_Scheduler + }, + + Params{ + model_dir:str + exist_model:str + start_epoch:int + epochs:int + ... + } + ) + +training_point (this epoch, iter in epoch, global step) +""" + +# Trainer ✿ + + +class _BaseTrainer(): + def __init__(self, package, stop_early=False): + default_elements = {"data": None, "model": None, + "optimizer": None, "lr_scheduler": None} + default_params = {"model_dir": "", "model_blueprint": "", "exist_model": "", "start_epoch": 0, "epochs": 10,"warmup_steps":0, + "use_gpu": True, "gpu_id": "", "benchmark": True, "max_change": 5.0, "use_amp": False, "accum_grad": 1, + "compute_accuracy": True, "compute_valid_accuracy": True, "compute_batch_num_valid": 1, + "suffix": "params", "nan_debug": False, "skip_nan_batch": True, "use_tensorboard": True} + + elements, params = package + self.elements = utils.assign_params_dict(default_elements, elements) + self.params = utils.assign_params_dict( + default_params, params, support_unknow=True) + + assert self.elements["data"] is not None + assert self.elements["model"] is not None + assert self.elements["optimizer"] is not None + + assert self.params["model_dir"] != "" + assert self.params["model_blueprint"] != "" + assert self.params["accum_grad"] == 1 ,"Sharpness Aware Minimization do not support accum_grad now." + + self.elements["model_forward"] = self.elements["model"] + self.params["start_epoch"] = max(0, self.params["start_epoch"]) + self.params["accum_grad"] = max(1,self.params["accum_grad"]) + self.use_ddp = utils.use_ddp() + self.stop_early = stop_early # To do. + + # Automatic mixed precision init.(Leo 2021-11-08) + self.scaler = torch.cuda.amp.GradScaler() if self.params["use_amp"] else None + + # (epoch, iter in epoch, global step) + self.training_point = copy.deepcopy([self.params["start_epoch"], 0, 0]) + self.cycle_point = 0 # for cycle training. + + self.warmup_steps = self.params["warmup_steps"] # model level warm up. just for transformer. + + def select_device(self): + return utils.select_model_device(self.elements["model"], self.params["use_gpu"], + gpu_id=self.params["gpu_id"], benchmark=self.params["benchmark"]) + + def init_training(self): + model = self.elements["model"] + start_epoch = self.params["start_epoch"] + exist_model = self.params["exist_model"] + model_dir = self.params["model_dir"] + model_blueprint = self.params["model_blueprint"] + suffix = self.params["suffix"] + + if start_epoch <= 0 and utils.is_main_training(): + model_creation = model.get_model_creation() + utils.write_nnet_config( + model_blueprint, model_creation, "{0}/config/nnet.config".format(model_dir)) + + # Recover checkpoint | Tansform learning | Initialize parametes + if start_epoch > 0: + # This train_stage is equal to number of completed epoch + if utils.is_main_training(): + logger.info( + "Recover training from {0} epoch.".format(start_epoch)) + model.load_state_dict(torch.load('{0}/{1}.{2}'.format(model_dir, start_epoch, suffix), + map_location="cpu")) + # info_path = '{0}/{1}/{2}.{3}'.format( + # model_dir, "checkpoint_info", start_epoch, "info") + info_log_path = '{0}/{1}/{2}.{3}'.format( + model_dir, "checkpoint_info", start_epoch, "yaml") + if os.path.exists(info_log_path): + # info = torch.load(info_path) + # self.elements["optimizer"].load_state_dict(info['optimizer']) + # for state in self.elements["optimizer"].values(): + # for k, v in state.items(): + # if isinstance(v, torch.Tensor): + # state[k] = v.cuda() + with open(info_log_path, 'r') as fin: + info = yaml.load(fin, Loader=yaml.FullLoader) + self.training_point[2] = info['step'] + elif os.path.exists(exist_model): + if utils.is_main_training(): + logger.info( + "Use {0} as the initial model to start transform-training.".format(exist_model)) + model.load_transform_state_dict( + torch.load(exist_model, map_location="cpu")) + else: + # Just use the raw initial model or initialize it again by some initial functions here + pass # Now, it means use the raw initial model + + # Select device + + model = self.select_device() + + # Original model is built in libs.nnet.framework.TopVirtualNnet, and it is not available after + # wrapped by DistributedDataParallel. So, to call functions of TopVirtualNnet conveniently, the + # self.elements["model_forward"] is set here to name DistributedDataParallel. + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + + self.elements["model"] = model.module + self.elements["model_forward"] = model + + + def save_model(self, mod="epoch",train_lr=None,valid_loss=None): + assert mod in ["epoch", "iter", "cycle"] + if mod == "epoch": + model_name = self.training_point[0] + elif mod == "iter": + model_name = "{}.{}".format( + self.training_point[0], self.training_point[1]) + else: + model_name = "{}_cycle".format(self.cycle_point) + model_path = '{0}/{1}.{2}'.format( + self.params["model_dir"], model_name, self.params["suffix"]) + + # info = { + # 'optimizer': self.elements["optimizer"].state_dict(), + # 'step': self.training_point[2], + # } + info_log = { + 'train_lr': train_lr if train_lr else "see train.csv", + "next_lr": self.elements["optimizer"].state_dict()['param_groups'][0]['lr'], + 'epoch': self.training_point[0], + 'iter in epoch': self.training_point[1], + 'step': self.training_point[2], + 'valid_loss':valid_loss if valid_loss else "see train.csv" + } + # info_path = '{0}/{1}/{2}.{3}'.format( + # self.params["model_dir"], "checkpoint_info", model_name, "info") + # info_log_path = re.sub('.info$', '.yaml', info_path) + info_log_path = '{0}/{1}/{2}.{3}'.format( + self.params["model_dir"], "checkpoint_info", model_name, "yaml") + logger.info("Save model to {0}. \n epoch/iter: {1}/{2}. cur_step: {3}".format(model_path, self.training_point[0], + self.training_point[1], self.training_point[2])) + torch.save(self.elements["model"].state_dict(), model_path) + # torch.save(info, info_path) + with open(info_log_path, 'w') as fout: + data = yaml.dump(info_log) + fout.write(data) + + def run(self): + raise NotImplementedError + + @for_lr_finder_new + def lr_finder_compute(self, train_batch): + # Only train_batch parameter for it's always main metric. + raise NotImplementedError + + def run_lr_finder(self): + # Implement this function w.r.t self.lr_finder_compute(). + raise NotImplementedError + + +class SimpleTrainer(_BaseTrainer): + """One input and one output. + """ + + def __init__(self, *args, **kwargs): + super(SimpleTrainer, self).__init__(*args, **kwargs) + self.num_batch=0 + def train_one_batch(self, batch, step_lr=True): + """A normal training core without fetching data from iterator. + """ + model = self.elements["model"] + model_forward = self.elements["model_forward"] + optimizer = self.elements["optimizer"] + device = utils.get_device(self.elements["model"]) + + # model level warm up. just for transformer. + cur_step = self.training_point[2] + warmup = cur_step / self.warmup_steps if self.warmup_steps >0 else 1.0 + warmup = torch.FloatTensor([warmup]) + if not model.training: + model.train() + + if self.params["nan_debug"]: + device = utils.get_device(self.elements["model"]) + inputs = torch.load( + "{0}/nan.batch".format(self.params["model_dir"])).to(device) + targets = torch.load( + "{0}/nan.targets".format(self.params["model_dir"])).to(device) + self.elements["model"].load_state_dict(torch.load("{0}/nan.params".format(self.params["model_dir"]), + map_location="cpu")) + self.elements["model"].to(device) + else: + inputs, targets, feats_lens = batch + feats_lens = (feats_lens*inputs.shape[2]).long() + input_list = [inputs,feats_lens] + + # model level warm up. just for transformer. + if self.warmup_steps >0: + input_list.append(warmup) + context = None + + # Disable gradient synchronizations across DDP processes in first background. + if self.use_ddp : + context = model_forward.no_sync + + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + + # first forward-backward pass + with context(): + # Managing automatic mixed precision (Leo 2021-11-08) + with torch.cuda.amp.autocast(self.scaler is not None): + + enable_running_stats(model_forward) + + loss = model.get_loss(model_forward(*input_list), targets)/self.params["accum_grad"] + # loss = model_forward(inputs, targets)/self.params["accum_grad"] + + if self.params["use_amp"]: + self.scaler.scale(loss).backward() + + else: + loss.backward() + + loss.detach() # For safe. + loss = loss.item()*self.params["accum_grad"] + accuracy = model.get_accuracy(targets) if self.params["compute_accuracy"] else None + # for name, param in model.named_parameters(): + # if param.grad is None: + # print(name) + # sys.exit() + + + + # Use mixed precision training (Leo 2021-11-08) + if self.params["use_amp"]: + self.scaler.unscale_(optimizer) + if not self.modify_grad() and not self.params["skip_nan_batch"]: + torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"])) + torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"])) + torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"])) + raise RuntimeError('There is Nan problem in epoch/iter: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[0], + self.training_point[1], "{0}/nan.*".format(self.params["model_dir"]))) + else: + flag_step1 = self.scaler.step(optimizer,stage=1) + + self.scaler.update() + else: + if not self.modify_grad(): + if not self.params["skip_nan_batch"]: + torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"])) + torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"])) + torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"])) + raise RuntimeError('There is Nan problem in epoch/iter: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[0], + self.training_point[1], "{0}/nan.*".format(self.params["model_dir"]))) + else: + flag_step1=optimizer.step(stage=1) + if flag_step1: + stop_flag = torch.zeros(1).to(device) + else: + stop_flag = torch.ones(1).to(device) + if utils.use_ddp(): + dist.all_reduce(stop_flag,op=dist.ReduceOp.SUM) + + dist.barrier() + stop_flag = bool(stop_flag) + flag_step2 = False + if not stop_flag: + + # Managing automatic mixed precision (Leo 2021-11-08) + with torch.cuda.amp.autocast(self.scaler is not None): + disable_running_stats(model_forward) + loss1 = model.get_loss(model_forward(*input_list), targets)/self.params["accum_grad"] + # loss = model_forward(inputs, targets)/self.params["accum_grad"] + + if self.params["use_amp"]: + self.scaler.scale(loss1).backward() + else: + loss1.backward() + loss1.detach() + + # Use mixed precision training (Leo 2021-11-08) + if self.params["use_amp"]: + self.scaler.unscale_(optimizer) + if not self.modify_grad() and not self.params["skip_nan_batch"]: + torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"])) + torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"])) + torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"])) + raise RuntimeError('There is Nan problem in epoch/iter: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[0], + self.training_point[1], "{0}/nan.*".format(self.params["model_dir"]))) + else: + flag_step2 = self.scaler.step(optimizer,stage=2) + + self.scaler.update() + else: + if not self.modify_grad(): + if not self.params["skip_nan_batch"]: + torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"])) + torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"])) + torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"])) + raise RuntimeError('There is Nan problem in epoch/iter: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[0], + self.training_point[1], "{0}/nan.*".format(self.params["model_dir"]))) + else: + flag_step2 = optimizer.step(stage=2) + + + self.training_point[2] += 1 # update step + + if step_lr: + self.step_lr(loss,accuracy,optimizer,self.elements["lr_scheduler"]) + + if flag_step1 and flag_step2 is not True: + optimizer.back_w() + optimizer.zero_grad() + self.num_batch += 1 + + return loss, accuracy + + # clip grad + def modify_grad(self): + grad_norm = torch.nn.utils.clip_grad_norm_( + self.elements["model"].parameters(), self.params["max_change"]) + if not torch.isfinite(grad_norm): + logger.warning("Grad:{0} is not finite in epoch/iter: {1}/{2}".format(grad_norm, self.training_point[0],self.training_point[1])) + if self.params["nan_debug"]: + raise RuntimeError( + "[NOT OK] Nan is still found in this debug.") + return False + else: + if self.params["nan_debug"]: + raise RuntimeError( + "[OK] There is no nan found for this debug.") + return True + + def compute_validation(self, data_loader): + """A normal evaluation core. + """ + model = self.elements["model"] + model_forward = self.elements["model_forward"] + train_status = model.training # Record status. + model.eval() + + loss = 0. + accuracy = 0. if self.params["compute_valid_accuracy"] else None + num_samples = 0 + with torch.no_grad(): + for idx,this_data in enumerate(data_loader): + inputs, targets, feats_lens = this_data + feats_lens = (feats_lens*inputs.shape[2]).long() + num_utts = targets.size(0) + input_list = [inputs,feats_lens] + if num_utts == 0: + continue + # in valid stage, DO NOT call ddp model, for ddp model is in JOIN context wrapper. + # Leo 2022-02-03 + loss += model.get_loss(model(*input_list), + targets).item() * len(targets) + + # loss += model_forward(inputs,targets).item() * len(targets) + num_samples += len(targets) + + if self.params["compute_valid_accuracy"]: + # This will occupy extra GPU memory. + accuracy += model.get_accuracy(targets) * len(targets) + if idx > self.params["compute_batch_num_valid"]-1: + break + avg_loss = loss/num_samples + avg_accuracy = accuracy / num_samples if self.params["compute_valid_accuracy"] else None + if train_status: + model.train() + + return avg_loss, avg_accuracy + + def step_lr(self,train_loss,train_acc,base_optimizer,lr_scheduler): + + # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler + # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau + # and some simple schedulers whose step() parameter is 'epoch' only are supported. + + valid_dataloader=self.elements["data"].valid_loader + + lr_scheduler_params = { + "training_point": self.training_point} + valid_loss = None + valid_computed = False + if lr_scheduler.name == "reduceP" and lr_scheduler.is_reduce_point(self.training_point): + assert valid_dataloader is not None + valid_loss, valid_acc = self.compute_validation(valid_dataloader) + lr_scheduler_params["valid_metric"] = (valid_loss, valid_acc) + valid_computed = True + if utils.is_main_training(): + if valid_computed or (valid_dataloader is not None and self.reporter.is_report(self.training_point)): + if not valid_computed: + valid_loss, valid_acc = self.compute_validation(valid_dataloader) + + valid_computed = False + # real_snapshot is set for tensorboard to avoid workspace problem + real_snapshot = {"train_loss": train_loss, "valid_loss": valid_loss, + "train_acc": train_acc*100, "valid_acc": valid_acc*100} + snapshot = {"train_loss": "{0:.6f}".format(train_loss), "valid_loss": "{0:.6f}".format(valid_loss), + "train_acc": "{0:.2f}".format(train_acc*100), "valid_acc": "{0:.2f}".format(valid_acc*100), + "total_dur":self.origin_total_dur,"num_sample":self.num_sample,"real": real_snapshot} + # For ReduceLROnPlateau. + lr_scheduler_params["valid_metric"] = (valid_loss, valid_acc) + else: + real_snapshot = { + "train_loss": train_loss, "train_acc": train_acc*100} + snapshot = {"train_loss": "{0:.6f}".format(train_loss), "valid_loss": "", + "train_acc": "{0:.2f}".format(train_acc*100), "valid_acc": "", + "total_dur":self.origin_total_dur,"num_sample":self.num_sample,"real": real_snapshot} + training_point = (self.training_point[0],self.training_point[1],self.training_point[2]) + self.train_lr = base_optimizer.state_dict()['param_groups'][0]['lr'] + self.reporter.update(snapshot,training_point,self.train_lr) + if lr_scheduler is not None: + # It is not convenient to wrap lr_scheduler (doing). + if isinstance(lr_scheduler, LRSchedulerWrapper): + lr_scheduler.step(**lr_scheduler_params) + if utils.is_main_training(): + current_lr = base_optimizer.state_dict()['param_groups'][0]['lr'] + if lr_scheduler.name == "reduceP": + if current_lr < self.last_lr: + self.last_lr = current_lr + self.save_model(mod="iter",train_lr=self.train_lr,valid_loss=valid_loss) + elif current_lr <= lr_scheduler.min_lr and lr_scheduler.is_reduce_point(self.training_point): + self.save_model(mod="iter",train_lr=self.train_lr,valid_loss=valid_loss) + + if lr_scheduler.is_cycle_point(self.training_point): + self.cycle_point+=1 + self.save_model(mod="cycle",train_lr=self.train_lr,valid_loss=valid_loss) + else: + # For some pytorch lr_schedulers, but it is not available for all. + lr_scheduler.step(self.training_point[0]) + + + def run(self): + """Main function to start a training process. + """ + try: + self.init_training() + + if utils.is_main_training(): + self.reporter = Reporter(self) + + start_epoch = self.params["start_epoch"] + epochs = self.params["epochs"] + data = self.elements["data"] + model = self.elements["model"] + # See init_training. + model_forward = self.elements["model_forward"] + self.train_lr = self.elements["optimizer"].state_dict()['param_groups'][0]['lr'] + self.last_lr = self.elements["optimizer"].state_dict()['param_groups'][0]['lr'] + + if utils.is_main_training(): + logger.info("Training will run for {0} epochs.".format(epochs)) + if utils.is_main_training() and self.params["accum_grad"] > 1: + logger.info("using accumulate grad,accum_num: {}".format( + self.params["accum_grad"])) + + if isinstance(model_forward, torch.nn.parallel.DistributedDataParallel): + + model_context = model_forward.join(throw_on_early_termination=True) + else: + model_context = nullcontext() + device = utils.get_device(self.elements["model"]) + stop_training = torch.zeros(1).to(device) + with model_context: + for this_epoch in range(start_epoch, epochs+5): + # In case the uneven data in different ranks when ddp training, + # here we design a more epoch for sub-ranks to ensure the main rank can broadcasting. + if utils.is_main_training: + if this_epoch == epochs: # skip the last+1 epoch + stop_training = torch.ones(1).to(device) + + self.training_point[0]+=1 + data.train_loader.dataset.set_epoch(this_epoch) + if utils.is_main_training() and self.training_point[1]==0: self.origin_total_dur,self.num_sample=data.train_loader.dataset.get_data_dur() + + + # with model_context: + + for _, batch in enumerate(data.train_loader, 0): + # It is important for reporter. + dist.barrier() + if utils.use_ddp():dist.all_reduce(stop_training,op=dist.ReduceOp.SUM) + if stop_training: + break + self.training_point[1] +=1 + + num_utts = batch[0].size(0) + + if num_utts == 0: + continue + if model.use_step: + step_point = (self.training_point[0],self.training_point[2]) + model.step_iter(*step_point) + + loss, acc = self.train_one_batch(batch) + + model.backward_step(*self.training_point) + if stop_training: + break + if utils.is_main_training():self.save_model(train_lr=self.train_lr) + self.training_point[1] =0 + + + if utils.is_main_training():self.reporter.finish() + if utils.is_main_training(): + final_model_name = "{}_cycle".format(self.cycle_point) if self.cycle_point else epochs + final_model_path = os.path.join(self.params["model_dir"],'final.params') + if os.path.exists(final_model_path) or os.path.islink(final_model_path): + os.remove(final_model_path) + + os.symlink('{0}/{1}.{2}'.format(self.params["model_dir"], final_model_name, self.params["suffix"]), final_model_path) + except BaseException as e: + if utils.use_ddp():utils.cleanup_ddp() + if not isinstance(e, KeyboardInterrupt):traceback.print_exc() + sys.exit(1) + + @for_lr_finder_new + def lr_finder_compute(self, train_batch): + model = self.elements["model"] + if model.use_step: + step_point = (self.training_point[0],self.training_point[2]) + model.step_iter(*step_point) + loss, acc = self.train_one_batch(train_batch,step_lr=False) + model.backward_step(*self.training_point) + valid_loss, valid_acc=0,0 + if utils.is_main_training(): + valid_loss, valid_acc = self.compute_validation( + self.elements["data"].valid_loader) + return ["train_loss", "train_acc", "valid_loss", "valid_acc"], [loss, acc, valid_loss, valid_acc] + + def run_lr_finder(self, save_file: str, comment=None, init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98): + self.init_training() + log_dir = self.params["model_dir"] + "/log/" # For tensorboardX + if comment is not None: + save_file = comment + "-" + save_file + save_file = log_dir + save_file + log_lrs, values_matrix = self.lr_finder_compute(self.elements["data"].train_loader, self.elements["optimizer"], + init_lr=init_lr, final_lr=final_lr, num_iters=num_iters, beta=beta, + log_dir=log_dir, comment=comment) + + if utils.is_main_training(): + df = pd.DataFrame(np.vstack([log_lrs, values_matrix]).T, + columns=["log_lr", "train_loss", "train_acc", "valid_loss", "valid_acc"]) + logger.info("Save lr finder values to {}.".format(save_file)) + df.to_csv(save_file) + if utils.use_ddp():utils.cleanup_ddp() + sys.exit(1) + +class MultitaskTrainer(_BaseTrainer): + """One input and multi-output corresponding to different tasks, such as one + is speaker classfication and another is phones classfication. + """ + pass + + +class GANTrainer(_BaseTrainer): + """This is for GAN. + """ + pass + + +# Function ✿ +def add_gaussian_noise_to_grad(model, t, n=0.1, gamma=0.55): + """ADDING GRADIENT NOISE IMPROVES LEARNING FOR VERY DEEP NETWORKS. + """ + var = n/(1+t)**gamma + for param in model.params(): + param.grad += to_device(model, torch.normal(0, var, param.size())) + + + +def disable_running_stats(model): + def _disable(module): + if isinstance(module, _BatchNorm): + module.backup_momentum = module.momentum + module.momentum = 0 + + model.apply(_disable) + +def enable_running_stats(model): + def _enable(module): + if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"): + module.momentum = module.backup_momentum + + model.apply(_enable) \ No newline at end of file diff --git a/pytorch/model/ecapa_tdnn_xvector.py b/pytorch/model/ecapa_tdnn_xvector.py old mode 100644 new mode 100755 index 5f79f1e..bf7022d --- a/pytorch/model/ecapa_tdnn_xvector.py +++ b/pytorch/model/ecapa_tdnn_xvector.py @@ -1,3 +1,9 @@ + +# Copyright xmuspeech (Author: Leo 2022-05-27) +# refs: +# 1. ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification +# https://arxiv.org/abs/2005.07143 + import torch import torch.nn as nn import torch.nn.functional as F @@ -6,10 +12,6 @@ sys.path.insert(0, 'subtools/pytorch') import libs.support.utils as utils from libs.nnet import * -# Copyright xmuspeech (Author: Leo 2022-05-27) -# refs: -# 1. ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification -# https://arxiv.org/abs/2005.07143 class Res2NetBlock(torch.nn.Module): @@ -34,7 +36,7 @@ class Res2NetBlock(torch.nn.Module): >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3) >>> out_tensor = layer(inp_tensor).transpose(1, 2) >>> out_tensor.shape - torch.Size([8, 120, 64]) + torch.Size([8, 120, 64]) """ def __init__( @@ -231,6 +233,8 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin "curricular": False} default_step_params = { + "margin_warm":False, + "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0}, "T": None, "m": False, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4, "s": False, "s_tuple": (30, 12), "s_list": None, @@ -332,6 +336,8 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin if margin_loss: self.loss = MarginSoftmaxLoss( embd_dim, num_targets, **margin_loss_params) + if self.use_step and self.step_params["margin_warm"]: + self.margin_warm = MarginWarm(**step_params["margin_warm_conf"]) else: self.loss = SoftmaxLoss(embd_dim, num_targets) # self.loss = AngleLoss(embd_dim,num_targets) @@ -339,7 +345,7 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin self.loss, self.mixup) if mixup else None # An example to using transform-learning without initializing loss.affine parameters self.transform_keys = ["layer2", "layer3", - "layer4", "conv", "stats", "fc1", "fc2"] + "layer4", "stats", "mfa", "bn_stats", "fc1", "fc2", "loss"] if margin_loss and transfer_from == "softmax_loss": # For softmax_loss to am_softmax_loss @@ -348,7 +354,7 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin @torch.jit.unused @utils.for_device_free - def forward(self, x): + def forward(self, x, x_len: torch.Tensor=torch.empty(0)): x = self.layer1(x) x1 = self.layer2(x) x2 = self.layer3(x+x1) @@ -442,27 +448,28 @@ def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torc return xvector @torch.jit.export - def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 10000, isMatrix: bool = True): - if isMatrix: - input = torch.unsqueeze(input, dim=0) - input = input.transpose(1, 2) - num_frames = input.shape[2] - num_split = (num_frames + maxChunk - 1) // maxChunk - split_size = num_frames // num_split - offset = 0 - embedding_stats = torch.zeros(1, self.embd_dim, 1) - for _ in range(0, num_split-1): - this_embedding = self.extract_embedding_jit( - input[:, :, offset:offset+split_size], position) - offset += split_size - embedding_stats += split_size*this_embedding - - last_embedding = self.extract_embedding_jit( - input[:, :, offset:], position) - - embedding = (embedding_stats + (num_frames-offset) - * last_embedding) / num_frames - return torch.squeeze(embedding.transpose(1, 2)).cpu() + def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True): + with torch.no_grad(): + if isMatrix: + input = torch.unsqueeze(input, dim=0) + input = input.transpose(1, 2) + num_frames = input.shape[2] + num_split = (num_frames + maxChunk - 1) // maxChunk + split_size = num_frames // num_split + offset = 0 + embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device) + for _ in range(0, num_split-1): + this_embedding = self.extract_embedding_jit( + input[:, :, offset:offset+split_size], position) + offset += split_size + embedding_stats += split_size*this_embedding + + last_embedding = self.extract_embedding_jit( + input[:, :, offset:], position) + + embedding = (embedding_stats + (num_frames-offset) + * last_embedding) / num_frames + return torch.squeeze(embedding.transpose(1, 2)).cpu() @torch.jit.export def embedding_dim(self) -> int: @@ -488,7 +495,8 @@ def step(self, epoch, this_iter, epoch_batchs): current_postion = epoch*epoch_batchs + this_iter lambda_factor = max(self.step_params["lambda_0"], self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch) @@ -509,10 +517,15 @@ def step(self, epoch, this_iter, epoch_batchs): def step_iter(self, epoch, cur_step): # For iterabledataset if self.use_step: + if self.step_params["margin_warm"]: + offset_margin, lambda_m = self.margin_warm.step(cur_step) + lambda_m = max(1e-3,lambda_m) + self.loss.step(lambda_m,offset_margin) if self.step_params["m"]: lambda_factor = max(self.step_params["lambda_0"], self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step) diff --git a/pytorch/model/extended_xvector.py b/pytorch/model/extended_xvector.py old mode 100644 new mode 100755 index 85d5d31..c539144 --- a/pytorch/model/extended_xvector.py +++ b/pytorch/model/extended_xvector.py @@ -49,8 +49,9 @@ def init(self, inputs_dim, num_targets, extend=True, nonlinearity="relu", self.transform_keys = ["tdnn1","tdnn2","tdnn3","tdnn4","tdnn5","stats","tdnn6","tdnn7", "ex_tdnn1","ex_tdnn2","ex_tdnn3","ex_tdnn4","ex_tdnn5"] + @torch.jit.unused @utils.for_device_free - def forward(self, inputs): + def forward(self, inputs, x_len: torch.Tensor=torch.empty(0)): """ @inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index] """ diff --git a/pytorch/model/factored_xvector.py b/pytorch/model/factored_xvector.py old mode 100644 new mode 100755 index 9699bc8..0ed7053 --- a/pytorch/model/factored_xvector.py +++ b/pytorch/model/factored_xvector.py @@ -61,7 +61,7 @@ def init(self, inputs_dim, num_targets, nonlinearity="relu", semi_orth=True,embd @torch.jit.unused @utils.for_device_free - def forward(self, inputs): + def forward(self, inputs, x_len: torch.Tensor=torch.empty(0)): """ @inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index] """ @@ -158,24 +158,28 @@ def extract_embedding_jit(self, inputs: torch.Tensor, position: str = 'far') -> return xvector @torch.jit.export - def extract_embedding_whole(self,input:torch.Tensor,position:str='far',maxChunk:int=10000,isMatrix:bool=True): - if isMatrix: - input=torch.unsqueeze(input,dim=0) - input=input.transpose(1,2) - num_frames = input.shape[2] - num_split = (num_frames + maxChunk - 1) // maxChunk - split_size = num_frames // num_split - offset=0 - embedding_stats = torch.zeros(1,self.embd_dim,1) - for _ in range(0, num_split-1): - this_embedding = self.extract_embedding_jit(input[:, :, offset:offset+split_size],position) - offset += split_size - embedding_stats += split_size*this_embedding - - last_embedding = self.extract_embedding_jit(input[:, :, offset:],position) - - embedding = (embedding_stats + (num_frames-offset) * last_embedding) / num_frames - return torch.squeeze(embedding.transpose(1,2)).cpu() + def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True): + with torch.no_grad(): + if isMatrix: + input = torch.unsqueeze(input, dim=0) + input = input.transpose(1, 2) + num_frames = input.shape[2] + num_split = (num_frames + maxChunk - 1) // maxChunk + split_size = num_frames // num_split + offset = 0 + embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device) + for _ in range(0, num_split-1): + this_embedding = self.extract_embedding_jit( + input[:, :, offset:offset+split_size], position) + offset += split_size + embedding_stats += split_size*this_embedding + + last_embedding = self.extract_embedding_jit( + input[:, :, offset:], position) + + embedding = (embedding_stats + (num_frames-offset) + * last_embedding) / num_frames + return torch.squeeze(embedding.transpose(1, 2)).cpu() @torch.jit.export def embedding_dim(self) -> int: diff --git a/pytorch/model/multi_task_xvector_fix.py b/pytorch/model/multi_task_xvector_fix.py old mode 100644 new mode 100755 index efc681d..c5f653d --- a/pytorch/model/multi_task_xvector_fix.py +++ b/pytorch/model/multi_task_xvector_fix.py @@ -323,7 +323,8 @@ def step(self, epoch, this_iter, epoch_batchs): current_postion = epoch*epoch_batchs + this_iter lambda_factor = max(self.step_params["lambda_0"], self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) - self.loss_spk.step(lambda_factor) + lambda_m = 1/(1 + lambda_factor) + self.loss_spk.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch) @@ -346,7 +347,8 @@ def step_iter(self, epoch, cur_step): if self.step_params["m"]: lambda_factor = max(self.step_params["lambda_0"], self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + lambda_m = 1/(1 + lambda_factor) + self.loss_spk.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step) diff --git a/pytorch/model/repvgg_xvector.py b/pytorch/model/repvgg_xvector.py old mode 100644 new mode 100755 index fe0e200..4106bce --- a/pytorch/model/repvgg_xvector.py +++ b/pytorch/model/repvgg_xvector.py @@ -12,7 +12,7 @@ from libs.nnet import * class RepVggXvector(TopVirtualNnet): - """ A repvgg vector framework """ + """ A repvgg vector framework """ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropout=0.,training=True, extracted_embedding="near", deploy=False,repvgg_config={}, pooling="statistics", pooling_params={}, fc1=False, fc1_params={}, fc2_params={}, margin_loss=False, margin_loss_params={}, @@ -21,13 +21,13 @@ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropou ## Params. default_repvgg_config = { "auto_model" : False, - "auto_model_name" : "RepVGG_B0", - "block": "RepVGG", + "auto_model_name" : "RepVGG_A1", + "block": "RepSPK", "repvgg_params":{ - "num_blocks": [4, 6, 16, 1], + "num_blocks": [2, 4, 14, 1], "strides":[1,1,2,2,2], "base_width": 32, - "width_multiplier":[1, 1, 1, 2.5], + "width_multiplier": [1, 1, 1, 2.5], "norm_layer_params":{"momentum":0.5, "affine":True}, "override_groups_map": None, "use_se": False, @@ -57,6 +57,8 @@ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropou } default_step_params = { + "margin_warm":False, + "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0}, "T":None, "m":False, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4, "s":False, "s_tuple":(30, 12), "s_list":None, @@ -125,13 +127,15 @@ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropou if training : if margin_loss: self.loss = MarginSoftmaxLoss(embd_dim, num_targets, **margin_loss_params) + if self.use_step and self.step_params["margin_warm"]: + self.margin_warm = MarginWarm(**step_params["margin_warm_conf"]) elif adacos: self.loss = AdaCos(embd_dim,num_targets) else: self.loss = SoftmaxLoss(embd_dim, num_targets) # An example to using transform-learning without initializing loss.affine parameters - self.transform_keys = ["repvgg", "stats", "fc1", "fc2"] + self.transform_keys = ["repvgg", "stats", "fc1", "fc2","loss"] # self.transform_keys = ["resnet"] if margin_loss and transfer_from == "softmax_loss": @@ -140,7 +144,7 @@ def init(self, inputs_dim, num_targets, embd_dim=256,aug_dropout=0., tail_dropou @torch.jit.unused @utils.for_device_free - def forward(self, x): + def forward(self, x, x_len: torch.Tensor=torch.empty(0)): """ @inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index] """ @@ -233,24 +237,28 @@ def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torc return xvector @torch.jit.export - def extract_embedding_whole(self,input:torch.Tensor,position:str='near',maxChunk:int=10000,isMatrix:bool=True): - if isMatrix: - input=torch.unsqueeze(input,dim=0) - input=input.transpose(1,2) - num_frames = input.shape[2] - num_split = (num_frames + maxChunk - 1) // maxChunk - split_size = num_frames // num_split - offset=0 - embedding_stats = torch.zeros(1,self.embd_dim,1) - for _ in range(0, num_split-1): - this_embedding = self.extract_embedding_jit(input[:, :, offset:offset+split_size],position) - offset += split_size - embedding_stats+=split_size*this_embedding - - last_embedding = self.extract_embedding_jit(input[:, :, offset:],position) - - embedding = (embedding_stats + (num_frames-offset) * last_embedding) / num_frames - return torch.squeeze(embedding.transpose(1,2)).cpu() + def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True): + with torch.no_grad(): + if isMatrix: + input = torch.unsqueeze(input, dim=0) + input = input.transpose(1, 2) + num_frames = input.shape[2] + num_split = (num_frames + maxChunk - 1) // maxChunk + split_size = num_frames // num_split + offset = 0 + embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device) + for _ in range(0, num_split-1): + this_embedding = self.extract_embedding_jit( + input[:, :, offset:offset+split_size], position) + offset += split_size + embedding_stats += split_size*this_embedding + + last_embedding = self.extract_embedding_jit( + input[:, :, offset:], position) + + embedding = (embedding_stats + (num_frames-offset) + * last_embedding) / num_frames + return torch.squeeze(embedding.transpose(1, 2)).cpu() @torch.jit.export def embedding_dim(self) -> int: @@ -277,9 +285,10 @@ def step(self, epoch, this_iter, epoch_batchs): if self.use_step: if self.step_params["m"]: current_postion = epoch*epoch_batchs + this_iter - lambda_factor = max(self.step_params["lambda_0"], - self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + lambda_factor = max(self.step_params["lambda_0"], + self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch) @@ -287,10 +296,12 @@ def step(self, epoch, this_iter, epoch_batchs): T_i = T_i * epoch_batchs if self.step_params["t"]: - self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i) + self.loss.t = self.compute_decay_value( + *self.step_params["t_tuple"], T_cur, T_i) if self.step_params["p"]: - self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i) + self.aug_dropout.p = self.compute_decay_value( + *self.step_params["p_tuple"], T_cur, T_i) if self.step_params["s"]: self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]] @@ -298,20 +309,26 @@ def step(self, epoch, this_iter, epoch_batchs): def step_iter(self, epoch, cur_step): # For iterabledataset if self.use_step: + if self.step_params["margin_warm"]: + offset_margin, lambda_m = self.margin_warm.step(cur_step) + lambda_m = max(1e-3,lambda_m) + self.loss.step(lambda_m,offset_margin) if self.step_params["m"]: lambda_factor = max(self.step_params["lambda_0"], - self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step) - if self.step_params["t"]: - self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i) + self.loss.t = self.compute_decay_value( + *self.step_params["t_tuple"], T_cur, T_i) if self.step_params["p"]: - self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i) + self.aug_dropout.p = self.compute_decay_value( + *self.step_params["p_tuple"], T_cur, T_i) if self.step_params["s"]: self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]] diff --git a/pytorch/model/resnet_xvector.py b/pytorch/model/resnet_xvector.py old mode 100644 new mode 100755 index 952eed5..d8f41ff --- a/pytorch/model/resnet_xvector.py +++ b/pytorch/model/resnet_xvector.py @@ -29,8 +29,8 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin "head_conv":True, "head_conv_params":{"kernel_size":3, "stride":1, "padding":1}, "head_maxpool":False, "head_maxpool_params":{"kernel_size":3, "stride":1, "padding":1}, "block":"BasicBlock", - "layers":[3, 4, 6, 3], - "planes":[32, 64, 128, 256], # a.k.a channels. + "layers": [3, 4, 6, 3], + "planes": [32, 64, 128, 256], # a.k.a channels. "use_se": False, "se_ratio": 4, "convXd":2, @@ -63,6 +63,8 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin } default_step_params = { + "margin_warm":False, + "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0}, "T":None, "m":False, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4, "s":False, "s_tuple":(30, 12), "s_list":None, @@ -85,7 +87,7 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin ## Nnet. self.aug_dropout = torch.nn.Dropout2d(p=aug_dropout) if aug_dropout > 0 else None - self.cmvn_=InputSequenceNormalization(**cmvn_params) if cmvn else None + self.cmvn_=InputSequenceNormalization(**cmvn_params) if cmvn else torch.nn.Identity() # [batch, 1, feats-dim, frames] for 2d and [batch, feats-dim, frames] for 1d. # Should keep the channel/plane is always in 1-dim of tensor (index-0 based). @@ -126,10 +128,14 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin if training : if margin_loss: self.loss = MarginSoftmaxLoss(resnet_params["planes"][3], num_targets, **margin_loss_params) + if self.use_step and self.step_params["margin_warm"]: + self.margin_warm = MarginWarm(**step_params["margin_warm_conf"]) else: self.loss = SoftmaxLoss(resnet_params["planes"][3], num_targets) # An example to using transform-learning without initializing loss.affine parameters + # self.transform_keys = ["resnet", "stats", "fc1", "fc2"] + self.transform_keys = ["resnet", "stats", "fc1", "fc2","loss.weight"] if margin_loss and transfer_from == "softmax_loss": @@ -138,12 +144,12 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin @torch.jit.unused @utils.for_device_free - def forward(self, x): + def forward(self, x, x_len: torch.Tensor=torch.empty(0)): """ @inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index] """ - x = self.auto(self.cmvn_,x) + x = self.self.cmvn_(x) x = self.auto(self.aug_dropout, x) # This auto function is equal to "x = layer(x) if layer is not None else x" for convenience. # [samples-index, frames-dim-index, frames-index] -> [samples-index, 1, frames-dim-index, frames-index] @@ -181,7 +187,7 @@ def extract_embedding(self, x): return: an 1-dimensional vector after processed by decorator """ # Tensor shape is not modified in libs.nnet.resnet.py for calling free, such as using this framework in cv. - x = self.auto(self.cmvn_,x) + x = self.cmvn_(x) x = x.unsqueeze(1) if self.convXd == 2 else x x = self.resnet(x) x = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) if self.convXd == 2 else x @@ -204,10 +210,10 @@ def extract_embedding(self, x): def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torch.Tensor: """ x: a 3-dimensional tensor with batch-dim = 1 or normal features matrix - return: an 1-dimensional vector after processed by decorator + return: an 1-dimensional vector after processed by decorator """ - x = self.auto(self.cmvn_,x) + x = self.cmvn_(x) # Tensor shape is not modified in libs.nnet.resnet.py for calling free, such as using this framework in cv. x = x.unsqueeze(1) if self.convXd == 2 else x x = self.resnet(x) @@ -230,24 +236,28 @@ def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torc return xvector @torch.jit.export - def extract_embedding_whole(self,input:torch.Tensor,position:str='near',maxChunk:int=10000,isMatrix:bool=True): - if isMatrix: - input=torch.unsqueeze(input,dim=0) - input=input.transpose(1,2) - num_frames = input.shape[2] - num_split = (num_frames + maxChunk - 1) // maxChunk - split_size = num_frames // num_split - offset=0 - embedding_stats = torch.zeros(1,self.embd_dim,1) - for _ in range(0, num_split-1): - this_embedding = self.extract_embedding_jit(input[:, :, offset:offset+split_size],position) - offset += split_size - embedding_stats += split_size*this_embedding - - last_embedding = self.extract_embedding_jit(input[:, :, offset:],position) - - embedding = (embedding_stats + (num_frames-offset) * last_embedding) / num_frames - return torch.squeeze(embedding.transpose(1,2)).cpu() + def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True): + with torch.no_grad(): + if isMatrix: + input = torch.unsqueeze(input, dim=0) + input = input.transpose(1, 2) + num_frames = input.shape[2] + num_split = (num_frames + maxChunk - 1) // maxChunk + split_size = num_frames // num_split + offset = 0 + embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device) + for _ in range(0, num_split-1): + this_embedding = self.extract_embedding_jit( + input[:, :, offset:offset+split_size], position) + offset += split_size + embedding_stats += split_size*this_embedding + + last_embedding = self.extract_embedding_jit( + input[:, :, offset:], position) + + embedding = (embedding_stats + (num_frames-offset) + * last_embedding) / num_frames + return torch.squeeze(embedding.transpose(1, 2)).cpu() @torch.jit.export def embedding_dim(self) -> int: @@ -274,9 +284,10 @@ def step(self, epoch, this_iter, epoch_batchs): if self.use_step: if self.step_params["m"]: current_postion = epoch*epoch_batchs + this_iter - lambda_factor = max(self.step_params["lambda_0"], - self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + lambda_factor = max(self.step_params["lambda_0"], + self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch) @@ -284,10 +295,12 @@ def step(self, epoch, this_iter, epoch_batchs): T_i = T_i * epoch_batchs if self.step_params["t"]: - self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i) + self.loss.t = self.compute_decay_value( + *self.step_params["t_tuple"], T_cur, T_i) if self.step_params["p"]: - self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i) + self.aug_dropout.p = self.compute_decay_value( + *self.step_params["p_tuple"], T_cur, T_i) if self.step_params["s"]: self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]] @@ -295,20 +308,26 @@ def step(self, epoch, this_iter, epoch_batchs): def step_iter(self, epoch, cur_step): # For iterabledataset if self.use_step: + if self.step_params["margin_warm"]: + offset_margin, lambda_m = self.margin_warm.step(cur_step) + lambda_m = max(1e-3,lambda_m) + self.loss.step(lambda_m,offset_margin) if self.step_params["m"]: lambda_factor = max(self.step_params["lambda_0"], - self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step) - if self.step_params["t"]: - self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i) + self.loss.t = self.compute_decay_value( + *self.step_params["t_tuple"], T_cur, T_i) if self.step_params["p"]: - self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i) + self.aug_dropout.p = self.compute_decay_value( + *self.step_params["p_tuple"], T_cur, T_i) if self.step_params["s"]: self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]] diff --git a/pytorch/model/snowdar_xvector.py b/pytorch/model/snowdar_xvector.py old mode 100644 new mode 100755 index 2a21f42..1d6b8d2 --- a/pytorch/model/snowdar_xvector.py +++ b/pytorch/model/snowdar_xvector.py @@ -47,7 +47,7 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False, "share":True, "affine_layers":1, "hidden_size":64, - "context":[0], + "context": [0], "stddev":True, "temperature":False, "fixed":True, @@ -65,6 +65,8 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False, } default_step_params = { + "margin_warm":False, + "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0}, "T":None, "m":False, "lambda_0":0, "lambda_b":1000, "alpha":5, "gamma":1e-4, "s":False, "s_tuple":(30, 12), "s_list":None, @@ -153,6 +155,8 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False, if training : if margin_loss: self.loss = MarginSoftmaxLoss(512, num_targets, **margin_loss_params) + if self.use_step and self.step_params["margin_warm"]: + self.margin_warm = MarginWarm(**step_params["margin_warm_conf"]) else: self.loss = SoftmaxLoss(512, num_targets) @@ -161,14 +165,14 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False, # An example to using transform-learning without initializing loss.affine parameters self.transform_keys = ["tdnn1","tdnn2","tdnn3","tdnn4","tdnn5","stats","tdnn6","tdnn7", "ex_tdnn1","ex_tdnn2","ex_tdnn3","ex_tdnn4","ex_tdnn5", - "se1","se2","se3","se4","loss"] + "se1","se2","se3","se4", "loss"] if margin_loss and transfer_from == "softmax_loss": # For softmax_loss to am_softmax_loss self.rename_transform_keys = {"loss.affine.weight":"loss.weight"} - + @torch.jit.unused @utils.for_device_free - def forward(self, inputs): + def forward(self, inputs, x_len: torch.Tensor=torch.empty(0)): """ @inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index] """ @@ -272,6 +276,67 @@ def extract_embedding(self, inputs): return xvector + def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torch.Tensor: + """ + inputs: a 3-dimensional tensor with batch-dim = 1 or normal features matrix + return: an 1-dimensional vector after processed by decorator + """ + + # Tensor shape is not modified in libs.nnet.resnet.py for calling free, such as using this framework in cv. + x = x.unsqueeze(1) + x = self.repvgg(x) + x = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) + x = self.stats(x) + + if position == "far" and self.fc1 is not None: + xvector = self.fc1.affine(x) + elif position == "near_affine": + if self.fc1 is not None: + x=self.fc1(x) + xvector = self.fc2.affine(x) + elif position == "near": + if self.fc1 is not None: + x=self.fc1(x) + xvector = self.fc2(x) + # xvector = F.normalize(xvector) + + else: + raise TypeError("Expected far or near position, but got {}".format(position)) + + return xvector + + @torch.jit.export + def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True): + with torch.no_grad(): + if isMatrix: + input = torch.unsqueeze(input, dim=0) + input = input.transpose(1, 2) + num_frames = input.shape[2] + num_split = (num_frames + maxChunk - 1) // maxChunk + split_size = num_frames // num_split + offset = 0 + embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device) + for _ in range(0, num_split-1): + this_embedding = self.extract_embedding_jit( + input[:, :, offset:offset+split_size], position) + offset += split_size + embedding_stats += split_size*this_embedding + + last_embedding = self.extract_embedding_jit( + input[:, :, offset:], position) + + embedding = (embedding_stats + (num_frames-offset) + * last_embedding) / num_frames + return torch.squeeze(embedding.transpose(1, 2)).cpu() + + @torch.jit.export + def embedding_dim(self) -> int: + """ Export interface for c++ call, return embedding dim of the model + """ + + return self.embd_dim + + def get_warmR_T(self,T_0, T_mult, epoch): n = int(math.log(max(0.05, (epoch / T_0 * (T_mult - 1) + 1)), T_mult)) T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1) @@ -288,9 +353,10 @@ def step(self, epoch, this_iter, epoch_batchs): if self.use_step: if self.step_params["m"]: current_postion = epoch*epoch_batchs + this_iter - lambda_factor = max(self.step_params["lambda_0"], - self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + lambda_factor = max(self.step_params["lambda_0"], + self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch) @@ -298,32 +364,39 @@ def step(self, epoch, this_iter, epoch_batchs): T_i = T_i * epoch_batchs if self.step_params["t"]: - self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i) + self.loss.t = self.compute_decay_value( + *self.step_params["t_tuple"], T_cur, T_i) if self.step_params["p"]: - self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i) + self.aug_dropout.p = self.compute_decay_value( + *self.step_params["p_tuple"], T_cur, T_i) if self.step_params["s"]: self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]] - def step_iter(self, epoch, cur_step): # For iterabledataset if self.use_step: + if self.step_params["margin_warm"]: + offset_margin, lambda_m = self.margin_warm.step(cur_step) + lambda_m = max(1e-3,lambda_m) + self.loss.step(lambda_m,offset_margin) if self.step_params["m"]: lambda_factor = max(self.step_params["lambda_0"], - self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) - self.loss.step(lambda_factor) + self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step) - if self.step_params["t"]: - self.loss.t = self.compute_decay_value(*self.step_params["t_tuple"], T_cur, T_i) + self.loss.t = self.compute_decay_value( + *self.step_params["t_tuple"], T_cur, T_i) if self.step_params["p"]: - self.aug_dropout.p = self.compute_decay_value(*self.step_params["p_tuple"], T_cur, T_i) + self.aug_dropout.p = self.compute_decay_value( + *self.step_params["p_tuple"], T_cur, T_i) if self.step_params["s"]: self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]] diff --git a/pytorch/model/transformer_xvector.py b/pytorch/model/transformer_xvector.py new file mode 100755 index 0000000..4bf2371 --- /dev/null +++ b/pytorch/model/transformer_xvector.py @@ -0,0 +1,485 @@ +# Copyright xmuspeech (Author: Leo 2022-07-18) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import sys +sys.path.insert(0, 'subtools/pytorch') +import libs.support.utils as utils +from libs.nnet import * + +def compute_statistics(x, m, dim: int=2, stddev: bool=True,eps: float=1e-5): + mean = (m * x).sum(dim) + + + if stddev: + # std = torch.sqrt( + # (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps) + # ) + std = torch.sqrt( + (torch.sum(m * (x ** 2), dim=dim) - mean ** 2).clamp(eps) + ) + else: + std = torch.empty(0) + return mean, std + +class AttentiveStatsPool(nn.Module): + def __init__(self, in_dim, hidden_size=128, time_attention=False, stddev=True): + super().__init__() + self.stddev = stddev + self.output_dim = in_dim*2 if self.stddev else in_dim + self.time_attention = time_attention + accept_dim = in_dim + # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. + if time_attention: + + accept_dim = in_dim*3 if self.stddev else in_dim*2 + + norm = LayerNorm(hidden_size,dim=1,eps=1e-5) + self.attention = nn.Sequential( + nn.Conv1d(accept_dim, hidden_size, kernel_size=1), + nn.ReLU(), + norm, + nn.Tanh(), + nn.Conv1d(hidden_size, in_dim, kernel_size=1) + ) + + # gn_num = 2 if self.stddev else 1 + # self.norm_stats = nn.GroupNorm(gn_num,self.output_dim ) + self.norm_stats = LayerNorm(self.output_dim,dim=1) + + + def forward(self, x, mask: torch.Tensor = torch.ones((0, 0, 0))): + B, C ,T = x.shape + + if mask.size(2) == 0 : + mask = torch.ones((B, 1, T)).to(x.device) + + if self.time_attention: + total = mask.sum(dim=2, keepdim=True).float() + mean, std = compute_statistics(x, mask / total,stddev = self.stddev) + mean = mean.unsqueeze(2).repeat(1, 1, T) + if self.stddev: + std = std.unsqueeze(2).repeat(1, 1, T) + x_in = [x,mean,std] + else: + x_in = [x,mean] + + x_in = torch.cat(x_in, dim=1) + + else: + x_in = x + + alpha = self.attention(x_in) + + alpha = alpha.masked_fill(mask == 0, float("-inf")) + + alpha = F.softmax(alpha, dim=2) + + mean, std = compute_statistics(x, alpha,stddev = self.stddev) + if self.stddev: + + out = torch.cat([mean, std], dim=1).unsqueeze(2) + else: + out = mean.unsqueeze(2) + + return self.norm_stats(out) + def get_output_dim(self): + return self.output_dim + + +class TransformerXvector(TopVirtualNnet): + def init(self, inputs_dim, num_targets, embd_dim=256,training=True, + extracted_embedding="near", mixup=False, mixup_alpha=1.0, pooling="ecpa-attentive", pooling_params={}, + transformer_type="conformer", transformer_params={},tansformer_out={}, fc1=False, fc1_params={}, fc2_params={}, + margin_loss=True, margin_loss_params={}, lsm_weight=0.0,use_step=False, step_params={}, transfer_from="softmax_loss",wenet_transfer=False): + + default_transformer_params = { + "attention_dim": 256, + "att_type": 'multi', # [multi, gau] gau attention don't suppport rel_pos. + "attention_heads": 4, + "gau_key": 64, # gau key dim. + "gau_units": 512, + "num_blocks": 6, + "dropout_rate": 0.1, + "layer_dropout":0., + "positionwise_layer_type": 'linear', # [linear, conv1d, conv1d-linear, gau, re_conv1d] + "positional_dropout_rate": 0.1, + "linear_units": 2048, + "positionwise_conv_kernel_size": 3, + "attention_dropout_rate": 0.0, + "attention_norm_args": { + "norm_method": "softmax", # [softmax, relu_plus, softmax_plus] + "train_len":300., # for softmax_plus. + }, + "input_layer": "conv2d", # [linear, conv2d2, conv2d, re_conv2d, conv2d6, conv2d8] + "pos_enc_type": "abs_pos", # [abs_pos, no_pos, rot_pos, rel_pos] + "cnn_module_kernel": 15, # for conformer + "use_cnn_module": True, # for conformer + "cnn_module_norm": 'layer_norm', # for conformer ['batch_norm', 'layer_norm'] + "static_chunk_size": 0, + "left_chunk_size": -1, + "use_dynamic_chunk": False, + "use_dynamic_left_chunk": False, + "combiner_type": "norm", # [norm, mfa, random_frame, random_layer] + "convfnn_blocks": 0 + } + default_tansformer_out = { + "out_dim": 1536, + "nonlinearity": 'swish', "nonlinearity_params": {"inplace": True}, + "bn-relu": False, + "bn": True, + "ln_replace": True, # replace BN with LN + "bn_params": {"momentum": 0.5, "affine": True, "track_running_stats": True} + } + + default_pooling_params = { + "hidden_size": 128, + "time_attention": False, + "stddev": True, + } + + default_fc_params = { + "nonlinearity": 'relu', "nonlinearity_params": {"inplace": True}, + "bn-relu": False, + "bn": True, + "ln_replace": True, # replace BN with LN + "bn_params": {"momentum": 0.5, "affine": True, "track_running_stats": True} + } + + default_margin_loss_params = { + "method": "am", "m": 0.2, + "feature_normalize": True, "s": 30, + "double": False, + "mhe_loss": False, "mhe_w": 0.01, + "inter_loss": 0., + "ring_loss": 0., + "curricular": False + } + + default_step_params = { + "margin_warm":False, + "margin_warm_conf":{"start_epoch":5.,"end_epoch":10.,"offset_margin":-0.2,"init_lambda":0.0}, + "T": None, + "m": True, "lambda_0": 0, "lambda_b": 1000, "alpha": 5, "gamma": 1e-4, + "s": False, "s_tuple": (30, 12), "s_list": None, + "t": False, "t_tuple": (0.5, 1.2), + "p": False, "p_tuple": (0.5, 0.1) + } + + self.use_step = use_step + self.step_params = step_params + self.extracted_embedding = extracted_embedding + + transformer_params = utils.assign_params_dict( + default_transformer_params, transformer_params,support_unknow=True) + tansformer_out = utils.assign_params_dict(default_tansformer_out, tansformer_out) + pooling_params = utils.assign_params_dict( + default_pooling_params, pooling_params) + fc1_params = utils.assign_params_dict(default_fc_params, fc1_params) + fc2_params = utils.assign_params_dict(default_fc_params, fc2_params) + margin_loss_params = utils.assign_params_dict( + default_margin_loss_params, margin_loss_params) + step_params = utils.assign_params_dict( + default_step_params, step_params) + self.embd_dim = embd_dim + self.mixup = Mixup(alpha=mixup_alpha) if mixup else None + + if transformer_type == "transformer": + transformer_backbone = TransformerEncoder + elif transformer_type == "conformer": + transformer_backbone = ConformerEncoder + elif transformer_type == "re_conformer": + transformer_backbone = ReConformerEncoder + else: + raise ValueError("unknown transformer_type: " + transformer_type) + self.transformer = transformer_backbone(inputs_dim,**transformer_params) + + self.transform_out = ReluBatchNormTdnnLayer(self.transformer.output_size(),tansformer_out["out_dim"],**tansformer_out) + # Pooling + stddev = pooling_params.pop("stddev") + + if pooling == "ecpa-attentive": + self.stats = AttentiveStatsPool( + tansformer_out["out_dim"], stddev=stddev,**pooling_params) + + self.fc1 = ReluBatchNormTdnnLayer( + self.stats.get_output_dim(), embd_dim, **fc1_params) if fc1 else None + else: + raise ValueError("Only supoort asp for conformer now.") + + + + if fc1: + fc2_in_dim = embd_dim + else: + fc2_in_dim = self.stats.get_output_dim() + self.fc2 = ReluBatchNormTdnnLayer(fc2_in_dim, embd_dim, **fc2_params) + + + + # print("num_targets---------------",num_targets) + # Loss + # Do not need when extracting embedding. + if training: + if margin_loss: + self.loss = MarginSoftmaxLoss( + embd_dim, num_targets, label_smoothing=lsm_weight,**margin_loss_params) + if self.use_step and self.step_params["margin_warm"]: + self.margin_warm = MarginWarm(**step_params["margin_warm_conf"]) + + else: + self.loss = SoftmaxLoss(embd_dim, num_targets,label_smoothing=lsm_weight) + # self.loss = AngleLoss(embd_dim,num_targets) + self.wrapper_loss = MixupLoss( + self.loss, self.mixup) if mixup else None + # An example to using transform-learning without initializing loss.affine parameters + self.transform_keys = ["transformer", "transform_out", "stats", "fc1", "fc2", "loss"] + + if margin_loss and transfer_from == "softmax_loss": + # For softmax_loss to am_softmax_loss + self.rename_transform_keys = { + "loss.affine.weight": "loss.weight"} + self.wenet_transfer = wenet_transfer + def load_transform_state_dict(self, state_dict): + """It is used in transform-learning. + """ + assert isinstance(self.transform_keys, list) + assert isinstance(self.rename_transform_keys, dict) + remaining = {} + for k,v in state_dict.items(): + + # if "train_len" in k: + # print(k,v) + if self.wenet_transfer: + + k = k.replace("encoder.","transformer.") + + # k = k.replace("embed.","noembed.") + if k.split('.')[0] in self.transform_keys or k in self.transform_keys: + k = utils.key_to_value(self.rename_transform_keys, k, False) + remaining[k] = v + # for k in remaining.keys(): + # print(k) + + # assert 1==0 + + self.load_state_dict(remaining, strict=False) + return self + + + @torch.jit.unused + @utils.for_device_free + def forward(self, x, x_len,warmup: torch.Tensor=torch.FloatTensor([1.0])): + # [samples-index, frames-dim-index, frames-index] -> [samples-index, frames-index, frames-dim-index] + x = x.transpose(1,2) + + x, masks = self.transformer(x,x_len,warmup=float(warmup)) + + x = x.transpose(1,2) + + x = self.transform_out(x) + + x = self.stats(x,masks) + if len(x.shape) != 3: + x = x.unsqueeze(dim=2) + + with torch.cuda.amp.autocast(enabled=False): + x = self.auto(self.fc1, x) + x = self.fc2(x) + + return x + + @utils.for_device_free + def get_loss(self, inputs, targets): + """Should call get_loss() after forward() with using Xvector model function. + e.g.: + m=Xvector(20,10) + loss=m.get_loss(m(inputs),targets) + + model.get_loss [custom] -> loss.forward [custom] + | + v + model.get_accuracy [custom] -> loss.get_accuracy [custom] -> loss.compute_accuracy [static] -> loss.predict [static] + """ + if self.wrapper_loss is not None: + return self.wrapper_loss(inputs, targets) + else: + return self.loss(inputs, targets) + + @utils.for_device_free + def get_accuracy(self, targets): + """Should call get_accuracy() after get_loss(). + @return: return accuracy + """ + if self.wrapper_loss is not None: + return self.wrapper_loss.get_accuracy(targets) + else: + return self.loss.get_accuracy(targets) + + @for_extract_embedding(maxChunk=300, isMatrix=True) + def extract_embedding(self, x): + x_lens = torch.LongTensor([x.shape[2]]).to(x.device) + + x = x.transpose(1,2) + x, _ = self.transformer(x,x_lens) + + x = x.transpose(1,2) + x = self.transform_out(x) + x = self.stats(x) + + if len(x.shape) != 3: + x = x.unsqueeze(dim=2) + if self.extracted_embedding == "far": + assert self.fc1 is not None + xvector = self.fc1.affine(x) + elif self.extracted_embedding == "near_affine": + x = self.auto(self.fc1, x) + xvector = self.fc2.affine(x) + elif self.extracted_embedding == "near": + x = self.auto(self.fc1, x) + xvector = self.fc2(x) + else: + raise TypeError("Expected far or near position, but got {}".format( + self.extracted_embedding)) + return xvector + + + + + def extract_embedding_jit(self, x: torch.Tensor, position: str = 'near') -> torch.Tensor: + x_lens = torch.tensor([x.shape[2]]).to(x.device) + + x = x.transpose(1,2) + + x, _ = self.transformer(x,x_lens) + + x = x.transpose(1,2) + + x = self.transform_out(x) + x = self.stats(x) + if len(x.shape) != 3: + x = x.unsqueeze(dim=2) + if position == "far" and self.fc1 is not None: + xvector = self.fc1.affine(x) + elif position == "near_affine": + if self.fc1 is not None: + x = self.fc1(x) + xvector = self.fc2.affine(x) + elif position == "near": + if self.fc1 is not None: + x = self.fc1(x) + xvector = self.fc2(x) + else: + raise TypeError("Expected far or near position, but got {}".format( + self.extracted_embedding)) + return xvector + + @torch.jit.export + def extract_embedding_whole(self, input: torch.Tensor, position: str = 'near', maxChunk: int = 4000, isMatrix: bool = True): + with torch.no_grad(): + if isMatrix: + input = torch.unsqueeze(input, dim=0) + input = input.transpose(1, 2) + num_frames = input.shape[2] + num_split = (num_frames + maxChunk - 1) // maxChunk + split_size = num_frames // num_split + offset = 0 + embedding_stats = torch.zeros(1, self.embd_dim, 1).to(input.device) + for _ in range(0, num_split-1): + this_embedding = self.extract_embedding_jit( + input[:, :, offset:offset+split_size], position) + offset += split_size + embedding_stats += split_size*this_embedding + + last_embedding = self.extract_embedding_jit( + input[:, :, offset:], position) + + embedding = (embedding_stats + (num_frames-offset) + * last_embedding) / num_frames + return torch.squeeze(embedding.transpose(1, 2)).cpu() + + @torch.jit.export + def embedding_dim(self) -> int: + """ Export interface for c++ call, return embedding dim of the model + """ + return self.embd_dim + + def get_warmR_T(self, T_0, T_mult, epoch): + n = int(math.log(max(0.05, (epoch / T_0 * (T_mult - 1) + 1)), T_mult)) + T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1) + T_i = T_0 * T_mult ** (n) + return T_cur, T_i + + def compute_decay_value(self, start, end, T_cur, T_i): + # Linear decay in every cycle time. + return start - (start - end)/(T_i-1) * (T_cur % T_i) + + def step(self, epoch, this_iter, epoch_batchs): + # Heated up for t and s. + # Decay for margin and dropout p. + if self.use_step: + if self.step_params["m"]: + current_postion = epoch*epoch_batchs + this_iter + lambda_factor = max(self.step_params["lambda_0"], + self.step_params["lambda_b"]*(1+self.step_params["gamma"]*current_postion)**(-self.step_params["alpha"])) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) + + if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): + T_cur, T_i = self.get_warmR_T(*self.step_params["T"], epoch) + T_cur = T_cur*epoch_batchs + this_iter + T_i = T_i * epoch_batchs + + if self.step_params["t"]: + self.loss.t = self.compute_decay_value( + *self.step_params["t_tuple"], T_cur, T_i) + + if self.step_params["p"]: + self.aug_dropout.p = self.compute_decay_value( + *self.step_params["p_tuple"], T_cur, T_i) + + if self.step_params["s"]: + self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]] + + def step_iter(self, epoch, cur_step): + # For iterabledataset + if self.use_step: + if self.step_params["margin_warm"]: + offset_margin, lambda_m = self.margin_warm.step(cur_step) + lambda_m = max(1e-3,lambda_m) + self.loss.step(lambda_m,offset_margin) + if self.step_params["m"]: + lambda_factor = max(self.step_params["lambda_0"], + self.step_params["lambda_b"]*(1+self.step_params["gamma"]*cur_step)**(-self.step_params["alpha"])) + lambda_m = 1/(1 + lambda_factor) + self.loss.step(lambda_m) + + if self.step_params["T"] is not None and (self.step_params["t"] or self.step_params["p"]): + T_cur, T_i = self.get_warmR_T(*self.step_params["T"], cur_step) + + if self.step_params["t"]: + self.loss.t = self.compute_decay_value( + *self.step_params["t_tuple"], T_cur, T_i) + + if self.step_params["p"]: + self.aug_dropout.p = self.compute_decay_value( + *self.step_params["p_tuple"], T_cur, T_i) + + if self.step_params["s"]: + self.loss.s = self.step_params["s_tuple"][self.step_params["s_list"][epoch]] + + +if __name__ == '__main__': + # Input size: batch_size * feat_dim * seq_len * + timer = utils.Timer() + x = torch.zeros(1000, 80) + + + model = TransformerXvector(inputs_dim=80, num_targets=1211,training=False) + out = model.extract_embedding(x) + total = sum(p.numel() for p in model.parameters()) + print(model) + print(total) + print(out.shape) diff --git a/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh b/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh old mode 100644 new mode 100755 index 7dc1979..9fe7ac2 --- a/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh +++ b/pytorch/pipeline/extract_xvectors_for_pytorch_new.sh @@ -14,6 +14,7 @@ de_silence=false amp_th=100 use_gpu=false gpu_id="" +max_chunk=10000 force=false sleep_time=3 feat_config=config/feat_conf.yaml @@ -87,8 +88,8 @@ if [ $stage -le 1 ]; then pids="" for g in $(seq $nj); do $cmd --gpu 1 ${dir}/log/extract.$g.log \ - python3 subtools/pytorch/pipeline/onestep/extract_embeddings_new.py --use-gpu=$use_gpu --gpu-id="$gpu_id" \ - --data-type=$data_type --de-silence=$de_silence --amp-th=$amp_th \ + python3 subtools/pytorch/pipeline/onestep/extract_embeddings_online.py --use-gpu=$use_gpu --gpu-id="$gpu_id" \ + --data-type=$data_type --de-silence=$de_silence --amp-th=$amp_th --max-chunk=$max_chunk \ --feat-config=$srcdir/$feat_config --nnet-config=$srcdir/$nnet_config \ "$srcdir/$model" "`echo $wavs | sed s/JOB/$g/g`" "`echo $output | sed s/JOB/$g/g`" || exit 1 & sleep $sleep_time @@ -98,8 +99,8 @@ if [ $stage -le 1 ]; then wait else $cmd JOB=1:$nj ${dir}/log/extract.JOB.log \ - python3 subtools/pytorch/pipeline/onestep/extract_embeddings_new.py --use-gpu="false" \ - --data-type=$data_type --de-silence=$de_silence --amp-th=$amp_th \ + python3 subtools/pytorch/pipeline/onestep/extract_embeddings_online.py --use-gpu="false" \ + --data-type=$data_type --de-silence=$de_silence --amp-th=$amp_th --max-chunk=$max_chunk \ --feat-config=$srcdir/$feat_config --nnet-config=$srcdir/$nnet_config \ "$srcdir/$model" "$wavs" "$output" || exit 1; fi diff --git a/pytorch/pipeline/onestep/extract_embeddings_new.py b/pytorch/pipeline/onestep/extract_embeddings_online.py old mode 100644 new mode 100755 similarity index 88% rename from pytorch/pipeline/onestep/extract_embeddings_new.py rename to pytorch/pipeline/onestep/extract_embeddings_online.py index 4284d5b..17894b3 --- a/pytorch/pipeline/onestep/extract_embeddings_new.py +++ b/pytorch/pipeline/onestep/extract_embeddings_online.py @@ -43,7 +43,11 @@ parser.add_argument("--de-silence", type=str, action=kaldi_common.StrToBoolAction, default=False, choices=["true", "false"], help="Vad or not") parser.add_argument("--amp-th", type=int, default=50, - help="De_silence threshold (16bit)") + help="De_silence threshold (16bit)") + +parser.add_argument("--max-chunk", type=int, default=10000, + help="Select chun_size of features when extracting xvector") + parser.add_argument("--feat-config",type=str,default="",help="The config yaml of feat extraction") parser.add_argument("--use-gpu", type=str, default='true', @@ -80,12 +84,14 @@ raise ValueError("Expected nnet_config or (model_blueprint, model_creation) to exist.") model = utils.create_model_from_py(model_blueprint, model_creation) - + position = model.extracted_embedding model.load_state_dict(torch.load(args.model_path, map_location='cpu'), strict=False) + # Select device model = utils.select_model_device(model, args.use_gpu, gpu_id=args.gpu_id) - + devc = utils.get_device(model) + model.eval() feature_extraction_conf = None if args.data_type != "kaldi": feat_config=args.feat_config @@ -102,7 +108,7 @@ de_sil_conf={} de_sil_conf["min_eng"]=args.amp_th dataset=WavEgsXvector(args.feats_rspecifier,feat_conf=feature_extraction_conf,data_type=args.data_type,de_silence=args.de_silence,de_sil_conf=de_sil_conf) - data_loader = DataLoader(dataset, batch_size=None,num_workers=2, prefetch_factor=500) + data_loader = DataLoader(dataset, batch_size=None,num_workers=2, prefetch_factor=100) timer = utils.Timer() with kaldi_io.open_or_fd(args.vectors_wspecifier, 'wb') as w: cnt=0 @@ -111,11 +117,16 @@ pbar=tqdm(total=tot_len, position=0,ascii=True,miniters=tot_len/100,dynamic_ncols=True) for idx,sample in enumerate(data_loader): key = sample['keys'][0] - feats = sample['feats'][0] + feats = sample['feats'][0].to(devc) + total_dur += feats.size(0)*0.01 timer.reset() - embedding = model.extract_embedding(feats) + embedding = model.extract_embedding_whole(feats,position=position,maxChunk=args.max_chunk) + # embedding1 = model.forward_1(feats) + # print(embedding) + # print(embedding1) + # assert 1==0 extract_time+=timer.elapse() if cnt%500==0: pbar.update(500) diff --git a/pytorch/pipeline/onestep/prepare_speechaug_csv.py b/pytorch/pipeline/onestep/prepare_speechaug_csv.py old mode 100644 new mode 100755 index 898d131..470847e --- a/pytorch/pipeline/onestep/prepare_speechaug_csv.py +++ b/pytorch/pipeline/onestep/prepare_speechaug_csv.py @@ -225,9 +225,14 @@ def prepare_aug_csv(items, csv_file, max_length=None): os.remove(filename) for i in range(int(duration / max_length)): start = int(max_length * i * rate) - stop = int( - min(max_length * (i + 1), duration) * rate - ) + if i == int(duration / max_length) -1: + stop = int( + duration * rate + ) + else: + stop = int( + max_length * (i + 1) * rate + ) new_filename = ( filename[: -len(f".{ext}")] + f"_{i}.{ext}" ) @@ -264,17 +269,17 @@ def concat_csv(out_file,*csv_files): conflict_handler='resolve') # Options - parser.add_argument("--openrir-folder", type=str, default='/tsdata/ASR', + parser.add_argument("--openrir-folder", type=str, default='/data', help="where has openslr rir.") - parser.add_argument("--musan-folder", type=str, default='/tsdata/ASR', + parser.add_argument("--musan-folder", type=str, default='/data', help="where has openslr musan.") - parser.add_argument("--savewav-folder", type=str, default='/work1/ldx/speech_aug_2_new', + parser.add_argument("--savewav-folder", type=str, default='/export/yourpath/speech_aug_6', help="noise clips for online speechaug, set it in SSD.") parser.add_argument("--force-clear", type=str, action=kaldi_common.StrToBoolAction, default=True, choices=["true", "false"], help="force clear") - parser.add_argument("--max-noise-len", type=float, default=2.015, + parser.add_argument("--max-noise-len", type=float, default=6.015, help="the maximum noise length in seconds. Noises longer than this will be cut into pieces") parser.add_argument("csv_aug_folder", type=str, help="csv file folder.") diff --git a/recipe/voxcelebSRC/README.md b/recipe/voxcelebSRC/README.md new file mode 100755 index 0000000..e9db8d4 --- /dev/null +++ b/recipe/voxcelebSRC/README.md @@ -0,0 +1,57 @@ +## Reports +### Results of ResNet34 +* Egs = Voxceleb2_dev(online random aug) + sequential sampling(2s) +* Optimization = [SGD (lr = 0.04) + ReduceLROnPlateau] x 4 GPUs (total batch-size=512) +* ResNet34 (channels = 32, 64, 128, 256) + Stats-Pooling + FC-BN + AM-Softmax (margin = 0.2) + AMP training +* Back-end = near + Cosine + +| EER% | vox1-O | vox1-O-clean | vox1-E | vox1-E-clean | vox1-H | vox1-H-clean | +|:-----|:------:|:------------:|:------:|:------------:|:------:|:------------:| +| Submean | 1.071 | 0.920 | 1.257 | 1.135 | 2.205 | 2.072 | +| AS-Norm | 0.970 | 0.819 | - | - | - | - | +
+ +### Results of ECAPA-TDNN +* Egs = Voxceleb2_dev(online random aug) + random chunk(2s) +* Optimization = [adamW (lr = 1e-8 - 1e-3) + cyclic for 3 cycle with triangular2 strategy] x 4 GPUs (total batch-size=512) +* ECAPA-TDNN (channels = 1024) + FC-BN + AAM-Softmax (margin = 0.2) +* Back-end = near + Cosine + +| EER% | vox1-O | vox1-O-clean | vox1-E | vox1-E-clean | vox1-H | vox1-H-clean | +|:-----|:------:|:------------:|:------:|:------------:|:------:|:------------:| +| Submean | 1.045 | 0.904 | 1.330 | 1.211 | 2.430 | 2.303 | +| AS-Norm | 0.991 | 0.856 | - | - | - | - | +
+ + +### Results of Conformer +* Egs = Voxceleb2_dev(online random aug) + random chunk(3s) +* Optimization = [adamW (lr = 1e-6 - 1e-3) + 1cycle] x 4 GPUs (total batch-size=512) +* Conformer + FC-Swish-LN + ASP + FC-LN + AAM-Softmax (margin = 0.2)) +* Back-end = near + Cosine +* LM: Large-Margin Fine-tune (margin: 0.2 --> 0.5, chunk: 6s) + +| Config | | vox1-O | vox1-O-clean | vox1-E | vox1-E-clean | vox1-H | vox1-H-clean | +|:---------------------------- |:------:|:------:|:------------:|:------:|:------------:|:------:|:------------:| +| 6L-256D-4H-4Sub (50 epochs) | Cosine | 1.204 | 1.074 | 1.386 | 1.267 | 2.416 | 2.294 | +| | AS-Norm | 1.092 | 0.952 | - | - | - | - | +| $\quad+$ SAM training | Cosine | 1.103 | 0.984 | 1.350 | 1.234 | 2.380 | 2.257 | +| | LM | 1.034 | 0.899 | 1.181 | 1.060 | 2.079 | 1.953 | +| | AS-Norm | 0.943 | 0.792 | - | - | - | - | +| 6L-256D-4H-2Sub (30 epochs) | Cosine | 1.066 | 0.915 | 1.298 | 1.177 | 2.167 | 2.034 | +| | LM | 1.029 | 0.888 | 1.160 | 1.043 | 1.923 | 1.792 | +| | AS-Norm | 0.949 | 0.792 | - | - | - | - | +
+ +### Results of RTF +* RTF is evaluated on LibTorch-based runtime, see `subtools/runtime` +* One thread is used for CPU threading and TorchScript inference. +* CPU: Intel(R) Xeon(R) Gold 5218R CPU @ 2.10GHz. + +| Model | Config | Params | RTF | +|:-----|:------ |:------:|:---:| +| ResNet34 | base32 | 6.80M | 0.090 | +| ECAPA | C1024 | 16.0M | 0.071 | +| | C512 | 6.53M | 0.030 | +| Conformer | 6L-256D-4H-4Sub | 18.8M | 0.025 | +| | 6L-256D-4H-2Sub | 22.5M | 0.070 | \ No newline at end of file diff --git a/recipe/voxcelebSRC/runVoxcelebSRC_online.sh b/recipe/voxcelebSRC/runVoxcelebSRC_online.sh old mode 100644 new mode 100755 index a2b8c6d..14b0675 --- a/recipe/voxcelebSRC/runVoxcelebSRC_online.sh +++ b/recipe/voxcelebSRC/runVoxcelebSRC_online.sh @@ -68,50 +68,50 @@ subtools/newCopyData.sh $prefix "voxceleb2_dev voxceleb1" # [5] Sample egs. It will do cmn and vad firstly and then remove invalid utts. Finally, # it samples egs to fixed chunk-size with instance sampling. -subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=0 --endstage=2 +# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=0 --endstage=2 # subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runRepvggXvector.py --stage=0 --endstage=2 # subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runEcapaXvector_online.py --stage=0 --endstage=2 +subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector.py --stage=0 --endstage=2 # [6] Train a thin Resnet34 model with AM-Softmax loss and 8 GPUs will be used to accelerate training -subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=3 --endstage=3 --gpu-id=0,1,2,3,4 -# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runRepvggXvector.py --stage=3 --endstage=3 --gpu-id=0,1,2,3,4 -# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runEcapaXvector_online.py --stage=3 --endstage=3 --gpu-id=0,1,2,3,4 +# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=3 --endstage=3 --gpu-id=0,1,2,3 +# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runRepvggXvector.py --stage=3 --endstage=3 --gpu-id=0,1,2,3 +# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runEcapaXvector_online.py --stage=3 --endstage=3 --gpu-id=0,1,2,3 +subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector.py --stage=3 --endstage=3 --gpu-id=0,1,2,3 # [7] Extract near xvectors for voxceleb1 and voxceleb2_dev -subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=4 +# subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runResnetXvector_online.py --stage=4 # subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runRepvggXvector.py --stage=4 # subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runEcapaXvector_online.py --stage=4 +subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector.py --stage=4 + +# [8] Large-Margin Fine-tune +subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector_LM.py --stage=3 --endstage=3 --gpu-id=0,1,2,3 + +# [9] Extract xvectors +subtools/runPytorchLauncher.sh subtools/pytorch/launcher/runTransformerXvector.py --stage=4 + ### Back-end scoring -# [14] Score with submean + Cosine + AS-Norm processes +# [14] Score with submean + Cosine + AS-Norm processes. + +# tasks="vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean" +# for task in $tasks;do +# score_norm=false +# [ "$task" == "vox1-O" ] && score_norm=true +# [ "$task" == "vox1-O-clean" ] && score_norm=true +# subtools/recipe/voxcelebSRC/gather_results_from_epochs.sh --prefix $prefix --score cosine --submean true \ +# --vectordir "exp/resnet34_fbank80_online" --task $task --epochs "40" --positions "near" --trainset voxceleb2_dev_vad \ +# --score-norm $score_norm --score-norm-method "asnorm" --top-n 100 --cohort-set voxceleb2_dev_vad +# done + tasks="vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean" for task in $tasks;do score_norm=false [ "$task" == "vox1-O" ] && score_norm=true [ "$task" == "vox1-O-clean" ] && score_norm=true - subtools/recipe/voxcelebSRC/gather_results_from_epochs.sh --prefix $prefix --score cosine --submean true \ - --vectordir "exp/resnet34_fbank80_online" --task $task --epochs "40" --positions "near" --trainset voxceleb2_dev_vad \ - --score-norm $score_norm --score-norm-method "asnorm" --top-n 100 --cohort-set voxceleb2_dev_vad + subtools/recipe/voxcelebSRC/gather_results_from_epochs.sh --prefix $prefix --score cosine --submean false \ + --vectordir "exp/conformer_6L256D4H_4sub_lm" --task $task --epochs "4" --positions "near" --trainset voxceleb2_dev \ + --score-norm $score_norm --score-norm-method "asnorm" --top-n 300 --cohort-set voxceleb2_dev done -#### Report #### -# Egs = Voxceleb2_dev(online random aug) + sequential sampling -# Optimization = [SGD (lr = 0.01) + ReduceLROnPlateau] x 4 GPUs (total batch-size=512) -# Resnet34 (channels = 32, 64, 128, 256) + Stats-Pooling + FC-ReLU-BN-FC-BN + AM-Softmax (margin = 0.2) + AMP training -# -# Back-end = near + Cosine -# -# EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean -# Submean 1.071 0.920 1.257 1.135 2.205 2.072 -# AS-Norm 0.970 0.819 - - - - -# - -# Egs = Voxceleb2_dev(online random aug) + random chunk -# ECAPA-TDNN (channels = 1024) + FC-ReLU-BN-FC-BN + AAM-Softmax (margin = 0.2) -# Optimization = [adamW (lr = 1e-8 - 1e-3) + cyclic for 3 cycle with triangular2 strategy] x 4 GPUs (total batch-size=512) -# Back-end = near + Cosine -# -# EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean -# Submean 1.045 0.904 1.330 1.211 2.430 2.303 -# AS-Norm 0,991 0.856 - - - - -#