Skip to content

Commit 76e7a13

Browse files
committed
some docstrings
1 parent 4b4e6bb commit 76e7a13

File tree

4 files changed

+63
-5
lines changed

4 files changed

+63
-5
lines changed

src/kyber_py/modules/modules.py

+24
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ def __init__(self):
88
self.matrix = MatrixKyber
99

1010
def decode_vector(self, input_bytes, k, d, is_ntt=False):
11+
"""
12+
Decode bytes into a a vector of polynomial elements.
13+
14+
Each element is assumed to be encoded as a polynomial with ``d``-bit
15+
coefficients (hence a polynomial is encoded into ``256 * d`` bits).
16+
17+
A vector of length ``k`` then has ``256 * d * k`` bits.
18+
"""
1119
# Ensure the input bytes are the correct length to create k elements with
1220
# d bits used for each coefficient
1321
if self.ring.n * d * k != len(input_bytes) * 8:
@@ -32,28 +40,44 @@ def __init__(self, parent, matrix_data, transpose=False):
3240
super().__init__(parent, matrix_data, transpose=transpose)
3341

3442
def encode(self, d):
43+
"""
44+
Encode every element of a matrix into bytes and concatenate
45+
"""
3546
output = b""
3647
for row in self._data:
3748
for ele in row:
3849
output += ele.encode(d)
3950
return output
4051

4152
def compress(self, d):
53+
"""
54+
Compress every element of the matrix to have at most ``d`` bits
55+
"""
4256
for row in self._data:
4357
for ele in row:
4458
ele.compress(d)
4559
return self
4660

4761
def decompress(self, d):
62+
"""
63+
Perform (lossy) decompression of the polynomial assuming it has been
64+
compressed to have at most ``d`` bits.
65+
"""
4866
for row in self._data:
4967
for ele in row:
5068
ele.decompress(d)
5169
return self
5270

5371
def to_ntt(self):
72+
"""
73+
Convert every element of the matrix into NTT form
74+
"""
5475
data = [[x.to_ntt() for x in row] for row in self._data]
5576
return self.parent(data, transpose=self._transpose)
5677

5778
def from_ntt(self):
79+
"""
80+
Convert every element of the matrix from NTT form
81+
"""
5882
data = [[x.from_ntt() for x in row] for row in self._data]
5983
return self.parent(data, transpose=self._transpose)

src/kyber_py/modules/modules_generic.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
class Module:
22
def __init__(self, ring):
3+
"""
4+
Initialise a module over the ring ``ring``.
5+
"""
36
self.ring = ring
47
self.matrix = Matrix
58

69
def random_element(self, m, n):
10+
"""
11+
Generate a random element of the module of dimension m x n
12+
13+
:param int m: the number of rows in the matrix
14+
:param int m: the number of columns in tge matrix
15+
:return: an element of the module with dimension `m times n`
16+
"""
717
elements = [
818
[self.ring.random_element() for _ in range(n)] for _ in range(m)
919
]
@@ -47,7 +57,10 @@ def __call__(self, matrix_elements, transpose=False):
4757

4858
def vector(self, elements):
4959
"""
50-
Construct a vector with the given elements
60+
Construct a vector given a list of elements of the module's ring
61+
62+
:param list: a list of elements of the ring
63+
:return: a vector of the module
5164
"""
5265
return self.matrix(self, [elements], transpose=True)
5366

@@ -64,6 +77,9 @@ def dim(self):
6477
"""
6578
Return the dimensions of the matrix with m rows
6679
and n columns
80+
81+
:return: the dimension of the matrix ``(m, n)``
82+
:rtype: tuple(int, int)
6783
"""
6884
if not self._transpose:
6985
return len(self._data), len(self._data[0])
@@ -78,13 +94,13 @@ def _check_dimensions(self):
7894

7995
def transpose(self):
8096
"""
81-
Swap rows and columns of self
97+
Return a matrix with the rows and columns of swapped
8298
"""
8399
return self.parent(self._data, not self._transpose)
84100

85101
def transpose_self(self):
86102
"""
87-
Transpose in place
103+
Swap the rows and columns of the matrix in place
88104
"""
89105
self._transpose = not self._transpose
90106
return
@@ -193,7 +209,7 @@ def __matmul__(self, other):
193209

194210
def dot(self, other):
195211
"""
196-
Inner product
212+
Compute the inner product of two vectors
197213
"""
198214
if not isinstance(other, type(self)):
199215
raise TypeError("Can only perform dot product with other matrices")

src/kyber_py/polynomials/polynomials.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _decompress_ele(self, x, d):
157157
def compress(self, d):
158158
"""
159159
Compress the polynomial by compressing each coefficient
160+
160161
NOTE: This is lossy compression
161162
"""
162163
self.coeffs = [self._compress_ele(c, d) for c in self.coeffs]
@@ -165,6 +166,7 @@ def compress(self, d):
165166
def decompress(self, d):
166167
"""
167168
Decompress the polynomial by decompressing each coefficient
169+
168170
NOTE: This as compression is lossy, we have
169171
x' = decompress(compress(x)), which x' != x, but is
170172
close in magnitude.
@@ -198,6 +200,9 @@ def to_ntt(self):
198200
return self.parent(coeffs, is_ntt=True)
199201

200202
def from_ntt(self):
203+
"""
204+
Not supported, raises a ``TypeError``
205+
"""
201206
raise TypeError(f"Polynomial not in the NTT domain: {type(self) = }")
202207

203208

@@ -207,6 +212,9 @@ def __init__(self, parent, coefficients):
207212
self.coeffs = self._parse_coefficients(coefficients)
208213

209214
def to_ntt(self):
215+
"""
216+
Not supported, raises a ``TypeError``
217+
"""
210218
raise TypeError(
211219
f"Polynomial is already in the NTT domain: {type(self) = }"
212220
)
@@ -249,6 +257,10 @@ def _ntt_base_multiplication(a0, a1, b0, b1, zeta):
249257
return r0, r1
250258

251259
def _ntt_coefficient_multiplication(self, f_coeffs, g_coeffs):
260+
"""
261+
Given the coefficients of two polynomials compute the coefficients of
262+
their product
263+
"""
252264
new_coeffs = []
253265
zetas = self.parent.ntt_zetas
254266
for i in range(64):
@@ -272,7 +284,6 @@ def _ntt_coefficient_multiplication(self, f_coeffs, g_coeffs):
272284
def _ntt_multiplication(self, other):
273285
"""
274286
Number Theoretic Transform multiplication.
275-
Only implemented (currently) for n = 256
276287
"""
277288
new_coeffs = self._ntt_coefficient_multiplication(
278289
self.coeffs, other.coeffs

src/kyber_py/polynomials/polynomials_generic.py

+7
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@ def __init__(self, q, n):
1414
self.element = Polynomial
1515

1616
def gen(self):
17+
"""
18+
Return the generator `x` of the polynomial ring
19+
"""
1720
return self([0, 1])
1821

1922
def random_element(self):
23+
"""
24+
Compute a random element of the polynomial ring with coefficients in the
25+
canonical range: ``[0, q-1]``
26+
"""
2027
coefficients = [random.randint(0, self.q - 1) for _ in range(self.n)]
2128
return self(coefficients)
2229

0 commit comments

Comments
 (0)