Skip to content

Commit

Permalink
Hugging face dataset streaming support (#177)
Browse files Browse the repository at this point in the history
* streaming dataset
* no stride / offset for streaming
* add recipe example for streaming
  • Loading branch information
vince62s authored Jan 16, 2025
1 parent 02a9dbe commit d2992b7
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 1 deletion.
2 changes: 1 addition & 1 deletion eole/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _validate_vocab_config(self, build_vocab_only=False):
@staticmethod
def _validate_file(file_path, info):
"""Check `file_path` is valid or raise `IOError`."""
if file_path == "dummy":
if file_path == "dummy" or file_path.startswith("hf://"):
# hack to allow creating objects with required fields
pass
elif not os.path.isfile(file_path):
Expand Down
39 changes: 39 additions & 0 deletions eole/inputters/text_corpus.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Module that contain shard utils for dynamic data."""

import os
import re
from eole.utils.logging import logger
from eole.constants import CorpusName, CorpusTask
from eole.transforms import TransformPipe
from eole.inputters.text_utils import transform_bucket
from contextlib import contextmanager
import itertools
from datasets import load_dataset


@contextmanager
Expand Down Expand Up @@ -102,11 +104,38 @@ def __init__(self, name, src, tgt, sco=None, align=None):
self.sco = sco
self.align = align

def _is_hf_dataset(self, path):
"""
Check if a given path refers to a Hugging Face dataset.
Matchs the 'hf://' prefix and assumes the dataset is in streaming mode.
Match the last '/field' to get the language / score field
"""
pattern = r"hf://([^/]+/[^/]+)/([^/]+)"
if isinstance(path, str):
return re.match(pattern, path)
else:
return None

def _load_hf_dataset(self, path):
"""
Load a Hugging Face dataset from the given identifier.
Matchs the 'hf://' prefix and assumes the dataset is in streaming mode.
Match the last '/field' to get the language / score field
"""
pattern = r"hf://([^/]+/[^/]+)/([^/]+)"
dataset_name = re.match(pattern, self.src).group(1)
return load_dataset(dataset_name, split="train", streaming=True)

def load(self, offset=0, stride=1):
"""
Load file and iterate by lines.
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
In the case of local files, all files are open exactly the same way by each worker
Therefore we need to apply a stride / offset rule to make sure we do not process the same ex.
In the case of HF streaming mode we need to make sure we have more shards than workers.
Typically we recommend to have shard being a multiple of workers for instance for big datasets:
16 shards for 4 workers. The shards will be iterated automatically since HF locks shards when in use.
"""

def make_ex(sline, tline, scoline, align):
Expand All @@ -133,6 +162,16 @@ def make_ex(sline, tline, scoline, align):
if scoline is None:
scoline = 1.0
yield make_ex(sline, tline, scoline, align)

elif self._is_hf_dataset(self.src):
# If `src` is a Hugging Face dataset identifier
dataset = self._load_hf_dataset(self.src)
for i, example in enumerate(dataset):
sline = example.get(self.src.split("/")[-1])
tline = example.get(self.tgt.split("/")[-1])
scoline = example.get(self.sco.split("/")[-1], 1.0)
yield make_ex(sline, tline, scoline, None)

else:
with exfile_open(self.src, mode="rb") as fs, exfile_open(self.tgt, mode="rb") as ft, exfile_open(
self.sco, mode="rb"
Expand Down
42 changes: 42 additions & 0 deletions recipes/NewsPalm-synthetic/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Example of using Hugging Face streaming dataset

## Based on: https://arxiv.org/pdf/2408.06537

Introducing the NewsPaLM MBR and QE Dataset:
LLM-Generated High-Quality Parallel Data Outperforms Traditional
Web-Crawled Data

### Get the vocab and BPE model on HF

https://huggingface.co/eole-nlp/NewsPalmSynthetic-ENDE

copy files:
* ende.vocab2
* subwords.en_de.bpe


### Optionally you can get the trained model to test it.


* config.json
* vocab.json
* model.00.safetensors

## Train with the yaml config file

```
eole train -c newspalm-synthetic-hfstreaming.yaml
```

## Start the gradio based translator

```
eole predict -c inference.yaml --src newstest2023-src.en --output newstest2023-hyp.de
```

Then you can score with sacrebleu and/or comet


Scoring with Unbabel/wmt22-comet-da gives: 81.90

You can compare to table 5 lines 2a) to 2d) of the paper https://arxiv.org/pdf/2408.06537
15 changes: 15 additions & 0 deletions recipes/NewsPalm-synthetic/inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Model info
model_path: "Path to model.00.safetensors"
# Inference
max_length: 1024
max_length_ratio: 3
world_size: 1
gpu_ranks: [0]
batch_type: tokens
batch_size: 16384
compute_dtype: fp16
beam_size: 4
n_best: 1
report_time: true
self_attn_backend: "pytorch"
src: none
132 changes: 132 additions & 0 deletions recipes/NewsPalm-synthetic/newspalm-synthetic-hfstreaming.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
seed: 1234
share_vocab: true
src_vocab: "ende.vocab2"

src_words_min_frequency: 1
vocab_size_multiple: 8
report_every: 100
skip_empty_level: silent
valid_metrics: ["BLEU"]
scoring_debug: True

# transforms config
transforms: [onmt_tokenize, filtertoolong]
transforms_configs:
onmt_tokenize:
#### Subword
src_subword_type: bpe
src_subword_model: "subwords.en_de.bpe"
src_onmttok_kwargs: {"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": True, "segment_case": True, "segment_numbers": True, "segment_alphabet_change": True}
tgt_subword_type: bpe
tgt_subword_model: "subwords.en_de.bpe"
tgt_onmttok_kwargs: {"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": True, "segment_case": True, "segment_numbers": True, "segment_alphabet_change": True}

filtertoolong:
src_seq_length: 1024
tgt_seq_length: 1024

# Corpus opts:
data:
synth-mbr-decoded-sentlevel:
# 997834 ex - 315MB
path_src: "hf://eole-nlp/synth-mbr-decoded-sentlevel/en"
path_tgt: "hf://eole-nlp/synth-mbr-decoded-sentlevel/de"
path_sco: "hf://eole-nlp/synth-mbr-decoded-sentlevel/sco"
transforms: [onmt_tokenize, filtertoolong]
weight: 12

synth-greedy-decoded-sentlevel:
# 832709 ex - 250MB
path_src: "hf://eole-nlp/synth-greedy-decoded-sentlevel/en"
path_tgt: "hf://eole-nlp/synth-greedy-decoded-sentlevel/de"
path_sco: "hf://eole-nlp/synth-greedy-decoded-sentlevel/sco"
transforms: [onmt_tokenize, filtertoolong]
weight: 10

synth-qe-reranked-doclevel:
# 417102 ex - 970MB
path_src: "hf://eole-nlp/synth-qe-reranked-doclevel/en"
path_tgt: "hf://eole-nlp/synth-qe-reranked-doclevel/de"
path_sco: "hf://eole-nlp/synth-qe-reranked-doclevel/sco"
transforms: [onmt_tokenize, filtertoolong]
weight: 1

synth-greedy-decoded-doclevel:
# 857937 ex - 1.7GB
path_src: "hf://eole-nlp/europarl-v10.de-en/en"
path_tgt: "hf://eole-nlp/europarl-v10.de-en/de"
path_sco: "hf://eole-nlp/europarl-v10.de-en/sco"
transforms: [onmt_tokenize, filtertoolong]
weight: 2

valid:
path_src: "newstest2023-src.en"
path_tgt: "newstest2023-ref.de"
transforms: [onmt_tokenize]

training:
# General opts
torch_compile: false

model_path: "6-6-16-1024-4096-hfstreaming"
keep_checkpoint: 50
save_checkpoint_steps: 5000
average_decay: 0.0005
train_steps: 51000
valid_steps: 100

# Batching
bucket_size: 10000
num_workers: 4
prefetch_factor: 400
world_size: 1
gpu_ranks: [0]
batch_type: "tokens"
batch_size: 12144
valid_batch_size: 8192
batch_size_multiple: 1
accum_count: [6, 6, 6]
accum_steps: [0, 15000, 30000]

# Optimization
compute_dtype: "fp16"
apex_opt_level: ""
optim: "adamw"
reset_optim: "all"
learning_rate: 1
warmup_steps: 6000
decay_method: "noam"
adam_beta2: 0.998
max_grad_norm: 1
label_smoothing: 0.1
param_init_method: "xavier_uniform"
normalization: "tokens"

dropout_steps: [0, 15000, 30000]
dropout: [0.1, 0.1, 0.1]
attention_dropout: [0.0, 0.0, 0.0]
score_threshold: 0.65

freeze_decoder: false
freeze_encoder: false

model:
architecture: "transformer"
layers: 6
hidden_size: 1024
heads: 16
transformer_ff: 4096
add_qkvbias: false
add_ffnbias: true
mlp_activation_fn: gated-silu
add_estimator: false
share_decoder_embeddings: true
share_embeddings: true
layer_norm: standard
norm_eps: 1e-6
rope_config:
rotary_interleave: false
embeddings:
word_vec_size: 1024
position_encoding_type: "Rotary"
freeze_word_vecs_dec: false

0 comments on commit d2992b7

Please sign in to comment.