Skip to content

Commit

Permalink
update funcs to remove breaking change
Browse files Browse the repository at this point in the history
  • Loading branch information
prenner committed Oct 2, 2024
1 parent 982ae6a commit e9e823d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from werkzeug.exceptions import HTTPException

from OpenMediaMatch.blueprints.hashing import hash_media
from OpenMediaMatch.blueprints.matching import lookup_signal
from OpenMediaMatch.blueprints.matching import (
lookup_signal,
lookup_signal_with_distance,
)
from OpenMediaMatch.utils.flask_utils import api_error_handler

from OpenMediaMatch.utils import dev_utils
Expand Down Expand Up @@ -50,6 +53,11 @@ def query_media():
return signal_type_to_signal_map
abort(500, "Something went wrong while hashing the provided media.")

include_distance = bool(request.args.get("include_distance", False)) == True
lookup_signal_func = (
lookup_signal_with_distance if include_distance else lookup_signal
)

# Check if signal_type is an option in the map of hashes
signal_type_name = request.args.get("signal_type")
if signal_type_name is not None:
Expand All @@ -59,14 +67,14 @@ def query_media():
f"Requested signal type '{signal_type_name}' is not supported for the provided "
"media.",
)
return lookup_signal(
return lookup_signal_func(
signal_type_to_signal_map[signal_type_name], signal_type_name
)
return {
"matches": list(
itertools.chain(
*map(
lambda x: lookup_signal(x[1], x[0])["matches"],
lambda x: lookup_signal_func(x[1], x[0])["matches"],
signal_type_to_signal_map.items(),
),
)
Expand Down
27 changes: 23 additions & 4 deletions hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from threatexchange.signal_type.signal_base import SignalType
from threatexchange.signal_type.index import SignalTypeIndex
from threatexchange.signal_type.index import IndexMatch

from OpenMediaMatch.background_tasks.development import get_apscheduler
from OpenMediaMatch.storage import interface
Expand Down Expand Up @@ -96,14 +97,19 @@ def raw_lookup():
* Signal value (the hash)
* Optional list of banks to restrict search to
Output:
* List of matching with content_id and distance values
* List of matching with content_id and, if included, distance values
"""
signal = require_request_param("signal")
signal_type_name = require_request_param("signal_type")
return lookup_signal(signal, signal_type_name)
include_distance = bool(request.args.get("include_distance", False)) == True
lookup_signal_func = (
lookup_signal_with_distance if include_distance else lookup_signal
)

return lookup_signal_func(signal, signal_type_name)


def lookup_signal(signal: str, signal_type_name: str) -> dict[str, dict[str, str]]:
def query_index(signal: str, signal_type_name: str) -> IndexMatch:
storage = get_storage()
signal_type = _validate_and_transform_signal_type(signal_type_name, storage)

Expand All @@ -119,6 +125,18 @@ def lookup_signal(signal: str, signal_type_name: str) -> dict[str, dict[str, str
current_app.logger.debug("[lookup_signal] querying index")
results = index.query(signal)
current_app.logger.debug("[lookup_signal] query complete")
return results


def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]:
results = query_index(signal, signal_type_name)
return {"matches": [m.metadata for m in results]}


def lookup_signal_with_distance(
signal: str, signal_type_name: str
) -> dict[str, dict[str, str]]:
results = query_index(signal, signal_type_name)
return {
"matches": [
{
Expand Down Expand Up @@ -320,5 +338,6 @@ def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None:
return entry.index
return None


def is_in_pytest():
return "pytest" in sys.modules
return "pytest" in sys.modules

0 comments on commit e9e823d

Please sign in to comment.