Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytx] Implement a new cleaner PDQ index solution #1698

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Implementation of SignalTypeIndex abstraction for PDQ by wrapping
hashing.pdq_faiss_matcher.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: out of date

"""

import typing as t
import faiss
import numpy as np


from threatexchange.signal_type.index import (
IndexMatchUntyped,
SignalSimilarityInfoWithIntDistance,
SignalTypeIndex,
T as IndexT,
)
from threatexchange.signal_type.pdq.signal import PDQ_CONFIDENT_MATCH_THRESHOLD
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
from threatexchange.signal_type.pdq.pdq_utils import (
BITS_IN_PDQ,
convert_pdq_strings_to_ndarray,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT]


class PDQIndex2(SignalTypeIndex[IndexT]):
"""
Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search.
"""
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
index: t.Optional[faiss.Index] = None,
entries: t.Iterable[t.Tuple[str, IndexT]] = (),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can probably leave this off and rely on build

*,
threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD,
) -> None:
super().__init__()
self.threshold = threshold

if index is None:
index = faiss.IndexFlatL2(BITS_IN_PDQ)
self.index = _PDQFaissIndex(index)
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved

# Matches hash to Faiss index
self._deduper: t.Dict[str, int] = {}
# Entry mapping: Each list[entries]'s index is its hash's index
self._idx_to_entries: t.List[t.List[IndexT]] = []

self.add_all(entries=entries)

def __len__(self) -> int:
return len(self._idx_to_entries)

def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]:
"""
Look up entries against the index, up to the max supported distance.
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
"""
results: t.List[PDQIndexMatch[IndexT]] = []
matches_list: t.List[t.Tuple[int, int]] = self.index.search(
queries=[hash], threshold=self.threshold
)

for match, distance in matches_list:
entries = self._idx_to_entries[match]
# Create match objects for each entry
results.extend(
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(distance=int(distance)),
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
entry,
)
for entry in entries
)
return results

def add(self, signal_str: str, entry: IndexT) -> None:
self.add_all(((signal_str, entry),))

def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None:
for h, i in entries:
existing_faiss_id = self._deduper.get(h)
if existing_faiss_id is None:
self.index.add([h])
self._idx_to_entries.append([i])
next_id = len(self._deduper) # Because faiss index starts from 0 up
self._deduper[h] = next_id
else:
# Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication
self._idx_to_entries[existing_faiss_id].append(i)


class _PDQFaissIndex:
"""
A wrapper around the faiss index for pickle serialization
"""

def __init__(self, faiss_index: faiss.Index) -> None:
self.faiss_index = faiss_index

def add(self, pdq_strings: t.Sequence[str]) -> None:
"""
Add PDQ hashes to the FAISS index.
"""
vectors = convert_pdq_strings_to_ndarray(pdq_strings)
self.faiss_index.add(vectors)

def search(
self, queries: t.Sequence[str], threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
) -> t.List[t.Tuple[int, int]]:
"""
Search the FAISS index for matches to the given PDQ queries.
"""
query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries)
limits, distances, indices = self.faiss_index.range_search(
query_array, threshold + 1
)

results: t.List[t.Tuple[int, int]] = []
for i in range(len(queries)):
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]]
dists = [dist for dist in distances[limits[i] : limits[i + 1]]]
for j in range(len(matches)):
results.append((matches[j], dists[j]))
return results

def __getstate__(self):
return faiss.serialize_index(self.faiss_index)

def __setstate__(self, data):
self.faiss_index = faiss.deserialize_index(data)
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.

import numpy as np
import typing as t

BITS_IN_PDQ = 256
PDQ_HEX_STR_LEN = int(BITS_IN_PDQ / 4)

Expand Down Expand Up @@ -49,3 +52,19 @@ def pdq_match(pdq_hex_a: str, pdq_hex_b: str, threshold: int) -> bool:
"""
distance = simple_distance(pdq_hex_a, pdq_hex_b)
return distance <= threshold


def convert_pdq_strings_to_ndarray(pdq_strings: t.Sequence[str]) -> np.ndarray:
"""
Convert multiple PDQ hash strings to a numpy array.
"""
if not all(len(pdq_str) == PDQ_HEX_STR_LEN for pdq_str in pdq_strings):
raise ValueError("All PDQ hash strings must be 64 hex characters long")

binary_arrays = []
for pdq_str in pdq_strings:
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
hash_bytes = bytes.fromhex(pdq_str)
binary_array = np.unpackbits(np.frombuffer(hash_bytes, dtype=np.uint8))
binary_arrays.append(binary_array)

return np.array(binary_arrays, dtype=np.uint8)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)
from threatexchange.signal_type.pdq.pdq_index import PDQIndex

PDQ_CONFIDENT_MATCH_THRESHOLD = 31


class PdqSignal(
signal_base.SimpleSignalType,
Expand All @@ -43,7 +45,7 @@ class PdqSignal(

# This may need to be updated (TODO make more configurable)
# Hashes of distance less than or equal to this threshold are considered a 'match'
PDQ_CONFIDENT_MATCH_THRESHOLD = 31
PDQ_CONFIDENT_MATCH_THRESHOLD = PDQ_CONFIDENT_MATCH_THRESHOLD
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
# Images with less than quality 50 are too unreliable to match on
QUALITY_THRESHOLD = 50

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import typing as t
import numpy as np
import random

from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.pdq.pdq_utils import convert_pdq_strings_to_ndarray

SAMPLE_HASHES = [PdqSignal.get_random_signal() for _ in range(100)]
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved


def _brute_force_match(
base: t.List[str], query: str, threshold: int = 32
) -> t.Set[int]:
matches = set()
query_arr = convert_pdq_strings_to_ndarray([query])[0]

for i, base_hash in enumerate(base):
base_arr = convert_pdq_strings_to_ndarray([base_hash])[0]
distance = np.count_nonzero(query_arr != base_arr)
if distance <= threshold:
matches.add(i)
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
return matches


def _generate_random_hash_with_distance(hash: str, distance: int) -> str:
if len(hash) != 64 or not all(c in "0123456789abcdef" for c in hash.lower()):
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Hash must be a 64-character hexadecimal string")
if distance < 0 or distance > 256:
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Distance must be between 0 and 256")

hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary
bits = list(hash_bits)
positions = random.sample(
range(256), distance
) # Randomly select unique positions to flip
for pos in positions:
bits[pos] = "0" if bits[pos] == "1" else "1" # Flip selected bit positions
modified_hash = hex(int("".join(bits), 2))[2:].zfill(64) # Convert back to hex
Dcallies marked this conversation as resolved.
Show resolved Hide resolved

return modified_hash


def test_pdq_index():
# Make sure base_hashes and query_hashes have at least 100 similar hashes
base_hashes = SAMPLE_HASHES + [PdqSignal.get_random_signal() for _ in range(1000)]
query_hashes = SAMPLE_HASHES + [PdqSignal.get_random_signal() for _ in range(10000)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add SAMPLE_HASHES here? If you are trying to guarantee matches, you could generate less query hashes and copy the banked hashes in


brute_force_matches = {
query_hash: _brute_force_match(base_hashes, query_hash)
for query_hash in query_hashes
}

index = PDQIndex2()
for i, base_hash in enumerate(base_hashes):
index.add(base_hash, i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: PDQIndex2.build()


for query_hash in query_hashes:
expected_indices = brute_force_matches[query_hash]
index_results = index.query(query_hash)

result_indices = {result.metadata for result in index_results}

assert result_indices == expected_indices, (
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
f"Mismatch for hash {query_hash}: "
f"Expected {expected_indices}, Got {result_indices}"
)


def test_pdq_index_with_exact_distance():
thresholds: t.List[int] = [10, 31, 50]
indexes: t.List[PDQIndex2] = []
for thres in thresholds:
index = PDQIndex2(
entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES],
threshold=thres,
)
indexes.append(index)
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved

distances: t.List[int] = [0, 1, 20, 30, 31, 60]
query_hash = SAMPLE_HASHES[0]

for i in range(len(indexes)):
index = indexes[i]

for dist in distances:
query_hash_w_dist = _generate_random_hash_with_distance(query_hash, dist)
results = index.query(query_hash_w_dist)
result_indices = {result.similarity_info.distance for result in results}
if dist <= thresholds[i]:
assert dist in result_indices


def test_empty_index_query():
"""Test querying an empty index."""
index = PDQIndex2()

# Query should return empty list
results = index.query(PdqSignal.get_random_signal())
assert len(results) == 0


def test_sample_set_no_match():
"""Test no matches in sample set."""
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES])
results = index.query("b" * 64)
assert len(results) == 0


def test_duplicate_handling():
"""Test how the index handles duplicate entries."""
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES])

# Add same hash multiple times
index.add_all(entries=[(SAMPLE_HASHES[0], i) for i in range(3)])

results = index.query(SAMPLE_HASHES[0])

# Should find all entries associated with the hash
assert len(results) == 4
for result in results:
assert result.similarity_info.distance == 0


def test_one_entry_sample_index():
"""Test how the index handles when it only has one entry."""
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
index = PDQIndex2(entries=[(SAMPLE_HASHES[0], 0)])

matching_test_hash = SAMPLE_HASHES[0] # This is the existing hash in index
unmatching_test_hash = SAMPLE_HASHES[1]

results = index.query(matching_test_hash)
# Should find 1 entry associated with the hash
assert len(results) == 1
assert results[0].similarity_info.distance == 0

results = index.query(unmatching_test_hash)
assert len(results) == 0