Skip to content

Commit

Permalink
Add Gradient Cache&Recompute into Neural Search (PaddlePaddle#3697)
Browse files Browse the repository at this point in the history
* Add Gradient Cache&Recompute into Neural Search

* Update README.md

* Optimize gradient cache code
  • Loading branch information
w5688414 authored Nov 16, 2022
1 parent 9ab3ea4 commit 39c4f76
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 62 deletions.
3 changes: 3 additions & 0 deletions applications/neural_search/recall/in_batch_negative/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3" \
* `recall_num`: 对 1 个文本召回的相似文本数量
* `similar_text_pair_file`: 由相似文本对构成的评估集
* `corpus_file`: 召回库数据 corpus_file
* `use_recompute`: 使用Recompute策略,用于节省显存,是一种以时间换空间的技术
* `use_gradient_cache`: 使用Gradient Cache策略,用于节省显存,是一种以时间换空间的技术
* `chunk_numbers`: 使用Gradient Cache策略的参数,表示的是同一个批次的样本分几次执行

也可以使用bash脚本:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def forward(self,
title_cls_embedding,
transpose_y=True)

# substract margin from all positive samples cosine_sim()
# Substract margin from all positive samples cosine_sim()
margin_diag = paddle.full(shape=[query_cls_embedding.shape[0]],
fill_value=self.margin,
dtype=paddle.get_default_dtype())

cosine_sim = cosine_sim - paddle.diag(margin_diag)

# scale cosine to ease training converge
# Scale cosine to ease training converge
cosine_sim *= self.sacle

labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64')
Expand All @@ -76,3 +76,56 @@ def forward(self,
loss = F.cross_entropy(input=cosine_sim, label=labels)

return loss


class SemanticIndexCacheNeg(SemanticIndexBase):

def __init__(self,
pretrained_model,
dropout=None,
margin=0.3,
scale=30,
output_emb_size=None):
super().__init__(pretrained_model, dropout, output_emb_size)
self.margin = margin
# Used scaling cosine similarity to ease converge
self.sacle = scale

def forward(self,
query_input_ids,
title_input_ids,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None):

query_cls_embedding = self.get_pooled_embedding(query_input_ids,
query_token_type_ids,
query_position_ids,
query_attention_mask)

title_cls_embedding = self.get_pooled_embedding(title_input_ids,
title_token_type_ids,
title_position_ids,
title_attention_mask)

cosine_sim = paddle.matmul(query_cls_embedding,
title_cls_embedding,
transpose_y=True)

# Substract margin from all positive samples cosine_sim()
margin_diag = paddle.full(shape=[query_cls_embedding.shape[0]],
fill_value=self.margin,
dtype=cosine_sim.dtype)

cosine_sim = cosine_sim - paddle.diag(margin_diag)

# Scale cosine to ease training converge
cosine_sim *= self.sacle

labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64')
labels = paddle.reshape(labels, shape=[-1, 1])

return [cosine_sim, labels, query_cls_embedding, title_cls_embedding]
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# GPU version
root_dir="checkpoints/inbatch"
python -u -m paddle.distributed.launch --gpus "3" --log_dir "recall_log/" \
python -u -m paddle.distributed.launch --gpus "0" --log_dir "recall_log/" \
recall.py \
--device gpu \
--recall_result_dir "recall_result_dir" \
Expand All @@ -11,7 +25,7 @@ python -u -m paddle.distributed.launch --gpus "3" --log_dir "recall_log/" \
--hnsw_ef 100 \
--batch_size 64 \
--output_emb_size 256\
--max_seq_length 60 \
--max_seq_length 64 \
--recall_num 50 \
--similar_text_pair "recall/dev.csv" \
--corpus_file "recall/corpus.csv"
Expand Down
Loading

0 comments on commit 39c4f76

Please sign in to comment.