Skip to content

Commit

Permalink
Benchmark ANN index for Jaccard (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Jun 30, 2023
1 parent fd9e56b commit ebe4ca4
Show file tree
Hide file tree
Showing 16 changed files with 964 additions and 261 deletions.
2 changes: 2 additions & 0 deletions benchmark/indexes/jaccard/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.sqlite
*.inp
59 changes: 59 additions & 0 deletions benchmark/indexes/jaccard/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Benchmarking ANN Indexes for Jaccard Distance

Download datasets from: [Set Similarity Search Benchmarks](https://github.com/ekzhu/set-similarity-search-benchmarks).

Use gzip to decompress the datasets.

```
gzip -d *.gz
```

## Set Size Distribution

```
python plot_set_distribution.py orkut_ge10.inp FLICKR-london2y_dup_dr.inp --output-dir plots
```

![Set size distribution](plots/set_size_distribution.png)

## Run Benchmarks

For example, for Orkut and Flickr datasets, run:

```
python topk_benchmark.py --index-set-file orkut_ge10.inp --query-set-file orkut_ge10.inp --query-sample-ratio 0.01 --output orkut.sqlite
python topk_benchmark.py --index-set-file FLICKR-london2y_dup_dr.inp --query-set-file FLICKR-london2y_dup_dr.inp --query-sample-ratio 0.01 --output flickr.sqlite
```

The results are stored in a SQLite database `orkut.sqlite` and `flickr.sqlite`.

## Plot Results

```
python plot_topk_benchmark.py orkut.sqlite --output-dir plots --max-distance-at-k 0.74 1.0
python plot_topk_benchmark.py flickr.sqlite --output-dir plots --max-distance-at-k 0.1 1.0
```

The plots are stored in the `plots` directory.

Query Per Second (QPS) vs. Recall for Orkut, Maximum Distance at K = 1.00 (i.e. all queries are selected).

![QPS vs. Recall Orkut](plots/k100/orkut_qps_recall_1.00.png)

Query Per Second (QPS) vs. Recall, Maximum Distance at K = 0.74 (i.e. only queries with all top-k results' distances greater than 0.74 are selected.).

![QPS vs. Recall Orkut](plots/k100/orkut_qps_recall_0.74.png)

Indexing Time vs. Recall.

![Indexing vs. Recall Orkut](plots/k100/orkut_indexing_recall_1.00.png)


## Distance Distribution

```
python plot_distance_distribution.py orkut.sqlite flickr.sqlite --output-dir plots
```

![Distance distribution](plots/jaccard_distances_at_k.png)

29 changes: 16 additions & 13 deletions benchmark/indexes/jaccard/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

def _query_jaccard_topk(index, query, k):
"""Query the search index for the best k candidates."""
assert(index.similarity_threshold == 0.0)
assert index.similarity_threshold == 0.0
s1 = [index.order[token] for token in query if token in index.order]
# Get the number of occurrances of candidates in the posting lists.
counter = collections.Counter(
i for token in s1 for i, _ in index.index[token])
counter = collections.Counter(i for token in s1 for i, _ in index.index[token])
# Compute the Jaccard similarities based on the counts.
candidates = [(i, float(c) / float(len(s1) + len(index.sets[i]) - c))
for (i, c) in counter.items()]
candidates = [
(i, float(c) / float(len(s1) + len(index.sets[i]) - c))
for (i, c) in counter.items()
]
# Sort candidates based on similarities.
candidates.sort(key=lambda x: x[1], reverse=True)
# Return the top-k candidates.
Expand All @@ -27,20 +28,22 @@ def search_jaccard_topk(index_data, query_data, k):
print("Building jaccard search index.")
start = time.perf_counter()
# Build the search index with the 0 threshold to index all tokens.
index = SearchIndex(index_sets, similarity_func_name="jaccard",
similarity_threshold=0.0)
duration = time.perf_counter() - start
print("Finished building index in {:.3f}.".format(duration))
index = SearchIndex(
index_sets, similarity_func_name="jaccard", similarity_threshold=0.0
)
indexing_time = time.perf_counter() - start
print("Finished building index in {:.3f}.".format(indexing_time))
times = []
results = []
for query_set, query_key in zip(query_sets, query_keys):
start = time.perf_counter()
result = [[index_keys[i], similarity]
for i, similarity in _query_jaccard_topk(index, query_set, k)]
result = [
[index_keys[i], similarity]
for i, similarity in _query_jaccard_topk(index, query_set, k)
]
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
sys.stdout.write("\rQueried {} sets.".format(len(results)))
sys.stdout.write("\n")
return (results, times)

return (indexing_time, results, times)
25 changes: 14 additions & 11 deletions benchmark/indexes/jaccard/hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,31 @@ def search_hnsw_jaccard_topk(index_data, query_data, index_params, k):
(query_sets, query_keys) = query_data
print("Building HNSW Index.")
start = time.perf_counter()
index = nmslib.init(method="hnsw", space="jaccard_sparse",
data_type=nmslib.DataType.OBJECT_AS_STRING)
index = nmslib.init(
method="hnsw",
space="jaccard_sparse",
data_type=nmslib.DataType.OBJECT_AS_STRING,
)
index.addDataPointBatch(
[" ".join(str(v) for v in s) for s in index_sets],
range(len(index_keys)))
[" ".join(str(v) for v in s) for s in index_sets], range(len(index_keys))
)
index.createIndex(index_params)
end = time.perf_counter()
print("Indexing time: {:.3f}.".format(end-start))
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
print("Querying.")
times = []
results = []
index.setQueryTimeParams({"efSearch": index_params["efConstruction"]})
for query_set, query_key in zip(query_sets, query_keys):
start = time.perf_counter()
result, _ = index.knnQuery(" ".join(str(v) for v in query_set), k)
result = [[index_keys[i], compute_jaccard(query_set, index_sets[i])]
for i in result]
result.sort(key=lambda x : x[1], reverse=True)
result = [
[index_keys[i], compute_jaccard(query_set, index_sets[i])] for i in result
]
result.sort(key=lambda x: x[1], reverse=True)
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
sys.stdout.write(f"\rQueried {len(results)} sets")
sys.stdout.write("\n")
return (results, times)

return (indexing_time, results, times)
31 changes: 20 additions & 11 deletions benchmark/indexes/jaccard/lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,39 @@ def search_lsh_jaccard_topk(index_data, query_data, b, r, k):
num_perm = b * r
print("Building LSH Index.")
start = time.perf_counter()
index = MinHashLSH(num_perm=num_perm, params=(b, r))
index = MinHashLSH(
num_perm=num_perm,
params=(b, r),
)
# Use the indices of the indexed sets as keys in LSH.
for i in range(len(index_keys)):
index.insert(i, index_minhashes[num_perm][i])
end = time.perf_counter()
print("Indexing time: {:.3f}.".format(end-start))
index.insert(
i,
index_minhashes[num_perm][i],
check_duplication=False,
)
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
print("Querying.")
times = []
results = []
for query_minhash, query_key, query_set in \
zip(query_minhashes[num_perm], query_keys, query_sets):
for query_minhash, query_key, query_set in zip(
query_minhashes[num_perm], query_keys, query_sets
):
start = time.perf_counter()
result = index.query(query_minhash)
# Recover the retrieved indexed sets and
# Recover the retrieved indexed sets and
# compute the exact Jaccard similarities.
result = [[index_keys[i], compute_jaccard(query_set, index_sets[i])]
for i in result]
result = [
[index_keys[i], compute_jaccard(query_set, index_sets[i])] for i in result
]
# Sort by similarity.
result.sort(key=lambda x : x[1], reverse=True)
result.sort(key=lambda x: x[1], reverse=True)
# Take the first k.
result = result[:k]
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
sys.stdout.write(f"\rQueried {len(results)} sets")
sys.stdout.write("\n")
return (results, times)
return (indexing_time, results, times)
23 changes: 12 additions & 11 deletions benchmark/indexes/jaccard/lshforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,28 @@ def search_lshforest_jaccard_topk(index_data, query_data, b, r, k):
for i in range(len(index_keys)):
index.add(i, index_minhashes[num_perm][i])
index.index()
end = time.perf_counter()
print("Indexing time: {:.3f}.".format(end-start))
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
print("Querying.")
times = []
results = []
for query_minhash, query_key, query_set in \
zip(query_minhashes[num_perm], query_keys, query_sets):
for query_minhash, query_key, query_set in zip(
query_minhashes[num_perm], query_keys, query_sets
):
start = time.perf_counter()
result = index.query(query_minhash, k*2)
# Recover the retrieved indexed sets and
result = index.query(query_minhash, k * 2)
# Recover the retrieved indexed sets and
# compute the exact Jaccard similarities.
result = [[index_keys[i], compute_jaccard(query_set, index_sets[i])]
for i in result]
result = [
[index_keys[i], compute_jaccard(query_set, index_sets[i])] for i in result
]
# Sort by similarity.
result.sort(key=lambda x : x[1], reverse=True)
result.sort(key=lambda x: x[1], reverse=True)
# Take the top k.
result = result[:k]
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
sys.stdout.write(f"\rQueried {len(results)} sets")
sys.stdout.write("\n")
return (results, times)

return (indexing_time, results, times)
71 changes: 71 additions & 0 deletions benchmark/indexes/jaccard/plot_distance_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import json
import os
import argparse
import sqlite3

import matplotlib.pyplot as plt
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument("benchmark_result", nargs="+", type=str)
parser.add_argument("--k", type=int, required=False)
parser.add_argument("--output-dir", default="plots")
args = parser.parse_args()

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

# Obtain ground truth runs from the benchmark results.
quartiles_at_k = {}
for benchmark_result in args.benchmark_result:
conn = sqlite3.connect(benchmark_result)
cursor = conn.cursor()
cursor.execute(
"""SELECT key
FROM runs
WHERE name == 'ground_truth'""",
)
run_keys = [row[0] for row in cursor]

# Load results for the first run.
distances = []
cursor.execute("""SELECT result FROM results WHERE run_key = ?""", (run_keys[0],))
for row in cursor:
result = json.loads(row[0])
if args.k is not None:
result = result[: args.k]
distances.append([1.0 - similarity for _, similarity in result])

# Pad similarities with zeros.
max_length = max([len(x) for x in distances])
distances = np.array(
[
np.pad(x, (0, max_length - len(x)), "constant", constant_values=1.0)
for x in distances
]
)

# Compute quartiles at k.
name = os.path.splitext(os.path.basename(benchmark_result))[0]
pcts = [25, 50, 75]
quartiles_at_k[name] = (pcts, np.percentile(distances, pcts, axis=0))


# Plot.
if args.k is not None:
filename = f"jaccard_distances_at_k_k={args.k}.png"
else:
filename = "jaccard_distances_at_k.png"
plt.figure()
for name, (pcts, sims) in quartiles_at_k.items():
lower, upper = sims[0], sims[-1]
xs = np.arange(1, len(lower) + 1)
plt.fill_between(xs, lower, upper, alpha=0.2, label=f"{name} (25-75%)")
plt.plot(xs, sims[1], label=f"{name} (50%)")
plt.xlabel("K")
plt.ylabel("Median Jaccard Distance at K")
plt.title("Jaccard Distances at K")
plt.legend()
plt.savefig(os.path.join(args.output_dir, filename))
plt.close()
45 changes: 45 additions & 0 deletions benchmark/indexes/jaccard/plot_set_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import argparse
import os
from matplotlib import pyplot as plt
import numpy as np

import numpy as np
from utils import read_set_sizes_from_file

parser = argparse.ArgumentParser()
parser.add_argument(
"sets_files", nargs="+", help="Ihe input files for reading sets from."
)
parser.add_argument("--output-dir", default=".", help="The output directory.")
args = parser.parse_args()

plt.figure()
for sets_file in args.sets_files:
name = os.path.splitext(os.path.basename(sets_file))[0]
print(f"Processing {sets_file}...")
set_sizes = read_set_sizes_from_file(sets_file)
# Plot the distribution of set sizes histogram.
h, bins, patches = plt.hist(
set_sizes,
bins=np.logspace(np.log10(np.min(set_sizes)), np.log10(np.max(set_sizes)), 30),
alpha=0.5,
label=name,
)
color = patches[0].get_facecolor()
mean = np.mean(set_sizes)
plt.axvline(mean, color=color, linestyle="dashed", linewidth=1)
plt.text(
mean,
0.5 * np.max(h),
f"Mean Size = {mean:.2f}",
color="k",
)
plt.xscale("log")
# plt.yscale("log")
plt.legend()
plt.xlabel("Set size")
plt.ylabel("Number of sets")
plt.title("Distribution of Set Sizes")
plt.savefig(
os.path.join(args.output_dir, "set_size_distribution.png"), bbox_inches="tight"
)
Loading

0 comments on commit ebe4ca4

Please sign in to comment.