From e9e823dae154bc8d370538d630ca8fb841cd40c9 Mon Sep 17 00:00:00 2001 From: Patrick Renner Date: Wed, 2 Oct 2024 15:29:50 -0500 Subject: [PATCH] update funcs to remove breaking change --- .../OpenMediaMatch/blueprints/development.py | 14 +++++++--- .../src/OpenMediaMatch/blueprints/matching.py | 27 ++++++++++++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/development.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/development.py index 3f3ed763b..6e9abc0dd 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/development.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/development.py @@ -10,7 +10,10 @@ from werkzeug.exceptions import HTTPException from OpenMediaMatch.blueprints.hashing import hash_media -from OpenMediaMatch.blueprints.matching import lookup_signal +from OpenMediaMatch.blueprints.matching import ( + lookup_signal, + lookup_signal_with_distance, +) from OpenMediaMatch.utils.flask_utils import api_error_handler from OpenMediaMatch.utils import dev_utils @@ -50,6 +53,11 @@ def query_media(): return signal_type_to_signal_map abort(500, "Something went wrong while hashing the provided media.") + include_distance = bool(request.args.get("include_distance", False)) == True + lookup_signal_func = ( + lookup_signal_with_distance if include_distance else lookup_signal + ) + # Check if signal_type is an option in the map of hashes signal_type_name = request.args.get("signal_type") if signal_type_name is not None: @@ -59,14 +67,14 @@ def query_media(): f"Requested signal type '{signal_type_name}' is not supported for the provided " "media.", ) - return lookup_signal( + return lookup_signal_func( signal_type_to_signal_map[signal_type_name], signal_type_name ) return { "matches": list( itertools.chain( *map( - lambda x: lookup_signal(x[1], x[0])["matches"], + lambda x: lookup_signal_func(x[1], x[0])["matches"], signal_type_to_signal_map.items(), ), ) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py index b63f030af..27cb5d51d 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py @@ -17,6 +17,7 @@ from threatexchange.signal_type.signal_base import SignalType from threatexchange.signal_type.index import SignalTypeIndex +from threatexchange.signal_type.index import IndexMatch from OpenMediaMatch.background_tasks.development import get_apscheduler from OpenMediaMatch.storage import interface @@ -96,14 +97,19 @@ def raw_lookup(): * Signal value (the hash) * Optional list of banks to restrict search to Output: - * List of matching with content_id and distance values + * List of matching with content_id and, if included, distance values """ signal = require_request_param("signal") signal_type_name = require_request_param("signal_type") - return lookup_signal(signal, signal_type_name) + include_distance = bool(request.args.get("include_distance", False)) == True + lookup_signal_func = ( + lookup_signal_with_distance if include_distance else lookup_signal + ) + + return lookup_signal_func(signal, signal_type_name) -def lookup_signal(signal: str, signal_type_name: str) -> dict[str, dict[str, str]]: +def query_index(signal: str, signal_type_name: str) -> IndexMatch: storage = get_storage() signal_type = _validate_and_transform_signal_type(signal_type_name, storage) @@ -119,6 +125,18 @@ def lookup_signal(signal: str, signal_type_name: str) -> dict[str, dict[str, str current_app.logger.debug("[lookup_signal] querying index") results = index.query(signal) current_app.logger.debug("[lookup_signal] query complete") + return results + + +def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]: + results = query_index(signal, signal_type_name) + return {"matches": [m.metadata for m in results]} + + +def lookup_signal_with_distance( + signal: str, signal_type_name: str +) -> dict[str, dict[str, str]]: + results = query_index(signal, signal_type_name) return { "matches": [ { @@ -320,5 +338,6 @@ def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None: return entry.index return None + def is_in_pytest(): - return "pytest" in sys.modules \ No newline at end of file + return "pytest" in sys.modules