-
Notifications
You must be signed in to change notification settings - Fork 297
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Benchmark ANN index for Jaccard (#210)
- Loading branch information
Showing
16 changed files
with
964 additions
and
261 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.sqlite | ||
*.inp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Benchmarking ANN Indexes for Jaccard Distance | ||
|
||
Download datasets from: [Set Similarity Search Benchmarks](https://github.com/ekzhu/set-similarity-search-benchmarks). | ||
|
||
Use gzip to decompress the datasets. | ||
|
||
``` | ||
gzip -d *.gz | ||
``` | ||
|
||
## Set Size Distribution | ||
|
||
``` | ||
python plot_set_distribution.py orkut_ge10.inp FLICKR-london2y_dup_dr.inp --output-dir plots | ||
``` | ||
|
||
![Set size distribution](plots/set_size_distribution.png) | ||
|
||
## Run Benchmarks | ||
|
||
For example, for Orkut and Flickr datasets, run: | ||
|
||
``` | ||
python topk_benchmark.py --index-set-file orkut_ge10.inp --query-set-file orkut_ge10.inp --query-sample-ratio 0.01 --output orkut.sqlite | ||
python topk_benchmark.py --index-set-file FLICKR-london2y_dup_dr.inp --query-set-file FLICKR-london2y_dup_dr.inp --query-sample-ratio 0.01 --output flickr.sqlite | ||
``` | ||
|
||
The results are stored in a SQLite database `orkut.sqlite` and `flickr.sqlite`. | ||
|
||
## Plot Results | ||
|
||
``` | ||
python plot_topk_benchmark.py orkut.sqlite --output-dir plots --max-distance-at-k 0.74 1.0 | ||
python plot_topk_benchmark.py flickr.sqlite --output-dir plots --max-distance-at-k 0.1 1.0 | ||
``` | ||
|
||
The plots are stored in the `plots` directory. | ||
|
||
Query Per Second (QPS) vs. Recall for Orkut, Maximum Distance at K = 1.00 (i.e. all queries are selected). | ||
|
||
![QPS vs. Recall Orkut](plots/k100/orkut_qps_recall_1.00.png) | ||
|
||
Query Per Second (QPS) vs. Recall, Maximum Distance at K = 0.74 (i.e. only queries with all top-k results' distances greater than 0.74 are selected.). | ||
|
||
![QPS vs. Recall Orkut](plots/k100/orkut_qps_recall_0.74.png) | ||
|
||
Indexing Time vs. Recall. | ||
|
||
![Indexing vs. Recall Orkut](plots/k100/orkut_indexing_recall_1.00.png) | ||
|
||
|
||
## Distance Distribution | ||
|
||
``` | ||
python plot_distance_distribution.py orkut.sqlite flickr.sqlite --output-dir plots | ||
``` | ||
|
||
![Distance distribution](plots/jaccard_distances_at_k.png) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import json | ||
import os | ||
import argparse | ||
import sqlite3 | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("benchmark_result", nargs="+", type=str) | ||
parser.add_argument("--k", type=int, required=False) | ||
parser.add_argument("--output-dir", default="plots") | ||
args = parser.parse_args() | ||
|
||
if not os.path.exists(args.output_dir): | ||
os.makedirs(args.output_dir) | ||
|
||
# Obtain ground truth runs from the benchmark results. | ||
quartiles_at_k = {} | ||
for benchmark_result in args.benchmark_result: | ||
conn = sqlite3.connect(benchmark_result) | ||
cursor = conn.cursor() | ||
cursor.execute( | ||
"""SELECT key | ||
FROM runs | ||
WHERE name == 'ground_truth'""", | ||
) | ||
run_keys = [row[0] for row in cursor] | ||
|
||
# Load results for the first run. | ||
distances = [] | ||
cursor.execute("""SELECT result FROM results WHERE run_key = ?""", (run_keys[0],)) | ||
for row in cursor: | ||
result = json.loads(row[0]) | ||
if args.k is not None: | ||
result = result[: args.k] | ||
distances.append([1.0 - similarity for _, similarity in result]) | ||
|
||
# Pad similarities with zeros. | ||
max_length = max([len(x) for x in distances]) | ||
distances = np.array( | ||
[ | ||
np.pad(x, (0, max_length - len(x)), "constant", constant_values=1.0) | ||
for x in distances | ||
] | ||
) | ||
|
||
# Compute quartiles at k. | ||
name = os.path.splitext(os.path.basename(benchmark_result))[0] | ||
pcts = [25, 50, 75] | ||
quartiles_at_k[name] = (pcts, np.percentile(distances, pcts, axis=0)) | ||
|
||
|
||
# Plot. | ||
if args.k is not None: | ||
filename = f"jaccard_distances_at_k_k={args.k}.png" | ||
else: | ||
filename = "jaccard_distances_at_k.png" | ||
plt.figure() | ||
for name, (pcts, sims) in quartiles_at_k.items(): | ||
lower, upper = sims[0], sims[-1] | ||
xs = np.arange(1, len(lower) + 1) | ||
plt.fill_between(xs, lower, upper, alpha=0.2, label=f"{name} (25-75%)") | ||
plt.plot(xs, sims[1], label=f"{name} (50%)") | ||
plt.xlabel("K") | ||
plt.ylabel("Median Jaccard Distance at K") | ||
plt.title("Jaccard Distances at K") | ||
plt.legend() | ||
plt.savefig(os.path.join(args.output_dir, filename)) | ||
plt.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import argparse | ||
import os | ||
from matplotlib import pyplot as plt | ||
import numpy as np | ||
|
||
import numpy as np | ||
from utils import read_set_sizes_from_file | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"sets_files", nargs="+", help="Ihe input files for reading sets from." | ||
) | ||
parser.add_argument("--output-dir", default=".", help="The output directory.") | ||
args = parser.parse_args() | ||
|
||
plt.figure() | ||
for sets_file in args.sets_files: | ||
name = os.path.splitext(os.path.basename(sets_file))[0] | ||
print(f"Processing {sets_file}...") | ||
set_sizes = read_set_sizes_from_file(sets_file) | ||
# Plot the distribution of set sizes histogram. | ||
h, bins, patches = plt.hist( | ||
set_sizes, | ||
bins=np.logspace(np.log10(np.min(set_sizes)), np.log10(np.max(set_sizes)), 30), | ||
alpha=0.5, | ||
label=name, | ||
) | ||
color = patches[0].get_facecolor() | ||
mean = np.mean(set_sizes) | ||
plt.axvline(mean, color=color, linestyle="dashed", linewidth=1) | ||
plt.text( | ||
mean, | ||
0.5 * np.max(h), | ||
f"Mean Size = {mean:.2f}", | ||
color="k", | ||
) | ||
plt.xscale("log") | ||
# plt.yscale("log") | ||
plt.legend() | ||
plt.xlabel("Set size") | ||
plt.ylabel("Number of sets") | ||
plt.title("Distribution of Set Sizes") | ||
plt.savefig( | ||
os.path.join(args.output_dir, "set_size_distribution.png"), bbox_inches="tight" | ||
) |
Oops, something went wrong.