-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytx] Implement a new cleaner PDQ index solution #1698
Changes from 5 commits
300df97
9fd7744
3cf072b
cd44cbd
a5bfb6d
0a32d65
610af2e
2ea86a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
""" | ||
Implementation of SignalTypeIndex abstraction for PDQ by wrapping | ||
hashing.pdq_faiss_matcher. | ||
""" | ||
|
||
import typing as t | ||
import faiss | ||
import numpy as np | ||
|
||
|
||
from threatexchange.signal_type.index import ( | ||
IndexMatchUntyped, | ||
SignalSimilarityInfoWithIntDistance, | ||
SignalTypeIndex, | ||
T as IndexT, | ||
) | ||
from threatexchange.signal_type.pdq.signal import PDQ_CONFIDENT_MATCH_THRESHOLD | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from threatexchange.signal_type.pdq.pdq_utils import ( | ||
BITS_IN_PDQ, | ||
convert_pdq_strings_to_ndarray, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
|
||
PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT] | ||
|
||
|
||
class PDQIndex2(SignalTypeIndex[IndexT]): | ||
""" | ||
Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search. | ||
""" | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
index: t.Optional[faiss.Index] = None, | ||
entries: t.Iterable[t.Tuple[str, IndexT]] = (), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can probably leave this off and rely on |
||
*, | ||
threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD, | ||
) -> None: | ||
super().__init__() | ||
self.threshold = threshold | ||
|
||
if index is None: | ||
index = faiss.IndexFlatL2(BITS_IN_PDQ) | ||
self.index = _PDQFaissIndex(index) | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Matches hash to Faiss index | ||
self._deduper: t.Dict[str, int] = {} | ||
# Entry mapping: Each list[entries]'s index is its hash's index | ||
self._idx_to_entries: t.List[t.List[IndexT]] = [] | ||
|
||
self.add_all(entries=entries) | ||
|
||
def __len__(self) -> int: | ||
return len(self._idx_to_entries) | ||
|
||
def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: | ||
""" | ||
Look up entries against the index, up to the max supported distance. | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
results: t.List[PDQIndexMatch[IndexT]] = [] | ||
matches_list: t.List[t.Tuple[int, int]] = self.index.search( | ||
queries=[hash], threshold=self.threshold | ||
) | ||
|
||
for match, distance in matches_list: | ||
entries = self._idx_to_entries[match] | ||
# Create match objects for each entry | ||
results.extend( | ||
PDQIndexMatch( | ||
SignalSimilarityInfoWithIntDistance(distance=int(distance)), | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
entry, | ||
) | ||
for entry in entries | ||
) | ||
return results | ||
|
||
def add(self, signal_str: str, entry: IndexT) -> None: | ||
self.add_all(((signal_str, entry),)) | ||
|
||
def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None: | ||
for h, i in entries: | ||
existing_faiss_id = self._deduper.get(h) | ||
if existing_faiss_id is None: | ||
self.index.add([h]) | ||
self._idx_to_entries.append([i]) | ||
next_id = len(self._deduper) # Because faiss index starts from 0 up | ||
self._deduper[h] = next_id | ||
else: | ||
# Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication | ||
self._idx_to_entries[existing_faiss_id].append(i) | ||
|
||
|
||
class _PDQFaissIndex: | ||
""" | ||
A wrapper around the faiss index for pickle serialization | ||
""" | ||
|
||
def __init__(self, faiss_index: faiss.Index) -> None: | ||
self.faiss_index = faiss_index | ||
|
||
def add(self, pdq_strings: t.Sequence[str]) -> None: | ||
""" | ||
Add PDQ hashes to the FAISS index. | ||
""" | ||
vectors = convert_pdq_strings_to_ndarray(pdq_strings) | ||
self.faiss_index.add(vectors) | ||
|
||
def search( | ||
self, queries: t.Sequence[str], threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> t.List[t.Tuple[int, int]]: | ||
""" | ||
Search the FAISS index for matches to the given PDQ queries. | ||
""" | ||
query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries) | ||
limits, distances, indices = self.faiss_index.range_search( | ||
query_array, threshold + 1 | ||
) | ||
|
||
results: t.List[t.Tuple[int, int]] = [] | ||
for i in range(len(queries)): | ||
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]] | ||
dists = [dist for dist in distances[limits[i] : limits[i + 1]]] | ||
for j in range(len(matches)): | ||
results.append((matches[j], dists[j])) | ||
return results | ||
|
||
def __getstate__(self): | ||
return faiss.serialize_index(self.faiss_index) | ||
|
||
def __setstate__(self, data): | ||
self.faiss_index = faiss.deserialize_index(data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import typing as t | ||
import numpy as np | ||
import random | ||
|
||
from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 | ||
from threatexchange.signal_type.pdq.signal import PdqSignal | ||
from threatexchange.signal_type.pdq.pdq_utils import convert_pdq_strings_to_ndarray | ||
|
||
SAMPLE_HASHES = [PdqSignal.get_random_signal() for _ in range(100)] | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _brute_force_match( | ||
base: t.List[str], query: str, threshold: int = 32 | ||
) -> t.Set[int]: | ||
matches = set() | ||
query_arr = convert_pdq_strings_to_ndarray([query])[0] | ||
|
||
for i, base_hash in enumerate(base): | ||
base_arr = convert_pdq_strings_to_ndarray([base_hash])[0] | ||
distance = np.count_nonzero(query_arr != base_arr) | ||
if distance <= threshold: | ||
matches.add(i) | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return matches | ||
|
||
|
||
def _generate_random_hash_with_distance(hash: str, distance: int) -> str: | ||
if len(hash) != 64 or not all(c in "0123456789abcdef" for c in hash.lower()): | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError("Hash must be a 64-character hexadecimal string") | ||
if distance < 0 or distance > 256: | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError("Distance must be between 0 and 256") | ||
|
||
hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary | ||
bits = list(hash_bits) | ||
positions = random.sample( | ||
range(256), distance | ||
) # Randomly select unique positions to flip | ||
for pos in positions: | ||
bits[pos] = "0" if bits[pos] == "1" else "1" # Flip selected bit positions | ||
modified_hash = hex(int("".join(bits), 2))[2:].zfill(64) # Convert back to hex | ||
Dcallies marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return modified_hash | ||
|
||
|
||
def test_pdq_index(): | ||
# Make sure base_hashes and query_hashes have at least 100 similar hashes | ||
base_hashes = SAMPLE_HASHES + [PdqSignal.get_random_signal() for _ in range(1000)] | ||
query_hashes = SAMPLE_HASHES + [PdqSignal.get_random_signal() for _ in range(10000)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why add SAMPLE_HASHES here? If you are trying to guarantee matches, you could generate less query hashes and copy the banked hashes in |
||
|
||
brute_force_matches = { | ||
query_hash: _brute_force_match(base_hashes, query_hash) | ||
for query_hash in query_hashes | ||
} | ||
|
||
index = PDQIndex2() | ||
for i, base_hash in enumerate(base_hashes): | ||
index.add(base_hash, i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
|
||
for query_hash in query_hashes: | ||
expected_indices = brute_force_matches[query_hash] | ||
index_results = index.query(query_hash) | ||
|
||
result_indices = {result.metadata for result in index_results} | ||
|
||
assert result_indices == expected_indices, ( | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"Mismatch for hash {query_hash}: " | ||
f"Expected {expected_indices}, Got {result_indices}" | ||
) | ||
|
||
|
||
def test_pdq_index_with_exact_distance(): | ||
thresholds: t.List[int] = [10, 31, 50] | ||
indexes: t.List[PDQIndex2] = [] | ||
for thres in thresholds: | ||
index = PDQIndex2( | ||
entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES], | ||
threshold=thres, | ||
) | ||
indexes.append(index) | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
distances: t.List[int] = [0, 1, 20, 30, 31, 60] | ||
query_hash = SAMPLE_HASHES[0] | ||
|
||
for i in range(len(indexes)): | ||
index = indexes[i] | ||
|
||
for dist in distances: | ||
query_hash_w_dist = _generate_random_hash_with_distance(query_hash, dist) | ||
results = index.query(query_hash_w_dist) | ||
result_indices = {result.similarity_info.distance for result in results} | ||
if dist <= thresholds[i]: | ||
assert dist in result_indices | ||
|
||
|
||
def test_empty_index_query(): | ||
"""Test querying an empty index.""" | ||
index = PDQIndex2() | ||
|
||
# Query should return empty list | ||
results = index.query(PdqSignal.get_random_signal()) | ||
assert len(results) == 0 | ||
|
||
|
||
def test_sample_set_no_match(): | ||
"""Test no matches in sample set.""" | ||
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) | ||
results = index.query("b" * 64) | ||
assert len(results) == 0 | ||
|
||
|
||
def test_duplicate_handling(): | ||
"""Test how the index handles duplicate entries.""" | ||
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) | ||
|
||
# Add same hash multiple times | ||
index.add_all(entries=[(SAMPLE_HASHES[0], i) for i in range(3)]) | ||
|
||
results = index.query(SAMPLE_HASHES[0]) | ||
|
||
# Should find all entries associated with the hash | ||
assert len(results) == 4 | ||
for result in results: | ||
assert result.similarity_info.distance == 0 | ||
|
||
|
||
def test_one_entry_sample_index(): | ||
"""Test how the index handles when it only has one entry.""" | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
index = PDQIndex2(entries=[(SAMPLE_HASHES[0], 0)]) | ||
|
||
matching_test_hash = SAMPLE_HASHES[0] # This is the existing hash in index | ||
unmatching_test_hash = SAMPLE_HASHES[1] | ||
|
||
results = index.query(matching_test_hash) | ||
# Should find 1 entry associated with the hash | ||
assert len(results) == 1 | ||
assert results[0].similarity_info.distance == 0 | ||
|
||
results = index.query(unmatching_test_hash) | ||
assert len(results) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: out of date