-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Hugging face dataset streaming support (#177)
* streaming dataset * no stride / offset for streaming * add recipe example for streaming
- Loading branch information
Showing
5 changed files
with
229 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Example of using Hugging Face streaming dataset | ||
|
||
## Based on: https://arxiv.org/pdf/2408.06537 | ||
|
||
Introducing the NewsPaLM MBR and QE Dataset: | ||
LLM-Generated High-Quality Parallel Data Outperforms Traditional | ||
Web-Crawled Data | ||
|
||
### Get the vocab and BPE model on HF | ||
|
||
https://huggingface.co/eole-nlp/NewsPalmSynthetic-ENDE | ||
|
||
copy files: | ||
* ende.vocab2 | ||
* subwords.en_de.bpe | ||
|
||
|
||
### Optionally you can get the trained model to test it. | ||
|
||
|
||
* config.json | ||
* vocab.json | ||
* model.00.safetensors | ||
|
||
## Train with the yaml config file | ||
|
||
``` | ||
eole train -c newspalm-synthetic-hfstreaming.yaml | ||
``` | ||
|
||
## Start the gradio based translator | ||
|
||
``` | ||
eole predict -c inference.yaml --src newstest2023-src.en --output newstest2023-hyp.de | ||
``` | ||
|
||
Then you can score with sacrebleu and/or comet | ||
|
||
|
||
Scoring with Unbabel/wmt22-comet-da gives: 81.90 | ||
|
||
You can compare to table 5 lines 2a) to 2d) of the paper https://arxiv.org/pdf/2408.06537 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Model info | ||
model_path: "Path to model.00.safetensors" | ||
# Inference | ||
max_length: 1024 | ||
max_length_ratio: 3 | ||
world_size: 1 | ||
gpu_ranks: [0] | ||
batch_type: tokens | ||
batch_size: 16384 | ||
compute_dtype: fp16 | ||
beam_size: 4 | ||
n_best: 1 | ||
report_time: true | ||
self_attn_backend: "pytorch" | ||
src: none |
132 changes: 132 additions & 0 deletions
132
recipes/NewsPalm-synthetic/newspalm-synthetic-hfstreaming.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
seed: 1234 | ||
share_vocab: true | ||
src_vocab: "ende.vocab2" | ||
|
||
src_words_min_frequency: 1 | ||
vocab_size_multiple: 8 | ||
report_every: 100 | ||
skip_empty_level: silent | ||
valid_metrics: ["BLEU"] | ||
scoring_debug: True | ||
|
||
# transforms config | ||
transforms: [onmt_tokenize, filtertoolong] | ||
transforms_configs: | ||
onmt_tokenize: | ||
#### Subword | ||
src_subword_type: bpe | ||
src_subword_model: "subwords.en_de.bpe" | ||
src_onmttok_kwargs: {"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": True, "segment_case": True, "segment_numbers": True, "segment_alphabet_change": True} | ||
tgt_subword_type: bpe | ||
tgt_subword_model: "subwords.en_de.bpe" | ||
tgt_onmttok_kwargs: {"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": True, "segment_case": True, "segment_numbers": True, "segment_alphabet_change": True} | ||
|
||
filtertoolong: | ||
src_seq_length: 1024 | ||
tgt_seq_length: 1024 | ||
|
||
# Corpus opts: | ||
data: | ||
synth-mbr-decoded-sentlevel: | ||
# 997834 ex - 315MB | ||
path_src: "hf://eole-nlp/synth-mbr-decoded-sentlevel/en" | ||
path_tgt: "hf://eole-nlp/synth-mbr-decoded-sentlevel/de" | ||
path_sco: "hf://eole-nlp/synth-mbr-decoded-sentlevel/sco" | ||
transforms: [onmt_tokenize, filtertoolong] | ||
weight: 12 | ||
|
||
synth-greedy-decoded-sentlevel: | ||
# 832709 ex - 250MB | ||
path_src: "hf://eole-nlp/synth-greedy-decoded-sentlevel/en" | ||
path_tgt: "hf://eole-nlp/synth-greedy-decoded-sentlevel/de" | ||
path_sco: "hf://eole-nlp/synth-greedy-decoded-sentlevel/sco" | ||
transforms: [onmt_tokenize, filtertoolong] | ||
weight: 10 | ||
|
||
synth-qe-reranked-doclevel: | ||
# 417102 ex - 970MB | ||
path_src: "hf://eole-nlp/synth-qe-reranked-doclevel/en" | ||
path_tgt: "hf://eole-nlp/synth-qe-reranked-doclevel/de" | ||
path_sco: "hf://eole-nlp/synth-qe-reranked-doclevel/sco" | ||
transforms: [onmt_tokenize, filtertoolong] | ||
weight: 1 | ||
|
||
synth-greedy-decoded-doclevel: | ||
# 857937 ex - 1.7GB | ||
path_src: "hf://eole-nlp/europarl-v10.de-en/en" | ||
path_tgt: "hf://eole-nlp/europarl-v10.de-en/de" | ||
path_sco: "hf://eole-nlp/europarl-v10.de-en/sco" | ||
transforms: [onmt_tokenize, filtertoolong] | ||
weight: 2 | ||
|
||
valid: | ||
path_src: "newstest2023-src.en" | ||
path_tgt: "newstest2023-ref.de" | ||
transforms: [onmt_tokenize] | ||
|
||
training: | ||
# General opts | ||
torch_compile: false | ||
|
||
model_path: "6-6-16-1024-4096-hfstreaming" | ||
keep_checkpoint: 50 | ||
save_checkpoint_steps: 5000 | ||
average_decay: 0.0005 | ||
train_steps: 51000 | ||
valid_steps: 100 | ||
|
||
# Batching | ||
bucket_size: 10000 | ||
num_workers: 4 | ||
prefetch_factor: 400 | ||
world_size: 1 | ||
gpu_ranks: [0] | ||
batch_type: "tokens" | ||
batch_size: 12144 | ||
valid_batch_size: 8192 | ||
batch_size_multiple: 1 | ||
accum_count: [6, 6, 6] | ||
accum_steps: [0, 15000, 30000] | ||
|
||
# Optimization | ||
compute_dtype: "fp16" | ||
apex_opt_level: "" | ||
optim: "adamw" | ||
reset_optim: "all" | ||
learning_rate: 1 | ||
warmup_steps: 6000 | ||
decay_method: "noam" | ||
adam_beta2: 0.998 | ||
max_grad_norm: 1 | ||
label_smoothing: 0.1 | ||
param_init_method: "xavier_uniform" | ||
normalization: "tokens" | ||
|
||
dropout_steps: [0, 15000, 30000] | ||
dropout: [0.1, 0.1, 0.1] | ||
attention_dropout: [0.0, 0.0, 0.0] | ||
score_threshold: 0.65 | ||
|
||
freeze_decoder: false | ||
freeze_encoder: false | ||
|
||
model: | ||
architecture: "transformer" | ||
layers: 6 | ||
hidden_size: 1024 | ||
heads: 16 | ||
transformer_ff: 4096 | ||
add_qkvbias: false | ||
add_ffnbias: true | ||
mlp_activation_fn: gated-silu | ||
add_estimator: false | ||
share_decoder_embeddings: true | ||
share_embeddings: true | ||
layer_norm: standard | ||
norm_eps: 1e-6 | ||
rope_config: | ||
rotary_interleave: false | ||
embeddings: | ||
word_vec_size: 1024 | ||
position_encoding_type: "Rotary" | ||
freeze_word_vecs_dec: false |