Skip to content

Commit

Permalink
added option for tokenizer to split on special tokens (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni authored Jul 13, 2024
1 parent 7318c44 commit ac91597
Show file tree
Hide file tree
Showing 7 changed files with 200,612 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"rich",
"s3fs>=2023.6.0",
"smart-open",
"tokenizers>=0.15.0,<1.0.0",
"tokenizers>=0.19.1,<1.0.0",
"tqdm",
"uniseg",
"numpy",
Expand Down
5 changes: 5 additions & 0 deletions python/dolma/cli/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class TokenizerConfig:
default=True,
help="Whether to use the fast tokenizer. If False, it requires the transformers library to be installed.",
)
encode_special_tokens: bool = field(
default=False,
help="Whether to encode special tokens in the tokenized output, e.g. splitting '<s>' into '<', 's', '>'.",
)

def __post__init__(self):
logger = get_logger(__file__)
Expand Down Expand Up @@ -208,6 +212,7 @@ def run(cls, parsed_config: TokenizationConfig):
eos_token_id=parsed_config.tokenizer.eos_token_id,
pad_token_id=parsed_config.tokenizer.pad_token_id,
segment_before_tokenization=parsed_config.tokenizer.segment_before_tokenization,
encode_special_tokens=parsed_config.tokenizer.encode_special_tokens,
dtype=parsed_config.dtype,
seed=parsed_config.seed,
metadata_dir=work_dirs.output,
Expand Down
5 changes: 5 additions & 0 deletions python/dolma/tokenizer/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def process_single(cls, source_path: str, destination_path: str, queue: QueueTyp
# flag to control whether to segment the documents before tokenization
tokenizer_kwargs["segment_before_tokenization"] = kwargs.pop("segment_before_tokenization", None) or False

# whether to split the special tokens into separate tokens, e.g. <s> -> < s >
tokenizer_kwargs["encode_special_tokens"] = kwargs.pop("encode_special_tokens", None) or False

# this is useful for making sure the queue does not grows too much
cpu_count = multiprocessing.cpu_count()

Expand Down Expand Up @@ -293,6 +296,7 @@ def tokenize_in_parallel(
eos_token_id: Optional[int] = 50279,
pad_token_id: Optional[int] = 1,
segment_before_tokenization: bool = False,
encode_special_tokens: bool = False,
seed: int = 3920,
metadata_dir: Optional[str] = None,
max_size: int = 1024 * 1024 * 1024,
Expand Down Expand Up @@ -371,6 +375,7 @@ def tokenize_in_parallel(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
segment_before_tokenization=segment_before_tokenization,
encode_special_tokens=encode_special_tokens,
tokenizer_name_or_path=tokenizer_name_or_path,
sample_ring_prop=sample_ring_prop,
use_fast_tokenizer=use_fast_tokenizer,
Expand Down
19 changes: 15 additions & 4 deletions python/dolma/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,14 @@ def __init__(
truncate_to: Optional[int] = None,
truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right,
segment_before_tokenization: bool = False,
encode_special_tokens: bool = False,
):
self.base_tokenizer = base_tokenizer
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.is_fast = isinstance(self.base_tokenizer, BaseTokenizer)

if self.is_fast:
self.base_tokenizer.no_truncation()

if self.pad_token_id is None:
logger.warning("No pad token ID provided; using EOS token ID %s.", eos_token_id)
self.pad_token_id = eos_token_id
Expand All @@ -95,6 +93,17 @@ def __init__(

self.config = self.get_base_tokenizer_config()
self.dtype = np.min_scalar_type(self.vocab_size - 1)
self.encode_special_tokens = encode_special_tokens

@property
def encode_special_tokens(self) -> bool:
return bool(getattr(self, "_encode_special_tokens", False))

@encode_special_tokens.setter
def encode_special_tokens(self, value: bool):
self._encode_special_tokens = value
if self.is_fast:
self.base_tokenizer.encode_special_tokens = value # pyright: ignore

@cached_property
def tokenizer_has_prefix(self) -> bool:
Expand Down Expand Up @@ -314,7 +323,9 @@ def encode_batch(
fast_batch = self.base_tokenizer.encode_batch(inputs, add_special_tokens=False)
batch_encoding = [e.ids for e in fast_batch]
else:
slow_batch = self.base_tokenizer(inputs, add_special_tokens=False) # pyright: ignore
slow_batch = self.base_tokenizer(
inputs, add_special_tokens=False, split_special_tokens=self.encode_special_tokens
) # pyright: ignore
batch_encoding = slow_batch.input_ids

all_input_ids = []
Expand Down
27 changes: 27 additions & 0 deletions scripts/make_olmo2_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import tqdm
from datasets import load_dataset
import tiktoken
from transformers import GPT2TokenizerFast

hf_tokenizer = GPT2TokenizerFast.from_pretrained("allenai/dolma2-tokenizer")
og_tokenizer = tiktoken.encoding_for_model("gpt-4")

# dataset = load_dataset("xnli", "all_languages")
dataset = load_dataset("HuggingFaceFW/fineweb", "sample-10BT", streaming=True)


cnt = 10_000
for item in tqdm.tqdm(dataset["train"]):
encoded1 = og_tokenizer.encode(item["text"])
encoded2 = hf_tokenizer.encode(item["text"])

assert encoded1 == encoded2, f'encoding "{item["text"]}" is incorrect. "{encoded1}" != "{encoded2}"'

decoded1 = og_tokenizer.decode(encoded1)
decoded2 = hf_tokenizer.decode(encoded2, skip_special_tokens=True)

assert decoded1 == decoded2, f'decoding "{item["text"]}" is incorrect. "{decoded1}" != "{decoded2}"'

cnt -= 1
if cnt == 0:
break
Loading

0 comments on commit ac91597

Please sign in to comment.