diff --git a/pyroengine/engine.py b/pyroengine/engine.py index 72ee262c..86af6326 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -21,6 +21,8 @@ from requests.exceptions import ConnectionError from requests.models import Response +from pyroengine.utils import box_iou, nms + from .vision import Classifier __all__ = ["Engine"] @@ -97,7 +99,7 @@ def __init__( cam_creds: Optional[Dict[str, Dict[str, str]]] = None, latitude: Optional[float] = None, longitude: Optional[float] = None, - alert_relaxation: int = 3, + nb_consecutive_frames: int = 4, frame_size: Optional[Tuple[int, int]] = None, cache_backup_period: int = 60, frame_saving_period: Optional[int] = None, @@ -127,7 +129,7 @@ def __init__( # Cache & relaxation self.frame_saving_period = frame_saving_period - self.alert_relaxation = alert_relaxation + self.nb_consecutive_frames = nb_consecutive_frames self.frame_size = frame_size self.jpeg_quality = jpeg_quality self.cache_backup_period = cache_backup_period @@ -138,11 +140,15 @@ def __init__( # Var initialization self._states: Dict[str, Dict[str, Any]] = { - "-1": {"consec": 0, "frame_count": 0, "ongoing": False}, + "-1": {"last_predictions": deque([], self.nb_consecutive_frames), "frame_count": 0, "ongoing": False}, } if isinstance(cam_creds, dict): for cam_id in cam_creds: - self._states[cam_id] = {"consec": 0, "frame_count": 0, "ongoing": False} + self._states[cam_id] = { + "last_predictions": deque([], self.nb_consecutive_frames), + "frame_count": 0, + "ongoing": False, + } # Restore pending alerts cache self._alerts: deque = deque([], cache_size) @@ -153,7 +159,7 @@ def __init__( def clear_cache(self) -> None: """Clear local cache""" - for file in self._cache.rglob("*"): + for file in self._cache.rglob("pending*"): file.unlink() def _dump_cache(self) -> None: @@ -178,6 +184,7 @@ def _dump_cache(self) -> None: "frame_path": str(self._cache.joinpath(f"pending_frame{idx}.jpg")), "cam_id": info["cam_id"], "ts": info["ts"], + "localization": info["localization"], } ) @@ -202,27 +209,49 @@ def heartbeat(self, cam_id: str) -> Response: """Updates last ping of device""" return self.api_client[cam_id].heartbeat() - def _update_states(self, conf: float, cam_key: str) -> bool: + def _update_states(self, frame: Image.Image, preds: np.array, cam_key: str) -> bool: """Updates the detection states""" - # Detection - if conf >= self.conf_thresh: - # Don't increment beyond relaxation - if not self._states[cam_key]["ongoing"] and self._states[cam_key]["consec"] < self.alert_relaxation: - self._states[cam_key]["consec"] += 1 - if self._states[cam_key]["consec"] == self.alert_relaxation: - self._states[cam_key]["ongoing"] = True + conf_th = self.conf_thresh * self.nb_consecutive_frames + # Reduce threshold once we are in alert mode to collect more data + if self._states[cam_key]["ongoing"]: + conf_th *= 0.8 + + # Get last predictions + boxes = np.zeros((0, 5)) + boxes = np.concatenate([boxes, preds]) + for _, box, _, _, _ in self._states[cam_key]["last_predictions"]: + if box.shape[0] > 0: + boxes = np.concatenate([boxes, box]) + + conf = 0 + output_predictions = np.zeros((0, 5)) + # Get the best ones + if boxes.shape[0]: + best_boxes = nms(boxes) + ious = box_iou(best_boxes[:, :4], boxes[:, :4]) + best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T]) + combine_predictions = best_boxes[best_boxes_scores > conf_th, :] + conf = np.max(best_boxes_scores) / self.nb_consecutive_frames + + # if current predictions match with combine predictions send match else send combine predcition + ious = box_iou(combine_predictions[:, :4], preds[:, :4])[0] + if np.sum(ious) > 0: + output_predictions = preds + else: + output_predictions = combine_predictions + + self._states[cam_key]["last_predictions"].append( + (frame, preds, str(json.dumps(output_predictions.tolist())), datetime.utcnow().isoformat(), False) + ) - return self._states[cam_key]["ongoing"] - # No wildfire + # update state + if conf > self.conf_thresh: + self._states[cam_key]["ongoing"] = True else: - if self._states[cam_key]["consec"] > 0: - self._states[cam_key]["consec"] -= 1 - # Consider event as finished - if self._states[cam_key]["consec"] == 0: - self._states[cam_key]["ongoing"] = False + self._states[cam_key]["ongoing"] = False - return False + return conf def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: """Computes the confidence that the image contains wildfire cues @@ -245,28 +274,29 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: # Reduce image size to save bandwidth if isinstance(self.frame_size, tuple): frame_resize = frame.resize(self.frame_size[::-1], Image.BILINEAR) + else: + frame_resize = frame if is_day_time(self._cache, frame, self.day_time_strategy): # Inference with ONNX preds = self.model(frame.convert("RGB")) - if len(preds) == 0: - conf = 0 - localization = "" - else: - conf = float(np.max(preds[:, -1])) - localization = str(json.dumps(preds.tolist())) + conf = self._update_states(frame_resize, preds, cam_key) # Log analysis result device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else "" - pred_str = "Wildfire detected" if conf >= self.conf_thresh else "No wildfire" + pred_str = "Wildfire detected" if conf > self.conf_thresh else "No wildfire" logging.info(f"{device_str}{pred_str} (confidence: {conf:.2%})") # Alert - - to_be_staged = self._update_states(conf, cam_key) - if to_be_staged and len(self.api_client) > 0 and isinstance(cam_id, str): + if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str): # Save the alert in cache to avoid connection issues - self._stage_alert(frame_resize, cam_id, localization) + for idx, (frame, preds, localization, ts, is_staged) in enumerate( + self._states[cam_key]["last_predictions"] + ): + if not is_staged: + self._stage_alert(frame, cam_id, ts, localization) + self._states[cam_key]["last_predictions"][idx] = frame, preds, localization, ts, True + else: conf = 0 # return default value @@ -310,13 +340,13 @@ def _upload_frame(self, cam_id: str, media_data: bytes) -> Response: return response - def _stage_alert(self, frame: Image.Image, cam_id: str, localization: str) -> None: + def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, localization: str) -> None: # Store information in the queue self._alerts.append( { "frame": frame, "cam_id": cam_id, - "ts": datetime.utcnow().isoformat(), + "ts": ts, "media_id": None, "alert_id": None, "localization": localization, diff --git a/pyroengine/vision.py b/pyroengine/vision.py index 53e2d2ec..85801ec0 100644 --- a/pyroengine/vision.py +++ b/pyroengine/vision.py @@ -74,5 +74,7 @@ def __call__(self, pil_img: Image.Image) -> np.ndarray: if len(y) > 0: y[:, :4:2] /= self.img_size[1] y[:, 1:4:2] /= self.img_size[0] + else: + y = np.zeros((0, 5)) # normalize output return y diff --git a/src/run.py b/src/run.py index 9bdf5fb9..a2b6a1d1 100644 --- a/src/run.py +++ b/src/run.py @@ -55,7 +55,7 @@ def main(args): frame_saving_period=args.save_period // args.period, cache_folder=args.cache, backup_size=args.backup_size, - alert_relaxation=args.alert_relaxation, + nb_consecutive_frames=args.nb_consecutive_frames, frame_size=args.frame_size, cache_backup_period=args.cache_backup_period, cache_size=args.cache_size, @@ -94,10 +94,10 @@ def main(args): parser.add_argument("--jpeg_quality", type=int, default=80, help="Jpeg compression") parser.add_argument("--cache-size", type=int, default=20, help="Maximum number of alerts to save in cache") parser.add_argument( - "--alert_relaxation", + "--nb-consecutive_frames", type=int, default=3, - help="Number of consecutive positive detections required to send the first alert", + help="Number of consecutive frames to combine for prediction", ) parser.add_argument( "--cache_backup_period", type=int, default=60, help="Number of minutes between each cache backup to disk" diff --git a/tests/test_engine.py b/tests/test_engine.py index c1e3d17d..15fcb852 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -4,6 +4,7 @@ from pathlib import Path from dotenv import load_dotenv +from PIL import Image from pyroengine.engine import Engine @@ -16,7 +17,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): # Cache saving _ts = datetime.utcnow().isoformat() - engine._stage_alert(mock_wildfire_image, 0, localization="dummy") + engine._stage_alert(mock_wildfire_image, 0, datetime.utcnow().isoformat(), localization="dummy") assert len(engine._alerts) == 1 assert engine._alerts[0]["ts"] < datetime.utcnow().isoformat() and _ts < engine._alerts[0]["ts"] assert engine._alerts[0]["media_id"] is None @@ -32,6 +33,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): "frame_path": str(engine._cache.joinpath("pending_frame0.jpg")), "cam_id": 0, "ts": engine._alerts[0]["ts"], + "localization": "dummy", } # Overrites cache files engine._dump_cache() @@ -42,21 +44,40 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): engine.clear_cache() # inference - engine = Engine(alert_relaxation=3, cache_folder=folder) + engine = Engine(nb_consecutive_frames=4, cache_folder=folder) out = engine.predict(mock_forest_image) assert isinstance(out, float) and 0 <= out <= 1 - assert engine._states["-1"]["consec"] == 0 - out = engine.predict(mock_wildfire_image) + assert len(engine._states["-1"]["last_predictions"]) == 1 + assert engine._states["-1"]["frame_count"] == 0 + assert engine._states["-1"]["ongoing"] is False + assert isinstance(engine._states["-1"]["last_predictions"][0][0], Image.Image) + assert engine._states["-1"]["last_predictions"][0][1].shape[0] == 0 + assert engine._states["-1"]["last_predictions"][0][1].shape[1] == 5 + assert engine._states["-1"]["last_predictions"][0][2] == "[]" + assert engine._states["-1"]["last_predictions"][0][3] < datetime.utcnow().isoformat() + assert engine._states["-1"]["last_predictions"][0][4] is False - assert isinstance(out, float) and 0 <= out <= 1 - assert engine._states["-1"]["consec"] == 1 - # Alert relaxation - assert not engine._states["-1"]["ongoing"] out = engine.predict(mock_wildfire_image) - assert engine._states["-1"]["consec"] == 2 + assert isinstance(out, float) and 0 <= out <= 1 + assert len(engine._states["-1"]["last_predictions"]) == 2 + assert engine._states["-1"]["ongoing"] is False + assert isinstance(engine._states["-1"]["last_predictions"][0][0], Image.Image) + assert engine._states["-1"]["last_predictions"][1][1].shape[0] > 0 + assert engine._states["-1"]["last_predictions"][1][1].shape[1] == 5 + assert engine._states["-1"]["last_predictions"][1][2] == "[]" + assert engine._states["-1"]["last_predictions"][1][3] < datetime.utcnow().isoformat() + assert engine._states["-1"]["last_predictions"][1][4] is False + out = engine.predict(mock_wildfire_image) - assert engine._states["-1"]["consec"] == 3 - assert engine._states["-1"]["ongoing"] + assert isinstance(out, float) and 0 <= out <= 1 + assert len(engine._states["-1"]["last_predictions"]) == 3 + assert engine._states["-1"]["ongoing"] is True + assert isinstance(engine._states["-1"]["last_predictions"][0][0], Image.Image) + assert engine._states["-1"]["last_predictions"][2][1].shape[0] > 0 + assert engine._states["-1"]["last_predictions"][2][1].shape[1] == 5 + assert len(engine._states["-1"]["last_predictions"][-1][2].split(" ")) == 5 + assert engine._states["-1"]["last_predictions"][2][3] < datetime.utcnow().isoformat() + assert engine._states["-1"]["last_predictions"][2][4] is False def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image): @@ -76,7 +97,7 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image cam_creds=cam_creds, latitude=float(lat), longitude=float(lon), - alert_relaxation=2, + nb_consecutive_frames=4, frame_saving_period=3, cache_folder=folder, frame_size=(256, 384), @@ -90,10 +111,14 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image assert start_ts < json_respone["last_ping"] < ts # Send an alert engine.predict(mock_wildfire_image, "dummy_cam") - assert len(engine._alerts) == 0 and engine._states["dummy_cam"]["consec"] == 1 - assert engine._states["dummy_cam"]["frame_count"] == 1 + assert len(engine._states["dummy_cam"]["last_predictions"]) == 1 + assert len(engine._alerts) == 0 + assert engine._states["dummy_cam"]["ongoing"] is False + engine.predict(mock_wildfire_image, "dummy_cam") - assert engine._states["dummy_cam"]["consec"] == 2 and engine._states["dummy_cam"]["ongoing"] + assert len(engine._states["dummy_cam"]["last_predictions"]) == 2 + + assert engine._states["dummy_cam"]["ongoing"] is True assert engine._states["dummy_cam"]["frame_count"] == 2 # Check that a media and an alert have been registered assert len(engine._alerts) == 0