Skip to content

Commit

Permalink
Make gptxdata optional dependency
Browse files Browse the repository at this point in the history
Also sort imports
  • Loading branch information
janEbert committed Aug 21, 2023
1 parent 4e99bb4 commit a90c3d7
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,29 @@

from abc import ABC
from abc import abstractmethod
from gptxdata.tokenization import HFTokenizer, SPTokenizer, PretrainedHFTokenizer

try:
from gptxdata.tokenization import (
HFTokenizer,
PretrainedHFTokenizer,
SPTokenizer,
)
except ImportError:
HFTokenizer = None
PretrainedHFTokenizer = None
SPTokenizer = None

from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer


def _assert_gptx_tokenizer_available(tokenizer_name, tokenizer_cls):
assert tokenizer_cls is not None, (
f'Please install `gptxdata` to use {tokenizer_name}, e.g., with '
f'`pip install git+https://github.com/OpenGPTX/opengptx_data.git`'
)


def build_tokenizer(args):
"""Initialize tokenizer."""
if args.rank == 0:
Expand Down Expand Up @@ -39,10 +58,14 @@ def build_tokenizer(args):
assert args.vocab_size is not None
tokenizer = _NullTokenizer(args.vocab_size)
elif args.tokenizer_type == "OpenGPTX-HFTokenizer":
_assert_gptx_tokenizer_available(args.tokenizer_type, HFTokenizer)
tokenizer = HFTokenizer.instantiate_from_file_or_name(model_file_or_name=args.tokenizer_model)
elif args.tokenizer_type == "OpenGPTX-PretrainedHFTokenizer":
_assert_gptx_tokenizer_available(
args.tokenizer_type, PretrainedHFTokenizer)
tokenizer = PretrainedHFTokenizer.instantiate_from_file_or_name(model_file_or_name=args.tokenizer_model)
elif args.tokenizer_type == "OpenGPTX-SPTokenizer":
_assert_gptx_tokenizer_available(args.tokenizer_type, SPTokenizer)
tokenizer = SPTokenizer.instantiate_from_file_or_name(model_file_or_name=args.tokenizer_model)
else:
raise NotImplementedError('{} tokenizer is not '
Expand Down

0 comments on commit a90c3d7

Please sign in to comment.