Skip to content

Commit

Permalink
Merge pull request #6 from NillionNetwork/feature/alphas-array
Browse files Browse the repository at this point in the history
Feature/alphas array
  • Loading branch information
mathias-nillion authored May 22, 2024
2 parents ac4dc01 + 7d9b4a6 commit 4856e64
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 39 deletions.
88 changes: 65 additions & 23 deletions nada_algebra/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from dataclasses import dataclass
from typing import Callable, Union
from typing import Any, Callable, Union

import numpy as np
from nada_dsl import (
Expand All @@ -31,6 +31,34 @@ class NadaArray:

inner: np.ndarray

SUPPORTED_OPERATIONS = {
"compress",
"copy",
"cumprod",
"cumsum",
"diagonal",
"fill",
"flatten",
"item",
"itemset",
"prod",
"put",
"ravel",
"repeat",
"reshape",
"resize",
"shape",
"size",
"squeeze",
"sum",
"swapaxes",
"T",
"take",
"tolist",
"trace",
"transpose",
}

def __getitem__(self, item):
"""
Get an item from the array.
Expand All @@ -54,23 +82,10 @@ def __setitem__(self, key, value):
value: The value to set.
"""
if isinstance(value, NadaArray):
# print("NadaArray")
self.inner[key] = value.inner
else:
self.inner[key] = value

def __getattr__(self, name: str):
"""
Get an attribute from the array.
Args:
name (str): The attribute name.
Returns:
NadaArray: A new NadaArray representing the retrieved attribute.
"""
return getattr(self.inner, name)

def __add__(
self,
other: Union[
Expand Down Expand Up @@ -220,15 +235,6 @@ def dot(self, other: "NadaArray") -> "NadaArray":
"""
return NadaArray(self.inner.dot(other.inner))

def sum(self) -> Union[SecretInteger, SecretUnsignedInteger]:
"""
Compute the sum of the elements in the array.
Returns:
Union[SecretInteger, SecretUnsignedInteger]: The sum of the array elements.
"""
return NadaArray(self.inner.sum())

def hstack(self, other: "NadaArray") -> "NadaArray":
"""
Horizontally stack two NadaArray objects.
Expand Down Expand Up @@ -424,3 +430,39 @@ def random(
)
)
)

def __getattr__(self, name: str) -> Any:
"""Routes other attributes to the inner NumPy array.
Args:
name (str): Attribute name.
Raises:
AttributeError: Raised if attribute not supported.
Returns:
Any: Result of attribute.
"""
if name not in self.SUPPORTED_OPERATIONS:
raise AttributeError("NumPy method `%s` is not (currently) supported by NadaArrays." % name)

attr = getattr(self.inner, name)

if callable(attr):
def wrapper(*args, **kwargs):
result = attr(*args, **kwargs)
if isinstance(result, np.ndarray):
return NadaArray(result)
return result
return wrapper

if isinstance(attr, np.ndarray):
attr = NadaArray(attr)

return attr

def __setattr__(self, name, value):
if name == 'inner':
super().__setattr__(name, value)
else:
setattr(self.inner, name, value)
95 changes: 80 additions & 15 deletions nada_algebra/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
and manipulation of arrays and party objects.
"""

from typing import Any, Iterable
from nada_dsl import (
Party,
SecretInteger,
Expand Down Expand Up @@ -30,20 +31,20 @@ def parties(num: int, prefix: str = "Party") -> list:
return [Party(name=f"{prefix}{i}") for i in range(num)]


def __from_list(lst: list, nada_type: Integer | UnsignedInteger) -> list:
def __from_numpy(arr: np.ndarray, nada_type: Integer | UnsignedInteger) -> list:
"""
Recursively convert a nested list to a list of NadaInteger objects.
Recursively convert a n-dimensional NumPy array to a nested list of NadaInteger objects.
Args:
lst (list): A nested list of integers.
arr (np.ndarray): A NumPy array of integers.
nada_type (type): The type of NadaInteger objects to create.
Returns:
list: A nested list of NadaInteger objects.
"""
if len(lst.shape) == 1:
return [nada_type(int(elem)) for elem in lst]
return [__from_list(lst[i], nada_type) for i in range(len(lst))]
if len(arr.shape) == 1:
return [nada_type(int(elem)) for elem in arr]
return [__from_numpy(arr[i], nada_type) for i in range(arr.shape[0])]


def from_list(lst: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
Expand All @@ -59,15 +60,15 @@ def from_list(lst: list, nada_type: Integer | UnsignedInteger = Integer) -> Nada
"""
if not isinstance(lst, np.ndarray):
lst = np.array(lst)
return NadaArray(np.array(__from_list(lst, nada_type)))
return NadaArray(np.array(__from_numpy(lst, nada_type)))


def ones(dims: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
def ones(dims: Iterable[int], nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
"""
Create a cleartext NadaArray filled with ones.
Args:
dims (list): A list of integers representing the dimensions of the array.
dims (Iterable[int]): A list of integers representing the dimensions of the array.
nada_type (type, optional): The type of NadaInteger objects to create. Defaults to Integer.
Returns:
Expand All @@ -76,12 +77,28 @@ def ones(dims: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArra
return from_list(np.ones(dims), nada_type)


def zeros(dims: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
def ones_like(a: np.ndarray | NadaArray, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
"""
Create a cleartext NadaArray filled with one with the same shape and type as a given array.
Args:
a (np.ndarray | NadaArray): A reference array.
nada_type (type, optional): The type of NadaInteger objects to create. Defaults to Integer.
Returns:
NadaArray: The created NadaArray filled with ones.
"""
if isinstance(a, NadaArray):
a = a.inner
return from_list(np.ones_like(a), nada_type)


def zeros(dims: Iterable[int], nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
"""
Create a cleartext NadaArray filled with zeros.
Args:
dims (list): A list of integers representing the dimensions of the array.
dims (Iterable[int]): A list of integers representing the dimensions of the array.
nada_type (type, optional): The type of NadaInteger objects to create. Defaults to Integer.
Returns:
Expand All @@ -90,8 +107,56 @@ def zeros(dims: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArr
return from_list(np.zeros(dims), nada_type)


def zeros_like(a: np.ndarray | NadaArray, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
"""
Create a cleartext NadaArray filled with zeros with the same shape and type as a given array.
Args:
a (np.ndarray | NadaArray): A reference array.
nada_type (type, optional): The type of NadaInteger objects to create. Defaults to Integer.
Returns:
NadaArray: The created NadaArray filled with zeros.
"""
if isinstance(a, NadaArray):
a = a.inner
return from_list(np.zeros_like(a), nada_type)


def alphas(dims: Iterable[int], alpha: Any) -> NadaArray:
"""
Create a NadaArray filled with a certain constant value.
Args:
dims (Iterable[int]): A list of integers representing the dimensions of the array.
alpha (Any): Some constant value.
Returns:
NadaArray: NadaArray filled with constant value.
"""
ones_array = np.ones(dims)
return NadaArray(np.frompyfunc(lambda _: alpha, 1, 1)(ones_array))


def alphas_like(a: np.ndarray | NadaArray, alpha: Any) -> NadaArray:
"""
Create a NadaArray filled with a certain constant value with the same shape and type as a given array.
Args:
a (np.ndarray | NadaArray): Reference array.
alpha (Any): Some constant value.
Returns:
NadaArray: NadaArray filled with constant value.
"""
if isinstance(a, NadaArray):
a = a.inner
ones_array = np.ones_like(a)
return NadaArray(np.frompyfunc(lambda _: alpha, 1, 1)(ones_array))


def array(
dims: list,
dims: Iterable[int],
party: Party,
prefix: str,
nada_type: (
Expand All @@ -102,7 +167,7 @@ def array(
Create a NadaArray with the specified dimensions and elements of the given type.
Args:
dims (list): A list of integers representing the dimensions of the array.
dims (Iterable[int]): A list of integers representing the dimensions of the array.
party (Party): The party object.
prefix (str): A prefix for naming the array elements.
nada_type (type, optional): The type of elements to create. Defaults to SecretInteger.
Expand All @@ -114,13 +179,13 @@ def array(


def random(
dims: list, nada_type: SecretInteger | SecretUnsignedInteger = SecretInteger
dims: Iterable[int], nada_type: SecretInteger | SecretUnsignedInteger = SecretInteger
) -> NadaArray:
"""
Create a random NadaArray with the specified dimensions.
Args:
dims (list): A list of integers representing the dimensions of the array.
dims (Iterable[int]): A list of integers representing the dimensions of the array.
nada_type (type, optional): The type of elements to create. Defaults to SecretInteger.
Returns:
Expand Down
7 changes: 7 additions & 0 deletions tests/generate_array/nada-project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name = "generate_array"
version = "0.1.0"
authors = [""]

[[programs]]
path = "src/main.py"
prime_size = 128
23 changes: 23 additions & 0 deletions tests/generate_array/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from nada_dsl import *
import nada_algebra as na


def nada_main():
party = Party("party_0")

a = SecretInteger(Input("a", party))

ones1 = na.ones([2, 3])
ones2 = na.ones_like(ones1)

zeros1 = na.zeros([2, 3])
zeros2 = na.zeros_like(zeros1)

alphas1 = na.alphas([2, 3], alpha=a)
alphas2 = na.alphas_like(alphas1, alpha=a)

two_a = alphas1 + alphas2

out = two_a + zeros1 + zeros2 + ones1 + ones2

return out.output(party, "my_output")
5 changes: 5 additions & 0 deletions tests/generate_array/target/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# This directory is kept purposely, so that no compilation errors arise.
# Ignore everything in this directory
*
# Except this file
!.gitignore
20 changes: 20 additions & 0 deletions tests/generate_array/tests/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
program: main
inputs:
secrets:
a:
SecretInteger: "3"
public_variables: {}
expected_outputs:
my_output_0_0:
SecretInteger: "8"
my_output_0_1:
SecretInteger: "8"
my_output_0_2:
SecretInteger: "8"
my_output_1_0:
SecretInteger: "8"
my_output_1_1:
SecretInteger: "8"
my_output_1_2:
SecretInteger: "8"
2 changes: 1 addition & 1 deletion tests/sum/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def nada_main():

result = a.sum()

return result.output(parties[1], "my_output")
return [Output(result, "my_output", parties[1])]
7 changes: 7 additions & 0 deletions tests/supported_operations/nada-project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name = "supported_operations"
version = "0.1.0"
authors = [""]

[[programs]]
path = "src/main.py"
prime_size = 128
Loading

0 comments on commit 4856e64

Please sign in to comment.