diff --git a/applications/neural_search/recall/in_batch_negative/README.md b/applications/neural_search/recall/in_batch_negative/README.md index ed04bd15b4e9..aa07d6908480 100644 --- a/applications/neural_search/recall/in_batch_negative/README.md +++ b/applications/neural_search/recall/in_batch_negative/README.md @@ -229,6 +229,9 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3" \ * `recall_num`: 对 1 个文本召回的相似文本数量 * `similar_text_pair_file`: 由相似文本对构成的评估集 * `corpus_file`: 召回库数据 corpus_file +* `use_recompute`: 使用Recompute策略,用于节省显存,是一种以时间换空间的技术 +* `use_gradient_cache`: 使用Gradient Cache策略,用于节省显存,是一种以时间换空间的技术 +* `chunk_numbers`: 使用Gradient Cache策略的参数,表示的是同一个批次的样本分几次执行 也可以使用bash脚本: diff --git a/applications/neural_search/recall/in_batch_negative/batch_negative/model.py b/applications/neural_search/recall/in_batch_negative/batch_negative/model.py index 911fe0b4d360..050beb62f613 100644 --- a/applications/neural_search/recall/in_batch_negative/batch_negative/model.py +++ b/applications/neural_search/recall/in_batch_negative/batch_negative/model.py @@ -60,14 +60,14 @@ def forward(self, title_cls_embedding, transpose_y=True) - # substract margin from all positive samples cosine_sim() + # Substract margin from all positive samples cosine_sim() margin_diag = paddle.full(shape=[query_cls_embedding.shape[0]], fill_value=self.margin, dtype=paddle.get_default_dtype()) cosine_sim = cosine_sim - paddle.diag(margin_diag) - # scale cosine to ease training converge + # Scale cosine to ease training converge cosine_sim *= self.sacle labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64') @@ -76,3 +76,56 @@ def forward(self, loss = F.cross_entropy(input=cosine_sim, label=labels) return loss + + +class SemanticIndexCacheNeg(SemanticIndexBase): + + def __init__(self, + pretrained_model, + dropout=None, + margin=0.3, + scale=30, + output_emb_size=None): + super().__init__(pretrained_model, dropout, output_emb_size) + self.margin = margin + # Used scaling cosine similarity to ease converge + self.sacle = scale + + def forward(self, + query_input_ids, + title_input_ids, + query_token_type_ids=None, + query_position_ids=None, + query_attention_mask=None, + title_token_type_ids=None, + title_position_ids=None, + title_attention_mask=None): + + query_cls_embedding = self.get_pooled_embedding(query_input_ids, + query_token_type_ids, + query_position_ids, + query_attention_mask) + + title_cls_embedding = self.get_pooled_embedding(title_input_ids, + title_token_type_ids, + title_position_ids, + title_attention_mask) + + cosine_sim = paddle.matmul(query_cls_embedding, + title_cls_embedding, + transpose_y=True) + + # Substract margin from all positive samples cosine_sim() + margin_diag = paddle.full(shape=[query_cls_embedding.shape[0]], + fill_value=self.margin, + dtype=cosine_sim.dtype) + + cosine_sim = cosine_sim - paddle.diag(margin_diag) + + # Scale cosine to ease training converge + cosine_sim *= self.sacle + + labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64') + labels = paddle.reshape(labels, shape=[-1, 1]) + + return [cosine_sim, labels, query_cls_embedding, title_cls_embedding] diff --git a/applications/neural_search/recall/in_batch_negative/scripts/run_build_index.sh b/applications/neural_search/recall/in_batch_negative/scripts/run_build_index.sh index 857302c334a1..9920a045b9dc 100755 --- a/applications/neural_search/recall/in_batch_negative/scripts/run_build_index.sh +++ b/applications/neural_search/recall/in_batch_negative/scripts/run_build_index.sh @@ -1,6 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # GPU version root_dir="checkpoints/inbatch" -python -u -m paddle.distributed.launch --gpus "3" --log_dir "recall_log/" \ +python -u -m paddle.distributed.launch --gpus "0" --log_dir "recall_log/" \ recall.py \ --device gpu \ --recall_result_dir "recall_result_dir" \ @@ -11,7 +25,7 @@ python -u -m paddle.distributed.launch --gpus "3" --log_dir "recall_log/" \ --hnsw_ef 100 \ --batch_size 64 \ --output_emb_size 256\ - --max_seq_length 60 \ + --max_seq_length 64 \ --recall_num 50 \ --similar_text_pair "recall/dev.csv" \ --corpus_file "recall/corpus.csv" diff --git a/applications/neural_search/recall/in_batch_negative/train_batch_neg.py b/applications/neural_search/recall/in_batch_negative/train_batch_neg.py index 10bead311455..fe48bd49ffbc 100644 --- a/applications/neural_search/recall/in_batch_negative/train_batch_neg.py +++ b/applications/neural_search/recall/in_batch_negative/train_batch_neg.py @@ -17,14 +17,17 @@ import time import numpy as np import paddle +import paddle.nn.functional as F from functools import partial + from paddlenlp.utils.log import logger from paddlenlp.data import Tuple, Pad from paddlenlp.datasets import load_dataset, MapDataset from paddlenlp.transformers import AutoModel, AutoTokenizer from paddlenlp.transformers import LinearDecayWithWarmup + from base_model import SemanticIndexBase -from batch_negative.model import SemanticIndexBatchNeg +from batch_negative.model import SemanticIndexBatchNeg, SemanticIndexCacheNeg from data import read_text_pair, convert_example, create_dataloader, gen_id2corpus, gen_text_file from ann_util import build_index @@ -89,6 +92,15 @@ help="evaluate_result") parser.add_argument('--evaluate', action='store_true', help='whether evaluate while training') +parser.add_argument("--use_amp", action="store_true", help="Whether to use AMP.") +parser.add_argument("--amp_loss_scale", default=32768, type=float,help="The value of scale_loss for fp16. This is only used for AMP training.") +parser.add_argument("--use_recompute", + action='store_true', + help="Using the recompute to scale up the batch size and save the memory.") +parser.add_argument("--use_gradient_cache", + action='store_true', + help="Using the gradient cache to scale up the batch size and save the memory.") +parser.add_argument("--chunk_numbers",type=int,default=50,help="The number of the chunks for model") args = parser.parse_args() # yapf: enable @@ -161,6 +173,179 @@ def evaluate(model, corpus_data_loader, query_data_loader, recall_result_file, return float(recall_N[1]) +def train(train_data_loader, model, optimizer, lr_scheduler, rank, + corpus_data_loader, query_data_loader, recall_result_file, text_list, + id2corpus, tokenizer): + global_step = 0 + best_recall = 0.0 + tic_train = time.time() + for epoch in range(1, args.epochs + 1): + for step, batch in enumerate(train_data_loader, start=1): + query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch + + loss = model(query_input_ids=query_input_ids, + title_input_ids=title_input_ids, + query_token_type_ids=query_token_type_ids, + title_token_type_ids=title_token_type_ids) + + global_step += 1 + if global_step % args.log_steps == 0 and rank == 0: + print( + "global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s" + % (global_step, epoch, step, loss, args.log_steps / + (time.time() - tic_train))) + tic_train = time.time() + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + if not args.evaluate: + if global_step % args.save_steps == 0 and rank == 0: + save_dir = os.path.join(args.save_dir, + "model_%d" % global_step) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_param_path = os.path.join(save_dir, + 'model_state.pdparams') + paddle.save(model.state_dict(), save_param_path) + tokenizer.save_pretrained(save_dir) + if args.evaluate and rank == 0: + print("evaluating") + recall_5 = evaluate(model, corpus_data_loader, query_data_loader, + recall_result_file, text_list, id2corpus) + if recall_5 > best_recall: + best_recall = recall_5 + + save_dir = os.path.join(args.save_dir, "model_best") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_param_path = os.path.join(save_dir, 'model_state.pdparams') + paddle.save(model.state_dict(), save_param_path) + tokenizer.save_pretrained(save_dir) + with open(os.path.join(save_dir, "train_result.txt"), + 'a', + encoding='utf-8') as fp: + fp.write('epoch=%d, global_step: %d, recall: %s\n' % + (epoch, global_step, recall_5)) + + +def gradient_cache_train(train_data_loader, model, optimizer, lr_scheduler, + rank, tokenizer): + + if args.use_amp: + scaler = paddle.amp.GradScaler(init_loss_scaling=args.amp_loss_scale) + + if args.batch_size % args.chunk_numbers == 0: + chunk_numbers = args.chunk_numbers + else: + raise Exception( + f" Batch_size {args.batch_size} must divides chunk_numbers {args.chunk_numbers} without producing a remainder " + ) + + def split(inputs, chunk_numbers, axis=0): + if inputs.shape[0] % chunk_numbers == 0: + return paddle.split(inputs, chunk_numbers, axis=0) + else: + return paddle.split(inputs, inputs.shape[0], axis=0) + + global_step = 0 + tic_train = time.time() + for epoch in range(1, args.epochs + 1): + for step, batch in enumerate(train_data_loader, start=1): + # Separate large batches into several sub batches + chunked_x = [split(t, chunk_numbers, axis=0) for t in batch] + sub_batchs = [list(s) for s in zip(*chunked_x)] + + all_grads = [] + all_CUDA_rnd_state = [] + all_query = [] + all_title = [] + + for sub_batch in sub_batchs: + all_reps = [] + all_labels = [] + sub_query_input_ids, sub_query_token_type_ids, sub_title_input_ids, sub_title_token_type_ids = sub_batch + with paddle.amp.auto_cast( + args.use_amp, + custom_white_list=["layer_norm", "softmax", "gelu"]): + + with paddle.no_grad(): + sub_CUDA_rnd_state = paddle.framework.random.get_cuda_rng_state( + ) + all_CUDA_rnd_state.append(sub_CUDA_rnd_state) + sub_cosine_sim, sub_label, query_embedding, title_embedding = model( + query_input_ids=sub_query_input_ids, + title_input_ids=sub_title_input_ids, + query_token_type_ids=sub_query_token_type_ids, + title_token_type_ids=sub_title_token_type_ids) + all_reps.append(sub_cosine_sim) + all_labels.append(sub_label) + all_title.append(title_embedding) + all_query.append(query_embedding) + + model_reps = paddle.concat(all_reps, axis=0) + model_title = paddle.concat(all_title) + model_query = paddle.concat(all_query) + + model_title = model_title.detach() + model_query = model_query.detach() + + model_query.stop_gradient = False + model_title.stop_gradient = False + model_reps.stop_gradient = False + + model_label = paddle.concat(all_labels, axis=0) + loss = F.cross_entropy(input=model_reps, label=model_label) + loss.backward() + # Store gradients + all_grads.append(model_reps.grad) + + for sub_batch, CUDA_state, grad in zip(sub_batchs, + all_CUDA_rnd_state, + all_grads): + + sub_query_input_ids, sub_query_token_type_ids, sub_title_input_ids, sub_title_token_type_ids = sub_batch + paddle.framework.random.set_cuda_rng_state(CUDA_state) + # Recompute the forward propogation + sub_cosine_sim, sub_label, query_embedding, title_embedding = model( + query_input_ids=sub_query_input_ids, + title_input_ids=sub_title_input_ids, + query_token_type_ids=sub_query_token_type_ids, + title_token_type_ids=sub_title_token_type_ids) + # Chain rule + surrogate = paddle.dot(sub_cosine_sim, grad) + # Backward propogation + if args.use_amp: + scaled = scaler.scale(surrogate) + scaled.backward() + else: + surrogate.backward() + # Update model parameters + if args.use_amp: + scaler.minimize(optimizer, scaled) + else: + optimizer.step() + + global_step += 1 + if global_step % args.log_steps == 0 and rank == 0: + print( + "global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s" + % (global_step, epoch, step, loss, args.log_steps / + (time.time() - tic_train))) + tic_train = time.time() + + lr_scheduler.step() + optimizer.clear_grad() + + if global_step % args.save_steps == 0 and rank == 0: + save_dir = os.path.join(args.save_dir, "model_%d" % global_step) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_param_path = os.path.join(save_dir, 'model_state.pdparams') + paddle.save(model.state_dict(), save_param_path) + tokenizer.save_pretrained(save_dir) + + def do_train(): paddle.set_device(args.device) rank = paddle.distributed.get_rank() @@ -173,7 +358,8 @@ def do_train(): data_path=args.train_set_file, lazy=False) - pretrained_model = AutoModel.from_pretrained(args.model_name_or_path) + pretrained_model = AutoModel.from_pretrained( + args.model_name_or_path, enable_recompute=args.use_recompute) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) @@ -197,11 +383,16 @@ def do_train(): batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func) - - model = SemanticIndexBatchNeg(pretrained_model, - margin=args.margin, - scale=args.scale, - output_emb_size=args.output_emb_size) + if (args.use_gradient_cache): + model = SemanticIndexCacheNeg(pretrained_model, + margin=args.margin, + scale=args.scale, + output_emb_size=args.output_emb_size) + else: + model = SemanticIndexBatchNeg(pretrained_model, + margin=args.margin, + scale=args.scale, + output_emb_size=args.output_emb_size) if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt): state_dict = paddle.load(args.init_from_ckpt) @@ -262,57 +453,13 @@ def do_train(): weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_params) - global_step = 0 - best_recall = 0.0 - tic_train = time.time() - for epoch in range(1, args.epochs + 1): - for step, batch in enumerate(train_data_loader, start=1): - query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch - - loss = model(query_input_ids=query_input_ids, - title_input_ids=title_input_ids, - query_token_type_ids=query_token_type_ids, - title_token_type_ids=title_token_type_ids) - - global_step += 1 - if global_step % args.log_steps == 0 and rank == 0: - print( - "global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s" - % (global_step, epoch, step, loss, 10 / - (time.time() - tic_train))) - tic_train = time.time() - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.clear_grad() - if not args.evaluate: - if global_step % args.save_steps == 0 and rank == 0: - save_dir = os.path.join(args.save_dir, - "model_%d" % global_step) - if not os.path.exists(save_dir): - os.makedirs(save_dir) - save_param_path = os.path.join(save_dir, - 'model_state.pdparams') - paddle.save(model.state_dict(), save_param_path) - tokenizer.save_pretrained(save_dir) - if args.evaluate and rank == 0: - print("evaluating") - recall_5 = evaluate(model, corpus_data_loader, query_data_loader, - recall_result_file, text_list, id2corpus) - if recall_5 > best_recall: - best_recall = recall_5 - - save_dir = os.path.join(args.save_dir, "model_best") - if not os.path.exists(save_dir): - os.makedirs(save_dir) - save_param_path = os.path.join(save_dir, 'model_state.pdparams') - paddle.save(model.state_dict(), save_param_path) - tokenizer.save_pretrained(save_dir) - with open(os.path.join(save_dir, "train_result.txt"), - 'a', - encoding='utf-8') as fp: - fp.write('epoch=%d, global_step: %d, recall: %s\n' % - (epoch, global_step, recall_5)) + if (args.use_gradient_cache): + gradient_cache_train(train_data_loader, model, optimizer, lr_scheduler, + rank, tokenizer) + else: + train(train_data_loader, model, optimizer, lr_scheduler, rank, + corpus_data_loader, query_data_loader, recall_result_file, + text_list, id2corpus, tokenizer) if __name__ == "__main__":