-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytx] Implement a new cleaner PDQ index solution #1698
Changes from 7 commits
300df97
9fd7744
3cf072b
cd44cbd
a5bfb6d
0a32d65
610af2e
2ea86a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
""" | ||
Implementation of SignalTypeIndex abstraction for PDQ | ||
""" | ||
|
||
import typing as t | ||
import faiss | ||
import numpy as np | ||
|
||
|
||
from threatexchange.signal_type.index import ( | ||
IndexMatchUntyped, | ||
SignalSimilarityInfoWithIntDistance, | ||
SignalTypeIndex, | ||
T as IndexT, | ||
) | ||
from threatexchange.signal_type.pdq.pdq_utils import ( | ||
BITS_IN_PDQ, | ||
PDQ_CONFIDENT_MATCH_THRESHOLD, | ||
convert_pdq_strings_to_ndarray, | ||
) | ||
|
||
PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT] | ||
|
||
|
||
class PDQIndex2(SignalTypeIndex[IndexT]): | ||
""" | ||
Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search. | ||
|
||
This is a redo of the existing PDQ index, | ||
designed to be simpler and fix hard-to-squash bugs in the existing implementation. | ||
Purpose of this class: to replace the original index in pytx 2.0 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
index: t.Optional[faiss.Index] = None, | ||
entries: t.Iterable[t.Tuple[str, IndexT]] = (), | ||
*, | ||
threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD, | ||
) -> None: | ||
super().__init__() | ||
self.threshold = threshold | ||
|
||
if index is None: | ||
index = faiss.IndexFlatL2(BITS_IN_PDQ) | ||
self._index = _PDQFaissIndex(index) | ||
|
||
# Matches hash to Faiss index | ||
self._deduper: t.Dict[str, int] = {} | ||
# Entry mapping: Each list[entries]'s index is its hash's index | ||
self._idx_to_entries: t.List[t.List[IndexT]] = [] | ||
|
||
self.add_all(entries=entries) | ||
|
||
def __len__(self) -> int: | ||
return len(self._idx_to_entries) | ||
|
||
def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: | ||
""" | ||
Look up entries against the index, up to the threshold. | ||
""" | ||
results: t.List[PDQIndexMatch[IndexT]] = [] | ||
matches_list: t.List[t.Tuple[int, int]] = self._index.search( | ||
queries=[hash], threshold=self.threshold | ||
) | ||
|
||
for match, distance in matches_list: | ||
entries = self._idx_to_entries[match] | ||
# Create match objects for each entry | ||
results.extend( | ||
PDQIndexMatch( | ||
SignalSimilarityInfoWithIntDistance(distance=int(distance)), | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
entry, | ||
) | ||
for entry in entries | ||
) | ||
return results | ||
|
||
def add(self, signal_str: str, entry: IndexT) -> None: | ||
self.add_all(((signal_str, entry),)) | ||
|
||
def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None: | ||
for h, i in entries: | ||
existing_faiss_id = self._deduper.get(h) | ||
if existing_faiss_id is None: | ||
self._index.add([h]) | ||
self._idx_to_entries.append([i]) | ||
next_id = len(self._deduper) # Because faiss index starts from 0 up | ||
self._deduper[h] = next_id | ||
else: | ||
# Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication | ||
self._idx_to_entries[existing_faiss_id].append(i) | ||
|
||
|
||
class _PDQFaissIndex: | ||
""" | ||
A wrapper around the faiss index for pickle serialization | ||
""" | ||
|
||
def __init__(self, faiss_index: faiss.Index) -> None: | ||
self.faiss_index = faiss_index | ||
|
||
def add(self, pdq_strings: t.Sequence[str]) -> None: | ||
""" | ||
Add PDQ hashes to the FAISS index. | ||
""" | ||
vectors = convert_pdq_strings_to_ndarray(pdq_strings) | ||
self.faiss_index.add(vectors) | ||
|
||
def search( | ||
self, queries: t.Sequence[str], threshold: int | ||
) -> t.List[t.Tuple[int, int]]: | ||
""" | ||
Search the FAISS index for matches to the given PDQ queries. | ||
""" | ||
query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries) | ||
limits, distances, indices = self.faiss_index.range_search( | ||
query_array, threshold + 1 | ||
) | ||
|
||
results: t.List[t.Tuple[int, int]] = [] | ||
for i in range(len(queries)): | ||
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]] | ||
dists = [dist for dist in distances[limits[i] : limits[i + 1]]] | ||
for j in range(len(matches)): | ||
results.append((matches[j], dists[j])) | ||
return results | ||
|
||
def __getstate__(self): | ||
return faiss.serialize_index(self.faiss_index) | ||
|
||
def __setstate__(self, data): | ||
self.faiss_index = faiss.deserialize_index(data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import typing as t | ||
import numpy as np | ||
import random | ||
|
||
from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 | ||
from threatexchange.signal_type.pdq.signal import PdqSignal | ||
from threatexchange.signal_type.pdq.pdq_utils import simple_distance | ||
|
||
|
||
def _generate_sample_hashes(size: int, seed: int = 42): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sorry what do you mean by the image:P ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 42 is the "Answer to the Ultimate Question of Life, The Universe, and Everything" in the book series "The Hitchhiker's Guide to the Galaxy", of which the image is on the cover. I thought you were a fan :P Read more: https://en.wikipedia.org/wiki/Phrases_from_The_Hitchhiker%27s_Guide_to_the_Galaxy |
||
random.seed(seed) | ||
return [PdqSignal.get_random_signal() for _ in range(size)] | ||
|
||
|
||
SAMPLE_HASHES = _generate_sample_hashes(100) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works for me! |
||
|
||
|
||
def _brute_force_match( | ||
base: t.List[str], query: str, threshold: int = 32 | ||
) -> t.Set[t.Tuple[int, int]]: | ||
matches = set() | ||
|
||
for i, base_hash in enumerate(base): | ||
distance = simple_distance(base_hash, query) | ||
if distance <= threshold: | ||
matches.add((i, distance)) | ||
return matches | ||
|
||
|
||
def _generate_random_hash_with_distance(hash: str, distance: int) -> str: | ||
if not (0 <= distance <= 256): | ||
raise ValueError("Distance must be between 0 and 256") | ||
|
||
hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary | ||
bits = list(hash_bits) | ||
positions = random.sample( | ||
range(256), distance | ||
) # Randomly select unique positions to flip | ||
for pos in positions: | ||
bits[pos] = "0" if bits[pos] == "1" else "1" # Flip selected bit positions | ||
modified_hash = hex(int("".join(bits), 2))[2:].zfill(64) # Convert back to hex | ||
Dcallies marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return modified_hash | ||
|
||
|
||
def test_pdq_index(): | ||
# Make sure base_hashes and query_hashes have at least 10 similar hashes | ||
base_hashes = SAMPLE_HASHES | ||
query_hashes = SAMPLE_HASHES[:10] + _generate_sample_hashes(10) | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
brute_force_matches = { | ||
query_hash: _brute_force_match(base_hashes, query_hash) | ||
for query_hash in query_hashes | ||
} | ||
|
||
index = PDQIndex2() | ||
for i, base_hash in enumerate(base_hashes): | ||
index.add(base_hash, i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
|
||
for query_hash in query_hashes: | ||
expected_indices = brute_force_matches[query_hash] | ||
index_results = index.query(query_hash) | ||
|
||
result_indices: t.Set[t.Tuple[t.Any, int]] = { | ||
(result.metadata, result.similarity_info.distance) | ||
for result in index_results | ||
} | ||
|
||
assert result_indices == expected_indices, ( | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"Mismatch for hash {query_hash}: " | ||
f"Expected {expected_indices}, Got {result_indices}" | ||
) | ||
|
||
|
||
def test_pdq_index_with_exact_distance(): | ||
thresholds: t.List[int] = [10, 31, 50] | ||
|
||
indexes = [ | ||
PDQIndex2( | ||
entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES], | ||
threshold=thres, | ||
) | ||
for thres in thresholds | ||
] | ||
|
||
distances: t.List[int] = [0, 1, 20, 30, 31, 60] | ||
query_hash = SAMPLE_HASHES[0] | ||
|
||
for i in range(len(indexes)): | ||
index = indexes[i] | ||
|
||
for dist in distances: | ||
query_hash_w_dist = _generate_random_hash_with_distance(query_hash, dist) | ||
results = index.query(query_hash_w_dist) | ||
result_indices = {result.similarity_info.distance for result in results} | ||
if dist <= thresholds[i]: | ||
assert dist in result_indices | ||
|
||
|
||
def test_empty_index_query(): | ||
"""Test querying an empty index.""" | ||
index = PDQIndex2() | ||
|
||
# Query should return empty list | ||
results = index.query(PdqSignal.get_random_signal()) | ||
assert len(results) == 0 | ||
|
||
|
||
def test_sample_set_no_match(): | ||
"""Test no matches in sample set.""" | ||
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) | ||
results = index.query("b" * 64) | ||
assert len(results) == 0 | ||
|
||
|
||
def test_duplicate_handling(): | ||
"""Test how the index handles duplicate entries.""" | ||
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) | ||
|
||
# Add same hash multiple times | ||
index.add_all(entries=[(SAMPLE_HASHES[0], i) for i in range(3)]) | ||
|
||
results = index.query(SAMPLE_HASHES[0]) | ||
|
||
# Should find all entries associated with the hash | ||
assert len(results) == 4 | ||
for result in results: | ||
assert result.similarity_info.distance == 0 | ||
|
||
|
||
def test_one_entry_sample_index(): | ||
"""Test how the index handles when it only has one entry.""" | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
index = PDQIndex2(entries=[(SAMPLE_HASHES[0], 0)]) | ||
|
||
matching_test_hash = SAMPLE_HASHES[0] # This is the existing hash in index | ||
unmatching_test_hash = SAMPLE_HASHES[1] | ||
|
||
results = index.query(matching_test_hash) | ||
# Should find 1 entry associated with the hash | ||
assert len(results) == 1 | ||
assert results[0].similarity_info.distance == 0 | ||
|
||
results = index.query(unmatching_test_hash) | ||
assert len(results) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can probably leave this off and rely on
build