Skip to content

Commit

Permalink
replace use_fast_tokenizer with tokenizer_args
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Dec 22, 2020
1 parent 55756ad commit 28d6f90
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions sentence_transformers/cross_encoder/CrossEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class CrossEncoder():
def __init__(self, model_name:str, num_labels:int = None, max_length:int = None, device:str = None, use_fast_tokenizer:bool = None):
def __init__(self, model_name:str, num_labels:int = None, max_length:int = None, device:str = None, tokenizer_args:Dict = {}):
"""
A CrossEncoder takes exactly two sentences / texts as input and either predicts
a score or label for this sentence pair. It can for example predict the similarity of the sentence pair
Expand All @@ -27,7 +27,7 @@ def __init__(self, model_name:str, num_labels:int = None, max_length:int = None,
:param num_labels: Number of labels of the classifier. If 1, the CrossEncoder is a regression model that outputs a continous score 0...1. If > 1, it output several scores that can be soft-maxed to get probability scores for the different classes.
:param max_length: Max length for input sequences. Longer sequences will be truncated. If None, max length of the model will be used
:param device: Device that should be used for the model. If None, it will use CUDA if available.
:param use_fast_tokenizer: Use fast tokenizer from hugging face.
:param tokenizer_args: Arguments passed to AutoTokenizer
"""

self.config = AutoConfig.from_pretrained(model_name)
Expand All @@ -42,10 +42,6 @@ def __init__(self, model_name:str, num_labels:int = None, max_length:int = None,
self.config.num_labels = num_labels

self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
tokenizer_args = {}
if use_fast_tokenizer is not None:
tokenizer_args['use_fast'] = use_fast_tokenizer

self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args)

self.max_length = max_length
Expand Down

0 comments on commit 28d6f90

Please sign in to comment.