Skip to content

Commit

Permalink
Move extracting of trace points to rust
Browse files Browse the repository at this point in the history
  • Loading branch information
edenhaus committed Jan 21, 2025
1 parent 94ecd72 commit 8b27eb1
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 26 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ crate-type = ["cdylib"]

[dependencies]
base64 = "0.22.1"
byteorder = "1.5.0"
pyo3 = "0.23.3"
rust-lzma = "0.6.0"
29 changes: 3 additions & 26 deletions deebot_client/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from decimal import Decimal
from io import BytesIO
import itertools
import struct
from typing import TYPE_CHECKING, Any, Final
import zlib

Expand All @@ -35,9 +34,7 @@
from .exceptions import MapError
from .logging_filter import get_logger
from .models import Room
from .rs import (
decompress_7z_base64_data,
)
from .rs import TracePoint, decompress_7z_base64_data, extract_trace_points
from .util import (
OnChangedDict,
OnChangedList,
Expand Down Expand Up @@ -150,13 +147,6 @@ def flatten(self) -> tuple[float, float]:
return (self.x, self.y)


@dataclasses.dataclass(frozen=True)
class TracePoint(Point):
"""Trace point."""

connected: bool


@dataclasses.dataclass
class BackgroundImage:
"""Background image."""
Expand Down Expand Up @@ -394,21 +384,8 @@ async def on_map_subset(event: MapSubsetEvent) -> None:
# ---------------------------- METHODS ----------------------------

def _update_trace_points(self, data: str) -> None:
_LOGGER.debug("[_update_trace_points] Begin")
trace_points = decompress_7z_base64_data(data)

for i in range(0, len(trace_points), 5):
position_x, position_y = struct.unpack("<hh", trace_points[i : i + 4])

point_data = trace_points[i + 4]

connected = point_data >> 7 & 1 == 0

self._map_data.trace_values.append(
TracePoint(position_x, position_y, connected)
)

_LOGGER.debug("[_update_trace_points] finish")
_LOGGER.debug("[_update_trace_points]: %s", data)
self._map_data.trace_values.extend(extract_trace_points(data))

def _draw_map_pieces(self, image: Image.Image) -> None:
_LOGGER.debug("[_draw_map_pieces] Draw")
Expand Down
23 changes: 23 additions & 0 deletions deebot_client/rs.pyi
Original file line number Diff line number Diff line change
@@ -1,2 +1,25 @@
from typing import Self

def decompress_7z_base64_data(value: str) -> bytes:
"""Decompress base64 decoded 7z compressed string."""

class TracePoint:
"""Trace point."""

def __new__(cls, x: float, y: float, connected: bool) -> Self:
"""Create a new trace point."""

@property
def x(self) -> float:
"""X coordinate."""

@property
def y(self) -> float:
"""Y coordinate."""

@property
def connected(self) -> float:
"""If the point is connected."""

def extract_trace_points(value: str) -> list[TracePoint]:
"""Extract trace points from 7z compressed data string."""
57 changes: 57 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::error::Error;

use base64::{engine::general_purpose, Engine as _};
use byteorder::{LittleEndian, ReadBytesExt};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::io::Cursor;

fn _decompress_7z_base64_data(value: String) -> Result<Vec<u8>, Box<dyn Error>> {
let mut bytes = general_purpose::STANDARD.decode(value)?;
Expand All @@ -21,9 +23,64 @@ fn decompress_7z_base64_data(value: String) -> Result<Vec<u8>, PyErr> {
Ok(_decompress_7z_base64_data(value).map_err(|err| PyValueError::new_err(err.to_string()))?)
}

/// Trace point
#[pyclass]
struct TracePoint {
#[pyo3(get)]
x: i16,

#[pyo3(get)]
y: i16,

#[pyo3(get)]
connected: bool,
}

#[pymethods]
impl TracePoint {
#[new]
fn new(x: i16, y: i16, connected: bool) -> Self {
TracePoint { x, y, connected }
}
}

fn process_trace_points(trace_points: &[u8]) -> Result<Vec<TracePoint>, Box<dyn Error>> {
let mut trace_values = Vec::new();
for i in (0..trace_points.len()).step_by(5) {
if i + 4 >= trace_points.len() {
break; // Avoid out-of-bounds slice
}

// Read position_x and position_y
let mut cursor = Cursor::new(&trace_points[i..i + 4]);
let x = cursor.read_i16::<LittleEndian>()?;
let y = cursor.read_i16::<LittleEndian>()?;

// Extract point_data
let point_data = trace_points[i + 4];

// Determine connection status
let connected = (point_data >> 7 & 1) == 0;

// Append the TracePoint to trace_values
trace_values.push(TracePoint { x, y, connected });
}
Ok(trace_values)
}

#[pyfunction]
/// Extract trace points from 7z compressed data string.
fn extract_trace_points(value: String) -> Result<Vec<TracePoint>, PyErr> {
let decompressed_data = decompress_7z_base64_data(value)?;
Ok(process_trace_points(&decompressed_data)
.map_err(|err| PyValueError::new_err(err.to_string()))?)
}

/// Deebot client written in Rust
#[pymodule]
fn rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(decompress_7z_base64_data, m)?)?;
m.add_function(wrap_pyfunction!(extract_trace_points, m)?)?;
m.add_class::<TracePoint>()?;
Ok(())
}

0 comments on commit 8b27eb1

Please sign in to comment.