Skip to content

Commit 5a7da98

Browse files
authored
Fix python examples fail (#108)
Use margin check and add tests to pipeline. Fixes #107.
1 parent e26732d commit 5a7da98

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

.github/workflows/cibuildwheel.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,9 @@ jobs:
6060
working-directory: ${{ runner.temp }}
6161
run: python -m unittest discover -s ${GITHUB_WORKSPACE}/bindings/python
6262

63+
- name: Run examples
64+
env:
65+
PYTHONPATH: ${{ runner.temp }}/usr
66+
CTEST_OUTPUT_ON_FAILURE: 1
67+
working-directory: ${{ runner.temp }}
68+
run: python -m unittest discover -p "example*.py" -s ${GITHUB_WORKSPACE}/examples/python

examples/python/example_vamana.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
# [imports]
2222

2323
DEBUG_MODE = False
24-
def assert_equal(lhs, rhs, message: str = ""):
24+
def assert_equal(lhs, rhs, message: str = "", epsilon = 0.05):
2525
if DEBUG_MODE:
2626
print(f"{message}: {lhs} == {rhs}")
2727
else:
28-
assert lhs == rhs, message
28+
assert lhs < rhs + epsilon, message
29+
assert lhs > rhs - epsilon, message
2930

3031
def run_test_float(index, queries, groundtruth):
3132
expected = {
@@ -79,7 +80,6 @@ def run_test_build_two_level4_8(index, queries, groundtruth):
7980
test_data_dir = None
8081

8182
def run():
82-
8383
# ###
8484
# Generating test data
8585
# ###
@@ -159,7 +159,7 @@ def run():
159159
# Compare with the groundtruth.
160160
recall = svs.k_recall_at(groundtruth, I, 10, 10)
161161
print(f"Recall = {recall}")
162-
assert(recall == 0.8288)
162+
assert_equal(recall, 0.8288)
163163
# [perform-queries]
164164

165165
# [search-window-size]
@@ -213,7 +213,7 @@ def run():
213213
# Compare with the groundtruth.
214214
recall = svs.k_recall_at(groundtruth, I, 10, 10)
215215
print(f"Recall = {recall}")
216-
assert(recall == 0.8288)
216+
assert_equal(recall, 0.8288)
217217
# [loading]
218218

219219
##### Begin Test

examples/python/example_vamana_dynamic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
# [imports]
2323

2424
DEBUG_MODE = False
25-
def assert_equal(lhs, rhs, message: str = ""):
25+
def assert_equal(lhs, rhs, message: str = "", epsilon = 0.05):
2626
if DEBUG_MODE:
2727
print(f"{message}: {lhs} == {rhs}")
2828
else:
29-
assert lhs == rhs, message
29+
assert lhs < rhs + epsilon, message
30+
assert lhs > rhs - epsilon, message
3031

3132
def run_test_float(index, queries, groundtruth):
3233
expected = {
@@ -118,7 +119,7 @@ def run():
118119
# Compare with the groundtruth.
119120
recall = svs.k_recall_at(groundtruth, I, 10, 10)
120121
print(f"Recall = {recall}")
121-
assert(recall == 0.8202)
122+
assert_equal(recall, 0.8202)
122123
# [perform-queries]
123124

124125
##### Begin Test
@@ -158,8 +159,7 @@ def run():
158159
# Compare with the groundtruth.
159160
recall = svs.k_recall_at(groundtruth, I, 10, 10)
160161
print(f"Recall = {recall}")
161-
assert(recall == 0.8202)
162-
162+
assert_equal(recall, 0.8202)
163163

164164
##### Begin Test
165165
run_test_float(index, queries, groundtruth)

0 commit comments

Comments
 (0)