Skip to content

Commit

Permalink
Add datapipe for masked lm
Browse files Browse the repository at this point in the history
  • Loading branch information
luozhouyang committed Nov 21, 2021
1 parent 8c74253 commit 1757dc6
Show file tree
Hide file tree
Showing 10 changed files with 406 additions and 0 deletions.
4 changes: 4 additions & 0 deletions smile_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging

from smile_datasets.dataset import Datapipe
from smile_datasets.mlm.dataset import DatapipeForMaksedLanguageModel, DatasetForMaskedLanguageModel
from smile_datasets.mlm.example import ExampleForMaskedLanguageModel
from smile_datasets.mlm.masking_strategy import WholeWordMask
from smile_datasets.mlm.parsers import ParserForMasledLanguageModel
from smile_datasets.question_answering.dataset import DatapipeForQuestionAnswering, DatasetForQuestionAnswering
from smile_datasets.question_answering.example import ExampleForQuestionAnswering
from smile_datasets.question_answering.parsers import ParserForQuestionAnswering
Expand Down
13 changes: 13 additions & 0 deletions smile_datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,27 @@ class Datapipe(abc.ABC):

@classmethod
def from_tfrecord_files(cls, input_files, **kwargs) -> tf.data.Dataset:
"""Build tf.data.Dataset from tfrecord files."""
raise NotImplementedError()

@classmethod
def from_jsonl_files(cls, input_files, **kwargs) -> tf.data.Dataset:
"""Build tf.data.Dataset from jsonl files."""
raise NotImplementedError()

@classmethod
def from_instances(cls, instances, **kwargs) -> tf.data.Dataset:
"""Build tf.data.Dataset from json instances."""
raise NotImplementedError()

@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs) -> tf.data.Dataset:
"""Build tf.data.Dataset from instances of Dataset."""
raise NotImplementedError()

@classmethod
def from_examples(cls, examples, **kwargs) -> tf.data.Dataset:
"""Build tf.data.Dataset from list of examples."""
raise NotImplementedError()

def __call__(
Expand Down
Empty file added smile_datasets/mlm/__init__.py
Empty file.
183 changes: 183 additions & 0 deletions smile_datasets/mlm/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import abc
import logging
from typing import List

import tensorflow as tf
from smile_datasets import utils
from smile_datasets.dataset import Datapipe, Dataset
from smile_datasets.mlm.parsers import ParserForMasledLanguageModel
from tokenizers import BertWordPieceTokenizer

from . import readers
from .example import ExampleForMaskedLanguageModel


class DatasetForMaskedLanguageModel(Dataset):
""" """

def __len__(self):
return super().__len__()

def __getitem__(self, index):
return super().__getitem__(index)

def save_tfrecord(self, output_files, **kwargs):
"""Convert examples to tfrecord"""

def _encode(example: ExampleForMaskedLanguageModel):
feature = {
"input_ids": utils.int64_feature([int(x) for x in example.input_ids]),
"segment_ids": utils.int64_feature([int(x) for x in example.segment_ids]),
"attention_mask": utils.int64_feature([int(x) for x in example.attention_mask]),
"masked_ids": utils.int64_feature([int(x) for x in example.masked_ids]),
"masked_pos": utils.int64_feature([int(x) for x in example.masked_pos]),
}
return feature

utils.save_tfrecord(iter(self), _encode, output_files, **kwargs)


class DatapipeForMaksedLanguageModel(Datapipe):
""" """

@classmethod
def from_tfrecord_files(cls, input_files, **kwargs) -> tf.data.Dataset:
dataset = utils.read_tfrecord_files(input_files, **kwargs)
# parse example
num_parallel_calls = kwargs.get("num_parallel_calls", utils.AUTOTUNE)
buffer_size = kwargs.get("buffer_size", utils.AUTOTUNE)
features = {
"input_ids": tf.io.VarLenFeature(tf.int64),
"segment_ids": tf.io.VarLenFeature(tf.int64),
"attention_mask": tf.io.VarLenFeature(tf.int64),
"masked_ids": tf.io.VarLenFeature(tf.int64),
"masked_pos": tf.io.VarLenFeature(tf.int64),
}
dataset = dataset.map(
lambda x: tf.io.parse_example(x, features),
num_parallel_calls=num_parallel_calls,
).prefetch(buffer_size)
dataset = dataset.map(
lambda x: (
tf.cast(tf.sparse.to_dense(x["input_ids"]), tf.int32),
tf.cast(tf.sparse.to_dense(x["segment_ids"]), tf.int32),
tf.cast(tf.sparse.to_dense(x["attention_mask"]), tf.int32),
tf.cast(tf.sparse.to_dense(x["masked_ids"]), tf.int32),
tf.cast(tf.sparse.to_dense(x["masked_pos"]), tf.int32),
),
num_parallel_calls=num_parallel_calls,
).prefetch(buffer_size)
# do transformation
d = cls(**kwargs)
return d(dataset, **kwargs)

@classmethod
def from_jsonl_files(
cls, input_files, tokenizer: BertWordPieceTokenizer = None, vocab_file=None, **kwargs
) -> tf.data.Dataset:
instances = readers.read_jsonl_files(input_files, **kwargs)
return cls.from_instances(instances, tokenizer=tokenizer, vocab_file=vocab_file, **kwargs)

@classmethod
def from_instances(cls, instances, tokenizer: BertWordPieceTokenizer = None, vocab_file=None, **kwargs) -> tf.data.Dataset:
parser = ParserForMasledLanguageModel(tokenizer=tokenizer, vocab_file=vocab_file, **kwargs)
examples = []
for instance in instances:
if not instance:
continue
e = parser.parse(instance, max_sequence_length=kwargs.pop("max_sequence_length", 512), **kwargs)
if not e:
continue
examples.append(e)
return cls.from_examples(examples, **kwargs)

@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs) -> tf.data.Dataset:
examples = [e for _, e in enumerate(dataset) if e]
return cls.from_examples(examples, **kwargs)

@classmethod
def from_examples(cls, examples: List[ExampleForMaskedLanguageModel], verbose=True, **kwargs) -> tf.data.Dataset:
"Parse examples to tf.data.Dataset"
if not examples:
logging.warning("examples is empty or null, skipped to build dataset.")
return None
if verbose:
n = min(5, len(examples))
for i in range(n):
logging.info("Showing NO.%d example: %s", i, examples[i])

def _to_dataset(x, dtype=tf.int32):
x = tf.ragged.constant(x, dtype=dtype)
d = tf.data.Dataset.from_tensor_slices(x)
d = d.map(lambda x: x)
return d

# conver examples to dataset
dataset = tf.data.Dataset.zip(
(
_to_dataset([e.input_ids for e in examples], dtype=tf.int32),
_to_dataset([e.segment_ids for e in examples], dtype=tf.int32),
_to_dataset([e.attention_mask for e in examples], dtype=tf.int32),
_to_dataset([e.masked_ids for e in examples], dtype=tf.int32),
_to_dataset([e.masked_pos for e in examples], dtype=tf.int32),
)
)
# do transformation
d = cls(**kwargs)
return d(dataset, **kwargs)

def _filter(self, dataset: tf.data.Dataset, do_filter=True, max_sequence_length=512, **kwargs) -> tf.data.Dataset:
if not do_filter:
return dataset
dataset = dataset.filter(lambda a, b, c, x, y: tf.size(a) <= max_sequence_length)
return dataset

def _to_dict(self, dataset: tf.data.Dataset, to_dict=True, **kwargs) -> tf.data.Dataset:
num_parallel_calls = kwargs.get("num_parallel_calls", utils.AUTOTUNE)
buffer_size = kwargs.get("buffer_size", utils.AUTOTUNE)
if not to_dict:
dataset = dataset.map(
lambda a, b, c, x, y: ((a, b, c), (x, y)),
num_parallel_calls=num_parallel_calls,
).prefetch(buffer_size)
return dataset
dataset = dataset.map(
lambda a, b, c, x, y: ({"input_ids": a, "segment_ids": b, "attention_mask": c}, {"masked_ids": x, "masked_pos": y}),
num_parallel_calls=num_parallel_calls,
).prefetch(buffer_size)
return dataset

def _fixed_padding(self, dataset: tf.data.Dataset, pad_id=0, max_sequence_length=512, **kwargs) -> tf.data.Dataset:
maxlen = tf.constant(max_sequence_length, dtype=tf.int32)
pad_id = tf.constant(pad_id, dtype=tf.int32)
# fmt: off
padded_shapes = kwargs.get("padded_shapes", ([maxlen, ], [maxlen, ], [maxlen, ], [maxlen, ], [maxlen]))
padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id))
# fmt: on
dataset = utils.batching_and_padding(dataset, padded_shapes, padding_values, **kwargs)
return dataset

def _batch_padding(self, dataset: tf.data.Dataset, pad_id=0, **kwargs) -> tf.data.Dataset:
pad_id = tf.constant(pad_id, dtype=tf.int32)
# fmt: off
padded_shapes = kwargs.get("padded_shapes", ([None, ], [None, ], [None, ], [None, ], [None, ]))
padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id))
# fmt: on
dataset = utils.batching_and_padding(dataset, padded_shapes, padding_values, **kwargs)
return dataset

def _bucket_padding(self, dataset: tf.data.Dataset, pad_id=0, **kwargs) -> tf.data.Dataset:
pad_id = tf.constant(pad_id, dtype=tf.int32)
# fmt: off
padded_shapes = kwargs.get("padded_shapes", ([None, ], [None, ], [None, ], [None, ], [None, ]))
padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id))
# fmt: on
dataset = utils.bucketing_and_padding(
dataset,
bucket_fn=lambda a, b, c, x, y: tf.size(a),
padded_shapes=padded_shapes,
padding_values=padding_values,
**kwargs,
)
return dataset
6 changes: 6 additions & 0 deletions smile_datasets/mlm/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from collections import namedtuple

ExampleForMaskedLanguageModel = namedtuple(
"ExampleForMaskedLanguageModel",
["tokens", "input_ids", "segment_ids", "attention_mask", "masked_ids", "masked_pos"],
)
88 changes: 88 additions & 0 deletions smile_datasets/mlm/masking_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import abc
import random
from collections import namedtuple

ResultForMasking = namedtuple("ResultForMasking", ["origin_tokens", "masked_tokens", "masked_indexes"])


class AbstractMaskingStrategy(abc.ABC):
"""Abstract masking strategy"""

@abc.abstractmethod
def __call__(self, tokens, **kwargs) -> ResultForMasking:
raise NotImplementedError()


class WholeWordMask(AbstractMaskingStrategy):
"""Default masking strategy from BERT."""

def __init__(self, vocabs, change_prob=0.15, mask_prob=0.8, rand_prob=0.1, keep_prob=0.1, max_predictions=20, **kwargs):
self.vocabs = vocabs
self.change_prob = change_prob
self.mask_prob = mask_prob / (mask_prob + rand_prob + keep_prob)
self.rand_prob = rand_prob / (mask_prob + rand_prob + keep_prob)
self.keep_prob = keep_prob / (mask_prob + rand_prob + keep_prob)
self.max_predictions = max_predictions

def __call__(self, tokens, max_sequence_length=512, **kwargs) -> ResultForMasking:
tokens = self._truncate_sequence(tokens, max_sequence_length - 2)
if not tokens:
return None
num_to_predict = min(self.max_predictions, max(1, round(self.change_prob * len(tokens))))
cand_indexes = self._collect_candidates(tokens)
# copy original tokens
masked_tokens = [x for x in tokens]
masked_indexes = [0] * len(tokens)
for piece_indexes in cand_indexes:
if sum(masked_indexes) >= num_to_predict:
break
if sum(masked_indexes) + len(piece_indexes) > num_to_predict:
continue
if any(masked_indexes[idx] == 1 for idx in piece_indexes):
continue
for index in piece_indexes:
masked_indexes[index] = 1
masked_tokens[index] = self._masking_tokens(index, tokens, self.vocabs)

# add special tokens
tokens = ["[CLS]"] + tokens + ["[SEP]"]
masked_tokens = ["[CLS]"] + masked_tokens + ["[SEP]"]
masked_indexes = [0] + masked_indexes + [0]
assert len(tokens) == len(masked_tokens) == len(masked_indexes)
return ResultForMasking(origin_tokens=tokens, masked_tokens=masked_tokens, masked_indexes=masked_indexes)

def _masking_tokens(self, index, tokens, vocabs, **kwargs):
# 80% of the time, replace with [MASK]
if random.random() < self.mask_prob:
return "[MASK]"
# 10% of the time, keep original
p = self.rand_prob / (self.rand_prob + self.keep_prob)
if random.random() < p:
return tokens[index]
# 10% of the time, replace with random word
masked_token = vocabs[random.randint(0, len(vocabs) - 1)]
return masked_token

def _collect_candidates(self, tokens):
cand_indexes = [[]]
for idx, token in enumerate(tokens):
if cand_indexes and token.startswith("##"):
cand_indexes[-1].append(idx)
continue
cand_indexes.append([idx])
random.shuffle(cand_indexes)
return cand_indexes

def _truncate_sequence(self, tokens, max_tokens=512, **kwargs):
while len(tokens) > max_tokens:
if len(tokens) > max_tokens:
tokens.pop(0)
# truncate whole world
while tokens and tokens[0].startswith("##"):
tokens.pop(0)
if len(tokens) > max_tokens:
while tokens and tokens[-1].startswith("##"):
tokens.pop()
if tokens:
tokens.pop()
return tokens
38 changes: 38 additions & 0 deletions smile_datasets/mlm/parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from tokenizers import BertWordPieceTokenizer

from .example import ExampleForMaskedLanguageModel
from .masking_strategy import WholeWordMask


class ParserForMasledLanguageModel:
""" """

@classmethod
def from_vocab_file(cls, vocab_file, **kwargs):
return cls(tokenizer=None, vocab_file=vocab_file, **kwargs)

@classmethod
def from_tokenizer(cls, tokenizer: BertWordPieceTokenizer, **kwargs):
return cls(tokenizer=tokenizer, vocab_file=None, **kwargs)

def __init__(self, tokenizer: BertWordPieceTokenizer, vocab_file=None, do_lower_case=True, **kwargs) -> None:
assert tokenizer or vocab_file, "`tokenizer` or `vocab_file` must be provided!"
self.tokenizer = tokenizer or BertWordPieceTokenizer.from_file(vocab_file, lowercase=do_lower_case)
self.vocabs = list(self.tokenizer.get_vocab().keys())
self.masking = WholeWordMask(vocabs=self.vocabs, **kwargs)

def parse(self, instance, max_sequence_length=512, **kwargs) -> ExampleForMaskedLanguageModel:
sequence = instance["sequence"]
# set add_special_tokens=False here, masking strategy will add these special tokens
encoding = self.tokenizer.encode(sequence, add_special_tokens=False)
masking_results = self.masking(tokens=encoding.tokens, max_sequence_length=max_sequence_length, **kwargs)
origin_tokens, masked_tokens = masking_results.origin_tokens, masking_results.masked_tokens
example = ExampleForMaskedLanguageModel(
tokens=masked_tokens,
input_ids=[self.tokenizer.token_to_id(x) for x in masked_tokens],
segment_ids=[0] * len(masked_tokens),
attention_mask=[1] * len(masked_tokens),
masked_ids=[self.tokenizer.token_to_id(x) for x in origin_tokens],
masked_pos=masking_results.masked_indexes,
)
return example
17 changes: 17 additions & 0 deletions smile_datasets/mlm/readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json
import logging
import os


def read_jsonl_files(input_files, sequence_key="sequence", **kwargs):
if isinstance(input_files, str):
input_files = [input_files]
for f in input_files:
with open(f, mode="rt", encoding="utf-8") as fin:
for line in fin:
line = line.strip()
if not line:
continue
data = json.loads(line)
instance = {"sequence": data[sequence_key]}
yield instance
Empty file added tests/mlm_tests/__init__.py
Empty file.
Loading

0 comments on commit 1757dc6

Please sign in to comment.