-
Notifications
You must be signed in to change notification settings - Fork 5
/
test_compare.py
38 lines (27 loc) · 940 Bytes
/
test_compare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from hypothesis import given, settings
from hypothesis.strategies import text, lists
from sentence_transformers import SentenceTransformer
import numpy as np
from run_onnx import DefaultEmbeddingModel
def _run_and_compare(texts):
model = DefaultEmbeddingModel()
result_onnx = model(texts)
st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
result_st = st_model.encode(texts)
assert np.allclose(result_onnx, result_st, atol=1e-6)
@given(text())
@settings(deadline=1500)
def test_compare_single(text):
_run_and_compare([text])
@given(lists(text(), min_size=1))
@settings(deadline=1500)
def test_compare_lists(texts):
_run_and_compare(texts)
@given(lists(text(), min_size=50))
@settings(deadline=5000)
def test_compare_large_lists(texts):
_run_and_compare(texts)
@given(text(min_size=100))
@settings(deadline=5000)
def test_compare_large_text(text):
_run_and_compare([text])