From d373e69da1312401d3b8976948c43d4da99f73e8 Mon Sep 17 00:00:00 2001 From: dcallies Date: Thu, 11 Jan 2024 10:28:14 -0500 Subject: [PATCH 1/2] [omm] Fix seed data commands; add to ui --- open-media-match/src/OpenMediaMatch/app.py | 37 +++---------------- .../src/OpenMediaMatch/blueprints/ui.py | 13 +++++++ .../templates/components/dev_mode_bar.html.j2 | 14 +++++-- 3 files changed, 28 insertions(+), 36 deletions(-) diff --git a/open-media-match/src/OpenMediaMatch/app.py b/open-media-match/src/OpenMediaMatch/app.py index 6b9d1fbcd..15328f11a 100644 --- a/open-media-match/src/OpenMediaMatch/app.py +++ b/open-media-match/src/OpenMediaMatch/app.py @@ -35,6 +35,7 @@ from OpenMediaMatch.persistence import get_storage from OpenMediaMatch.blueprints import development, hashing, matching, curation, ui from OpenMediaMatch.storage.interface import BankConfig +from OpenMediaMatch.utils import dev_utils def _is_debug_mode(): @@ -177,18 +178,9 @@ def site_map(): return routes @app.cli.command("seed") - def seed_data(): - """Insert plausible-looking data into the database layer""" - from threatexchange.signal_type.pdq.signal import PdqSignal - - bank_name = "SEED_BANK" - - storage = get_storage() - storage.bank_update(BankConfig(name=bank_name, matching_enabled_ratio=1.0)) - - for st in (PdqSignal, VideoMD5Signal): - for example in st.get_examples(): - storage.bank_add_content(bank_name, {st.get_name(): example}) + def seed_data() -> None: + """Add sample data API connection""" + dev_utils.seed_sample() @app.cli.command("big-seed") @click.option("-b", "--banks", default=100, show_default=True) @@ -198,26 +190,7 @@ def seed_enourmous(banks: int, seeds: int) -> None: Seed the database with a large number of banks and hashes It will generate n banks and put n/m hashes on each bank """ - storage = get_storage() - - types: list[t.Type[CanGenerateRandomSignal]] = [PdqSignal, VideoMD5Signal] - - for i in range(banks): - # create bank - bank = BankConfig(name=f"SEED_BANK_{i}", matching_enabled_ratio=1.0) - storage.bank_update(bank, create=True) - - # Add hashes - for _ in range(seeds // banks): - # grab randomly either PDQ or MD5 signal - signal_type = random.choice(types) - random_hash = signal_type.get_random_signal() - - storage.bank_add_content( - bank.name, {t.cast(t.Type[SignalType], signal_type): random_hash} - ) - - print("Finished adding hashes to", bank.name) + dev_utils.seed_banks_random(banks, seeds) @app.cli.command("fetch") def fetch(): diff --git a/open-media-match/src/OpenMediaMatch/blueprints/ui.py b/open-media-match/src/OpenMediaMatch/blueprints/ui.py index af3678c34..7c0d4c4b6 100644 --- a/open-media-match/src/OpenMediaMatch/blueprints/ui.py +++ b/open-media-match/src/OpenMediaMatch/blueprints/ui.py @@ -8,6 +8,7 @@ from OpenMediaMatch.blueprints import matching, curation, hashing from OpenMediaMatch.persistence import get_storage +from OpenMediaMatch.utils import dev_utils from OpenMediaMatch.storage.postgres.flask_utils import reset_tables from OpenMediaMatch.storage.postgres.database import db from OpenMediaMatch.utils.time_utils import duration_to_human_str @@ -124,6 +125,18 @@ def upload(): return {"hashes": signals, "banks": sorted(banks)} +@bp.route("/seed_sample", methods=["POST"]) +def seed_sample(): + dev_utils.seed_sample() + return redirect("./") + + +@bp.route("/seed_banks", methods=["POST"]) +def seed_banks(): + dev_utils.seed_banks_random() + return redirect("./") + + @bp.route("/factory_reset", methods=["POST"]) def factory_reset(): reset_tables() diff --git a/open-media-match/src/OpenMediaMatch/templates/components/dev_mode_bar.html.j2 b/open-media-match/src/OpenMediaMatch/templates/components/dev_mode_bar.html.j2 index ce6dbb630..688f5e5bb 100644 --- a/open-media-match/src/OpenMediaMatch/templates/components/dev_mode_bar.html.j2 +++ b/open-media-match/src/OpenMediaMatch/templates/components/dev_mode_bar.html.j2 @@ -2,11 +2,17 @@ \ No newline at end of file + From 07a58942ce505f08871b686ee8e6f83dc2133631 Mon Sep 17 00:00:00 2001 From: dcallies Date: Thu, 11 Jan 2024 14:17:44 -0500 Subject: [PATCH 2/2] fixup cli --- .../src/OpenMediaMatch/utils/dev_utils.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 open-media-match/src/OpenMediaMatch/utils/dev_utils.py diff --git a/open-media-match/src/OpenMediaMatch/utils/dev_utils.py b/open-media-match/src/OpenMediaMatch/utils/dev_utils.py new file mode 100644 index 000000000..7d7e98c7e --- /dev/null +++ b/open-media-match/src/OpenMediaMatch/utils/dev_utils.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import typing as t + +from threatexchange.signal_type.pdq.signal import PdqSignal +from threatexchange.signal_type.md5 import VideoMD5Signal +from threatexchange.exchanges.collab_config import CollaborationConfigBase +from threatexchange.exchanges.impl.static_sample import StaticSampleSignalExchangeAPI +from threatexchange.signal_type.signal_base import SignalType, CanGenerateRandomSignal + +from OpenMediaMatch import persistence +from OpenMediaMatch.storage.interface import BankConfig + + +def seed_sample() -> None: + storage = persistence.get_storage() + storage.exchange_update( + CollaborationConfigBase( + name="SEED_SAMPLE", + api=StaticSampleSignalExchangeAPI.get_name(), + enabled=True, + ), + create=True, + ) + + +def seed_banks_random(banks: int = 2, seeds: int = 10000) -> None: + """ + Seed the database with a large number of banks and hashes + It will generate n banks and put n/m hashes on each bank + """ + storage = persistence.get_storage() + + types: list[t.Type[CanGenerateRandomSignal]] = [PdqSignal, VideoMD5Signal] + + for i in range(banks): + # create bank + bank = BankConfig(name=f"SEED_BANK_{i}", matching_enabled_ratio=1.0) + storage.bank_update(bank, create=True) + + # Add hashes + for i in range(seeds // banks): + signal_type = types[i % len(types)] + random_hash = signal_type.get_random_signal() + + storage.bank_add_content( + bank.name, {t.cast(t.Type[SignalType], signal_type): random_hash} + )