-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8c74253
commit 1757dc6
Showing
10 changed files
with
406 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import abc | ||
import logging | ||
from typing import List | ||
|
||
import tensorflow as tf | ||
from smile_datasets import utils | ||
from smile_datasets.dataset import Datapipe, Dataset | ||
from smile_datasets.mlm.parsers import ParserForMasledLanguageModel | ||
from tokenizers import BertWordPieceTokenizer | ||
|
||
from . import readers | ||
from .example import ExampleForMaskedLanguageModel | ||
|
||
|
||
class DatasetForMaskedLanguageModel(Dataset): | ||
""" """ | ||
|
||
def __len__(self): | ||
return super().__len__() | ||
|
||
def __getitem__(self, index): | ||
return super().__getitem__(index) | ||
|
||
def save_tfrecord(self, output_files, **kwargs): | ||
"""Convert examples to tfrecord""" | ||
|
||
def _encode(example: ExampleForMaskedLanguageModel): | ||
feature = { | ||
"input_ids": utils.int64_feature([int(x) for x in example.input_ids]), | ||
"segment_ids": utils.int64_feature([int(x) for x in example.segment_ids]), | ||
"attention_mask": utils.int64_feature([int(x) for x in example.attention_mask]), | ||
"masked_ids": utils.int64_feature([int(x) for x in example.masked_ids]), | ||
"masked_pos": utils.int64_feature([int(x) for x in example.masked_pos]), | ||
} | ||
return feature | ||
|
||
utils.save_tfrecord(iter(self), _encode, output_files, **kwargs) | ||
|
||
|
||
class DatapipeForMaksedLanguageModel(Datapipe): | ||
""" """ | ||
|
||
@classmethod | ||
def from_tfrecord_files(cls, input_files, **kwargs) -> tf.data.Dataset: | ||
dataset = utils.read_tfrecord_files(input_files, **kwargs) | ||
# parse example | ||
num_parallel_calls = kwargs.get("num_parallel_calls", utils.AUTOTUNE) | ||
buffer_size = kwargs.get("buffer_size", utils.AUTOTUNE) | ||
features = { | ||
"input_ids": tf.io.VarLenFeature(tf.int64), | ||
"segment_ids": tf.io.VarLenFeature(tf.int64), | ||
"attention_mask": tf.io.VarLenFeature(tf.int64), | ||
"masked_ids": tf.io.VarLenFeature(tf.int64), | ||
"masked_pos": tf.io.VarLenFeature(tf.int64), | ||
} | ||
dataset = dataset.map( | ||
lambda x: tf.io.parse_example(x, features), | ||
num_parallel_calls=num_parallel_calls, | ||
).prefetch(buffer_size) | ||
dataset = dataset.map( | ||
lambda x: ( | ||
tf.cast(tf.sparse.to_dense(x["input_ids"]), tf.int32), | ||
tf.cast(tf.sparse.to_dense(x["segment_ids"]), tf.int32), | ||
tf.cast(tf.sparse.to_dense(x["attention_mask"]), tf.int32), | ||
tf.cast(tf.sparse.to_dense(x["masked_ids"]), tf.int32), | ||
tf.cast(tf.sparse.to_dense(x["masked_pos"]), tf.int32), | ||
), | ||
num_parallel_calls=num_parallel_calls, | ||
).prefetch(buffer_size) | ||
# do transformation | ||
d = cls(**kwargs) | ||
return d(dataset, **kwargs) | ||
|
||
@classmethod | ||
def from_jsonl_files( | ||
cls, input_files, tokenizer: BertWordPieceTokenizer = None, vocab_file=None, **kwargs | ||
) -> tf.data.Dataset: | ||
instances = readers.read_jsonl_files(input_files, **kwargs) | ||
return cls.from_instances(instances, tokenizer=tokenizer, vocab_file=vocab_file, **kwargs) | ||
|
||
@classmethod | ||
def from_instances(cls, instances, tokenizer: BertWordPieceTokenizer = None, vocab_file=None, **kwargs) -> tf.data.Dataset: | ||
parser = ParserForMasledLanguageModel(tokenizer=tokenizer, vocab_file=vocab_file, **kwargs) | ||
examples = [] | ||
for instance in instances: | ||
if not instance: | ||
continue | ||
e = parser.parse(instance, max_sequence_length=kwargs.pop("max_sequence_length", 512), **kwargs) | ||
if not e: | ||
continue | ||
examples.append(e) | ||
return cls.from_examples(examples, **kwargs) | ||
|
||
@classmethod | ||
def from_dataset(cls, dataset: Dataset, **kwargs) -> tf.data.Dataset: | ||
examples = [e for _, e in enumerate(dataset) if e] | ||
return cls.from_examples(examples, **kwargs) | ||
|
||
@classmethod | ||
def from_examples(cls, examples: List[ExampleForMaskedLanguageModel], verbose=True, **kwargs) -> tf.data.Dataset: | ||
"Parse examples to tf.data.Dataset" | ||
if not examples: | ||
logging.warning("examples is empty or null, skipped to build dataset.") | ||
return None | ||
if verbose: | ||
n = min(5, len(examples)) | ||
for i in range(n): | ||
logging.info("Showing NO.%d example: %s", i, examples[i]) | ||
|
||
def _to_dataset(x, dtype=tf.int32): | ||
x = tf.ragged.constant(x, dtype=dtype) | ||
d = tf.data.Dataset.from_tensor_slices(x) | ||
d = d.map(lambda x: x) | ||
return d | ||
|
||
# conver examples to dataset | ||
dataset = tf.data.Dataset.zip( | ||
( | ||
_to_dataset([e.input_ids for e in examples], dtype=tf.int32), | ||
_to_dataset([e.segment_ids for e in examples], dtype=tf.int32), | ||
_to_dataset([e.attention_mask for e in examples], dtype=tf.int32), | ||
_to_dataset([e.masked_ids for e in examples], dtype=tf.int32), | ||
_to_dataset([e.masked_pos for e in examples], dtype=tf.int32), | ||
) | ||
) | ||
# do transformation | ||
d = cls(**kwargs) | ||
return d(dataset, **kwargs) | ||
|
||
def _filter(self, dataset: tf.data.Dataset, do_filter=True, max_sequence_length=512, **kwargs) -> tf.data.Dataset: | ||
if not do_filter: | ||
return dataset | ||
dataset = dataset.filter(lambda a, b, c, x, y: tf.size(a) <= max_sequence_length) | ||
return dataset | ||
|
||
def _to_dict(self, dataset: tf.data.Dataset, to_dict=True, **kwargs) -> tf.data.Dataset: | ||
num_parallel_calls = kwargs.get("num_parallel_calls", utils.AUTOTUNE) | ||
buffer_size = kwargs.get("buffer_size", utils.AUTOTUNE) | ||
if not to_dict: | ||
dataset = dataset.map( | ||
lambda a, b, c, x, y: ((a, b, c), (x, y)), | ||
num_parallel_calls=num_parallel_calls, | ||
).prefetch(buffer_size) | ||
return dataset | ||
dataset = dataset.map( | ||
lambda a, b, c, x, y: ({"input_ids": a, "segment_ids": b, "attention_mask": c}, {"masked_ids": x, "masked_pos": y}), | ||
num_parallel_calls=num_parallel_calls, | ||
).prefetch(buffer_size) | ||
return dataset | ||
|
||
def _fixed_padding(self, dataset: tf.data.Dataset, pad_id=0, max_sequence_length=512, **kwargs) -> tf.data.Dataset: | ||
maxlen = tf.constant(max_sequence_length, dtype=tf.int32) | ||
pad_id = tf.constant(pad_id, dtype=tf.int32) | ||
# fmt: off | ||
padded_shapes = kwargs.get("padded_shapes", ([maxlen, ], [maxlen, ], [maxlen, ], [maxlen, ], [maxlen])) | ||
padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id)) | ||
# fmt: on | ||
dataset = utils.batching_and_padding(dataset, padded_shapes, padding_values, **kwargs) | ||
return dataset | ||
|
||
def _batch_padding(self, dataset: tf.data.Dataset, pad_id=0, **kwargs) -> tf.data.Dataset: | ||
pad_id = tf.constant(pad_id, dtype=tf.int32) | ||
# fmt: off | ||
padded_shapes = kwargs.get("padded_shapes", ([None, ], [None, ], [None, ], [None, ], [None, ])) | ||
padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id)) | ||
# fmt: on | ||
dataset = utils.batching_and_padding(dataset, padded_shapes, padding_values, **kwargs) | ||
return dataset | ||
|
||
def _bucket_padding(self, dataset: tf.data.Dataset, pad_id=0, **kwargs) -> tf.data.Dataset: | ||
pad_id = tf.constant(pad_id, dtype=tf.int32) | ||
# fmt: off | ||
padded_shapes = kwargs.get("padded_shapes", ([None, ], [None, ], [None, ], [None, ], [None, ])) | ||
padding_values = kwargs.get("padding_values", (pad_id, pad_id, pad_id, pad_id, pad_id)) | ||
# fmt: on | ||
dataset = utils.bucketing_and_padding( | ||
dataset, | ||
bucket_fn=lambda a, b, c, x, y: tf.size(a), | ||
padded_shapes=padded_shapes, | ||
padding_values=padding_values, | ||
**kwargs, | ||
) | ||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from collections import namedtuple | ||
|
||
ExampleForMaskedLanguageModel = namedtuple( | ||
"ExampleForMaskedLanguageModel", | ||
["tokens", "input_ids", "segment_ids", "attention_mask", "masked_ids", "masked_pos"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import abc | ||
import random | ||
from collections import namedtuple | ||
|
||
ResultForMasking = namedtuple("ResultForMasking", ["origin_tokens", "masked_tokens", "masked_indexes"]) | ||
|
||
|
||
class AbstractMaskingStrategy(abc.ABC): | ||
"""Abstract masking strategy""" | ||
|
||
@abc.abstractmethod | ||
def __call__(self, tokens, **kwargs) -> ResultForMasking: | ||
raise NotImplementedError() | ||
|
||
|
||
class WholeWordMask(AbstractMaskingStrategy): | ||
"""Default masking strategy from BERT.""" | ||
|
||
def __init__(self, vocabs, change_prob=0.15, mask_prob=0.8, rand_prob=0.1, keep_prob=0.1, max_predictions=20, **kwargs): | ||
self.vocabs = vocabs | ||
self.change_prob = change_prob | ||
self.mask_prob = mask_prob / (mask_prob + rand_prob + keep_prob) | ||
self.rand_prob = rand_prob / (mask_prob + rand_prob + keep_prob) | ||
self.keep_prob = keep_prob / (mask_prob + rand_prob + keep_prob) | ||
self.max_predictions = max_predictions | ||
|
||
def __call__(self, tokens, max_sequence_length=512, **kwargs) -> ResultForMasking: | ||
tokens = self._truncate_sequence(tokens, max_sequence_length - 2) | ||
if not tokens: | ||
return None | ||
num_to_predict = min(self.max_predictions, max(1, round(self.change_prob * len(tokens)))) | ||
cand_indexes = self._collect_candidates(tokens) | ||
# copy original tokens | ||
masked_tokens = [x for x in tokens] | ||
masked_indexes = [0] * len(tokens) | ||
for piece_indexes in cand_indexes: | ||
if sum(masked_indexes) >= num_to_predict: | ||
break | ||
if sum(masked_indexes) + len(piece_indexes) > num_to_predict: | ||
continue | ||
if any(masked_indexes[idx] == 1 for idx in piece_indexes): | ||
continue | ||
for index in piece_indexes: | ||
masked_indexes[index] = 1 | ||
masked_tokens[index] = self._masking_tokens(index, tokens, self.vocabs) | ||
|
||
# add special tokens | ||
tokens = ["[CLS]"] + tokens + ["[SEP]"] | ||
masked_tokens = ["[CLS]"] + masked_tokens + ["[SEP]"] | ||
masked_indexes = [0] + masked_indexes + [0] | ||
assert len(tokens) == len(masked_tokens) == len(masked_indexes) | ||
return ResultForMasking(origin_tokens=tokens, masked_tokens=masked_tokens, masked_indexes=masked_indexes) | ||
|
||
def _masking_tokens(self, index, tokens, vocabs, **kwargs): | ||
# 80% of the time, replace with [MASK] | ||
if random.random() < self.mask_prob: | ||
return "[MASK]" | ||
# 10% of the time, keep original | ||
p = self.rand_prob / (self.rand_prob + self.keep_prob) | ||
if random.random() < p: | ||
return tokens[index] | ||
# 10% of the time, replace with random word | ||
masked_token = vocabs[random.randint(0, len(vocabs) - 1)] | ||
return masked_token | ||
|
||
def _collect_candidates(self, tokens): | ||
cand_indexes = [[]] | ||
for idx, token in enumerate(tokens): | ||
if cand_indexes and token.startswith("##"): | ||
cand_indexes[-1].append(idx) | ||
continue | ||
cand_indexes.append([idx]) | ||
random.shuffle(cand_indexes) | ||
return cand_indexes | ||
|
||
def _truncate_sequence(self, tokens, max_tokens=512, **kwargs): | ||
while len(tokens) > max_tokens: | ||
if len(tokens) > max_tokens: | ||
tokens.pop(0) | ||
# truncate whole world | ||
while tokens and tokens[0].startswith("##"): | ||
tokens.pop(0) | ||
if len(tokens) > max_tokens: | ||
while tokens and tokens[-1].startswith("##"): | ||
tokens.pop() | ||
if tokens: | ||
tokens.pop() | ||
return tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from tokenizers import BertWordPieceTokenizer | ||
|
||
from .example import ExampleForMaskedLanguageModel | ||
from .masking_strategy import WholeWordMask | ||
|
||
|
||
class ParserForMasledLanguageModel: | ||
""" """ | ||
|
||
@classmethod | ||
def from_vocab_file(cls, vocab_file, **kwargs): | ||
return cls(tokenizer=None, vocab_file=vocab_file, **kwargs) | ||
|
||
@classmethod | ||
def from_tokenizer(cls, tokenizer: BertWordPieceTokenizer, **kwargs): | ||
return cls(tokenizer=tokenizer, vocab_file=None, **kwargs) | ||
|
||
def __init__(self, tokenizer: BertWordPieceTokenizer, vocab_file=None, do_lower_case=True, **kwargs) -> None: | ||
assert tokenizer or vocab_file, "`tokenizer` or `vocab_file` must be provided!" | ||
self.tokenizer = tokenizer or BertWordPieceTokenizer.from_file(vocab_file, lowercase=do_lower_case) | ||
self.vocabs = list(self.tokenizer.get_vocab().keys()) | ||
self.masking = WholeWordMask(vocabs=self.vocabs, **kwargs) | ||
|
||
def parse(self, instance, max_sequence_length=512, **kwargs) -> ExampleForMaskedLanguageModel: | ||
sequence = instance["sequence"] | ||
# set add_special_tokens=False here, masking strategy will add these special tokens | ||
encoding = self.tokenizer.encode(sequence, add_special_tokens=False) | ||
masking_results = self.masking(tokens=encoding.tokens, max_sequence_length=max_sequence_length, **kwargs) | ||
origin_tokens, masked_tokens = masking_results.origin_tokens, masking_results.masked_tokens | ||
example = ExampleForMaskedLanguageModel( | ||
tokens=masked_tokens, | ||
input_ids=[self.tokenizer.token_to_id(x) for x in masked_tokens], | ||
segment_ids=[0] * len(masked_tokens), | ||
attention_mask=[1] * len(masked_tokens), | ||
masked_ids=[self.tokenizer.token_to_id(x) for x in origin_tokens], | ||
masked_pos=masking_results.masked_indexes, | ||
) | ||
return example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import json | ||
import logging | ||
import os | ||
|
||
|
||
def read_jsonl_files(input_files, sequence_key="sequence", **kwargs): | ||
if isinstance(input_files, str): | ||
input_files = [input_files] | ||
for f in input_files: | ||
with open(f, mode="rt", encoding="utf-8") as fin: | ||
for line in fin: | ||
line = line.strip() | ||
if not line: | ||
continue | ||
data = json.loads(line) | ||
instance = {"sequence": data[sequence_key]} | ||
yield instance |
Empty file.
Oops, something went wrong.