Skip to content

Commit

Permalink
[pytx] ncmec: store checkpoint on large fetches (#1731)
Browse files Browse the repository at this point in the history
  • Loading branch information
prenner authored Jan 30, 2025
1 parent 2c6210d commit 35c206e
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,11 @@ def get_entries(
)

def get_entries_iter(
self, *, start_timestamp: int = 0, end_timestamp: int = 0
self,
*,
start_timestamp: int = 0,
end_timestamp: int = 0,
checkpointed_paging_url: str = "",
) -> t.Iterator[GetEntriesResponse]:
"""
A simple wrapper around get_entries to keep fetching until complete.
Expand All @@ -574,7 +578,7 @@ def get_entries_iter(
much of the data you have fetched. @see get_entries
"""
has_more = True
next_ = ""
next_ = checkpointed_paging_url
while has_more:
result = self.get_entries(
start_timestamp=start_timestamp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
"&to=2017-10-30T00%3A00%3A00.000Z&start=4001&size=1000&max=5000"
)

ENTRIES_NO_DATA_XML = """
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<queryResult xmlns="https://hashsharing.ncmec.org/hashsharing/v2">
<images count="0"/>
<videos count="0"/>
</queryResult>
""".strip()

ENTRIES_XML = """
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<queryResult xmlns="https://hashsharing.ncmec.org/hashsharing/v2">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from threatexchange.exchanges.clients.ncmec.tests.data import (
ENTRIES_LARGE_FINGERPRINTS,
ENTRIES_NO_DATA_XML,
ENTRIES_XML,
ENTRIES_XML2,
ENTRIES_XML3,
Expand Down Expand Up @@ -78,6 +79,30 @@ def api(monkeypatch: pytest.MonkeyPatch):
return api


@pytest.fixture
def empty_api_response(monkeypatch: pytest.MonkeyPatch):
api = NCMECHashAPI("fake_user", "fake_pass", NCMECEnvironment.test_Industry)

def _mock_get_impl(url: str, **params):
content = ENTRIES_NO_DATA_XML
resp = requests.Response()
resp._content = content.encode()
resp.status_code = 200
resp.content # Set the rest of Request's internal state
return resp

session = None
session = Mock(
strict_spec=["get", "__enter__", "__exit__"],
get=_mock_get_impl,
_put=Mock(),
__enter__=lambda _: session,
__exit__=lambda *args: None,
)
monkeypatch.setattr(api, "_get_session", lambda: session)
return api


def assert_first_entry(entry: NCMECEntryUpdate) -> None:
assert entry.id == "image1"
assert entry.member_id == 42
Expand Down
72 changes: 58 additions & 14 deletions python-threatexchange/threatexchange/exchanges/impl/ncmec_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,43 @@ class NCMECCheckpoint(

# The biggest value of "to", and the next "from"
get_entries_max_ts: int
# A url to fetch the next page of results
# Only reference this value through `paging_url` property
_paging_url: str = ""
# a timestamp for the last fetch time, specifically used with a paging_url
# NCMEC suggests not storing paging_urls long term so we consider them invalid
# 12hr after the last_fetch_time
last_fetch_time: int = field(default_factory=lambda: int(time.time()))

def get_progress_timestamp(self) -> t.Optional[int]:
return self.get_entries_max_ts

@property
def paging_url(self) -> str:
PAGING_URL_EXPIRATION = 12 * 60 * 60
if int(time.time()) - self.last_fetch_time < PAGING_URL_EXPIRATION:
return self._paging_url
return ""

@classmethod
def from_ncmec_fetch(cls, response: api.GetEntriesResponse) -> "NCMECCheckpoint":
"""Synthesizes a checkpoint from the API response"""
return cls(response.max_timestamp)
return cls(response.max_timestamp, response.next, int(time.time()))

def __setstate__(self, d: t.Dict[str, t.Any]) -> None:
"""Implemented for pickle version compatibility."""
# 0.99.0 => 1.0.0:
### field 'max_timestamp' renamed to 'get_entries_max_ts'
if "max_timestamp" in d:
d["get_entries_max_ts"] = d.pop("max_timestamp")

# 1.0.0 => 1.2.3:
# Add last_fetch_time
# note: the default_factory value was not being set correctly when
# reading from pickle
if not "last_fetch_time" in d:
d["last_fetch_time"] = int(time.time())

self.__dict__ = d


Expand Down Expand Up @@ -240,8 +262,10 @@ def fetch_iter(
the cursor
"""
start_time = 0
current_paging_url = ""
if checkpoint is not None:
start_time = checkpoint.get_entries_max_ts
current_paging_url = checkpoint.paging_url
# Avoid being exactly at end time for updates showing up multiple
# times in the fetch, since entries are not ordered by time
end_time = int(time.time()) - 5
Expand Down Expand Up @@ -274,13 +298,17 @@ def log(event: str) -> None:
duration = max(1, duration) # Infinite loop defense
# Don't fetch past our designated end
current_end = min(end_time, current_start + duration)
updates: t.List[api.NCMECEntryUpdate] = []
entry = None
for i, entry in enumerate(
client.get_entries_iter(
start_timestamp=current_start, end_timestamp=current_end
start_timestamp=current_start,
end_timestamp=current_end,
checkpointed_paging_url=current_paging_url,
)
):
if i == 0: # First batch, check for overfetch
if (
i == 0 and not current_paging_url
): # First batch, check for overfetch when not using a checkpoint
if (
entry.estimated_entries_in_range > self.MAX_FETCH_SIZE
and duration > 1
Expand All @@ -303,16 +331,26 @@ def log(event: str) -> None:
# Our entry estimatation (based on the cursor parameters)
# occasionally seem to over-estimate
log(f"est {entry.estimated_entries_in_range} entries")
elif i % 100 == 0:
# If we get down to one second, we can potentially be
# fetching an arbitrary large amount of data in one go,
# so log something occasionally
log(f"large fetch ({i}), up to {len(updates)}")
updates.extend(entry.updates)

if i % 100 == 5:
# On large fetches, log notice every once in a while
log(f"large fetch ({i}) with {len(entry.updates)} updates.")

yield state.FetchDelta(
{f"{entry.member_id}-{entry.id}": entry for entry in entry.updates},
NCMECCheckpoint(
get_entries_max_ts=current_start,
_paging_url=entry.next,
),
)

else: # AKA a successful fetch
# If we're hovering near the single-fetch limit for a period
# of time, we can likely safely expand our range.
if len(updates) < api.NCMECHashAPI.ENTRIES_PER_FETCH * 2:
if (
entry
and len(entry.updates) < api.NCMECHashAPI.ENTRIES_PER_FETCH * 2
):
low_fetch_counter += 1
if low_fetch_counter >= self.FETCH_SHRINK_FACTOR:
log("multiple low fetches, increasing duration")
Expand All @@ -321,16 +359,22 @@ def log(event: str) -> None:
low_fetch_counter = 0
# If we are not quite at our limit, but getting close to it,
# pre-emptively shrink to try and stay under the limit
elif len(updates) > self.MAX_FETCH_SIZE / self.FETCH_SHRINK_FACTOR:
elif (
entry
and len(entry.updates)
> self.MAX_FETCH_SIZE / self.FETCH_SHRINK_FACTOR
):
log("close to overfetch limit, reducing duration")
duration //= self.FETCH_SHRINK_FACTOR
low_fetch_counter = 0
else: # Not too small, not too large, just right
low_fetch_counter = 0

yield state.FetchDelta(
{f"{entry.member_id}-{entry.id}": entry for entry in updates},
NCMECCheckpoint(current_end),
{},
NCMECCheckpoint(get_entries_max_ts=current_end),
)
current_paging_url = ""
current_start = current_end

@classmethod
Expand Down
141 changes: 120 additions & 21 deletions python-threatexchange/threatexchange/exchanges/impl/tests/test_ncmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import typing as t
import pytest

from threatexchange.exchanges.clients.ncmec.tests.test_hash_api import api
from threatexchange.exchanges.clients.ncmec.tests.test_hash_api import (
api,
empty_api_response,
)
from threatexchange.exchanges.fetch_state import FetchDelta
from threatexchange.exchanges.impl.ncmec_api import (
NCMECCollabConfig,
NCMECSignalExchangeAPI,
Expand All @@ -27,34 +31,101 @@ def exchange(api: NCMECHashAPI, monkeypatch: pytest.MonkeyPatch):
return signal_exchange


@pytest.fixture
def empty_exchange(empty_api_response: NCMECHashAPI, monkeypatch: pytest.MonkeyPatch):
collab = NCMECCollabConfig(NCMECEnvironment.Industry, "Test")
signal_exchange = NCMECSignalExchangeAPI(collab, "user", "pass")
monkeypatch.setattr(
signal_exchange, "get_client", lambda _environment: empty_api_response
)
return signal_exchange


def assert_delta(
delta: FetchDelta,
updates: set[str],
progress_timestamp: int,
is_stale: bool,
get_entries_max_ts: int,
paging_url: str,
) -> None:
assert set(delta.updates) == updates
assert len(delta.updates) == len(updates)
assert delta.checkpoint.get_progress_timestamp() == progress_timestamp
assert delta.checkpoint.is_stale() is is_stale
assert delta.checkpoint.get_entries_max_ts == get_entries_max_ts
assert delta.checkpoint.paging_url == paging_url


def test_fetch(exchange: NCMECSignalExchangeAPI, monkeypatch: pytest.MonkeyPatch):
frozen_time = 1664496000
monkeypatch.setattr("time.time", lambda: frozen_time)
it = exchange.fetch_iter([], None)
# Since our test data from test_hash_api is is all in one fetch sequence,
# we'd have to craft some specialized data to get the NCMECSignalAPI split it
# into multiple updates
total_updates: t.Dict[str, NCMECEntryUpdate] = {}

# Fetch 1
delta = next(it, None)
assert delta is not None
assert len(delta.updates) == 7
total_updates: t.Dict[str, NCMECEntryUpdate] = {}
exchange.naive_fetch_merge(total_updates, delta.updates)
assert_delta(
delta,
{
"42-image1",
"43-image4",
"42-video1",
"42-video4",
},
0,
False,
0,
"/v2/entries?from=2017-10-20T00%3A00%3A00.000Z&to=2017-10-30T00%3A00%3A00.000Z&start=2001&size=1000&max=3000",
)

assert delta.checkpoint.get_progress_timestamp() == frozen_time - 5
assert delta.checkpoint.is_stale() is False
assert delta.checkpoint.get_entries_max_ts == frozen_time - 5

assert set(delta.updates) == {
"43-image4",
"42-image1",
"42-video1",
"42-video4",
"42-image10",
"101-willdelete",
"101-willupdate",
}
# Fetch 2
delta = next(it, None)
assert delta is not None
exchange.naive_fetch_merge(total_updates, delta.updates)

assert_delta(
delta,
{"42-image10"},
0,
False,
0,
"/v2/entries?from=2017-10-20T00%3A00%3A00.000Z&to=2017-10-30T00%3A00%3A00.000Z&start=3001&size=1000&max=4000",
)

# Fetch 3
delta = next(it, None)
assert delta is not None
exchange.naive_fetch_merge(total_updates, delta.updates)
assert_delta(
delta,
{"101-willupdate", "101-willdelete"},
0,
False,
0,
"/v2/entries?from=2017-10-20T00%3A00%3A00.000Z&to=2017-10-30T00%3A00%3A00.000Z&start=4001&size=1000&max=5000",
)

# Fetch 4
delta = next(it, None)
assert delta is not None
exchange.naive_fetch_merge(total_updates, delta.updates)
assert_delta(delta, {"101-willupdate", "101-willdelete"}, 0, False, 0, "")

# No more data, but one final checkpoint
delta = next(it, None)
assert delta is not None
expected_progress_timestamp = frozen_time - 5
assert_delta(
delta,
set(),
expected_progress_timestamp,
False,
expected_progress_timestamp,
"",
)

as_signals = NCMECSignalExchangeAPI.naive_convert_to_signal_type(
[VideoMD5Signal], exchange.collab, total_updates
Expand All @@ -63,8 +134,8 @@ def test_fetch(exchange: NCMECSignalExchangeAPI, monkeypatch: pytest.MonkeyPatch
"b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1": NCMECSignalMetadata({42: set()}),
"facefacefacefacefacefacefaceface": NCMECSignalMetadata({101: {"A2"}}),
}
## No more data
assert next(it, None) is None

assert next(it, None) is None # We fetched everything

# Test esp_id filter
collab = NCMECCollabConfig(NCMECEnvironment.Industry, "Test")
Expand All @@ -90,3 +161,31 @@ def test_fetch(exchange: NCMECSignalExchangeAPI, monkeypatch: pytest.MonkeyPatch
"b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1": NCMECSignalMetadata({42: set()}),
"facefacefacefacefacefacefaceface": NCMECSignalMetadata({101: {"A2"}}),
}


def test_empty_fetch(
empty_exchange: NCMECSignalExchangeAPI, monkeypatch: pytest.MonkeyPatch
):
frozen_time = 1664496000
monkeypatch.setattr("time.time", lambda: frozen_time)
it = empty_exchange.fetch_iter([], None)
# No updates
delta = next(it, None)
assert delta is not None
assert_delta(delta, set(), 0, False, 0, "")

# No more data, but one final checkpoint
delta = next(it, None)
assert delta is not None
expected_progress_timestamp = frozen_time - 5
assert_delta(
delta,
set(),
expected_progress_timestamp,
False,
expected_progress_timestamp,
"",
)

delta = next(it, None)
assert delta is None # We fetched everything
Loading

0 comments on commit 35c206e

Please sign in to comment.