From 016cc50e06649ddf57a5534672171e6fa6dd7cff Mon Sep 17 00:00:00 2001 From: Erik Ziegler Date: Wed, 28 Jul 2021 16:36:23 +0200 Subject: [PATCH] Add pre-training code --- training/README.md | 26 ++++ training/__init__.py | 0 training/data_preparation.py | 177 +++++++++++++++++++++++++ training/example.ini | 23 ++++ training/load_config.py | 40 ++++++ training/losses.py | 19 +++ training/pretraining.py | 246 +++++++++++++++++++++++++++++++++++ training/requirements.txt | 8 ++ training/tokenization.py | 136 +++++++++++++++++++ 9 files changed, 675 insertions(+) create mode 100644 training/README.md create mode 100644 training/__init__.py create mode 100644 training/data_preparation.py create mode 100644 training/example.ini create mode 100644 training/load_config.py create mode 100644 training/losses.py create mode 100644 training/pretraining.py create mode 100644 training/requirements.txt create mode 100644 training/tokenization.py diff --git a/training/README.md b/training/README.md new file mode 100644 index 0000000..f7c9a86 --- /dev/null +++ b/training/README.md @@ -0,0 +1,26 @@ +# FNet PyTorch + +## Pre-Training + +You can pre-train an FNet from a checkpoint or from scratch. + +Keep in mind that can also always use the official implementation for training and converting the resulting checkpoint. + +### Setup + +1) Create a virtualenv and install dependencies + +```bash +pip install -r training/requirements.txt +``` + +2) Copy the `example.ini` and configure it to your needs. + + +### Start a pre-training + +Run a training (from this repositories root directory) + +```bash +python -m training.pretraining --config myconfig.ini +``` diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/data_preparation.py b/training/data_preparation.py new file mode 100644 index 0000000..a458a52 --- /dev/null +++ b/training/data_preparation.py @@ -0,0 +1,177 @@ +""" +Heavily inspired by https://github.com/google-research/google-research/blob/master/f_net/input_pipeline.py +""" +from typing import Iterator, List, TypedDict + +import tensorflow_datasets as tfds +import numpy as np +import torch +import json +import glob + +from .tokenization import Tokenizer + +np.random.seed(0) + + +class NSPData(TypedDict): + input_ids: torch.Tensor + type_ids: torch.Tensor + nsp_labels: int + + +class PreTrainData(TypedDict): + input_ids: torch.Tensor + original_input_ids: torch.Tensor + type_ids: torch.Tensor + mlm_positions: torch.Tensor + mlm_ids: torch.Tensor + mlm_weights: torch.Tensor + nsp_labels: int + + +def pretraining_data_gen( + tokenizer: Tokenizer, + batch_size: int, + max_seq_length: int, + device: torch.device, + max_predictions_per_seq=80, + masking_rate=0.15, + mask_token_proportion=0.8, + random_token_proportion=0.1 +) -> Iterator[PreTrainData]: + ignore_ids = [tokenizer.cls_id, tokenizer.sep_id, tokenizer.pad_id] + ignore_ids = torch.LongTensor(ignore_ids)[:, None] + + normal_tokens = [t for t in range(tokenizer.vocab_size) if t not in tokenizer.special_tokens()] + + gen = _nsp_data_gen(tokenizer, max_seq_length) + + samples = [] + for sample in gen: + sample: PreTrainData = sample + num_tokens = torch.sum(sample["input_ids"] != tokenizer.pad_id).item() + prediction_mask = torch.all(sample["input_ids"] != ignore_ids, dim=0) + cand_indices = torch.arange(prediction_mask.shape[0], dtype=torch.long)[prediction_mask] + num_to_predict = min(max_predictions_per_seq, max(1, int(num_tokens * masking_rate))) + + if len(cand_indices) == 0: + continue + + mlm_positions = torch.LongTensor(np.sort(np.random.choice(cand_indices, num_to_predict, replace=False))) + mlm_ids = sample["input_ids"][mlm_positions] + mlm_weights = torch.ones(num_to_predict, dtype=torch.float32) + + # Mask out tokens + for position in mlm_positions: + rand = np.random.random() + if rand < mask_token_proportion: + replace_token_id = tokenizer.mask_id + elif rand < mask_token_proportion + random_token_proportion: + replace_token_id = np.random.choice(normal_tokens, 1).item() + else: + replace_token_id = sample["input_ids"][position] + sample["input_ids"][position] = replace_token_id + + mlm_positions_out = torch.zeros(max_predictions_per_seq, dtype=torch.long) + mlm_ids_out = torch.zeros(max_predictions_per_seq, dtype=torch.long) + mlm_weights_out = torch.zeros(max_predictions_per_seq, dtype=torch.float32) + + mlm_weights_out[:num_to_predict] = mlm_weights + mlm_positions_out[:num_to_predict] = mlm_positions + mlm_ids_out[:num_to_predict] = mlm_ids + + sample["mlm_positions"] = mlm_positions_out + sample["mlm_ids"] = mlm_ids_out + sample["mlm_weights"] = mlm_weights_out + + samples.append(sample) + + if len(samples) == batch_size: + yield samples_to_batch(samples, device) + samples = [] + + +def _nsp_data_gen( + tokenizer: Tokenizer, + max_seq_length: int +) -> Iterator[NSPData]: + ds = tfds.load(name='wikipedia/20201201.en', split="train", shuffle_files=True) + ds = ds.repeat() + ds = ds.shuffle(1024) + ds = ds.batch(16) + + for batch in tfds.as_numpy(ds): + for text in batch["text"]: + text = str(text, "utf-8") + lines = [tokenizer.encode_as_ids(line) for line in text.splitlines()] + j = 0 + while j < len(lines) - 1: + if len(lines[j]) + len(lines[j + 1]) > max_seq_length - 3: + j += 1 + continue + + input_ids = torch.full((max_seq_length,), tokenizer.pad_id, dtype=torch.long) + type_ids = torch.full((max_seq_length,), 1, dtype=torch.long) + + selected_lines = concat_lines_until_max(lines[j:], max_seq_length) + j += len(selected_lines) + + pivot = np.random.randint(1, len(selected_lines)) + datum = [tokenizer.cls_id] + + if np.random.random() < 0.5: + for tokens in selected_lines[:pivot]: + datum.extend(tokens) + datum.append(tokenizer.sep_id) + type_ids[:len(datum)] = 0 + for tokens in selected_lines[pivot:]: + datum.extend(tokens) + datum.append(tokenizer.sep_id) + next_sentence_label = 0 + type_ids[len(datum):] = 0 + else: + for tokens in selected_lines[pivot:]: + datum.extend(tokens) + datum.append(tokenizer.sep_id) + type_ids[:len(datum)] = 0 + for tokens in selected_lines[:pivot]: + datum.extend(tokens) + datum.append(tokenizer.sep_id) + next_sentence_label = 1 + type_ids[len(datum):] = 0 + + input_ids[:] = tokenizer.pad_id + input_ids[:len(datum)] = torch.LongTensor(datum) + + yield { + "input_ids": input_ids, + "type_ids": type_ids, + "nsp_labels": next_sentence_label, + } + + +def concat_lines_until_max(lines, max_len): + cum_len = 0 + k = 0 + for k in range(len(lines)): + cum_len += len(lines[k]) + if cum_len > max_len - 3: + k -= 1 + break + return lines[:k + 1] + + +def samples_to_batch(samples, device): + batch_size = len(samples) + batch = {} + for key in samples[0].keys(): + value = samples[0][key] + if isinstance(value, torch.Tensor): + batch[key] = torch.zeros((batch_size, value.shape[0]), dtype=value.dtype).to(device) + else: + batch[key] = torch.zeros(batch_size, dtype=(torch.long if isinstance(value, int) else torch.float32)).to(device) + for i, sample in enumerate(samples): + for key in batch.keys(): + batch[key][i] = sample[key] + return batch diff --git a/training/example.ini b/training/example.ini new file mode 100644 index 0000000..826a4cb --- /dev/null +++ b/training/example.ini @@ -0,0 +1,23 @@ +[general] +experiment_name=my_experiment +gpu_id=0 + +[model] +fnet_config= +fnet_checkpoint= + +[tokenizer] +# sentencepiece, wordpiece or huggingface +type=sentencepiece +# path to .model file for sentencepiece +# path to vocab file for wordpiece +vocab= +# name of tokenizer (only for huggingface type) +hf_name= + +[training] +learning_rate=1e-4 +train_batch_size=64 +eval_batch_size=64 +eval_frequency=500 +eval_steps=1000 diff --git a/training/load_config.py b/training/load_config.py new file mode 100644 index 0000000..4458bfc --- /dev/null +++ b/training/load_config.py @@ -0,0 +1,40 @@ +import os +from configparser import ConfigParser + +from tabulate import tabulate + +def load_config(config_path: str): + if not os.path.exists(config_path): + raise Exception('configuration file {} does not exist'.format(config_path)) + + configparser = ConfigParser() + configparser.read(config_path) + + config = {} + + # == general == + config['experiment_name'] = configparser.get('general', 'experiment_name', fallback='unnamed') + config['gpu_id'] = configparser.getint('general', 'gpu_id', fallback=-1) + + # == model == + config['fnet_config'] = configparser.get('model', 'fnet_config') + config['fnet_checkpoint'] = configparser.get('model', 'fnet_checkpoint') + + # == training == + config['learning_rate'] = configparser.getfloat('training', 'learning_rate') + config['train_batch_size'] = configparser.getint('training', 'train_batch_size') + config['eval_batch_size'] = configparser.getint('training', 'eval_batch_size') + config['eval_frequency'] = configparser.getint('training', 'eval_frequency') + config['eval_steps'] = configparser.getint('training', 'eval_steps') + + # == tokenizer == + config['tokenizer'] = {} + config['tokenizer']['type'] = configparser.get('tokenizer', 'type') + config['tokenizer']['vocab'] = configparser.get('tokenizer', 'vocab') + config['tokenizer']['hf_name'] = configparser.get('tokenizer', 'hf_name') + + return config + + +def print_config(config): + print(tabulate(config.items(), tablefmt='grid')) diff --git a/training/losses.py b/training/losses.py new file mode 100644 index 0000000..fb008a1 --- /dev/null +++ b/training/losses.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import MSELoss + + +class MLMWeightedCELoss(nn.Module): + def __init__(self): + super(MLMWeightedCELoss, self).__init__() + self.log_softmax = nn.LogSoftmax(dim=1) + + def forward(self, logits, targets, weights): + input = self.log_softmax(logits) + loss = -torch.sum(targets * input, dim=-1) * weights + loss = loss.sum() / weights.sum() + return loss + + + diff --git a/training/pretraining.py b/training/pretraining.py new file mode 100644 index 0000000..a2e6b0b --- /dev/null +++ b/training/pretraining.py @@ -0,0 +1,246 @@ +import logging +import os +import json +from argparse import ArgumentParser +from datetime import datetime +from time import time +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +from fnet import FNetForPreTraining +from tabulate import tabulate +from torch.optim import Adam +from torch.nn import CrossEntropyLoss, MSELoss +from torch.utils.tensorboard import SummaryWriter +from warmup_scheduler import GradualWarmupScheduler + +from .data_preparation import pretraining_data_gen +from .losses import MLMWeightedCELoss +from .tokenization import get_tokenizer +from .load_config import load_config, print_config + +report_frequency = 100 +analyize_prediction_frequency = 200 +num_warmup_samples = 2_560_000 + +logging.basicConfig(level=logging.INFO) + +def pretraining(config: Dict): + device = torch.device(f"cuda:{config['gpu_id']}") + + logging.info(f"Loading FNet config {config['fnet_config']}") + with open(config['fnet_config']) as f: + fnet_config = json.load(f) + + max_seq_len = fnet_config['max_position_embeddings'] + + model = FNetForPreTraining(fnet_config) + + if config['fnet_checkpoint']: + logging.info(f"Loading FNet pre-training checkpoint {config['fnet_checkpoint']}") + state_dict = torch.load(config['fnet_checkpoint'], map_location=torch.device('cpu')) + model.load_state_dict(state_dict) + + tokenizer = get_tokenizer(config['tokenizer'], max_seq_len) + + optimizer = Adam(model.parameters(), lr=config['learning_rate'], eps=1e-6, weight_decay=0.01) + warmup_steps = int(num_warmup_samples / config['train_batch_size']) + warmup_scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_steps) + + logging.info(f'Scheduled {warmup_steps} warm-up steps ({num_warmup_samples} samples)') + + mlm_criterion = MLMWeightedCELoss() + nsp_criterion = CrossEntropyLoss() + + log_dir = 'experiments' + date_str = datetime.now().strftime('%Y-%m-%d_%H-%M') + experiment_name = f"{date_str}_{config['experiment_name']}" + train_writer = SummaryWriter(os.path.join(log_dir, experiment_name, 'train')) + eval_writer = SummaryWriter(os.path.join(log_dir, experiment_name, 'eval')) + + train_gen = pretraining_data_gen(tokenizer, config['train_batch_size'], max_seq_len, device) + eval_gen = pretraining_data_gen(tokenizer, config['eval_batch_size'], max_seq_len, device) + + model.to(device) + model.train() + + step = 0 + + logging.info('Starting training') + + for batch in train_gen: + step_start = time() + optimizer.zero_grad() + + pred = model( + input_ids=batch['input_ids'], + type_ids=batch['type_ids'], + mlm_positions=batch['mlm_positions'] + ) + + losses = get_loss(batch, pred, mlm_criterion, nsp_criterion) + loss = losses['loss'] + loss.backward() + optimizer.step() + + train_writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], step) + warmup_scheduler.step() + + for name, value in losses.items(): + train_writer.add_scalar(f'Loss/{name}', value.item(), step) + + step_end = time() + duration = step_end - step_start + + train_writer.add_scalar('StepDuration', duration, step) + train_writer.add_scalar('StepsPerSecond', 1 / duration, step) + + if step % report_frequency == 0 and step > 0: + logging.info(f'Step {step}, Loss {loss.item()}') + + if step % analyize_prediction_frequency == 0 and step > 0: + analyze_prediction(batch, pred, tokenizer) + + if step % config['eval_frequency'] == 0 and step > 0: + evaluate(model, eval_writer, eval_gen, config, mlm_criterion, nsp_criterion, step, tokenizer) + export_checkpoint(model, step, experiment_name) + + step += 1 + + +def get_loss(batch, pred, mlm_criterion, nsp_criterion): + mlm_loss = mlm_criterion( + pred['mlm_logits'].flatten(0, 1), + F.one_hot(batch['mlm_ids'].ravel(), pred['mlm_logits'].shape[-1]), + batch['mlm_weights'].ravel() + ) + + nsp_loss = nsp_criterion(pred['nsp_logits'], batch['nsp_labels']) + loss = mlm_loss + nsp_loss + + return { + 'mlm_loss': mlm_loss, + 'nsp_loss': nsp_loss, + 'loss': loss + } + + +def export_checkpoint(model, step, experiment_name): + name = f'pretraining_model_step_{step}.statedict.pt' + checkpoints_dir = os.path.join('exports', experiment_name) + if not os.path.exists(checkpoints_dir): + os.mkdir(checkpoints_dir) + torch.save(model.state_dict(), os.path.join(checkpoints_dir, name)) + + +def evaluate(model, eval_writer, eval_gen, config, mlm_criterion, nsp_criterion, train_step, tokenizer): + print('Running evaluation') + model.eval() + with torch.no_grad(): + eval_step = 0 + losses, mlm_losses, nsp_losses = 0, 0, 0 + mlm_hits, nsp_hits = 0, 0 + mlm_total, nsp_total = 0, 0 + + for batch in eval_gen: + pred = model( + input_ids=batch['input_ids'], + type_ids=batch['type_ids'], + mlm_positions=batch['mlm_positions'] + ) + metrics = get_loss(batch, pred, mlm_criterion, nsp_criterion) + hits = get_hits(batch, pred) + + mlm_hits += hits['mlm_hits'] + nsp_hits += hits['nsp_hits'] + mlm_total += hits['mlm_total'] + nsp_total += hits['nsp_total'] + + losses += metrics['loss'] + mlm_losses += metrics['mlm_loss'] + nsp_losses += metrics['nsp_loss'] + + eval_step += 1 + + if eval_step % report_frequency == 0: + logging.info('Eval Step', eval_step) + + if eval_step % analyize_prediction_frequency: + analyze_prediction(batch, pred, tokenizer) + + if eval_step >= config['eval_steps']: + break + + losses = (('MLM', mlm_losses / eval_step), ('NSP', nsp_losses / eval_step), ('Total', losses / eval_step)) + averages = (('MLM', mlm_hits / mlm_total), ('NSP', nsp_hits / nsp_total)) + + logging.info(f'Losses: {losses}') + logging.info(f'Averages: {averages}') + + if eval_writer: + for name, value in losses: + eval_writer.add_scalar(f'Loss/{name}', value, train_step) + for name, value in averages: + eval_writer.add_scalar(f'Accuracy/{name}', value, train_step) + + model.train() + + +def get_hits(batch, pred): + predicted_mlm_ids = pred['mlm_logits'].flatten(0, 1).argmax(-1) + mlm_hits = torch.sum((predicted_mlm_ids == batch['mlm_ids'].ravel()) * batch['mlm_weights'].ravel()).item() + nsp_hits = torch.sum(pred['nsp_logits'].argmax(-1) == batch['nsp_labels']).item() + + return { + 'mlm_hits': mlm_hits, + 'mlm_total': batch['mlm_weights'].ravel().sum().item(), + 'nsp_hits': nsp_hits, + 'nsp_total': batch['nsp_labels'].shape[0] + } + + +def analyze_prediction(batch, pred, tokenizer): + logging.info('Printing qualitative result') + + input_ids = batch['input_ids'][0] + mlm_positions = batch['mlm_positions'][0] + non_null_positions = mlm_positions[mlm_positions != 0] + correct_ids = input_ids.detach().clone() + correct_ids[non_null_positions] = batch['mlm_ids'][0][:len(non_null_positions)] + predicted_ids = input_ids.detach().clone() + predicted_ids[non_null_positions] = pred['mlm_logits'][0].argmax(-1)[:len(non_null_positions)] + + print('Input text', '\n', tokenizer.decode(input_ids.cpu().numpy().tolist()), '\n') + print('Correct text', '\n', tokenizer.decode(correct_ids.cpu().numpy().tolist()), '\n') + print('Predicted text', '\n', tokenizer.decode(predicted_ids.cpu().numpy().tolist()), '\n') + print('Type Ids', '\n', batch['type_ids'][0], '\n') + + print(tabulate([( + tokenizer.decode([input_ids[idx].item()]), + tokenizer.decode([correct_ids[idx].item()]), + tokenizer.decode([predicted_ids[idx].item()]), + 'X' if correct_ids[idx] == predicted_ids[idx] else '', + ) for idx in non_null_positions], tablefmt='grid', headers=('input', 'correct', 'predicted', 'hit'))) + + hits = torch.sum(correct_ids[non_null_positions] == predicted_ids[non_null_positions]).item() + accuracy = hits / len(non_null_positions) + + print(f'{hits} / {len(non_null_positions)} Hits ({accuracy:.2f} accuracy)') + + print('NSP Truth:', batch['nsp_labels'][0].item()) + print('NSP Pred:', pred['nsp_logits'][0].argmax(-1).item()) + + print() + + +if __name__ == '__main__': + argparser = ArgumentParser() + argparser.add_argument('--config', required=True, help='Path to config file') + args, _ = argparser.parse_known_args() + + config = load_config(args.config) + print_config(config) + + pretraining(config) diff --git a/training/requirements.txt b/training/requirements.txt new file mode 100644 index 0000000..243b8e2 --- /dev/null +++ b/training/requirements.txt @@ -0,0 +1,8 @@ +torch>=1.9.0 +tabulate>=0.8.9 +tensorflow_datasets>=4.0.0 +tokenizers>=0.10.3 +numpy>=1.18.5 +sentencepiece>=0.1.91 +git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git +git+https://github.com/erksch/fnet-pytorch.git \ No newline at end of file diff --git a/training/tokenization.py b/training/tokenization.py new file mode 100644 index 0000000..23ff738 --- /dev/null +++ b/training/tokenization.py @@ -0,0 +1,136 @@ +from abc import ABC, abstractmethod +from typing import List, Dict +import torch +import sentencepiece as spm + +from typing import TypedDict + +from tokenizers import Tokenizer as HFTokenizer, BertWordPieceTokenizer + + +class EncodedText(TypedDict): + input_ids: torch.Tensor + type_ids: torch.Tensor + + +class Tokenizer(ABC): + pad_id: int + sep_id: int + mask_id: int + cls_id: int + vocab_size: int + + def special_tokens(self) -> List[int]: + return [self.mask_id, self.cls_id, self.sep_id, self.pad_id] + + @abstractmethod + def encode_as_ids(self, text: str) -> List[int]: + pass + + @abstractmethod + def decode(self, ids: List[int]) -> str: + pass + + @abstractmethod + def encode(self, texts: List[str]) -> EncodedText: + pass + + +class HuggingFaceTokenizer(Tokenizer): + def __init__(self, tokenizer: HFTokenizer): + self.tokenizer = tokenizer + + self.cls_id = self.tokenizer.token_to_id("[CLS]") + self.sep_id = self.tokenizer.token_to_id("[SEP]") + self.mask_id = self.tokenizer.token_to_id("[MASK]") + self.pad_id = self.tokenizer.token_to_id("[PAD]") + self.vocab_size = self.tokenizer.get_vocab_size() + + def decode(self, ids: List[int]) -> str: + return self.tokenizer.decode(ids, skip_special_tokens=False) + + def encode_as_ids(self, text: str) -> List[int]: + return self.tokenizer.encode(text, is_pretokenized=False, add_special_tokens=False).ids + + def encode(self, texts: List[str]) -> EncodedText: + if len(texts) > 2: + raise Exception("Hugging face tokenizer can only encode two texts") + elif len(texts) == 2: + sequence, pair = texts + else: + sequence, pair = texts[0], None + + encoding = self.tokenizer.encode(sequence, pair) + + return {"input_ids": encoding.ids, "type_ids": encoding.type_ids} + + +class SentencePieceTokenizer(Tokenizer): + def __init__(self, vocab_file, max_seq_length): + super(SentencePieceTokenizer).__init__() + + self.tokenizer = spm.SentencePieceProcessor() + self.tokenizer.Load(vocab_file) + self.tokenizer.SetEncodeExtraOptions("") + + self.vocab_size = self.tokenizer.GetPieceSize() + + self.cls_id = self.tokenizer.PieceToId("[CLS]") + self.sep_id = self.tokenizer.PieceToId("[SEP]") + self.mask_id = self.tokenizer.PieceToId("[MASK]") + self.pad_id = self.tokenizer.pad_id() + + self.max_seq_length = max_seq_length + + def special_tokens(self): + eos_id = self.tokenizer.eos_id() + bos_id = self.tokenizer.bos_id() + return {self.mask_id, self.cls_id, self.sep_id, self.pad_id, bos_id, eos_id} + + def decode(self, ids: List[int]) -> str: + return self.tokenizer.DecodeIdsWithCheck(ids) + + def encode_as_ids(self, text: str) -> List[int]: + return self.tokenizer.EncodeAsIds(text) + + def encode(self, texts: List[str]) -> EncodedText: + input_ids_out = torch.full([self.max_seq_length], self.pad_id, dtype=torch.long) + type_ids_out = torch.full([self.max_seq_length], 0, dtype=torch.long) + + input_ids = [self.cls_id] + type_ids = [0] + + for text in texts: + tokens = self.tokenizer.EncodeAsIds(text) + [self.sep_id] + input_ids.extend(tokens) + type_ids.extend([1] * len(tokens)) + + # truncate + input_ids = input_ids[:self.max_seq_length] + type_ids = type_ids[:self.max_seq_length] + + # pad + input_ids_out[:len(input_ids)] = torch.LongTensor(input_ids) + type_ids_out[:len(type_ids)] = torch.LongTensor(type_ids) + + return {'input_ids': input_ids_out, 'type_ids': type_ids_out} + + +def get_tokenizer(tokenizer_config: Dict, max_seq_length: int): + type = tokenizer_config['type'] + + if type == 'sentencepiece': + if not tokenizer_config['vocab']: raise Exception("No vocab given") + return SentencePieceTokenizer(tokenizer_config['vocab'], max_seq_length) + + if type == 'wordpiece': + if not tokenizer_config['vocab']: raise Exception("No vocab given") + tokenizer = BertWordPieceTokenizer.from_file(tokenizer_config['vocab']) + return HuggingFaceTokenizer(tokenizer) + + if type == 'huggingface': + if not tokenizer_config['hf_name']: raise Exception("No name given") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config['hf_name']) + return HuggingFaceTokenizer(tokenizer._tokenizer) + + raise Exception(f"Unexpected tokenizer type {type}") \ No newline at end of file