Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytx] Implement a new cleaner PDQ index solution #1698

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Implementation of SignalTypeIndex abstraction for PDQ
"""

import typing as t
import faiss
import numpy as np


from threatexchange.signal_type.index import (
IndexMatchUntyped,
SignalSimilarityInfoWithIntDistance,
SignalTypeIndex,
T as IndexT,
)
from threatexchange.signal_type.pdq.pdq_utils import (
BITS_IN_PDQ,
PDQ_CONFIDENT_MATCH_THRESHOLD,
convert_pdq_strings_to_ndarray,
)

PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT]


class PDQIndex2(SignalTypeIndex[IndexT]):
"""
Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search.

This is a redo of the existing PDQ index,
designed to be simpler and fix hard-to-squash bugs in the existing implementation.
Purpose of this class: to replace the original index in pytx 2.0
"""

def __init__(
self,
index: t.Optional[faiss.Index] = None,
entries: t.Iterable[t.Tuple[str, IndexT]] = (),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can probably leave this off and rely on build

*,
threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD,
) -> None:
super().__init__()
self.threshold = threshold

if index is None:
index = faiss.IndexFlatL2(BITS_IN_PDQ)
self._index = _PDQFaissIndex(index)

# Matches hash to Faiss index
self._deduper: t.Dict[str, int] = {}
# Entry mapping: Each list[entries]'s index is its hash's index
self._idx_to_entries: t.List[t.List[IndexT]] = []

self.add_all(entries=entries)

def __len__(self) -> int:
return len(self._idx_to_entries)

def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]:
"""
Look up entries against the index, up to the threshold.
"""
results: t.List[PDQIndexMatch[IndexT]] = []
matches_list: t.List[t.Tuple[int, int]] = self._index.search(
queries=[hash], threshold=self.threshold
)

for match, distance in matches_list:
entries = self._idx_to_entries[match]
# Create match objects for each entry
results.extend(
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(distance=int(distance)),
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
entry,
)
for entry in entries
)
return results

def add(self, signal_str: str, entry: IndexT) -> None:
self.add_all(((signal_str, entry),))

def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None:
for h, i in entries:
existing_faiss_id = self._deduper.get(h)
if existing_faiss_id is None:
self._index.add([h])
self._idx_to_entries.append([i])
next_id = len(self._deduper) # Because faiss index starts from 0 up
self._deduper[h] = next_id
else:
# Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication
self._idx_to_entries[existing_faiss_id].append(i)


class _PDQFaissIndex:
"""
A wrapper around the faiss index for pickle serialization
"""

def __init__(self, faiss_index: faiss.Index) -> None:
self.faiss_index = faiss_index

def add(self, pdq_strings: t.Sequence[str]) -> None:
"""
Add PDQ hashes to the FAISS index.
"""
vectors = convert_pdq_strings_to_ndarray(pdq_strings)
self.faiss_index.add(vectors)

def search(
self, queries: t.Sequence[str], threshold: int
) -> t.List[t.Tuple[int, int]]:
"""
Search the FAISS index for matches to the given PDQ queries.
"""
query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries)
limits, distances, indices = self.faiss_index.range_search(
query_array, threshold + 1
)

results: t.List[t.Tuple[int, int]] = []
for i in range(len(queries)):
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]]
dists = [dist for dist in distances[limits[i] : limits[i + 1]]]
for j in range(len(matches)):
results.append((matches[j], dists[j]))
return results

def __getstate__(self):
return faiss.serialize_index(self.faiss_index)

def __setstate__(self, data):
self.faiss_index = faiss.deserialize_index(data)
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.

import numpy as np
import typing as t

BITS_IN_PDQ = 256
PDQ_HEX_STR_LEN = int(BITS_IN_PDQ / 4)
# Hashes of distance less than or equal to this threshold are considered a 'match'
PDQ_CONFIDENT_MATCH_THRESHOLD = 31


def simple_distance_binary(bin_a: str, bin_b: str) -> int:
Expand Down Expand Up @@ -49,3 +54,18 @@ def pdq_match(pdq_hex_a: str, pdq_hex_b: str, threshold: int) -> bool:
"""
distance = simple_distance(pdq_hex_a, pdq_hex_b)
return distance <= threshold


def convert_pdq_strings_to_ndarray(pdq_strings: t.Iterable[str]) -> np.ndarray:
"""
Convert multiple PDQ hash strings to a numpy array.
"""
binary_arrays = []
for pdq_str in pdq_strings:
if len(pdq_str) != PDQ_HEX_STR_LEN:
raise ValueError("PDQ hash string must be 64 hex characters long")
hash_bytes = bytes.fromhex(pdq_str)
binary_array = np.unpackbits(np.frombuffer(hash_bytes, dtype=np.uint8))
binary_arrays.append(binary_array)

return np.array(binary_arrays, dtype=np.uint8)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.signal_type import signal_base
from threatexchange.signal_type.pdq.pdq_utils import simple_distance
from threatexchange.signal_type.pdq.pdq_utils import (
simple_distance,
PDQ_CONFIDENT_MATCH_THRESHOLD,
)
from threatexchange.exchanges.impl.fb_threatexchange_signal import (
HasFbThreatExchangeIndicatorType,
)
Expand Down Expand Up @@ -42,8 +45,6 @@ class PdqSignal(
INDICATOR_TYPE = "HASH_PDQ"

# This may need to be updated (TODO make more configurable)
# Hashes of distance less than or equal to this threshold are considered a 'match'
PDQ_CONFIDENT_MATCH_THRESHOLD = 31
# Images with less than quality 50 are too unreliable to match on
QUALITY_THRESHOLD = 50

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import typing as t
import numpy as np
import random

from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.pdq.pdq_utils import simple_distance


def _generate_sample_hashes(size: int, seed: int = 42):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

42

image

Copy link
Contributor Author

@haianhng31 haianhng31 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry what do you mean by the image:P ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

42 is the "Answer to the Ultimate Question of Life, The Universe, and Everything" in the book series "The Hitchhiker's Guide to the Galaxy", of which the image is on the cover. I thought you were a fan :P

Read more: https://en.wikipedia.org/wiki/Phrases_from_The_Hitchhiker%27s_Guide_to_the_Galaxy

random.seed(seed)
return [PdqSignal.get_random_signal() for _ in range(size)]


SAMPLE_HASHES = _generate_sample_hashes(100)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for me!



def _brute_force_match(
base: t.List[str], query: str, threshold: int = 32
) -> t.Set[t.Tuple[int, int]]:
matches = set()

for i, base_hash in enumerate(base):
distance = simple_distance(base_hash, query)
if distance <= threshold:
matches.add((i, distance))
return matches


def _generate_random_hash_with_distance(hash: str, distance: int) -> str:
if not (0 <= distance <= 256):
raise ValueError("Distance must be between 0 and 256")

hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary
bits = list(hash_bits)
positions = random.sample(
range(256), distance
) # Randomly select unique positions to flip
for pos in positions:
bits[pos] = "0" if bits[pos] == "1" else "1" # Flip selected bit positions
modified_hash = hex(int("".join(bits), 2))[2:].zfill(64) # Convert back to hex
Dcallies marked this conversation as resolved.
Show resolved Hide resolved

return modified_hash


def test_pdq_index():
# Make sure base_hashes and query_hashes have at least 10 similar hashes
base_hashes = SAMPLE_HASHES
query_hashes = SAMPLE_HASHES[:10] + _generate_sample_hashes(10)
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved

brute_force_matches = {
query_hash: _brute_force_match(base_hashes, query_hash)
for query_hash in query_hashes
}

index = PDQIndex2()
for i, base_hash in enumerate(base_hashes):
index.add(base_hash, i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: PDQIndex2.build()


for query_hash in query_hashes:
expected_indices = brute_force_matches[query_hash]
index_results = index.query(query_hash)

result_indices: t.Set[t.Tuple[t.Any, int]] = {
(result.metadata, result.similarity_info.distance)
for result in index_results
}

assert result_indices == expected_indices, (
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
f"Mismatch for hash {query_hash}: "
f"Expected {expected_indices}, Got {result_indices}"
)


def test_pdq_index_with_exact_distance():
thresholds: t.List[int] = [10, 31, 50]

indexes = [
PDQIndex2(
entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES],
threshold=thres,
)
for thres in thresholds
]

distances: t.List[int] = [0, 1, 20, 30, 31, 60]
query_hash = SAMPLE_HASHES[0]

for i in range(len(indexes)):
index = indexes[i]

for dist in distances:
query_hash_w_dist = _generate_random_hash_with_distance(query_hash, dist)
results = index.query(query_hash_w_dist)
result_indices = {result.similarity_info.distance for result in results}
if dist <= thresholds[i]:
assert dist in result_indices


def test_empty_index_query():
"""Test querying an empty index."""
index = PDQIndex2()

# Query should return empty list
results = index.query(PdqSignal.get_random_signal())
assert len(results) == 0


def test_sample_set_no_match():
"""Test no matches in sample set."""
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES])
results = index.query("b" * 64)
assert len(results) == 0


def test_duplicate_handling():
"""Test how the index handles duplicate entries."""
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES])

# Add same hash multiple times
index.add_all(entries=[(SAMPLE_HASHES[0], i) for i in range(3)])

results = index.query(SAMPLE_HASHES[0])

# Should find all entries associated with the hash
assert len(results) == 4
for result in results:
assert result.similarity_info.distance == 0


def test_one_entry_sample_index():
"""Test how the index handles when it only has one entry."""
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
index = PDQIndex2(entries=[(SAMPLE_HASHES[0], 0)])

matching_test_hash = SAMPLE_HASHES[0] # This is the existing hash in index
unmatching_test_hash = SAMPLE_HASHES[1]

results = index.query(matching_test_hash)
# Should find 1 entry associated with the hash
assert len(results) == 1
assert results[0].similarity_info.distance == 0

results = index.query(unmatching_test_hash)
assert len(results) == 0