Skip to content

Commit

Permalink
New alerte strategy (#164)
Browse files Browse the repository at this point in the history
* new alert system

* fix unitests
  • Loading branch information
MateoLostanlen authored Jul 31, 2023
1 parent d5a4d8c commit 6501eec
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 52 deletions.
98 changes: 64 additions & 34 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from requests.exceptions import ConnectionError
from requests.models import Response

from pyroengine.utils import box_iou, nms

from .vision import Classifier

__all__ = ["Engine"]
Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(
cam_creds: Optional[Dict[str, Dict[str, str]]] = None,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
alert_relaxation: int = 3,
nb_consecutive_frames: int = 4,
frame_size: Optional[Tuple[int, int]] = None,
cache_backup_period: int = 60,
frame_saving_period: Optional[int] = None,
Expand Down Expand Up @@ -127,7 +129,7 @@ def __init__(

# Cache & relaxation
self.frame_saving_period = frame_saving_period
self.alert_relaxation = alert_relaxation
self.nb_consecutive_frames = nb_consecutive_frames
self.frame_size = frame_size
self.jpeg_quality = jpeg_quality
self.cache_backup_period = cache_backup_period
Expand All @@ -138,11 +140,15 @@ def __init__(

# Var initialization
self._states: Dict[str, Dict[str, Any]] = {
"-1": {"consec": 0, "frame_count": 0, "ongoing": False},
"-1": {"last_predictions": deque([], self.nb_consecutive_frames), "frame_count": 0, "ongoing": False},
}
if isinstance(cam_creds, dict):
for cam_id in cam_creds:
self._states[cam_id] = {"consec": 0, "frame_count": 0, "ongoing": False}
self._states[cam_id] = {
"last_predictions": deque([], self.nb_consecutive_frames),
"frame_count": 0,
"ongoing": False,
}

# Restore pending alerts cache
self._alerts: deque = deque([], cache_size)
Expand All @@ -153,7 +159,7 @@ def __init__(

def clear_cache(self) -> None:
"""Clear local cache"""
for file in self._cache.rglob("*"):
for file in self._cache.rglob("pending*"):
file.unlink()

def _dump_cache(self) -> None:
Expand All @@ -178,6 +184,7 @@ def _dump_cache(self) -> None:
"frame_path": str(self._cache.joinpath(f"pending_frame{idx}.jpg")),
"cam_id": info["cam_id"],
"ts": info["ts"],
"localization": info["localization"],
}
)

Expand All @@ -202,27 +209,49 @@ def heartbeat(self, cam_id: str) -> Response:
"""Updates last ping of device"""
return self.api_client[cam_id].heartbeat()

def _update_states(self, conf: float, cam_key: str) -> bool:
def _update_states(self, frame: Image.Image, preds: np.array, cam_key: str) -> bool:
"""Updates the detection states"""
# Detection
if conf >= self.conf_thresh:
# Don't increment beyond relaxation
if not self._states[cam_key]["ongoing"] and self._states[cam_key]["consec"] < self.alert_relaxation:
self._states[cam_key]["consec"] += 1

if self._states[cam_key]["consec"] == self.alert_relaxation:
self._states[cam_key]["ongoing"] = True
conf_th = self.conf_thresh * self.nb_consecutive_frames
# Reduce threshold once we are in alert mode to collect more data
if self._states[cam_key]["ongoing"]:
conf_th *= 0.8

# Get last predictions
boxes = np.zeros((0, 5))
boxes = np.concatenate([boxes, preds])
for _, box, _, _, _ in self._states[cam_key]["last_predictions"]:
if box.shape[0] > 0:
boxes = np.concatenate([boxes, box])

conf = 0
output_predictions = np.zeros((0, 5))
# Get the best ones
if boxes.shape[0]:
best_boxes = nms(boxes)
ious = box_iou(best_boxes[:, :4], boxes[:, :4])
best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T])
combine_predictions = best_boxes[best_boxes_scores > conf_th, :]
conf = np.max(best_boxes_scores) / self.nb_consecutive_frames

# if current predictions match with combine predictions send match else send combine predcition
ious = box_iou(combine_predictions[:, :4], preds[:, :4])[0]
if np.sum(ious) > 0:
output_predictions = preds
else:
output_predictions = combine_predictions

self._states[cam_key]["last_predictions"].append(
(frame, preds, str(json.dumps(output_predictions.tolist())), datetime.utcnow().isoformat(), False)
)

return self._states[cam_key]["ongoing"]
# No wildfire
# update state
if conf > self.conf_thresh:
self._states[cam_key]["ongoing"] = True
else:
if self._states[cam_key]["consec"] > 0:
self._states[cam_key]["consec"] -= 1
# Consider event as finished
if self._states[cam_key]["consec"] == 0:
self._states[cam_key]["ongoing"] = False
self._states[cam_key]["ongoing"] = False

return False
return conf

def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
"""Computes the confidence that the image contains wildfire cues
Expand All @@ -245,28 +274,29 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
# Reduce image size to save bandwidth
if isinstance(self.frame_size, tuple):
frame_resize = frame.resize(self.frame_size[::-1], Image.BILINEAR)
else:
frame_resize = frame

if is_day_time(self._cache, frame, self.day_time_strategy):
# Inference with ONNX
preds = self.model(frame.convert("RGB"))
if len(preds) == 0:
conf = 0
localization = ""
else:
conf = float(np.max(preds[:, -1]))
localization = str(json.dumps(preds.tolist()))
conf = self._update_states(frame_resize, preds, cam_key)

# Log analysis result
device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else ""
pred_str = "Wildfire detected" if conf >= self.conf_thresh else "No wildfire"
pred_str = "Wildfire detected" if conf > self.conf_thresh else "No wildfire"
logging.info(f"{device_str}{pred_str} (confidence: {conf:.2%})")

# Alert

to_be_staged = self._update_states(conf, cam_key)
if to_be_staged and len(self.api_client) > 0 and isinstance(cam_id, str):
if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str):
# Save the alert in cache to avoid connection issues
self._stage_alert(frame_resize, cam_id, localization)
for idx, (frame, preds, localization, ts, is_staged) in enumerate(
self._states[cam_key]["last_predictions"]
):
if not is_staged:
self._stage_alert(frame, cam_id, ts, localization)
self._states[cam_key]["last_predictions"][idx] = frame, preds, localization, ts, True

else:
conf = 0 # return default value

Expand Down Expand Up @@ -310,13 +340,13 @@ def _upload_frame(self, cam_id: str, media_data: bytes) -> Response:

return response

def _stage_alert(self, frame: Image.Image, cam_id: str, localization: str) -> None:
def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, localization: str) -> None:
# Store information in the queue
self._alerts.append(
{
"frame": frame,
"cam_id": cam_id,
"ts": datetime.utcnow().isoformat(),
"ts": ts,
"media_id": None,
"alert_id": None,
"localization": localization,
Expand Down
2 changes: 2 additions & 0 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,7 @@ def __call__(self, pil_img: Image.Image) -> np.ndarray:
if len(y) > 0:
y[:, :4:2] /= self.img_size[1]
y[:, 1:4:2] /= self.img_size[0]
else:
y = np.zeros((0, 5)) # normalize output

return y
6 changes: 3 additions & 3 deletions src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main(args):
frame_saving_period=args.save_period // args.period,
cache_folder=args.cache,
backup_size=args.backup_size,
alert_relaxation=args.alert_relaxation,
nb_consecutive_frames=args.nb_consecutive_frames,
frame_size=args.frame_size,
cache_backup_period=args.cache_backup_period,
cache_size=args.cache_size,
Expand Down Expand Up @@ -94,10 +94,10 @@ def main(args):
parser.add_argument("--jpeg_quality", type=int, default=80, help="Jpeg compression")
parser.add_argument("--cache-size", type=int, default=20, help="Maximum number of alerts to save in cache")
parser.add_argument(
"--alert_relaxation",
"--nb-consecutive_frames",
type=int,
default=3,
help="Number of consecutive positive detections required to send the first alert",
help="Number of consecutive frames to combine for prediction",
)
parser.add_argument(
"--cache_backup_period", type=int, default=60, help="Number of minutes between each cache backup to disk"
Expand Down
55 changes: 40 additions & 15 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

from dotenv import load_dotenv
from PIL import Image

from pyroengine.engine import Engine

Expand All @@ -16,7 +17,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):

# Cache saving
_ts = datetime.utcnow().isoformat()
engine._stage_alert(mock_wildfire_image, 0, localization="dummy")
engine._stage_alert(mock_wildfire_image, 0, datetime.utcnow().isoformat(), localization="dummy")
assert len(engine._alerts) == 1
assert engine._alerts[0]["ts"] < datetime.utcnow().isoformat() and _ts < engine._alerts[0]["ts"]
assert engine._alerts[0]["media_id"] is None
Expand All @@ -32,6 +33,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):
"frame_path": str(engine._cache.joinpath("pending_frame0.jpg")),
"cam_id": 0,
"ts": engine._alerts[0]["ts"],
"localization": "dummy",
}
# Overrites cache files
engine._dump_cache()
Expand All @@ -42,21 +44,40 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):
engine.clear_cache()

# inference
engine = Engine(alert_relaxation=3, cache_folder=folder)
engine = Engine(nb_consecutive_frames=4, cache_folder=folder)
out = engine.predict(mock_forest_image)
assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 0
out = engine.predict(mock_wildfire_image)
assert len(engine._states["-1"]["last_predictions"]) == 1
assert engine._states["-1"]["frame_count"] == 0
assert engine._states["-1"]["ongoing"] is False
assert isinstance(engine._states["-1"]["last_predictions"][0][0], Image.Image)
assert engine._states["-1"]["last_predictions"][0][1].shape[0] == 0
assert engine._states["-1"]["last_predictions"][0][1].shape[1] == 5
assert engine._states["-1"]["last_predictions"][0][2] == "[]"
assert engine._states["-1"]["last_predictions"][0][3] < datetime.utcnow().isoformat()
assert engine._states["-1"]["last_predictions"][0][4] is False

assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 1
# Alert relaxation
assert not engine._states["-1"]["ongoing"]
out = engine.predict(mock_wildfire_image)
assert engine._states["-1"]["consec"] == 2
assert isinstance(out, float) and 0 <= out <= 1
assert len(engine._states["-1"]["last_predictions"]) == 2
assert engine._states["-1"]["ongoing"] is False
assert isinstance(engine._states["-1"]["last_predictions"][0][0], Image.Image)
assert engine._states["-1"]["last_predictions"][1][1].shape[0] > 0
assert engine._states["-1"]["last_predictions"][1][1].shape[1] == 5
assert engine._states["-1"]["last_predictions"][1][2] == "[]"
assert engine._states["-1"]["last_predictions"][1][3] < datetime.utcnow().isoformat()
assert engine._states["-1"]["last_predictions"][1][4] is False

out = engine.predict(mock_wildfire_image)
assert engine._states["-1"]["consec"] == 3
assert engine._states["-1"]["ongoing"]
assert isinstance(out, float) and 0 <= out <= 1
assert len(engine._states["-1"]["last_predictions"]) == 3
assert engine._states["-1"]["ongoing"] is True
assert isinstance(engine._states["-1"]["last_predictions"][0][0], Image.Image)
assert engine._states["-1"]["last_predictions"][2][1].shape[0] > 0
assert engine._states["-1"]["last_predictions"][2][1].shape[1] == 5
assert len(engine._states["-1"]["last_predictions"][-1][2].split(" ")) == 5
assert engine._states["-1"]["last_predictions"][2][3] < datetime.utcnow().isoformat()
assert engine._states["-1"]["last_predictions"][2][4] is False


def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image):
Expand All @@ -76,7 +97,7 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image
cam_creds=cam_creds,
latitude=float(lat),
longitude=float(lon),
alert_relaxation=2,
nb_consecutive_frames=4,
frame_saving_period=3,
cache_folder=folder,
frame_size=(256, 384),
Expand All @@ -90,10 +111,14 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image
assert start_ts < json_respone["last_ping"] < ts
# Send an alert
engine.predict(mock_wildfire_image, "dummy_cam")
assert len(engine._alerts) == 0 and engine._states["dummy_cam"]["consec"] == 1
assert engine._states["dummy_cam"]["frame_count"] == 1
assert len(engine._states["dummy_cam"]["last_predictions"]) == 1
assert len(engine._alerts) == 0
assert engine._states["dummy_cam"]["ongoing"] is False

engine.predict(mock_wildfire_image, "dummy_cam")
assert engine._states["dummy_cam"]["consec"] == 2 and engine._states["dummy_cam"]["ongoing"]
assert len(engine._states["dummy_cam"]["last_predictions"]) == 2

assert engine._states["dummy_cam"]["ongoing"] is True
assert engine._states["dummy_cam"]["frame_count"] == 2
# Check that a media and an alert have been registered
assert len(engine._alerts) == 0
Expand Down

0 comments on commit 6501eec

Please sign in to comment.