Skip to content

Commit

Permalink
add support for numpy arrays of atomic types
Browse files Browse the repository at this point in the history
  • Loading branch information
WatcherBox committed Feb 26, 2024
1 parent 2a78a49 commit 35310d1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 4 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ packages = find:
install_requires =
PyYAML>=5.3.1
pyserial>=3.4
numpy>=1.19.4

[options.entry_points]
console_scripts =
Expand Down
21 changes: 19 additions & 2 deletions simple_rpc/io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from typing import Any, BinaryIO
from struct import calcsize, pack, unpack

Expand Down Expand Up @@ -43,9 +44,12 @@ def _write_basic(
if basic_type == 's':
stream.write(value + b'\0')
return

elif isinstance(value, np.ndarray):
stream.write(value.tobytes())
return

full_type = (endianness + basic_type).encode('utf-8')

stream.write(pack(full_type, cast(basic_type)(value)))


Expand Down Expand Up @@ -82,9 +86,15 @@ def read(
return [
read(stream, endianness, size_t, item) for _ in range(length)
for item in obj_type]

if isinstance(obj_type, tuple):
return tuple(
read(stream, endianness, size_t, item) for item in obj_type)

if isinstance(obj_type, np.ndarray):
length = _read_basic(stream, endianness, size_t)
return np.frombuffer(
stream.read(length * obj_type.itemsize), obj_type.dtype)
return _read_basic(stream, endianness, obj_type)


Expand All @@ -104,14 +114,21 @@ def write(
:arg obj: Object of type {obj_type}.
"""
if isinstance(obj_type, list):
# print(f" size_t: {size_t}, len:{len(obj) // len(obj_type)}")
_write_basic(stream, endianness, size_t, len(obj) // len(obj_type))
if isinstance(obj_type, list) or isinstance(obj_type, tuple):
if isinstance(obj_type, np.ndarray):
# print(f"writing array: {size_t}, {obj.size}, {obj.dtype}, obj_tpye: {obj_type}")
_write_basic(stream, endianness, size_t, obj.size)
_write_basic(stream, endianness, obj_type, obj)
elif isinstance(obj_type, list) or isinstance(obj_type, tuple):
for item_type, item in zip(obj_type * len(obj), obj):
write(stream, endianness, size_t, item_type, item)
else:
_write_basic(stream, endianness, obj_type, obj)




def until(
condition: callable, f: callable, *args: Any, **kwargs: Any) -> None:
"""Call {f(*args, **kwargs)} until {condition} is true.
Expand Down
36 changes: 35 additions & 1 deletion simple_rpc/protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import numpy as np

from typing import Any, BinaryIO

from .io import cast, read_byte_string

dtype_map = {
'b': np.int8,
'B': np.uint8,
'h': np.int16,
'H': np.uint16,
'i': np.int32,
'I': np.uint32,
'l': np.int32,
'L': np.uint32,
'q': np.int64,
'Q': np.uint64,
'f': np.float32,
'd': np.float64,
'?': np.bool_,
'c': np.byte # Note: 'c' in struct is a single byte; for strings, consider np.bytes_ or np.chararray.
}

def _parse_type(type_str: bytes) -> Any:
"""Parse a type definition string.
Expand All @@ -18,7 +36,12 @@ def _construct_type(tokens: tuple):
obj_type.append(_construct_type(tokens))
elif token == b'(':
obj_type.append(tuple(_construct_type(tokens)))
elif token in (b')', b']'):
elif token == b'{':
t = _construct_type(tokens)
assert len(t) == 1, 'only atomic types allowed in np arrays'
dtype = _get_dtype(t[0])
obj_type.append(np.ndarray(dtype=dtype, shape=(1, 1)))
elif token in (b')', b']', b'}'):
break
else:
obj_type.append(token.decode())
Expand All @@ -33,6 +56,15 @@ def _construct_type(tokens: tuple):
return ''
return obj_type[0]

def _get_dtype(type_str: bytes) -> Any:
"""Get the NumPy data type of a type definition string.
:arg type_str: Type definition string.
:returns: NumPy data type.
"""
return dtype_map.get(type_str, np.byte)


def _type_name(obj_type: Any) -> str:
"""Python type name of a C object type.
Expand All @@ -41,6 +73,8 @@ def _type_name(obj_type: Any) -> str:
:returns: Python type name.
"""
if isinstance(obj_type, np.ndarray):
return '{' + ', '.join([_type_name(item) for item in obj_type]) + '}'
if not obj_type:
return ''
if isinstance(obj_type, list):
Expand Down
5 changes: 4 additions & 1 deletion simple_rpc/simple_rpc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

import numpy as np

from functools import wraps
from time import sleep
from types import MethodType
Expand Down Expand Up @@ -184,7 +187,7 @@ def call_method(self: object, name: str, *args: Any) -> Any:
self._write(parameter['fmt'], args[index])

# Read return value (if any).
if method['return']['fmt']:
if method['return']['fmt'] or isinstance(method['return']['fmt'], np.ndarray):
return self._read(method['return']['fmt'])
return None

Expand Down

0 comments on commit 35310d1

Please sign in to comment.