From 8b27eb10c8637872c6c0dabe51d0b51445fbd1e9 Mon Sep 17 00:00:00 2001 From: Robert Resch Date: Tue, 21 Jan 2025 17:34:35 +0000 Subject: [PATCH] Move extracting of trace points to rust --- Cargo.lock | 7 ++++++ Cargo.toml | 1 + deebot_client/map.py | 29 +++------------------- deebot_client/rs.pyi | 23 ++++++++++++++++++ src/lib.rs | 57 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 91 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 874922a1..88f34974 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,6 +14,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "cfg-if" version = "1.0.0" @@ -25,6 +31,7 @@ name = "deebot_client" version = "0.0.0" dependencies = [ "base64", + "byteorder", "pyo3", "rust-lzma", ] diff --git a/Cargo.toml b/Cargo.toml index 2d927b60..56fc6cbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,5 +11,6 @@ crate-type = ["cdylib"] [dependencies] base64 = "0.22.1" +byteorder = "1.5.0" pyo3 = "0.23.3" rust-lzma = "0.6.0" diff --git a/deebot_client/map.py b/deebot_client/map.py index 3dc9dad0..77a4413f 100644 --- a/deebot_client/map.py +++ b/deebot_client/map.py @@ -10,7 +10,6 @@ from decimal import Decimal from io import BytesIO import itertools -import struct from typing import TYPE_CHECKING, Any, Final import zlib @@ -35,9 +34,7 @@ from .exceptions import MapError from .logging_filter import get_logger from .models import Room -from .rs import ( - decompress_7z_base64_data, -) +from .rs import TracePoint, decompress_7z_base64_data, extract_trace_points from .util import ( OnChangedDict, OnChangedList, @@ -150,13 +147,6 @@ def flatten(self) -> tuple[float, float]: return (self.x, self.y) -@dataclasses.dataclass(frozen=True) -class TracePoint(Point): - """Trace point.""" - - connected: bool - - @dataclasses.dataclass class BackgroundImage: """Background image.""" @@ -394,21 +384,8 @@ async def on_map_subset(event: MapSubsetEvent) -> None: # ---------------------------- METHODS ---------------------------- def _update_trace_points(self, data: str) -> None: - _LOGGER.debug("[_update_trace_points] Begin") - trace_points = decompress_7z_base64_data(data) - - for i in range(0, len(trace_points), 5): - position_x, position_y = struct.unpack("> 7 & 1 == 0 - - self._map_data.trace_values.append( - TracePoint(position_x, position_y, connected) - ) - - _LOGGER.debug("[_update_trace_points] finish") + _LOGGER.debug("[_update_trace_points]: %s", data) + self._map_data.trace_values.extend(extract_trace_points(data)) def _draw_map_pieces(self, image: Image.Image) -> None: _LOGGER.debug("[_draw_map_pieces] Draw") diff --git a/deebot_client/rs.pyi b/deebot_client/rs.pyi index b43362f0..be87ecb9 100644 --- a/deebot_client/rs.pyi +++ b/deebot_client/rs.pyi @@ -1,2 +1,25 @@ +from typing import Self + def decompress_7z_base64_data(value: str) -> bytes: """Decompress base64 decoded 7z compressed string.""" + +class TracePoint: + """Trace point.""" + + def __new__(cls, x: float, y: float, connected: bool) -> Self: + """Create a new trace point.""" + + @property + def x(self) -> float: + """X coordinate.""" + + @property + def y(self) -> float: + """Y coordinate.""" + + @property + def connected(self) -> float: + """If the point is connected.""" + +def extract_trace_points(value: str) -> list[TracePoint]: + """Extract trace points from 7z compressed data string.""" diff --git a/src/lib.rs b/src/lib.rs index 62c2f617..56e728c5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,10 @@ use std::error::Error; use base64::{engine::general_purpose, Engine as _}; +use byteorder::{LittleEndian, ReadBytesExt}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use std::io::Cursor; fn _decompress_7z_base64_data(value: String) -> Result, Box> { let mut bytes = general_purpose::STANDARD.decode(value)?; @@ -21,9 +23,64 @@ fn decompress_7z_base64_data(value: String) -> Result, PyErr> { Ok(_decompress_7z_base64_data(value).map_err(|err| PyValueError::new_err(err.to_string()))?) } +/// Trace point +#[pyclass] +struct TracePoint { + #[pyo3(get)] + x: i16, + + #[pyo3(get)] + y: i16, + + #[pyo3(get)] + connected: bool, +} + +#[pymethods] +impl TracePoint { + #[new] + fn new(x: i16, y: i16, connected: bool) -> Self { + TracePoint { x, y, connected } + } +} + +fn process_trace_points(trace_points: &[u8]) -> Result, Box> { + let mut trace_values = Vec::new(); + for i in (0..trace_points.len()).step_by(5) { + if i + 4 >= trace_points.len() { + break; // Avoid out-of-bounds slice + } + + // Read position_x and position_y + let mut cursor = Cursor::new(&trace_points[i..i + 4]); + let x = cursor.read_i16::()?; + let y = cursor.read_i16::()?; + + // Extract point_data + let point_data = trace_points[i + 4]; + + // Determine connection status + let connected = (point_data >> 7 & 1) == 0; + + // Append the TracePoint to trace_values + trace_values.push(TracePoint { x, y, connected }); + } + Ok(trace_values) +} + +#[pyfunction] +/// Extract trace points from 7z compressed data string. +fn extract_trace_points(value: String) -> Result, PyErr> { + let decompressed_data = decompress_7z_base64_data(value)?; + Ok(process_trace_points(&decompressed_data) + .map_err(|err| PyValueError::new_err(err.to_string()))?) +} + /// Deebot client written in Rust #[pymodule] fn rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(decompress_7z_base64_data, m)?)?; + m.add_function(wrap_pyfunction!(extract_trace_points, m)?)?; + m.add_class::()?; Ok(()) }