Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NTTs: remove ldr/str macros that are no longer needed #36

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def core(self, slothy):


class ntt_kyber_1234_567(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA72, timeout=None):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55, timeout=None):
name = "ntt_kyber_1234_567"
infile = name

Expand Down Expand Up @@ -724,7 +724,7 @@ def core(self, slothy):


class ntt_kyber_1234(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA72):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "ntt_kyber_1234"
infile = "ntt_kyber_1234_567"

Expand All @@ -749,7 +749,7 @@ def core(self, slothy):


class ntt_kyber_567(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA72, timeout=None):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55, timeout=None):
name = "ntt_kyber_567"
infile = "ntt_kyber_1234_567"

Expand Down Expand Up @@ -1136,7 +1136,7 @@ def core(self, slothy):


class ntt_dilithium_1234_5678(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA72, timeout=None):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55, timeout=None):
name = f"ntt_dilithium_1234_5678"
infile = name

Expand Down Expand Up @@ -1226,7 +1226,7 @@ def core(self, slothy):


class ntt_dilithium_1234(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA72):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "ntt_dilithium_1234"
infile = "ntt_dilithium_1234_5678"

Expand All @@ -1250,7 +1250,7 @@ def core(self, slothy):


class ntt_dilithium_5678(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA72):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "ntt_dilithium_5678"
infile = "ntt_dilithium_1234_5678"

Expand Down Expand Up @@ -1399,6 +1399,7 @@ def main():
ntt_kyber_1234_567(),
intt_kyber_123_4567(),
intt_kyber_123_4567(var="manual_ld4"),

# Cortex-A72
ntt_kyber_123_4567(target=Target_CortexA72),
ntt_kyber_123_4567(var="scalar_load", target=Target_CortexA72),
Expand All @@ -1408,6 +1409,7 @@ def main():
ntt_kyber_1234_567(target=Target_CortexA72),
intt_kyber_123_4567(target=Target_CortexA72),
intt_kyber_123_4567(var="manual_ld4", target=Target_CortexA72),

# # Apple M1 Firestorm
ntt_kyber_123_4567(target=Target_AppleM1_firestorm, timeout=3600),
ntt_kyber_123_4567(var="scalar_load", target=Target_AppleM1_firestorm, timeout=3600),
Expand Down Expand Up @@ -1453,6 +1455,7 @@ def main():
intt_dilithium_123_45678(var="manual_ld4"),
intt_dilithium_1234_5678(),
intt_dilithium_1234_5678(var="manual_ld4"),

# Cortex-A72
ntt_dilithium_123_45678(target=Target_CortexA72),
ntt_dilithium_123_45678(var="w_scalar", target=Target_CortexA72),
Expand All @@ -1463,6 +1466,7 @@ def main():
intt_dilithium_123_45678(var="manual_ld4", target=Target_CortexA72),
intt_dilithium_1234_5678(target=Target_CortexA72),
intt_dilithium_1234_5678(var="manual_ld4", target=Target_CortexA72),

# Apple M1 Firestorm
ntt_dilithium_123_45678(target=Target_AppleM1_firestorm, timeout=3600),
ntt_dilithium_123_45678(var="w_scalar", target=Target_AppleM1_firestorm, timeout=3600),
Expand Down
122 changes: 52 additions & 70 deletions examples/naive/aarch64/intt_dilithium_1234_5678.s
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,6 @@
// Eventually, NeLight should include a proper parser for AArch64,
// but for initial investigations, the below is enough.

.macro ldr_vo vec, base, offset
ldr qform_\vec, [\base, #\offset]
.endm
.macro ldr_vi vec, base, inc
ldr qform_\vec, [\base], #\inc
.endm
.macro str_vo vec, base, offset
str qform_\vec, [\base, #\offset]
.endm
.macro str_vi vec, base, inc
str qform_\vec, [\base], #\inc
.endm
.macro vsub d,a,b
sub \d\().4s, \a\().4s, \b\().4s
.endm
Expand Down Expand Up @@ -132,31 +120,31 @@
.endm

.macro load_roots_1234 r_ptr
ldr_vi root0, \r_ptr, (8*16)
ldr_vo root1, \r_ptr, (-8*16 + 1*16)
ldr_vo root2, \r_ptr, (-8*16 + 2*16)
ldr_vo root3, \r_ptr, (-8*16 + 3*16)
ldr_vo root4, \r_ptr, (-8*16 + 4*16)
ldr_vo root5, \r_ptr, (-8*16 + 5*16)
ldr_vo root6, \r_ptr, (-8*16 + 6*16)
ldr_vo root7, \r_ptr, (-8*16 + 7*16)
ldr qform_root0, [\r_ptr], #(8*16)
ldr qform_root1, [\r_ptr, #(-8*16 + 1*16)]
ldr qform_root2, [\r_ptr, #(-8*16 + 2*16)]
ldr qform_root3, [\r_ptr, #(-8*16 + 3*16)]
ldr qform_root4, [\r_ptr, #(-8*16 + 4*16)]
ldr qform_root5, [\r_ptr, #(-8*16 + 5*16)]
ldr qform_root6, [\r_ptr, #(-8*16 + 6*16)]
ldr qform_root7, [\r_ptr, #(-8*16 + 7*16)]
.endm

.macro load_next_roots_56 root0, r_ptr0
ldr_vi \root0, \r_ptr0, 16
ldr qform_\root0, [\r_ptr0], #16
.endm

.macro load_next_roots_6 root0, r_ptr0
ldr_vi \root0, \r_ptr0, 8
ldr qform_\root0, [\r_ptr0], #8
.endm

.macro load_next_roots_78 root0, root0_tw, root1, root1_tw, root2, root2_tw, r_ptr1
ldr_vi \root0, \r_ptr1, (6*16)
ldr_vo \root0_tw, \r_ptr1, (-6*16 + 1*16)
ldr_vo \root1, \r_ptr1, (-6*16 + 2*16)
ldr_vo \root1_tw, \r_ptr1, (-6*16 + 3*16)
ldr_vo \root2, \r_ptr1, (-6*16 + 4*16)
ldr_vo \root2_tw, \r_ptr1, (-6*16 + 5*16)
ldr qform_\root0, [\r_ptr1], #(6*16)
ldr qform_\root0_tw, [\r_ptr1, #(-6*16 + 1*16)]
ldr qform_\root1, [\r_ptr1, #(-6*16 + 2*16)]
ldr qform_\root1_tw, [\r_ptr1, #(-6*16 + 3*16)]
ldr qform_\root2, [\r_ptr1, #(-6*16 + 4*16)]
ldr qform_\root2_tw, [\r_ptr1, #(-6*16 + 5*16)]
.endm

.macro transpose4 data
Expand Down Expand Up @@ -334,12 +322,6 @@ _intt_dilithium_1234_5678:

.p2align 2
layer5678_start:
// manual_ld4
// ldr_vo data0, inp, (16*0)
// ldr_vo data1, inp, (16*1)
// ldr_vo data2, inp, (16*2)
// ldr_vo data3, inp, (16*3)
// transpose4 data

ld4 {data0.4S, data1.4S, data2.4S, data3.4S}, [inp]

Expand All @@ -363,10 +345,10 @@ layer5678_start:
barrett_reduce_single data0
barrett_reduce_single data1

str_vi data0, inp, (16*4)
str_vo data1, inp, (-16*4 + 1*16)
str_vo data2, inp, (-16*4 + 2*16)
str_vo data3, inp, (-16*4 + 3*16)
str qform_data0, [inp], #(16*4)
str qform_data1, [inp, #(-16*4 + 1*16)]
str qform_data2, [inp, #(-16*4 + 2*16)]
str qform_data3, [inp, #(-16*4 + 3*16)]
// layer5678_end:
subs count, count, #1
cbnz count, layer5678_start
Expand Down Expand Up @@ -411,22 +393,22 @@ layer5678_start:

.p2align 2
layer1234_start:
ldr_vo data0, in, 0
ldr_vo data1, in, (1*(512/8))
ldr_vo data2, in, (2*(512/8))
ldr_vo data3, in, (3*(512/8))
ldr_vo data4, in, (4*(512/8))
ldr_vo data5, in, (5*(512/8))
ldr_vo data6, in, (6*(512/8))
ldr_vo data7, in, (7*(512/8))
ldr_vo data8, in, (8*(512/8))
ldr_vo data9, in, (9*(512/8))
ldr_vo data10, in, (10*(512/8))
ldr_vo data11, in, (11*(512/8))
ldr_vo data12, in, (12*(512/8))
ldr_vo data13, in, (13*(512/8))
ldr_vo data14, in, (14*(512/8))
ldr_vo data15, in, (15*(512/8))
ldr qform_data0, [in]
ldr qform_data1, [in, #(1*(512/8))]
ldr qform_data2, [in, #(2*(512/8))]
ldr qform_data3, [in, #(3*(512/8))]
ldr qform_data4, [in, #(4*(512/8))]
ldr qform_data5, [in, #(5*(512/8))]
ldr qform_data6, [in, #(6*(512/8))]
ldr qform_data7, [in, #(7*(512/8))]
ldr qform_data8, [in, #(8*(512/8))]
ldr qform_data9, [in, #(9*(512/8))]
ldr qform_data10, [in, #(10*(512/8))]
ldr qform_data11, [in, #(11*(512/8))]
ldr qform_data12, [in, #(12*(512/8))]
ldr qform_data13, [in, #(13*(512/8))]
ldr qform_data14, [in, #(14*(512/8))]
ldr qform_data15, [in, #(15*(512/8))]

// layer4
gs_butterfly data0, data1, root3, 2, 3
Expand Down Expand Up @@ -477,14 +459,14 @@ layer1234_start:
canonical_reduce data14, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data15, modulus_half, neg_modulus_half, t2, t3

str_vo data8, in, (8*(512/8))
str_vo data9, in, (9*(512/8))
str_vo data10, in, (10*(512/8))
str_vo data11, in, (11*(512/8))
str_vo data12, in, (12*(512/8))
str_vo data13, in, (13*(512/8))
str_vo data14, in, (14*(512/8))
str_vo data15, in, (15*(512/8))
str qform_data8, [in, #(8*(512/8))]
str qform_data9, [in, #(9*(512/8))]
str qform_data10, [in, #(10*(512/8))]
str qform_data11, [in, #(11*(512/8))]
str qform_data12, [in, #(12*(512/8))]
str qform_data13, [in, #(13*(512/8))]
str qform_data14, [in, #(14*(512/8))]
str qform_data15, [in, #(15*(512/8))]

// Scale half the coeffs by 1/n; for the other half, the scaling has
// been merged into the multiplication with the twiddle factor on the
Expand All @@ -500,14 +482,14 @@ layer1234_start:
canonical_reduce data6, modulus_half, neg_modulus_half, t2, t3
canonical_reduce data7, modulus_half, neg_modulus_half, t2, t3

str_vi data0, in, (16)
str_vo data1, in, (-16 + 1*(512/8))
str_vo data2, in, (-16 + 2*(512/8))
str_vo data3, in, (-16 + 3*(512/8))
str_vo data4, in, (-16 + 4*(512/8))
str_vo data5, in, (-16 + 5*(512/8))
str_vo data6, in, (-16 + 6*(512/8))
str_vo data7, in, (-16 + 7*(512/8))
str qform_data0, [in], #(16)
str qform_data1, [in, #(-16 + 1*(512/8))]
str qform_data2, [in, #(-16 + 2*(512/8))]
str qform_data3, [in, #(-16 + 3*(512/8))]
str qform_data4, [in, #(-16 + 4*(512/8))]
str qform_data5, [in, #(-16 + 5*(512/8))]
str qform_data6, [in, #(-16 + 6*(512/8))]
str qform_data7, [in, #(-16 + 7*(512/8))]

// layer1234_end:
subs count, count, #1
Expand Down
Loading
Loading