From 11dcbb9e850f8bd1c878afb49fc8467e44a12d36 Mon Sep 17 00:00:00 2001 From: Dylan Uys Date: Thu, 5 Dec 2024 13:44:30 -0800 Subject: [PATCH 1/3] Release 2.0.0 (#124) * Validator Proxy Response Update (#103) * adding rich arg, adding coldkeys and hotokeys * moving rich to payload from headers * bump version --------- Co-authored-by: benliang99 * Two new image models: SDXL finetuned on Midjourney, and SD finetuned on anime images * Added required StableDiffusionPipeline import * Updated transformers version to fix tokenizer initialization error * GPU Specification (#108) * Made gpu id specification consistent across synthetic image generation models * Changed gpu_id to device * Docstring grammar * add neuron.device to SyntheticImageGenerator init * Fixed variable names * adding device to start_validator.sh * deprecating old/biased random prompt generation * properly clear gpu of moderation pipeline * simplifying usage of self.device * fixing moderation pipeline device * explicitly defining model/tokenizer for moderation pipeline to avoid accelerate auto device management * deprecating random prompt generation --------- Co-authored-by: benliang99 * Update __init__.py bump version * removing logging * old logging removed * adding check for state file in case it is deleted somehow * removing remaining random prompt generation code * [Testnet] Video Challenges V1 (#111) * simple video challenge implementation wip * dummy multimodal miner * constants reorg * updating verify_models script with t2v * fixing MODEL_PIPELINE init * cleanup * __init__.py * hasattr fix * num_frames must be divisible by 8 * fixing dict iteration * dummy response for videos * fixing small bugs * fixing video logging and compression * apply image transforms uniformly to frames of video * transform list of tensor to pil for synapse prep * cleaning up vali forward * miner function signatures to use Synapse base class instead of ImageSynapse * vali requirements imageio and moviepy * attaching separate video and image forward functions * separating blacklist and priority fns for image/video synapses * pred -> prediction * initial synth video challenge flow * initial video cache implementation * video cache cleanup * video zip downloads * wip fairly large refactor of data generation, functionality and form * generalized hf zip download fn * had claude improve video_cache formatting * vali forward cleanup * cleanup + turning back on randomness for real/fake * fix relative import * wip moving video datasets to vali config * Adding optimization flags to vali config * check if captioning model already loaded * async SyntheticDataGenerator wip * async zip download * ImageCache wip * proper gpu clearing for moderation pipeline * sdg cleanup * new cache system WIP * image/video cache updates * cleaning up unused metadata arg, improving logging * fixed frame sampling, parquet image extraction, image sampling * synth data cache wip * Moving sgd to its own pm2 process * synthetic data gen memory management update * mochi-1-preview * util cleanup, new requirements * ensure SyntheticDataGenerator process waits for ImageCache to populate * adding new t2i models from main * Fixing t2v model output saving * miner cleanup * Moving tall model weights to bitmind hf org * removing test video pkl * fixing circular import * updating usage of hf_hub_download according to some breaking huggingface_hub changes * adding ffmpeg to vali reqs * adding back in video models in async generation after testing * renaming UCF directory to DFB, since it now contains TALL * remaining renames for UCF -> DFB * pyffmpegg * video compatible data augmentations * Default values for level, data_aug_params for failure case * switching image challenges back on * using sample variable to store data for all challenge types * disabling sequential_cpu_offload for CogVideoX5b * logging metadata fields to w&b * log challenge metadata * bump version * adding context manager for generation w different dtypes * variable name fix in ComposeWithTransforms * fixing broken DFB stuff in tall_detector.py * removing unnecessary logging * fixing outdated variable names * cache refactor; moving shared functionality to BaseCache * finally automating w&b project setting * improving logs * improving validator forward structure * detector ABC cleanup + function headers * adding try except for miner performance history loading * fixing import * cleaning up vali logging * pep8 formatting video_utils * cleaning up start_validator.sh, starting validator process before data gen * shortening vali challenge timer * moving data generation management to its own script & added w&B logging * run_data_generator.py * fixing full_path variable name * changing w&b name for data generator * yaml > json gang * simplifying ImageCache.sample to always return one sample * adding option to skip a challenge if no data are available in cache * adding config vars for image/video detector * cleaning up miner class, moving blacklist/priority to base * updating call to image_cache.sample() * fixing mochi gen to 84 frames * fixing video data padding for miners * updating setup script to create new .env file * fixing weight loading after detector refactor * model/detector separation for TALL & modifying base DFB code to allow device configuration * standardizing video detector input to a frames tensor * separation of concerns; moving all video preprocessing to detector class * pep8 cleanup * reformatting if statements * temporarily removing initial dataset class * standardizing config loading across video and image models * finished VideoDataloader and supporting components * moved save config file out of trian script * backwards compatibility for ucf training * moving data augmentation from RealFakeDataset to Dataset subclasses for video aug support * cleaning up data augmentation and target_image_size * import cleanup * gitignore update * fixing typos picked up by flake8 * fixing function name ty flake8 * fixing test fixtures * disabling pytests for now, some are broken after refactor and its 4am * fixing image_size for augmentations * Updated validator gpu requirements (#113) * splitting rewards over image and video (#112) * Update README.md (#110) * combining requirements files * Combined requirements installation * Improved formatting, added checks to prevent overwriting existing .env files. * Re-added endpoint options * Fixed incorrect diffusers install * Fixed missing initialization of miner performance trackers * [Testnet] Docs Updates (#114) * docs updates * mining docs update * Removed deprecated requirements files from github tests (#118) * [Testnet] Async Cache Updates (#119) * breaking out cache updates into their own process * adding retries for loading vali info * moving device config to data generation process * typo * removing old run_updater init arg, fixing dataset indexing * only download 1 zip to start to provide data for vali on first boot * cache deletion functionality * log cache size * name images with dataset prefix * Increased minimum and recommended storage (#120) * [Testnet] Data download cleanup (#121) * moving download_data.py to base_miner/datasets * removing unused args in download_data * constants -> config * docs updates for new paths * updating outdated fn headers * pep8 * use png codec, sample by framerate + num frames * fps, min_fps, max_fps parameterization of sample * return fps and num frames * Fix registry module imports (#123) * Fix registry module imports * Fixing config loading issues * fixing frame sampling * bugfix * print label on testnet * reenabling model verification * update detector class names * Fixing config_name arg for camo * fixing detector config in camo * fixing ref to self.config_name * udpate default frame rate * vidoe dataset creation example * default config for video datasets * update default num_videosg --------- Co-authored-by: Andrew * Update README.md * README title * removing samples from cache * README * fixing cache removal (#125) * Fixed tensor not being set to device for video challenges, causing errors when using cuda (#126) * Mainnet Prep (#127) * resetting challenge timer to 60s * fix logging for miner history loading * randomize model order, log gen time * remove frame limit * separate logging to after data check * generate with batch=1 first for diverse data availability * load v1 history path for smooth transition to new incentive * prune extracted cache * swapping url open-images for jpg * removing unused config args * shortening cache refresh timer * cache optimizations * typo * better variable naming * default to autocast * log num files in cache along with GB * surfacing max size gb variables * cooked typo * Fixed wrong validation split key string causing no transform to be applied * Changed detector arg to be required * fixing hotkey reset check * removing logline * clamp mcc at 0 so video doesn't negatively impact performant image miners * typo * improving cache logs * prune after clear * only update relevant tracker in reward * improved logging, turned off cache removal in sample() --------- Co-authored-by: Andrew --------- Co-authored-by: benliang99 Co-authored-by: Andrew Co-authored-by: Kenobi <108417131+kenobijon@users.noreply.github.com> --- .github/workflows/ci.yml | 10 +- .gitignore | 7 +- README.md | 29 +- autoupdate_miner_steps.sh | 2 +- autoupdate_validator_steps.sh | 5 +- base_miner/{UCF => DFB}/README.md | 0 base_miner/{UCF => DFB}/config/__init__.py | 0 base_miner/DFB/config/constants.py | 19 + base_miner/DFB/config/helpers.py | 81 ++ base_miner/DFB/config/tall.yaml | 89 ++ base_miner/{UCF => DFB}/config/ucf.yaml | 4 +- base_miner/{UCF => DFB}/config/xception.yaml | 0 base_miner/{UCF => DFB}/detectors/__init__.py | 3 +- .../{UCF => DFB}/detectors/base_detector.py | 0 base_miner/DFB/detectors/tall_detector.py | 1019 +++++++++++++++++ .../{UCF => DFB}/detectors/ucf_detector.py | 42 +- base_miner/{UCF => DFB}/logger.py | 0 base_miner/{UCF => DFB}/loss/__init__.py | 0 .../{UCF => DFB}/loss/abstract_loss_func.py | 0 .../loss/contrastive_regularization.py | 0 .../{UCF => DFB}/loss/cross_entropy_loss.py | 0 base_miner/{UCF => DFB}/loss/l1_loss.py | 0 base_miner/{UCF => DFB}/metrics/__init__.py | 0 .../metrics/base_metrics_class.py | 0 base_miner/{UCF => DFB}/metrics/registry.py | 0 base_miner/{UCF => DFB}/metrics/utils.py | 0 base_miner/{UCF => DFB}/networks/__init__.py | 0 base_miner/{UCF => DFB}/networks/xception.py | 0 base_miner/{UCF => DFB}/optimizor/LinearLR.py | 0 base_miner/{UCF => DFB}/optimizor/SAM.py | 0 base_miner/{UCF => DFB}/train_detector.py | 311 ++--- base_miner/{UCF => DFB}/trainer/trainer.py | 5 +- base_miner/NPR/train_detector.py | 31 +- base_miner/UCF/config/constants.py | 15 - base_miner/UCF/config/train_config.yaml | 9 - base_miner/__init__.py | 3 - base_miner/config.py | 42 + base_miner/datasets/__init__.py | 4 + base_miner/datasets/base_dataset.py | 79 ++ base_miner/datasets/create_video_dataset.py | 305 +++++ .../datasets}/download_data.py | 164 +-- base_miner/datasets/image_dataset.py | 113 ++ .../datasets}/real_fake_dataset.py | 35 +- .../data.py => base_miner/datasets/util.py | 32 +- base_miner/datasets/video_dataset.py | 116 ++ base_miner/deepfake_detectors/__init__.py | 7 +- .../deepfake_detectors/camo_detector.py | 10 +- .../deepfake_detectors/configs/tall.yaml | 3 + .../deepfake_detectors/configs/ucf.yaml | 2 +- .../deepfake_detectors/configs/ucf_face.yaml | 2 +- .../deepfake_detectors/deepfake_detector.py | 145 ++- base_miner/deepfake_detectors/npr_detector.py | 26 +- .../deepfake_detectors/tall_detector.py | 51 + base_miner/deepfake_detectors/ucf_detector.py | 76 +- base_miner/gating_mechanisms/face_gate.py | 4 +- .../gating_mechanisms/gating_mechanism.py | 2 +- bitmind/__init__.py | 2 +- bitmind/base/miner.py | 111 +- bitmind/base/neuron.py | 4 - bitmind/base/validator.py | 60 +- bitmind/constants.py | 134 --- bitmind/image_dataset.py | 159 --- bitmind/miner/predict.py | 21 - bitmind/protocol.py | 170 ++- .../README.md | 0 bitmind/synthetic_data_generation/__init__.py | 1 + .../image_annotation_generator.py | 244 ++++ .../image_utils.py | 4 +- .../synthetic_data_generation/prompt_utils.py | 39 + .../synthetic_data_generator.py | 387 +++++++ .../image_annotation_generator.py | 344 ------ .../synthetic_image_generator.py | 295 ----- .../utils/annotation_utils.py | 58 - .../utils/hugging_face_utils.py | 81 -- .../utils/stress_test.py | 66 -- bitmind/utils/config.py | 62 +- bitmind/{ => utils}/image_transforms.py | 199 ++-- bitmind/utils/mock.py | 28 +- bitmind/utils/video_utils.py | 26 + bitmind/validator/__init__.py | 2 - bitmind/validator/cache/__init__.py | 3 + bitmind/validator/cache/base_cache.py | 261 +++++ bitmind/validator/cache/download.py | 164 +++ bitmind/validator/cache/extract.py | 197 ++++ bitmind/validator/cache/image_cache.py | 137 +++ bitmind/validator/cache/util.py | 77 ++ bitmind/validator/cache/video_cache.py | 212 ++++ bitmind/validator/config.py | 236 ++++ bitmind/validator/forward.py | 202 ++-- .../validator/miner_performance_tracker.py | 4 +- bitmind/validator/model_utils.py | 37 + bitmind/validator/reward.py | 61 +- bitmind/validator/scripts/__init__.py | 0 .../validator/scripts/run_cache_updater.py | 73 ++ .../validator/scripts/run_data_generator.py | 52 + bitmind/validator/scripts/util.py | 82 ++ bitmind/validator/verify_models.py | 19 +- bitmind/validator/video_utils.py | 106 ++ create_video_dataset_example.sh | 14 + docs/Incentive.md | 10 +- docs/Mining.md | 23 +- docs/Validating.md | 39 +- min_compute.yml | 31 +- neurons/miner.py | 223 ++-- neurons/validator.py | 86 +- neurons/validator_proxy.py | 5 +- requirements-miner.txt | 5 - requirements-validator.txt | 5 - requirements.txt | 46 +- run_neuron.py | 2 +- setup_env.sh | 111 ++ setup_miner_env.sh | 39 - setup_validator_env.sh | 43 - start_miner.sh | 12 +- start_validator.sh | 41 +- tests/fixtures/image_transforms.py | 12 +- tests/validator/test_generate_image.py | 8 +- tests/validator/test_verify_models.py | 4 +- 118 files changed, 5766 insertions(+), 2315 deletions(-) rename base_miner/{UCF => DFB}/README.md (100%) rename base_miner/{UCF => DFB}/config/__init__.py (100%) create mode 100644 base_miner/DFB/config/constants.py create mode 100644 base_miner/DFB/config/helpers.py create mode 100644 base_miner/DFB/config/tall.yaml rename base_miner/{UCF => DFB}/config/ucf.yaml (96%) rename base_miner/{UCF => DFB}/config/xception.yaml (100%) rename base_miner/{UCF => DFB}/detectors/__init__.py (78%) rename base_miner/{UCF => DFB}/detectors/base_detector.py (100%) create mode 100644 base_miner/DFB/detectors/tall_detector.py rename base_miner/{UCF => DFB}/detectors/ucf_detector.py (92%) rename base_miner/{UCF => DFB}/logger.py (100%) rename base_miner/{UCF => DFB}/loss/__init__.py (100%) rename base_miner/{UCF => DFB}/loss/abstract_loss_func.py (100%) rename base_miner/{UCF => DFB}/loss/contrastive_regularization.py (100%) rename base_miner/{UCF => DFB}/loss/cross_entropy_loss.py (100%) rename base_miner/{UCF => DFB}/loss/l1_loss.py (100%) rename base_miner/{UCF => DFB}/metrics/__init__.py (100%) rename base_miner/{UCF => DFB}/metrics/base_metrics_class.py (100%) rename base_miner/{UCF => DFB}/metrics/registry.py (100%) rename base_miner/{UCF => DFB}/metrics/utils.py (100%) rename base_miner/{UCF => DFB}/networks/__init__.py (100%) rename base_miner/{UCF => DFB}/networks/xception.py (100%) rename base_miner/{UCF => DFB}/optimizor/LinearLR.py (100%) rename base_miner/{UCF => DFB}/optimizor/SAM.py (100%) rename base_miner/{UCF => DFB}/train_detector.py (54%) rename base_miner/{UCF => DFB}/trainer/trainer.py (98%) delete mode 100644 base_miner/UCF/config/constants.py delete mode 100644 base_miner/UCF/config/train_config.yaml create mode 100644 base_miner/config.py create mode 100644 base_miner/datasets/__init__.py create mode 100644 base_miner/datasets/base_dataset.py create mode 100644 base_miner/datasets/create_video_dataset.py rename {bitmind => base_miner/datasets}/download_data.py (53%) create mode 100644 base_miner/datasets/image_dataset.py rename {bitmind => base_miner/datasets}/real_fake_dataset.py (82%) rename bitmind/utils/data.py => base_miner/datasets/util.py (86%) create mode 100644 base_miner/datasets/video_dataset.py create mode 100644 base_miner/deepfake_detectors/configs/tall.yaml create mode 100644 base_miner/deepfake_detectors/tall_detector.py delete mode 100644 bitmind/constants.py delete mode 100644 bitmind/image_dataset.py delete mode 100644 bitmind/miner/predict.py rename bitmind/{synthetic_image_generation => synthetic_data_generation}/README.md (100%) create mode 100644 bitmind/synthetic_data_generation/__init__.py create mode 100644 bitmind/synthetic_data_generation/image_annotation_generator.py rename bitmind/{synthetic_image_generation/utils => synthetic_data_generation}/image_utils.py (97%) create mode 100644 bitmind/synthetic_data_generation/prompt_utils.py create mode 100644 bitmind/synthetic_data_generation/synthetic_data_generator.py delete mode 100644 bitmind/synthetic_image_generation/image_annotation_generator.py delete mode 100644 bitmind/synthetic_image_generation/synthetic_image_generator.py delete mode 100644 bitmind/synthetic_image_generation/utils/annotation_utils.py delete mode 100644 bitmind/synthetic_image_generation/utils/hugging_face_utils.py delete mode 100644 bitmind/synthetic_image_generation/utils/stress_test.py rename bitmind/{ => utils}/image_transforms.py (69%) create mode 100644 bitmind/utils/video_utils.py create mode 100644 bitmind/validator/cache/__init__.py create mode 100644 bitmind/validator/cache/base_cache.py create mode 100644 bitmind/validator/cache/download.py create mode 100644 bitmind/validator/cache/extract.py create mode 100644 bitmind/validator/cache/image_cache.py create mode 100644 bitmind/validator/cache/util.py create mode 100644 bitmind/validator/cache/video_cache.py create mode 100644 bitmind/validator/config.py create mode 100644 bitmind/validator/model_utils.py create mode 100644 bitmind/validator/scripts/__init__.py create mode 100644 bitmind/validator/scripts/run_cache_updater.py create mode 100644 bitmind/validator/scripts/run_data_generator.py create mode 100644 bitmind/validator/scripts/util.py create mode 100644 bitmind/validator/video_utils.py create mode 100755 create_video_dataset_example.sh delete mode 100644 requirements-miner.txt delete mode 100644 requirements-validator.txt create mode 100755 setup_env.sh delete mode 100755 setup_miner_env.sh delete mode 100755 setup_validator_env.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 419d56a1..c8f67697 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,15 +28,13 @@ jobs: python -m pip install --upgrade pip pip install flake8 pytest pytest-asyncio pip install -r requirements.txt - pip install -r requirements-miner.txt - pip install -r requirements-validator.txt - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - # run tests in tests/ dir and only fail if there are failures or errors - pytest tests/ --verbose --failed-first --exitfirst --disable-warnings + #- name: Test with pytest + # run: | + # # run tests in tests/ dir and only fail if there are failures or errors + # pytest tests/ --verbose --failed-first --exitfirst --disable-warnings diff --git a/.gitignore b/.gitignore index 3bda75c6..521104fb 100644 --- a/.gitignore +++ b/.gitignore @@ -164,7 +164,10 @@ data/ checkpoints/ .requirements_installed base_miner/NPR/weights/* -base_miner/UCF/weights/* -base_miner/UCF/logs/* +base_miner/NPR/logs/* +base_miner/DFB/weights/* +base_miner/DFB/logs/* miner_eval.py *.env +*~ +wandb/ \ No newline at end of file diff --git a/README.md b/README.md index db0c0a28..e772166a 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,13 @@

BitMind Logo

-

BitMind Subnet (Bittensor Subnet 34)

+

BitMind Subnet (Bittensor Subnet 34)
Deepfake Detection

+ ![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg) Welcome to the BitMind Subnet! This repository contains all the necessary information to get started, understand our subnet architecture, and contribute. -## The BitMind Platform - -The [BitMind platform](https://app.bitmindlabs.ai/) offers a best-in-class developer experience for Bittensor miners. - -⚡ **Access Compute**: We offer a wide variety of CPU and GPU options
-⚡ **Develop in VSCode**: Develop in a feature-rich IDE (we support Jupyter too if you hate rich features)
-⚡ **Fully Managed Devops:** No more tinkering with networking configuration - register and deploy your miner in just a few clicks
-⚡ **Monitor Emissions:** View the emissions for all of your miners in our Miner Dashboard - ## Quick Links @@ -25,12 +17,12 @@ The [BitMind platform](https://app.bitmindlabs.ai/) offers a best-in-class devel - [Project Structure and Terminology 📖](docs/Glossary.md) - [Contributor Guide 🤝](docs/Contributor_Guide.md) -**IMPORTANT**: If you are new to Bittensor, we recommend familiarizing yourself with the basics on the [Bittensor Website](https://bittensor.com/) before proceeding to the [Setup Guide](docs/Setup.md) page. +**IMPORTANT**: If you are new to Bittensor, we recommend familiarizing yourself with the basics on the [Bittensor Website](https://bittensor.com/) before proceeding. ## Identifying AI-Generated Media with a Decentralized Framework **Overview:** -The BitMind Subnet leverages advanced generative and discriminative AI models within the Bittensor network to detect AI-generated images. This platform is engineered on a decentralized, incentive-driven framework to enhance trustworthiness and stimulate continuous technological advancement. +The BitMind Subnet leverages advanced generative and discriminative AI models within the Bittensor network to detect AI-generated images and videos. This platform is engineered on a decentralized, incentive-driven framework to enhance trustworthiness and stimulate continuous technological advancement. **Purpose:** The proliferation of generative AI models has significantly increased the production of high-quality synthetic media, presenting challenges in distinguishing these from authentic content. The BitMind Subnet addresses this challenge by providing robust detection mechanisms to maintain the integrity of digital media. @@ -47,13 +39,22 @@ The proliferation of generative AI models has significantly increased the produc **Core Components:** - **Miners:** Tasked with running binary classifiers that discern between genuine and AI-generated content. - - **Research Integration:** We systematically update our detection models and methodologies in response to emerging academic research, offering resources like training codes and model weights to our community. + - **Research Integration:** We systematically update our detection models and methodologies in response to emerging academic research, offering resources like training code, model weights and datasets to our community. - **Validators:** Responsible for challenging miners with a balanced mix of real and synthetic images, drawn from a diverse pool of sources. - - **Resource Expansion:** We are committed to enhancing the validators' capabilities by increasing the diversity and volume of the image pool, which supports rigorous testing and validation processes. + - **Resource Expansion:** We continuously add new datasets and generative models to our validators in order to maximize the coverage of the types of media our miners are incentivized to detect. **Subnet Architecture Diagram** ![Subnet Architecture](static/Subnet-Arch.png) +## The BitMind Platform + +The [BitMind platform](https://app.bitmindlabs.ai/) offers a best-in-class developer experience for Bittensor miners. + +⚡ **Access Compute**: We offer a wide variety of CPU and GPU options
+⚡ **Develop in VSCode**: Develop in a feature-rich IDE (we support Jupyter too if you hate rich features)
+⚡ **Fully Managed Devops:** No more tinkering with networking configuration - register and deploy your miner in just a few clicks
+⚡ **Monitor Emissions:** View the emissions for all of your miners in our Miner Dashboard + ## Community

diff --git a/autoupdate_miner_steps.sh b/autoupdate_miner_steps.sh index 1d207e66..ea6a7af5 100755 --- a/autoupdate_miner_steps.sh +++ b/autoupdate_miner_steps.sh @@ -5,5 +5,5 @@ # THIS FILE ITSELF MAY CHANGE FROM UPDATE TO UPDATE, SO WE CAN DYNAMICALLY FIX ANY ISSUES echo $CONDA_PREFIX -$CONDA_PREFIX/bin/pip install -e . +./setup_env.sh echo "Autoupdate steps complete :)" diff --git a/autoupdate_validator_steps.sh b/autoupdate_validator_steps.sh index 8a46173d..ea6a7af5 100755 --- a/autoupdate_validator_steps.sh +++ b/autoupdate_validator_steps.sh @@ -5,8 +5,5 @@ # THIS FILE ITSELF MAY CHANGE FROM UPDATE TO UPDATE, SO WE CAN DYNAMICALLY FIX ANY ISSUES echo $CONDA_PREFIX -$CONDA_PREFIX/bin/pip install -e . -$CONDA_PREFIX/bin/pip install -r requirements-validator.txt -$CONDA_PREFIX/bin/python bitmind/download_data.py -$CONDA_PREFIX/bin/python bitmind/validator/verify_models.py +./setup_env.sh echo "Autoupdate steps complete :)" diff --git a/base_miner/UCF/README.md b/base_miner/DFB/README.md similarity index 100% rename from base_miner/UCF/README.md rename to base_miner/DFB/README.md diff --git a/base_miner/UCF/config/__init__.py b/base_miner/DFB/config/__init__.py similarity index 100% rename from base_miner/UCF/config/__init__.py rename to base_miner/DFB/config/__init__.py diff --git a/base_miner/DFB/config/constants.py b/base_miner/DFB/config/constants.py new file mode 100644 index 00000000..2ae2373e --- /dev/null +++ b/base_miner/DFB/config/constants.py @@ -0,0 +1,19 @@ +import os + +CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__)) +BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/DFB/ +WEIGHTS_DIR = os.path.join(BASE_PATH, "weights") + +CONFIG_PATHS = { + 'UCF': os.path.join(CONFIGS_DIR, "ucf.yaml"), + 'TALL': os.path.join(CONFIGS_DIR, "tall.yaml") +} + +HF_REPOS = { + "UCF": "bitmind/ucf", + "TALL": "bitmind/tall" +} + +BACKBONE_CKPT = "xception_best.pth" + +DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(BASE_PATH, "../../bitmind/dataset_processing/dlib_tools/shape_predictor_81_face_landmarks.dat")) \ No newline at end of file diff --git a/base_miner/DFB/config/helpers.py b/base_miner/DFB/config/helpers.py new file mode 100644 index 00000000..557bf896 --- /dev/null +++ b/base_miner/DFB/config/helpers.py @@ -0,0 +1,81 @@ +import yaml + + +def save_config(config, outputs_dir): + """ + Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved. + Also, lists like 'mean' and 'std' are saved in flow style (on a single line). + + Args: + config (dict): The configuration dictionary to save. + outputs_dir (str): The directory path where the files will be saved. + """ + + def is_basic_type(value): + """ + Check if a value is a basic data type that can be saved in YAML. + Basic types include int, float, str, bool, list, and dict. + """ + return isinstance(value, (int, float, str, bool, list, dict, type(None))) + + def filter_dict(data_dict): + """ + Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects). + """ + if not isinstance(data_dict, dict): + return data_dict + + filtered_dict = {} + for key, value in data_dict.items(): + if isinstance(value, dict): + # Recursively filter nested dictionaries + nested_dict = filter_dict(value) + if nested_dict: # Only add non-empty dictionaries + filtered_dict[key] = nested_dict + elif is_basic_type(value): + # Add if the value is a basic type + filtered_dict[key] = value + else: + # Skip the key if the value is not a basic type (e.g., an object) + print(f"Skipping key '{key}' because its value is of type {type(value)}") + + return filtered_dict + + def save_dict_to_yaml(data_dict, file_path): + """ + Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object. + Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style. + + Args: + data_dict (dict): The dictionary to save. + file_path (str): The local file path where the YAML file will be saved. + """ + + # Custom representer for lists to force flow style (compact lists) + class FlowStyleList(list): + pass + + def flow_style_list_representer(dumper, data): + return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True) + + yaml.add_representer(FlowStyleList, flow_style_list_representer) + + # Preprocess specific lists to be in flow style + if 'mean' in data_dict: + data_dict['mean'] = FlowStyleList(data_dict['mean']) + if 'std' in data_dict: + data_dict['std'] = FlowStyleList(data_dict['std']) + + try: + # Filter the dictionary + filtered_dict = filter_dict(data_dict) + + # Save the filtered dictionary as YAML + with open(file_path, 'w') as f: + yaml.dump(filtered_dict, f, default_flow_style=False) # Save with default block style except for FlowStyleList + print(f"Filtered dictionary successfully saved to {file_path}") + except Exception as e: + print(f"Error saving dictionary to YAML: {e}") + + # Save as YAML + save_dict_to_yaml(config, outputs_dir + '/config.yaml') \ No newline at end of file diff --git a/base_miner/DFB/config/tall.yaml b/base_miner/DFB/config/tall.yaml new file mode 100644 index 00000000..96de6a86 --- /dev/null +++ b/base_miner/DFB/config/tall.yaml @@ -0,0 +1,89 @@ +# model setting +pretrained: https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth # path to a pre-trained model, if using one +model_name: tall # model name + +mask_grid_size: 16 +num_classes: 2 +embed_dim: 128 +mlp_ratio: 4.0 +patch_size: 4 +window_size: [14, 14, 14, 7] +depths: [2, 2, 18, 2] +num_heads: [4, 8, 16, 32] +ape: true # use absolution position embedding +thumbnail_rows: 2 +drop_rate: 0 +drop_path_rate: 0.1 + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 4 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 4 # number of frames in each clip, should be square number of an integer +dataset_type: tall + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.485, 0.456, 0.406] +std: [0.229, 0.224, 0.225] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 100 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations \ No newline at end of file diff --git a/base_miner/UCF/config/ucf.yaml b/base_miner/DFB/config/ucf.yaml similarity index 96% rename from base_miner/UCF/config/ucf.yaml rename to base_miner/DFB/config/ucf.yaml index 40eb4b26..cee1097f 100644 --- a/base_miner/UCF/config/ucf.yaml +++ b/base_miner/DFB/config/ucf.yaml @@ -2,7 +2,9 @@ log_dir: ../debug_logs/ucf # model setting -pretrained: ../weights/xception-best.pth # path to a pre-trained model, if using one +pretrained: + hf_repo: bm_ucf + filename: xception-best.pth model_name: ucf # model name backbone_name: xception # backbone name encoder_feat_dim: 512 # feature dimension of the backbone diff --git a/base_miner/UCF/config/xception.yaml b/base_miner/DFB/config/xception.yaml similarity index 100% rename from base_miner/UCF/config/xception.yaml rename to base_miner/DFB/config/xception.yaml diff --git a/base_miner/UCF/detectors/__init__.py b/base_miner/DFB/detectors/__init__.py similarity index 78% rename from base_miner/UCF/detectors/__init__.py rename to base_miner/DFB/detectors/__init__.py index 6059a264..cbaeaf92 100644 --- a/base_miner/UCF/detectors/__init__.py +++ b/base_miner/DFB/detectors/__init__.py @@ -8,4 +8,5 @@ from metrics.registry import DETECTOR -from .ucf_detector import UCFDetector \ No newline at end of file +from .ucf_detector import UCFDetector +from .tall_detector import TALLDetector \ No newline at end of file diff --git a/base_miner/UCF/detectors/base_detector.py b/base_miner/DFB/detectors/base_detector.py similarity index 100% rename from base_miner/UCF/detectors/base_detector.py rename to base_miner/DFB/detectors/base_detector.py diff --git a/base_miner/DFB/detectors/tall_detector.py b/base_miner/DFB/detectors/tall_detector.py new file mode 100644 index 00000000..8a175fa3 --- /dev/null +++ b/base_miner/DFB/detectors/tall_detector.py @@ -0,0 +1,1019 @@ +""" +# author: Kangran Zhao +# email: kangranzhao@link.cuhk.edu.cn +# date: 2023-0822 +# description: Class for the TALLDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{xu2023tall, + title={TALL: Thumbnail Layout for Deepfake Video Detection}, + author={Xu, Yuting and Liang, Jian and Jia, Gengyun and Yang, Ziming and Zhang, Yanhao and He, Ran}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + pages={22658--22668}, + year={2023} +} +""" + +import logging +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from torch.hub import load_state_dict_from_url + +from .base_detector import AbstractDetector +from base_miner.DFB.detectors import DETECTOR +from base_miner.DFB.loss import LOSSFUNC +from base_miner.DFB.metrics.base_metrics_class import calculate_metrics_for_train + +_logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='tall') +class TALLDetector(AbstractDetector): + def __init__(self, config, device='cuda'): + super().__init__() + self.device = device + self.model = self.build_backbone(config).to(self.device) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + model_kwargs = dict( + num_classes=config['num_classes'], + embed_dim=config['embed_dim'], + mlp_ratio=config['mlp_ratio'], + patch_size=config['patch_size'], + window_size=config['window_size'], + depths=config['depths'], + num_heads=config['num_heads'], + ape=config['ape'], + thumbnail_rows=config['thumbnail_rows'], + drop_rate=config['drop_rate'], + drop_path_rate=config['drop_path_rate'], + use_checkpoint=False, + bottleneck=False, + duration=config['clip_size'], + device=self.device + ) + + default_cfg = { + 'url': config['pretrained'], + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': .9, + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + } + + backbone = SwinTransformer(img_size=config['resolution'], **model_kwargs) + backbone.default_cfg = default_cfg + + load_pretrained( + backbone, + num_classes=config['num_classes'], + in_chans=model_kwargs.get('in_chans', 3), + filter_fn=_conv_filter, + img_size=config['resolution'], + pretrained_window_size=7, + pretrained_model='' + ) + + return backbone + + def build_loss(self, config): + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + bs, t, c, h, w = data_dict['image'].shape + inputs = data_dict['image'].view(bs, t * c, h, w) + pred = self.model(inputs) + return pred + + def classifier(self, features: torch.tensor): + pass + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'].long() + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + pred = self.features(data_dict) + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': prob} + return pred_dict + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """Partition input tensor into windows. + + Args: + x: Input tensor of shape (B, H, W, C) + window_size (int): Size of each window + + Returns: + windows: Output tensor of shape (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """Reverse window partitioning. + + Args: + windows: Input tensor of shape (num_windows*B, window_size, window_size, C) + window_size (int): Size of each window + H (int): Height of original image + W (int): Width of original image + + Returns: + x: Output tensor of shape (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, + window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + + It supports both shifted and non-shifted window attention. + + Args: + dim (int): Number of input channels + window_size (tuple[int]): Height and width of window + num_heads (int): Number of attention heads + qkv_bias (bool, optional): Add learnable bias to query, key, value. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, + qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # Define parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) + + # Get pair-wise relative position index for each token in window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward pass. + + Args: + x: Input features with shape (num_windows*B, N, C) + mask: (0/-inf) mask with shape (num_windows, Wh*Ww, Wh*Ww) or None + + Returns: + Output tensor after attention + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + attn = attn + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + """Extra string representation.""" + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + """Calculate FLOPs for one window.""" + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, bottleneck=False, use_checkpoint=False + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None) + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None) + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward_attn(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2) + ) + else: + shifted_x = x + + # partition windows + # nW*B, window_size, window_size, C + x_windows = window_partition(shifted_x, self.window_size) + # nW*B, window_size*window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # nW*B, window_size*window_size, C + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2) + ) + else: + x = shifted_x + x = x.view(B, H * W, C) + + return x + + def forward_mlp(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x): + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_attn, x) + else: + x = self.forward_attn(x) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_mlp, x) + else: + x = x + self.forward_mlp(x) + + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, " + f"num_heads={self.num_heads}, window_size={self.window_size}, " + f"shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """Forward pass. + + Args: + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + bottleneck=False + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + bottleneck=bottleneck if i == depth - 1 else False, + use_checkpoint=use_checkpoint + ) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, + dim=dim, + norm_layer=norm_layer + ) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r"""Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + img_size=(224, 224), + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None + ): + super().__init__() + # img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * ( + self.patch_size[0] * self.patch_size[1] + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__( + self, duration=8, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, thumbnail_rows=1, bottleneck=False, device='cuda', **kwargs + ): + super().__init__() + + self.duration = duration # 4 + self.num_classes = num_classes # 2 + self.num_layers = len(depths) # [2, 2, 18, 2] + self.embed_dim = embed_dim # 128 + self.ape = ape # True + self.patch_norm = patch_norm # False + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio # 4 = default + self.thumbnail_rows = thumbnail_rows # 2 + self.device = device + + self.img_size = img_size # 224 + self.window_size = ([window_size for _ in depths] if not isinstance(window_size, list) + else window_size) + + self.frame_padding = self.duration % thumbnail_rows # 0 + if self.frame_padding != 0: + self.frame_padding = self.thumbnail_rows - self.frame_padding + self.duration += self.frame_padding + + # split image into non-overlapping patches + thumbnail_dim = (thumbnail_rows, self.duration // thumbnail_rows) # (2, 2) + thumbnail_size = (img_size * thumbnail_dim[0], img_size * thumbnail_dim[1]) + + self.patch_embed = PatchEmbed( + img_size=(img_size, img_size), + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None + ) + num_patches = self.patch_embed.num_patches # 16 + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution # [56, 56] + + # absolute position embedding + if self.ape: # True + self.frame_pos_embed = nn.Parameter(torch.zeros(1, self.duration, embed_dim)) + trunc_normal_(self.frame_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=( + patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=self.window_size[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + bottleneck=bottleneck + ) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = (nn.Linear(self.num_features, num_classes) + if num_classes > 0 else nn.Identity()) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed', 'frame_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def create_thumbnail(self, x): + input_size = x.shape[-2:] + if input_size != to_2tuple(self.img_size): + x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear') + x = rearrange(x, 'b (th tw c) h w -> b c (th h) (tw w)', + th=self.thumbnail_rows, c=3) + return x + + def pad_frames(self, x): + frame_num = self.duration - self.frame_padding + x = x.view((-1, 3 * frame_num) + x.size()[2:]) + x_padding = torch.zeros((x.shape[0], 3 * self.frame_padding) + + x.size()[2:]).to(self.device) + x = torch.cat((x, x_padding), dim=1) + assert x.shape[1] == 3 * self.duration, ( + 'frame number %d not the same as adjusted input size %d' % + (x.shape[1], 3 * self.duration)) + + return x + + # need to find a better way to do this, maybe torch.fold? + def create_image_pos_embed(self): + img_rows, img_cols = self.patches_resolution # (56, 56) + _, _, T = self.frame_pos_embed.shape # (1, 4, embed) + rows = img_rows // self.thumbnail_rows # 28 + cols = img_cols // (self.duration // self.thumbnail_rows) # 28 + img_pos_embed = torch.zeros(img_rows, img_cols, T).to(self.device) # [56, 56, embed] + for i in range(self.duration): + r_indx = (i // self.thumbnail_rows) * rows + c_indx = (i % self.thumbnail_rows) * cols + img_pos_embed[r_indx:r_indx + rows, c_indx:c_indx + cols] = self.frame_pos_embed[0, i] + + return img_pos_embed.reshape(-1, T) # [56*56, embed] + + def forward_features(self, x): + if self.frame_padding > 0: + x = self.pad_frames(x) + else: + x = x.view((-1, 3 * self.duration) + x.size()[2:]) + + x = self.create_thumbnail(x) + x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear') # [B, 3, 224, 224] + + x = self.patch_embed(x) # [B, 56*56, embed] + if self.ape: + img_pos_embed = self.create_image_pos_embed() + x = x + img_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += (self.num_features * self.patches_resolution[0] * + self.patches_resolution[1] // (2 ** self.num_layers)) + flops += self.num_features * self.num_classes + return flops + +def load_pretrained( + model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, + num_patches=196, pretrained_window_size=7, pretrained_model="", strict=True +): + if cfg is None: + cfg = getattr(model, 'default_cfg') + if cfg is None or 'url' not in cfg or not cfg['url']: + _logger.warning("Pretrained model URL is invalid, using random initialization.") + return + + if len(pretrained_model) == 0: + state_dict = load_state_dict_from_url(cfg['url'], map_location='cpu') + else: + try: + state_dict = torch.load(pretrained_model)['model'] + except: + state_dict = torch.load(pretrained_model) + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + if in_chans == 1: + conv1_name = cfg['first_conv'] + _logger.info( + 'Converting first conv (%s) pretrained weights from 3 to 1 channel', + conv1_name + ) + conv1_weight = state_dict[conv1_name + '.weight'] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I > 3: + assert conv1_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) + conv1_weight = conv1_weight.sum(dim=2, keepdim=False) + else: + conv1_weight = conv1_weight.sum(dim=1, keepdim=True) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + elif in_chans != 3: + conv1_name = cfg['first_conv'] + conv1_weight = state_dict[conv1_name + '.weight'] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I != 3: + _logger.warning( + 'Deleting first conv (%s) from pretrained weights.', + conv1_name + ) + del state_dict[conv1_name + '.weight'] + strict = False + else: + _logger.info( + 'Repeating first conv (%s) weights in channel dim.', + conv1_name + ) + repeat = int(math.ceil(in_chans / 3)) + conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv1_weight *= (3 / float(in_chans)) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + + classifier_name = cfg['classifier'] + if num_classes == 1000 and cfg['num_classes'] == 1001: + # special case for imagenet trained models with extra background class + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[1:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[1:] + elif num_classes != cfg['num_classes']: + # discard fully connected for all other differences + del state_dict['model'][classifier_name + '.weight'] + del state_dict['model'][classifier_name + '.bias'] + strict = False + ''' + ## Resizing the positional embeddings in case they don't match + if img_size != cfg['input_size'][1]: + pos_embed = state_dict['pos_embed'] + cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) + other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) + new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest') + new_pos_embed = new_pos_embed.transpose(1, 2) + new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) + state_dict['pos_embed'] = new_pos_embed + ''' + + # remove window_size related parameters + window_size = (model.window_size)[0] + print(pretrained_window_size, window_size) + + new_state_dict = state_dict['model'].copy() + for key in state_dict['model']: + if 'attn_mask' in key: + del new_state_dict[key] + + if 'relative_position_index' in key: + del new_state_dict[key] + + # resize it + if 'relative_position_bias_table' in key: + pretrained_table = state_dict['model'][key] + pretrained_table_size = int(math.sqrt(pretrained_table.shape[0])) + table_size = int(math.sqrt(model.state_dict()[key].shape[0])) + if pretrained_table_size != table_size: + table = pretrained_table.permute(1, 0).view(1, -1, pretrained_table_size, pretrained_table_size) + table = nn.functional.interpolate(table, size=table_size, mode='bilinear') + table = table.view(-1, table_size * table_size).permute(1, 0) + new_state_dict[key] = table + + for key in model.state_dict(): + if 'bottleneck_norm' in key: + attn_key = key.replace('bottleneck_norm', 'norm1') + # print (key, attn_key) + new_state_dict[key] = new_state_dict[attn_key] + + print('loading weights....') + ## Loading the weights + model.load_state_dict(new_state_dict, strict=False) + + +def _conv_filter(state_dict, patch_size=4): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + if v.shape[-1] != patch_size: + patch_size = v.shape[-1] + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict \ No newline at end of file diff --git a/base_miner/UCF/detectors/ucf_detector.py b/base_miner/DFB/detectors/ucf_detector.py similarity index 92% rename from base_miner/UCF/detectors/ucf_detector.py rename to base_miner/DFB/detectors/ucf_detector.py index 4ffc9b39..51b16d81 100644 --- a/base_miner/UCF/detectors/ucf_detector.py +++ b/base_miner/DFB/detectors/ucf_detector.py @@ -43,13 +43,15 @@ from metrics.base_metrics_class import calculate_metrics_for_train -from .base_detector import AbstractDetector -from UCF.detectors import DETECTOR -from networks import BACKBONE -from loss import LOSSFUNC +from DFB.detectors.base_detector import AbstractDetector +from DFB.detectors import DETECTOR +from DFB.networks import BACKBONE +from DFB.loss import LOSSFUNC +from DFB.config.constants import WEIGHTS_DIR logger = logging.getLogger(__name__) + @DETECTOR.register_module(module_name='ucf') class UCFDetector(AbstractDetector): def __init__(self, config): @@ -99,20 +101,32 @@ def __init__(self, config): ) def build_backbone(self, config): - current_dir = os.path.dirname(os.path.abspath(__file__)) - pretrained_path = os.path.join(current_dir, config['pretrained']) # prepare the backbone backbone_class = BACKBONE[config['backbone_name']] model_config = config['backbone_config'] backbone = backbone_class(model_config) - # if donot load the pretrained weights, fail to get good results - state_dict = torch.load(pretrained_path) - for name, weights in state_dict.items(): - if 'pointwise' in name: - state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) - state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} - backbone.load_state_dict(state_dict, False) - logger.info('Load pretrained model successfully!') + + if 'pretrained' in config: + pretrained_path = config['pretrained'] + if isinstance(pretrained_path, dict): + if 'local_path' in pretrained_path: + pretrained_path = pretrained_path['local_path'] + elif 'filename' in pretrained_path: + pretrained_path = pretrained_path['filename'] + else: + pretrained_path = pretrained_path.split('/')[-1] + + if not os.path.isabs(pretrained_path): + pretrained_path = os.path.join(WEIGHTS_DIR, pretrained_path) + + logger.info(f"Loading {pretrained_path}") + state_dict = torch.load(pretrained_path) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') return backbone def build_loss(self, config): diff --git a/base_miner/UCF/logger.py b/base_miner/DFB/logger.py similarity index 100% rename from base_miner/UCF/logger.py rename to base_miner/DFB/logger.py diff --git a/base_miner/UCF/loss/__init__.py b/base_miner/DFB/loss/__init__.py similarity index 100% rename from base_miner/UCF/loss/__init__.py rename to base_miner/DFB/loss/__init__.py diff --git a/base_miner/UCF/loss/abstract_loss_func.py b/base_miner/DFB/loss/abstract_loss_func.py similarity index 100% rename from base_miner/UCF/loss/abstract_loss_func.py rename to base_miner/DFB/loss/abstract_loss_func.py diff --git a/base_miner/UCF/loss/contrastive_regularization.py b/base_miner/DFB/loss/contrastive_regularization.py similarity index 100% rename from base_miner/UCF/loss/contrastive_regularization.py rename to base_miner/DFB/loss/contrastive_regularization.py diff --git a/base_miner/UCF/loss/cross_entropy_loss.py b/base_miner/DFB/loss/cross_entropy_loss.py similarity index 100% rename from base_miner/UCF/loss/cross_entropy_loss.py rename to base_miner/DFB/loss/cross_entropy_loss.py diff --git a/base_miner/UCF/loss/l1_loss.py b/base_miner/DFB/loss/l1_loss.py similarity index 100% rename from base_miner/UCF/loss/l1_loss.py rename to base_miner/DFB/loss/l1_loss.py diff --git a/base_miner/UCF/metrics/__init__.py b/base_miner/DFB/metrics/__init__.py similarity index 100% rename from base_miner/UCF/metrics/__init__.py rename to base_miner/DFB/metrics/__init__.py diff --git a/base_miner/UCF/metrics/base_metrics_class.py b/base_miner/DFB/metrics/base_metrics_class.py similarity index 100% rename from base_miner/UCF/metrics/base_metrics_class.py rename to base_miner/DFB/metrics/base_metrics_class.py diff --git a/base_miner/UCF/metrics/registry.py b/base_miner/DFB/metrics/registry.py similarity index 100% rename from base_miner/UCF/metrics/registry.py rename to base_miner/DFB/metrics/registry.py diff --git a/base_miner/UCF/metrics/utils.py b/base_miner/DFB/metrics/utils.py similarity index 100% rename from base_miner/UCF/metrics/utils.py rename to base_miner/DFB/metrics/utils.py diff --git a/base_miner/UCF/networks/__init__.py b/base_miner/DFB/networks/__init__.py similarity index 100% rename from base_miner/UCF/networks/__init__.py rename to base_miner/DFB/networks/__init__.py diff --git a/base_miner/UCF/networks/xception.py b/base_miner/DFB/networks/xception.py similarity index 100% rename from base_miner/UCF/networks/xception.py rename to base_miner/DFB/networks/xception.py diff --git a/base_miner/UCF/optimizor/LinearLR.py b/base_miner/DFB/optimizor/LinearLR.py similarity index 100% rename from base_miner/UCF/optimizor/LinearLR.py rename to base_miner/DFB/optimizor/LinearLR.py diff --git a/base_miner/UCF/optimizor/SAM.py b/base_miner/DFB/optimizor/SAM.py similarity index 100% rename from base_miner/UCF/optimizor/SAM.py rename to base_miner/DFB/optimizor/SAM.py diff --git a/base_miner/UCF/train_detector.py b/base_miner/DFB/train_detector.py similarity index 54% rename from base_miner/UCF/train_detector.py rename to base_miner/DFB/train_detector.py index 9e877b1a..b5abf663 100644 --- a/base_miner/UCF/train_detector.py +++ b/base_miner/DFB/train_detector.py @@ -34,35 +34,45 @@ import torch.distributed as dist from torch.utils.data import DataLoader -from optimizor.SAM import SAM -from optimizor.LinearLR import LinearDecayLR - -from trainer.trainer import Trainer -from detectors import DETECTOR -from metrics.utils import parse_metric_for_print -from logger import create_logger, RankFilter +from base_miner.DFB.optimizor.SAM import SAM +from base_miner.DFB.optimizor.LinearLR import LinearDecayLR +from base_miner.DFB.config.helpers import save_config +from base_miner.DFB.trainer.trainer import Trainer +from base_miner.DFB.detectors import DETECTOR +from base_miner.DFB.metrics.utils import parse_metric_for_print +from base_miner.DFB.logger import create_logger, RankFilter from huggingface_hub import hf_hub_download # BitMind imports (not from original Deepfake Bench repo) -from bitmind.utils.data import load_and_split_datasets, create_real_fake_datasets -from bitmind.image_transforms import base_transforms, random_aug_transforms, ucf_transforms -from bitmind.constants import DATASET_META, FACE_TRAINING_DATASET_META -from config.constants import ( - CONFIG_PATH, +from base_miner.datasets.util import load_and_split_datasets, create_real_fake_datasets +from base_miner.config import VIDEO_DATASETS, IMAGE_DATASETS, FACE_IMAGE_DATASETS +from bitmind.utils.image_transforms import ( + get_base_transforms, + get_random_augmentations, + get_ucf_base_transforms, + get_tall_base_transforms +) +from base_miner.DFB.config.constants import ( + CONFIG_PATHS, WEIGHTS_DIR, - HF_REPO, - BACKBONE_CKPT + HF_REPOS ) +TRANSFORM_FNS = { + 'UCF': get_ucf_base_transforms, + 'TALL': get_tall_base_transforms +} + parser = argparse.ArgumentParser(description='Process some paths.') -parser.add_argument('--detector_path', type=str, default=CONFIG_PATH, help='path to detector YAML file') +parser.add_argument('--detector', type=str, choices=['UCF', 'TALL'], required=True, help='Detector name') +parser.add_argument('--modality', type=str, default='image', choices=['image', 'video']) parser.add_argument('--faces_only', dest='faces_only', action='store_true', default=False) parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True) parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True) parser.add_argument("--ddp", action='store_true', default=False) -parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='gpu', +parser.add_argument('--device', type=str, default='cuda', help='Specify whether to use CPU or GPU. Defaults to GPU if available.') parser.add_argument('--gpu_id', type=int, default=0, help='Specify the GPU ID to use if using GPU. Defaults to 0.') parser.add_argument('--workers', type=int, default=os.cpu_count() - 1, @@ -71,58 +81,6 @@ args = parser.parse_args() -def set_device(device=args.device, gpu_id=args.gpu_id): - """ - Determine the device to use based on user input and system availability. - - Parameters: - device_arg (str, optional): The device specified by the user ('cpu', 'gpu', or None). - Defaults to None, in which case it automatically chooses. - gpu_id (int, optional): The specific GPU ID to set if using a GPU (defaults to 0). - - Returns: - torch.device: The device to be used (either 'cuda' or 'cpu'). - """ - if device == 'cpu': - return torch.device("cpu") - elif device == 'gpu': - if torch.cuda.is_available(): - torch.cuda.set_device(gpu_id) # Set the GPU ID - return torch.device(f"cuda:{gpu_id}") - else: - print("Warning: GPU specified but not available. Falling back to CPU.") - return torch.device("cpu") - else: - # Default: Use GPU if available, otherwise fall back to CPU - if torch.cuda.is_available(): - torch.cuda.set_device(gpu_id) - return torch.device(f"cuda:{gpu_id}") - else: - return torch.device("cpu") - - -def ensure_backbone_is_available(logger, - weights_dir=WEIGHTS_DIR, - model_filename=BACKBONE_CKPT, - hugging_face_repo_name=HF_REPO): - - destination_path = Path(weights_dir) / Path(model_filename) - if not destination_path.parent.exists(): - destination_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Created directory {destination_path.parent}.") - if not destination_path.exists(): - model_path = hf_hub_download(hugging_face_repo_name, model_filename) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = torch.load(model_path, map_location=device) - torch.save(model, destination_path) - del model - if torch.cuda.is_available(): - torch.cuda.empty_cache() - logger.info(f"Downloaded backbone {model_filename} to {destination_path}.") - else: - logger.info(f"{model_filename} backbone already present at {destination_path}.") - - def init_seed(config): if config['manualSeed'] is None: config['manualSeed'] = random.randint(1, 10000) @@ -132,28 +90,13 @@ def init_seed(config): torch.cuda.manual_seed_all(config['manualSeed']) -def custom_collate_fn(batch): - images, labels, source_labels = zip(*batch) - - images = torch.stack(images, dim=0) # Stack image tensors into a single tensor - labels = torch.LongTensor(labels) - source_labels = torch.LongTensor(source_labels) - - data_dict = { - 'image': images, - 'label': labels, - 'label_spe': source_labels, - 'landmark': None, - 'mask': None - } - return data_dict - - def prepare_datasets(config, logger): start_time = log_start_time(logger, "Loading and splitting individual datasets") - fake_datasets = load_and_split_datasets(config['dataset_meta']['fake']) - real_datasets = load_and_split_datasets(config['dataset_meta']['real']) + fake_datasets = load_and_split_datasets( + config['dataset_meta']['fake'], modality=config['modality'], split_transforms=config['split_transforms']) + real_datasets = load_and_split_datasets( + config['dataset_meta']['real'], modality=config['modality'], split_transforms=config['split_transforms']) log_finish_time(logger, "Loading and splitting individual datasets", start_time) @@ -161,10 +104,7 @@ def prepare_datasets(config, logger): train_dataset, val_dataset, test_dataset, source_label_mapping = create_real_fake_datasets( real_datasets, fake_datasets, - config['split_transforms']['train'], - config['split_transforms']['validation'], - config['split_transforms']['test'], - source_labels=True, + source_labels=True, # TODO UCF Only group_sources_by_name=True) log_finish_time(logger, "Creating real fake dataset splits", start_time) @@ -175,7 +115,7 @@ def prepare_datasets(config, logger): shuffle=True, num_workers=config['workers'], drop_last=True, - collate_fn=custom_collate_fn) + collate_fn=train_dataset.collate_fn) val_loader = torch.utils.data.DataLoader( val_dataset, @@ -183,7 +123,7 @@ def prepare_datasets(config, logger): shuffle=True, num_workers=config['workers'], drop_last=True, - collate_fn=custom_collate_fn) + collate_fn=val_dataset.collate_fn) test_loader = torch.utils.data.DataLoader( test_dataset, @@ -191,7 +131,7 @@ def prepare_datasets(config, logger): shuffle=True, num_workers=config['workers'], drop_last=True, - collate_fn=custom_collate_fn) + collate_fn=train_dataset.collate_fn) print(f"Train size: {len(train_loader.dataset)}") print(f"Validation size: {len(val_loader.dataset)}") @@ -284,137 +224,53 @@ def log_finish_time(logger, process_name, start_time): logger.info(f"{process_name} Elapsed Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds") -def save_config(config, outputs_dir): - """ - Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved. - Also, lists like 'mean' and 'std' are saved in flow style (on a single line). - - Args: - config (dict): The configuration dictionary to save. - outputs_dir (str): The directory path where the files will be saved. - """ - - def is_basic_type(value): - """ - Check if a value is a basic data type that can be saved in YAML. - Basic types include int, float, str, bool, list, and dict. - """ - return isinstance(value, (int, float, str, bool, list, dict, type(None))) - - def filter_dict(data_dict): - """ - Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects). - """ - if not isinstance(data_dict, dict): - return data_dict - - filtered_dict = {} - for key, value in data_dict.items(): - if isinstance(value, dict): - # Recursively filter nested dictionaries - nested_dict = filter_dict(value) - if nested_dict: # Only add non-empty dictionaries - filtered_dict[key] = nested_dict - elif is_basic_type(value): - # Add if the value is a basic type - filtered_dict[key] = value - else: - # Skip the key if the value is not a basic type (e.g., an object) - print(f"Skipping key '{key}' because its value is of type {type(value)}") - - return filtered_dict - - def save_dict_to_yaml(data_dict, file_path): - """ - Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object. - Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style. - - Args: - data_dict (dict): The dictionary to save. - file_path (str): The local file path where the YAML file will be saved. - """ - - # Custom representer for lists to force flow style (compact lists) - class FlowStyleList(list): - pass - - def flow_style_list_representer(dumper, data): - return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True) - - yaml.add_representer(FlowStyleList, flow_style_list_representer) - - # Preprocess specific lists to be in flow style - if 'mean' in data_dict: - data_dict['mean'] = FlowStyleList(data_dict['mean']) - if 'std' in data_dict: - data_dict['std'] = FlowStyleList(data_dict['std']) - - try: - # Filter the dictionary - filtered_dict = filter_dict(data_dict) - - # Save the filtered dictionary as YAML - with open(file_path, 'w') as f: - yaml.dump(filtered_dict, f, default_flow_style=False) # Save with default block style except for FlowStyleList - print(f"Filtered dictionary successfully saved to {file_path}") - except Exception as e: - print(f"Error saving dictionary to YAML: {e}") - - # Save as YAML - save_dict_to_yaml(config, outputs_dir + '/config.yaml') - - def main(): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + detector_config_path = CONFIG_PATHS[args.detector] + # parse options and load config - with open(args.detector_path, 'r') as f: + with open(detector_config_path, 'r') as f: config = yaml.safe_load(f) - with open(os.getcwd() + '/config/train_config.yaml', 'r') as f: - config2 = yaml.safe_load(f) - if 'label_dict' in config: - config2['label_dict']=config['label_dict'] - config.update(config2) + config['log_dir'] = os.getcwd() + config['device'] = args.device + config['modality'] = args.modality config['workers'] = args.workers - config['device'] = set_device(args.device, args.gpu_id) config['gpu_id'] = args.gpu_id - if config['dry_run']: - config['nEpochs'] = 0 - config['save_feat'] = False - if args.epochs: config['nEpochs'] = args.epochs + tforms = TRANSFORM_FNS.get(args.detector, None)((256, 256)) config['split_transforms'] = { - 'train': ucf_transforms, - 'validation': ucf_transforms, - 'test': ucf_transforms + 'train': tforms, + 'validation': tforms, + 'test': tforms } - config['dataset_meta'] = FACE_TRAINING_DATASET_META if args.faces_only else DATASET_META + if config['modality'] == 'video': + config['dataset_meta'] = VIDEO_DATASETS + elif config['modality'] == 'image': + if args.faces_only: + config['dataset_meta'] = FACE_IMAGE_DATASETS + else: + config['dataset_meta'] = IMAGE_DATASETS + dataset_names = [item["path"] for datasets in config['dataset_meta'].values() for item in datasets] config['train_dataset'] = dataset_names config['save_ckpt'] = args.save_ckpt config['save_feat'] = args.save_feat - - if config['lmdb']: - config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' - + # create logger timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') - - outputs_dir = os.path.join( - config['log_dir'], - config['model_name'] + '_' + timenow - ) - + outputs_dir = os.path.join(config['log_dir'], 'logs', config['model_name'] + '_' + timenow) + config['log_dir'] = outputs_dir + os.makedirs(outputs_dir, exist_ok=True) logger = create_logger(os.path.join(outputs_dir, 'training.log')) - config['log_dir'] = outputs_dir logger.info('Save log to {}'.format(outputs_dir)) config['ddp']= args.ddp @@ -437,29 +293,37 @@ def main(): ) logger.addFilter(RankFilter(0)) - ensure_backbone_is_available( - logger=logger, - model_filename=config['pretrained'].split('/')[-1], - hugging_face_repo_name='bitmind/bm-ucf' - ) - - # prepare the model (detector) + # download weights if huggingface repo provided. + # Note: TALL currently skips this and downloads from github + pretrained_config = config.get('pretrained', {}) + if not isinstance(pretrained_config, str): + hf_repo = pretrained_config.get('hf_repo') + weights_filename = pretrained_config.get('filename') + if hf_repo and weights_filename: + local_path = Path(WEIGHTS_DIR) / weights_filename + if not local_path.exists(): + model_path = hf_hub_download( + repo_id=hf_repo, + filename=weights_filename, + local_dir=WEIGHTS_DIR + ) + logger.info(f"Downloaded {hf_repo}/{weights_filename} to {model_path}") + else: + model_path = local_path + logger.info(f"{model_path} exists, skipping download") + config['pretrained']['local_path'] = str(model_path) + else: + logger.info("Pretrain config is a url, falling back to detector-specific download") + + # prepare model and trainer model_class = DETECTOR[config['model_name']] model = model_class(config).to(config['device']) - - # prepare the optimizer - optimizer = choose_optimizer(model, config) - # prepare the scheduler + optimizer = choose_optimizer(model, config) scheduler = choose_scheduler(config, optimizer) - - # prepare the metric metric_scoring = choose_metric(config) - - # prepare the trainer trainer = Trainer(config, model, config['device'], optimizer, scheduler, logger, metric_scoring) - # print configuration logger.info("--------------- Configuration ---------------") params_string = "Parameters: \n" for key, value in config.items(): @@ -474,10 +338,10 @@ def main(): for epoch in range(config['start_epoch'], config['nEpochs'] + 1): trainer.model.epoch = epoch best_metric = trainer.train_epoch( - epoch, - train_data_loader=train_loader, - validation_data_loaders={'val':val_loader} - ) + epoch, + train_data_loader=train_loader, + validation_data_loaders={'val':val_loader} + ) if best_metric is not None: logger.info(f"===> Epoch[{epoch}] end with validation {metric_scoring}: {parse_metric_for_print(best_metric)}!") @@ -488,10 +352,7 @@ def main(): start_time = log_start_time(logger, "Test") trainer.eval(eval_data_loaders={'test':test_loader}, eval_stage="test") log_finish_time(logger, "Test", start_time) - - # update - if 'svdd' in config['model_name']: - model.update_R(epoch) + if scheduler is not None: scheduler.step() @@ -504,4 +365,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/base_miner/UCF/trainer/trainer.py b/base_miner/DFB/trainer/trainer.py similarity index 98% rename from base_miner/UCF/trainer/trainer.py rename to base_miner/DFB/trainer/trainer.py index b8b0dff7..d4287c84 100644 --- a/base_miner/UCF/trainer/trainer.py +++ b/base_miner/DFB/trainer/trainer.py @@ -227,7 +227,7 @@ def train_epoch( losses, predictions=self.train_step(data_dict) # update learning rate - if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']: + if self.config.get('SWA', False) and epoch>self.config['swa_start']: self.swa_model.update_parameters(self.model) # compute training metric for each batch data @@ -246,7 +246,7 @@ def train_epoch( # run tensorboard to visualize the training process if iteration % 300 == 0 and self.config['gpu_id']==0: - if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']): + if self.config.get('SWA', False) and (epoch>self.config['swa_start'] or self.config['dry_run']): self.scheduler.step() # info for loss loss_str = f"Iter: {step_cnt} " @@ -331,7 +331,6 @@ def eval_one_dataset(self, data_loader): data_dict[key] = data_dict[key].cuda() # model forward without considering gradient computation predictions = self.inference(data_dict) #dict with keys cls, feat - label_lists += list(data_dict['label'].cpu().detach().numpy()) # Get the predicted class for each sample in the batch _, predicted_classes = torch.max(predictions['cls'], dim=1) diff --git a/base_miner/NPR/train_detector.py b/base_miner/NPR/train_detector.py index bef1f668..64d4e608 100644 --- a/base_miner/NPR/train_detector.py +++ b/base_miner/NPR/train_detector.py @@ -1,6 +1,4 @@ from tensorboardX import SummaryWriter -from validate import validate -from networks.trainer import Trainer from torch.utils.data import DataLoader import numpy as np import os @@ -8,10 +6,12 @@ import random import torch -from bitmind.constants import DATASET_META -from bitmind.image_transforms import base_transforms, random_aug_transforms -from bitmind.utils.data import load_and_split_datasets, create_real_fake_datasets -from options import TrainOptions +from base_miner.NPR.validate import validate +from base_miner.NPR.networks.trainer import Trainer +from base_miner.config import IMAGE_DATASETS as DATASET_META +from base_miner.NPR.options import TrainOptions +from bitmind.utils.image_transforms import get_base_transforms, get_random_augmentations +from base_miner.datasets.util import load_and_split_datasets, create_real_fake_datasets def seed_torch(seed=1029): @@ -34,14 +34,19 @@ def main(): val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val")) # RealFakeDataseta will limit the number of images sampled per dataset to the length of the smallest dataset - real_datasets = load_and_split_datasets(DATASET_META['real']) - fake_datasets = load_and_split_datasets(DATASET_META['fake']) + base_transforms = get_base_transforms() + random_augs = get_random_augmentations() + split_transforms = { + 'train': random_augs, + 'validation': base_transforms, + 'test': base_transforms + } + real_datasets = load_and_split_datasets( + DATASET_META['real'], modality='image', split_transforms=split_transforms) + fake_datasets = load_and_split_datasets( + DATASET_META['fake'], modality='image', split_transforms=split_transforms) train_dataset, val_dataset, test_dataset = create_real_fake_datasets( - real_datasets, - fake_datasets, - train_transforms=random_aug_transforms, - val_transforms=base_transforms, - test_transforms=base_transforms) + real_datasets, fake_datasets) train_loader = DataLoader( train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=lambda d: tuple(d)) diff --git a/base_miner/UCF/config/constants.py b/base_miner/UCF/config/constants.py deleted file mode 100644 index 61d2ad6a..00000000 --- a/base_miner/UCF/config/constants.py +++ /dev/null @@ -1,15 +0,0 @@ -import os - -# Path to the directory containing the constants.py file -CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__)) - -# The base directory for UCF-related files, i.e., UCF directory -UCF_BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/UCF/ -# Absolute paths for the required files and directories -CONFIG_PATH = os.path.join(CONFIGS_DIR, "ucf.yaml") # Path to the ucf.yaml file -WEIGHTS_DIR = os.path.join(UCF_BASE_PATH, "weights/") # Path to pretrained weights directory - -HF_REPO = "bitmind/ucf" -BACKBONE_CKPT = "xception_best.pth" - -DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(UCF_BASE_PATH, "../../bitmind/dataset_processing/dlib_tools/shape_predictor_81_face_landmarks.dat")) \ No newline at end of file diff --git a/base_miner/UCF/config/train_config.yaml b/base_miner/UCF/config/train_config.yaml deleted file mode 100644 index cd94d867..00000000 --- a/base_miner/UCF/config/train_config.yaml +++ /dev/null @@ -1,9 +0,0 @@ -mode: train -lmdb: True -dry_run: false -rgb_dir: './datasets/rgb' -lmdb_dir: './datasets/lmdb' -dataset_json_folder: './preprocessing/dataset_json' -SWA: False -save_avg: True -log_dir: ./logs/training/ \ No newline at end of file diff --git a/base_miner/__init__.py b/base_miner/__init__.py index 77486091..e69de29b 100644 --- a/base_miner/__init__.py +++ b/base_miner/__init__.py @@ -1,3 +0,0 @@ -from .registry import DETECTOR_REGISTRY, GATE_REGISTRY -from .deepfake_detectors import NPRDetector, UCFDetector, CAMODetector -from .gating_mechanisms import FaceGate, GatingMechanism \ No newline at end of file diff --git a/base_miner/config.py b/base_miner/config.py new file mode 100644 index 00000000..e88e822d --- /dev/null +++ b/base_miner/config.py @@ -0,0 +1,42 @@ +from pathlib import Path + +HUGGINGFACE_CACHE_DIR: Path = Path.home() / '.cache' / 'huggingface' +TARGET_IMAGE_SIZE = (256, 256) + + +IMAGE_DATASETS = { + "real": [ + {"path": "bitmind/bm-real"}, + {"path": "bitmind/open-images-v7"}, + {"path": "bitmind/celeb-a-hq"}, + {"path": "bitmind/ffhq-256"}, + {"path": "bitmind/MS-COCO-unique-256"} + ], + "fake": [ + {"path": "bitmind/bm-realvisxl"}, + {"path": "bitmind/bm-mobius"}, + {"path": "bitmind/bm-sdxl"} + ] +} + +# see bitmind-subnet/create_video_dataset_example.sh +VIDEO_DATASETS = { + "real": [ + {"path": ""} + ], + "fake": [ + {"path": ""} + ] +} + +FACE_IMAGE_DATASETS = { + "real": [ + {"path": "bitmind/ffhq-256_training_faces", "name": "base_transforms"}, + {"path": "bitmind/celeb-a-hq_training_faces", "name": "base_transforms"} + + ], + "fake": [ + {"path": "bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces", "name": "base_transforms"}, + {"path": "bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces", "name": "base_transforms"} + ] +} diff --git a/base_miner/datasets/__init__.py b/base_miner/datasets/__init__.py new file mode 100644 index 00000000..78111baa --- /dev/null +++ b/base_miner/datasets/__init__.py @@ -0,0 +1,4 @@ +from .base_dataset import BaseDataset +from .image_dataset import ImageDataset +from .video_dataset import VideoDataset +from .real_fake_dataset import RealFakeDataset diff --git a/base_miner/datasets/base_dataset.py b/base_miner/datasets/base_dataset.py new file mode 100644 index 00000000..3dcc8887 --- /dev/null +++ b/base_miner/datasets/base_dataset.py @@ -0,0 +1,79 @@ +from abc import ABC, abstractmethod +from datasets import Dataset +from typing import Optional +from torchvision.transforms import Compose + +from base_miner.datasets.download_data import load_huggingface_dataset + + +class BaseDataset(ABC): + def __init__( + self, + huggingface_dataset_path: Optional[str] = None, + huggingface_dataset_split: str = 'train', + huggingface_dataset_name: Optional[str] = None, + huggingface_dataset: Optional[Dataset] = None, + download_mode: Optional[str] = None, + transforms: Optional[Compose] = None + ): + """Base class for dataset implementations. + + Args: + huggingface_dataset_path (str, optional): Path to the Hugging Face dataset. + Can be a publicly hosted dataset (/) or + local directory (imagefolder:) + huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'. + huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset. + huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance. + download_mode (str, optional): Download mode for the dataset. + Can be None or "force_redownload" + """ + self.huggingface_dataset_path = None + self.huggingface_dataset_split = huggingface_dataset_split + self.huggingface_dataset_name = None + self.dataset = None + self.transforms = transforms + + if huggingface_dataset_path is None and huggingface_dataset is None: + raise ValueError("Either huggingface_dataset_path or huggingface_dataset must be provided.") + + # If a dataset is directly provided, use it + if huggingface_dataset is not None: + self.dataset = huggingface_dataset + self.huggingface_dataset_path = self.dataset.info.dataset_name + self.huggingface_dataset_name = self.dataset.info.config_name + try: + self.huggingface_dataset_split = list(self.dataset.info.splits.keys())[0] + except AttributeError as e: + self.huggingface_data_split = None + + else: + # Store the initialization parameters + self.huggingface_dataset_path = huggingface_dataset_path + self.huggingface_dataset_name = huggingface_dataset_name + self.dataset = load_huggingface_dataset( + huggingface_dataset_path, + huggingface_dataset_split, + huggingface_dataset_name, + download_mode) + + @abstractmethod + def __getitem__(self, index: int) -> dict: + """Get an item from the dataset. + + Args: + index (int): Index of the item to retrieve. + + Returns: + dict: Dictionary containing the item data. + """ + pass + + @abstractmethod + def __len__(self) -> int: + """Get the length of the dataset. + + Returns: + int: Length of the dataset. + """ + pass diff --git a/base_miner/datasets/create_video_dataset.py b/base_miner/datasets/create_video_dataset.py new file mode 100644 index 00000000..c834cdf7 --- /dev/null +++ b/base_miner/datasets/create_video_dataset.py @@ -0,0 +1,305 @@ +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Union, Tuple +from multiprocessing import Pool, cpu_count + +import cv2 +import glob +import os + +import argparse +from PIL import Image +from datasets import Dataset, DatasetInfo, Image as HFImage, Split +from datasets.features import Features, Sequence, Value +from tqdm import tqdm + + +def process_single_video(args: Tuple[Path, Path, int, Optional[int], bool]) -> Tuple[str, int]: + """ + Extract frames from a single video + + Args: + args: Tuple containing (video_file, output_dir, frame_rate, max_frames, overwrite) + + Returns: + Tuple of (video_name, number_of_frames_saved) + """ + video_file, output_dir, frame_rate, max_frames, overwrite = args + video_name = video_file.stem + video_output_dir = output_dir / video_name + + if video_output_dir.exists() and not overwrite: + return video_name, 0 + + video_output_dir.mkdir(parents=True, exist_ok=True) + + video_capture = cv2.VideoCapture(str(video_file)) + frame_idx = 0 + saved_frame_count = 0 + + while True: + success, frame = video_capture.read() + if not success or (max_frames is not None and saved_frame_count >= max_frames): + break + + if frame_idx % frame_rate == 0: + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + frame_filename = video_output_dir / f"frame_{frame_idx:05d}.png" + pil_image.save(frame_filename) + saved_frame_count += 1 + + frame_idx += 1 + + video_capture.release() + return video_name, saved_frame_count + + +def extract_frames_from_videos( + input_dir: Union[str, Path], + output_dir: Union[str, Path], + num_videos: Optional[int] = None, + frame_rate: int = 1, + max_frames: Optional[int] = None, + overwrite: bool = False, + num_workers: Optional[int] = None +) -> None: + """ + Extract frames from videos (mp4s -> directories of PILs) using multiprocessing + + Args: + input_dir: Directory containing input MP4 files + output_dir: Directory where extracted frames will be saved + num_videos: Number of videos to process. If None, processes all videos + frame_rate: Extract one frame every 'frame_rate' frames + max_frames: Maximum number of frames to extract per video + overwrite: If True, overwrites existing frame directories + num_workers: Number of worker processes to use. If None, uses CPU count + """ + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + video_files = list(input_dir.glob("*.mp4")) + if num_videos is not None: + video_files = video_files[:num_videos] + + if not num_workers: + num_workers = cpu_count() + + print(f'Processing {len(video_files)} videos using {num_workers} workers') + + # Prepare arguments for each video + process_args = [ + (video_file, output_dir, frame_rate, max_frames, overwrite) + for video_file in video_files + ] + + # Process videos in parallel + with Pool(num_workers) as pool: + results = list(tqdm( + pool.imap(process_single_video, process_args), + total=len(video_files), + desc="Extracting frames" + )) + + # Print results + for video_name, frame_count in results: + if frame_count > 0: + print(f"Extracted {frame_count} frames from {video_name}") + else: + print(f"Skipped {video_name} (already exists)") + + +def create_video_frames_dataset( + frames_dir: Union[str, Path], + dataset_name: str = "video_frames", + validate_frames: bool = False, + delete_corrupted: bool = False, +) -> Dataset: + """Create a HuggingFace dataset from a directory of video frames.""" + frames_dir = Path(frames_dir) + video_data: Dict[str, Dict[str, List]] = defaultdict(lambda: {'frames': [], 'frame_numbers': []}) + + for video_dir in tqdm(sorted(os.listdir(frames_dir)), desc='processing video frames'): + video_path = frames_dir / video_dir + + if not video_path.is_dir(): + continue + + image_files = [] + for ext in ('*.png', '*.jpg', '*.jpeg'): + image_files.extend(glob.glob(str(video_path / ext))) + + image_files.sort() + + # Validate images before adding them to the dataset + if validate_frames: + valid_frames = [] + valid_frame_numbers = [] + for img_path in tqdm(image_files, desc="Checking image files"): + try: + # Attempt to fully load the image to verify it's valid + with Image.open(img_path) as img: + img.load() # Force load the image data + frame_num = int(Path(img_path).stem.split('_')[1]) + valid_frames.append(img_path) + valid_frame_numbers.append(frame_num) + except Exception as e: + print(f"Skipping corrupted image {img_path}: {str(e)}") + if delete_corrupted: + print(f"Deleting {img_path} (delete_corrupted = true)") + Path(img_path).unlink() + continue + if valid_frames: # Only add videos that have valid frames + video_data[video_dir]['frames'] = valid_frames + video_data[video_dir]['frame_numbers'] = valid_frame_numbers + else: + video_data[video_dir]['frames'] = image_files + video_data[video_dir]['frame_numbers'] = list(range(len(image_files))) + print(video_data[video_dir]['frames'][:10]) + print(video_data[video_dir]['frame_numbers'][:10]) + + dataset_dict = { + "video_id": [], + "frames": [], + "frame_numbers": [], + "num_frames": [] + } + + for video_id, data in video_data.items(): + if data['frames']: # Double check we have frames + dataset_dict["video_id"].append(video_id) + dataset_dict["frames"].append(data["frames"]) + dataset_dict["frame_numbers"].append(data["frame_numbers"]) + dataset_dict["num_frames"].append(len(data["frames"])) + + features = Features({ + "video_id": Value("string"), + "frames": Sequence(Value("string")), + "frame_numbers": Sequence(Value("int64")), + "num_frames": Value("int64") + }) + + dataset_info = DatasetInfo( + description="Video frames dataset", + features=features, + supervised_keys=None, + homepage="", + citation="", + task_templates=None, + dataset_name=dataset_name + ) + + # Create dataset with validated images + dataset = Dataset.from_dict( + dataset_dict, + info=dataset_info, + features=features + ) + + # Convert to HuggingFace image format with error handling + def convert_frames_to_images(example): + converted_frames = [] + for frame_path in example["frames"]: + try: + converted_frames.append(HFImage().encode_example(frame_path)) + except Exception as e: + print(f"Error converting {frame_path}: {str(e)}") + continue + example["frames"] = converted_frames + return example + + #dataset = dataset.map(convert_frames_to_images) + return dataset + + +def main() -> None: + """Parse command line arguments and run the dataset creation pipeline.""" + parser = argparse.ArgumentParser( + description="Extract frames from videos and create a HuggingFace dataset." + ) + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Path to the directory containing input MP4 files." + ) + parser.add_argument( + "--frames_dir", + type=str, + required=True, + help="Path to the directory where extracted frames will be saved." + ) + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Path where the HuggingFace dataset will be saved." + ) + parser.add_argument( + "--num_videos", + type=int, + default=None, + help="Number of videos to process. If not specified, processes all videos." + ) + parser.add_argument( + "--frame_rate", + type=int, + default=5, + help="Extract one frame every 'frame_rate' frames." + ) + parser.add_argument( + "--max_frames", + type=int, + default=None, + help="Maximum number of frames to extract per video." + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="If set, overwrites existing frame directories." + ) + parser.add_argument( + "--skip_extraction", + action="store_true", + help="If set, skips the frame extraction step and only creates the dataset." + ) + parser.add_argument( + "--dataset_name", + type=str, + default="video_frames", + help="Name for the local HuggingFace dataset to be created." + ) + + args = parser.parse_args() + + if not args.skip_extraction: + print("Extracting frames from videos...") + extract_frames_from_videos( + input_dir=args.input_dir, + output_dir=args.frames_dir, + num_videos=args.num_videos, + frame_rate=args.frame_rate, + max_frames=args.max_frames, + overwrite=args.overwrite + ) + + print("\nCreating HuggingFace dataset...") + dataset = create_video_frames_dataset( + args.frames_dir, + dataset_name=args.dataset_name + ) + print(dataset.info) + print(f"\nSaving dataset to {args.dataset_dir}") + dataset.save_to_disk(args.dataset_dir) + + print(f"\nDataset creation complete!") + print(f"Total number of videos: {len(dataset)}") + print(f"Features: {dataset.features}") + print("Frame counts:", dataset["num_frames"]) + print(f"Dataset name: {dataset.info.dataset_name}") + + +if __name__ == "__main__": + main() diff --git a/bitmind/download_data.py b/base_miner/datasets/download_data.py similarity index 53% rename from bitmind/download_data.py rename to base_miner/datasets/download_data.py index fcb68559..485a88ab 100644 --- a/bitmind/download_data.py +++ b/base_miner/datasets/download_data.py @@ -11,35 +11,50 @@ import glob import requests -from bitmind.constants import DATASET_META, HUGGINGFACE_CACHE_DIR +from base_miner.config import IMAGE_DATASETS, HUGGINGFACE_CACHE_DIR datasets.logging.set_verbosity_warning() datasets.disable_progress_bar() +from datasets import load_dataset, load_from_disk +from typing import Optional +import os + def load_huggingface_dataset( path: str, split: str = 'train', name: Optional[str] = None, - download_mode: str = 'reuse_cache_if_exists', + download_mode: str = 'reuse_cache_if_exists' ) -> datasets.Dataset: - """ - Load a dataset from Hugging Face or a local directory. + """Load a dataset from Hugging Face or a local directory. Args: - path (str): Path to the dataset or 'imagefolder:' for image folder. Can either be to a publicly - hosted huggingface datset with the format / or a local directory with the format - imagefolder: - split (str, optional): Name of the dataset split to load (default: None). - Make sure to check what splits are available for the datasets you're working with. - name (str, optional): Name of the dataset (if loading from Hugging Face, default: None). - Some huggingface datasets provide various subets of different sizes, which can be accessed via thi - parameter. - download_mode (str, optional): Download mode for the dataset (if loading from Hugging Face, default: None). - can be None or "force_redownload" + path (str): Path to dataset. Can be: + - A Hugging Face dataset path (/) + - An image folder path (imagefolder:) + - A local path to a saved dataset (for load_from_disk) + split (str, optional): Dataset split to load (default: 'train') + name (str, optional): Dataset configuration name (default: None) + download_mode (str, optional): Download mode for Hugging Face datasets + Returns: - Union[dict, load_dataset.Dataset]: The loaded dataset or a specific split of the dataset as requested. + Dataset: The loaded dataset or requested split """ + # Check if it's a local path suitable for load_from_disk + if not path.startswith('imagefolder:') and os.path.exists(path): + try: + # Look for dataset artifacts that indicate this is a saved dataset + dataset_files = {'dataset_info.json', 'state.json', 'data'} + path_contents = set(os.listdir(path)) + if dataset_files.intersection(path_contents): + dataset = load_from_disk(path) + if split is None: + return dataset + return dataset[split] + except Exception as e: + print(f"Attempted load_from_disk but failed: {e}") + if 'imagefolder' in path: _, directory = path.split(':') if name: @@ -55,13 +70,11 @@ def load_huggingface_dataset( if split is None: return dataset - return dataset[split] def download_image(url: str) -> Image.Image: - """ - Download an image from a URL. + """Download an image from a URL. Args: url (str): The URL of the image to download. @@ -74,37 +87,11 @@ def download_image(url: str) -> Image.Image: if response.status_code == 200: image_data = BytesIO(response.content) return Image.open(image_data) - else: #print(f"Failed to download image: {response.status_code}") return None -def clear_cache(cache_dir): - """Clears lock files and incomplete downloads from the cache directory.""" - # Find lock and incomplete files - lock_files = glob.glob(cache_dir + "/*lock") - incomplete_files = glob.glob(cache_dir + "/downloads/**/*.incomplete", recursive=True) - try: - if lock_files: - subprocess.run(["rm", *lock_files], check=True) - if incomplete_files: - for file in incomplete_files: - os.remove(file) - print("Hugging Face cache lock files cleared successfully.") - except Exception as e: - print(f"Failed to clear Hugging Face cache lock files: {e}") - - -def fix_permissions(path): - """Attempts to fix permission issues on a given path.""" - try: - subprocess.run(["chmod", "-R", "775", path], check=True) - print(f"Fixed permissions for {path}") - except subprocess.CalledProcessError as e: - print(f"Failed to fix permissions for {path}: {e}") - - def download_dataset( dataset_path: str, dataset_name: str, @@ -112,26 +99,29 @@ def download_dataset( cache_dir: str, max_wait: int = 300 ): - """ Downloads the datasets present in datasets.json with exponential backoff - download_mode: either 'force_redownload' or 'use_cache_if_exists' - cache_dir: huggingface cache directory. ~/.cache/huggingface by default + """Downloads the datasets present in datasets.json with exponential backoff. + + Args: + dataset_path (str): Path to the dataset on Hugging Face + dataset_name (str): Name/config of the dataset subset + download_mode (str): Either 'force_redownload' or 'use_cache_if_exists' + cache_dir (str): Huggingface cache directory. ~/.cache/huggingface by default + max_wait (int, optional): Maximum wait time between retries in seconds. Defaults to 300. + + Returns: + Dataset: The downloaded Hugging Face dataset """ - retry_wait = 10 # initial wait time in seconds - attempts = 0 # initialize attempts counter + retry_wait = 10 # initial wait time in seconds + attempts = 0 print(f"Downloading {dataset_path} (subset={dataset_name}) dataset...") while True: try: - if dataset_name: - dataset = load_dataset(dataset_path, - name=dataset_name, #config/subset name - cache_dir=cache_dir, - download_mode=download_mode, - trust_remote_code=True) - else: - dataset = load_dataset(dataset_path, - cache_dir=cache_dir, - download_mode=download_mode, - trust_remote_code=True) + dataset = load_dataset( + dataset_path, + name=dataset_name, # config/subset name + cache_dir=cache_dir, + download_mode=download_mode, + trust_remote_code=True) break except Exception as e: print(e) @@ -141,7 +131,7 @@ def download_dataset( file_path = str(e).split(": '")[1].rstrip("'") print(f"Permission error at {file_path}, attempting to fix...") fix_permissions(file_path) # Attempt to fix permissions directly - clear_cache(cache_dir) # Clear cache to remove any incomplete or locked files + clean_cache(cache_dir) # Clear cache to remove any incomplete or locked files else: print(f"Unexpected error, stopping retries for {dataset_path}") raise e @@ -158,9 +148,42 @@ def download_dataset( return dataset +def clean_cache(cache_dir): + """Clears lock files and incomplete downloads from the cache directory. + + Args: + cache_dir (str): Path to the Hugging Face cache directory + """ + lock_files = glob.glob(os.path.join(cache_dir, "*lock")) + incomplete_files = glob.glob(os.path.join(cache_dir, "downloads", "**", "*.incomplete"), recursive=True) + try: + if lock_files: + subprocess.run(["rm", *lock_files], check=True) + if incomplete_files: + for file in incomplete_files: + os.remove(file) + print("Hugging Face cache lock files cleared successfully.") + except Exception as e: + print(f"Failed to clear Hugging Face cache lock files: {e}") + + +def fix_permissions(path): + """Attempts to fix permission issues on a given path. + + Args: + path (str): Path to fix permissions for + """ + try: + subprocess.run(["chmod", "-R", "775", path], check=True) + print(f"Fixed permissions for {path}") + except subprocess.CalledProcessError as e: + print(f"Failed to fix permissions for {path}: {e}") + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Download Hugging Face datasets for validator challenge generation and miner training.') parser.add_argument('--force_redownload', action='store_true', help='force redownload of datasets') + parser.add_argument('--modality', default='image', choices=['video', 'image'], help='download image or video datasets') parser.add_argument('--cache_dir', type=str, default=HUGGINGFACE_CACHE_DIR, help='huggingface cache directory') args = parser.parse_args() @@ -169,8 +192,17 @@ def download_dataset( download_mode = "force_redownload" os.makedirs(args.cache_dir, exist_ok=True) - clear_cache(args.cache_dir) # Clear the cache of lock and incomplete files. - - for dataset_type in DATASET_META: - for dataset in DATASET_META[dataset_type]: - download_dataset(dataset['path'], dataset.get('name', None), download_mode, args.cache_dir) + clean_cache(args.cache_dir) # Clear the cache of lock and incomplete files. + + if args.modality == 'image': + dataset_meta = IMAGE_DATASETS + #elif args.modality == 'video': + # dataset_meta = VIDEO_DATASET_META + + for dataset_type in dataset_meta: + for dataset in dataset_meta[dataset_type]: + download_dataset( + dataset_path=dataset['path'], + dataset_name=dataset.get('name', None), + download_mode=download_mode, + cache_dir=args.cache_dir) diff --git a/base_miner/datasets/image_dataset.py b/base_miner/datasets/image_dataset.py new file mode 100644 index 00000000..09aa4897 --- /dev/null +++ b/base_miner/datasets/image_dataset.py @@ -0,0 +1,113 @@ +from typing import Optional +from datasets import Dataset +from PIL import Image +from io import BytesIO +from torchvision.transforms import Compose + +from .download_data import download_image +from .base_dataset import BaseDataset + + +class ImageDataset(BaseDataset): + def __init__( + self, + huggingface_dataset_path: Optional[str] = None, + huggingface_dataset_split: str = 'train', + huggingface_dataset_name: Optional[str] = None, + huggingface_dataset: Optional[Dataset] = None, + download_mode: Optional[str] = None, + transforms: Optional[Compose] = None, + ): + """Initialize the ImageDataset. + + Args: + huggingface_dataset_path (str, optional): Path to the Hugging Face dataset. + Can be a publicly hosted dataset (/) or + local directory (imagefolder:) + huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'. + huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset. + huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance. + download_mode (str, optional): Download mode for the dataset. + Can be None or "force_redownload" + """ + super().__init__( + huggingface_dataset_path=huggingface_dataset_path, + huggingface_dataset_split=huggingface_dataset_split, + huggingface_dataset_name=huggingface_dataset_name, + huggingface_dataset=huggingface_dataset, + download_mode=download_mode, + transforms=transforms + ) + + def __getitem__(self, index: int) -> dict: + """ + Get an item (image and ID) from the dataset. + + Args: + index (int): Index of the item to retrieve. + + Returns: + dict: Dictionary containing 'image' (PIL image) and 'id' (str). + """ + """ + Load an image from self.dataset. Expects self.dataset[i] to be a dictionary containing either 'image' or 'url' + as a key. + - The value associated with the 'image' key should be either a PIL image or a b64 string encoding of + the image. + - The value associated with the 'url' key should be a url that hosts the image (as in + dalle-mini/open-images) + + Args: + index (int): Index of the image in the dataset. + + Returns: + dict: Dictionary containing 'image' (PIL image) and 'id' (str). + """ + sample = self.dataset[int(index)] + if 'url' in sample: + image = download_image(sample['url']) + image_id = sample['url'] + elif 'image_url' in sample: + image = download_image(sample['image_url']) + image_id = sample['image_url'] + elif 'image' in sample: + if isinstance(sample['image'], Image.Image): + image = sample['image'] + elif isinstance(sample['image'], bytes): + image = Image.open(BytesIO(sample['image'])) + else: + raise NotImplementedError + + image_id = '' + if 'name' in sample: + image_id = sample['name'] + elif 'filename' in sample: + image_id = sample['filename'] + + image_id = image_id if image_id != '' else index + + else: + raise NotImplementedError + + # remove alpha channel if download didnt 404 + if image is not None: + image = image.convert('RGB') + + if self.transforms is not None: + image = self.transforms(image) + + return { + 'image': image, + 'id': image_id, + 'source': self.huggingface_dataset_path + } + + def __len__(self) -> int: + """ + Get the length of the dataset. + + Returns: + int: Length of the dataset. + """ + return len(self.dataset) + diff --git a/bitmind/real_fake_dataset.py b/base_miner/datasets/real_fake_dataset.py similarity index 82% rename from bitmind/real_fake_dataset.py rename to base_miner/datasets/real_fake_dataset.py index 4387d9be..0fb320a8 100644 --- a/bitmind/real_fake_dataset.py +++ b/base_miner/datasets/real_fake_dataset.py @@ -8,7 +8,6 @@ def __init__( self, real_image_datasets: list, fake_image_datasets: list, - transforms=None, fake_prob=0.5, source_label_mapping=None ): @@ -24,7 +23,6 @@ def __init__( """ self.real_image_datasets = real_image_datasets self.fake_image_datasets = fake_image_datasets - self.transforms = transforms self.fake_prob = fake_prob self.source_label_mapping = source_label_mapping @@ -55,21 +53,15 @@ def __getitem__(self, index: int) -> tuple: label = 1.0 else: source = self.real_image_datasets[np.random.randint(0, len(self.real_image_datasets))] - imgs, idx = source.sample(1) - image = imgs[0]['image'] - index = idx[0] + #imgs, idx = source.sample(1) + image = source[index]['image'] + #image = imgs[0]['image'] + #index = idx[0] label = 0.0 self._history['source'].append(source.huggingface_dataset_path) self._history['label'].append(label) self._history['index'].append(index) - - try: - if self.transforms is not None: - image = self.transforms(image) - except Exception as e: - print(e) - print(source.huggingface_dataset_path, index) if self.source_label_mapping: source_label = self.source_label_mapping[source.huggingface_dataset_path] @@ -94,4 +86,21 @@ def reset(self): 'source': [], 'index': [], 'label': [], - } \ No newline at end of file + } + + @staticmethod + def collate_fn(batch): + images, labels, source_labels = zip(*batch) + + images = torch.stack(images, dim=0) # Stack image tensors into a single tensor + labels = torch.LongTensor(labels) + source_labels = torch.LongTensor(source_labels) + + data_dict = { + 'image': images, + 'label': labels, + 'label_spe': source_labels, + 'landmark': None, + 'mask': None + } + return data_dict \ No newline at end of file diff --git a/bitmind/utils/data.py b/base_miner/datasets/util.py similarity index 86% rename from bitmind/utils/data.py rename to base_miner/datasets/util.py index 058faba8..7ce46550 100644 --- a/bitmind/utils/data.py +++ b/base_miner/datasets/util.py @@ -1,13 +1,11 @@ -from typing import Optional, Union, List, Tuple, Dict +from typing import List, Tuple, Dict import torchvision.transforms as transforms import numpy as np import datasets -import requests import datasets -from bitmind.download_data import load_huggingface_dataset -from bitmind.real_fake_dataset import RealFakeDataset -from bitmind.image_dataset import ImageDataset +from base_miner.datasets.download_data import load_huggingface_dataset +from base_miner.datasets import ImageDataset, VideoDataset, RealFakeDataset datasets.logging.set_verbosity_error() datasets.disable_progress_bar() @@ -17,8 +15,11 @@ def split_dataset(dataset): # Split data into train, validation, test and return the three splits dataset = dataset.shuffle(seed=42) + if 'train' in dataset: + dataset = dataset['train'] + split_dataset = {} - train_test_split = dataset['train'].train_test_split(test_size=0.2, seed=42) + train_test_split = dataset.train_test_split(test_size=0.2, seed=42) split_dataset['train'] = train_test_split['train'] temp_dataset = train_test_split['test'] @@ -30,7 +31,11 @@ def split_dataset(dataset): return split_dataset['train'], split_dataset['validation'], split_dataset['test'] -def load_and_split_datasets(dataset_meta: list) -> Dict[str, List[ImageDataset]]: +def load_and_split_datasets( + dataset_meta: list, + modality: str, + split_transforms: Dict[str, transforms.Compose] = {}, +) -> Dict[str, List[ImageDataset]]: """ Helper function to load and split dataset into train, validation, and test sets. @@ -56,7 +61,12 @@ def load_and_split_datasets(dataset_meta: list) -> Dict[str, List[ImageDataset]] train_ds, val_ds, test_ds = split_dataset(dataset) for split, data in zip(splits, [train_ds, val_ds, test_ds]): - image_dataset = ImageDataset(huggingface_dataset=data) + if modality == 'image': + image_dataset = ImageDataset(huggingface_dataset=data, transforms=split_transforms.get(split, None)) + elif modality == 'video': + image_dataset = VideoDataset(huggingface_dataset=data, transforms=split_transforms.get(split, None)) + else: + raise NotImplementedError(f'Unsupported modality: {modality}') datasets[split].append(image_dataset) split_lengths = ', '.join([f"{split} len={len(datasets[split][0])}" for split in splits]) @@ -105,9 +115,6 @@ def create_source_label_mapping( def create_real_fake_datasets( real_datasets: Dict[str, List[ImageDataset]], fake_datasets: Dict[str, List[ImageDataset]], - train_transforms: transforms.Compose = None, - val_transforms: transforms.Compose = None, - test_transforms: transforms.Compose = None, source_labels: bool = False, group_sources_by_name: bool = False) -> Tuple[RealFakeDataset, ...]: """ @@ -131,19 +138,16 @@ def create_real_fake_datasets( train_dataset = RealFakeDataset( real_image_datasets=real_datasets['train'], fake_image_datasets=fake_datasets['train'], - transforms=train_transforms, source_label_mapping=source_label_mapping) val_dataset = RealFakeDataset( real_image_datasets=real_datasets['validation'], fake_image_datasets=fake_datasets['validation'], - transforms=val_transforms, source_label_mapping=source_label_mapping) test_dataset = RealFakeDataset( real_image_datasets=real_datasets['test'], fake_image_datasets=fake_datasets['test'], - transforms=test_transforms, source_label_mapping=source_label_mapping) if source_labels: diff --git a/base_miner/datasets/video_dataset.py b/base_miner/datasets/video_dataset.py new file mode 100644 index 00000000..814e2bbc --- /dev/null +++ b/base_miner/datasets/video_dataset.py @@ -0,0 +1,116 @@ +""" +Author: Zhiyuan Yan +Email: zhiyuanyan@link.cuhk.edu.cn +Date: 2023-03-30 +Description: Abstract Base Class for all types of deepfake datasets. +""" + +import os +import cv2 +from PIL import Image +import sys +import yaml +import numpy as np +from copy import deepcopy +import random +import torch +from torch import nn +from torch.utils import data +from torchvision.utils import save_image +from torchvision.transforms import Compose +from einops import rearrange +from typing import List, Tuple, Optional +from datasets import Dataset + +from .base_dataset import BaseDataset + + +class VideoDataset(BaseDataset): + def __init__( + self, + huggingface_dataset_path: Optional[str] = None, + huggingface_dataset_split: str = 'train', + huggingface_dataset_name: Optional[str] = None, + huggingface_dataset: Optional[Dataset] = None, + download_mode: Optional[str] = None, + max_frames_per_video: Optional[int] = 4, + transforms: Optional[Compose] = None + ): + """Initialize the ImageDataset. + + Args: + huggingface_dataset_path (str, optional): Path to the Hugging Face dataset. + Can be a publicly hosted dataset (/) or + local directory (imagefolder:) + huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'. + huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset. + huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance. + download_mode (str, optional): Download mode for the dataset. + Can be None or "force_redownload" + """ + super().__init__( + huggingface_dataset_path=huggingface_dataset_path, + huggingface_dataset_split=huggingface_dataset_split, + huggingface_dataset_name=huggingface_dataset_name, + huggingface_dataset=huggingface_dataset, + download_mode=download_mode, + transforms=transforms, + ) + self.max_frames = max_frames_per_video + + def __getitem__(self, index): + """Return the data point at the given index. + + Args: + index (int): The index of the data point. + no_norm (bool): Whether to skip normalization. + + Returns: + tuple: Contains image tensor, label tensor, landmark tensor, + and mask tensor. + """ + image_paths = self.dataset[index]['frames'] + + if not isinstance(image_paths, list): + image_paths = [image_paths] + + images = [] + for image_path in image_paths[:self.max_frames]: + try: + img = Image.open(image_path) + images.append(img) + except Exception as e: + print(f"Error loading image at index {index}: {e}") + return self.__getitem__(0) + + if self.transforms is not None: + images = self.transforms(images) + + # Stack images along the time dimension (frame_dim) + image_tensors = torch.stack(images, dim=0) # Shape: [frame_dim, C, H, W] + + frames, channels, height, width = image_tensors.shape + x = torch.randint(0, width, (1,)).item() + y = torch.randint(0, height, (1,)).item() + mask_grid_size = 16 + x1 = max(x - mask_grid_size // 2, 0) + x2 = min(x + mask_grid_size // 2, width) + y1 = max(y - mask_grid_size // 2, 0) + y2 = min(y + mask_grid_size // 2, height) + image_tensors[:, :, y1:y2, x1:x2] = -1 + + return { + 'image': image_tensors, # Shape: [frame_dim, C, H, W] + 'id': self.dataset[index]['video_id'], + 'source': self.huggingface_dataset_path + } + + + def __len__(self) -> int: + """ + Get the length of the dataset. + + Returns: + int: Length of the dataset. + """ + return len(self.dataset['video_id']) \ No newline at end of file diff --git a/base_miner/deepfake_detectors/__init__.py b/base_miner/deepfake_detectors/__init__.py index 08e0636f..37529bfa 100644 --- a/base_miner/deepfake_detectors/__init__.py +++ b/base_miner/deepfake_detectors/__init__.py @@ -1,4 +1,5 @@ from .deepfake_detector import DeepfakeDetector -from .npr_detector import NPRDetector -from .ucf_detector import UCFDetector -from .camo_detector import CAMODetector \ No newline at end of file +from .npr_detector import NPRImageDetector +from .ucf_detector import UCFImageDetector +from .camo_detector import CAMOImageDetector +from .tall_detector import TALLVideoDetector diff --git a/base_miner/deepfake_detectors/camo_detector.py b/base_miner/deepfake_detectors/camo_detector.py index d30f474b..8408ce3a 100644 --- a/base_miner/deepfake_detectors/camo_detector.py +++ b/base_miner/deepfake_detectors/camo_detector.py @@ -8,7 +8,7 @@ @DETECTOR_REGISTRY.register_module(module_name='CAMO') -class CAMODetector(DeepfakeDetector): +class CAMOImageDetector(DeepfakeDetector): """ This DeepfakeDetector subclass implements Content-Aware Model Orchestration (CAMO), a mixture-of-experts approach to the binary classification of @@ -21,17 +21,17 @@ class CAMODetector(DeepfakeDetector): Attributes: model_name (str): Name of the detector instance. - config (str): Name of the YAML file in deepfake_detectors/config/ to load + config_name (str): Name of the YAML file in deepfake_detectors/config/ to load attributes from. device (str): The type of device ('cpu' or 'cuda'). """ - def __init__(self, model_name: str = 'CAMO', config: str = 'camo.yaml', device: str = 'cpu'): + def __init__(self, model_name: str = 'CAMO', config_name: str = 'camo.yaml', device: str = 'cpu'): """ Initialize the CAMODetector with dynamic model selection based on config. """ self.detectors = {} - super().__init__(model_name, config, device) + super().__init__(model_name, config_name, device) gate_names = [ content_type for content_type in self.content_type @@ -50,7 +50,7 @@ def load_model(self): if model_name in DETECTOR_REGISTRY: self.detectors[content_type] = DETECTOR_REGISTRY[model_name]( model_name=f'{model_name}_{content_type.capitalize()}', - config=detector_config, + config_name=detector_config, device=self.device ) else: diff --git a/base_miner/deepfake_detectors/configs/tall.yaml b/base_miner/deepfake_detectors/configs/tall.yaml new file mode 100644 index 00000000..d1248a24 --- /dev/null +++ b/base_miner/deepfake_detectors/configs/tall.yaml @@ -0,0 +1,3 @@ +hf_repo: 'bitmind/tall' # Hugging Face repository for downloading model files +config_name: 'tall.yaml' # pre-trained configuration file in HuggingFace +weights: 'tall_trainFF_testCDF.pth' # UCF model checkpoint in HuggingFace \ No newline at end of file diff --git a/base_miner/deepfake_detectors/configs/ucf.yaml b/base_miner/deepfake_detectors/configs/ucf.yaml index a8ae7667..db9978c8 100644 --- a/base_miner/deepfake_detectors/configs/ucf.yaml +++ b/base_miner/deepfake_detectors/configs/ucf.yaml @@ -1,4 +1,4 @@ # UCFDetector Generalist Configuration hf_repo: 'bitmind/bm-ucf' # Hugging Face repository for downloading model files -train_config: 'bm-general-config-v1.yaml' # pre-trained configuration file in HuggingFace +config_name: 'bm-general-config-v1.yaml' # pre-trained configuration file in HuggingFace weights: 'bm-general-v1.pth' # UCF model checkpoint in HuggingFace \ No newline at end of file diff --git a/base_miner/deepfake_detectors/configs/ucf_face.yaml b/base_miner/deepfake_detectors/configs/ucf_face.yaml index 9de77210..4cd4c5b6 100644 --- a/base_miner/deepfake_detectors/configs/ucf_face.yaml +++ b/base_miner/deepfake_detectors/configs/ucf_face.yaml @@ -1,4 +1,4 @@ # UCFDetector Face Expert Configuration hf_repo: 'bitmind/bm-ucf' # Hugging Face repository for downloading model files -train_config: 'bm-faces-config-v1.yaml' # pre-trained configuration file in HuggingFace +config_name: 'bm-faces-config-v1.yaml' # pre-trained configuration file in HuggingFace weights: 'bm-faces-v1.pth' # UCF model checkpoint in HuggingFace diff --git a/base_miner/deepfake_detectors/deepfake_detector.py b/base_miner/deepfake_detectors/deepfake_detector.py index 4c4af4ca..2aaff586 100644 --- a/base_miner/deepfake_detectors/deepfake_detector.py +++ b/base_miner/deepfake_detectors/deepfake_detector.py @@ -1,85 +1,154 @@ -import typing from abc import ABC, abstractmethod from pathlib import Path -import yaml +from typing import Optional, Dict, Any + import torch +import yaml +import bittensor as bt from PIL import Image +from huggingface_hub import hf_hub_download + +from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR class DeepfakeDetector(ABC): - """ - Abstract base class for detecting deepfake images via binary classification. + """Abstract base class for detecting deepfake images via binary classification. This class is intended to be subclassed by detector implementations - using different underying model architectures, routing via gates, or + using different underlying model architectures, routing via gates, or configurations. - + Attributes: model_name (str): Name of the detector instance. - config (str): Name of the YAML file in deepfake_detectors/config/ to load - instance attributes from. + config_name (Optional[str]): Name of the YAML file in deepfake_detectors/config/ + to load instance attributes from. device (str): The type of device ('cpu' or 'cuda'). + hf_repo (str): Hugging Face repository name for model weights. """ - - def __init__(self, model_name: str, config = None, device: str = 'cpu'): + + def __init__( + self, + model_name: str, + config_name: Optional[str] = None, + device: str = 'cpu' + ) -> None: + """Initialize the DeepfakeDetector. + + Args: + model_name: Name of the detector instance. + config: Optional name of configuration file to load. + device: Device to run the model on ('cpu' or 'cuda'). + """ self.model_name = model_name - self.device = torch.device(device if device == 'cuda' and torch.cuda.is_available() else 'cpu') - if config: - self.load_and_apply_config(config) + self.device = torch.device( + device if device == 'cuda' and torch.cuda.is_available() else 'cpu' + ) + + if config_name: + print(f"Configuring with {config_name}") + self.set_class_attrs(config_name) + self.load_model_config() + self.load_model() @abstractmethod - def load_model(self): - """ - Load the model. Specific loading implementations will be defined in subclasses. + def load_model(self) -> None: + """Load the model weights and architecture. + + This method should be implemented by subclasses to define their specific + model loading logic. """ pass - def preprocess(self, image: Image) -> torch.Tensor: - """ - Preprocess the image for model inference. - + def preprocess(self, image: Image.Image) -> torch.Tensor: + """Preprocess the image for model inference. + Args: - image (PIL.Image): The image to preprocess. - extra_data (dict, optional): Any additional data required for preprocessing. + image: The input image to preprocess. Returns: - torch.Tensor: The preprocessed image tensor. + The preprocessed image as a tensor ready for model input. """ # General preprocessing, to be overridden if necessary in subclasses pass @abstractmethod - def __call__(self, image: Image) -> float: - """ - Perform inference with the model. + def __call__(self, image: Image.Image) -> float: + """Perform inference with the model. Args: - image (PIL.Image): The preprocessed image. + image: The preprocessed input image. Returns: - float: The model's prediction (or other relevant result). + The model's prediction score (typically between 0 and 1). """ + pass + + def set_class_attrs(self, detector_config: str) -> None: + """Load detector configuration from YAML file and set attributes. - def load_and_apply_config(self, detector_config): - """ - Load detector configuration from YAML file and set corresponding attributes dynamically. - Args: - config_path (str): Path to the YAML configuration file. + detector_config: Path to the YAML configuration file or filename + in the configs directory. + + Raises: + Exception: If there is an error loading or parsing the config file. """ if Path(detector_config).exists(): detector_config_file = Path(detector_config) else: - detector_config_file = Path(__file__).resolve().parent / Path('configs/' + detector_config) + detector_config_file = ( + Path(__file__).resolve().parent / Path('configs/' + detector_config) + ) + try: - with open(detector_config_file, 'r') as file: + with open(detector_config_file, 'r', encoding='utf-8') as file: config_dict = yaml.safe_load(file) # Set class attributes dynamically from the config dictionary for key, value in config_dict.items(): - setattr(self, key, value) # Dynamically create self.key = value - + print('k:v', key, value) + setattr(self, key, value) + except Exception as e: print(f"Error loading detector configurations from {detector_config_file}: {e}") - raise \ No newline at end of file + raise + + def ensure_weights_are_available( + self, + weights_dir: str, + weights_filename: str + ) -> None: + """Ensure model weights are downloaded and available locally. + + Downloads weights from Hugging Face Hub if not found locally. + + Args: + weights_dir: Directory to store/find the weights. + weights_filename: Name of the weights file. + """ + destination_path = Path(weights_dir) / Path(weights_filename) + if not Path(weights_dir).exists(): + Path(weights_dir).mkdir(parents=True, exist_ok=True) + + if not destination_path.exists(): + print(f"Downloading {weights_filename} from {self.hf_repo} " + f"to {weights_dir}") + hf_hub_download(self.hf_repo, weights_filename, local_dir=weights_dir) + + def load_model_config(self): + try: + destination_path = Path(CONFIGS_DIR) / Path(self.config_name) + if not destination_path.exists(): + local_config_path = hf_hub_download(self.hf_repo, self.config_name, local_dir=CONFIGS_DIR) + print(f"Downloaded {self.hf_repo}/{self.config_name} to {local_config_path}") + with Path(local_config_path).open('r') as f: + self.config = yaml.safe_load(f) + else: + print(f"Loading local config from {destination_path}") + with destination_path.open('r') as f: + self.config = yaml.safe_load(f) + print(f"Loaded: {self.config}") + except Exception as e: + # some models such as NPR don't have an additional config file + bt.logging.warning("No additional train config loaded.") diff --git a/base_miner/deepfake_detectors/npr_detector.py b/base_miner/deepfake_detectors/npr_detector.py index 30b1782d..a5fe3500 100644 --- a/base_miner/deepfake_detectors/npr_detector.py +++ b/base_miner/deepfake_detectors/npr_detector.py @@ -4,45 +4,37 @@ from pathlib import Path from huggingface_hub import hf_hub_download from base_miner.NPR.networks.resnet import resnet50 -from bitmind.image_transforms import base_transforms +from bitmind.utils.image_transforms import get_base_transforms from base_miner.deepfake_detectors import DeepfakeDetector -from base_miner import DETECTOR_REGISTRY +from base_miner.registry import DETECTOR_REGISTRY from base_miner.NPR.config.constants import WEIGHTS_DIR @DETECTOR_REGISTRY.register_module(module_name='NPR') -class NPRDetector(DeepfakeDetector): +class NPRImageDetector(DeepfakeDetector): """ DeepfakeDetector subclass that initializes a pretrained NPR model for binary classification of fake and real images. Attributes: model_name (str): Name of the detector instance. - config (str): Name of the YAML file in deepfake_detectors/config/ to load + config_name (str): Name of the YAML file in deepfake_detectors/config/ to load attributes from. device (str): The type of device ('cpu' or 'cuda'). """ - def __init__(self, model_name: str = 'NPR', config: str = 'npr.yaml', device: str = 'cpu'): - super().__init__(model_name, config, device) + def __init__(self, model_name: str = 'NPR', config_name: str = 'npr.yaml', device: str = 'cpu'): + super().__init__(model_name, config_name, device) + self.transforms = get_base_transforms() def load_model(self): """ Load the ResNet50 model with the specified weights for deepfake detection. """ - self.ensure_weights_are_available(self.weights) + self.ensure_weights_are_available(WEIGHTS_DIR, self.weights) self.model = resnet50(num_classes=1) self.model.load_state_dict(torch.load(Path(WEIGHTS_DIR) / self.weights, map_location=self.device)) self.model.eval() - - def ensure_weights_are_available(self, weight_filename): - destination_path = Path(WEIGHTS_DIR) / Path(weight_filename) - if not destination_path.parent.exists(): - destination_path.parent.mkdir(parents=True, exist_ok=True) - if not destination_path.exists(): - model_path = hf_hub_download(self.hf_repo, weight_filename) - model = torch.load(model_path, map_location=torch.device(self.device)) - torch.save(model, destination_path) def preprocess(self, image: Image) -> torch.Tensor: """ @@ -54,7 +46,7 @@ def preprocess(self, image: Image) -> torch.Tensor: Returns: torch.Tensor: The preprocessed image tensor. """ - image_tensor = base_transforms(image).unsqueeze(0).float() + image_tensor = self.transforms(image).unsqueeze(0).float() return image_tensor def __call__(self, image: Image) -> float: diff --git a/base_miner/deepfake_detectors/tall_detector.py b/base_miner/deepfake_detectors/tall_detector.py new file mode 100644 index 00000000..7b4bd40a --- /dev/null +++ b/base_miner/deepfake_detectors/tall_detector.py @@ -0,0 +1,51 @@ +import torch +from pathlib import Path + +import bittensor as bt +from base_miner.registry import DETECTOR_REGISTRY +from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR +from base_miner.DFB.detectors import DETECTOR, TALLDetector +from base_miner.deepfake_detectors import DeepfakeDetector +from bitmind.utils.video_utils import pad_frames + + +@DETECTOR_REGISTRY.register_module(module_name="TALL") +class TALLVideoDetector(DeepfakeDetector): + def __init__( + self, + model_name: str = "TALL", + config_name: str = "tall.yaml", + device: str = "cpu", + ): + super().__init__(model_name, config_name, device) + + total_params = sum(p.numel() for p in self.tall.model.parameters()) + trainable_params = sum( + p.numel() for p in self.tall.model.parameters() if p.requires_grad + ) + bt.logging.info('device:', self.device) + bt.logging.info(total_params, "parameters") + bt.logging.info(trainable_params, "trainable parameters") + + def load_model(self): + # download weights from hf if not available locally + self.ensure_weights_are_available(WEIGHTS_DIR, self.weights) + bt.logging.info(f"Loaded config: {self.config}") + self.tall = TALLDetector(self.config, self.device) + + # load weights + checkpoint_path = Path(WEIGHTS_DIR) / self.weights + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.tall.load_state_dict(checkpoint, strict=True) + self.tall.model.eval() + + def preprocess(self, frames_tensor): + """ Prepare input data dict for TALLDetector """ + frames_tensor = pad_frames(frames_tensor, 4) + return {'image': frames_tensor} + + def __call__(self, frames_tensor): + input_data = self.preprocess(frames_tensor) + with torch.no_grad(): + output_data = self.tall.forward(input_data, inference=True) + return output_data['prob'][0] diff --git a/base_miner/deepfake_detectors/ucf_detector.py b/base_miner/deepfake_detectors/ucf_detector.py index ce30ed4d..92b40b54 100644 --- a/base_miner/deepfake_detectors/ucf_detector.py +++ b/base_miner/deepfake_detectors/ucf_detector.py @@ -4,91 +4,65 @@ import random import warnings warnings.filterwarnings("ignore", category=FutureWarning) +from huggingface_hub import hf_hub_download from pathlib import Path - +from PIL import Image +import torchvision.transforms as transforms +import torch.backends.cudnn as cudnn +import bittensor as bt import numpy as np import torch -import torch.backends.cudnn as cudnn -import torchvision.transforms as transforms import yaml -from PIL import Image -from huggingface_hub import hf_hub_download import gc -from base_miner.UCF.config.constants import CONFIGS_DIR, WEIGHTS_DIR -from base_miner.gating_mechanisms import FaceGate - -from base_miner.UCF.detectors import DETECTOR +from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR from base_miner.deepfake_detectors import DeepfakeDetector -from base_miner import DETECTOR_REGISTRY, GATE_REGISTRY +from base_miner.DFB.detectors import UCFDetector +from base_miner.registry import DETECTOR_REGISTRY -import bittensor as bt @DETECTOR_REGISTRY.register_module(module_name='UCF') -class UCFDetector(DeepfakeDetector): +class UCFImageDetector(DeepfakeDetector): """ DeepfakeDetector subclass that initializes a pretrained UCF model for binary classification of fake and real images. Attributes: model_name (str): Name of the detector instance. - config (str): Name of the YAML file in deepfake_detectors/config/ to load + config_name (str): Name of the YAML file in deepfake_detectors/config/ to load attributes from. device (str): The type of device ('cpu' or 'cuda'). """ - def __init__(self, model_name: str = 'UCF', config: str = 'ucf.yaml', device: str = 'cpu'): - super().__init__(model_name, config, device) - - def ensure_weights_are_available(self, weight_filename): - destination_path = Path(WEIGHTS_DIR) / Path(weight_filename) - if not destination_path.parent.exists(): - destination_path.parent.mkdir(parents=True, exist_ok=True) - if not destination_path.exists(): - model_path = hf_hub_download(self.hf_repo, weight_filename) - model = torch.load(model_path, map_location=self.device) - torch.save(model, destination_path) - - def load_train_config(self): - destination_path = Path(CONFIGS_DIR) / Path(self.train_config) - - if not destination_path.exists(): - local_config_path = hf_hub_download(self.hf_repo, self.train_config) - print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}") - config_dict = {} - with open(local_config_path, 'r') as f: - config_dict = yaml.safe_load(f) - with open(destination_path, 'w') as f: - yaml.dump(config_dict, f, default_flow_style=False) - with destination_path.open('r') as f: - return yaml.safe_load(f) - else: - print(f"Loaded local config from {destination_path}") - with destination_path.open('r') as f: - return yaml.safe_load(f) + def __init__(self, model_name: str = 'UCF', config_name: str = 'ucf.yaml', device: str = 'cpu'): + super().__init__(model_name, config_name, device) def init_cudnn(self): - if self.train_config.get('cudnn'): + if self.config.get('cudnn'): cudnn.benchmark = True def init_seed(self): - seed_value = self.train_config.get('manualSeed') + seed_value = self.config.get('manualSeed') if seed_value: random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) def load_model(self): - self.train_config = self.load_train_config() self.init_cudnn() self.init_seed() - self.ensure_weights_are_available(self.weights) - self.ensure_weights_are_available(self.train_config['pretrained'].split('/')[-1]) - model_class = DETECTOR[self.train_config['model_name']] - bt.logging.info(f"Loaded config from training run: {self.train_config}") - self.model = model_class(self.train_config).to(self.device) + self.ensure_weights_are_available(WEIGHTS_DIR, self.weights) + pretrained = self.config['pretrained'] + if isinstance(pretrained, dict) and 'filename' in pretrained: + pretrained = pretrained['filename'] + else: + pretrained = pretrained.split('/')[-1] + + self.ensure_weights_are_available(WEIGHTS_DIR, pretrained) + self.model = UCFDetector(self.config).to(self.device) self.model.eval() weights_path = Path(WEIGHTS_DIR) / self.weights + bt.logging.info(f"Loading checkpoint {weights_path}") checkpoint = torch.load(weights_path, map_location=self.device) try: self.model.load_state_dict(checkpoint, strict=True) @@ -121,7 +95,7 @@ def preprocess(self, image, res=256): transform = transforms.Compose([ transforms.Resize((res, res), interpolation=Image.LANCZOS), # Resize image to specified resolution. transforms.ToTensor(), # Convert the image to a PyTorch tensor. - transforms.Normalize(mean=self.train_config['mean'], std=self.train_config['std']) # Normalize the image tensor. + transforms.Normalize(mean=self.config['mean'], std=self.config['std']) # Normalize the image tensor. ]) # Apply transformations and add a batch dimension for model inference. diff --git a/base_miner/gating_mechanisms/face_gate.py b/base_miner/gating_mechanisms/face_gate.py index b8854e8e..d433056d 100644 --- a/base_miner/gating_mechanisms/face_gate.py +++ b/base_miner/gating_mechanisms/face_gate.py @@ -4,8 +4,8 @@ import dlib from base_miner.gating_mechanisms import Gate -from base_miner.UCF.config.constants import DLIB_FACE_PREDICTOR_PATH -from base_miner import GATE_REGISTRY +from base_miner.DFB.config.constants import DLIB_FACE_PREDICTOR_PATH +from base_miner.registry import GATE_REGISTRY from base_miner.gating_mechanisms.utils import get_face_landmarks, align_and_crop_face diff --git a/base_miner/gating_mechanisms/gating_mechanism.py b/base_miner/gating_mechanisms/gating_mechanism.py index ca4301cd..58ff7050 100644 --- a/base_miner/gating_mechanisms/gating_mechanism.py +++ b/base_miner/gating_mechanisms/gating_mechanism.py @@ -1,5 +1,5 @@ from PIL import Image -from base_miner import GATE_REGISTRY +from base_miner.registry import GATE_REGISTRY class GatingMechanism: diff --git a/bitmind/__init__.py b/bitmind/__init__.py index 68285ed0..8b0f87ca 100644 --- a/bitmind/__init__.py +++ b/bitmind/__init__.py @@ -18,7 +18,7 @@ # DEALINGS IN THE SOFTWARE. -__version__ = "1.2.9" +__version__ = "2.0.0" version_split = __version__.split(".") __spec_version__ = ( (1000 * int(version_split[0])) diff --git a/bitmind/base/miner.py b/bitmind/base/miner.py index e812eac4..e63d70a1 100644 --- a/bitmind/base/miner.py +++ b/bitmind/base/miner.py @@ -20,6 +20,7 @@ import threading import argparse import traceback +import typing import bittensor as bt @@ -53,17 +54,9 @@ def __init__(self, config=None): bt.logging.warning( "You are allowing non-registered entities to send requests to your miner. This is a security risk." ) - # The axon handles request processing, allowing validators to send this miner requests. - self.axon = bt.axon(wallet=self.wallet, config=self.config() if callable(self.config) else self.config) - # Attach determiners which functions are called when servicing a request. - bt.logging.info(f"Attaching forward function to miner axon.") - self.axon.attach( - forward_fn=self.forward, - blacklist_fn=self.blacklist, - priority_fn=self.priority, - ) - bt.logging.info(f"Axon created: {self.axon}") + # attach miner-specific functions in subclass __init__ + self.axon = bt.axon(wallet=self.wallet, config=self.config() if callable(self.config) else self.config) # Instantiate runners self.should_exit: bool = False @@ -192,3 +185,101 @@ def resync_metagraph(self): # Sync the metagraph. self.metagraph.sync(subtensor=self.subtensor) + + async def blacklist( + self, synapse: bt.Synapse + ) -> typing.Tuple[bool, str]: + """ + Determines whether an incoming request should be blacklisted and thus ignored. Your implementation should + define the logic for blacklisting requests based on your needs and desired security parameters. + + Blacklist runs before the synapse data has been deserialized (i.e. before synapse.data is available). + The synapse is instead contructed via the headers of the request. It is important to blacklist + requests before they are deserialized to avoid wasting resources on requests that will be ignored. + + Args: + synapse (bt.Synapse): A synapse object constructed from the headers of the incoming request. + + Returns: + Tuple[bool, str]: A tuple containing a boolean indicating whether the synapse's hotkey is blacklisted, + and a string providing the reason for the decision. + + This function is a security measure to prevent resource wastage on undesired requests. It should be enhanced + to include checks against the metagraph for entity registration, validator status, and sufficient stake + before deserialization of synapse data to minimize processing overhead. + + Example blacklist logic: + - Reject if the hotkey is not a registered entity within the metagraph. + - Consider blacklisting entities that are not validators or have insufficient stake. + + In practice it would be wise to blacklist requests from entities that are not validators, or do not have + enough stake. This can be checked via metagraph.S and metagraph.validator_permit. You can always attain + the uid of the sender via a metagraph.hotkeys.index( synapse.dendrite.hotkey ) call. + + Otherwise, allow the request to be processed further. + """ + if synapse.dendrite is None or synapse.dendrite.hotkey is None: + bt.logging.warning("Received a request without a dendrite or hotkey.") + return True, "Missing dendrite or hotkey" + + # TODO(developer): Define how miners should blacklist requests. + uid = self.metagraph.hotkeys.index(synapse.dendrite.hotkey) + if ( + not self.config.blacklist.allow_non_registered + and synapse.dendrite.hotkey not in self.metagraph.hotkeys + ): + # Ignore requests from un-registered entities. + bt.logging.trace( + f"Blacklisting un-registered hotkey {synapse.dendrite.hotkey}" + ) + return True, "Unrecognized hotkey" + + if self.config.blacklist.force_validator_permit: + # If the config is set to force validator permit, then we should only allow requests from validators. + if not self.metagraph.validator_permit[uid]: + bt.logging.warning( + f"Blacklisting a request from non-validator hotkey {synapse.dendrite.hotkey}" + ) + return True, "Non-validator hotkey" + + bt.logging.trace( + f"Not Blacklisting recognized hotkey {synapse.dendrite.hotkey}" + ) + return False, "Hotkey recognized!" + + async def priority(self, synapse: bt.Synapse) -> float: + """ + The priority function determines the order in which requests are handled. More valuable or higher-priority + requests are processed before others. You should design your own priority mechanism with care. + + This implementation assigns priority to incoming requests based on the calling entity's stake in the metagraph. + + Args: + synapse (bt.Synapse): The synapse object that contains metadata about the incoming request. + + Returns: + float: A priority score derived from the stake of the calling entity. + + Miners may recieve messages from multiple entities at once. This function determines which request should be + processed first. Higher values indicate that the request should be processed first. Lower values indicate + that the request should be processed later. + + Example priority logic: + - A higher stake results in a higher priority value. + """ + if synapse.dendrite is None or synapse.dendrite.hotkey is None: + bt.logging.warning("Received a request without a dendrite or hotkey.") + return 0.0 + + # TODO(developer): Define how miners should prioritize requests. + caller_uid = self.metagraph.hotkeys.index( + synapse.dendrite.hotkey + ) # Get the caller index. + + prirority = float( + self.metagraph.S[caller_uid] + ) # Return the stake as the priority. + bt.logging.trace( + f"Prioritizing {synapse.dendrite.hotkey} with value: ", prirority + ) + return prirority diff --git a/bitmind/base/neuron.py b/bitmind/base/neuron.py index 15c0374f..247afa16 100644 --- a/bitmind/base/neuron.py +++ b/bitmind/base/neuron.py @@ -108,10 +108,6 @@ def __init__(self, config=None): ) self.step = 0 - @abstractmethod - async def forward(self, synapse: bt.Synapse) -> bt.Synapse: - ... - @abstractmethod def run(self): ... diff --git a/bitmind/base/validator.py b/bitmind/base/validator.py index 47b055b7..498511cb 100644 --- a/bitmind/base/validator.py +++ b/bitmind/base/validator.py @@ -54,9 +54,15 @@ def add_args(cls, parser: argparse.ArgumentParser): def __init__(self, config=None): super().__init__(config=config) - self.history_cache_path = os.path.join( - self.config.neuron.full_path, "miner_performance_tracker.pkl") - + self.performance_trackers = { + 'image': None, + 'video': None + } + + self.image_history_cache_path = os.path.join( + self.config.neuron.full_path, "image_miner_performance_tracker.pkl") + self.video_history_cache_path = os.path.join( + self.config.neuron.full_path, "video_miner_performance_tracker.pkl") self.load_miner_history() # Save a copy of the hotkeys to local memory. @@ -169,9 +175,8 @@ def run(self): # Sync metagraph and potentially set weights. self.sync() - - self.step += 1 time.sleep(60) + self.step += 1 # If someone intentionally stops the validator, it'll safely terminate operations. except KeyboardInterrupt: @@ -376,22 +381,39 @@ def update_scores(self, rewards: np.ndarray, uids: List[int]): bt.logging.debug(f"Updated moving avg scores: {self.scores}") def save_miner_history(self): - bt.logging.info(f"Saving miner performance history to {self.history_cache_path}") - joblib.dump(self.performance_tracker, self.history_cache_path) + bt.logging.info(f"Saving miner performance history to {self.image_history_cache_path}") + joblib.dump(self.performance_trackers['image'], self.image_history_cache_path) + bt.logging.info(f"Saving miner performance history to {self.video_history_cache_path}") + joblib.dump(self.performance_trackers['video'], self.video_history_cache_path) def load_miner_history(self): - if os.path.exists(self.history_cache_path): - bt.logging.info(f"Loading miner performance history from {self.history_cache_path}") - self.performance_tracker = joblib.load(self.history_cache_path) - pred_history = self.performance_tracker.prediction_history - num_miners_history = len([ - uid for uid in pred_history - if len([p for p in pred_history[uid] if p != -1]) > 0 - ]) - bt.logging.info(f"Loaded history for {num_miners_history} miners") - else: - bt.logging.info(f"No miner performance history found at {self.history_cache_path} - starting fresh!") - self.performance_tracker = MinerPerformanceTracker() + def load(path): + if os.path.exists(path): + bt.logging.info(f"Loading miner performance history from {path}") + try: + tracker = joblib.load(path) + num_miners_history = len([ + uid for uid in tracker.prediction_history + if len([p for p in tracker.prediction_history[uid] if p != -1]) > 0 + ]) + bt.logging.info(f"Loaded history for {num_miners_history} miners") + except Exception as e: + bt.logging.error(f'Error loading miner performance tracker: {e}') + tracker = MinerPerformanceTracker() + else: + bt.logging.info(f"No miner performance history found at {path} - starting fresh!") + tracker = MinerPerformanceTracker() + return tracker + + try: + self.performance_trackers['image'] = load(self.image_history_cache_path) + except Exception as e: + # just for 2.0.0 upgrade for miner performance to carry over + v1_history_cache_path = os.path.join( + self.config.neuron.full_path, "miner_performance_tracker.pkl") + self.performance_trackers['image'] = load(v1_history_cache_path) + + self.performance_trackers['video'] = load(self.video_history_cache_path) def save_state(self): """Saves the state of the validator to a file.""" diff --git a/bitmind/constants.py b/bitmind/constants.py deleted file mode 100644 index d2dccdd4..00000000 --- a/bitmind/constants.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import torch - - -WANDB_PROJECT = 'bitmind-subnet' -WANDB_ENTITY = 'bitmindai' - -DATASET_META = { - "real": [ - {"path": "bitmind/bm-real"}, - {"path": "bitmind/open-images-v7"}, - {"path": "bitmind/celeb-a-hq"}, - {"path": "bitmind/ffhq-256"}, - {"path": "bitmind/MS-COCO-unique-256"} - ], - "fake": [ - {"path": "bitmind/bm-realvisxl"}, - {"path": "bitmind/bm-mobius"}, - {"path": "bitmind/bm-sdxl"} - ] -} - -FACE_TRAINING_DATASET_META = { - "real": [ - {"path": "bitmind/ffhq-256_training_faces", "name": "base_transforms"}, - {"path": "bitmind/celeb-a-hq_training_faces", "name": "base_transforms"} - - ], - "fake": [ - {"path": "bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces", "name": "base_transforms"}, - {"path": "bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces", "name": "base_transforms"} - ] -} - -VALIDATOR_DATASET_META = { - "real": [ - {"path": "bitmind/bm-real"}, - {"path": "bitmind/open-images-v7"}, - {"path": "bitmind/celeb-a-hq"}, - {"path": "bitmind/ffhq-256"}, - {"path": "bitmind/MS-COCO-unique-256"}, - {"path": "bitmind/AFHQ"}, - {"path": "bitmind/lfw"}, - {"path": "bitmind/caltech-256"}, - {"path": "bitmind/caltech-101"}, - {"path": "bitmind/dtd"} - ] -} - -VALIDATOR_MODEL_META = { - "diffusers": [ - { - "path": "stabilityai/stable-diffusion-xl-base-1.0", - "use_safetensors": True, - "torch_dtype": torch.float16, - "variant": "fp16", - "pipeline": "StableDiffusionXLPipeline" - }, - { - "path": "SG161222/RealVisXL_V4.0", - "use_safetensors": True, - "torch_dtype": torch.float16, - "variant": "fp16", - "pipeline": "StableDiffusionXLPipeline" - }, - { - "path": "Corcelio/mobius", - "use_safetensors": True, - "torch_dtype": torch.float16, - "pipeline": "StableDiffusionXLPipeline" - }, - { - "path": 'black-forest-labs/FLUX.1-dev', - "use_safetensors": True, - "torch_dtype": torch.bfloat16, - "generate_args": { - "guidance_scale": 2, - "num_inference_steps": {"min": 50, "max": 125}, - "generator": torch.Generator("cuda" if torch.cuda.is_available() else "cpu"), - "height": [512, 768], - "width": [512, 768] - }, - "enable_cpu_offload": False, - "pipeline": "FluxPipeline" - }, - { - "path": "prompthero/openjourney-v4", - "use_safetensors": True, - "torch_dtype": torch.float16, - "pipeline": "StableDiffusionPipeline" - }, - { - "path": "cagliostrolab/animagine-xl-3.1", - "use_safetensors": True, - "torch_dtype": torch.float16, - "pipeline": "StableDiffusionXLPipeline" - } - ] -} - -HUGGINGFACE_CACHE_DIR = os.path.expanduser('~/.cache/huggingface') - -TARGET_IMAGE_SIZE = (256, 256) - -PROMPT_TYPES = ('annotation', 'none') - -# args for .from_pretrained -DIFFUSER_ARGS = { - m['path']: { - k: v for k, v in m.items() - if k not in ('path', 'pipeline', 'generate_args', 'enable_cpu_offload') - } for m in VALIDATOR_MODEL_META['diffusers'] -} - -GENERATE_ARGS = { - m['path']: m['generate_args'] - for m in VALIDATOR_MODEL_META['diffusers'] - if 'generate_args' in m -} - -DIFFUSER_CPU_OFFLOAD_ENABLED = { - m['path']: m.get('enable_cpu_offload', False) - for m in VALIDATOR_MODEL_META['diffusers'] -} - -DIFFUSER_PIPELINE = { - m['path']: m['pipeline'] for m in VALIDATOR_MODEL_META['diffusers'] if 'pipeline' in m -} - -DIFFUSER_NAMES = list(DIFFUSER_ARGS.keys()) - -IMAGE_ANNOTATION_MODEL = "Salesforce/blip2-opt-6.7b-coco" - -TEXT_MODERATION_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" diff --git a/bitmind/image_dataset.py b/bitmind/image_dataset.py deleted file mode 100644 index b37c894c..00000000 --- a/bitmind/image_dataset.py +++ /dev/null @@ -1,159 +0,0 @@ -from typing import List, Tuple -from datasets import Dataset -from PIL import Image -from io import BytesIO -import bittensor as bt -import numpy as np - -from bitmind.download_data import load_huggingface_dataset, download_image - - -class ImageDataset: - - def __init__( - self, - huggingface_dataset_path: str = None, - huggingface_dataset_split: str = 'train', - huggingface_dataset_name: str = None, - huggingface_dataset: Dataset = None, - download_mode: str = None - ): - """ - Args: - huggingface_dataset_path (str): Path to the Hugging Face dataset. Can either be to a publicly hosted - huggingface dataset (/) or a local directory (imagefolder:) - huggingface_dataset_split (str): Split of the dataset to load (default: 'train'). - Make sure to check what splits are available for the datasets you're working with. - huggingface_dataset_name (str): Name of the Hugging Face dataset (default: None). - Some huggingface datasets provide various subets of different sizes, which can be accessed via thi - parameter. - create_splits (bool): Whether to create dataset splits (default: False). - If the huggingface dataset hasn't been pre-split (i.e., it only contains "Train"), we split it here - randomly. - download_mode (str): Download mode for the dataset (default: None). - can be None or "force_redownload" - """ - assert huggingface_dataset_path is not None or huggingface_dataset is not None, \ - "Either huggingface_dataset_path or huggingface_dataset must be provided." - - if huggingface_dataset: - self.dataset = huggingface_dataset - self.huggingface_dataset_path = self.dataset.info.dataset_name - self.huggingface_dataset_split = list(self.dataset.info.splits.keys())[0] - self.huggingface_dataset_name = self.dataset.info.config_name - - else: - self.huggingface_dataset_path = huggingface_dataset_path - self.huggingface_dataset_name = huggingface_dataset_name - self.dataset = load_huggingface_dataset( - huggingface_dataset_path, - huggingface_dataset_split, - huggingface_dataset_name, - download_mode) - self.sampled_images_idx = [] - - def __getitem__(self, index: int) -> dict: - """ - Get an item (image and ID) from the dataset. - - Args: - index (int): Index of the item to retrieve. - - Returns: - dict: Dictionary containing 'image' (PIL image) and 'id' (str). - """ - return self._get_image(index) - - def __len__(self) -> int: - """ - Get the length of the dataset. - - Returns: - int: Length of the dataset. - """ - return len(self.dataset) - - def _get_image(self, index: int) -> dict: - """ - Load an image from self.dataset. Expects self.dataset[i] to be a dictionary containing either 'image' or 'url' - as a key. - - The value associated with the 'image' key should be either a PIL image or a b64 string encoding of - the image. - - The value associated with the 'url' key should be a url that hosts the image (as in - dalle-mini/open-images) - - Args: - index (int): Index of the image in the dataset. - - Returns: - dict: Dictionary containing 'image' (PIL image) and 'id' (str). - """ - sample = self.dataset[int(index)] - if 'url' in sample: - image = download_image(sample['url']) - image_id = sample['url'] - elif 'image_url' in sample: - image = download_image(sample['image_url']) - image_id = sample['image_url'] - elif 'image' in sample: - if isinstance(sample['image'], Image.Image): - image = sample['image'] - elif isinstance(sample['image'], bytes): - image = Image.open(BytesIO(sample['image'])) - else: - raise NotImplementedError - - image_id = '' - if 'name' in sample: - image_id = sample['name'] - elif 'filename' in sample: - image_id = sample['filename'] - - image_id = image_id if image_id != '' else index - - else: - raise NotImplementedError - - # remove alpha channel if download didnt 404 - if image is not None: - image = image.convert('RGB') - - return { - 'image': image, - 'id': image_id, - 'source': self.huggingface_dataset_path - } - - def sample(self, k: int = 1) -> Tuple[List[dict], List[int]]: - """ - Randomly sample k images from self.dataset. Includes retries for failed downloads, in the case that - self.dataset contains urls. - - Args: - k (int): Number of images to sample (default: 1). - - Returns: - Tuple[List[dict], List[int]]: A tuple containing a list of sampled images and their indices. - """ - sampled_images = [] - sampled_idx = [] - while k > 0: - attempts = len(self.dataset) // 2 - for i in range(attempts): - image_idx = np.random.randint(0, len(self.dataset)) - if image_idx not in self.sampled_images_idx: - break - if i >= attempts: - self.sampled_images_idx = [] - try: - image = self._get_image(image_idx) - if image['image'] is not None: - sampled_images.append(image) - sampled_idx.append(image_idx) - self.sampled_images_idx.append(image_idx) - k -= 1 - except Exception as e: - bt.logging.error(e) - continue - - return sampled_images, sampled_idx diff --git a/bitmind/miner/predict.py b/bitmind/miner/predict.py deleted file mode 100644 index 8beba11e..00000000 --- a/bitmind/miner/predict.py +++ /dev/null @@ -1,21 +0,0 @@ -from PIL import Image -import torch - -from bitmind.image_transforms import base_transforms - - -def predict(model: torch.nn.Module, image: Image.Image) -> float: - """ - Perform prediction using a given PyTorch model on an image. You may need to modify this - if you train a custom model. - - Args: - model (torch.nn.Module): The PyTorch model to use for prediction. - image (Image.Image): The input image as a PIL Image. - - Returns: - float: The predicted output value. - """ - image = base_transforms(image).unsqueeze(0).float() - out = model(image).sigmoid().flatten().tolist() - return out[0] \ No newline at end of file diff --git a/bitmind/protocol.py b/bitmind/protocol.py index 996e09dc..94851d94 100644 --- a/bitmind/protocol.py +++ b/bitmind/protocol.py @@ -17,33 +17,20 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. -from pydantic import root_validator, validator +from typing import List +from pydantic import BaseModel, Field from torchvision import transforms from io import BytesIO from PIL import Image import bittensor as bt -import pydantic import base64 +import pydantic import torch +import zlib - -def prepare_image_synapse(image: Image): - """ - Prepares an image for use with ImageSynapse object. - - Args: - image (Image): The input image to be prepared. - - Returns: - ImageSynapse: An instance of ImageSynapse containing the encoded image and a default prediction value. - """ - if isinstance(image, torch.Tensor): - image = transforms.ToPILImage()(image.cpu().detach()) - - image_bytes = BytesIO() - image.save(image_bytes, format="JPEG") - b64_encoded_image = base64.b64encode(image_bytes.getvalue()) - return ImageSynapse(image=b64_encoded_image) +from bitmind.validator.config import TARGET_IMAGE_SIZE +from bitmind.utils.image_transforms import get_base_transforms +base_transforms = get_base_transforms(TARGET_IMAGE_SIZE) # ---- miner ---- @@ -61,6 +48,36 @@ def prepare_image_synapse(image: Image): # predictions = dendrite.query( ImageSynapse( images = b64_images ) ) # assert len(predictions) == len(b64_images) +def prepare_synapse(input_data, modality): + if isinstance(input_data, torch.Tensor): + input_data = transforms.ToPILImage()(input_data.cpu().detach()) + if isinstance(input_data, list) and isinstance(input_data[0], torch.Tensor): + for i, img in enumerate(input_data): + input_data[i] = transforms.ToPILImage()(img.cpu().detach()) + + if modality == 'image': + return prepare_image_synapse(input_data) + elif modality == 'video': + return prepare_video_synapse(input_data) + else: + raise NotImplementedError(f"Unsupported modality: {modality}") + + +def prepare_image_synapse(image: Image): + """ + Prepares an image for use with ImageSynapse object. + + Args: + image (Image): The input image to be prepared. + + Returns: + ImageSynapse: An instance of ImageSynapse containing the encoded image and a default prediction value. + """ + image_bytes = BytesIO() + image.save(image_bytes, format="JPEG") + b64_encoded_image = base64.b64encode(image_bytes.getvalue()) + return ImageSynapse(image=b64_encoded_image) + class ImageSynapse(bt.Synapse): """ @@ -73,6 +90,8 @@ class ImageSynapse(bt.Synapse): >.5 is considered generated/modified, <= 0.5 is considered real. """ + testnet_label: int = -1 # for easier miner eval on testnet + # Required request input, filled by sending dendrite caller. image: str = pydantic.Field( title="Image", @@ -99,3 +118,114 @@ def deserialize(self) -> float: prediction probabilities """ return self.prediction + + +def prepare_video_synapse(frames: List[Image.Image]): + """ + """ + frame_bytes = [] + for frame in frames: + buffer = BytesIO() + frame.save(buffer, format="JPEG") + frame_bytes.append(buffer.getvalue()) + + combined_bytes = b''.join(frame_bytes) + compressed_data = zlib.compress(combined_bytes) + encoded_data = base64.b85encode(compressed_data).decode('utf-8') + return VideoSynapse(video=encoded_data) + +class VideoSynapse(bt.Synapse): + """ + Naive initial VideoSynapse + Better option would be to modify the Dendrite interface to allow multipart/form-data here: + https://github.com/opentensor/bittensor/blob/master/bittensor/core/dendrite.py#L533 + Another higher lift option would be to look into Epistula or Fiber + """ + + testnet_label: int = -1 # for easier miner eval on testnet + + # Required request input, filled by sending dendrite caller. + video: str = pydantic.Field( + title="Video", + description="A wildly inefficient means of sending video data", + default="", + frozen=False + ) + + # Optional request output, filled by receiving axon. + prediction: float = pydantic.Field( + title="Prediction", + description="Probability that the image is AI generated/modified", + default=-1., + frozen=False + ) + + def deserialize(self) -> float: + """ + Deserialize the output. This method retrieves the response from + the miner, deserializes it and returns it as the output of the dendrite.query() call. + + Returns: + - float: The deserialized miner prediction + prediction probabilities + """ + return self.prediction + + +def decode_video_synapse(synapse: VideoSynapse) -> List[torch.Tensor]: + """ + V1 of a function for decoding a VideoSynapse object back into a list of torch tensors. + + Args: + synapse: VideoSynapse object containing the encoded video data + + Returns: + List of torch tensors, each representing a frame from the video + """ + compressed_data = base64.b85decode(synapse.video.encode('utf-8')) + combined_bytes = zlib.decompress(compressed_data) + + # Split the combined bytes into individual JPEG files + # Look for JPEG markers: FF D8 (start) and FF D9 (end) + frames = [] + current_pos = 0 + data_length = len(combined_bytes) + + while current_pos < data_length: + # Find start of JPEG (FF D8) + while current_pos < data_length - 1: + if combined_bytes[current_pos] == 0xFF and combined_bytes[current_pos + 1] == 0xD8: + break + current_pos += 1 + + if current_pos >= data_length - 1: + break + + start_pos = current_pos + + # Find end of JPEG (FF D9) + while current_pos < data_length - 1: + if combined_bytes[current_pos] == 0xFF and combined_bytes[current_pos + 1] == 0xD9: + current_pos += 2 + break + current_pos += 1 + + if current_pos > start_pos: + # Extract the JPEG data + jpeg_data = combined_bytes[start_pos:current_pos] + try: + # Convert to PIL Image + img = Image.open(BytesIO(jpeg_data)) + # Convert to numpy array + frames.append(img) + except Exception as e: + print(f"Error processing frame: {e}") + continue + + bt.logging.info('transforming video inputs') + frames = base_transforms(frames) + + frames = torch.stack(frames, dim=0) + frames = frames.unsqueeze(0) + print(f'decoded video into tensor with shape {frames.shape}') + return frames diff --git a/bitmind/synthetic_image_generation/README.md b/bitmind/synthetic_data_generation/README.md similarity index 100% rename from bitmind/synthetic_image_generation/README.md rename to bitmind/synthetic_data_generation/README.md diff --git a/bitmind/synthetic_data_generation/__init__.py b/bitmind/synthetic_data_generation/__init__.py new file mode 100644 index 00000000..5c7fbce0 --- /dev/null +++ b/bitmind/synthetic_data_generation/__init__.py @@ -0,0 +1 @@ +from .synthetic_data_generator import SyntheticDataGenerator diff --git a/bitmind/synthetic_data_generation/image_annotation_generator.py b/bitmind/synthetic_data_generation/image_annotation_generator.py new file mode 100644 index 00000000..d2cbd11c --- /dev/null +++ b/bitmind/synthetic_data_generation/image_annotation_generator.py @@ -0,0 +1,244 @@ +import gc +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + Blip2ForConditionalGeneration, + Blip2Processor, + pipeline, + logging as transformers_logging, +) +from transformers.utils.logging import disable_progress_bar + +import bittensor as bt +from bitmind.validator.config import HUGGINGFACE_CACHE_DIR + +disable_progress_bar() + + +class ImageAnnotationGenerator: + """ + A class for generating and moderating image annotations using transformer models. + + This class provides functionality to generate descriptive captions for images + using BLIP2 models and optionally moderate the generated text using a separate + language model. + """ + + def __init__( + self, + model_name: str, + text_moderation_model_name: str, + device: str = 'cuda', + apply_moderation: bool = True + ) -> None: + """ + Initialize the ImageAnnotationGenerator with specific models and device settings. + + Args: + model_name: The name of the BLIP model for generating image captions. + text_moderation_model_name: The name of the model used for moderating + text descriptions. + device: The device to use. + apply_moderation: Flag to determine whether text moderation should be + applied to captions. + """ + self.model_name = model_name + self.processor = Blip2Processor.from_pretrained( + self.model_name, + cache_dir=HUGGINGFACE_CACHE_DIR + ) + + self.apply_moderation = apply_moderation + self.text_moderation_model_name = text_moderation_model_name + self.text_moderation_pipeline = None + self.model = None + self.device = device + + def is_model_loaded(self) -> bool: + return self.model is not None + + def load_models(self) -> None: + """ + Load the necessary models for image annotation and text moderation onto + the specified device. + """ + if self.is_model_loaded(): + bt.logging.warning( + f"Image annotation model {self.model_name} is already loaded" + ) + return + + bt.logging.info(f"Loading image annotation model {self.model_name}") + self.model = Blip2ForConditionalGeneration.from_pretrained( + self.model_name, + torch_dtype=torch.float16, + cache_dir=HUGGINGFACE_CACHE_DIR + ) + self.model.to(self.device) + bt.logging.info(f"Loaded image annotation model {self.model_name}") + bt.logging.info( + f"Loading annotation moderation model {self.text_moderation_model_name}..." + ) + if self.apply_moderation: + model = AutoModelForCausalLM.from_pretrained( + self.text_moderation_model_name, + torch_dtype=torch.bfloat16, + cache_dir=HUGGINGFACE_CACHE_DIR + ) + + tokenizer = AutoTokenizer.from_pretrained( + self.text_moderation_model_name, + cache_dir=HUGGINGFACE_CACHE_DIR + ) + model = model.to(self.device) + self.text_moderation_pipeline = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer + ) + bt.logging.info( + f"Loaded annotation moderation model {self.text_moderation_model_name}." + ) + + def clear_gpu(self) -> None: + """ + Clear GPU memory by moving models back to CPU and deleting them, + followed by collecting garbage. + """ + bt.logging.info("Clearing GPU memory after generating image annotation") + self.model.to('cpu') + del self.model + self.model = None + if self.text_moderation_pipeline: + self.text_moderation_pipeline.model.to('cpu') + del self.text_moderation_pipeline + self.text_moderation_pipeline = None + gc.collect() + torch.cuda.empty_cache() + + def moderate(self, description: str, max_new_tokens: int = 80) -> str: + """ + Use the text moderation pipeline to make the description more concise + and neutral. + + Args: + description: The text description to be moderated. + max_new_tokens: Maximum number of new tokens to generate in the + moderated text. + + Returns: + The moderated description text, or the original description if + moderation fails. + """ + messages = [ + { + "role": "system", + "content": ( + "[INST]You always concisely rephrase given descriptions, " + "eliminate redundancy, and remove all specific references to " + "individuals by name. You do not respond with anything other " + "than the revised description.[/INST]" + ) + }, + { + "role": "user", + "content": description + } + ] + try: + moderated_text = self.text_moderation_pipeline( + messages, + max_new_tokens=max_new_tokens, + pad_token_id=self.text_moderation_pipeline.tokenizer.eos_token_id, + return_full_text=False + ) + + if isinstance(moderated_text, list): + return moderated_text[0]['generated_text'] + + bt.logging.error("Moderated text did not return a list.") + return description + + except Exception as e: + bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True) + return description + + def generate( + self, + image: Image.Image, + max_new_tokens: int = 20, + verbose: bool = False + ) -> str: + """ + Generate a string description for a given image using prompt-based + captioning and building conversational context. + + Args: + image: The image for which the description is to be generated. + max_new_tokens: The maximum number of tokens to generate for each + prompt. + verbose: If True, additional logging information is printed. + + Returns: + A generated description of the image. + """ + if not verbose: + transformers_logging.set_verbosity_error() + + description = "" + prompts = [ + "An image of", + "The setting is", + "The background is", + "The image type/style is" + ] + + for i, prompt in enumerate(prompts): + description += prompt + ' ' + inputs = self.processor( + image, + text=description, + return_tensors="pt" + ).to(self.device, torch.float16) + + generated_ids = self.model.generate( + **inputs, + max_new_tokens=max_new_tokens + ) + answer = self.processor.batch_decode( + generated_ids, + skip_special_tokens=True + )[0].strip() + + if verbose: + bt.logging.info(f"{i}. Prompt: {prompt}") + bt.logging.info(f"{i}. Answer: {answer}") + + if answer: + answer = answer.rstrip(" ,;!?") + if not answer.endswith('.'): + answer += '.' + description += answer + ' ' + else: + description = description[:-len(prompt) - 1] + + if not verbose: + transformers_logging.set_verbosity_info() + + if description.startswith(prompts[0]): + description = description[len(prompts[0]):] + + description = description.strip() + if not description.endswith('.'): + description += '.' + + if self.apply_moderation: + moderated_description = self.moderate(description) + return moderated_description + + return description diff --git a/bitmind/synthetic_image_generation/utils/image_utils.py b/bitmind/synthetic_data_generation/image_utils.py similarity index 97% rename from bitmind/synthetic_image_generation/utils/image_utils.py rename to bitmind/synthetic_data_generation/image_utils.py index a01627b9..5c419537 100644 --- a/bitmind/synthetic_image_generation/utils/image_utils.py +++ b/bitmind/synthetic_data_generation/image_utils.py @@ -1,7 +1,8 @@ import PIL import os import json -from bitmind.constants import TARGET_IMAGE_SIZE +from bitmind.validator.config import TARGET_IMAGE_SIZE + def resize_image(image: PIL.Image.Image, max_width: int, max_height: int) -> PIL.Image.Image: """Resize the image to fit within specified dimensions while maintaining aspect ratio.""" @@ -20,6 +21,7 @@ def resize_image(image: PIL.Image.Image, max_width: int, max_height: int) -> PIL resized_image = image.resize((new_width, new_height), PIL.Image.LANCZOS) return resized_image + def resize_images_in_directory(directory, target_width=TARGET_IMAGE_SIZE[0], target_height=TARGET_IMAGE_SIZE[1]): """ Resize all images in the specified directory to the target width and height. diff --git a/bitmind/synthetic_data_generation/prompt_utils.py b/bitmind/synthetic_data_generation/prompt_utils.py new file mode 100644 index 00000000..7c5ce81e --- /dev/null +++ b/bitmind/synthetic_data_generation/prompt_utils.py @@ -0,0 +1,39 @@ + + +def get_tokenizer_with_min_len(model): + """ + Returns the tokenizer with the smallest maximum token length from the 't2vis_model` object. + + If a second tokenizer exists, it compares both and returns the one with the smaller + maximum token length. Otherwise, it returns the available tokenizer. + + Returns: + tuple: A tuple containing the tokenizer and its maximum token length. + """ + # Check if a second tokenizer is available in the t2vis_model + if hasattr(model, 'tokenizer_2'): + if model.tokenizer.model_max_length > model.tokenizer_2.model_max_length: + return model.tokenizer_2, model.tokenizer_2.model_max_length + return model.tokenizer, model.tokenizer.model_max_length + + +def truncate_prompt_if_too_long(prompt: str, model): + """ + Truncates the input string if it exceeds the maximum token length when tokenized. + + Args: + prompt (str): The text prompt that may need to be truncated. + + Returns: + str: The original prompt if within the token limit; otherwise, a truncated version of the prompt. + """ + tokenizer, max_token_len = get_tokenizer_with_min_len(model) + tokens = tokenizer(prompt, verbose=False) # Suppress token max exceeded warnings + if len(tokens['input_ids']) < max_token_len: + return prompt + + # Truncate tokens if they exceed the maximum token length, decode the tokens back to a string + truncated_prompt = tokenizer.decode(token_ids=tokens['input_ids'][:max_token_len-1], + skip_special_tokens=True) + tokens = tokenizer(truncated_prompt) + return truncated_prompt \ No newline at end of file diff --git a/bitmind/synthetic_data_generation/synthetic_data_generator.py b/bitmind/synthetic_data_generation/synthetic_data_generator.py new file mode 100644 index 00000000..d140541e --- /dev/null +++ b/bitmind/synthetic_data_generation/synthetic_data_generator.py @@ -0,0 +1,387 @@ +import gc +import json +import os +import random +import time +import warnings +from pathlib import Path +from typing import Dict, Optional, Any, Union + +import bittensor as bt +import numpy as np +import torch +from diffusers.utils import export_to_video +from PIL import Image + +from bitmind.validator.config import ( + HUGGINGFACE_CACHE_DIR, + TEXT_MODERATION_MODEL, + IMAGE_ANNOTATION_MODEL, + T2VIS_MODELS, + T2VIS_MODEL_NAMES, + T2V_MODEL_NAMES, + T2I_MODEL_NAMES, + TARGET_IMAGE_SIZE, + select_random_t2vis_model, + get_modality +) +from bitmind.synthetic_data_generation.prompt_utils import truncate_prompt_if_too_long +from bitmind.synthetic_data_generation.image_annotation_generator import ImageAnnotationGenerator +from bitmind.validator.cache import ImageCache + + +future_warning_modules_to_ignore = [ + 'diffusers', + 'transformers.tokenization_utils_base' +] + +for module in future_warning_modules_to_ignore: + warnings.filterwarnings("ignore", category=FutureWarning, module=module) + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision('high') + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' + + +class SyntheticDataGenerator: + """ + A class for generating synthetic images and videos based on text prompts. + + This class supports different prompt generation strategies and can utilize + various text-to-video (t2v) and text-to-image (t2i) models. + + Attributes: + use_random_t2vis_model: Whether to randomly select a t2v or t2i for each + generation task. + prompt_type: The type of prompt generation strategy ('random', 'annotation'). + prompt_generator_name: Name of the prompt generation model. + t2vis_model_name: Name of the t2v or t2i model. + image_annotation_generator: The generator object for annotating images if required. + output_dir: Directory to write generated data. + """ + + def __init__( + self, + t2vis_model_name: Optional[str] = None, + use_random_t2vis_model: bool = True, + prompt_type: str = 'annotation', + output_dir: Optional[Union[str, Path]] = None, + image_cache: Optional[ImageCache] = None, + device: str = 'cuda' + ) -> None: + """ + Initialize the SyntheticDataGenerator. + + Args: + t2vis_model_name: Name of the text-to-video or text-to-image model. + use_random_t2vis_model: Whether to randomly select models for generation. + prompt_type: The type of prompt generation strategy. + output_dir: Directory to write generated data. + device: Device identifier. + run_as_daemon: Whether to run generation in the background. + image_cache: Optional image cache instance. + + Raises: + ValueError: If an invalid model name is provided. + NotImplementedError: If an unsupported prompt type is specified. + """ + if not use_random_t2vis_model and t2vis_model_name not in T2VIS_MODEL_NAMES: + raise ValueError( + f"Invalid model name '{t2vis_model_name}'. " + f"Options are {T2VIS_MODEL_NAMES}" + ) + + self.use_random_t2vis_model = use_random_t2vis_model + self.t2vis_model_name = t2vis_model_name + self.t2vis_model = None + self.device = device + + if self.use_random_t2vis_model and t2vis_model_name is not None: + bt.logging.warning( + "t2vis_model_name will be ignored (use_random_t2vis_model=True)" + ) + self.t2vis_model_name = None + + self.prompt_type = prompt_type + if self.prompt_type == 'annotation': + self.image_annotation_generator = ImageAnnotationGenerator( + model_name=IMAGE_ANNOTATION_MODEL, + text_moderation_model_name=TEXT_MODERATION_MODEL + ) + else: + raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}") + + self.output_dir = Path(output_dir) if output_dir else None + if self.output_dir: + (self.output_dir / "video").mkdir(parents=True, exist_ok=True) + (self.output_dir / "image").mkdir(parents=True, exist_ok=True) + + self.image_cache = image_cache + + def batch_generate(self, batch_size: int = 5) -> None: + """ + Asynchronously generate synthetic data in batches. + + Args: + batch_size: Number of prompts to generate in each batch. + """ + prompts = [] + bt.logging.info(f"Generating {batch_size} prompts") + for i in range(batch_size): + image_sample = self.image_cache.sample() + bt.logging.info(f"Sampled image {i+1}/{batch_size} for captioning: {image_sample['path']}") + prompts.append(self.generate_prompt(image=image_sample['image'], clear_gpu=i==batch_size-1)) + bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}") + + + # shuffle and interleave models + t2i_model_names = random.sample(T2I_MODEL_NAMES, len(T2I_MODEL_NAMES)) + t2v_model_names = random.sample(T2V_MODEL_NAMES, len(T2V_MODEL_NAMES)) + model_names = [m for pair in zip(t2v_model_names, t2i_model_names) for m in pair] + for model_name in model_names: + modality = get_modality(model_name) + for i, prompt in enumerate(prompts): + bt.logging.info(f"Started generation {i+1}/{batch_size} | Model: {model_name} | Prompt: {prompt}") + + # Generate image/video from current model and prompt + start = time.time() + output = self.run_t2vis(prompt, modality, t2vis_model_name=model_name) + + bt.logging.info(f'Writing to cache {self.output_dir}') + base_path = self.output_dir / modality / str(output['time']) + metadata = {k: v for k, v in output.items() if k != 'gen_output'} + base_path.with_suffix('.json').write_text(json.dumps(metadata)) + + if modality == 'image': + out_path = base_path.with_suffix('.png') + output['gen_output'].images[0].save(out_path) + elif modality == 'video': + bt.logging.info("Writing to cache") + out_path = str(base_path.with_suffix('.mp4')) + export_to_video( + output['gen_output'].frames[0], + out_path, + fps=30 + ) + bt.logging.info(f"Wrote to {out_path}") + + def generate( + self, + image: Optional[Image.Image] = None, + modality: str = 'image', + t2vis_model_name: Optional[str] = None + ) -> Dict[str, Any]: + """ + Generate synthetic data based on input parameters. + + Args: + image: Input image for annotation-based generation. + modality: Type of media to generate ('image' or 'video'). + + Returns: + Dictionary containing generated data information. + + Raises: + ValueError: If real_image is None when using annotation prompt type. + NotImplementedError: If prompt type is not supported. + """ + prompt = self.generate_prompt(image, clear_gpu=True) + bt.logging.info("Generating synthetic data...") + gen_data = self.run_t2vis(prompt, modality, t2vis_model_name) + self.clear_gpu() + return gen_data + + def generate_prompt( + self, + image: Optional[Image.Image] = None, + clear_gpu: bool = True + ) -> str: + """Generate a prompt based on the specified strategy.""" + bt.logging.info("Generating prompt") + if self.prompt_type == 'annotation': + if image is None: + raise ValueError( + "image can't be None if self.prompt_type is 'annotation'" + ) + self.image_annotation_generator.load_models() + prompt = self.image_annotation_generator.generate(image) + if clear_gpu: + self.image_annotation_generator.clear_gpu() + else: + raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}") + return prompt + + def run_t2vis( + self, + prompt: str, + modality: str, + t2vis_model_name: Optional[str] = None, + generate_at_target_size: bool = False, + + ) -> Dict[str, Any]: + """ + Generate synthetic data based on a text prompt. + + Args: + prompt: The text prompt used to inspire the generation. + generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions. + t2vis_model_name: Optional model name to use for generation. + + Returns: + Dictionary containing generated data and metadata. + + Raises: + RuntimeError: If generation fails. + """ + self.load_t2vis_model(t2vis_model_name) + model_config = T2VIS_MODELS[self.t2vis_model_name] + + bt.logging.info("Preparing generation arguments") + gen_args = model_config.get('generate_args', {}).copy() + + # Process generation arguments + for k, v in gen_args.items(): + if isinstance(v, dict): + gen_args[k] = np.random.randint( + gen_args[k]['min'], + gen_args[k]['max'] + ) + for dim in ('height', 'width'): + if isinstance(gen_args.get(dim), list): + gen_args[dim] = np.random.choice(gen_args[dim]) + + try: + if generate_at_target_size: + gen_args['height'] = TARGET_IMAGE_SIZE[0] + gen_args['width'] = TARGET_IMAGE_SIZE[1] + + truncated_prompt = truncate_prompt_if_too_long( + prompt, + self.t2vis_model + ) + + bt.logging.info(f"Generating media from prompt: {truncated_prompt}") + bt.logging.info(f"Generation args: {gen_args}") + start_time = time.time() + if model_config.get('use_autocast', True): + pretrained_args = model_config.get('from_pretrained_args', {}) + torch_dtype = pretrained_args.get('torch_dtype', torch.bfloat16) + with torch.autocast(self.device, torch_dtype, cache_enabled=False): + gen_output = self.t2vis_model( + prompt=truncated_prompt, + **gen_args + ) + else: + gen_output = self.t2vis_model( + prompt=truncated_prompt, + **gen_args + ) + gen_time = time.time() - start_time + + except Exception as e: + if generate_at_target_size: + bt.logging.error( + f"Attempt with custom dimensions failed, falling back to " + f"default dimensions. Error: {e}" + ) + try: + gen_output = self.t2vis_model(prompt=truncated_prompt) + gen_time = time.time() - start_time + except Exception as fallback_error: + bt.logging.error( + f"Failed to generate image with default dimensions after " + f"initial failure: {fallback_error}" + ) + raise RuntimeError( + f"Both attempts to generate image failed: {fallback_error}" + ) + else: + bt.logging.error(f"Image generation error: {e}") + raise RuntimeError(f"Failed to generate image: {e}") + + print(f"Finished generation in {gen_time/60} minutes") + return { + 'prompt': truncated_prompt, + 'prompt_long': prompt, + 'gen_output': gen_output, # image or video + 'time': time.time(), + 'model_name': self.t2vis_model_name, + 'gen_time': gen_time + } + + def load_t2vis_model(self, model_name: Optional[str] = None, modality: Optional[str] = None) -> None: + """Load a Hugging Face text-to-image or text-to-video model to a specific GPU.""" + if model_name is not None: + self.t2vis_model_name = model_name + elif self.use_random_t2vis_model or model_name == 'random': + model_name = select_random_t2vis_model(modality) + self.t2vis_model_name = model_name + + bt.logging.info(f"Loading {self.t2vis_model_name}") + + pipeline_cls = T2VIS_MODELS[model_name]['pipeline_cls'] + pipeline_args = T2VIS_MODELS[model_name]['from_pretrained_args'] + + self.t2vis_model = pipeline_cls.from_pretrained( + pipeline_args.get('base', model_name), + cache_dir=HUGGINGFACE_CACHE_DIR, + **pipeline_args, + add_watermarker=False + ) + + self.t2vis_model.set_progress_bar_config(disable=True) + + # Load scheduler if specified + if 'scheduler' in T2VIS_MODELS[model_name]: + sched_cls = T2VIS_MODELS[model_name]['scheduler']['cls'] + sched_args = T2VIS_MODELS[model_name]['scheduler']['from_config_args'] + self.t2vis_model.scheduler = sched_cls.from_config( + self.t2vis_model.scheduler.config, + **sched_args + ) + + # Configure model optimizations + model_config = T2VIS_MODELS[model_name] + if model_config.get('enable_model_cpu_offload', False): + bt.logging.info(f"Enabling cpu offload for {model_name}") + self.t2vis_model.enable_model_cpu_offload() + if model_config.get('enable_sequential_cpu_offload', False): + bt.logging.info(f"Enabling sequential cpu offload for {model_name}") + self.t2vis_model.enable_sequential_cpu_offload() + if model_config.get('vae_enable_slicing', False): + bt.logging.info(f"Enabling vae slicing for {model_name}") + try: + self.t2vis_model.vae.enable_slicing() + except Exception: + try: + self.t2vis_model.enable_vae_slicing() + except Exception: + bt.logging.warning(f"Could not enable vae slicing for {self.t2vis_model}") + if model_config.get('vae_enable_tiling', False): + bt.logging.info(f"Enabling vae tiling for {model_name}") + try: + self.t2vis_model.vae.enable_tiling() + except Exception: + try: + self.t2vis_model.enable_vae_tiling() + except Exception: + bt.logging.warning(f"Could not enable vae tiling for {self.t2vis_model}") + + self.t2vis_model.to(self.device) + bt.logging.info(f"Loaded {model_name} using {pipeline_cls.__name__}.") + + def clear_gpu(self) -> None: + """Clear GPU memory by deleting models and running garbage collection.""" + if self.t2vis_model is not None: + bt.logging.info( + "Deleting previous text-to-image or text-to-video model, " + "freeing memory" + ) + del self.t2vis_model + self.t2vis_model = None + gc.collect() + torch.cuda.empty_cache() + diff --git a/bitmind/synthetic_image_generation/image_annotation_generator.py b/bitmind/synthetic_image_generation/image_annotation_generator.py deleted file mode 100644 index ced54e43..00000000 --- a/bitmind/synthetic_image_generation/image_annotation_generator.py +++ /dev/null @@ -1,344 +0,0 @@ -# Transformer models -from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, pipeline -import torch - -# Logging and progress handling -from transformers import logging as transformers_logging -from transformers.utils.logging import disable_progress_bar - -from typing import Any, Dict, List, Tuple -import bittensor as bt -import PIL -import time -import torch -import gc - -from bitmind.image_dataset import ImageDataset -from bitmind.synthetic_image_generation.utils import image_utils -from bitmind.constants import HUGGINGFACE_CACHE_DIR - -disable_progress_bar() - - -class ImageAnnotationGenerator: - """ - A class responsible for generating text annotations for images using a transformer-based image captioning model. - It integrates text moderation to ensure the descriptions are concise and neutral. - - Attributes: - device (torch.device): The device (CPU or GPU) on which the models are loaded. - model_name (str): The name of the BLIP model for generating image captions. - processor (Blip2Processor): The processor associated with the BLIP model. - model (Blip2ForConditionalGeneration): The BLIP model used for generating image captions. - apply_moderation (bool): Flag to determine whether text moderation should be applied to captions. - text_moderation_model_name (str): The name of the model used for moderating text descriptions. - text_moderation_pipeline (pipeline): A Hugging Face pipeline for text moderation. - - Methods: - __init__(self, model_name: str, text_moderation_model_name: str, device: str = cuda, apply_moderation: bool = True): - Initializes the ImageAnnotationGenerator with the specified model, device, and moderation settings. - - load_models(self): - Loads the image annotation and text moderation models into memory. - - clear_gpu(self): - Clears GPU memory to ensure that no residual data remains that could affect further operations. - - moderate_description(self, description: str, max_new_tokens: int = 80) -> str: - Moderates the given description to make it more concise and neutral, using the text moderation model. - - generate_description(self, image: PIL.Image.Image, verbose: bool = False, max_new_tokens: int = 20) -> str: - Generates a description for the provided image using the image captioning model. - - generate_annotation(self, image_id, dataset_name: str, image: PIL.Image.Image, original_dimensions: tuple, resize: bool, verbose: int) -> dict: - Generates a text annotation for a given image, including handling image resizing and verbose logging. - - process_image(self, image_info: dict, dataset_name: str, image_index: int, resize: bool, verbose: int) -> Tuple[Any, float]: - Processes a single image from a dataset to generate its annotation and measures the time taken. - - generate_annotations(self, real_image_datasets: List[ImageDataset], verbose: int = 0, max_images: int = None, resize_images: bool = False) -> Dict[str, Dict[str, Any]]: - Generates text annotations for a batch of images from the specified datasets and calculates the average processing latency. - """ - def __init__( - self, model_name: str, text_moderation_model_name: str, device: str = "cuda", - apply_moderation: bool = True - ): - """ - Initializes the ImageAnnotationGenerator with specific models and device settings. - - Args: - model_name (str): The name of the BLIP model for generating image captions. - text_moderation_model_name (str): The name of the model used for moderating text descriptions. - device (str): Device to use for model inference. Defaults to "cuda". - apply_moderation (bool): Flag to determine whether text moderation should be applied to captions. - """ - self.device = device - self.model_name = model_name - self.processor = Blip2Processor.from_pretrained( - self.model_name, cache_dir=HUGGINGFACE_CACHE_DIR - ) - self.model = None - - self.apply_moderation = apply_moderation - self.text_moderation_model_name = text_moderation_model_name - self.text_moderation_pipeline = None - - def load_models(self): - """ - Loads the necessary models for image annotation and text moderation onto the specified device. - """ - self.model = Blip2ForConditionalGeneration.from_pretrained( - self.model_name, - torch_dtype=torch.float16, - cache_dir=HUGGINGFACE_CACHE_DIR - ) - self.model.to(self.device) - if self.apply_moderation: - model = AutoModelForCausalLM.from_pretrained( - self.text_moderation_model_name, - torch_dtype=torch.bfloat16, - cache_dir=HUGGINGFACE_CACHE_DIR - ) - - tokenizer = AutoTokenizer.from_pretrained( - self.text_moderation_model_name, - cache_dir=HUGGINGFACE_CACHE_DIR - ) - model = model.to(self.device) - self.text_moderation_pipeline = pipeline( - "text-generation", - model=model, - tokenizer=tokenizer - ) - - def clear_gpu(self): - """ - Clears GPU memory by moving models back to CPU and deleting them, followed by collecting garbage. - """ - self.model.to('cpu') - del self.model - self.model = None - if self.text_moderation_pipeline: - self.text_moderation_pipeline.model.to('cpu') - del self.text_moderation_pipeline - self.text_moderation_pipeline = None - gc.collect() - torch.cuda.empty_cache() - - def moderate_description(self, description: str, max_new_tokens: int = 80) -> str: - """ - Uses the text moderation pipeline to make the description more concise and neutral. - """ - messages = [ - { - "role": "system", - "content": ("[INST]You always concisely rephrase given descriptions, eliminate redundancy, " - "and remove all specific references to individuals by name. You do not respond with" - "anything other than the revised description.[/INST]") - }, - { - "role": "user", - "content": description - } - ] - try: - moderated_text = self.text_moderation_pipeline(messages, max_new_tokens=max_new_tokens, - pad_token_id=self.text_moderation_pipeline.tokenizer.eos_token_id, - return_full_text=False) - - if isinstance(moderated_text, list): - return moderated_text[0]['generated_text'] - bt.logging.error("Failed to return moderated text.") - else: - bt.logging.error("Moderated text did not return a list.") - - return description # Fallback to the original description if no suitable entry is found - except Exception as e: - bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True) - return description # Return the original description as a fallback - - def generate_description(self, - image: PIL.Image.Image, - verbose: bool = False, - max_new_tokens: int = 20) -> str: - """ - Generates a string description for a given image by interfacing with a transformer - model using prompt-based captioning and building conversational context. - - Args: - image (PIL.Image.Image): The image for which the description is to be generated. - verbose (bool, optional): If True, additional logging information is printed. Defaults to False. - max_new_tokens (int, optional): The maximum number of tokens to generate for each prompt. Defaults to 20. - - Returns: - str: A generated description of the image. - """ - if not verbose: - transformers_logging.set_verbosity_error() - - description = "" - prompts = ["An image of", "The setting is", "The background is", "The image type/style is"] - for i, prompt in enumerate(prompts): - description += prompt + ' ' - inputs = self.processor(image, text=description, return_tensors="pt").to(self.device, torch.float16) - generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) #GPT2Tokenizer - answer = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() - if verbose: - bt.logging.info(f"{i}. Prompt: {prompt}") - bt.logging.info(f"{i}. Answer: {answer}") - - if answer: - # Remove any ending spaces or punctuation that is not a period - answer = answer.rstrip(" ,;!?") - # Add a period at the end if it's not already there - if not answer.endswith('.'): - answer += '.' - - description += answer + ' ' - else: - description = description[:-len(prompt) - 1] - - if not verbose: - transformers_logging.set_verbosity_info() - - if description.startswith(prompts[0]): - description = description[len(prompts[0]):] - - # Remove any trailing spaces and ensure the description ends with a period - description = description.strip() - if not description.endswith('.'): - description += '.' - if self.apply_moderation: - moderated_description = self.moderate_description(description) - return moderated_description - return description - - def generate_annotation( - self, - image_id, - dataset_name: str, - image: PIL.Image.Image, - original_dimensions: tuple, - resize: bool, - verbose: int) -> dict: - """ - Generate a text annotation for a given image. - - Parameters: - image_id (int or str): The identifier for the image within the dataset. - dataset_name (str): The name of the dataset the image belongs to. - image (PIL.Image.Image): The image object that requires annotation. - original_dimensions (tuple): Original dimensions of the image as (width, height). - resize (bool): Allow image downsizing to maximum dimensions of (1280, 1280). - verbose (int): Verbosity level. - - Returns: - dict: Dictionary containing the annotation data. - """ - image_to_process = image.copy() - if resize: # Downsize if dimension(s) are greater than 1280 - image_to_process = image_utils.resize_image(image_to_process, 1280, 1280) - if verbose > 1 and image_to_process.size != image.size: - bt.logging.info(f"Resized {image_id}: {image.size} to {image_to_process.size}") - try: - description = self.generate_description(image_to_process, verbose > 2) - annotation = { - 'description': description, - 'original_dataset': dataset_name, - 'original_dimensions': f"{original_dimensions[0]}x{original_dimensions[1]}", - 'id': image_id - } - return annotation - except Exception as e: - if verbose > 1: - bt.logging.error(f"Error processing image {image_id} in {dataset_name}: {e}") - return None - - def process_image( - self, - image_info: dict, - dataset_name: str, - image_index: int, - resize: bool, - verbose: int) -> Tuple[Any, float]: - """ - Processes an individual image for annotation, including resizing and verbosity controls, - and calculates the time taken to process the image. - - Args: - image_info (dict): Dictionary containing image data and metadata. - dataset_name (str): The name of the dataset containing the image. - image_index (int): The index of the image within the dataset. - resize (bool): Whether to resize the image before processing. - verbose (int): Verbosity level for logging outputs. - - Returns: - Tuple[Any, float]: A tuple containing the generated annotation (or None if failed) and the time taken to process. - """ - - if image_info['image'] is None: - if verbose > 1: - bt.logging.debug(f"Skipping image {image_index} in dataset {dataset_name} due to missing image data.") - return None, 0 - - original_dimensions = image_info['image'].size - start_time = time.time() - annotation = self.generate_annotation(image_index, - dataset_name, - image_info['image'], - original_dimensions, - resize, - verbose) - time_elapsed = time.time() - start_time - - if annotation is None: - if verbose > 1: - bt.logging.debug(f"Failed to generate annotation for image {image_index} in dataset {dataset_name}") - return None, time_elapsed - - return annotation, time_elapsed - - def generate_annotations( - self, - real_image_datasets: - List[ImageDataset], - verbose: int = 0, - max_images: int = None, - resize_images: bool = False) -> Dict[str, Dict[str, Any]]: - """ - Generates text annotations for images in the given datasets, saves them in a specified directory, - and computes the average per image latency. Returns a dictionary of new annotations and the average latency. - - Parameters: - real_image_datasets (List[Any]): Datasets containing images. - verbose (int): Verbosity level for process messages (Most verbose = 3). - max_images (int): Maximum number of images to annotate. - resize_images (bool) : Allow image downsizing before captioning. - Sets max dimensions to (1280, 1280), maintaining aspect ratio. - - Returns: - Tuple[Dict[str, Dict[str, Any]], float]: A tuple containing the annotations dictionary and average latency. - """ - annotations = {} - total_time = 0 - total_processed_images = 0 - for dataset in real_image_datasets: - dataset_name = dataset.huggingface_dataset_path - processed_images = 0 - dataset_time = 0 - for j, image_info in enumerate(dataset): - annotation, time_elapsed = self.process_image(image_info, - dataset_name, - j, - resize_images, - verbose) - if annotation is not None: - annotations.setdefault(dataset_name, {})[image_info['id']] = annotation - total_time += time_elapsed - dataset_time += time_elapsed - processed_images += 1 - if max_images is not None and len(annotations[dataset_name]) >= max_images: - break - total_processed_images += processed_images - overall_average_latency = total_time / total_processed_images if total_processed_images else 0 - return annotations, overall_average_latency diff --git a/bitmind/synthetic_image_generation/synthetic_image_generator.py b/bitmind/synthetic_image_generation/synthetic_image_generator.py deleted file mode 100644 index 80a2a919..00000000 --- a/bitmind/synthetic_image_generation/synthetic_image_generator.py +++ /dev/null @@ -1,295 +0,0 @@ -from transformers import pipeline -from transformers import set_seed -from diffusers import StableDiffusionXLPipeline, FluxPipeline, StableDiffusionPipeline -import bittensor as bt -import numpy as np -import torch -import random -import time -import re -import gc -import os -import warnings - -from bitmind.constants import ( - TEXT_MODERATION_MODEL, - DIFFUSER_NAMES, - DIFFUSER_ARGS, - DIFFUSER_PIPELINE, - DIFFUSER_CPU_OFFLOAD_ENABLED, - GENERATE_ARGS, - PROMPT_TYPES, - IMAGE_ANNOTATION_MODEL, - TARGET_IMAGE_SIZE -) - -future_warning_modules_to_ignore = [ - 'diffusers', - 'transformers.tokenization_utils_base' -] - -for module in future_warning_modules_to_ignore: - warnings.filterwarnings("ignore", category=FutureWarning, module=module) - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' - -from transformers import pipeline, set_seed -import bittensor as bt - -from bitmind.synthetic_image_generation.image_annotation_generator import ImageAnnotationGenerator -from bitmind.constants import HUGGINGFACE_CACHE_DIR - - -class SyntheticImageGenerator: - """ - A class for generating synthetic images based on text prompts. Supports different prompt generation strategies - and can utilize various image diffuser models to create images. - - Attributes: - use_random_diffuser (bool): Whether to randomly select a diffuser for each generation task. - prompt_type (str): The type of prompt generation strategy (currently only supports 'annotation') - diffuser_name (str): Name of the image diffuser model. - image_annotation_generator (ImageAnnotationGenerator): The generator object for annotating images if required. - image_cache_dir (str): Directory to cache generated images. - device (str): Device to use for model inference. Defaults to "cuda". - """ - def __init__( - self, - prompt_type='annotation', - diffuser_name=DIFFUSER_NAMES[0], - use_random_diffuser=False, - image_cache_dir=None, - device="cuda" - ): - if prompt_type not in PROMPT_TYPES: - raise ValueError(f"Invalid prompt type '{prompt_type}'. Options are {PROMPT_TYPES}") - if not use_random_diffuser and diffuser_name not in DIFFUSER_NAMES: - raise ValueError(f"Invalid diffuser name '{diffuser_name}'. Options are {DIFFUSER_NAMES}") - - self.use_random_diffuser = use_random_diffuser - self.prompt_type = prompt_type - self.device = device - - self.diffuser = None - if self.use_random_diffuser and diffuser_name is not None: - bt.logging.warning("Warning: diffuser_name will be ignored (use_random_diffuser=True)") - self.diffuser_name = None - else: - self.diffuser_name = diffuser_name - - self.image_annotation_generator = None - if self.prompt_type == 'annotation': - self.image_annotation_generator = ImageAnnotationGenerator(model_name=IMAGE_ANNOTATION_MODEL, - text_moderation_model_name=TEXT_MODERATION_MODEL, - device = self.device) - else: - raise NotImplementedError(f"Unsupported prompt_type: {self.prompt_type}") - - self.image_cache_dir = image_cache_dir - if image_cache_dir is not None: - os.makedirs(self.image_cache_dir, exist_ok=True) - - def generate(self, k: int = 1, real_images=None) -> list: - """ - Generates k synthetic images. If self.prompt_type is 'annotation', a BLIP2 captioning pipeline is used - to produce prompts by captioning real images. If self.prompt_type is 'random', an LLM is used to generate - prompts. - - Args: - k (int): Number of images to generate. - - Returns: - list: List of dictionaries containing 'prompt', 'image', and 'id'. - """ - if self.prompt_type == 'annotation': - if real_images is None: - raise ValueError(f"real_images can't be None if self.prompt_type is 'annotation'") - prompts = [ - self.generate_image_caption(real_images[i]) - for i in range(k) - ] - else: - raise NotImplementedError - - if self.use_random_diffuser: - self.load_diffuser('random') - else: - self.load_diffuser(self.diffuser_name) - - gen_data = [] - for prompt in prompts: - image_data = self.generate_image(prompt) - if self.image_cache_dir is not None: - path = os.path.join(self.image_cache_dir, image_data['id']) - image_data['image'].save(path) - gen_data.append(image_data) - self.clear_gpu() # remove diffuser from gpu - - return gen_data - - def clear_gpu(self): - """ - Clears GPU memory by deleting the loaded diffuser and performing garbage collection. - """ - if self.diffuser is not None: - del self.diffuser - gc.collect() - torch.cuda.empty_cache() - self.diffuser = None - - def load_diffuser(self, diffuser_name) -> None: - """ - Loads a Hugging Face diffuser model to a specific GPU. - - Parameters: - diffuser_name (str): Name of the diffuser to load. - """ - if diffuser_name == 'random': - diffuser_name = np.random.choice(DIFFUSER_NAMES, 1)[0] - - self.diffuser_name = diffuser_name - pipeline_class = globals()[DIFFUSER_PIPELINE[diffuser_name]] - self.diffuser = pipeline_class.from_pretrained(diffuser_name, - cache_dir=HUGGINGFACE_CACHE_DIR, - **DIFFUSER_ARGS[diffuser_name], - add_watermarker=False) - self.diffuser.set_progress_bar_config(disable=True) - self.diffuser.to(self.device) - if DIFFUSER_CPU_OFFLOAD_ENABLED[diffuser_name]: - self.diffuser.enable_model_cpu_offload() - - def generate_image_caption(self, image_sample) -> str: - """ - Generates a descriptive caption for a given image sample. - - This function takes an image sample as input, processes the image using a pre-trained - model, and returns a generated caption describing the content of the image. - - Args: - image_sample (dict): A dictionary containing information about the image to be processed. - It includes: - - 'source' (str): The dataset or source name of the image. - - 'id' (int/str): The unique identifier of the image. - - Returns: - str: A descriptive caption generated for the input image. - """ - self.image_annotation_generator.load_models() - annotation = self.image_annotation_generator.process_image( - image_info=image_sample, - dataset_name=image_sample['source'], - image_index=image_sample['id'], - resize=False, - verbose=0 - )[0] - self.image_annotation_generator.clear_gpu() - return annotation['description'] - - def get_tokenizer_with_min_len(self): - """ - Returns the tokenizer with the smallest maximum token length from the 'diffuser` object. - - If a second tokenizer exists, it compares both and returns the one with the smaller - maximum token length. Otherwise, it returns the available tokenizer. - - Returns: - tuple: A tuple containing the tokenizer and its maximum token length. - """ - # Check if a second tokenizer is available in the diffuser - if hasattr(self.diffuser, 'tokenizer_2'): - if self.diffuser.tokenizer.model_max_length > self.diffuser.tokenizer_2.model_max_length: - return self.diffuser.tokenizer_2, self.diffuser.tokenizer_2.model_max_length - return self.diffuser.tokenizer, self.diffuser.tokenizer.model_max_length - - def truncate_prompt_if_too_long(self, prompt: str): - """ - Truncates the input string if it exceeds the maximum token length when tokenized. - - Args: - prompt (str): The text prompt that may need to be truncated. - - Returns: - str: The original prompt if within the token limit; otherwise, a truncated version of the prompt. - """ - tokenizer, max_token_len = self.get_tokenizer_with_min_len() - tokens = tokenizer(prompt, verbose=False) # Suppress token max exceeded warnings - if len(tokens['input_ids']) < max_token_len: - return prompt - # Truncate tokens if they exceed the maximum token length, decode the tokens back to a string - truncated_prompt = tokenizer.decode(token_ids=tokens['input_ids'][:max_token_len-1], - skip_special_tokens=True) - tokens = tokenizer(truncated_prompt) - bt.logging.info("Truncated prompt to abide by token limit.") - return truncated_prompt - - def generate_image(self, prompt, name = None, generate_at_target_size = False) -> list: - """ - Generates a synthetic image based on a text prompt. This function can optionally adjust the generation args of the - diffusion model, such as dimensions and the number of inference steps. - - Args: - prompt (str): The text prompt used to inspire the image generation. - name (str, optional): An optional identifier for the generated image. If not provided, a timestamp-based - identifier is used. - generate_at_target_size (bool, optional): If True, the image is generated at the dimensions specified by the - TARGET_IMAGE_SIZE constant. Otherwise, dimensions are selected based on the diffuser's default or random settings. - - Returns: - dict: A dictionary containing: - - 'prompt': The possibly truncated version of the input prompt. - - 'image': The generated image object. - - 'id': The identifier of the generated image. - - 'gen_time': The time taken to generate the image, measured from the start of the process. - """ - # Generate a unique image name based on current time if not provided - image_name = name if name else f"{time.time():.0f}.jpg" - # Check if the prompt is too long - truncated_prompt = self.truncate_prompt_if_too_long(prompt) - gen_args = {} - - # Load generation arguments based on diffuser settings - if self.diffuser_name in GENERATE_ARGS: - gen_args = GENERATE_ARGS[self.diffuser_name].copy() - - if isinstance(gen_args.get('num_inference_steps'), dict): - gen_args['num_inference_steps'] = np.random.randint( - gen_args['num_inference_steps']['min'], - gen_args['num_inference_steps']['max']) - - for dim in ('height', 'width'): - if isinstance(gen_args.get(dim), list): - gen_args[dim] = np.random.choice(gen_args[dim]) - - try: - if generate_at_target_size: - #Attempt to generate an image with specified dimensions - gen_args['height'] = TARGET_IMAGE_SIZE[0] - gen_args['width'] = TARGET_IMAGE_SIZE[1] - # Record the time taken to generate the image - start_time = time.time() - # Generate image using the diffuser with appropriate arguments - gen_image = self.diffuser(prompt=truncated_prompt, num_images_per_prompt=1, **gen_args).images[0] - # Calculate generation time - gen_time = time.time() - start_time - except Exception as e: - if generate_at_target_size: - bt.logging.error(f"Attempt with custom dimensions failed, falling back to default dimensions. Error: {e}") - try: - # Fallback to generating an image without specifying dimensions - gen_image = self.diffuser(prompt=truncated_prompt).images[0] - gen_time = time.time() - start_time - except Exception as fallback_error: - bt.logging.error(f"Failed to generate image with default dimensions after initial failure: {fallback_error}") - raise RuntimeError(f"Both attempts to generate image failed: {fallback_error}") - else: - bt.logging.error(f"Image generation error: {e}") - raise RuntimeError(f"Failed to generate image: {e}") - - image_data = { - 'prompt': truncated_prompt, - 'image': gen_image, - 'id': image_name, - 'gen_time': gen_time - } - return image_data diff --git a/bitmind/synthetic_image_generation/utils/annotation_utils.py b/bitmind/synthetic_image_generation/utils/annotation_utils.py deleted file mode 100644 index 279009ab..00000000 --- a/bitmind/synthetic_image_generation/utils/annotation_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import bittensor as bt -import json -import os - -def ensure_save_path(path: str) -> str: - """Ensure that a directory exists; if it does not, create it.""" - if not os.path.exists(path): - os.makedirs(path) - return path - -def create_annotation_dataset_directory(base_path: str, dataset_name: str) -> str: - """Create a directory for a dataset with a safe name, replacing any invalid characters.""" - safe_name = dataset_name.replace("/", "_") - full_path = os.path.join(base_path, safe_name) - if not os.path.exists(full_path): - os.makedirs(full_path) - return full_path - - -def save_annotation(dataset_dir: str, image_id, annotation: dict, verbose: int): - """Save a text annotation to a JSON file if it doesn't already exist.""" - file_path = os.path.join(dataset_dir, f"{image_id}.json") - if os.path.exists(file_path): - if verbose > 0: - bt.logging.info(f"Annotation for {image_id} already exists - Skipping") - return -1 # Skip this image as it already has an annotation - - with open(file_path, 'w') as f: - json.dump(annotation, f, indent=4) - if verbose > 0: - bt.logging.info(f"Created {file_path}") - - return 0 - - -def compute_annotation_latency(self, processed_images: int, dataset_time: float, dataset_name: str) -> float: - if processed_images > 0: - average_latency = dataset_time / processed_images - bt.logging.info(f'Average annotation latency for {dataset_name}: {average_latency:.4f} seconds') - return average_latency - return 0.0 - - -def list_datasets(base_dir: str) -> list[str]: - """List all subdirectories in the base directory.""" - return [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))] - - -def load_annotations(base_dir: str, dataset: str) -> list[dict]: - """Load annotations from JSON files within a specified directory.""" - annotations = [] - path = os.path.join(base_dir, dataset) - for filename in os.listdir(path): - if filename.endswith(".json"): - with open(os.path.join(path, filename), 'r') as file: - data = json.load(file) - annotations.append(data) - return annotations diff --git a/bitmind/synthetic_image_generation/utils/hugging_face_utils.py b/bitmind/synthetic_image_generation/utils/hugging_face_utils.py deleted file mode 100644 index 8507fead..00000000 --- a/bitmind/synthetic_image_generation/utils/hugging_face_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import json -from datasets import load_dataset -from huggingface_hub import HfApi - -def dataset_exists_on_hf(hf_dataset_name, token): - """Check if the dataset exists on Hugging Face.""" - api = HfApi() - try: - dataset_info = api.dataset_info(hf_dataset_name, token=token) - return True - except Exception as e: - return False - -def numerical_sort(value): - return int(os.path.splitext(os.path.basename(value))[0]) - -def load_and_sort_dataset(data_dir, file_type): - # Get list of filenames in the directory with the given extension - try: - if file_type == 'image': - # List image filenames with common image extensions - valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif') - filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) - if f.lower().endswith(valid_extensions)] - elif file_type == 'json': - # List json filenames - filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) - if f.lower().endswith('.json')] - else: - raise ValueError(f"Unsupported file type: {file_type}") - - if not filenames: - raise FileNotFoundError(f"No files with the extension '{file_type}' \ - found in directory '{data_dir}'") - - # Sort filenames numerically (0, 1, 2, 3, 4). Necessary because - # HF datasets are ordered by string (0, 1, 10, 11, 12). - sorted_filenames = sorted(filenames, key=numerical_sort) - - # Load the dataset with sorted filenames - if file_type == 'image': - return load_dataset("imagefolder", data_files=sorted_filenames) - elif file_type == 'json': - return load_dataset("json", data_files=sorted_filenames) - - except Exception as e: - print(f"Error loading dataset: {e}") - return None - -def upload_to_huggingface(dataset, repo_name, token): - """Uploads the dataset dictionary to Hugging Face.""" - api = HfApi() - api.create_repo(repo_name, repo_type="dataset", private=False, token=token) - dataset.push_to_hub(repo_name) - -def slice_dataset(dataset, start_index, end_index=None): - """ - Slice the dataset according to provided start and end indices. - - Parameters: - dataset (Dataset): The dataset to be sliced. - start_index (int): The index of the first element to include in the slice. - end_index (int, optional): The index of the last element to include in the slice. If None, slices to the end of the dataset. - - Returns: - Dataset: The sliced dataset. - """ - if end_index is not None and end_index < len(dataset): - return dataset.select(range(start_index, end_index)) - else: - return dataset.select(range(start_index, len(dataset))) - -def save_as_json(df, output_dir): - os.makedirs(output_dir, exist_ok=True) # Ensure the directory exists - # Iterate through rows in dataframe - for index, row in df.iterrows(): - file_path = os.path.join(output_dir, f"{row['id']}.json") - # Convert the row to a dictionary and save it as JSON - with open(file_path, 'w', encoding='utf-8') as f: - json.dump(row.to_dict(), f, ensure_ascii=False, indent=4) diff --git a/bitmind/synthetic_image_generation/utils/stress_test.py b/bitmind/synthetic_image_generation/utils/stress_test.py deleted file mode 100644 index 70eb2e0d..00000000 --- a/bitmind/synthetic_image_generation/utils/stress_test.py +++ /dev/null @@ -1,66 +0,0 @@ -import logging -import os -import time -import time - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' -logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') - -from synthetic_image_generator import SyntheticImageGenerator -from bitmind.image_dataset import ImageDataset -from bitmind.utils.data import sample_dataset_index_name - -from bitmind.constants import DATASET_META - - -def slice_dataset(dataset, start_index, end_index=None): - """ - Slice the dataset according to provided start and end indices. - - Parameters: - dataset (Dataset): The dataset to be sliced. - start_index (int): The index of the first element to include in the slice. - end_index (int, optional): The index of the last element to include in the slice. If None, slices to the end of the dataset. - - Returns: - Dataset: The sliced dataset. - """ - if end_index is not None and end_index < len(dataset): - return dataset.select(range(start_index, end_index)) - else: - return dataset.select(range(start_index, len(dataset))) - - -def main(): - synthetic_image_generator = SyntheticImageGenerator(prompt_type='annotation', - use_random_diffuser=False, - diffuser_name='stabilityai/stable-diffusion-xl-base-1.0') - - # Load the datasets specified in DATASET_META - real_image_datasets = [ - ImageDataset(ds['path'], 'train', ds.get('name', None), ds['create_splits']) - for ds in DATASET_META['real'] - ] - DIFFUSER_NAMES = ['black-forest-labs/FLUX.1-dev'] - for model_name in DIFFUSER_NAMES: - synthetic_image_generator.diffuser_name = model_name # Set the diffuser model - print(f"Testing {model_name}") - for _ in range(11): - # Sample an image from real datasets - real_dataset_index, source_dataset = sample_dataset_index_name(real_image_datasets) - real_dataset = real_image_datasets[real_dataset_index] - images_to_caption, image_indexes = real_dataset.sample(k=1) - - start = time.time() - # Generate synthetic images from sampled real images - sample = synthetic_image_generator.generate(k=1, real_images=images_to_caption)[0] - end = time.time() - - # Logging the results - time_elapsed = end - start - print(f"Model: {model_name}, Time elapsed: {time_elapsed}") - print(sample) # You may want to store these samples differently depending on your needs. - -if __name__ == "__main__": - main() diff --git a/bitmind/utils/config.py b/bitmind/utils/config.py index baa06ec7..271aab51 100644 --- a/bitmind/utils/config.py +++ b/bitmind/utils/config.py @@ -87,13 +87,6 @@ def add_args(cls, parser): parser.add_argument("--netuid", type=int, help="Subnet netuid", default=1) - parser.add_argument( - "--neuron.device", - type=str, - help="Device to run on.", - default=get_device(), - ) - parser.add_argument( "--neuron.epoch_length", type=int, @@ -148,19 +141,47 @@ def add_miner_args(cls, parser): """Add miner specific arguments to the parser.""" parser.add_argument( - "--neuron.detector_config", + "--neuron.image_detector_config", type=str, help=".yaml file name in base_miner/deepfake_detectors/configs/ to load for trained model.", default="camo.yaml", ) parser.add_argument( - "--neuron.detector", + "--neuron.image_detector", type=str, help="The DETECTOR_REGISTRY module name of the DeepfakeDetector subclass to use for inference.", default="CAMO", ) - + + parser.add_argument( + "--neuron.image_detector_device", + type=str, + help="Device to run image detection model on.", + default=get_device(), + ) + + parser.add_argument( + "--neuron.video_detector_config", + type=str, + help=".yaml file name in base_miner/deepfake_detectors/configs/ to load for trained model.", + default="tall.yaml", + ) + + parser.add_argument( + "--neuron.video_detector", + type=str, + help="The DETECTOR_REGISTRY module name of the DeepfakeDetector subclass to use for inference.", + default="TALL", + ) + + parser.add_argument( + "--neuron.video_detector_device", + type=str, + help="Device to run image detection model on.", + default=get_device(), + ) + parser.add_argument( "--neuron.name", type=str, @@ -200,6 +221,13 @@ def add_miner_args(cls, parser): def add_validator_args(cls, parser): """Add validator specific arguments to the parser.""" + parser.add_argument( + "--neuron.device", + type=str, + help="Device to run on.", + default=get_device(), + ) + parser.add_argument( "--neuron.prompt_type", type=str, @@ -207,6 +235,20 @@ def add_validator_args(cls, parser): default='annotation', ) + parser.add_argument( + "--neuron.clip_frames_min", + type=int, + help="Min number of frames for video challenge", + default=8, + ) + + parser.add_argument( + "--neuron.clip_frames_max", + type=int, + help="Max number of frames for video challenge", + default=24, + ) + parser.add_argument( "--neuron.name", type=str, diff --git a/bitmind/image_transforms.py b/bitmind/utils/image_transforms.py similarity index 69% rename from bitmind/image_transforms.py rename to bitmind/utils/image_transforms.py index 8642e910..e32d3c4c 100644 --- a/bitmind/image_transforms.py +++ b/bitmind/utils/image_transforms.py @@ -2,17 +2,18 @@ import random from PIL import Image import torchvision.transforms as transforms +import torchvision.transforms.functional as F import numpy as np import torch import cv2 -from bitmind.constants import TARGET_IMAGE_SIZE +from bitmind.validator.config import TARGET_IMAGE_SIZE + def center_crop(): def fn(img): m = min(img.size) return transforms.CenterCrop(m)(img) - return fn @@ -21,10 +22,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.params = None - def get_params(self, img, scale, ratio): - params = super().get_params(img, scale, ratio) - self.params = params - return params + def forward(self, img, crop_params=None): + if crop_params is None: + i, j, h, w = super().get_params(img, self.scale, self.ratio) + else: + i, j, h, w = crop_params + self.params = {'crop_params': (i, j, h, w)} + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) class RandomHorizontalFlipWithParams(transforms.RandomHorizontalFlip): @@ -32,12 +36,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.params = None - def forward(self, img): - if torch.rand(1) < self.p: - self.params = True + def forward(self, img, do_flip=False): + if do_flip or (torch.rand(1) < self.p): + self.params = {'do_flip': True} return transforms.functional.hflip(img) else: - self.params = False + self.params = {'do_flip': False} return img @@ -46,12 +50,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.params = None - def forward(self, img): - if torch.rand(1) < self.p: - self.params = True + def forward(self, img, do_flip=True): + if do_flip or (torch.rand(1) < self.p): + self.params = {'do_flip': True} return transforms.functional.vflip(img) else: - self.params = False + self.params = {'do_flip': False} return img @@ -60,9 +64,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.params = None - def forward(self, img): - angle = self.get_params(self.degrees) - self.params = angle + def forward(self, img, angle=None): + if angle is None: + angle = self.get_params(self.degrees) + self.params = {'angle': angle} return transforms.functional.rotate(img, angle) @@ -303,83 +308,107 @@ def __call__(self, tensor): class ComposeWithParams: - """Compose multiple transforms with parameter tracking.""" - def __init__(self, transforms): self.transforms = transforms self.params = {} - def __call__(self, img): + def __call__(self, input_data): transform_params = { RandomResizedCropWithParams: 'RandomResizedCrop', RandomHorizontalFlipWithParams: 'RandomHorizontalFlip', RandomVerticalFlipWithParams: 'RandomVerticalFlip', RandomRotationWithParams: 'RandomRotation' } - - for transform in self.transforms: - img = transform(img) - if type(transform) in transform_params: - self.params[transform_params[type(transform)]] = transform.params - return img + output_data = [] + list_input = True + if not isinstance(input_data, list): + input_data = [input_data] + list_input = False + + for img in input_data: + for t in self.transforms: + if type(t) in transform_params and transform_params[type(t)] in self.params: + params = self.params[transform_params[type(t)]] + img = t(img, **params) + else: + img = t(img) + if type(t) in transform_params: + self.params[transform_params[type(t)]] = t.params + output_data.append(img) + + if list_input: + return output_data + return output_data[0] # Transform configurations -base_transforms = transforms.Compose([ - ConvertToRGB(), - center_crop(), - transforms.Resize(TARGET_IMAGE_SIZE), - transforms.ToTensor() -]) - -random_aug_transforms = ComposeWithParams([ - ConvertToRGB(), - transforms.ToTensor(), - RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), - RandomResizedCropWithParams(TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)), - RandomHorizontalFlipWithParams(), - RandomVerticalFlipWithParams() -]) - -ucf_transforms = transforms.Compose([ - ConvertToRGB(), - center_crop(), - transforms.Resize(TARGET_IMAGE_SIZE), - CLAHE(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -]) +def get_base_transforms(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + ConvertToRGB(), + center_crop(), + transforms.Resize(target_image_size), + transforms.ToTensor() + ]) + + +def get_random_augmentations(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + ConvertToRGB(), + transforms.ToTensor(), + RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), + RandomResizedCropWithParams(TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)), + RandomHorizontalFlipWithParams(), + RandomVerticalFlipWithParams() + ]) + +def get_ucf_base_transforms(target_image_size=TARGET_IMAGE_SIZE): + return transforms.Compose([ + ConvertToRGB(), + center_crop(), + transforms.Resize(target_image_size), + CLAHE(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + +def get_tall_base_transforms(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + transforms.Resize(target_image_size), + transforms.ToTensor() + ]) # Medium difficulty transforms with mild distortions -random_aug_transforms_medium = ComposeWithParams([ - ConvertToRGB(), - transforms.ToTensor(), - RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), - RandomResizedCropWithParams(TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)), - RandomHorizontalFlipWithParams(), - RandomVerticalFlipWithParams(), - ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=1), - ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=1), - ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=1) -]) +def get_random_augmentations_medium(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + ConvertToRGB(), + transforms.ToTensor(), + RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), + RandomResizedCropWithParams(target_image_size, scale=(0.2, 1.0), ratio=(1.0, 1.0)), + RandomHorizontalFlipWithParams(), + RandomVerticalFlipWithParams(), + ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=1), + ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=1), + ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=1) + ]) # Hard difficulty transforms with more severe distortions -random_aug_transforms_hard = ComposeWithParams([ - ConvertToRGB(), - transforms.ToTensor(), - RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), - RandomResizedCropWithParams(TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)), - RandomHorizontalFlipWithParams(), - RandomVerticalFlipWithParams(), - ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=2), - ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=2), - ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=2), - ApplyDeeperForensicsDistortion('GNC', level_min=0, level_max=2), - ApplyDeeperForensicsDistortion('GB', level_min=0, level_max=2) -]) - - -def apply_augmentation_by_level(image, level_probs={ +def get_random_augmentations_hard(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + ConvertToRGB(), + transforms.ToTensor(), + RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), + RandomResizedCropWithParams(target_image_size, scale=(0.2, 1.0), ratio=(1.0, 1.0)), + RandomHorizontalFlipWithParams(), + RandomVerticalFlipWithParams(), + ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=2), + ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=2), + ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=2), + ApplyDeeperForensicsDistortion('GNC', level_min=0, level_max=2), + ApplyDeeperForensicsDistortion('GB', level_min=0, level_max=2) + ]) + + +def apply_augmentation_by_level(image, target_image_size, level_probs={ 0: 0.25, # No augmentations (base transforms) 1: 0.45, # Basic augmentations 2: 0.15, # Medium distortions @@ -423,16 +452,14 @@ def apply_augmentation_by_level(image, level_probs={ # Apply appropriate transform if level == 0: - transformed = base_transforms(image) - params = {} + tforms = get_base_transforms(target_image_size) elif level == 1: - transformed = random_aug_transforms(image) - params = random_aug_transforms.params + tforms = get_random_augmentations(target_image_size) elif level == 2: - transformed = random_aug_transforms_medium(image) - params = random_aug_transforms_medium.params + tforms = get_random_augmentations_medium(target_image_size) else: # level == 3 - transformed = random_aug_transforms_hard(image) - params = random_aug_transforms_hard.params + tforms = get_random_augmentations_hard(target_image_size) + + transformed = tforms(image) - return transformed, level, params \ No newline at end of file + return transformed, level, tforms.params diff --git a/bitmind/utils/mock.py b/bitmind/utils/mock.py index 110f6cad..bfb6639a 100644 --- a/bitmind/utils/mock.py +++ b/bitmind/utils/mock.py @@ -6,7 +6,7 @@ from typing import List from PIL import Image -from bitmind.constants import DIFFUSER_NAMES +from bitmind.validator.config import T2VIS_MODEL_NAMES as MODEL_NAMES from bitmind.validator.miner_performance_tracker import MinerPerformanceTracker @@ -43,17 +43,17 @@ def sample(self, k=1): return [self.__getitem__(i) for i in range(k)], [i for i in range(k)] -class MockSyntheticImageGenerator: - def __init__(self, prompt_type, use_random_diffuser, diffuser_name): +class MockSyntheticDataGenerator: + def __init__(self, prompt_type, use_random_t2v_model, t2v_model_name): self.prompt_type = prompt_type - self.diffuser_name = diffuser_name - self.use_random_diffuser = use_random_diffuser + self.t2v_model_name = t2v_model_name + self.use_random_t2v_model = use_random_t2v_model - def generate(self, k=1, real_images=None): - if self.use_random_diffuser: - self.load_diffuser('random') + def generate(self, k=1, real_images=None, modality='image'): + if self.use_random_t2v_model: + self.load_t2v_model('random') else: - self.load_diffuser(self.diffuser_name) + self.load_t2v_model(self.t2v_model_name) return [{ 'prompt': f'mock {self.prompt_type} prompt', @@ -61,13 +61,13 @@ def generate(self, k=1, real_images=None): 'id': i } for i in range(k)] - def load_diffuser(self, diffuser_name) -> None: + def load_diffuser(self, t2v_model_name) -> None: """ loads a huggingface diffuser model. """ - if diffuser_name == 'random': - diffuser_name = np.random.choice(DIFFUSER_NAMES, 1)[0] - self.diffuser_name = diffuser_name + if t2v_model_name == 'random': + t2v_model_name = np.random.choice(MODEL_NAMES, 1)[0] + self.t2v_model_name = t2v_model_name class MockValidator: @@ -90,7 +90,7 @@ def __init__(self, config): False) for i in range(3) ] - self.synthetic_image_generator = MockSyntheticImageGenerator( + self.synthetic_data_generator = MockSyntheticDataGenerator( prompt_type='annotation', use_random_diffuser=True, diffuser_name=None) self.total_real_images = sum([len(ds) for ds in self.real_image_datasets]) self.scores = np.zeros(self.metagraph.n, dtype=np.float32) diff --git a/bitmind/utils/video_utils.py b/bitmind/utils/video_utils.py new file mode 100644 index 00000000..ffb2a7ff --- /dev/null +++ b/bitmind/utils/video_utils.py @@ -0,0 +1,26 @@ +import torch + + +def pad_frames(x, divisible_by): + """ + Pads the tensor `x` along the frame dimension (1) until the number of frames is divisible by `divisible_by`. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, num_frames, channels, height, width). + divisible_by (int): The divisor to make the number of frames divisible by. + + Returns: + torch.Tensor: Padded tensor of shape (batch_size, adjusted_num_frames, channels, height, width). + """ + num_frames = x.shape[1] + frame_padding = (divisible_by - (num_frames % divisible_by)) % divisible_by + + if frame_padding > 0: + padding_shape = (x.shape[0], frame_padding, x.shape[2], x.shape[3], x.shape[4]) + x_padding = torch.zeros(padding_shape, device=x.device) # Ensure padding is on the same device + x = torch.cat((x, x_padding), dim=1) + + assert x.shape[1] % divisible_by == 0, ( + f'Frame number mismatch: got {x.shape[1]} frames, not divisible by {divisible_by}.' + ) + return x \ No newline at end of file diff --git a/bitmind/validator/__init__.py b/bitmind/validator/__init__.py index 0b7ddf1a..e69de29b 100644 --- a/bitmind/validator/__init__.py +++ b/bitmind/validator/__init__.py @@ -1,2 +0,0 @@ -from .forward import forward -from .reward import get_rewards diff --git a/bitmind/validator/cache/__init__.py b/bitmind/validator/cache/__init__.py new file mode 100644 index 00000000..8858fff1 --- /dev/null +++ b/bitmind/validator/cache/__init__.py @@ -0,0 +1,3 @@ +from .base_cache import BaseCache +from .image_cache import ImageCache +from .video_cache import VideoCache diff --git a/bitmind/validator/cache/base_cache.py b/bitmind/validator/cache/base_cache.py new file mode 100644 index 00000000..10b714bd --- /dev/null +++ b/bitmind/validator/cache/base_cache.py @@ -0,0 +1,261 @@ +from abc import ABC, abstractmethod +import asyncio +from datetime import datetime +from pathlib import Path +import time +from typing import Any, Dict, List, Optional, Union + +import bittensor as bt +import huggingface_hub as hf_hub +import numpy as np + +from .util import get_most_recent_update_time, seconds_to_str +from .download import download_files, list_hf_files + + +class BaseCache(ABC): + """ + Abstract base class for managing file caches with compressed sources. + + This class provides the basic infrastructure for maintaining both a compressed + source cache and an extracted cache, with automatic refresh intervals and + background update tasks. + """ + + def __init__( + self, + cache_dir: Union[str, Path], + file_extensions: List[str], + compressed_file_extension: str, + datasets: dict = None, + extracted_update_interval: int = 4, + compressed_update_interval: int = 12, + num_sources_per_dataset: int = 1, + max_compressed_size_gb: float = 100.0, + max_extracted_size_gb: float = 10.0, + ) -> None: + """ + Initialize the base cache infrastructure. + + Args: + cache_dir: Path to store extracted files + extracted_update_interval: Hours between extracted cache updates + compressed_update_interval: Hours between compressed cache updates + file_extensions: List of valid file extensions for this cache type + max_compressed_size_gb: Maximum size in GB for compressed cache directory + max_extracted_size_gb: Maximum size in GB for extracted cache directory + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True, parents=True) + + self.compressed_dir = self.cache_dir / 'sources' + self.compressed_dir.mkdir(exist_ok=True, parents=True) + + self.datasets = datasets + + self.extracted_update_interval = extracted_update_interval * 60 * 60 + self.compressed_update_interval = compressed_update_interval * 60 * 60 + self.num_sources_per_dataset = num_sources_per_dataset + self.file_extensions = file_extensions + self.compressed_file_extension = compressed_file_extension + self.max_compressed_size_bytes = max_compressed_size_gb * 1024 * 1024 * 1024 + self.max_extracted_size_bytes = max_extracted_size_gb * 1024 * 1024 * 1024 + + def start_updater(self): + """Start the background updater tasks for compressed and extracted caches.""" + if not self.datasets: + bt.logging.error("No datasets configured. Cannot start cache updater.") + return + + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + self.loop = asyncio.get_event_loop() + + # Initialize caches, blocking to ensure data are available for validator + bt.logging.info(f"Setting up cache at {self.cache_dir}") + bt.logging.info(f"Clearing incomplete sources in {self.compressed_dir}") + self._clear_incomplete_sources() + + if self._extracted_cache_empty(): + if self._compressed_cache_empty(): + bt.logging.info(f"Compressed cache {self.compressed_dir} empty; populating") + # grab 1 zip to ensure validator has available data + self._refresh_compressed_cache(n_sources_per_dataset=1, n_datasets=1) + + bt.logging.info(f"Extracted cache {self.cache_dir} empty; populating") + self._refresh_extracted_cache() + + # Start background tasks + bt.logging.info(f"Starting background tasks") + self._compressed_updater_task = self.loop.create_task( + self._run_compressed_updater() + ) + self._extracted_updater_task = self.loop.create_task( + self._run_extracted_updater() + ) + + def _get_cached_files(self) -> List[Path]: + """Get list of all extracted files in cache directory.""" + return [ + f for f in self.cache_dir.iterdir() + if f.is_file() and f.suffix.lower() in self.file_extensions + ] + + def _get_compressed_files(self) -> List[Path]: + """Get list of all compressed files in compressed directory.""" + return list(self.compressed_dir.glob(f'*{self.compressed_file_extension}')) + + def _extracted_cache_empty(self) -> bool: + """Check if extracted cache directory is empty.""" + return len(self._get_cached_files()) == 0 + + def _compressed_cache_empty(self) -> bool: + """Check if compressed cache directory is empty.""" + return len(self._get_compressed_files()) == 0 + + def _prune_compressed_cache(self) -> None: + """Check compressed cache size and remove oldest files if over limit.""" + files = self._get_compressed_files() + total_size = sum(f.stat().st_size for f in files) + bt.logging.info(f"Compressed cache size: {len(files)} files | {total_size / (1024*1024*1024):.4f} GB [{self.compressed_dir}]") + while total_size > self.max_compressed_size_bytes: + compressed_files = self._get_compressed_files() + if not compressed_files: + break + + oldest_file = min(compressed_files, key=lambda f: f.stat().st_mtime) + file_size = oldest_file.stat().st_size + + oldest_file.unlink() + total_size -= file_size + bt.logging.info(f"Removed {oldest_file.name} to stay under size limit - new cache size is {total_size / (1024*1024*1024):.4f} GB") + + def _prune_extracted_cache(self) -> None: + """Check extracted cache size and remove oldest files if over limit.""" + files = self._get_cached_files() + total_size = sum(f.stat().st_size for f in files) + bt.logging.info(f"Extracted cache size: {len(files)} files | {total_size / (1024*1024*1024):.2f} GB [{self.cache_dir}]") + while total_size > self.max_extracted_size_bytes: + extracted_files = self._get_cached_files() + if not extracted_files: + break + + oldest_file = min(extracted_files, key=lambda f: f.stat().st_mtime) + file_size = oldest_file.stat().st_size + + oldest_file.unlink() + json_file = oldest_file.with_suffix('.json') + if json_file.exists(): + json_file.unlink() + total_size -= file_size + bt.logging.info(f"Removed {oldest_file.name} to stay under size limit - new cache size is {total_size / (1024*1024*1024):.4f} GB") + + async def _run_extracted_updater(self) -> None: + """Asynchronously refresh extracted files according to update interval.""" + while True: + try: + self._prune_extracted_cache() + last_update = get_most_recent_update_time(self.cache_dir) + time_elapsed = time.time() - last_update + + if time_elapsed >= self.extracted_update_interval: + bt.logging.info(f"Refreshing cache [{self.cache_dir}]") + self._refresh_extracted_cache() + bt.logging.info(f"Cache refresh complete [{self.cache_dir}]") + + sleep_time = max(0, self.extracted_update_interval - time_elapsed) + bt.logging.info(f"Next cache refresh in {seconds_to_str(sleep_time)} [{self.compressed_dir}]") + await asyncio.sleep(sleep_time) + except Exception as e: + bt.logging.error(f"Error in extracted cache update: {e}") + await asyncio.sleep(60) + + async def _run_compressed_updater(self) -> None: + """Asynchronously refresh compressed files according to update interval.""" + while True: + try: + self._clear_incomplete_sources() + self._prune_compressed_cache() + last_update = get_most_recent_update_time(self.compressed_dir) + time_elapsed = time.time() - last_update + + if time_elapsed >= self.compressed_update_interval: + bt.logging.info(f"Refreshing cache [{self.compressed_dir}]") + self._refresh_compressed_cache() + bt.logging.info(f"Cache refresh complete [{self.cache_dir}]") + + sleep_time = max(0, self.compressed_update_interval - time_elapsed) + bt.logging.info(f"Next cache refresh in {seconds_to_str(sleep_time)} [{self.compressed_dir}]") + await asyncio.sleep(sleep_time) + except Exception as e: + bt.logging.error(f"Error in compressed cache update: {e}") + await asyncio.sleep(60) + + def _refresh_compressed_cache( + self, + n_sources_per_dataset: Optional[int] = None, + n_datasets: Optional[int] = None + ) -> None: + """ + Refresh the compressed file cache with new downloads. + """ + if n_sources_per_dataset is None: + n_sources_per_dataset = self.num_sources_per_dataset + + try: + bt.logging.info(f"{len(self._get_compressed_files())} compressed sources currently cached") + + new_files: List[Path] = [] + for dataset in self.datasets[:n_datasets]: + filenames = list_hf_files( + repo_id=dataset['path'], + extension=self.compressed_file_extension) + remote_paths = [ + f"https://huggingface.co/datasets/{dataset['path']}/resolve/main/{f}" + for f in filenames + ] + bt.logging.info(f"Downloading {n_sources_per_dataset} from {dataset['path']} to {self.compressed_dir}") + new_files += download_files( + urls=np.random.choice(remote_paths, n_sources_per_dataset), + output_dir=self.compressed_dir) + + if new_files: + bt.logging.info(f"{len(new_files)} new files added to {self.compressed_dir}") + else: + bt.logging.error(f"No new files were added to {self.compressed_dir}") + + except Exception as e: + bt.logging.error(f"Error during compressed refresh for {self.compressed_dir}: {e}") + raise + + def _refresh_extracted_cache(self, n_items_per_source: Optional[int] = None) -> None: + """Refresh the extracted cache with new selections.""" + bt.logging.info(f"{len(self._get_compressed_files())} files currently cached") + new_files = self._extract_random_items(n_items_per_source) + if new_files: + bt.logging.info(f"{len(new_files)} new files added to {self.cache_dir}") + else: + bt.logging.error(f"No new files were added to {self.cache_dir}") + + @abstractmethod + def _extract_random_items(self, n_items_per_source: Optional[int] = None) -> List[Path]: + """Remove any incomplete or corrupted source files from cache.""" + pass + + @abstractmethod + def _clear_incomplete_sources(self) -> None: + """Remove any incomplete or corrupted source files from cache.""" + pass + + @abstractmethod + def sample(self, num_samples: int) -> Optional[Dict[str, Any]]: + """Sample random items from the cache.""" + pass + + def __del__(self) -> None: + """Cleanup background tasks on deletion.""" + if hasattr(self, '_extracted_updater_task'): + self._extracted_updater_task.cancel() + if hasattr(self, '_compressed_updater_task'): + self._compressed_updater_task.cancel() diff --git a/bitmind/validator/cache/download.py b/bitmind/validator/cache/download.py new file mode 100644 index 00000000..b5d45978 --- /dev/null +++ b/bitmind/validator/cache/download.py @@ -0,0 +1,164 @@ +import requests +import os +from pathlib import Path +from requests.exceptions import RequestException +from typing import List, Union, Dict, Optional + +import bittensor as bt +import huggingface_hub as hf_hub + + +def download_files( + urls: List[str], + output_dir: Union[str, Path], + chunk_size: int = 8192 +) -> List[Path]: + """ + Downloads multiple files synchronously. + + Args: + urls: List of URLs to download + output_dir: Directory to save the files + chunk_size: Size of chunks to download at a time + + Returns: + List of successfully downloaded file paths + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + downloaded_files = [] + + for url in urls: + try: + bt.logging.info(f'Downloading {url}') + response = requests.get(url, stream=True) + if response.status_code != 200: + bt.logging.error(f'Failed to download {url}: Status {response.status_code}') + continue + + filename = os.path.basename(url) + filepath = output_dir / filename + + bt.logging.info(f'Writing to {filepath}') + with open(filepath, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + if chunk: # filter out keep-alive chunks + f.write(chunk) + + downloaded_files.append(filepath) + bt.logging.info(f'Successfully downloaded {filename}') + + except Exception as e: + bt.logging.error(f'Error downloading {url}: {str(e)}') + continue + + return downloaded_files + + +def list_hf_files(repo_id, repo_type='dataset', extension=None): + files = [] + try: + files = list(hf_hub.list_repo_files(repo_id=repo_id, repo_type=repo_type)) + if extension: + files = [f for f in files if f.endswith(extension)] + except Exception as e: + bt.logging.error(f"Failed to list files of type {extension} in {repo_id}: {e}") + return files + + +def openvid1m_err_handler( + base_zip_url: str, + output_path: Path, + part_index: int, + chunk_size: int = 8192, + timeout: int = 300 +) -> Optional[Path]: + """ + Synchronous error handler for OpenVid1M downloads that handles split files. + + Args: + base_zip_url: Base URL for the zip parts + output_path: Directory to save files + part_index: Index of the part to download + chunk_size: Size of download chunks + timeout: Download timeout in seconds + + Returns: + Path to combined file if successful, None otherwise + """ + part_urls = [ + f"{base_zip_url}{part_index}_partaa", + f"{base_zip_url}{part_index}_partab" + ] + error_log_path = output_path / "download_log.txt" + downloaded_parts = [] + + # Download each part + for part_url in part_urls: + part_file_path = output_path / Path(part_url).name + + if part_file_path.exists(): + bt.logging.warning(f"File {part_file_path} exists.") + downloaded_parts.append(part_file_path) + continue + + try: + response = requests.get(part_url, stream=True, timeout=timeout) + if response.status_code != 200: + raise RequestException( + f"HTTP {response.status_code}: {response.reason}" + ) + + with open(part_file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + if chunk: # filter out keep-alive chunks + f.write(chunk) + + bt.logging.info(f"File {part_url} saved to {part_file_path}") + downloaded_parts.append(part_file_path) + + except Exception as e: + error_message = f"File {part_url} download failed: {str(e)}\n" + bt.logging.error(error_message) + with open(error_log_path, "a") as error_log_file: + error_log_file.write(error_message) + return None + + if len(downloaded_parts) == len(part_urls): + try: + combined_file = output_path / f"OpenVid_part{part_index}.zip" + combined_data = bytearray() + for part_path in downloaded_parts: + with open(part_path, 'rb') as part_file: + combined_data.extend(part_file.read()) + + with open(combined_file, 'wb') as out_file: + out_file.write(combined_data) + + for part_path in downloaded_parts: + part_path.unlink() + + bt.logging.info(f"Successfully combined parts into {combined_file}") + return combined_file + + except Exception as e: + error_message = f"Failed to combine parts for index {part_index}: {str(e)}\n" + bt.logging.error(error_message) + with open(error_log_path, "a") as error_log_file: + error_log_file.write(error_message) + return None + + return None + + """ +data_folder = output_path / "data" / "train" +data_folder.mkdir(parents=True, exist_ok=True) +data_urls = [ + "https://huggingface.co/datasets/nkp37/OpenVid-1M/resolve/main/data/train/OpenVid-1M.csv", + "https://huggingface.co/datasets/nkp37/OpenVid-1M/resolve/main/data/train/OpenVidHD.csv" +] +for data_url in data_urls: + data_path = data_folder / Path(data_url).name + command = ["wget", "-O", str(data_path), data_url] + subprocess.run(command, check=True) +""" diff --git a/bitmind/validator/cache/extract.py b/bitmind/validator/cache/extract.py new file mode 100644 index 00000000..8dbb29cf --- /dev/null +++ b/bitmind/validator/cache/extract.py @@ -0,0 +1,197 @@ +import base64 +import hashlib +import json +import logging +import mimetypes +import os +import random +import warnings +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple +from zipfile import ZipFile + +from PIL import Image +import pyarrow.parquet as pq +import bittensor as bt + + +def extract_videos_from_zip( + zip_path: Path, + dest_dir: Path, + num_videos: int, + file_extensions: Set[str] = {'.mp4', '.avi', '.mov', '.mkv', '.wmv'}, + include_checksums: bool = True +) -> List[Tuple[str, str]]: + """ + Extract random videos and their metadata from a zip file and save them to disk. +q + Args: + zip_path: Path to the zip file + dest_dir: Directory to save videos and metadata + num_videos: Number of videos to extract + file_extensions: Set of valid video file extensions + include_checksums: Whether to calculate and include file checksums in metadata + + Returns: + List of tuples containing (video_path, metadata_path) + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + + extracted_files = [] + try: + with ZipFile(zip_path) as zip_file: + video_files = [ + f for f in zip_file.namelist() + if any(f.lower().endswith(ext) for ext in file_extensions) + ] + if not video_files: + bt.logging.warning(f"No video files found in {zip_path}") + return extracted_files + + bt.logging.info(f"{len(video_files)} video files found in {zip_path}") + selected_videos = random.sample( + video_files, + min(num_videos, len(video_files)) + ) + + bt.logging.info(f"Extracting {len(selected_videos)} randomly sampled video files from {zip_path}") + for idx, video in enumerate(selected_videos): + try: + zip_basename = zip_path.name.split('.zip')[0] + original_filename = Path(video).name + base_filename = f"{zip_basename}__{idx}_{original_filename}" + + # extract video and get metadata + video_path = dest_dir / base_filename + temp_path = Path(zip_file.extract(video, path=dest_dir)) + temp_path.rename(video_path) + + video_info = zip_file.getinfo(video) + metadata = { + 'source_zip': str(zip_path), + 'original_filename': original_filename, + 'original_path_in_zip': video, + 'extraction_date': datetime.now().isoformat(), + 'file_size': os.path.getsize(video_path), + 'mime_type': mimetypes.guess_type(video_path)[0], + 'zip_metadata': { + 'compress_size': video_info.compress_size, + 'file_size': video_info.file_size, + 'compress_type': video_info.compress_type, + 'date_time': datetime.strftime( + datetime(*video_info.date_time), + '%Y-%m-%d %H:%M:%S' + ), + } + } + + if include_checksums: + with open(video_path, 'rb') as f: + file_data = f.read() + metadata['checksums'] = { + 'md5': hashlib.md5(file_data).hexdigest(), + 'sha256': hashlib.sha256(file_data).hexdigest() + } + + metadata_filename = f"{video_path.stem}.json" + metadata_path = dest_dir / metadata_filename + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + extracted_files.append((str(video_path), str(metadata_path))) + logging.info(f"Extracted {original_filename} from {zip_path}") + + except Exception as e: + bt.logging.warning(f"Error extracting {video}: {e}") + if 'temp_path' in locals() and temp_path.exists(): + temp_path.unlink() + continue + + except Exception as e: + bt.logging.warning(f"Error processing zip file {zip_path}: {e}") + + return extracted_files + + +def extract_images_from_parquet( + parquet_path: Path, + dest_dir: Path, + num_images: int, + seed: Optional[int] = None +) -> List[Tuple[str, str]]: + """ + Extract random images and their metadata from a parquet file and save them to disk. + + Args: + parquet_path: Path to the parquet file + dest_dir: Directory to save images and metadata + num_images: Number of images to extract + columns: Specific columns to include in metadata + seed: Random seed for sampling + + Returns: + List of tuples containing (image_path, metadata_path) + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + + # read parquet file, sample random image rows + table = pq.read_table(parquet_path) + df = table.to_pandas() + sample_df = df.sample(n=min(num_images, len(df)), random_state=seed) + image_col = next((col for col in sample_df.columns if 'image' in col.lower()), None) + metadata_cols = [c for c in sample_df.columns if c != image_col] + + saved_files = [] + parquet_prefix = parquet_path.stem + for idx, row in sample_df.iterrows(): + try: + img_data = row[image_col] + if isinstance(img_data, dict): + key = next((k for k in img_data if 'bytes' in k.lower() or 'image' in k.lower()), None) + img_data = img_data[key] + + try: + img = Image.open(BytesIO(img_data)) + except Exception as e: + img_data = base64.b64decode(img_data) + img = Image.open(BytesIO(img_data)) + + base_filename = f"{parquet_prefix}__image_{idx}" + image_format = img.format.lower() if img.format else 'png' + img_filename = f"{base_filename}.{image_format}" + img_path = dest_dir / img_filename + img.save(img_path) + + metadata = { + 'source_parquet': str(parquet_path), + 'original_index': str(idx), + 'image_format': image_format, + 'image_size': img.size, + 'image_mode': img.mode + } + + for col in metadata_cols: + # Convert any non-serializable types to strings + try: + json.dumps({col: row[col]}) + metadata[col] = row[col] + except (TypeError, OverflowError): + metadata[col] = str(row[col]) + + metadata_filename = f"{base_filename}.json" + metadata_path = dest_dir / metadata_filename + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + saved_files.append(str(img_path)) + + except Exception as e: + warnings.warn(f"Failed to extract/save image {idx}: {e}") + continue + + return saved_files \ No newline at end of file diff --git a/bitmind/validator/cache/image_cache.py b/bitmind/validator/cache/image_cache.py new file mode 100644 index 00000000..2d583373 --- /dev/null +++ b/bitmind/validator/cache/image_cache.py @@ -0,0 +1,137 @@ +import os +import json +import random +from pathlib import Path +from typing import Dict, List, Optional, Union, Any + +import bittensor as bt +from PIL import Image + +from .base_cache import BaseCache +from .extract import extract_images_from_parquet +from .util import is_parquet_complete + + +class ImageCache(BaseCache): + """ + A class to manage image caching from parquet files. + + This class handles the caching, updating, and sampling of images stored + in parquet files. It maintains both a compressed cache of parquet files + and an extracted cache of images ready for processing. + """ + + def __init__( + self, + cache_dir: Union[str, Path], + datasets: Optional[dict] = None, + parquet_update_interval: int = 6, + image_update_interval: int = 1, + num_parquets_per_dataset: int = 5, + num_images_per_source: int = 100, + max_compressed_size_gb: int = 100, + max_extracted_size_gb: int = 10 + ) -> None: + """ + Args: + cache_dir: Path to store extracted images + parquet_update_interval: Hours between parquet cache updates + image_update_interval: Hours between image cache updates + num_images_per_source: Number of images to extract per parquet + """ + super().__init__( + cache_dir=cache_dir, + datasets=datasets, + extracted_update_interval=image_update_interval, + compressed_update_interval=parquet_update_interval, + num_sources_per_dataset=num_parquets_per_dataset, + file_extensions=['.jpg', '.jpeg', '.png'], + compressed_file_extension='.parquet', + max_compressed_size_gb=max_compressed_size_gb, + max_extracted_size_gb=max_extracted_size_gb + ) + self.num_images_per_source = num_images_per_source + + def _clear_incomplete_sources(self) -> None: + """Remove any incomplete or corrupted parquet files.""" + for path in self._get_compressed_files(): + if path.suffix == '.parquet' and not is_parquet_complete(path): + try: + path.unlink() + bt.logging.warning(f"Removed incomplete parquet file {path}") + except Exception as e: + bt.logging.error(f"Error removing incomplete parquet {path}: {e}") + + def _extract_random_items(self, n_items_per_source: Optional[int] = None) -> List[Path]: + """ + Extract random videos from zip files in compressed directory. + + Returns: + List of paths to extracted video files. + """ + if n_items_per_source is None: + n_items_per_source = self.num_images_per_source + + extracted_files = [] + parquet_files = self._get_compressed_files() + if not parquet_files: + bt.logging.warning(f"No parquet files found in {self.compressed_dir}") + return extracted_files + + for parquet_file in parquet_files: + try: + extracted_files += extract_images_from_parquet( + parquet_file, + self.cache_dir, + n_items_per_source + ) + except Exception as e: + bt.logging.error(f"Error processing parquet file {parquet_file}: {e}") + return extracted_files + + def sample(self, remove_from_cache=False) -> Optional[Dict[str, Any]]: + """ + Sample a random image and its metadata from the cache. + + Returns: + Dictionary containing: + - image: PIL Image + - path: Path to source file + - dataset: Source dataset name + - metadata: Metadata dict + Returns None if no valid image is available. + """ + cached_files = self._get_cached_files() + if not cached_files: + bt.logging.warning("No images available in cache") + return None + + attempts = 0 + max_attempts = len(cached_files) * 2 + + while attempts < max_attempts: + attempts += 1 + image_path = random.choice(cached_files) + + try: + image = Image.open(image_path) + metadata = json.loads(image_path.with_suffix('.json').read_text()) + if remove_from_cache: + try: + os.remove(image_path) + os.remove(image_path.with_suffix('.json')) + except Exception as e: + bt.logging.warning(f"Failed to remove files for {image_path}: {e}") + return { + 'image': image, + 'path': str(image_path), + 'dataset': metadata.get('dataset', None), + 'index': metadata.get('index', None) + } + + except Exception as e: + bt.logging.warning(f"Failed to load image {image_path}: {e}") + continue + + bt.logging.warning(f"Failed to find valid image after {attempts} attempts") + return None diff --git a/bitmind/validator/cache/util.py b/bitmind/validator/cache/util.py new file mode 100644 index 00000000..d429db48 --- /dev/null +++ b/bitmind/validator/cache/util.py @@ -0,0 +1,77 @@ +from pathlib import Path +from typing import Union, Callable +from zipfile import ZipFile, BadZipFile +from enum import Enum, auto +import asyncio +import pyarrow.parquet as pq +import bittensor as bt + + +def seconds_to_str(seconds): + seconds = int(float(seconds)) + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + seconds = seconds % 60 + return f"{hours:02}:{minutes:02}:{seconds:02}" + + +def get_most_recent_update_time(directory: Path) -> float: + """Get the most recent modification time of any file in directory.""" + try: + mtimes = [f.stat().st_mtime for f in directory.iterdir()] + return max(mtimes) if mtimes else 0 + except Exception as e: + bt.logging.error(f"Error getting modification times: {e}") + return 0 + + +class FileType(Enum): + PARQUET = auto() + ZIP = auto() + + +def get_integrity_check(file_type: FileType) -> Callable[[Path], bool]: + """Returns the appropriate validation function for the file type.""" + if file_type == FileType.PARQUET: + return is_parquet_complete + elif file_type == FileType.ZIP: + return is_zip_complete + raise ValueError(f"Unsupported file type: {file_type}") + + +def is_zip_complete(zip_path: Union[str, Path], testzip=False) -> bool: + """ + Args: + zip_path: Path to zip file + testzip: More thorough, less efficient + Returns: + bool: True if zip is valid, False otherwise + """ + try: + with ZipFile(zip_path) as zf: + if testzip: + zf.testzip() + else: + zf.namelist() + return True + except (BadZipFile, Exception) as e: + bt.logging.error(f"Zip file {zip_path} is invalid: {e}") + return False + + +def is_parquet_complete(path: Path) -> bool: + """ + Args: + path: Path to the parquet file + + Returns: + bool: True if file is valid, False otherwise + """ + try: + with open(path, 'rb') as f: + pq.read_metadata(f) + return True + except Exception as e: + bt.logging.error(f"Parquet file {path} is incomplete or corrupted: {e}") + return False + diff --git a/bitmind/validator/cache/video_cache.py b/bitmind/validator/cache/video_cache.py new file mode 100644 index 00000000..4df85c66 --- /dev/null +++ b/bitmind/validator/cache/video_cache.py @@ -0,0 +1,212 @@ +import os +import random +from io import BytesIO +from pathlib import Path +from typing import Dict, List, Optional, Union + +import bittensor as bt +import ffmpeg +from PIL import Image + +from .base_cache import BaseCache +from .extract import extract_videos_from_zip +from .util import is_zip_complete +from bitmind.validator.video_utils import get_video_duration + + +class VideoCache(BaseCache): + """ + A class to manage video caching and processing operations. + + This class handles the caching, updating, and sampling of video files from + compressed archives and optionally YouTube. It maintains both a compressed + cache of source files and an extracted cache of video files ready for processing. + """ + + def __init__( + self, + cache_dir: Union[str, Path], + datasets: Optional[dict] = None, + video_update_interval: int = 1, + zip_update_interval: int = 6, + num_zips_per_dataset: int = 1, + num_videos_per_zip: int = 10, + max_compressed_size_gb: int = 100, + max_extracted_size_gb: int = 10 + ) -> None: + """ + Initialize the VideoCache. + + Args: + cache_dir: Path to store extracted video files + video_update_interval: Hours between video cache updates + zip_update_interval: Hours between zip cache updates + num_videos_per_source: Number of videos to extract per source + use_youtube: Whether to include YouTube videos + """ + super().__init__( + cache_dir=cache_dir, + datasets=datasets, + extracted_update_interval=video_update_interval, + compressed_update_interval=zip_update_interval, + num_sources_per_dataset=num_zips_per_dataset, + file_extensions=['.mp4', '.avi', '.mov', '.mkv'], + compressed_file_extension='.zip', + max_compressed_size_gb=max_compressed_size_gb, + max_extracted_size_gb=max_extracted_size_gb + ) + self.num_videos_per_zip = num_videos_per_zip + + def _clear_incomplete_sources(self) -> None: + """Remove any incomplete or corrupted zip files from cache.""" + for path in self._get_compressed_files(): + if path.suffix == '.zip' and not is_zip_complete(path): + try: + path.unlink() + bt.logging.warning(f"Removed incomplete zip file {path}") + except Exception as e: + bt.logging.error(f"Error removing incomplete zip {path}: {e}") + + def _extract_random_items(self, n_items_per_source: Optional[int] = None) -> List[Path]: + """ + Extract random videos from zip files in compressed directory. + + Returns: + List of paths to extracted video files. + """ + if n_items_per_source is None: + n_items_per_source = self.num_videos_per_zip + + extracted_files = [] + zip_files = self._get_compressed_files() + if not zip_files: + bt.logging.warning(f"No zip files found in {self.compressed_dir}") + return extracted_files + + for zip_file in zip_files: + try: + extracted_files += extract_videos_from_zip( + zip_file, + self.cache_dir, + n_items_per_source) + except Exception as e: + bt.logging.error(f"Error processing zip file {zip_file}: {e}") + + return extracted_files + + def sample( + self, + num_frames: int = 6, + fps: Optional[float] = None, + min_fps: Optional[float] = None, + max_fps: Optional[float] = None, + remove_from_cache: bool = False + ) -> Optional[Dict[str, Union[List[Image.Image], str, float]]]: + """ + Sample random frames from a random video in the cache. + + Args: + num_frames: Number of consecutive frames to sample + fps: Fixed frames per second to sample. Mutually exclusive with min_fps/max_fps. + min_fps: Minimum frames per second when auto-calculating fps. Must be used with max_fps. + max_fps: Maximum frames per second when auto-calculating fps. Must be used with min_fps. + + Returns: + Dictionary containing: + - video: List of sampled video frames as PIL Images + - path: Path to source video file + - dataset: Name of source dataset + - total_duration: Total video duration in seconds + - sampled_length: Number of seconds sampled + Returns None if no videos are available or extraction fails. + """ + if fps is not None and (min_fps is not None or max_fps is not None): + raise ValueError("Cannot specify both fps and min_fps/max_fps") + if (min_fps is None) != (max_fps is None): + raise ValueError("min_fps and max_fps must be specified together") + + video_files = self._get_cached_files() + if not video_files: + bt.logging.warning("No videos available in cache") + return None + + video_path = random.choice(video_files) + if not Path(video_path).exists(): + bt.logging.error(f"Selected video {video_path} not found") + return None + + duration = get_video_duration(str(video_path)) + + # Use fixed fps if provided, otherwise calculate from range + frame_rate = fps + if frame_rate is None: + # For very short videos (< 1 second), use max_fps to capture detail + if duration <= 1.0: + frame_rate = max_fps + else: + # For longer videos, scale fps inversely with duration + # This ensures we don't span too much of longer videos + # while still capturing enough detail in shorter ones + target_duration = min(2.0, duration * 0.2) # Cap at 2 seconds or 20% of duration + frame_rate = (num_frames - 1) / target_duration + frame_rate = max(min_fps, min(frame_rate, max_fps)) + + sample_duration = (num_frames - 1) / frame_rate + start_time = random.uniform(0, max(0, duration - sample_duration)) + frames: List[Image.Image] = [] + + #bt.logging.info(f'Extracting {num_frames} frames at {frame_rate}fps starting at {start_time:.2f}s') + + for i in range(num_frames): + timestamp = start_time + (i / frame_rate) + + try: + # extract frames + out_bytes, err = ( + ffmpeg + .input(str(video_path), ss=str(timestamp)) + .filter('select', 'eq(n,0)') + .output( + 'pipe:', + vframes=1, + format='image2', + vcodec='png', + loglevel='error' # silence ffmpeg output + ) + .run(capture_stdout=True, capture_stderr=True) + ) + + if not out_bytes: + bt.logging.error(f'No data received for frame at {timestamp}s; Error: {err}') + continue + + try: + frame = Image.open(BytesIO(out_bytes)) + frame.load() # Verify image can be loaded + frames.append(frame) + bt.logging.debug(f'Successfully extracted frame at {timestamp}s') + except Exception as e: + bt.logging.error(f'Failed to process frame at {timestamp}s: {e}') + continue + + except ffmpeg.Error as e: + bt.logging.error(f'FFmpeg error at {timestamp}s: {e.stderr.decode()}') + continue + + if remove_from_cache: + try: + os.remove(video_path) + os.remove(video_path.with_suffix('.json')) + except Exception as e: + bt.logging.warning(f"Failed to remove files for {video_path}: {e}") + + bt.logging.success(f"Sampled {len(frames)} frames at {frame_rate}fps") + return { + 'video': frames, + 'fps': frame_rate, + 'num_frames': num_frames, + 'path': str(video_path), + 'dataset': str(Path(video_path).name.split('_')[0]), + 'total_duration': duration, + 'sampled_length': sample_duration + } diff --git a/bitmind/validator/config.py b/bitmind/validator/config.py new file mode 100644 index 00000000..d5d7e670 --- /dev/null +++ b/bitmind/validator/config.py @@ -0,0 +1,236 @@ +from pathlib import Path +from typing import Dict, List, Union, Optional, Any + +import numpy as np +import torch +from diffusers import ( + StableDiffusionPipeline, + StableDiffusionXLPipeline, + FluxPipeline, + CogVideoXPipeline, + MochiPipeline, + AnimateDiffPipeline, + EulerDiscreteScheduler +) + +from .model_utils import load_annimatediff_motion_adapter + + +TARGET_IMAGE_SIZE: tuple[int, int] = (256, 256) + +MAINNET_UID = 34 +TESTNET_UID = 168 + +# Project constants +MAINNET_WANDB_PROJECT: str = 'bitmind-subnet' +TESTNET_WANDB_PROJECT: str = 'bitmind' +WANDB_ENTITY: str = 'bitmindai' + +# Cache directories +HUGGINGFACE_CACHE_DIR: Path = Path.home() / '.cache' / 'huggingface' +SN34_CACHE_DIR: Path = Path.home() / '.cache' / 'sn34' +REAL_CACHE_DIR: Path = SN34_CACHE_DIR / 'real' +SYNTH_CACHE_DIR: Path = SN34_CACHE_DIR / 'synthetic' +REAL_VIDEO_CACHE_DIR: Path = REAL_CACHE_DIR / 'video' +REAL_IMAGE_CACHE_DIR: Path = REAL_CACHE_DIR / 'image' +SYNTH_VIDEO_CACHE_DIR: Path = SYNTH_CACHE_DIR / 'video' +SYNTH_IMAGE_CACHE_DIR: Path = SYNTH_CACHE_DIR / 'image' +VALIDATOR_INFO_PATH: Path = SN34_CACHE_DIR / 'validator.yaml' +SN34_CACHE_DIR.mkdir(parents=True, exist_ok=True) + +# Update intervals in hours +VIDEO_ZIP_CACHE_UPDATE_INTERVAL = 3 +IMAGE_PARQUET_CACHE_UPDATE_INTERVAL = 2 +VIDEO_CACHE_UPDATE_INTERVAL = 1 +IMAGE_CACHE_UPDATE_INTERVAL = 1 + +MAX_COMPRESSED_GB = 100 +MAX_EXTRACTED_GB = 10 + +CHALLENGE_TYPE = { + 0: 'real', + 1: 'synthetic' +} + +# Image datasets configuration +IMAGE_DATASETS: Dict[str, List[Dict[str, str]]] = { + "real": [ + {"path": "bitmind/bm-real"}, + {"path": "bitmind/open-image-v7-256"}, + {"path": "bitmind/celeb-a-hq"}, + {"path": "bitmind/ffhq-256"}, + {"path": "bitmind/MS-COCO-unique-256"}, + {"path": "bitmind/AFHQ"}, + {"path": "bitmind/lfw"}, + {"path": "bitmind/caltech-256"}, + {"path": "bitmind/caltech-101"}, + {"path": "bitmind/dtd"} + ] +} + +VIDEO_DATASETS = { + "real": [ + { + "path": "nkp37/OpenVid-1M", + "filetype": "zip" + }, + { + "path": "shangxd/imagenet-vidvrd", + "filetype": "zip" + } + ] +} + + +# Prompt generation model configurations +IMAGE_ANNOTATION_MODEL: str = "Salesforce/blip2-opt-6.7b-coco" +TEXT_MODERATION_MODEL: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" + +# Text-to-image model configurations +T2I_MODELS: Dict[str, Dict[str, Any]] = { + "stabilityai/stable-diffusion-xl-base-1.0": { + "pipeline_cls": StableDiffusionXLPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16, + "variant": "fp16" + }, + "use_autocast": False + }, + "SG161222/RealVisXL_V4.0": { + "pipeline_cls": StableDiffusionXLPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16, + "variant": "fp16" + } + }, + "Corcelio/mobius": { + "pipeline_cls": StableDiffusionXLPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16 + } + }, + "black-forest-labs/FLUX.1-dev": { + "pipeline_cls": FluxPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.bfloat16, + }, + "generate_args": { + "guidance_scale": 2, + "num_inference_steps": {"min": 50, "max": 125}, + "generator": torch.Generator("cuda" if torch.cuda.is_available() else "cpu"), + "height": [512, 768], + "width": [512, 768] + }, + "enable_model_cpu_offload": False + }, + "prompthero/openjourney-v4" : { + "pipeline_cls": StableDiffusionPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16, + } + }, + "cagliostrolab/animagine-xl-3.1": { + "pipeline_cls": StableDiffusionXLPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16, + } + } +} +T2I_MODEL_NAMES: List[str] = list(T2I_MODELS.keys()) + + +# Text-to-video model configurations +T2V_MODELS: Dict[str, Dict[str, Any]] = { + "genmo/mochi-1-preview": { + "pipeline_cls": MochiPipeline, + "from_pretrained_args": { + "variant": "bf16", + "torch_dtype": torch.bfloat16 + }, + "generate_args": { + "num_frames": 84 + }, + #"enable_model_cpu_offload": True, + "vae_enable_tiling": True + }, + 'THUDM/CogVideoX-5b': { + "pipeline_cls": CogVideoXPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.bfloat16 + }, + "generate_args": { + "guidance_scale": 2, + "num_videos_per_prompt": 1, + "num_inference_steps": {"min": 50, "max": 125}, + "num_frames": 48, + }, + "enable_model_cpu_offload": True, + #"enable_sequential_cpu_offload": True, + "vae_enable_slicing": True, + "vae_enable_tiling": True + }, + 'ByteDance/AnimateDiff-Lightning': { + "pipeline_cls": AnimateDiffPipeline, + "from_pretrained_args": { + "base": "emilianJR/epiCRealism", + "torch_dtype": torch.bfloat16, + "motion_adapter": load_annimatediff_motion_adapter() + }, + "generate_args": { + "guidance_scale": 2, + "num_inference_steps": {"min": 50, "max": 125}, + }, + "scheduler": { + "cls": EulerDiscreteScheduler, + "from_config_args": { + "timestep_spacing": "trailing", + "beta_schedule": "linear" + } + } + } +} +T2V_MODEL_NAMES: List[str] = list(T2V_MODELS.keys()) + +# Combined model configurations +T2VIS_MODELS: Dict[str, Dict[str, Any]] = {**T2I_MODELS, **T2V_MODELS} +T2VIS_MODEL_NAMES: List[str] = list(T2VIS_MODELS.keys()) + + +def get_modality(model_name): + if model_name in T2V_MODEL_NAMES: + return 'video' + elif model_name in T2I_MODEL_NAMES: + return 'image' + + +def select_random_t2vis_model(modality: Optional[str] = None) -> str: + """ + Select a random text-to-image or text-to-video model based on the specified + modality. + + Args: + modality: The type of model to select ('image', 'video', or 'random'). + If None or 'random', randomly chooses between image and video. + + Returns: + The name of the selected model. + + Raises: + NotImplementedError: If the specified modality is not supported. + """ + if modality is None or modality == 'random': + modality = np.random.choice(['image', 'video']) + + if modality == 'image': + return np.random.choice(T2I_MODEL_NAMES) + elif modality == 'video': + return np.random.choice(T2V_MODEL_NAMES) + else: + raise NotImplementedError(f"Unsupported modality: {modality}") diff --git a/bitmind/validator/forward.py b/bitmind/validator/forward.py index 0a3fa984..3b78b27c 100644 --- a/bitmind/validator/forward.py +++ b/bitmind/validator/forward.py @@ -1,7 +1,7 @@ # The MIT License (MIT) # Copyright © 2023 Yuma Rao # developer: dubm -# Copyright © 2023 Bitmind +# Copyright © 2023 BitMind # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the “Software”), to deal in the Software without restriction, including without limitation @@ -17,38 +17,19 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. -from PIL import Image -from io import BytesIO -from datetime import datetime -import bittensor as bt -import pandas as pd +import random +import time + import numpy as np -import os +import pandas as pd import wandb +import bittensor as bt +from bitmind.protocol import prepare_synapse +from bitmind.utils.image_transforms import apply_augmentation_by_level from bitmind.utils.uids import get_random_uids -from bitmind.utils.data import sample_dataset_index_name -from bitmind.protocol import prepare_image_synapse +from bitmind.validator.config import CHALLENGE_TYPE, MAINNET_UID, TARGET_IMAGE_SIZE from bitmind.validator.reward import get_rewards -from bitmind.image_transforms import apply_augmentation_by_level - - -def sample_random_real_image(datasets, total_images, retries=10): - random_idx = np.random.randint(0, total_images) - source, idx = sample_real_image(datasets, random_idx) - if source[idx]['image'] is None: - if retries: - return sample_random_real_image(datasets, total_images, retries-1) - return None, None - return source, idx - - -def sample_real_image(datasets, index): - cumulative_sizes = np.cumsum([len(ds) for ds in datasets]) - source_index = np.searchsorted(cumulative_sizes - 1, index % (cumulative_sizes[-1])) - source = datasets[source_index] - valid_index = index - (cumulative_sizes[source_index - 1] if source_index > 0 else 0) - return source, valid_index async def forward(self): @@ -58,115 +39,112 @@ async def forward(self): Steps are: 1. Sample miner UIDs - 2. Get an image. 50/50 chance of: - A. REAL (label = 0): Randomly sample a real image from self.real_image_datasets - B. FAKE (label = 1): Generate a synthetic image with self.random_image_generator + 2. Sample synthetic/real image/video (50/50 chance for each choice) 3. Apply random data augmentation to the image - 4. Base64 encode the image and prepare an ImageSynapse + 4. Encode data and prepare Synapse 5. Query miner axons - 6. Log results, including image and miner responses (soon to be W&B) - 7. Compute rewards and update scores + 6. Compute rewards and update scores Args: self (:obj:`bittensor.neuron.Neuron`): The neuron object which contains all the necessary state for the validator. """ - wandb_data = {} - + challenge_metadata = {} # for bookkeeping + challenge = {} # for querying miners + + modality = 'video' if np.random.rand() > 0.5 else 'image' + label = 0 if np.random.rand() > self._fake_prob else 1 + challenge_metadata['label'] = label + challenge_metadata['modality'] = modality + + bt.logging.info(f"Sampling data from {modality} cache") + cache = self.media_cache[CHALLENGE_TYPE[label]][modality] + + if modality == 'video': + num_frames = random.randint( + self.config.neuron.clip_frames_min, + self.config.neuron.clip_frames_max) + challenge = cache.sample(num_frames, min_fps=8, max_fps=30) + + elif modality == 'image': + challenge = cache.sample() + + if challenge is None: + bt.logging.warning("Waiting for cache to populate. Challenge skipped.") + return + + # prepare metadata for logging + if modality == 'video': + video_arr = np.stack([np.array(img) for img in challenge['video']], axis=0) + challenge_metadata['video'] = wandb.Video(video_arr, fps=1) + challenge_metadata['fps'] = challenge['fps'] + challenge_metadata['num_frames'] = challenge['num_frames'] + elif modality == 'image': + challenge_metadata['image'] = wandb.Image(challenge['image']) + + # update logging dict with everything except image/video data + challenge_metadata.update({k: v for k, v in challenge.items() if k != modality}) + input_data = challenge[modality] # extract video or image + + # apply data augmentation pipeline + try: + input_data, level, data_aug_params = apply_augmentation_by_level(input_data, TARGET_IMAGE_SIZE) + except Exception as e: + level, data_aug_params = -1, {} + bt.logging.error(f"Unable to applay augmentations: {e}") + + challenge_metadata['data_aug_params'] = data_aug_params + challenge_metadata['data_aug_level'] = level + + # sample miner uids for challenge miner_uids = get_random_uids(self, k=self.config.neuron.sample_size) - bt.logging.info("Generating challenge") - if np.random.rand() > self._fake_prob: - label = 0 - source_dataset, local_index = sample_random_real_image(self.real_image_datasets, self.total_real_images) - wandb_data['source_dataset'] = source_dataset.huggingface_dataset_name - wandb_data['source_image_index'] = local_index - sample = source_dataset[local_index] - - else: - label = 1 - if self.config.neuron.prompt_type == 'annotation': - retries = 10 - while retries > 0: - retries -= 1 - source_dataset, local_index = sample_random_real_image(self.real_image_datasets, self.total_real_images) - source_sample = source_dataset[local_index] - source_image = source_sample['image'] - if source_image is None: - continue - - # generate captions for the real images, then synthetic images from these captions - sample = self.synthetic_image_generator.generate( - k=1, real_images=[source_sample])[0] # {'prompt': str, 'image': PIL Image ,'id': int} - - wandb_data['model'] = self.synthetic_image_generator.diffuser_name - wandb_data['source_dataset'] = source_dataset.huggingface_dataset_name - wandb_data['source_image_index'] = local_index - wandb_data['image'] = wandb.Image(sample['image']) - wandb_data['prompt'] = sample['prompt'] - if not np.any(np.isnan(sample['image'])): - break - else: - raise NotImplementedError(f'unsupported neuron.prompt_type: {self.config.neuron.prompt_type}') - - image = sample['image'] - image, level, data_aug_params = apply_augmentation_by_level(image) - - bt.logging.info(f"Querying {len(miner_uids)} miners...") axons = [self.metagraph.axons[uid] for uid in miner_uids] + challenge_metadata['miner_uids'] = list(miner_uids) + challenge_metadata['miner_hotkeys'] = list([axon.hotkey for axon in axons]) + + # prepare synapse + synapse = prepare_synapse(input_data, modality=modality) + if self.metagraph.netuid != MAINNET_UID: + synapse.testnet_label = label + + bt.logging.info(f"Sending {modality} challenge to {len(miner_uids)} miners") + start = time.time() responses = await self.dendrite( axons=axons, - synapse=prepare_image_synapse(image=image), + synapse=synapse, deserialize=True, timeout=9 ) + bt.logging.info(f"Responses received in {time.time() - start}s") + bt.logging.success(f"{CHALLENGE_TYPE[label]} {modality} challenge complete!") + bt.logging.info({k: v for k, v in challenge_metadata.items() if k not in ('miner_uids', 'miner_hotkeys')}) + bt.logging.info(f"Scoring responses") rewards, metrics = get_rewards( label=label, responses=responses, uids=miner_uids, axons=axons, - performance_tracker=self.performance_tracker) - - # Logging image source (model for synthetic, dataset for real) and verification details - source_name = wandb_data['model'] if 'model' in wandb_data else wandb_data['source_dataset'] - bt.logging.info(f'{"real" if label == 0 else "fake"} image | source: {source_name}: {sample["id"]}') - - # Logging responses and rewards - bt.logging.info(f"Received responses: {responses}") - bt.logging.info(f"Scored responses: {rewards}") - - # Update the scores based on the rewards. + challenge_modality=modality, + performance_trackers=self.performance_trackers) + self.update_scores(rewards, miner_uids) - # update logging data - wandb_data['data_aug_params'] = data_aug_params - wandb_data['label'] = label - wandb_data['miner_uids'] = list(miner_uids) - wandb_data['miner_hotkeys'] = list([axon.hotkey for axon in axons]) - wandb_data['predictions'] = responses - wandb_data['data_aug_level'] = level - wandb_data['correct'] = [ - np.round(y_hat) == y - for y_hat, y in zip(responses, [label] * len(responses)) - ] - wandb_data['rewards'] = list(rewards) - wandb_data['scores'] = list(self.scores) - - metric_names = list(metrics[0].keys()) - for metric_name in metric_names: - wandb_data[f'miner_{metric_name}'] = [m[metric_name] for m in metrics] + for metric_name in list(metrics[0].keys()): + challenge_metadata[f'miner_{metric_name}'] = [m[metric_name] for m in metrics] + challenge_metadata['predictions'] = responses + challenge_metadata['rewards'] = rewards + challenge_metadata['scores'] = list(self.scores) + + for uid, pred, reward in zip(miner_uids, responses, rewards): + if pred != -1: + bt.logging.success(f"UID: {uid} | Prediction: {pred} | Reward: {reward}") # W&B logging if enabled if not self.config.wandb.off: - wandb.log(wandb_data) + wandb.log(challenge_metadata) # ensure state is saved after each challenge self.save_miner_history() - - # Track miners who have responded - self.last_responding_miner_uids = [] - for i, pred in enumerate(responses): - # Logging specific prediction details - if pred != -1: - bt.logging.info(f'Miner uid: {miner_uids[i]} | prediction: {pred} | correct: {np.round(pred) == label} | reward: {rewards[i]}') - self.last_responding_miner_uids.append(miner_uids[i]) + if label == 1: + cache._prune_extracted_cache() diff --git a/bitmind/validator/miner_performance_tracker.py b/bitmind/validator/miner_performance_tracker.py index 8aab630c..c5cd4b39 100644 --- a/bitmind/validator/miner_performance_tracker.py +++ b/bitmind/validator/miner_performance_tracker.py @@ -72,7 +72,7 @@ def get_metrics(self, uid: int, window: int = None): precision = precision_score(labels, predictions, zero_division=0) recall = recall_score(labels, predictions, zero_division=0) f1 = f1_score(labels, predictions, zero_division=0) - mcc = matthews_corrcoef(labels, predictions) if len(np.unique(labels)) > 1 else 0.0 + mcc = max(0, matthews_corrcoef(labels, predictions)) if len(np.unique(labels)) > 1 else 0.0 auc = roc_auc_score(labels, pred_probs) if len(np.unique(labels)) > 1 else 0.0 except Exception as e: bt.logging.warning(f'Error in reward computation: {e}') @@ -112,4 +112,4 @@ def get_prediction_count(self, uid: int) -> int: """ if uid not in self.prediction_history: return 0 - return len(self.prediction_history[uid]) \ No newline at end of file + return len(self.prediction_history[uid]) diff --git a/bitmind/validator/model_utils.py b/bitmind/validator/model_utils.py new file mode 100644 index 00000000..36b90ad0 --- /dev/null +++ b/bitmind/validator/model_utils.py @@ -0,0 +1,37 @@ +import torch +from diffusers import MotionAdapter +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + + +def load_annimatediff_motion_adapter( + step: int = 4 +) -> MotionAdapter: + """ + Load a motion adapter model for AnimateDiff. + + Args: + step: The step size for the motion adapter. Options: [1, 2, 4, 8]. + repo: The HuggingFace repository to download the motion adapter from. + ckpt: The checkpoint filename + Returns: + A loaded MotionAdapter model. + + Raises: + ValueError: If step is not one of [1, 2, 4, 8]. + """ + if step not in [1, 2, 4, 8]: + raise ValueError("Step must be one of [1, 2, 4, 8]") + + device = "cuda" if torch.cuda.is_available() else "cpu" + adapter = MotionAdapter().to(device, torch.float16) + + repo = "ByteDance/AnimateDiff-Lightning" + ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" + adapter.load_state_dict( + load_file( + hf_hub_download(repo, ckpt), + device=device + ) + ) + return adapter diff --git a/bitmind/validator/reward.py b/bitmind/validator/reward.py index 0922b075..5dc313d7 100644 --- a/bitmind/validator/reward.py +++ b/bitmind/validator/reward.py @@ -1,7 +1,7 @@ # The MIT License (MIT) # Copyright © 2023 Yuma Rao -# TODO(developer): Set your name -# Copyright © 2023 +# developer: dubm +# Copyright © 2023 BitMind # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the “Software”), to deal in the Software without restriction, including without limitation @@ -33,7 +33,8 @@ def get_rewards( responses: List[float], uids: List[int], axons: List[bt.axon], - performance_tracker, + challenge_modality: str, + performance_trackers, ) -> np.array: """ Returns an array of rewards for the given label and miner responses. @@ -43,6 +44,7 @@ def get_rewards( - responses (List[float]): A list of responses from the miners. - uids (List[int]): List of miner UIDs. - axons (List[bt.axon]): List of miner axons. + - challenge_modality (str): video or image - performance_tracker (MinerPerformanceTracker): Tracks historical performance metrics per miner. Returns: @@ -51,24 +53,35 @@ def get_rewards( miner_rewards = [] miner_metrics = [] for axon, uid, pred_prob in zip(axons, uids, responses): - try: - miner_hotkey = axon.hotkey - if uid in performance_tracker.miner_hotkeys and performance_tracker.miner_hotkeys[uid] != miner_hotkey: - bt.logging.info(f"Miner hotkey changed for UID {uid}. Resetting performance metrics.") - performance_tracker.reset_miner_history(uid, miner_hotkey) - - performance_tracker.update(uid, pred_prob, label, miner_hotkey) - metrics_100 = performance_tracker.get_metrics(uid, window=100) - metrics_10 = performance_tracker.get_metrics(uid, window=10) - reward = 0.5 * metrics_100['mcc'] + 0.5 * metrics_10['accuracy'] - reward *= compute_penalty(pred_prob) - - miner_rewards.append(reward) - miner_metrics.append(metrics_100) - - except Exception as e: - bt.logging.error(f"Couldn't calculate reward for miner {uid}, prediction: {pred_prob}, label: {label}") - bt.logging.exception(e) - miner_rewards.append(0.0) - - return np.array(miner_rewards), miner_metrics \ No newline at end of file + miner_modality_rewards = {} + miner_modality_metrics = {} + for modality in ['image', 'video']: + tracker = performance_trackers[modality] + try: + miner_hotkey = axon.hotkey + + tracked_hotkeys = tracker.miner_hotkeys + if uid in tracked_hotkeys and tracked_hotkeys[uid] != miner_hotkey: + bt.logging.info(f"Miner hotkey changed for UID {uid}. Resetting performance metrics.") + tracker.reset_miner_history(uid, miner_hotkey) + + if modality == challenge_modality: + performance_trackers[modality].update(uid, pred_prob, label, miner_hotkey) + + metrics_100 = tracker.get_metrics(uid, window=100) + metrics_10 = tracker.get_metrics(uid, window=10) + reward = 0.5 * metrics_100['mcc'] + 0.5 * metrics_10['accuracy'] + reward *= compute_penalty(pred_prob) + miner_modality_rewards[modality] = reward + miner_modality_metrics[modality] = metrics_100 + + except Exception as e: + bt.logging.error(f"Couldn't calculate reward for miner {uid}, prediction: {pred_prob}, label: {label}") + bt.logging.exception(e) + miner_rewards.append(0.0) + + total_reward = 0.05 * miner_modality_rewards['video'] + 0.95 * miner_modality_rewards['image'] + miner_rewards.append(total_reward) + miner_metrics.append(metrics_100) + + return np.array(miner_rewards), miner_metrics diff --git a/bitmind/validator/scripts/__init__.py b/bitmind/validator/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bitmind/validator/scripts/run_cache_updater.py b/bitmind/validator/scripts/run_cache_updater.py new file mode 100644 index 00000000..cfbbc524 --- /dev/null +++ b/bitmind/validator/scripts/run_cache_updater.py @@ -0,0 +1,73 @@ +import asyncio +import argparse +import bittensor as bt +from bitmind.validator.cache.image_cache import ImageCache +from bitmind.validator.cache.video_cache import VideoCache +from bitmind.validator.scripts.util import load_validator_info, init_wandb_run +from bitmind.validator.config import ( + IMAGE_DATASETS, + VIDEO_DATASETS, + IMAGE_CACHE_UPDATE_INTERVAL, + VIDEO_CACHE_UPDATE_INTERVAL, + IMAGE_PARQUET_CACHE_UPDATE_INTERVAL, + VIDEO_ZIP_CACHE_UPDATE_INTERVAL, + REAL_VIDEO_CACHE_DIR, + REAL_IMAGE_CACHE_DIR, + MAX_COMPRESSED_GB, + MAX_EXTRACTED_GB +) + + +async def main(args): + + image_cache = ImageCache( + cache_dir=args.image_cache_dir, + datasets=IMAGE_DATASETS['real'], + parquet_update_interval=args.image_parquet_interval, + image_update_interval=args.image_interval, + num_parquets_per_dataset=5, + num_images_per_source=100, + max_extracted_size_gb=MAX_EXTRACTED_GB, + max_compressed_size_gb=MAX_COMPRESSED_GB + ) + image_cache.start_updater() + + video_cache = VideoCache( + cache_dir=args.video_cache_dir, + datasets=VIDEO_DATASETS['real'], + video_update_interval=args.video_interval, + zip_update_interval=args.video_zip_interval, + num_zips_per_dataset=2, + num_videos_per_zip=50, + max_extracted_size_gb=MAX_EXTRACTED_GB, + max_compressed_size_gb=MAX_COMPRESSED_GB + ) + video_cache.start_updater() + + while True: + bt.logging.info("Caches running...") + await asyncio.sleep(600) # Status update every 10 minutes + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--video-cache-dir', type=str, default=REAL_VIDEO_CACHE_DIR, + help='Directory to cache video data') + parser.add_argument('--image-cache-dir', type=str, default=REAL_IMAGE_CACHE_DIR, + help='Directory to cache image data') + parser.add_argument('--image-interval', type=int, default=IMAGE_CACHE_UPDATE_INTERVAL, + help='Update interval for images in hours') + parser.add_argument('--image-parquet-interval', type=int, default=IMAGE_PARQUET_CACHE_UPDATE_INTERVAL, + help='Update interval for image parquet files in hours') + parser.add_argument('--video-interval', type=int, default=VIDEO_CACHE_UPDATE_INTERVAL, + help='Update interval for videos in hours') + parser.add_argument('--video-zip-interval', type=int, default=VIDEO_ZIP_CACHE_UPDATE_INTERVAL, + help='Update interval for video zip files in hours') + args = parser.parse_args() + + init_wandb_run(run_base_name='cache-updater', **load_validator_info()) + + try: + asyncio.run(main(args)) + except KeyboardInterrupt: + bt.logging.info("Shutting down cache updaters...") diff --git a/bitmind/validator/scripts/run_data_generator.py b/bitmind/validator/scripts/run_data_generator.py new file mode 100644 index 00000000..25517a3f --- /dev/null +++ b/bitmind/validator/scripts/run_data_generator.py @@ -0,0 +1,52 @@ +import argparse +import time + +import bittensor as bt + +from bitmind.validator.scripts.util import load_validator_info, init_wandb_run +from bitmind.synthetic_data_generation import SyntheticDataGenerator +from bitmind.validator.cache import ImageCache +from bitmind.validator.config import ( + REAL_IMAGE_CACHE_DIR, + SYNTH_CACHE_DIR +) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image-cache-dir', type=str, default=REAL_IMAGE_CACHE_DIR, + help='Directory containing real images to use as reference') + parser.add_argument('--output-dir', type=str, default=SYNTH_CACHE_DIR, + help='Directory to save generated synthetic data') + parser.add_argument('--device', type=str, default='cuda', + help='Device to run generation on (cuda/cpu)') + parser.add_argument('--batch-size', type=int, default=3, + help='Number of images to generate per batch') + args = parser.parse_args() + + init_wandb_run(run_base_name='data-generator', **load_validator_info()) + + image_cache = ImageCache(args.image_cache_dir) + while True: + if image_cache._extracted_cache_empty(): + bt.logging.info("SyntheticDataGenerator waiting for real image cache to populate") + time.sleep(5) + continue + bt.logging.info("Image cache was populated! Proceeding to data generation") + break + + sdg = SyntheticDataGenerator( + prompt_type='annotation', + use_random_t2vis_model=True, + device=args.device, + image_cache=image_cache, + output_dir=args.output_dir) + + bt.logging.info("Starting standalone data generator service") + sdg.batch_generate(batch_size=1) + while True: + try: + sdg.batch_generate(batch_size=args.batch_size) + except Exception as e: + bt.logging.error(f"Error in batch generation: {str(e)}") + time.sleep(5) diff --git a/bitmind/validator/scripts/util.py b/bitmind/validator/scripts/util.py new file mode 100644 index 00000000..9beb727c --- /dev/null +++ b/bitmind/validator/scripts/util.py @@ -0,0 +1,82 @@ +import time +import yaml + +import wandb +import bittensor as bt + +import bitmind +from bitmind.validator.config import ( + WANDB_ENTITY, + TESTNET_WANDB_PROJECT, + MAINNET_WANDB_PROJECT, + MAINNET_UID, + VALIDATOR_INFO_PATH +) + +def load_validator_info(max_wait: int = 300): + start_time = time.time() + while True: + try: + with open(VALIDATOR_INFO_PATH, 'r') as f: + validator_info = yaml.safe_load(f) + bt.logging.info(f"Loaded validator info from {VALIDATOR_INFO_PATH}") + return validator_info + except FileNotFoundError: + if time.time() - start_time > max_wait: + bt.logging.error(f"Validator info not found at {VALIDATOR_INFO_PATH} after waiting 5 minutes. Exiting.") + exit(1) + bt.logging.info(f"Waiting for validator info at {VALIDATOR_INFO_PATH}") + time.sleep(3) + continue + except yaml.YAMLError: + bt.logging.error(f"Could not parse validator info at {VALIDATOR_INFO_PATH}") + validator_info = { + 'uid': 'ParseError', + 'hotkey': 'ParseError', + 'full_path': 'ParseError', + 'netuid': TESTNET_WANDB_PROJECT + } + return validator_info + + +def init_wandb_run(run_base_name: str, uid: str, hotkey: str, netuid: int, full_path: str) -> None: + """ + Initialize a Weights & Biases run for tracking the validator. + + Args: + vali_uid: The validator's uid + vali_hotkey: The validator's hotkey address + netuid: The network ID (mainnet or testnet) + vali_full_path: Validator's bittensor directory + + Returns: + None + """ + run_name = f'{run_base_name}-{uid}-{bitmind.__version__}' + + config = { + 'run_name': run_name, + 'uid': uid, + 'hotkey': hotkey, + 'version': bitmind.__version__ + } + + wandb_project = TESTNET_WANDB_PROJECT + if netuid == MAINNET_UID: + wandb_project = MAINNET_WANDB_PROJECT + + # Initialize the wandb run for the single project + bt.logging.info(f"Initializing W&B run for '{WANDB_ENTITY}/{wandb_project}'") + try: + return wandb.init( + name=run_name, + project=wandb_project, + entity=WANDB_ENTITY, + config=config, + dir=full_path, + reinit=True + ) + except wandb.UsageError as e: + bt.logging.warning(e) + bt.logging.warning("Did you run wandb login?") + return \ No newline at end of file diff --git a/bitmind/validator/verify_models.py b/bitmind/validator/verify_models.py index 8aa8b61e..278a0ff3 100644 --- a/bitmind/validator/verify_models.py +++ b/bitmind/validator/verify_models.py @@ -1,6 +1,6 @@ import os -from bitmind.synthetic_image_generation.synthetic_image_generator import SyntheticImageGenerator -from bitmind.constants import DIFFUSER_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL +from bitmind.synthetic_data_generation import SyntheticDataGenerator +from bitmind.validator.config import T2VIS_MODEL_NAMES as MODEL_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL import bittensor as bt @@ -38,10 +38,9 @@ def main(): It also initializes and loads diffusers for uncached models. """ bt.logging.info("Verifying validator model downloads....") - synthetic_image_generator = SyntheticImageGenerator( + synthetic_image_generator = SyntheticDataGenerator( prompt_type='annotation', - use_random_diffuser=True, - diffuser_name=None + use_random_t2vis_model=True ) # Check and load annotation and moderation models if not cached @@ -50,14 +49,14 @@ def main(): synthetic_image_generator.image_annotation_generator.clear_gpu() # Initialize and load diffusers if not cached - for model_name in DIFFUSER_NAMES: + for model_name in MODEL_NAMES: if not is_model_cached(model_name): - synthetic_image_generator = SyntheticImageGenerator( + synthetic_image_generator = SyntheticDataGenerator( prompt_type='annotation', - use_random_diffuser=False, - diffuser_name=model_name + use_random_t2vis_model=False, + t2vis_model_name=model_name ) - synthetic_image_generator.load_diffuser(model_name) + synthetic_image_generator.load_t2vis_model(model_name) synthetic_image_generator.clear_gpu() diff --git a/bitmind/validator/video_utils.py b/bitmind/validator/video_utils.py new file mode 100644 index 00000000..3c1e7e04 --- /dev/null +++ b/bitmind/validator/video_utils.py @@ -0,0 +1,106 @@ +import tempfile +from pathlib import Path +from typing import Optional, BinaryIO, List, Union + +import bittensor as bt +import ffmpeg +import numpy as np +from moviepy.editor import VideoFileClip +from PIL import Image + +from .cache.util import seconds_to_str + + +def video_to_pil(video_path: Union[str, Path]) -> List[Image.Image]: + """Load video file and convert it to a list of PIL images. + + Args: + video_path: Path to the input video file. + + Returns: + List of PIL Image objects representing each frame of the video. + """ + clip = VideoFileClip(str(video_path)) + frames = [Image.fromarray(np.array(frame)) for frame in clip.iter_frames()] + clip.close() + return frames + + +def clip_video( + video_path: str, + start: int, + num_seconds: int +) -> Optional[BinaryIO]: + """Extract a clip from a video file. + + Args: + video_path: Path to the input video file. + start: Start time in seconds. + num_seconds: Duration of the clip in seconds. + + Returns: + A temporary file object containing the clipped video, + or None if the operation fails. + + Raises: + ffmpeg.Error: If FFmpeg encounters an error during processing. + """ + temp_fileobj = tempfile.NamedTemporaryFile(suffix=".mp4") + try: + ( + ffmpeg + .input(video_path, ss=seconds_to_str(start), t=str(num_seconds)) + .output(temp_fileobj.name, vf='fps=1') + .overwrite_output() + .run(capture_stderr=True) + ) + return temp_fileobj + except ffmpeg.Error as e: + bt.logging.error(f"FFmpeg error: {e.stderr.decode()}") + raise + + +def get_video_duration(filename: str) -> int: + """Get the duration of a video file in seconds. + + Args: + filename: Path to the video file. + + Returns: + Duration of the video in seconds. + + Raises: + KeyError: If video stream information cannot be found. + """ + metadata = ffmpeg.probe(filename) + video_stream = next( + (stream for stream in metadata['streams'] + if stream['codec_type'] == 'video'), + None + ) + if not video_stream: + raise KeyError("No video stream found in the file") + return int(float(video_stream['duration'])) + + +def copy_audio(video_path: str) -> BinaryIO: + """Extract the audio stream from a video file. + + Args: + video_path: Path to the input video file. + + Returns: + A temporary file object containing the extracted audio stream. + + Raises: + ffmpeg.Error: If FFmpeg encounters an error during processing. + """ + temp_audiofile = tempfile.NamedTemporaryFile(suffix=".aac") + ( + ffmpeg + .input(video_path) + .output(temp_audiofile.name, vn=None, acodec='copy') + .overwrite_output() + .run(quiet=True) + ) + return temp_audiofile \ No newline at end of file diff --git a/create_video_dataset_example.sh b/create_video_dataset_example.sh new file mode 100755 index 00000000..02144562 --- /dev/null +++ b/create_video_dataset_example.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# --input_dir is a directory of mp4 files +# --frames_dir is where the extracted png frames will be stored +# --dataset_dir is where the huggingface dataset (containing paths to frames) will be stored +# once your dataset is created, you can add its local path to base_miner/config.py for training +python base_miner/datasets/create_video_dataset.py --input_dir ~/.cache/sn34/real/video \ + --frames_dir ~/.cache/sn34/train_frames \ + --dataset_dir ~/.cache/sn34/train_dataset/real_frames \ + --num_videos 500 \ + --frame_rate 5 \ + --max_frames 24 \ + --dataset_name real_frames \ + --overwrite diff --git a/docs/Incentive.md b/docs/Incentive.md index c62b71f8..faf810d3 100644 --- a/docs/Incentive.md +++ b/docs/Incentive.md @@ -9,7 +9,11 @@ This document covers the current state of SN34's incentive mechanism. ## Overview -Miners are rewarded based on the accuracy of their predictions, which is a weighted combination of the MCC of their last 100 predictions and the accuracy of their last 10. Validators keep track of miner performance using a score vector, which is updated using an exponential moving average. The weights assigned by validators determine the distribution of rewards among miners, incentivizing high-quality predictions and consistent performance. +Miner rewards are a weighted combination of their performance on both video and image challenges. + +Video and image rewards are computed separately, where each is its own weighted combination of the MCC of their last 100 predictions and the accuracy of their last 10. Validators keep track of miner performance using a score vector, which is updated using an exponential moving average. The weights assigned by validators determine the distribution of rewards among miners, incentivizing high-quality predictions and consistent performance. + +Reward implementation can be found in `bitmind/validator/rewards.py`

Incentive Mechanism @@ -21,7 +25,7 @@ Miners are rewarded based on the accuracy of their predictions, which is a weigh ## Rewards -> Miners rewards are computed based on the Matthews Correlation Coefficient (MCC) (https://en.wikipedia.org/wiki/Phi_coefficient) of (up to) their last 100 predictions, combined with the accuracy of their last 10 predictions. +> Miners rewards are computed based on the [Matthews Correlation Coefficient (MCC)](https://en.wikipedia.org/wiki/Phi_coefficient) of (up to) their last 100 predictions, combined with the accuracy of their last 10 predictions. $$ 0.5 \times MCC_{100} + 0.5 \times Accuracy_{10} @@ -47,7 +51,7 @@ A low *α* value places emphasis on a miner's historical performance, addin Weight normalization by L1 norm: -$$w = \frac{\text{Score}}{\lVert\text{Score}\rVert_1}$$ +$$w = \frac{\text{V}}{\lVert\text{V}\rVert_1}$$ ## Incentives diff --git a/docs/Mining.md b/docs/Mining.md index 8ca113c8..652d45f9 100644 --- a/docs/Mining.md +++ b/docs/Mining.md @@ -45,16 +45,15 @@ chmod +x setup_miner_env.sh *Only for training -- deployed miner instances do not require access to these datasets.* -If you intend on training a miner, you can download the our open source datasets by running: +You can optionally pre-download the training datasets by running: ```bash -python bitmind/download_data.py +python base_miner/datasets/download_data.py ``` -This step is optional. If you choose not to run it, the dataset will be downloaded automatically when you run our training scripts. - -The download location of this script is `~/.cache/huggingface` +Feel free to skip this step - datasets will be downloaded automatically when you run the training scripts. +The default list of datasets and default download location are defined in `base_miner/config.py` ## Registration @@ -89,10 +88,16 @@ First, make sure to update `validator.env` with your **wallet**, **hotkey**, and ```bash -DETECTOR=CAMO # Options: CAMO, UCF, NPR -DETECTOR_CONFIG=camo.yaml # Configs live in base_miner/deepfake_detectors/configs +IMAGE_DETECTOR=CAMO # Options: CAMO, UCF, NPR, None +IMAGE_DETECTOR_CONFIG=camo.yaml # Configs live in base_miner/deepfake_detectors/configs + # Supply a filename or relative path + +VIDEO_DETECTOR=TALL # Options: TALL, None +VIDEO_DETECTOR_CONFIG=tall.yaml # Configs live in base_miner/deepfake_detectors/configs # Supply a filename or relative path -DEVICE=cpu # Options: cpu, cuda + +IMAGE_DETECTOR_DEVICE=cpu # Options: cpu, cuda +VIDEO_DETECTOR_DEVICE=cpu # Subtensor Network Configuration: NETUID=34 # Network User ID options: 34, 168 @@ -138,7 +143,7 @@ The model with the lowest validation accuracy will be saved to `base_miner/NPR/c ### UCF ```python -cd base_miner/UCF/ && python train_detector.py +cd base_miner/DFB/ && python train_detector.py --detector [UCF, TALL] --modality [image, video] ``` The model with the lowest validation accuracy will be saved to `base_miner/UCF/logs/training//`.
diff --git a/docs/Validating.md b/docs/Validating.md index d7a08227..e443e518 100644 --- a/docs/Validating.md +++ b/docs/Validating.md @@ -38,17 +38,6 @@ chmod +x setup_validator_env.sh ./setup_validator_env.sh ``` -### Data - -You can download the necessary datasets by running: - -```bash -python bitmind/download_data.py -``` - -We recommend you do this prior to registering and running your validator. Please note the minimum storage requirements specified in `min_compute.yml`. - - ## Registration To validate on our subnet, you must have a registered hotkey. @@ -73,7 +62,6 @@ You can launch your validator with `run_neuron.py`. First, make sure to update `validator.env` with your **wallet**, **hotkey**, and **validator port**. This file was created for you during setup, and is not tracked by git. ```bash -# Subtensor Network Configuration: NETUID=34 # Network User ID options: 34, 168 SUBTENSOR_NETWORK=finney # Networks: finney, test, local SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443 @@ -85,21 +73,18 @@ SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443 WALLET_NAME=default WALLET_HOTKEY=default +# Note: If you're using RunPod, you must select a port >= 70000 for symmetric mapping # Validator Port Setting: VALIDATOR_AXON_PORT=8092 +VALIDATOR_PROXY_PORT=10913 +DEVICE=cuda # API Keys: WANDB_API_KEY=your_wandb_api_key_here HUGGING_FACE_TOKEN=your_hugging_face_token_here ``` -Then, log into weights and biases by running - -```bash -wandb login -``` - -and entering your API key. If you don't have an API key, please reach out to the BitMind team via Discord and we can provide one. +If you don't have a W&B API key, please reach out to the BitMind team via Discord and we can provide one. Now you're ready to run your validator! @@ -107,7 +92,21 @@ Now you're ready to run your validator! conda activate bitmind pm2 start run_neuron.py -- --validator ``` - - Auto updates are enabled by default. To disable, run with `--no-auto-updates`. - Self-healing restarts are enabled by default (every 6 hours). To disable, run with `--no-self-heal`. + +The above command will kick off 3 `pm2` processes +``` +┌────┬───────────────────────────┬─────────────┬─────────┬─────────┬──────────┬────────┬──────┬───────────┬──────────┬──────────┬──────────┬──────────┐ +│ id │ name │ namespace │ version │ mode │ pid │ uptime │ ↺ │ status │ cpu │ mem │ user │ watching │ +├────┼───────────────────────────┼─────────────┼─────────┼─────────┼──────────┼────────┼──────┼───────────┼──────────┼──────────┼──────────┼──────────┤ +│ 2 │ bitmind_data_generator │ default │ N/A │ fork │ 2759998 │ 4s │ 0 │ online │ 0% │ 464.9mb │ user │ disabled │ +│ 1 │ bitmind_validator │ default │ N/A │ fork │ 2759978 │ 5s │ 0 │ online │ 100% │ 518.5mb │ user │ disabled │ +│ 0 │ run_neuron │ default │ N/A │ fork │ 2759928 │ 9s │ 0 │ online │ 0% │ 10.3mb │ user │ disabled │ +└────┴───────────────────────────┴─────────────┴─────────┴─────────┴──────────┴────────┴──────┴───────────┴──────────┴──────────┴──────────┴──────────┘ +``` +- `run_neuron` manages self heals and auto updates +- `bitmind_validator` is the validator process, whose hotkey, port, etc. are configured in `validator.env` +- `bitmind_data_generator` runs our synthetic data generation pipeline to produce synthetic images and videos. + - These data are stored in `~/.cache/sn34` and are sampled by the `bitmind_validator` process \ No newline at end of file diff --git a/min_compute.yml b/min_compute.yml index 9623f3d3..7532a5d7 100644 --- a/min_compute.yml +++ b/min_compute.yml @@ -56,15 +56,24 @@ compute_spec: architecture: "x86_64" # Architecture type (e.g., x86_64, arm64) gpu: - required: True # Does the application require a GPU? - min_vram: 33 # Minimum GPU VRAM (GB) - recommended_vram: 48 # Recommended GPU VRAM (GB) - min_compute_capability: 8.6 # Minimum CUDA compute capability - recommended_compute_capability: 8.6 # Recommended CUDA compute capability - recommended_gpu: "NVIDIA A40" # Recommended GPU to purchase/rent - peak_fp16_tensor_tflops: # Peak FP16 tensor TFLOPS (with FP16 accumulate) - min: 149.7 - max: 299.4 + required: True # Does the application require a GPU? + min_vram: 80 # Minimum GPU VRAM (GB) + recommended_vram: 80 # Recommended GPU VRAM (GB) + min_compute_capability: 8.0 # Minimum CUDA compute capability + recommended_compute_capability: 8.0 # Recommended CUDA compute capability + recommended_gpu: "NVIDIA A100 80GB PCIE" # Recommended GPU to purchase/rent + fp64: 9.7 # TFLOPS + fp64_tensor_core: 19.5 # TFLOPS + fp32: 19.5 # TFLOPS + tf32: 156 # TFLOPS* + bfloat16_tensor_core: 312 # TFLOPS* + int8_tensor_core: 624 # TOPS* + + # See NVIDIA A100 datasheet for details: + # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/ + # nvidia-a100-datasheet-nvidia-us-2188504-web.pdf + + # *double with sparsity memory: min_ram: 32 # Minimum RAM (GB) @@ -73,8 +82,8 @@ compute_spec: ram_type: "DDR6" # RAM type (e.g., DDR4, DDR3, etc.) storage: - min_space: 300 # Minimum free storage space (GB) - recommended_space: 500 # Recommended free storage space (GB) + min_space: 500 # Minimum free storage space (GB) + recommended_space: 600 # Recommended free storage space (GB) type: "SSD" # Preferred storage type (e.g., SSD, HDD) min_iops: 1000 # Minimum I/O operations per second (if applicable) recommended_iops: 5000 # Recommended I/O operations per second diff --git a/neurons/miner.py b/neurons/miner.py index 25c1fa36..e894bb14 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -28,9 +28,10 @@ import sys import numpy as np -from base_miner import DETECTOR_REGISTRY +from base_miner.registry import DETECTOR_REGISTRY +from base_miner.deepfake_detectors import NPRImageDetector, UCFImageDetector, CAMOImageDetector, TALLVideoDetector from bitmind.base.miner import BaseMinerNeuron -from bitmind.protocol import ImageSynapse +from bitmind.protocol import ImageSynapse, VideoSynapse, decode_video_synapse from bitmind.utils.config import get_device @@ -38,145 +39,131 @@ class Miner(BaseMinerNeuron): def __init__(self, config=None): super(Miner, self).__init__(config=config) - if self.config.neuron.device == 'auto': - self.config.neuron.device = get_device() - self.load_detector() - - def load_detector(self): - self.deepfake_detector = DETECTOR_REGISTRY[self.config.neuron.detector]( - config=self.config.neuron.detector_config, - device=self.config.neuron.device + bt.logging.info("Attaching forward function to miner axon.") + self.axon.attach( + forward_fn=self.forward_image, + blacklist_fn=self.blacklist_image, + priority_fn=self.priority_image, + ).attach( + forward_fn=self.forward_video, + blacklist_fn=self.blacklist_video, + priority_fn=self.priority_video, ) - - async def forward( + bt.logging.info(f"Axon created: {self.axon}") + + bt.logging.info("Loading image detection model if configured") + self.load_image_detector() + bt.logging.info("Loading video detection model if configured") + self.load_video_detector() + + def load_image_detector(self): + if (str(self.config.neuron.image_detector).lower() == 'none' or + str(self.config.neuron.image_detector_config).lower() == 'none'): + bt.logging.warning("No image detector configuration provided, skipping.") + self.image_detector = None + return + + if self.config.neuron.image_detector_device == 'auto': + bt.logging.warning("Automatic device configuration enabled for image detector") + self.config.neuron.image_detector_device = get_device() + + self.image_detector = DETECTOR_REGISTRY[self.config.neuron.image_detector]( + config_name=self.config.neuron.image_detector_config, + device=self.config.neuron.image_detector_device + ) + bt.logging.info(f"Loaded image detection model: {self.config.neuron.image_detector}") + + def load_video_detector(self): + if (str(self.config.neuron.video_detector).lower() == 'none' or + str(self.config.neuron.video_detector_config).lower() == 'none'): + bt.logging.warning("No video detector configuration provided, skipping.") + self.video_detector = None + return + + if self.config.neuron.video_detector_device == 'auto': + bt.logging.warning("Automatic device configuration enabled for video detector") + self.config.neuron.video_detector_device = get_device() + + self.video_detector = DETECTOR_REGISTRY[self.config.neuron.video_detector]( + config_name=self.config.neuron.video_detector_config, + device=self.config.neuron.video_detector_device + ) + bt.logging.info(f"Loaded video detection model: {self.config.neuron.video_detector}") + + async def forward_image( self, synapse: ImageSynapse ) -> ImageSynapse: """ - Loads the deepfake detection model (a PyTorch binary classifier) from the path specified in --neuron.model_path. - Processes the incoming ImageSynapse and passes the image to the loaded model for classification. - The model is loaded here, rather than in __init__, so that miners may (backup) and overwrite - their model file as a means of updating their miner's predictor. + Perform inference on image Args: - synapse (ImageSynapse): The synapse object containing the list of b64 encoded images in the + synapse (bt.Synapse): The synapse object containing the list of b64 encoded images in the 'images' field. Returns: - ImageSynapse: The synapse object with the 'predictions' field populated with a list of probabilities + bt.Synapse: The synapse object with the 'predictions' field populated with a list of probabilities """ - try: - image_bytes = base64.b64decode(synapse.image) - image = Image.open(io.BytesIO(image_bytes)) - - pred = self.deepfake_detector(image) - - synapse.prediction = pred - - except Exception as e: - bt.logging.error("Error performing inference") - bt.logging.error(e) - - bt.logging.info(f"PREDICTION: {synapse.prediction}") + if self.image_detector is None: + bt.logging.info("Image detection model not configured; skipping image challenge") + else: + bt.logging.info("Received image challenge!") + try: + image_bytes = base64.b64decode(synapse.image) + image = Image.open(io.BytesIO(image_bytes)) + synapse.prediction = self.image_detector(image) + except Exception as e: + bt.logging.error("Error performing inference") + bt.logging.error(e) + + bt.logging.info(f"PREDICTION = {synapse.prediction}") + label = synapse.testnet_label + if synapse.testnet_label != -1: + bt.logging.info(f"LABEL (testnet only) = {label}") return synapse - async def blacklist( - self, synapse: ImageSynapse - ) -> typing.Tuple[bool, str]: + async def forward_video( + self, synapse: VideoSynapse + ) -> VideoSynapse: """ - Determines whether an incoming request should be blacklisted and thus ignored. Your implementation should - define the logic for blacklisting requests based on your needs and desired security parameters. - - Blacklist runs before the synapse data has been deserialized (i.e. before synapse.data is available). - The synapse is instead contructed via the headers of the request. It is important to blacklist - requests before they are deserialized to avoid wasting resources on requests that will be ignored. - + Perform inference on video Args: - synapse (ImageSynapse): A synapse object constructed from the headers of the incoming request. + synapse (bt.Synapse): The synapse object containing the list of b64 encoded images in the + 'images' field. Returns: - Tuple[bool, str]: A tuple containing a boolean indicating whether the synapse's hotkey is blacklisted, - and a string providing the reason for the decision. + bt.Synapse: The synapse object with the 'predictions' field populated with a list of probabilities - This function is a security measure to prevent resource wastage on undesired requests. It should be enhanced - to include checks against the metagraph for entity registration, validator status, and sufficient stake - before deserialization of synapse data to minimize processing overhead. - - Example blacklist logic: - - Reject if the hotkey is not a registered entity within the metagraph. - - Consider blacklisting entities that are not validators or have insufficient stake. - - In practice it would be wise to blacklist requests from entities that are not validators, or do not have - enough stake. This can be checked via metagraph.S and metagraph.validator_permit. You can always attain - the uid of the sender via a metagraph.hotkeys.index( synapse.dendrite.hotkey ) call. - - Otherwise, allow the request to be processed further. - """ - if synapse.dendrite is None or synapse.dendrite.hotkey is None: - bt.logging.warning("Received a request without a dendrite or hotkey.") - return True, "Missing dendrite or hotkey" - - # TODO(developer): Define how miners should blacklist requests. - uid = self.metagraph.hotkeys.index(synapse.dendrite.hotkey) - if ( - not self.config.blacklist.allow_non_registered - and synapse.dendrite.hotkey not in self.metagraph.hotkeys - ): - # Ignore requests from un-registered entities. - bt.logging.trace( - f"Blacklisting un-registered hotkey {synapse.dendrite.hotkey}" - ) - return True, "Unrecognized hotkey" - - if self.config.blacklist.force_validator_permit: - # If the config is set to force validator permit, then we should only allow requests from validators. - if not self.metagraph.validator_permit[uid]: - bt.logging.warning( - f"Blacklisting a request from non-validator hotkey {synapse.dendrite.hotkey}" - ) - return True, "Non-validator hotkey" - - bt.logging.trace( - f"Not Blacklisting recognized hotkey {synapse.dendrite.hotkey}" - ) - return False, "Hotkey recognized!" - - async def priority(self, synapse: ImageSynapse) -> float: """ - The priority function determines the order in which requests are handled. More valuable or higher-priority - requests are processed before others. You should design your own priority mechanism with care. - - This implementation assigns priority to incoming requests based on the calling entity's stake in the metagraph. + if self.video_detector is None: + bt.logging.info("Video detection model not configured; skipping video challenge") + else: + bt.logging.info("Received video challenge!") + try: + frames_tensor = decode_video_synapse(synapse) + frames_tensor = frames_tensor.to(self.config.neuron.video_detector_device) + synapse.prediction = self.video_detector(frames_tensor) + except Exception as e: + bt.logging.error("Error performing inference") + bt.logging.error(e) + + bt.logging.info(f"PREDICTION = {synapse.prediction}") + label = synapse.testnet_label + if synapse.testnet_label != -1: + bt.logging.info(f"LABEL (testnet only) = {label}") + return synapse - Args: - synapse (ImageSynapse): The synapse object that contains metadata about the incoming request. + async def blacklist_image(self, synapse: ImageSynapse) -> typing.Tuple[bool, str]: + return await self.blacklist(synapse) - Returns: - float: A priority score derived from the stake of the calling entity. + async def blacklist_video(self, synapse: VideoSynapse) -> typing.Tuple[bool, str]: + return await self.blacklist(synapse) - Miners may recieve messages from multiple entities at once. This function determines which request should be - processed first. Higher values indicate that the request should be processed first. Lower values indicate - that the request should be processed later. + async def priority_image(self, synapse: ImageSynapse) -> float: + return await self.priority(synapse) - Example priority logic: - - A higher stake results in a higher priority value. - """ - if synapse.dendrite is None or synapse.dendrite.hotkey is None: - bt.logging.warning("Received a request without a dendrite or hotkey.") - return 0.0 - - # TODO(developer): Define how miners should prioritize requests. - caller_uid = self.metagraph.hotkeys.index( - synapse.dendrite.hotkey - ) # Get the caller index. - - prirority = float( - self.metagraph.S[caller_uid] - ) # Return the stake as the priority. - bt.logging.trace( - f"Prioritizing {synapse.dendrite.hotkey} with value: ", prirority - ) - return prirority + async def priority_video(self, synapse: VideoSynapse) -> float: + return await self.priority(synapse) def save_state(self): pass diff --git a/neurons/validator.py b/neurons/validator.py index 15c38fbd..27f830b2 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -18,15 +18,28 @@ # DEALINGS IN THE SOFTWARE. import bittensor as bt +import yaml import wandb import time from neurons.validator_proxy import ValidatorProxy -from bitmind.validator import forward +from bitmind.validator.forward import forward +from bitmind.validator.cache import VideoCache, ImageCache from bitmind.base.validator import BaseValidatorNeuron -from bitmind.synthetic_image_generation.synthetic_image_generator import SyntheticImageGenerator -from bitmind.image_dataset import ImageDataset -from bitmind.constants import VALIDATOR_DATASET_META, WANDB_PROJECT, WANDB_ENTITY +from bitmind.validator.config import ( + MAINNET_UID, + MAINNET_WANDB_PROJECT, + TESTNET_WANDB_PROJECT, + IMAGE_DATASETS, + VIDEO_DATASETS, + WANDB_ENTITY, + REAL_VIDEO_CACHE_DIR, + REAL_IMAGE_CACHE_DIR, + SYNTH_IMAGE_CACHE_DIR, + SYNTH_VIDEO_CACHE_DIR, + VALIDATOR_INFO_PATH +) + import bitmind @@ -50,25 +63,26 @@ def __init__(self, config=None): self.last_responding_miner_uids = [] self.validator_proxy = ValidatorProxy(self) - - bt.logging.info("init_wandb()") - self.init_wandb() - - bt.logging.info("Loading real datasets") - self.real_image_datasets = [ - ImageDataset(ds['path'], 'train', ds.get('name', None)) - for ds in VALIDATOR_DATASET_META['real'] - ] - self.total_real_images = sum([ - len(ds) for ds in self.real_image_datasets - ]) - - self.synthetic_image_generator = SyntheticImageGenerator( - prompt_type='annotation', - use_random_diffuser=True, - diffuser_name=None, - device=self.config.neuron.device) + # real media caches are updated by the bitmind_cache_updater process (started by start_validator.sh) + self.real_media_cache = { + 'image': ImageCache(REAL_IMAGE_CACHE_DIR), + 'video': VideoCache(REAL_VIDEO_CACHE_DIR) + } + + # synthetic media caches are populated by the SyntheticDataGenerator process (started by start_validator.sh) + self.synthetic_media_cache = { + 'image': ImageCache(SYNTH_IMAGE_CACHE_DIR), + 'video': VideoCache(SYNTH_VIDEO_CACHE_DIR) + } + + self.media_cache = { + 'real': self.real_media_cache, + 'synthetic': self.synthetic_media_cache, + } + + self.init_wandb() + self.store_vali_info() self._fake_prob = self.config.get('fake_prob', 0.5) async def forward(self): @@ -93,12 +107,16 @@ def init_wandb(self): self.config.version = bitmind.__version__ self.config.type = self.neuron_type + wandb_project = TESTNET_WANDB_PROJECT + if self.config.netuid == MAINNET_UID: + wandb_project = MAINNET_WANDB_PROJECT + # Initialize the wandb run for the single project - print("Initializing W&B") + bt.logging.info(f"Initializing W&B run for '{WANDB_ENTITY}/{wandb_project}'") try: run = wandb.init( name=run_name, - project=WANDB_PROJECT, + project=wandb_project, entity=WANDB_ENTITY, config=self.config, dir=self.config.full_path, @@ -114,7 +132,23 @@ def init_wandb(self): self.config.signature = signature wandb.config.update(self.config, allow_val_change=True) - bt.logging.success(f"Started wandb run for project '{WANDB_PROJECT}'") + bt.logging.success(f"Started wandb run {run_name}") + + def store_vali_info(self): + """ + Stores the uid, hotkey and netuid of the currently running vali instance. + The SyntheticDataGenerator process reads this to name its w&b run + """ + validator_info = { + 'uid': self.uid, + 'hotkey': self.wallet.hotkey.ss58_address, + 'netuid': self.config.netuid, + 'full_path': self.config.neuron.full_path + } + with open(VALIDATOR_INFO_PATH, 'w') as f: + yaml.safe_dump(validator_info, f, indent=4) + + bt.logging.info(f"Wrote validator info to {VALIDATOR_INFO_PATH}") # The main function parses the configuration and runs the validator. @@ -124,4 +158,4 @@ def init_wandb(self): with Validator() as validator: while True: bt.logging.info(f"Validator running | uid {validator.uid} | {time.time()}") - time.sleep(5) + time.sleep(30) diff --git a/neurons/validator_proxy.py b/neurons/validator_proxy.py index 1cc9c53c..a4bc478d 100644 --- a/neurons/validator_proxy.py +++ b/neurons/validator_proxy.py @@ -20,12 +20,15 @@ import socket import base64 -from bitmind.image_transforms import base_transforms +from bitmind.validator.config import TARGET_IMAGE_SIZE +from bitmind.utils.image_transforms import get_base_transforms from bitmind.protocol import ImageSynapse, prepare_image_synapse from bitmind.utils.uids import get_random_uids from bitmind.validator.proxy import ProxyCounter import bitmind +base_transforms = get_base_transforms(TARGET_IMAGE_SIZE) + def preprocess_image(b64_image): image_bytes = base64.b64decode(b64_image) diff --git a/requirements-miner.txt b/requirements-miner.txt deleted file mode 100644 index 11da3b9e..00000000 --- a/requirements-miner.txt +++ /dev/null @@ -1,5 +0,0 @@ -tensorboardx==2.6.2.2 -dlib==19.24.6 -imutils==0.5.4 -scikit-image==0.24.0 -ultralytics==8.2.86 diff --git a/requirements-validator.txt b/requirements-validator.txt deleted file mode 100644 index 45bad1fa..00000000 --- a/requirements-validator.txt +++ /dev/null @@ -1,5 +0,0 @@ -httpx==0.27.0 -diffusers==0.30.0 -transformers==4.46.3 -sentencepiece==0.2.0 -bitsandbytes==0.43.3 diff --git a/requirements.txt b/requirements.txt index ed827814..ff3bda3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,44 @@ +# Core ML frameworks bittensor==7.4.0 torch==2.4.0 -scikit-learn==1.5.1 -loguru==0.7.2 +torchvision==0.19.0 +torchaudio==2.4.0 tensorflow==2.17.0 -opencv-python==4.10.0.84 +tf-keras==2.17.0 + +# ML utilities +scikit-learn==1.5.1 +scikit-image==0.24.0 numpy==1.26.4 -pillow==10.4.0 pandas==2.2.2 -diffusers==0.30.0 matplotlib==3.9.2 -torchvision==0.19.0 -torchaudio==2.4.0 -datasets==2.20.0 + +# Deep learning tools +transformers==4.46.3 +#git+https://github.com/huggingface/diffusers.git@6a51427b6a226591ccc40249721c486855f53e1c#egg=diffusers accelerate==0.33.0 -tensorboardx==2.6.2.2 -tf-keras==2.17.0 +bitsandbytes==0.43.3 +sentencepiece==0.2.0 +timm==1.0.11 +einops==0.8.0 +ultralytics==8.2.86 + +# Image/Video processing +opencv-python==4.10.0.84 +pillow==10.4.0 +imageio==2.35.1 +imageio-ffmpeg==0.5.1 +moviepy==1.0.3 +av==13.1.0 +ffmpeg-python==0.2.0 +pyffmpeg==2.4.2.18.1 +imutils==0.5.4 +dlib==19.24.6 + +# Data and logging +datasets==2.20.0 wandb==0.17.6 +tensorboardx==2.6.2.2 +loguru==0.7.2 +httpx==0.27.0 +yt-dlp==2024.11.4 diff --git a/run_neuron.py b/run_neuron.py index ac88215c..8105cb6d 100644 --- a/run_neuron.py +++ b/run_neuron.py @@ -8,7 +8,7 @@ import argparse # self heal restart interval -RESTART_INTERVAL_HOURS = 6 +RESTART_INTERVAL_HOURS = 3 def should_update_local(local_commit, remote_commit): diff --git a/setup_env.sh b/setup_env.sh new file mode 100755 index 00000000..18a561e3 --- /dev/null +++ b/setup_env.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +########################################### +# System Updates and Package Installation # +########################################### + +# Update system +sudo apt update -y + +# Install core dependencies +sudo apt install -y \ + python3-pip \ + nano \ + libgl1 \ + npm \ + ffmpeg \ + unzip + +# Install build dependencies +sudo apt install -y \ + build-essential \ + cmake \ + libopenblas-dev \ + liblapack-dev \ + libx11-dev \ + libgtk-3-dev + +# Install process manager +sudo npm install -g pm2@latest + +############################ +# Python Package Installation +############################ + +pip install -e . +pip install git+https://github.com/huggingface/diffusers.git@6a51427b6a226591ccc40249721c486855f53e1c + +############################ +# Environment Files Setup # +############################ + +# Create miner.env if it doesn't exist +if [ -f "miner.env" ]; then + echo "File 'miner.env' already exists. Skipping creation." +else + cat > miner.env << 'EOL' +# Default options +#-------------------- + +# Detector Configuration +IMAGE_DETECTOR=CAMO # Options: CAMO, UCF, NPR, None +IMAGE_DETECTOR_CONFIG=camo.yaml # Configs in base_miner/deepfake_detectors/configs +VIDEO_DETECTOR=TALL # Options: TALL, None +VIDEO_DETECTOR_CONFIG=tall.yaml # Configs in base_miner/deepfake_detectors/configs + +# Device Settings +IMAGE_DETECTOR_DEVICE=cpu # Options: cpu, cuda +VIDEO_DETECTOR_DEVICE=cpu + +# Subtensor Network Configuration +NETUID=34 # Network User ID options: 34, 168 +SUBTENSOR_NETWORK=finney # Networks: finney, test, local +SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443 + # Endpoints: + # - wss://entrypoint-finney.opentensor.ai:443 + # - wss://test.finney.opentensor.ai:443/ + +# Wallet Configuration +WALLET_NAME=default +WALLET_HOTKEY=default + +# Miner Settings +MINER_AXON_PORT=8091 +BLACKLIST_FORCE_VALIDATOR_PERMIT=True # Force validator permit for blacklisting +EOL + echo "File 'miner.env' created." +fi + +# Create validator.env if it doesn't exist +if [ -f "validator.env" ]; then + echo "File 'validator.env' already exists. Skipping creation." +else + cat > validator.env << 'EOL' +# Default options +#-------------------- + +# Subtensor Network Configuration +NETUID=34 # Network User ID options: 34, 168 +SUBTENSOR_NETWORK=finney # Networks: finney, test, local +SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443 + # Endpoints: + # - wss://entrypoint-finney.opentensor.ai:443 + # - wss://test.finney.opentensor.ai:443/ + +# Wallet Configuration +WALLET_NAME=default +WALLET_HOTKEY=default + +# Validator Settings +VALIDATOR_AXON_PORT=8092 # If using RunPod, must be >= 70000 for symmetric mapping +VALIDATOR_PROXY_PORT=10913 +DEVICE=cuda + +# API Keys +WANDB_API_KEY=your_wandb_api_key_here +HUGGING_FACE_TOKEN=your_hugging_face_token_here +EOL + echo "File 'validator.env' created." +fi + +echo "Environment setup completed successfully." \ No newline at end of file diff --git a/setup_miner_env.sh b/setup_miner_env.sh deleted file mode 100755 index 004479d2..00000000 --- a/setup_miner_env.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash - -# Update system and install required packages -sudo apt update -y -sudo apt install python3-pip -y -sudo apt install nano -y -sudo apt install libgl1 -y -sudo apt install npm -y -sudo npm install pm2@latest -g -sudo apt install build-essential cmake -y -sudo apt install libopenblas-dev liblapack-dev -y -sudo apt install libx11-dev libgtk-3-dev -y - -# Install Python dependencies -pip install -e . -pip install -r requirements-miner.txt - -echo "# Default options: -DETECTOR=CAMO # Options: CAMO, UCF, NPR -DETECTOR_CONFIG=camo.yaml # Configs live in base_miner/deepfake_detectors/configs - # Supply a filename or relative path -DEVICE=cpu # Options: cpu, cuda - -# Subtensor Network Configuration: -NETUID=34 # Network User ID options: 34, 168 -SUBTENSOR_NETWORK=finney # Networks: finney, test, local -SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443 - # Endpoints: - # - wss://entrypoint-finney.opentensor.ai:443 - # - wss://test.finney.opentensor.ai:443/ - -# Wallet Configuration: -WALLET_NAME=default -WALLET_HOTKEY=default - -# Miner Settings: -MINER_AXON_PORT=8091 -BLACKLIST_FORCE_VALIDATOR_PERMIT=True # Default setting to force validator permit for blacklisting" > miner.env - diff --git a/setup_validator_env.sh b/setup_validator_env.sh deleted file mode 100755 index 225bb195..00000000 --- a/setup_validator_env.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -# Update system and install required packages -sudo apt update -y -sudo apt install python3-pip -y -sudo apt install nano -y -sudo apt install libgl1 -y -sudo apt install npm -y -sudo npm install pm2@latest -g - -# Install Python dependencies -pip install -e . -pip install -r requirements-validator.txt - -# Check if validator.env exists -if [ -f "validator.env" ]; then - echo "File 'validator.env' already exists. Skipping creation." -else - echo "# Default options: - -# Subtensor Network Configuration: -NETUID=34 # Network User ID options: 34, 168 -SUBTENSOR_NETWORK=finney # Networks: finney, test, local -SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443 - # Endpoints: - # - wss://entrypoint-finney.opentensor.ai:443 - # - wss://test.finney.opentensor.ai:443/ - -# Wallet Configuration: -WALLET_NAME=default -WALLET_HOTKEY=default - -# Note: If you're using RunPod, you must select a port >= 70000 for symmetric mapping -# Validator Port Setting: -VALIDATOR_AXON_PORT=8092 -VALIDATOR_PROXY_PORT=10913 -DEVICE=cuda - -# API Keys: -WANDB_API_KEY=your_wandb_api_key_here -HUGGING_FACE_TOKEN=your_hugging_face_token_here" > validator.env - echo "File 'validator.env' created." -fi diff --git a/start_miner.sh b/start_miner.sh index 42b9c6a5..dd3c94f7 100755 --- a/start_miner.sh +++ b/start_miner.sh @@ -1,21 +1,21 @@ #!/bin/bash -# Load environment variables from .env file set -a source miner.env set +a -# Check if the process is already running if pm2 list | grep -q "bitmind_miner"; then echo "Process 'bitmind_miner' is already running. Deleting it..." pm2 delete bitmind_miner fi -# Start the process with arguments from environment variables pm2 start neurons/miner.py --name bitmind_miner -- \ - --neuron.detector $DETECTOR \ - --neuron.detector_config $DETECTOR_CONFIG \ - --neuron.device $DEVICE \ + --neuron.image_detector ${IMAGE_DETECTOR:-None} \ + --neuron.image_detector_config ${IMAGE_DETECTOR_CONFIG:-None} \ + --neuron.image_detector_device ${IMAGE_DETECTOR_DEVICE:-None} \ + --neuron.video_detector ${VIDEO_DETECTOR:-None} \ + --neuron.video_detector_config ${VIDEO_DETECTOR_CONFIG:-None} \ + --neuron.video_detector_device ${VIDEO_DETECTOR_DEVICE:-None} \ --netuid $NETUID \ --subtensor.network $SUBTENSOR_NETWORK \ --subtensor.chain_endpoint $SUBTENSOR_CHAIN_ENDPOINT \ diff --git a/start_validator.sh b/start_validator.sh index c1997ce6..9ca0d22f 100755 --- a/start_validator.sh +++ b/start_validator.sh @@ -1,14 +1,17 @@ #!/bin/bash -# Load environment variables from .env file +# Load environment variables from .env file & set defaults set -a source validator.env set +a -# Set default values for environment variables : ${VALIDATOR_PROXY_PORT:=10913} : ${DEVICE:=cuda} +VALIDATOR_PROCESS_NAME="bitmind_validator" +DATA_GEN_PROCESS_NAME="bitmind_data_generator" +CACHE_UPDATE_PROCESS_NAME="bitmind_cache_updater" + # Login to Weights & Biases if ! wandb login $WANDB_API_KEY; then echo "Failed to login to Weights & Biases with the provided API key." @@ -21,10 +24,10 @@ if ! huggingface-cli login --token $HUGGING_FACE_TOKEN; then exit 1 fi -# Check if the process is already running -if pm2 list | grep -q "bitmind_validator"; then - echo "Process 'bitmind_validator' is already running. Deleting it..." - pm2 delete bitmind_validator +# VALIDATOR PROCESS +if pm2 list | grep -q "$VALIDATOR_PROCESS_NAME"; then + echo "Process '$VALIDATOR_PROCESS_NAME' is already running. Deleting it..." + pm2 delete $VALIDATOR_PROCESS_NAME fi echo "Verifying access to synthetic image generation models. This may take a few minutes." @@ -33,13 +36,31 @@ if ! python3 bitmind/validator/verify_models.py; then exit 1 fi -# Start the process with arguments from environment variables -pm2 start neurons/validator.py --name bitmind_validator -- \ +echo "Starting validator process" +pm2 start neurons/validator.py --name $VALIDATOR_PROCESS_NAME -- \ --netuid $NETUID \ --subtensor.network $SUBTENSOR_NETWORK \ --subtensor.chain_endpoint $SUBTENSOR_CHAIN_ENDPOINT \ --wallet.name $WALLET_NAME \ --wallet.hotkey $WALLET_HOTKEY \ --axon.port $VALIDATOR_AXON_PORT \ - --proxy.port $VALIDATOR_PROXY_PORT \ - --neuron.device $DEVICE + --proxy.port $VALIDATOR_PROXY_PORT + +# REAL DATA CACHE UPDATER PROCESS +if pm2 list | grep -q "$CACHE_UPDATE_PROCESS_NAME"; then + echo "Process '$CACHE_UPDATE_PROCESS_NAME' is already running. Deleting it..." + pm2 delete $CACHE_UPDATE_PROCESS_NAME +fi + +echo "Starting real data cache updater process" +pm2 start bitmind/validator/scripts/run_cache_updater.py --name $CACHE_UPDATE_PROCESS_NAME + +# SYNTHETIC DATA GENERATOR PROCESS +if pm2 list | grep -q "$DATA_GEN_PROCESS_NAME"; then + echo "Process '$DATA_GEN_PROCESS_NAME' is already running. Deleting it..." + pm2 delete $DATA_GEN_PROCESS_NAME +fi + +echo "Starting synthetic data generation process" +pm2 start bitmind/validator/scripts/run_data_generator.py --name $DATA_GEN_PROCESS_NAME -- \ + --device $DEVICE diff --git a/tests/fixtures/image_transforms.py b/tests/fixtures/image_transforms.py index 985dce77..e4f35a71 100644 --- a/tests/fixtures/image_transforms.py +++ b/tests/fixtures/image_transforms.py @@ -1,8 +1,8 @@ from functools import partial import torchvision.transforms as transforms -from bitmind.constants import TARGET_IMAGE_SIZE -from bitmind.image_transforms import ( +from bitmind.validator.config import TARGET_IMAGE_SIZE +from bitmind.utils.image_transforms import ( center_crop, RandomResizedCropWithParams, RandomHorizontalFlipWithParams, @@ -10,8 +10,8 @@ RandomRotationWithParams, ConvertToRGB, ComposeWithParams, - base_transforms, - random_aug_transforms + get_base_transforms, + get_random_augmentations ) @@ -25,6 +25,6 @@ ] TRANSFORM_PIPELINES = [ - base_transforms, - random_aug_transforms + get_base_transforms(TARGET_IMAGE_SIZE), + get_random_augmentations(TARGET_IMAGE_SIZE) ] \ No newline at end of file diff --git a/tests/validator/test_generate_image.py b/tests/validator/test_generate_image.py index 4cc8728e..f4cd1705 100644 --- a/tests/validator/test_generate_image.py +++ b/tests/validator/test_generate_image.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, MagicMock -from bitmind.synthetic_image_generation.synthetic_image_generator import SyntheticImageGenerator -from bitmind.constants import DIFFUSER_NAMES +from bitmind.synthetic_data_generation.synthetic_data_generator import SyntheticDataGenerator +from bitmind.validator.config import T2I_MODEL_NAMES from PIL import Image @@ -38,7 +38,7 @@ def mock_image_annotation_generator(): yield instance -@pytest.mark.parametrize("diffuser_name", DIFFUSER_NAMES) +@pytest.mark.parametrize("diffuser_name", T2I_MODEL_NAMES) def test_generate_image_with_diffusers(mock_diffuser, mock_image_annotation_generator, diffuser_name): """ Test the image generation process using different diffusion models. @@ -64,7 +64,7 @@ def test_generate_image_with_diffusers(mock_diffuser, mock_image_annotation_gene - Validating the image generation process - Integration testing with different diffuser models """ - generator = SyntheticImageGenerator( + generator = SyntheticDataGenerator( prompt_type='annotation', use_random_diffuser=False, diffuser_name=diffuser_name diff --git a/tests/validator/test_verify_models.py b/tests/validator/test_verify_models.py index ef14cbce..8e669692 100644 --- a/tests/validator/test_verify_models.py +++ b/tests/validator/test_verify_models.py @@ -2,7 +2,7 @@ import os from unittest.mock import patch, MagicMock, call from bitmind.validator.verify_models import is_model_cached, main -from bitmind.constants import DIFFUSER_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL +from bitmind.validator.config import T2I_MODEL_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL @pytest.fixture def mock_expanduser(): @@ -88,7 +88,7 @@ def test_main(mock_is_model_cached, MockSyntheticImageGenerator): # Expected calls with varying parameters based on model type expected_calls = [ call(prompt_type='annotation', use_random_diffuser=True, diffuser_name=None), # For IMAGE_ANNOTATION_MODEL and TEXT_MODERATION_MODEL - *[call(prompt_type='annotation', use_random_diffuser=False, diffuser_name=name) for name in DIFFUSER_NAMES] # For each name in DIFFUSER_NAMES + *[call(prompt_type='annotation', use_random_diffuser=False, diffuser_name=name) for name in T2I_MODEL_NAMES] # For each name in T2I_MODEL_NAMES ] # Verify all calls to SyntheticImageGenerator with the correct parameters From cf51d56b4fb537963babafa783b540ccdc8f6458 Mon Sep 17 00:00:00 2001 From: Dylan Uys Date: Thu, 5 Dec 2024 18:46:46 -0800 Subject: [PATCH 2/3] Proxy Hotfix (#129) * new synapse prep fn * correcting arg --- neurons/validator_proxy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neurons/validator_proxy.py b/neurons/validator_proxy.py index a4bc478d..5a9d35b1 100644 --- a/neurons/validator_proxy.py +++ b/neurons/validator_proxy.py @@ -22,7 +22,7 @@ from bitmind.validator.config import TARGET_IMAGE_SIZE from bitmind.utils.image_transforms import get_base_transforms -from bitmind.protocol import ImageSynapse, prepare_image_synapse +from bitmind.protocol import ImageSynapse, prepare_synapse from bitmind.utils.uids import get_random_uids from bitmind.validator.proxy import ProxyCounter import bitmind @@ -146,7 +146,7 @@ async def forward(self, request: Request): bt.logging.info(f"[ORGANIC] Querying {len(miner_uids)} miners...") predictions = await self.dendrite( axons=[metagraph.axons[uid] for uid in miner_uids], - synapse=prepare_image_synapse(image=image), + synapse=prepare_synapse(image, modality='image'), deserialize=True, timeout=9 ) From f32fe3344a7d5f118beae07c966e45f267b74ae7 Mon Sep 17 00:00:00 2001 From: Dylan Uys Date: Thu, 5 Dec 2024 21:25:38 -0800 Subject: [PATCH 3/3] Update Validating.md --- docs/Validating.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/Validating.md b/docs/Validating.md index e443e518..265c5c9f 100644 --- a/docs/Validating.md +++ b/docs/Validating.md @@ -34,8 +34,8 @@ Install the remaining necessary requirements with the following chained command. ```bash conda activate bitmind export PIP_NO_CACHE_DIR=1 -chmod +x setup_validator_env.sh -./setup_validator_env.sh +chmod +x setup_env.sh +./setup_env.sh ``` ## Registration @@ -109,4 +109,4 @@ The above command will kick off 3 `pm2` processes - `run_neuron` manages self heals and auto updates - `bitmind_validator` is the validator process, whose hotkey, port, etc. are configured in `validator.env` - `bitmind_data_generator` runs our synthetic data generation pipeline to produce synthetic images and videos. - - These data are stored in `~/.cache/sn34` and are sampled by the `bitmind_validator` process \ No newline at end of file + - These data are stored in `~/.cache/sn34` and are sampled by the `bitmind_validator` process