Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

monoT5 fine-tuning process #222

Open
yixuan-qiao opened this issue Sep 21, 2021 · 25 comments
Open

monoT5 fine-tuning process #222

yixuan-qiao opened this issue Sep 21, 2021 · 25 comments

Comments

@yixuan-qiao
Copy link

Is there any plan to make the fine-tuning process of monoT5 & duoT5 public?

We follow the steps in the paper to finetune the T5-base & T5-3B in pytorch framework using model parallelism and data parallelism. At present, we just completed the base version and got NDCG@10 0.62 compared to yours 0.68. For 3B model, so far 67k steps have been trained, the best NDCG@10 less than 0.68. We are curious if there are any other training strategies, such as warmup steps, optimizer(adafactor?), dropout ratio, etc.

@rodrigonogueira4
Copy link
Member

We will be sharing a pytorch training script in a week or two that gets close to the original TF training.

@TuozhenLiu
Copy link

TuozhenLiu commented Sep 22, 2021

We will be sharing a pytorch training script in a week or two that gets close to the original TF training.

Would you mind sharing the Adafactor config in advance? We're following the config by Huggingface pytorch verision and confused about some details, like whether to add "scale-parameter", "weight-decay" and "lr-warm-up-stragety". Much thanks.

@vjeronymo2
Copy link
Contributor

Hi guys
You can check all these parameters and the finetuning script draft (working) here: https://github.com/vjeronymo2/pygaggle/blob/master/pygaggle/run/finetune_monot5.py
I'll be doing a PR soon

@yixuan-qiao
Copy link
Author

Thanks for releasing the excellent work.

In finetune_monot5.py, I find the base_model is castorini/monot5-base-msmarco-10k, the discriptions on huggingface is as follows

This model is a T5-base reranker fine-tuned on the MS MARCO passage dataset for 10k steps (or 1 epoch).

I am confused this model is already finetuned after 1 epoch from original google T5-base? but the 10k steps is not consistent, the same thing happened on castorini/monot5-base-msmarco

This model is a T5-base reranker fine-tuned on the MS MARCO passage dataset for 100k steps (or 10 epochs).

I'm not sure if there's something wrong with the name. The overall finetuning process needs 10 epochs, namely total 1000K steps? Besides, i just find the training strategy is very different from the paper.

@rodrigonogueira4
Copy link
Member

The MS MARCO dataset has ~530k query-positive passage pairs. We don't count negatives because they are virtually infinity. Using a batch of size 128, half of which are made of positive passages, we do approximately one epoch on the positive examples after training for 10k steps (64*10k=640k positives seen).
Does that make sense?

@yixuan-qiao
Copy link
Author

so amazing that just see roughly all query-positive passage pairs once can almost match the final performance. just curious how you get 10k checkpoints from T5-base?

@vjeronymo2
Copy link
Contributor

vjeronymo2 commented Sep 24, 2021

In finetune_monot5.py, I find the base_model is castorini/monot5-base-msmarco-10k, the discriptions on huggingface is as follows

Oh, thanks for bringing this up! The default is supposed to be the regular 't5-base' model, not 'castorini/monot5-base-msmarco-10k', which has already been finetuned. Sorry for the confusion

just curious how you get 10k checkpoints from T5-base?

You can train the model with the first 640k lines from triples.train.small.tsv, which would result in 1280k samples ( 640k positives + 640k negatives). The batch_size we're using is 128, so the total number of steps is 1280k/128 = 10k, hence, 1 epoch.
Let me know if this doesn't answer your question.

@yixuan-qiao
Copy link
Author

Got it, much thanks. ;-)
Is there a plan to release t5-3b fine-tuning script? we try to use the framework of NVIDIA/Megatron-LM to fine-tune the 3b model, but the performance is so bad, we are working on the problems.

@rodrigonogueira4
Copy link
Member

We don't plan to release the script for fine-tuning 3B as we also found it is quite complicated to do so with current Pytorch frameworks. In that case, I highly recommend using Mesh Tensorflow.

@yixuan-qiao
Copy link
Author

It seems that mesh tf is compatible with TPU well than GPU cluster. We almost replicated the effect of the monot5-base model with the model & data parallel framework using the parameters you provide. :-)

We use the same set of config to try monot5-3b model but not very well as expected. Would you mind share some training parameters or strategies for 3B model, such as lr, warmup steps, weight decay, dropout ratio, etc.

@rodrigonogueira4
Copy link
Member

Hi @yixuan-qiao, sorry for taking so long. Here is the CMD we used to finetune T5-3B on MS MARCO:

t5_mesh_transformer  \
  --tpu="<tpu_name>" \
  --gcp_project="<your_project_here>" \
  --tpu_zone="<tpu_zone>" \
  --model_dir="<model_dir>" \
  --gin_param="init_checkpoint = 'gs://t5-data/pretrained_models/3B/model.ckpt-1000000'" \
  --gin_file="dataset.gin" \
  --gin_file="models/bi_v1.gin" \
  --gin_file="gs://t5-data/pretrained_models/3B/operative_config.gin" \
  --gin_param="utils.tpu_mesh_shape.model_parallelism = 8" \
  --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \
  --gin_param="utils.run.train_dataset_fn = @t5.models.mesh_transformer.tsv_dataset_fn" \
  --gin_param="tsv_dataset_fn.filename = '<your_dataset_here>" \
  --gin_file="learning_rate_schedules/constant_0_001.gin" \
  --gin_param="run.train_steps = 1100000" \
  --gin_param="tokens_per_batch = 65536"

@yixuan-qiao
Copy link
Author

We reproduced the performance of monoT5-3b. Thanks a lot!!!

As you said in the paper, i use the output from monot5 as input to the duoT5. Specifically, for each query, take top 50 according to the score from monot5-3b, built 50*49=2450 pairs in sequence, finally total 12.8M training examples.

After training for about 50K iterations, i got about 0.72 ndcg@5, 0.71 ndcg@5, much lower than yours.

Is the process of constructing second stage training data as mentioned above, or are there other key points that I haven’t noticed?

@rodrigonogueira4
Copy link
Member

Hi @yixuan-qiao, for training duoT5, we used the original triples.train.small from MS MARCO as in the training of duoBERT: https://github.com/castorini/duobert#training-duobert

That is, triples of <query, negative_doc, positive_doc> or <query, positive_doc, negative_doc> are given as input to the model, where negative_doc and positive_doc are from triples.train.small.tsv.
Note that during training, duoT5 never sees a triple of <query, negative_doc, negative_doc> or <query, positive_doc, positive_doc>.

@yixuan-qiao
Copy link
Author

Note that during training, duoT5 never sees a triple of <query, negative_doc, negative_doc> or <query, positive_doc, positive_doc>.

I am not sure why duoT5 can not see the triple format. It takes the following format as input,
Query: 𝑞 Document0: 𝑑𝑖 Document1: 𝑑𝑗 Relevant: true(or false)

But from the point of loss function, duoT5 indeed use LM loss not the triple loss.

@rodrigonogueira4
Copy link
Member

duoT5 cannot see during training both positive or negative documents because how would you define the target? That is, if p_{i_j} = 1 means that doc i is more relevant than doc j, which probability should we use if both doc_i and doc_j are relevant (or not relevant).

@HansiZeng
Copy link

The MS MARCO dataset has ~530k query-positive passage pairs. We don't count negatives because they are virtually infinity. Using a batch of size 128, half of which are made of positive passages, we do approximately one epoch on the positive examples after training for 10k steps (64*10k=640k positives seen). Does that make sense?

(1) The negative documents are sampled from BM25 or randomly sampled?
(2) In each epoch, do you resample the negative documents or use the same negative documents as the previous epoch?

@rodrigonogueira4
Copy link
Member

Hi @HansiZeng,

(1) We use the original triples.train.small.tsv which has a complicated sampling procedure. We exchanged some emails with them a while a ago, and IIRC, they use a BM25-ish version to sample the negatives not from the 8.8M passages but from the inner join of all top-1000 passages retrieved for all training queries.

(2) In each epoch, a new negative is sampled.

@HansiZeng
Copy link

Hi @rodrigonogueira4,
I appreciate your response, but I still have a question about the triples.train.small.tsv file. I noticed that it only has about 400K relevant query-passage pairs, which is less than the total number of 532K relevant query-passage pairs in the training set. Is there a way to select negative passages for the remaining 132K relevant query-passage pairs that are not included in the triples.train.small.tsv file?
In your paper https://aclanthology.org/2020.findings-emnlp.63.pdf. It seems like using all 532K relevant query-passage pairs in training.

@HansiZeng
Copy link

Can you confirm if in every epoch each relevant query-document pair is seen exactly once, and 10 epochs equate to seeing the same relevant pair 10 times?

@rodrigonogueira4
Copy link
Member

Re: 532k vs 400k, that is a mistake in the paper: the model was finetuned on 400k positive pairs, all from triples.train.small. If we finetune on the 532k from the "full" triples.train, the effectiveness drops a bit (surprisingly).

@rodrigonogueira4
Copy link
Member

rodrigonogueira4 commented Mar 13, 2023

Can you confirm if in every epoch each relevant query-document pair is seen exactly once, and 10 epochs equate to seeing the same relevant pair 10 times?

That's right

@HansiZeng
Copy link

https://github.com/vjeronymo2/pygaggle/blob/master/pygaggle/run/finetune_monot5.py

Hi @rodrigonogueira4,
I have a question about constructing negative documents for each query. Could you please help me to clarify the following points?
(1) Apart from using negative documents from BM25 top-1000, are there any other sampling techniques they use for each query?
(2) I am a little unclear about your previous response: "they use a BM25-ish version to sample the negatives not from the 8.8M passages but from the inner join of all top-1000 passages retrieved for all training queries." Could you please provide more details or explain this further?

@rodrigonogueira4
Copy link
Member

Hi @HansiZeng,

(1) No.. but just as a warning: we were never able to generate as good negatives as the ones in the triples.train.small.tsv. We tried multiple strategies (dense + bm25, avoid sampling from top 10, etc), but there seems to be something special in the way MS constructed the original triples train.

(2) I don't fully understand their negative sampling method but perhaps @lintool can explain or have a pointer?

@inderjeetnair
Copy link

https://github.com/castorini/duobert#training-duobert

The link pointing to the triplets dataset is no longer valid. Can you please point me to the new link where the dataset can be downloaded?

@fangguo1
Copy link

fangguo1 commented Aug 4, 2023

@vjeronymo2 Thanks for making this finetune_monoT5.py for reference. However, in this file, since you are using the trainer class, keys in the 'dataset_train' should be the same as the variable in the T5 model's forward function, otherwise the key will be ignored by the trainer. Thus I think the key 'text' in 'dataset_train' should be changed to 'input_ids'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants