Skip to content

Commit

Permalink
localization -> bboxes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronan committed Jul 22, 2024
1 parent ea97769 commit f7b9b59
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dynamic = ["version"]
dependencies = [
"ultralytics==8.2.50",
"opencv-python",
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@5da3d23d38cb78a4a4e15cf2f9f83bf2da7cdaee#egg=pyroclient&subdirectory=client",
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@f809a399bf8928e93da8e95056e811217f6c2a17#egg=pyroclient&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"tqdm>=4.62.0",
"huggingface_hub==0.23.1",
Expand Down
22 changes: 10 additions & 12 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _dump_cache(self) -> None:
"cam_id": info["cam_id"],
"pose_id": info["pose_id"],
"ts": info["ts"],
"localization": info["localization"],
"bboxes": info["bboxes"],
}
)

Expand All @@ -188,7 +188,7 @@ def _load_cache(self) -> None:
"frame": frame,
"cam_id": entry["cam_id"],
"pose_id": entry["pose_id"],
"localization": entry["localization"],
"bboxes": entry["bboxes"],
"ts": entry["ts"],
}
)
Expand Down Expand Up @@ -283,12 +283,10 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None, pose_id: Opt
# Alert
if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str):
# Save the alert in cache to avoid connection issues
for idx, (frame, preds, localization, ts, is_staged) in enumerate(
self._states[cam_key]["last_predictions"]
):
for idx, (frame, preds, bboxes, ts, is_staged) in enumerate(self._states[cam_key]["last_predictions"]):
if not is_staged:
self._stage_alert(frame, cam_id, pose_id, ts, localization)
self._states[cam_key]["last_predictions"][idx] = frame, preds, localization, ts, True
self._stage_alert(frame, cam_id, pose_id, ts, bboxes)
self._states[cam_key]["last_predictions"][idx] = frame, preds, bboxes, ts, True

# Check if it's time to backup pending alerts
ts = datetime.now(timezone.utc)
Expand All @@ -299,7 +297,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None, pose_id: Opt
return float(conf)

def _stage_alert(
self, frame: Image.Image, cam_id: Optional[str], pose_id: Optional[int], ts: int, localization: list
self, frame: Image.Image, cam_id: Optional[str], pose_id: Optional[int], ts: int, bboxes: list
) -> None:
# Store information in the queue

Expand All @@ -309,7 +307,7 @@ def _stage_alert(
"cam_id": cam_id,
"pose_id": pose_id,
"ts": ts,
"localization": localization,
"bboxes": bboxes,
}
)

Expand All @@ -331,9 +329,9 @@ def _process_alerts(self, cameras: List[ReolinkCamera]) -> None:
for camera in cameras:
if camera.ip_address == cam_id:
azimuth = camera.cam_azimuths[pose_id - 1] if pose_id is not None else camera.cam_azimuths[0]
localization = self._alerts[0]["localization"]
response = self.api_client[cam_id].create_detection(stream.getvalue(), azimuth, localization)
logging.info(f"Azimuth : {azimuth} , localization : {localization}")
bboxes = self._alerts[0]["bboxes"]
response = self.api_client[cam_id].create_detection(stream.getvalue(), azimuth, bboxes)
logging.info(f"Azimuth : {azimuth} , bboxes : {bboxes}")
break

# Force a KeyError if the request failed
Expand Down
6 changes: 3 additions & 3 deletions src/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"

[tool.poetry.dependencies]
python = "^3.8"
pyroclient = { git = "https://github.com/pyronear/pyro-api.git", rev = "5da3d23d38cb78a4a4e15cf2f9f83bf2da7cdaee", subdirectory = "client" }
pyroclient = { git = "https://github.com/pyronear/pyro-api.git", rev = "f809a399bf8928e93da8e95056e811217f6c2a17", subdirectory = "client" }
pyroengine = "^0.2.0"
python-dotenv = ">=0.15.0"
ultralytics = "8.2.50"
2 changes: 1 addition & 1 deletion src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ psutil==6.0.0 ; python_version >= "3.8" and python_version < "4.0"
py-cpuinfo==9.0.0 ; python_version >= "3.8" and python_version < "4.0"
pyparsing==3.1.2 ; python_version >= "3.8" and python_version < "4.0"
pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.8" and python_version < "4"
pyroclient @ git+https://github.com/pyronear/pyro-api.git@5da3d23d38cb78a4a4e15cf2f9f83bf2da7cdaee#subdirectory=client ; python_version >= "3.8" and python_version < "4"
pyroclient @ git+https://github.com/pyronear/pyro-api.git@f809a399bf8928e93da8e95056e811217f6c2a17#subdirectory=client ; python_version >= "3.8" and python_version < "4"
pyroengine==0.2.0 ; python_version >= "3.8" and python_version < "4"
python-dateutil==2.9.0.post0 ; python_version >= "3.8" and python_version < "4.0"
python-dotenv==1.0.1 ; python_version >= "3.8" and python_version < "4.0"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def test_capture_images_ptz(system_controller_ptz):

assert queue.qsize() == 2
cam_id, frame = await queue.get() # Use timeout to wait for the item
assert cam_id == "192.168.1.1_1"
assert cam_id == "192.168.1.1"
assert isinstance(frame, Image.Image)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):

# Cache saving
_ts = datetime.now().isoformat()
engine._stage_alert(mock_wildfire_image, 0, None, datetime.now().isoformat(), localization="dummy")
engine._stage_alert(mock_wildfire_image, 0, None, datetime.now().isoformat(), bboxes="dummy")
assert len(engine._alerts) == 1
assert engine._alerts[0]["ts"] < datetime.now().isoformat() and _ts < engine._alerts[0]["ts"]

Expand All @@ -33,7 +33,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):
"cam_id": 0,
"pose_id": None,
"ts": engine._alerts[0]["ts"],
"localization": "dummy",
"bboxes": "dummy",
}
# Overrites cache files
engine._dump_cache()
Expand Down

0 comments on commit f7b9b59

Please sign in to comment.