Skip to content

Commit

Permalink
Merge pull request #37 from dop-amin/invntt
Browse files Browse the repository at this point in the history
Add inverse NTTs for Kyber & Dilithium
  • Loading branch information
hanno-becker authored Apr 11, 2024
2 parents 866fa4a + d5c0d02 commit d7b5296
Show file tree
Hide file tree
Showing 56 changed files with 52,016 additions and 189 deletions.
126 changes: 125 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,28 @@ def core(self, slothy):
slothy.optimize_loop("layer123_start")
slothy.optimize_loop("layer4567_start")

class intt_kyber_123_4567(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55, timeout=None):
name = "intt_kyber_123_4567"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout)

def core(self, slothy):
slothy.config.sw_pipelining.enabled = True
slothy.config.inputs_are_outputs = True
slothy.config.sw_pipelining.minimize_overlapping = False
slothy.config.variable_size = True
slothy.config.reserved_regs = [f"x{i}" for i in range(0, 7)] + ["x30", "sp"]
slothy.config.constraints.stalls_first_attempt = 64
slothy.optimize_loop("layer4567_start")
slothy.optimize_loop("layer123_start")


class ntt_kyber_123(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
Expand Down Expand Up @@ -1030,6 +1052,39 @@ def core(self, slothy):
slothy.optimize_loop("layer45678_start")


class intt_dilithium_123_45678(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55, timeout=None):
name = f"intt_dilithium_123_45678"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout)

def core(self, slothy):
slothy.config.sw_pipelining.enabled = True
slothy.config.sw_pipelining.minimize_overlapping = False
slothy.config.inputs_are_outputs = True

slothy.config.reserved_regs = [
f"x{i}" for i in range(0, 7)] + ["v8", "x30", "sp"]
slothy.config.reserved_regs += self.target_reserved
slothy.config.constraints.stalls_first_attempt = 40
slothy.optimize_loop("layer45678_start")

slothy.config.reserved_regs = [
f"x{i}" for i in range(0, 7)] + ["v8", "x30", "sp"]
slothy.config.reserved_regs += self.target_reserved
slothy.config.inputs_are_outputs = True
slothy.config.constraints.stalls_first_attempt = 110
slothy.optimize_loop("layer123_start")




class ntt_dilithium_123(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "ntt_dilithium_123"
Expand Down Expand Up @@ -1124,6 +1179,51 @@ def core(self, slothy):
slothy.optimize_loop("layer5678_start")


class intt_dilithium_1234_5678(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA72, timeout=None):
name = f"intt_dilithium_1234_5678"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout)

def core(self, slothy):
conf = slothy.config.copy()

slothy.config.reserved_regs = [
f"x{i}" for i in range(0, 6)] + ["x30", "sp"]
slothy.config.inputs_are_outputs = True
slothy.config.reserved_regs += self.target_reserved
slothy.config.sw_pipelining.enabled = True
slothy.config.sw_pipelining.minimize_overlapping = False
slothy.config.sw_pipelining.halving_heuristic = False
slothy.config.split_heuristic = False
slothy.optimize_loop("layer5678_start")

slothy.config = conf.copy()

if self.timeout is not None:
slothy.config.timeout = self.timeout // 12

slothy.config.sw_pipelining.enabled = True
slothy.config.sw_pipelining.minimize_overlapping = False
slothy.config.reserved_regs = [
f"x{i}" for i in range(0, 6)] + ["x30", "sp"]
slothy.config.reserved_regs += self.target_reserved
slothy.config.inputs_are_outputs = True
slothy.config.sw_pipelining.halving_heuristic = True
slothy.config.split_heuristic = True
slothy.config.split_heuristic_factor = 2
slothy.config.split_heuristic_repeat = 4
slothy.config.split_heuristic_stepsize = 0.1
slothy.config.constraints.stalls_first_attempt = 14
slothy.optimize_loop("layer1234_start")


class ntt_dilithium_1234(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA72):
name = "ntt_dilithium_1234"
Expand Down Expand Up @@ -1296,13 +1396,17 @@ def main():
ntt_kyber_123_4567(var="scalar_load_store"),
ntt_kyber_123_4567(var="manual_st4"),
ntt_kyber_1234_567(),
intt_kyber_123_4567(),
intt_kyber_123_4567(var="manual_ld4"),
# Cortex-A72
ntt_kyber_123_4567(target=Target_CortexA72),
ntt_kyber_123_4567(var="scalar_load", target=Target_CortexA72),
ntt_kyber_123_4567(var="scalar_store", target=Target_CortexA72),
ntt_kyber_123_4567(var="scalar_load_store", target=Target_CortexA72),
ntt_kyber_123_4567(var="manual_st4", target=Target_CortexA72),
ntt_kyber_1234_567(target=Target_CortexA72),
intt_kyber_123_4567(target=Target_CortexA72),
intt_kyber_123_4567(var="manual_ld4", target=Target_CortexA72),
# # Apple M1 Firestorm
ntt_kyber_123_4567(target=Target_AppleM1_firestorm, timeout=3600),
ntt_kyber_123_4567(var="scalar_load", target=Target_AppleM1_firestorm, timeout=3600),
Expand All @@ -1311,6 +1415,8 @@ def main():
ntt_kyber_123_4567(var="manual_st4", target=Target_AppleM1_firestorm, timeout=3600),
ntt_kyber_1234_567(target=Target_AppleM1_firestorm, timeout=300),
ntt_kyber_1234_567(var="manual_st4", target=Target_AppleM1_firestorm, timeout=300),
intt_kyber_123_4567(target=Target_AppleM1_firestorm, timeout=3600),
intt_kyber_123_4567(var="manual_ld4", target=Target_AppleM1_firestorm, timeout=3600),
# Apple M1 Icestorm
ntt_kyber_123_4567(target=Target_AppleM1_icestorm, timeout=3600),
ntt_kyber_123_4567(var="scalar_load", target=Target_AppleM1_icestorm, timeout=3600),
Expand All @@ -1319,6 +1425,8 @@ def main():
ntt_kyber_123_4567(var="manual_st4", target=Target_AppleM1_icestorm, timeout=3600),
ntt_kyber_1234_567(target=Target_AppleM1_icestorm, timeout=300),
ntt_kyber_1234_567(var="manual_st4", target=Target_AppleM1_icestorm, timeout=300),
intt_kyber_123_4567(target=Target_AppleM1_icestorm, timeout=3600),
intt_kyber_123_4567(var="manual_ld4", target=Target_AppleM1_icestorm, timeout=3600),
# Kyber InvNTT
# Cortex-M55
intt_kyber_1_23_45_67(),
Expand All @@ -1340,24 +1448,40 @@ def main():
ntt_dilithium_123_45678(var="manual_st4"),
ntt_dilithium_1234_5678(),
ntt_dilithium_1234_5678(var="manual_st4"),
intt_dilithium_123_45678(),
intt_dilithium_123_45678(var="manual_ld4"),
intt_dilithium_1234_5678(),
intt_dilithium_1234_5678(var="manual_ld4"),
# Cortex-A72
ntt_dilithium_123_45678(target=Target_CortexA72),
ntt_dilithium_123_45678(var="w_scalar", target=Target_CortexA72),
ntt_dilithium_123_45678(var="manual_st4", target=Target_CortexA72),
ntt_dilithium_1234_5678(target=Target_CortexA72),
ntt_dilithium_1234_5678(var="manual_st4", target=Target_CortexA72),
intt_dilithium_123_45678(target=Target_CortexA72),
intt_dilithium_123_45678(var="manual_ld4", target=Target_CortexA72),
intt_dilithium_1234_5678(target=Target_CortexA72),
intt_dilithium_1234_5678(var="manual_ld4", target=Target_CortexA72),
# Apple M1 Firestorm
ntt_dilithium_123_45678(target=Target_AppleM1_firestorm, timeout=3600),
ntt_dilithium_123_45678(target=Target_AppleM1_firestorm, timeout=3600),
ntt_dilithium_123_45678(var="w_scalar", target=Target_AppleM1_firestorm, timeout=3600),
ntt_dilithium_123_45678(var="manual_st4", target=Target_AppleM1_firestorm, timeout=3600),
ntt_dilithium_1234_5678(target=Target_AppleM1_firestorm, timeout=300),
ntt_dilithium_1234_5678(var="manual_st4", target=Target_AppleM1_firestorm, timeout=300),
intt_dilithium_123_45678(target=Target_AppleM1_firestorm, timeout=3600),
intt_dilithium_123_45678(var="manual_ld4", target=Target_AppleM1_firestorm, timeout=3600),
intt_dilithium_1234_5678(target=Target_AppleM1_firestorm, timeout=3600),
intt_dilithium_1234_5678(var="manual_ld4", target=Target_AppleM1_firestorm, timeout=3600),
# Apple M1 Icestorm
ntt_dilithium_123_45678(target=Target_AppleM1_icestorm, timeout=3600),
ntt_dilithium_123_45678(var="w_scalar", target=Target_AppleM1_icestorm, timeout=3600),
ntt_dilithium_123_45678(var="manual_st4", target=Target_AppleM1_icestorm, timeout=3600),
ntt_dilithium_1234_5678(target=Target_AppleM1_icestorm, timeout=300),
ntt_dilithium_1234_5678(var="manual_st4", target=Target_AppleM1_icestorm, timeout=300),
intt_dilithium_123_45678(target=Target_AppleM1_icestorm, timeout=3600),
intt_dilithium_123_45678(var="manual_ld4", target=Target_AppleM1_icestorm, timeout=3600),
intt_dilithium_1234_5678(target=Target_AppleM1_icestorm, timeout=3600),
intt_dilithium_1234_5678(var="manual_ld4", target=Target_AppleM1_icestorm, timeout=3600),
# Dilithium invNTT
# Cortex-M55
intt_dilithium_12_34_56_78(),
Expand Down
12 changes: 12 additions & 0 deletions examples/misc/gen_roots.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,13 @@ def _main():
ntt_kyber_l123.export("../naive/ntt_kyber_123_45_67_twiddles.s")
ntt_kyber_l123.export("../opt/ntt_kyber_123_45_67_twiddles.s")

# For intt_kyber_123_4567.s
intt_kyber_l123 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,3),(3,2),(5,2)],
pad=[0,3], print_label=True, widen_single_twiddles_to_words=False,
inverse=True)
intt_kyber_l123.export("../naive/aarch64/intt_kyber_123_45_67_twiddles.s")
intt_kyber_l123.export("../opt/aarch64/intt_kyber_123_45_67_twiddles.s")

ntt_kyber = NttRootGen(size=256,modulus=3329,root=17,layers=7)
ntt_kyber.export("../naive/ntt_kyber_1_23_45_67_twiddles.s")
ntt_kyber.export("../opt/ntt_kyber_1_23_45_67_twiddles.s")
Expand All @@ -428,6 +435,11 @@ def _main():
ntt_dilithium_l123.export("../naive/ntt_dilithium_123_456_78_twiddles.s")
ntt_dilithium_l123.export("../opt/ntt_dilithium_123_456_78_twiddles.s")

intt_dilithium_l123 = NttRootGen(size=256,inverse=True,bitsize=32,modulus=8380417,root=1753,layers=8,
print_label=True, pad=[0,3], iters=[(0,3),(3,3),(6,2)])
intt_dilithium_l123.export("../naive/aarch64/intt_dilithium_123_456_78_twiddles.s")
intt_dilithium_l123.export("../opt/aarch64/intt_dilithium_123_456_78_twiddles.s")

ntt_dilithium_l123 = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8,
print_label=True, pad=[0,3], iters=[(0,3),(3,3),(6,2)])
ntt_dilithium_l123.export("../naive/aarch64/ntt_dilithium_123_456_78_twiddles.s")
Expand Down
81 changes: 44 additions & 37 deletions examples/naive/aarch64/intt_dilithium_1234_5678.s
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@
.endm

.macro mulmodq dst, src, const, idx0, idx1
vqrdmulhq t2, \src, \const, \idx1
vmulq \dst, \src, \const, \idx0
vqrdmulhq \src, \src, \const, \idx1
vmls \dst, \src, modulus
vmls \dst, t2, modulus
.endm

.macro mulmod dst, src, const, const_twisted
vmul \dst, \src, \const
vqrdmulh \src, \src, \const_twisted
vmls \dst, \src, modulus
vqrdmulh t2, \src, \const_twisted
mul \dst\().4s, \src\().4s, \const\().4s
vmls \dst, t2, modulus
.endm

.macro montg_reduce a
.macro barrett_reduce_single a
srshr tmp.4S, \a\().4S, #23
vmls \a, tmp, modulus
.endm
Expand All @@ -114,12 +114,6 @@
mulmodq \b, tmp, \root, \idx0, \idx1
.endm

.macro mulmod_v dst, src, const, const_twisted
vmul \dst, \src, \const
vqrdmulh \src, \src, \const_twisted
vmls \dst, \src, modulus
.endm

.macro gs_butterfly_v a, b, root, root_twisted
vsub tmp, \a, \b
vadd \a, \a, \b
Expand Down Expand Up @@ -235,6 +229,12 @@
restore_gprs
.endm

// For comparability reasons, the output range for the coefficients of this
// invNTT code is supposed to match the implementation from PQClean on commit
// ee71d2c823982bfcf54686f3cf1d666f396dc9aa. After the invNTT, the coefficients
// are canonically reduced. The ordering of the coefficients is canonical, also
// matching PQClean.

.data
.p2align 4
roots:
Expand Down Expand Up @@ -334,10 +334,14 @@ _intt_dilithium_1234_5678:

.p2align 2
layer5678_start:
ldr_vo data0, inp, (16*0)
ldr_vo data1, inp, (16*1)
ldr_vo data2, inp, (16*2)
ldr_vo data3, inp, (16*3)
// manual_ld4
// ldr_vo data0, inp, (16*0)
// ldr_vo data1, inp, (16*1)
// ldr_vo data2, inp, (16*2)
// ldr_vo data3, inp, (16*3)
// transpose4 data

ld4 {data0.4S, data1.4S, data2.4S, data3.4S}, [inp]

load_next_roots_78 root0, root0_tw, root1, root1_tw, root2, root2_tw, r_ptr0

Expand All @@ -356,8 +360,8 @@ layer5678_start:
gs_butterfly data0, data2, root1, 0, 1
gs_butterfly data1, data3, root1, 0, 1

montg_reduce data0
montg_reduce data1
barrett_reduce_single data0
barrett_reduce_single data1

str_vi data0, inp, (16*4)
str_vo data1, inp, (-16*4 + 1*16)
Expand Down Expand Up @@ -482,25 +486,28 @@ layer1234_start:
str_vo data14, in, (14*(512/8))
str_vo data15, in, (15*(512/8))

mul_ninv data8, data9, data10, data11, data12, data13, data14, data15, data0, data1, data2, data3, data4, data5, data6, data7

canonical_reduce data8, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data9, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data10, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data11, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data12, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data13, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data14, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data15, modulus_half, neg_modulus_half, t2, t3

str_vi data8, in, (16)
str_vo data9, in, (-16 + 1*(512/8))
str_vo data10, in, (-16 + 2*(512/8))
str_vo data11, in, (-16 + 3*(512/8))
str_vo data12, in, (-16 + 4*(512/8))
str_vo data13, in, (-16 + 5*(512/8))
str_vo data14, in, (-16 + 6*(512/8))
str_vo data15, in, (-16 + 7*(512/8))
// Scale half the coeffs by 1/n; for the other half, the scaling has
// been merged into the multiplication with the twiddle factor on the
// last layer.
mul_ninv data0, data1, data2, data3, data4, data5, data6, data7, data0, data1, data2, data3, data4, data5, data6, data7

canonical_reduce data0, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data1, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data2, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data3, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data4, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data5, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data6, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data7, modulus_half, neg_modulus_half, t2, t3

str_vi data0, in, (16)
str_vo data1, in, (-16 + 1*(512/8))
str_vo data2, in, (-16 + 2*(512/8))
str_vo data3, in, (-16 + 3*(512/8))
str_vo data4, in, (-16 + 4*(512/8))
str_vo data5, in, (-16 + 5*(512/8))
str_vo data6, in, (-16 + 6*(512/8))
str_vo data7, in, (-16 + 7*(512/8))

// layer1234_end:
subs count, count, #1
Expand Down
Loading

0 comments on commit d7b5296

Please sign in to comment.