Skip to content

Commit

Permalink
Add support for numpy array attributes (#7)
Browse files Browse the repository at this point in the history
Add support for numpy array attributes
  • Loading branch information
mathias-nillion authored May 22, 2024
1 parent 294fd32 commit 75b76dd
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
15 changes: 14 additions & 1 deletion nada_algebra/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,23 @@ class NadaArray:
inner: np.ndarray

SUPPORTED_OPERATIONS = {
"base",
"compress",
"copy",
"cumprod",
"cumsum",
"data",
"dtype",
"diagonal",
"fill",
"flags",
"flat",
"flatten",
"item",
"itemset",
"itemsize",
"nbytes",
"ndim",
"diagonal",
"fill",
"flatten",
Expand All @@ -49,6 +62,7 @@ class NadaArray:
"resize",
"shape",
"size",
"strides",
"squeeze",
"sum",
"swapaxes",
Expand Down Expand Up @@ -82,7 +96,6 @@ def __setitem__(self, key, value):
value: The value to set.
"""
if isinstance(value, NadaArray):
# print("NadaArray")
self.inner[key] = value.inner
else:
self.inner[key] = value
Expand Down
9 changes: 9 additions & 0 deletions tests/nada-tests/src/supported_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ def nada_main():

a = na.array([3, 3], parties[0], "A")

assert isinstance(a.data, memoryview)
assert a.dtype == NadaType
assert a.flags["WRITEABLE"]
assert isinstance(na.NadaArray(a.flat), na.NadaArray)
assert a.itemsize == 8
assert a.nbytes == 72
assert a.ndim == 2
assert a.strides == (24, 8)

try:
a.argsort()
raise Exception("Unsopported operation `argsort` occurred")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"vstack",
"reveal",
"matrix_multiplication",
"generate_array",
"supported_operations",
"get_item",
"get_attr",
"set_item",
Expand Down

0 comments on commit 75b76dd

Please sign in to comment.