From 43a705b9396c30414c8aceb68f5e2e76abc32467 Mon Sep 17 00:00:00 2001 From: Patrick Renner Date: Wed, 2 Oct 2024 10:23:33 -0500 Subject: [PATCH 1/5] add similarity to raw_lookup endpoint --- .../src/OpenMediaMatch/blueprints/matching.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py index a68c82044..d2b4c384e 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py @@ -95,14 +95,14 @@ def raw_lookup(): * Signal value (the hash) * Optional list of banks to restrict search to Output: - * List of matching content items + * List of matching with content_id and distance values """ signal = require_request_param("signal") signal_type_name = require_request_param("signal_type") return lookup_signal(signal, signal_type_name) -def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]: +def lookup_signal(signal: str, signal_type_name: str) -> dict[str, dict[str, str]]: storage = get_storage() signal_type = _validate_and_transform_signal_type(signal_type_name, storage) @@ -118,7 +118,15 @@ def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]: current_app.logger.debug("[lookup_signal] querying index") results = index.query(signal) current_app.logger.debug("[lookup_signal] query complete") - return {"matches": [m.metadata for m in results]} + return { + "matches": [ + { + "content_id": m.metadata, + "distance": m.similarity_info.pretty_str(), + } + for m in results + ] + } def _validate_and_transform_signal_type( From 982ae6a1c39913ded58362431e42e6e3201f35ee Mon Sep 17 00:00:00 2001 From: Andrew Dillon Date: Wed, 2 Oct 2024 12:40:23 -0500 Subject: [PATCH 2/5] Make tests pass locally --- .../src/OpenMediaMatch/blueprints/matching.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py index d2b4c384e..b63f030af 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py @@ -7,6 +7,7 @@ from dataclasses import dataclass import datetime import random +import sys import typing as t import time @@ -308,9 +309,16 @@ def index_cache_is_stale() -> bool: def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None: entry = _get_index_cache().get(signal_type.get_name()) + + if entry is not None and is_in_pytest(): + entry.reload_if_needed(get_storage()) + if entry is None: current_app.logger.debug("[lookup_signal] no cache, loading index") return get_storage().get_signal_type_index(signal_type) if entry.is_ready: return entry.index return None + +def is_in_pytest(): + return "pytest" in sys.modules \ No newline at end of file From e9e823dae154bc8d370538d630ca8fb841cd40c9 Mon Sep 17 00:00:00 2001 From: Patrick Renner Date: Wed, 2 Oct 2024 15:29:50 -0500 Subject: [PATCH 3/5] update funcs to remove breaking change --- .../OpenMediaMatch/blueprints/development.py | 14 +++++++--- .../src/OpenMediaMatch/blueprints/matching.py | 27 ++++++++++++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/development.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/development.py index 3f3ed763b..6e9abc0dd 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/development.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/development.py @@ -10,7 +10,10 @@ from werkzeug.exceptions import HTTPException from OpenMediaMatch.blueprints.hashing import hash_media -from OpenMediaMatch.blueprints.matching import lookup_signal +from OpenMediaMatch.blueprints.matching import ( + lookup_signal, + lookup_signal_with_distance, +) from OpenMediaMatch.utils.flask_utils import api_error_handler from OpenMediaMatch.utils import dev_utils @@ -50,6 +53,11 @@ def query_media(): return signal_type_to_signal_map abort(500, "Something went wrong while hashing the provided media.") + include_distance = bool(request.args.get("include_distance", False)) == True + lookup_signal_func = ( + lookup_signal_with_distance if include_distance else lookup_signal + ) + # Check if signal_type is an option in the map of hashes signal_type_name = request.args.get("signal_type") if signal_type_name is not None: @@ -59,14 +67,14 @@ def query_media(): f"Requested signal type '{signal_type_name}' is not supported for the provided " "media.", ) - return lookup_signal( + return lookup_signal_func( signal_type_to_signal_map[signal_type_name], signal_type_name ) return { "matches": list( itertools.chain( *map( - lambda x: lookup_signal(x[1], x[0])["matches"], + lambda x: lookup_signal_func(x[1], x[0])["matches"], signal_type_to_signal_map.items(), ), ) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py index b63f030af..27cb5d51d 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py @@ -17,6 +17,7 @@ from threatexchange.signal_type.signal_base import SignalType from threatexchange.signal_type.index import SignalTypeIndex +from threatexchange.signal_type.index import IndexMatch from OpenMediaMatch.background_tasks.development import get_apscheduler from OpenMediaMatch.storage import interface @@ -96,14 +97,19 @@ def raw_lookup(): * Signal value (the hash) * Optional list of banks to restrict search to Output: - * List of matching with content_id and distance values + * List of matching with content_id and, if included, distance values """ signal = require_request_param("signal") signal_type_name = require_request_param("signal_type") - return lookup_signal(signal, signal_type_name) + include_distance = bool(request.args.get("include_distance", False)) == True + lookup_signal_func = ( + lookup_signal_with_distance if include_distance else lookup_signal + ) + + return lookup_signal_func(signal, signal_type_name) -def lookup_signal(signal: str, signal_type_name: str) -> dict[str, dict[str, str]]: +def query_index(signal: str, signal_type_name: str) -> IndexMatch: storage = get_storage() signal_type = _validate_and_transform_signal_type(signal_type_name, storage) @@ -119,6 +125,18 @@ def lookup_signal(signal: str, signal_type_name: str) -> dict[str, dict[str, str current_app.logger.debug("[lookup_signal] querying index") results = index.query(signal) current_app.logger.debug("[lookup_signal] query complete") + return results + + +def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]: + results = query_index(signal, signal_type_name) + return {"matches": [m.metadata for m in results]} + + +def lookup_signal_with_distance( + signal: str, signal_type_name: str +) -> dict[str, dict[str, str]]: + results = query_index(signal, signal_type_name) return { "matches": [ { @@ -320,5 +338,6 @@ def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None: return entry.index return None + def is_in_pytest(): - return "pytest" in sys.modules \ No newline at end of file + return "pytest" in sys.modules From f27623d066f2f13286b9c80edea3736f42626087 Mon Sep 17 00:00:00 2001 From: Patrick Renner Date: Wed, 2 Oct 2024 16:23:20 -0500 Subject: [PATCH 4/5] try verifying indexing directly in test_api --- .../src/OpenMediaMatch/blueprints/matching.py | 7 ------- .../src/OpenMediaMatch/tests/test_api.py | 3 +++ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py index 27cb5d51d..2c8949467 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py @@ -328,16 +328,9 @@ def index_cache_is_stale() -> bool: def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None: entry = _get_index_cache().get(signal_type.get_name()) - if entry is not None and is_in_pytest(): - entry.reload_if_needed(get_storage()) - if entry is None: current_app.logger.debug("[lookup_signal] no cache, loading index") return get_storage().get_signal_type_index(signal_type) if entry.is_ready: return entry.index return None - - -def is_in_pytest(): - return "pytest" in sys.modules diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py index 6b9ed728b..071fb6d39 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py @@ -2,6 +2,7 @@ import typing as t +from OpenMediaMatch.blueprints.matching import IndexCache, _get_index_cache from flask.testing import FlaskClient from flask import Flask @@ -192,6 +193,8 @@ def test_banks_add_hash_index(app: Flask, client: FlaskClient): # Build index build_all_indices(storage, storage, storage) + cache = _get_index_cache().get("pdq") + cache.reload_if_needed(get_storage()) # Test against first image post_response = client.get( From 89f5c7ddef9f1554e9460f841a38801ee03a5d19 Mon Sep 17 00:00:00 2001 From: Patrick Renner Date: Wed, 2 Oct 2024 16:24:26 -0500 Subject: [PATCH 5/5] Revert "try verifying indexing directly in test_api" This reverts commit f27623d066f2f13286b9c80edea3736f42626087. --- .../src/OpenMediaMatch/blueprints/matching.py | 7 +++++++ .../src/OpenMediaMatch/tests/test_api.py | 3 --- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py index 2c8949467..27cb5d51d 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py @@ -328,9 +328,16 @@ def index_cache_is_stale() -> bool: def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None: entry = _get_index_cache().get(signal_type.get_name()) + if entry is not None and is_in_pytest(): + entry.reload_if_needed(get_storage()) + if entry is None: current_app.logger.debug("[lookup_signal] no cache, loading index") return get_storage().get_signal_type_index(signal_type) if entry.is_ready: return entry.index return None + + +def is_in_pytest(): + return "pytest" in sys.modules diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py index 071fb6d39..6b9ed728b 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py @@ -2,7 +2,6 @@ import typing as t -from OpenMediaMatch.blueprints.matching import IndexCache, _get_index_cache from flask.testing import FlaskClient from flask import Flask @@ -193,8 +192,6 @@ def test_banks_add_hash_index(app: Flask, client: FlaskClient): # Build index build_all_indices(storage, storage, storage) - cache = _get_index_cache().get("pdq") - cache.reload_if_needed(get_storage()) # Test against first image post_response = client.get(