From 1757dc6866c8dcb50e4f56f5bada6d26559d6c07 Mon Sep 17 00:00:00 2001 From: luozhouyang Date: Sun, 21 Nov 2021 22:16:54 +0800 Subject: [PATCH] Add datapipe for masked lm --- smile_datasets/__init__.py | 4 + smile_datasets/dataset.py | 13 ++ smile_datasets/mlm/__init__.py | 0 smile_datasets/mlm/dataset.py | 183 +++++++++++++++++++++++++ smile_datasets/mlm/example.py | 6 + smile_datasets/mlm/masking_strategy.py | 88 ++++++++++++ smile_datasets/mlm/parsers.py | 38 +++++ smile_datasets/mlm/readers.py | 17 +++ tests/mlm_tests/__init__.py | 0 tests/mlm_tests/dataset_test.py | 57 ++++++++ 10 files changed, 406 insertions(+) create mode 100644 smile_datasets/mlm/__init__.py create mode 100644 smile_datasets/mlm/dataset.py create mode 100644 smile_datasets/mlm/example.py create mode 100644 smile_datasets/mlm/masking_strategy.py create mode 100644 smile_datasets/mlm/parsers.py create mode 100644 smile_datasets/mlm/readers.py create mode 100644 tests/mlm_tests/__init__.py create mode 100644 tests/mlm_tests/dataset_test.py diff --git a/smile_datasets/__init__.py b/smile_datasets/__init__.py index a82eb3f..cca7d1a 100644 --- a/smile_datasets/__init__.py +++ b/smile_datasets/__init__.py @@ -1,6 +1,10 @@ import logging from smile_datasets.dataset import Datapipe +from smile_datasets.mlm.dataset import DatapipeForMaksedLanguageModel, DatasetForMaskedLanguageModel +from smile_datasets.mlm.example import ExampleForMaskedLanguageModel +from smile_datasets.mlm.masking_strategy import WholeWordMask +from smile_datasets.mlm.parsers import ParserForMasledLanguageModel from smile_datasets.question_answering.dataset import DatapipeForQuestionAnswering, DatasetForQuestionAnswering from smile_datasets.question_answering.example import ExampleForQuestionAnswering from smile_datasets.question_answering.parsers import ParserForQuestionAnswering diff --git a/smile_datasets/dataset.py b/smile_datasets/dataset.py index c74dc3a..e62cb38 100644 --- a/smile_datasets/dataset.py +++ b/smile_datasets/dataset.py @@ -28,14 +28,27 @@ class Datapipe(abc.ABC): @classmethod def from_tfrecord_files(cls, input_files, **kwargs) -> tf.data.Dataset: + """Build tf.data.Dataset from tfrecord files.""" + raise NotImplementedError() + + @classmethod + def from_jsonl_files(cls, input_files, **kwargs) -> tf.data.Dataset: + """Build tf.data.Dataset from jsonl files.""" + raise NotImplementedError() + + @classmethod + def from_instances(cls, instances, **kwargs) -> tf.data.Dataset: + """Build tf.data.Dataset from json instances.""" raise NotImplementedError() @classmethod def from_dataset(cls, dataset: Dataset, **kwargs) -> tf.data.Dataset: + """Build tf.data.Dataset from instances of Dataset.""" raise NotImplementedError() @classmethod def from_examples(cls, examples, **kwargs) -> tf.data.Dataset: + """Build tf.data.Dataset from list of examples.""" raise NotImplementedError() def __call__( diff --git a/smile_datasets/mlm/__init__.py b/smile_datasets/mlm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/smile_datasets/mlm/dataset.py b/smile_datasets/mlm/dataset.py new file mode 100644 index 0000000..963cadc --- /dev/null +++ b/smile_datasets/mlm/dataset.py @@ -0,0 +1,183 @@ +import abc +import logging +from typing import List + +import tensorflow as tf +from smile_datasets import utils +from smile_datasets.dataset import Datapipe, Dataset +from smile_datasets.mlm.parsers import ParserForMasledLanguageModel +from tokenizers import BertWordPieceTokenizer + +from . import readers +from .example import ExampleForMaskedLanguageModel + + +class DatasetForMaskedLanguageModel(Dataset): + """ """ + + def __len__(self): + return super().__len__() + + def __getitem__(self, index): + return super().__getitem__(index) + + def save_tfrecord(self, output_files, **kwargs): + """Convert examples to tfrecord""" + + def _encode(example: ExampleForMaskedLanguageModel): + feature = { + "input_ids": utils.int64_feature([int(x) for x in example.input_ids]), + "segment_ids": utils.int64_feature([int(x) for x in example.segment_ids]), + "attention_mask": utils.int64_feature([int(x) for x in example.attention_mask]), + "masked_ids": utils.int64_feature([int(x) for x in example.masked_ids]), + "masked_pos": utils.int64_feature([int(x) for x in example.masked_pos]), + } + return feature + + utils.save_tfrecord(iter(self), _encode, output_files, **kwargs) + + +class DatapipeForMaksedLanguageModel(Datapipe): + """ """ + + @classmethod + def from_tfrecord_files(cls, input_files, **kwargs) -> tf.data.Dataset: + dataset = utils.read_tfrecord_files(input_files, **kwargs) + # parse example + num_parallel_calls = kwargs.get("num_parallel_calls", utils.AUTOTUNE) + buffer_size = kwargs.get("buffer_size", utils.AUTOTUNE) + features = { + "input_ids": tf.io.VarLenFeature(tf.int64), + "segment_ids": tf.io.VarLenFeature(tf.int64), + "attention_mask": tf.io.VarLenFeature(tf.int64), + "masked_ids": tf.io.VarLenFeature(tf.int64), + "masked_pos": tf.io.VarLenFeature(tf.int64), + } + dataset = dataset.map( + lambda x: tf.io.parse_example(x, features), + num_parallel_calls=num_parallel_calls, + ).prefetch(buffer_size) + dataset = dataset.map( + lambda x: ( + tf.cast(tf.sparse.to_dense(x["input_ids"]), tf.int32), + tf.cast(tf.sparse.to_dense(x["segment_ids"]), tf.int32), + tf.cast(tf.sparse.to_dense(x["attention_mask"]), tf.int32), + tf.cast(tf.sparse.to_dense(x["masked_ids"]), tf.int32), + tf.cast(tf.sparse.to_dense(x["masked_pos"]), tf.int32), + ), + num_parallel_calls=num_parallel_calls, + ).prefetch(buffer_size) + # do transformation + d = cls(**kwargs) + return d(dataset, **kwargs) + + @classmethod + def from_jsonl_files( + cls, input_files, tokenizer: BertWordPieceTokenizer = None, vocab_file=None, **kwargs + ) -> tf.data.Dataset: + instances = readers.read_jsonl_files(input_files, **kwargs) + return cls.from_instances(instances, tokenizer=tokenizer, vocab_file=vocab_file, **kwargs) + + @classmethod + def from_instances(cls, instances, tokenizer: BertWordPieceTokenizer = None, vocab_file=None, **kwargs) -> tf.data.Dataset: + parser = ParserForMasledLanguageModel(tokenizer=tokenizer, vocab_file=vocab_file, **kwargs) + examples = [] + for instance in instances: + if not instance: + continue + e = parser.parse(instance, max_sequence_length=kwargs.pop("max_sequence_length", 512), **kwargs) + if not e: + continue + examples.append(e) + return cls.from_examples(examples, **kwargs) + + @classmethod + def from_dataset(cls, dataset: Dataset, **kwargs) -> tf.data.Dataset: + examples = [e for _, e in enumerate(dataset) if e] + return cls.from_examples(examples, **kwargs) + + @classmethod + def from_examples(cls, examples: List[ExampleForMaskedLanguageModel], verbose=True, **kwargs) -> tf.data.Dataset: + "Parse examples to tf.data.Dataset" + if not examples: + logging.warning("examples is empty or null, skipped to build dataset.") + return None + if verbose: + n = min(5, len(examples)) + for i in range(n): + logging.info("Showing NO.%d example: %s", i, examples[i]) + + def _to_dataset(x, dtype=tf.int32): + x = tf.ragged.constant(x, dtype=dtype) + d = tf.data.Dataset.from_tensor_slices(x) + d = d.map(lambda x: x) + return d + + # conver examples to dataset + dataset = tf.data.Dataset.zip( + ( + _to_dataset([e.input_ids for e in examples], dtype=tf.int32), + _to_dataset([e.segment_ids for e in examples], dtype=tf.int32), + _to_dataset([e.attention_mask for e in examples], dtype=tf.int32), + _to_dataset([e.masked_ids for e in examples], dtype=tf.int32), + _to_dataset([e.masked_pos for e in examples], dtype=tf.int32), + ) + ) + # do transformation + d = cls(**kwargs) + return d(dataset, **kwargs) + + def _filter(self, dataset: tf.data.Dataset, do_filter=True, max_sequence_length=512, **kwargs) -> tf.data.Dataset: + if not do_filter: + return dataset + dataset = dataset.filter(lambda a, b, c, x, y: tf.size(a) <= max_sequence_length) + return dataset + + def _to_dict(self, dataset: tf.data.Dataset, to_dict=True, **kwargs) -> tf.data.Dataset: + num_parallel_calls = kwargs.get("num_parallel_calls", utils.AUTOTUNE) + buffer_size = kwargs.get("buffer_size", utils.AUTOTUNE) + if not to_dict: + dataset = dataset.map( + lambda a, b, c, x, y: ((a, b, c), (x, y)), + num_parallel_calls=num_parallel_calls, + ).prefetch(buffer_size) + return dataset + dataset = dataset.map( + lambda a, b, c, x, y: ({"input_ids": a, "segment_ids": b, "attention_mask": c}, {"masked_ids": x, "masked_pos": y}), + num_parallel_calls=num_parallel_calls, + ).prefetch(buffer_size) + return dataset + + def _fixed_padding(self, dataset: tf.data.Dataset, pad_id=0, max_sequence_length=512, **kwargs) -> tf.data.Dataset: + maxlen = tf.constant(max_sequence_length, dtype=tf.int32) + pad_id = tf.constant(pad_id, dtype=tf.int32) + # fmt: off + padded_shapes = kwargs.get("padded_shapes", ([maxlen, ], [maxlen, ], [maxlen, ], [maxlen, ], [maxlen])) + padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id)) + # fmt: on + dataset = utils.batching_and_padding(dataset, padded_shapes, padding_values, **kwargs) + return dataset + + def _batch_padding(self, dataset: tf.data.Dataset, pad_id=0, **kwargs) -> tf.data.Dataset: + pad_id = tf.constant(pad_id, dtype=tf.int32) + # fmt: off + padded_shapes = kwargs.get("padded_shapes", ([None, ], [None, ], [None, ], [None, ], [None, ])) + padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id)) + # fmt: on + dataset = utils.batching_and_padding(dataset, padded_shapes, padding_values, **kwargs) + return dataset + + def _bucket_padding(self, dataset: tf.data.Dataset, pad_id=0, **kwargs) -> tf.data.Dataset: + pad_id = tf.constant(pad_id, dtype=tf.int32) + # fmt: off + padded_shapes = kwargs.get("padded_shapes", ([None, ], [None, ], [None, ], [None, ], [None, ])) + padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id)) + # fmt: on + dataset = utils.bucketing_and_padding( + dataset, + bucket_fn=lambda a, b, c, x, y: tf.size(a), + padded_shapes=padded_shapes, + padding_values=padding_values, + **kwargs, + ) + return dataset diff --git a/smile_datasets/mlm/example.py b/smile_datasets/mlm/example.py new file mode 100644 index 0000000..cabdbea --- /dev/null +++ b/smile_datasets/mlm/example.py @@ -0,0 +1,6 @@ +from collections import namedtuple + +ExampleForMaskedLanguageModel = namedtuple( + "ExampleForMaskedLanguageModel", + ["tokens", "input_ids", "segment_ids", "attention_mask", "masked_ids", "masked_pos"], +) diff --git a/smile_datasets/mlm/masking_strategy.py b/smile_datasets/mlm/masking_strategy.py new file mode 100644 index 0000000..e3d9bac --- /dev/null +++ b/smile_datasets/mlm/masking_strategy.py @@ -0,0 +1,88 @@ +import abc +import random +from collections import namedtuple + +ResultForMasking = namedtuple("ResultForMasking", ["origin_tokens", "masked_tokens", "masked_indexes"]) + + +class AbstractMaskingStrategy(abc.ABC): + """Abstract masking strategy""" + + @abc.abstractmethod + def __call__(self, tokens, **kwargs) -> ResultForMasking: + raise NotImplementedError() + + +class WholeWordMask(AbstractMaskingStrategy): + """Default masking strategy from BERT.""" + + def __init__(self, vocabs, change_prob=0.15, mask_prob=0.8, rand_prob=0.1, keep_prob=0.1, max_predictions=20, **kwargs): + self.vocabs = vocabs + self.change_prob = change_prob + self.mask_prob = mask_prob / (mask_prob + rand_prob + keep_prob) + self.rand_prob = rand_prob / (mask_prob + rand_prob + keep_prob) + self.keep_prob = keep_prob / (mask_prob + rand_prob + keep_prob) + self.max_predictions = max_predictions + + def __call__(self, tokens, max_sequence_length=512, **kwargs) -> ResultForMasking: + tokens = self._truncate_sequence(tokens, max_sequence_length - 2) + if not tokens: + return None + num_to_predict = min(self.max_predictions, max(1, round(self.change_prob * len(tokens)))) + cand_indexes = self._collect_candidates(tokens) + # copy original tokens + masked_tokens = [x for x in tokens] + masked_indexes = [0] * len(tokens) + for piece_indexes in cand_indexes: + if sum(masked_indexes) >= num_to_predict: + break + if sum(masked_indexes) + len(piece_indexes) > num_to_predict: + continue + if any(masked_indexes[idx] == 1 for idx in piece_indexes): + continue + for index in piece_indexes: + masked_indexes[index] = 1 + masked_tokens[index] = self._masking_tokens(index, tokens, self.vocabs) + + # add special tokens + tokens = ["[CLS]"] + tokens + ["[SEP]"] + masked_tokens = ["[CLS]"] + masked_tokens + ["[SEP]"] + masked_indexes = [0] + masked_indexes + [0] + assert len(tokens) == len(masked_tokens) == len(masked_indexes) + return ResultForMasking(origin_tokens=tokens, masked_tokens=masked_tokens, masked_indexes=masked_indexes) + + def _masking_tokens(self, index, tokens, vocabs, **kwargs): + # 80% of the time, replace with [MASK] + if random.random() < self.mask_prob: + return "[MASK]" + # 10% of the time, keep original + p = self.rand_prob / (self.rand_prob + self.keep_prob) + if random.random() < p: + return tokens[index] + # 10% of the time, replace with random word + masked_token = vocabs[random.randint(0, len(vocabs) - 1)] + return masked_token + + def _collect_candidates(self, tokens): + cand_indexes = [[]] + for idx, token in enumerate(tokens): + if cand_indexes and token.startswith("##"): + cand_indexes[-1].append(idx) + continue + cand_indexes.append([idx]) + random.shuffle(cand_indexes) + return cand_indexes + + def _truncate_sequence(self, tokens, max_tokens=512, **kwargs): + while len(tokens) > max_tokens: + if len(tokens) > max_tokens: + tokens.pop(0) + # truncate whole world + while tokens and tokens[0].startswith("##"): + tokens.pop(0) + if len(tokens) > max_tokens: + while tokens and tokens[-1].startswith("##"): + tokens.pop() + if tokens: + tokens.pop() + return tokens diff --git a/smile_datasets/mlm/parsers.py b/smile_datasets/mlm/parsers.py new file mode 100644 index 0000000..1bb565a --- /dev/null +++ b/smile_datasets/mlm/parsers.py @@ -0,0 +1,38 @@ +from tokenizers import BertWordPieceTokenizer + +from .example import ExampleForMaskedLanguageModel +from .masking_strategy import WholeWordMask + + +class ParserForMasledLanguageModel: + """ """ + + @classmethod + def from_vocab_file(cls, vocab_file, **kwargs): + return cls(tokenizer=None, vocab_file=vocab_file, **kwargs) + + @classmethod + def from_tokenizer(cls, tokenizer: BertWordPieceTokenizer, **kwargs): + return cls(tokenizer=tokenizer, vocab_file=None, **kwargs) + + def __init__(self, tokenizer: BertWordPieceTokenizer, vocab_file=None, do_lower_case=True, **kwargs) -> None: + assert tokenizer or vocab_file, "`tokenizer` or `vocab_file` must be provided!" + self.tokenizer = tokenizer or BertWordPieceTokenizer.from_file(vocab_file, lowercase=do_lower_case) + self.vocabs = list(self.tokenizer.get_vocab().keys()) + self.masking = WholeWordMask(vocabs=self.vocabs, **kwargs) + + def parse(self, instance, max_sequence_length=512, **kwargs) -> ExampleForMaskedLanguageModel: + sequence = instance["sequence"] + # set add_special_tokens=False here, masking strategy will add these special tokens + encoding = self.tokenizer.encode(sequence, add_special_tokens=False) + masking_results = self.masking(tokens=encoding.tokens, max_sequence_length=max_sequence_length, **kwargs) + origin_tokens, masked_tokens = masking_results.origin_tokens, masking_results.masked_tokens + example = ExampleForMaskedLanguageModel( + tokens=masked_tokens, + input_ids=[self.tokenizer.token_to_id(x) for x in masked_tokens], + segment_ids=[0] * len(masked_tokens), + attention_mask=[1] * len(masked_tokens), + masked_ids=[self.tokenizer.token_to_id(x) for x in origin_tokens], + masked_pos=masking_results.masked_indexes, + ) + return example diff --git a/smile_datasets/mlm/readers.py b/smile_datasets/mlm/readers.py new file mode 100644 index 0000000..b044bad --- /dev/null +++ b/smile_datasets/mlm/readers.py @@ -0,0 +1,17 @@ +import json +import logging +import os + + +def read_jsonl_files(input_files, sequence_key="sequence", **kwargs): + if isinstance(input_files, str): + input_files = [input_files] + for f in input_files: + with open(f, mode="rt", encoding="utf-8") as fin: + for line in fin: + line = line.strip() + if not line: + continue + data = json.loads(line) + instance = {"sequence": data[sequence_key]} + yield instance diff --git a/tests/mlm_tests/__init__.py b/tests/mlm_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/mlm_tests/dataset_test.py b/tests/mlm_tests/dataset_test.py new file mode 100644 index 0000000..d49038c --- /dev/null +++ b/tests/mlm_tests/dataset_test.py @@ -0,0 +1,57 @@ +import unittest + +from smile_datasets.mlm import readers +from smile_datasets.mlm.dataset import DatapipeForMaksedLanguageModel, DatasetForMaskedLanguageModel +from smile_datasets.mlm.example import ExampleForMaskedLanguageModel +from smile_datasets.mlm.parsers import ParserForMasledLanguageModel + + +class MyDatasetForMaskedLanguageModel(DatasetForMaskedLanguageModel): + """ """ + + def __init__(self, input_files, vocab_file, **kwargs) -> None: + super().__init__() + examples = [] + parser = ParserForMasledLanguageModel.from_vocab_file(vocab_file, **kwargs) + for instance in readers.read_jsonl_files(input_files, **kwargs): + e = parser.parse(instance, **kwargs) + if not e: + continue + examples.append(e) + self.examples = examples + + def __len__(self): + return len(self.examples) + + def __getitem__(self, index): + return self.examples[index] + + +class DatasetTest(unittest.TestCase): + """ """ + + def test_dataset_from_jsonl_files(self): + d = MyDatasetForMaskedLanguageModel("testdata/mlm.jsonl", vocab_file="testdata/vocab.txt") + print("Showing examples:") + for _, e in enumerate(d): + print(e) + + print("Load datapipe from dataset:") + dataset = DatapipeForMaksedLanguageModel.from_dataset(d) + print(next(iter(dataset))) + + print("Save to tfrecord:") + d.save_tfrecord("testdata/mlm.tfrecord") + + print("Load datapipe from tfrecord:") + dataset = DatapipeForMaksedLanguageModel.from_tfrecord_files("testdata/mlm.tfrecord") + print(next(iter(dataset))) + + def test_datapipe_from_jsonl_files(self): + print() + dataset = DatapipeForMaksedLanguageModel.from_jsonl_files("testdata/mlm.jsonl", vocab_file="testdata/vocab.txt") + print(next(iter(dataset))) + + +if __name__ == "__main__": + unittest.main()