Skip to content

Commit

Permalink
[pytx] Pytest: unittest => pytest Conversion (#1742)
Browse files Browse the repository at this point in the history
  • Loading branch information
b8zhong authored Feb 1, 2025
1 parent e351a0e commit c4dd84e
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 190 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest

import pytest
from threatexchange.signal_type import (
md5,
raw_text,
Expand All @@ -11,40 +10,47 @@
)
from threatexchange.signal_type.pdq import signal
from threatexchange.signal_type.signal_base import TextHasher
import typing as t


SIGNAL_TYPES_TO_TEST = [
md5.VideoMD5Signal,
signal.PdqSignal,
raw_text.RawTextSignal,
trend_query.TrendQuerySignal,
url_md5.UrlMD5Signal,
url.URLSignal,
]


class SignalTypeHashTest(unittest.TestCase):
def test_signal_names_unique():
"""
Sanity check for signal type hashing methods.
Verify uniqueness of signal type names across all signal types.
"""
seen: dict[str, t.Any] = {}
for signal_type in SIGNAL_TYPES_TO_TEST:
name = signal_type.get_name()
assert (
name not in seen
), f"Two signal types share the same name: {signal_type!r} and {seen[name]}"
seen[name] = signal_type

# TODO - maybe make a metaclass for this to automatically detect?
SIGNAL_TYPES_TO_TEST = [
md5.VideoMD5Signal,
signal.PdqSignal,
raw_text.RawTextSignal,
trend_query.TrendQuerySignal,
url_md5.UrlMD5Signal,
url.URLSignal,
]

def test_signal_names_unique(self):
seen = {}
for s in self.SIGNAL_TYPES_TO_TEST:
name = s.get_name()
assert (
name not in seen
), f"Two signal types share the same name: {s!r} and {seen[name]!r}"

def test_signal_types_have_content(self):
for s in self.SIGNAL_TYPES_TO_TEST:
assert s.get_content_types(), "{s!r} has no content types"

def test_str_hashers_have_impl(self):
text_hashers = [
s for s in self.SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)
]
for s in text_hashers:
assert s.hash_from_str(
"test string"
), "{s!r} produced no output from hasher"

@pytest.mark.parametrize("signal_type", SIGNAL_TYPES_TO_TEST)
def test_signal_types_have_content(signal_type):
"""
Ensure that each signal type has associated content types.
"""
assert signal_type.get_content_types(), f"{signal_type!r} has no content types"


@pytest.mark.parametrize(
"signal_type", [s for s in SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)]
)
def test_str_hashers_have_impl(signal_type):
"""
Check that each TextHasher has an implementation that produces output.
"""
assert signal_type.hash_from_str(
"test string"
), f"{signal_type!r} produced no output from hasher"
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest
import pathlib

import pytest
from threatexchange.signal_type.md5 import VideoMD5Signal

TEST_FILE = pathlib.Path(__file__).parent.parent.parent.parent.joinpath(
"data", "sample-b.jpg"
)


class VideoMD5SignalTestCase(unittest.TestCase):
def setUp(self):
self.a_file = open(TEST_FILE, "rb")

def tearDown(self):
self.a_file.close()
def test_can_hash_simple_files():
"""
Test that the VideoMD5Signal produces the expected hash.
"""
with open(TEST_FILE, "rb") as f:
file_content = f.read()

def test_can_hash_simple_files(self):
assert "d35c785545392755e7e4164457657269" == VideoMD5Signal.hash_from_bytes(
self.a_file.read()
), "MD5 hash does not match"
expected_hash = "d35c785545392755e7e4164457657269"
computed_hash = VideoMD5Signal.hash_from_bytes(file_content)
assert computed_hash == expected_hash, "MD5 hash does not match"
197 changes: 85 additions & 112 deletions python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest
import pickle
import typing as t
import pytest
import functools

from threatexchange.signal_type.index import (
Expand All @@ -13,139 +11,72 @@
test_entries = [
(
"0000000000000000000000000000000000000000000000000000000000000000",
dict(
{
"hash_type": "pdq",
"system_id": 9,
}
),
{"hash_type": "pdq", "system_id": 9},
),
(
"000000000000000000000000000000000000000000000000000000000000ffff",
dict(
{
"hash_type": "pdq",
"system_id": 8,
}
),
{"hash_type": "pdq", "system_id": 8},
),
(
"0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f",
dict(
{
"hash_type": "pdq",
"system_id": 7,
}
),
{"hash_type": "pdq", "system_id": 7},
),
(
"f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0",
dict(
{
"hash_type": "pdq",
"system_id": 6,
}
),
{"hash_type": "pdq", "system_id": 6},
),
(
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
dict(
{
"hash_type": "pdq",
"system_id": 5,
}
),
{"hash_type": "pdq", "system_id": 5},
),
]


class TestPDQIndex(unittest.TestCase):
def setUp(self):
self.index = PDQIndex.build(test_entries)
@pytest.fixture
def index():
return PDQIndex.build(test_entries)

def assertEqualPDQIndexMatchResults(
self, result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]
):
self.assertEqual(
len(result), len(expected), "search results not of expected length"
)

accum_type = t.Dict[int, t.Set[int]]

# Between python 3.8.6 and 3.8.11, something caused the order of results
# from the index to change. This was noticed for items which had the
# same distance. To allow for this, we convert result and expected
# arrays from
# [PDQIndexMatch, PDQIndexMatch] to { distance: <set of PDQIndexMatch.metadata hash> }
# This allows you to compare [PDQIndexMatch A, PDQIndexMatch B] with
# [PDQIndexMatch B, PDQIndexMatch A] as long as A.distance == B.distance.
def quality_indexed_dict_reducer(
acc: accum_type, item: PDQIndexMatch
) -> accum_type:
acc[item.similarity_info.distance] = acc.get(
item.similarity_info.distance, set()
)
# Instead of storing the unhashable item.metadata dict, store its
# hash so we can compare using self.assertSetEqual
acc[item.similarity_info.distance].add(hash(frozenset(item.metadata)))
return acc

# Convert results to distance -> set of metadata map
distance_to_result_items_map: accum_type = functools.reduce(
quality_indexed_dict_reducer, result, {}
)
def assert_equal_pdq_index_match_results(
result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]
):
assert len(result) == len(expected), "Search results not of expected length"

# Convert expected to distance -> set of metadata map
distance_to_expected_items_map: accum_type = functools.reduce(
quality_indexed_dict_reducer, expected, {}
def quality_indexed_dict_reducer(
acc: t.Dict[int, t.Set[int]], item: PDQIndexMatch
) -> t.Dict[int, t.Set[int]]:
acc[item.similarity_info.distance] = acc.get(
item.similarity_info.distance, set()
)
acc[item.similarity_info.distance].add(hash(frozenset(item.metadata)))
return acc

assert len(distance_to_expected_items_map) == len(
distance_to_result_items_map
), "Unequal number of items in expected and results."

for distance, result_items in distance_to_result_items_map.items():
assert (
distance in distance_to_expected_items_map
), f"Unexpected distance {distance} found"
self.assertSetEqual(result_items, distance_to_expected_items_map[distance])

def test_search_index_for_matches(self):
entry_hash = test_entries[1][0]
result = self.index.query(entry_hash)
self.assertEqualPDQIndexMatchResults(
result,
[
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]
),
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]
),
],
)

def test_search_index_with_no_match(self):
query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
result = self.index.query(query_hash)
self.assertEqualPDQIndexMatchResults(result, [])
distance_to_result_items_map: t.Dict[int, t.Set[int]] = functools.reduce(
quality_indexed_dict_reducer, result, {}
)
distance_to_expected_items_map: t.Dict[int, t.Set[int]] = functools.reduce(
quality_indexed_dict_reducer, expected, {}
)

def test_supports_pickling(self):
pickled_data = pickle.dumps(self.index)
assert pickled_data != None, "index does not support pickling to a data stream"
assert len(distance_to_expected_items_map) == len(
distance_to_result_items_map
), "Unequal number of distance groups"

reconstructed_index = pickle.loads(pickled_data)
assert (
reconstructed_index != None
), "index does not support unpickling from data stream"
for distance, result_items in distance_to_result_items_map.items():
assert (
reconstructed_index.index.faiss_index != self.index.index.faiss_index
), "unpickling should create it's own faiss index in memory"
distance in distance_to_expected_items_map
), f"Unexpected distance {distance} found in results"
assert result_items == distance_to_expected_items_map[distance], (
f"Mismatch at distance {distance}. "
f"Expected: {distance_to_expected_items_map[distance]}, Got: {result_items}"
)


query = test_entries[0][0]
result = reconstructed_index.query(query)
self.assertEqualPDQIndexMatchResults(
result,
@pytest.mark.parametrize(
"entry_hash, expected_matches",
[
(
test_entries[1][0],
[
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]
Expand All @@ -154,4 +85,46 @@ def test_supports_pickling(self):
SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]
),
],
)
),
(
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
[],
),
],
)
def test_search_index(index, entry_hash, expected_matches):
result = index.query(entry_hash)
assert_equal_pdq_index_match_results(result, expected_matches)


def test_partial_match_below_threshold(index):
query_hash = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000"
result = index.query(query_hash)
assert_equal_pdq_index_match_results(result, [])


def test_supports_pickling(index):
pickled_data = pickle.dumps(index)
assert pickled_data is not None, "Index does not support pickling to a data stream"

reconstructed_index = pickle.loads(pickled_data)
assert (
reconstructed_index is not None
), "Index does not support unpickling from data stream"
assert (
reconstructed_index.index.faiss_index != index.index.faiss_index
), "Unpickling should create its own FAISS index in memory"

assert len(reconstructed_index) == len(
index
), "Index size mismatch after unpickling"

query = test_entries[0][0]
result = reconstructed_index.query(query)
assert_equal_pdq_index_match_results(
result,
[
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]),
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ def test_pdq_index_with_exact_distance():
def test_serialize_deserialize_index():
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])
index: PDQIndex2 = PDQIndex2(
entries=[(h, base_hashes.index(h)) for h in base_hashes]
)

buffer = io.BytesIO()
index.serialize(buffer)
buffer.seek(0)
deserialized_index = PDQIndex2.deserialize(buffer)
deserialized_index: PDQIndex2 = PDQIndex2.deserialize(buffer)

assert isinstance(deserialized_index, PDQIndex2)
assert isinstance(deserialized_index._index.faiss_index, faiss.IndexFlatL2)
Expand All @@ -120,7 +122,7 @@ def test_serialize_deserialize_index():

def test_empty_index_query():
"""Test querying an empty index."""
index = PDQIndex2()
index: PDQIndex2 = PDQIndex2()

# Query should return empty list
results = index.query(PdqSignal.get_random_signal())
Expand Down
Loading

0 comments on commit c4dd84e

Please sign in to comment.