diff --git a/sentence_transformers/__init__.py b/sentence_transformers/__init__.py index 83ec3b23f..fab41f554 100644 --- a/sentence_transformers/__init__.py +++ b/sentence_transformers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.0.2" +__version__ = "1.0.3" __DOWNLOAD_SERVER__ = 'http://sbert.net/models/' from .datasets import SentencesDataset, ParallelSentencesDataset from .LoggingHandler import LoggingHandler diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index f11fe13fd..dab29ff20 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -73,6 +73,7 @@ def paraphrase_mining(model, sentences: List[str], show_progress_bar: bool = False, batch_size:int = 32, + *args, **kwargs): """ Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all @@ -93,7 +94,7 @@ def paraphrase_mining(model, # Compute embedding for the sentences embeddings = model.encode(sentences, show_progress_bar=show_progress_bar, batch_size=batch_size, convert_to_tensor=True) - return paraphrase_mining_embeddings(embeddings, **kwargs) + return paraphrase_mining_embeddings(embeddings, *args, **kwargs) def paraphrase_mining_embeddings(embeddings: Tensor, diff --git a/setup.py b/setup.py index c14c15086..1172eeb27 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="sentence-transformers", - version="1.0.2", + version="1.0.3", author="Nils Reimers", author_email="info@nils-reimers.de", description="Sentence Embeddings using BERT / RoBERTa / XLM-R", @@ -15,7 +15,7 @@ long_description_content_type="text/markdown", license="Apache License 2.0", url="https://github.com/UKPLab/sentence-transformers", - download_url="https://github.com/UKPLab/sentence-transformers/archive/v1.0.2.zip", + download_url="https://github.com/UKPLab/sentence-transformers/archive/v1.0.3.zip", packages=find_packages(), install_requires=[ 'transformers>=3.1.0,<5.0.0', diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 4d3b0a61a..13b834cc4 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -34,8 +34,7 @@ def test_BinaryClassificationEvaluator_find_best_accuracy_and_threshold(self): assert np.abs(max_acc - sklearn_acc) < 1e-6 def test_LabelAccuracyEvaluator(self): - - + """Tests that the LabelAccuracyEvaluator can be loaded correctly""" model = SentenceTransformer('paraphrase-distilroberta-base-v1') nli_dataset_path = 'datasets/AllNLI.tsv.gz' @@ -59,3 +58,11 @@ def test_LabelAccuracyEvaluator(self): evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss) acc = evaluator(model) assert acc > 0.2 + + def test_ParaphraseMiningEvaluator(self): + """Tests that the ParaphraseMiningEvaluator can be loaded""" + model = SentenceTransformer('paraphrase-distilroberta-base-v1') + sentences = {0: "Hello World", 1: "Hello World!", 2: "The cat is on the table", 3: "On the table the cat is"} + data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0,1), (2,3)]) + score = data_eval(model) + assert score > 0.99 \ No newline at end of file