From e46b035c8c427ac795a6a8a6e81165db80ae07b3 Mon Sep 17 00:00:00 2001 From: Breakthrough Date: Mon, 18 Nov 2024 22:04:17 -0500 Subject: [PATCH] [detectors] Implement Koala-36M Add `KoalaDetector` and `detect-koala` command. #441 --- dist/requirements_windows.txt | 1 + requirements.txt | 1 + requirements_headless.txt | 3 +- scenedetect/_cli/__init__.py | 14 ++++ scenedetect/detectors/__init__.py | 1 + scenedetect/detectors/koala_detector.py | 85 +++++++++++++++++++++++++ tests/test_detectors.py | 2 + 7 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 scenedetect/detectors/koala_detector.py diff --git a/dist/requirements_windows.txt b/dist/requirements_windows.txt index a5debd55..1623bee4 100644 --- a/dist/requirements_windows.txt +++ b/dist/requirements_windows.txt @@ -2,6 +2,7 @@ av==13.1.0 click>=8.0 opencv-python-headless==4.10.0.84 +scikit-image==0.24.0 imageio-ffmpeg moviepy diff --git a/requirements.txt b/requirements.txt index 2c45e1a5..7ddebe6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ opencv-python platformdirs pytest>=7.0 tqdm +scikit-image diff --git a/requirements_headless.txt b/requirements_headless.txt index 4dfedd38..2ff962b0 100644 --- a/requirements_headless.txt +++ b/requirements_headless.txt @@ -7,4 +7,5 @@ numpy opencv-python-headless platformdirs pytest>=7.0 -tqdm +scikit-image +tqdm \ No newline at end of file diff --git a/scenedetect/_cli/__init__.py b/scenedetect/_cli/__init__.py index c26b6263..9e536a78 100644 --- a/scenedetect/_cli/__init__.py +++ b/scenedetect/_cli/__init__.py @@ -42,6 +42,7 @@ ContentDetector, HashDetector, HistogramDetector, + KoalaDetector, ThresholdDetector, ) from scenedetect.platform import get_cv2_imwrite_params, get_system_version_info @@ -1577,3 +1578,16 @@ def save_qp_command( scenedetect.add_command(list_scenes_command) scenedetect.add_command(save_images_command) scenedetect.add_command(split_video_command) + + +@click.command("detect-koala", cls=Command, help="""WIP""") +@click.pass_context +def detect_koala_command( + ctx: click.Context, +): + ctx = ctx.obj + assert isinstance(ctx, CliContext) + ctx.add_detector(KoalaDetector, {"min_scene_len": None}) + + +scenedetect.add_command(detect_koala_command) diff --git a/scenedetect/detectors/__init__.py b/scenedetect/detectors/__init__.py index a87a5689..0856bc3c 100644 --- a/scenedetect/detectors/__init__.py +++ b/scenedetect/detectors/__init__.py @@ -40,6 +40,7 @@ from scenedetect.detectors.adaptive_detector import AdaptiveDetector from scenedetect.detectors.hash_detector import HashDetector from scenedetect.detectors.histogram_detector import HistogramDetector +from scenedetect.detectors.koala_detector import KoalaDetector # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # diff --git a/scenedetect/detectors/koala_detector.py b/scenedetect/detectors/koala_detector.py new file mode 100644 index 00000000..8cd956a8 --- /dev/null +++ b/scenedetect/detectors/koala_detector.py @@ -0,0 +1,85 @@ +# +# PySceneDetect: Python-Based Video Scene Detector +# ------------------------------------------------------------------- +# [ Site: https://scenedetect.com ] +# [ Docs: https://scenedetect.com/docs/ ] +# [ Github: https://github.com/Breakthrough/PySceneDetect/ ] +# +# Copyright (C) 2014-2024 Brandon Castellano . +# PySceneDetect is licensed under the BSD 3-Clause License; see the +# included LICENSE file, or visit one of the above pages for details. +# +""":class:`KoalaDetector` uses the detection method described by Koala-36M. +See https://koala36m.github.io/ for details. + +TODO: Cite correctly. + +This detector is available from the command-line as the `detect-koala` command. +""" + +import typing as ty + +import cv2 +import numpy as np +from skimage.metrics import structural_similarity + +from scenedetect.scene_detector import SceneDetector + + +class KoalaDetector(SceneDetector): + def __init__(self, min_scene_len: int): + self._start_frame_num: int = None + self._min_scene_len: int = min_scene_len + self._last_histogram: np.ndarray = None + self._last_edges: np.ndarray = None + self._scores: ty.List[ty.List[int]] = [] + + def process_frame(self, frame_num: int, frame_img: np.ndarray) -> ty.List[int]: + frame_img = cv2.resize(frame_img, (256, 256)) + histogram = np.asarray( + [cv2.calcHist([c], [0], None, [254], [1, 255]) for c in cv2.split(frame_img)] + ) + frame_gray = cv2.resize(cv2.cvtColor(frame_img, cv2.COLOR_BGR2GRAY), (128, 128)) + edges = np.maximum(frame_gray, cv2.Canny(frame_gray, 100, 200)) + if self._start_frame_num is not None: + delta_histogram = cv2.compareHist(self._last_histogram, histogram, cv2.HISTCMP_CORREL) + delta_edges = structural_similarity(self._last_edges, edges, data_range=255) + score = 4.61480465 * delta_histogram + 3.75211168 * delta_edges - 5.485968377115124 + self._scores.append(score) + if self._start_frame_num is None: + self._start_frame_num = frame_num + self._last_histogram = histogram + self._last_edges = edges + return [] + + def post_process(self, frame_num: int) -> ty.List[int]: + self._scores = np.asarray(self._scores) + num_frames = len(self._scores) + convolution = self._scores.copy() + convolution[1:-1] = np.convolve(convolution, np.array([1, 1, 1]) / 3.0, mode="valid") + cut_found = np.zeros(num_frames + 1, bool) + cut_found[-1] = True + WINDOW_SIZE = 8 + for cut in range(num_frames): + if self._scores[cut] < 0 or cut < WINDOW_SIZE: + cut_found[cut] = True + continue + if convolution[cut] < 0.75: + window = convolution[cut - WINDOW_SIZE : cut] + window = np.sort(window)[int(WINDOW_SIZE * 0.2) : int(WINDOW_SIZE * 0.8)] + mu = window.mean() + std = window.std() + if convolution[cut] < mu - 3 * max(0.2, std): + cut_found[cut] = True + cuts = [] + last_cut = 0 + last_filtered_cut = self._start_frame_num + for cut in range(WINDOW_SIZE, len(cut_found)): + if cut_found[cut]: + if (cut - last_cut) > WINDOW_SIZE: + cut = self._start_frame_num + last_cut + if (cut - last_filtered_cut) >= self._min_scene_len: + cuts.append(cut) + last_filtered_cut = cut + last_cut = cut + 1 + return cuts diff --git a/tests/test_detectors.py b/tests/test_detectors.py index 109872be..c8b19002 100644 --- a/tests/test_detectors.py +++ b/tests/test_detectors.py @@ -29,6 +29,7 @@ ContentDetector, HashDetector, HistogramDetector, + KoalaDetector, ThresholdDetector, ) @@ -37,6 +38,7 @@ ContentDetector, HashDetector, HistogramDetector, + KoalaDetector, ) ALL_DETECTORS: ty.Tuple[ty.Type[SceneDetector]] = (*FAST_CUT_DETECTORS, ThresholdDetector)