From e3a5205dc40df78fd3d181dbd27e6150c9547da0 Mon Sep 17 00:00:00 2001 From: Tiago Oliveira Date: Wed, 7 Feb 2024 02:02:17 +0100 Subject: [PATCH] mlkem: remove -lea and adjust code --- code/jasmin/mlkem_avx2/Makefile | 2 +- .../jasmin/mlkem_avx2/extraction/jkem_avx2.ec | 4 +-- code/jasmin/mlkem_avx2/kem.jinc | 4 +-- code/jasmin/mlkem_ref/Makefile | 2 +- code/jasmin/mlkem_ref/extraction/jkem.ec | 25 ++++++++++++------- code/jasmin/mlkem_ref/kem.jinc | 4 +-- code/jasmin/mlkem_ref/poly.jinc | 14 +++++------ 7 files changed, 31 insertions(+), 24 deletions(-) diff --git a/code/jasmin/mlkem_avx2/Makefile b/code/jasmin/mlkem_avx2/Makefile index 64e42e58..46ac26af 100644 --- a/code/jasmin/mlkem_avx2/Makefile +++ b/code/jasmin/mlkem_avx2/Makefile @@ -6,7 +6,7 @@ CC ?= /usr/bin/gcc GFLAGS ?= CFLAGS := -Wall -Wextra -g -Ofast -fomit-frame-pointer -JFLAGS := -lea ${JADDFLAGS} +JFLAGS := ${JADDFLAGS} OS := $(shell uname -s) .SECONDARY: jpoly.s jpolyvec.s jfips202.s jindcpa.s jindcpa.o jkem.s diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec index 84eae86c..e670cafa 100644 --- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec +++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec @@ -5089,8 +5089,8 @@ module M(SC:Syscall_t) = { skp); buf <- Array64.init (fun i_0 => if 0 <= i_0 < 0 + 32 then aux.[i_0-0] else buf.[i_0]); - hp <- (skp + (W64.of_int 32)); - hp <- (hp + (W64.of_int (((24 * 3) * 256) `|>>` 3))); + hp <- skp; + hp <- (hp + (W64.of_int (32 + (((24 * 3) * 256) `|>>` 3)))); aux_0 <- (32 %/ 8); i <- 0; while (i < aux_0) { diff --git a/code/jasmin/mlkem_avx2/kem.jinc b/code/jasmin/mlkem_avx2/kem.jinc index fd2b4939..940411f5 100644 --- a/code/jasmin/mlkem_avx2/kem.jinc +++ b/code/jasmin/mlkem_avx2/kem.jinc @@ -106,8 +106,8 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) buf[0:MLKEM_INDCPA_MSGBYTES] = __indcpa_dec_1(buf[0:MLKEM_INDCPA_MSGBYTES], ctp, skp); - hp = skp + 32; - hp += 24 * MLKEM_K * MLKEM_N>>3; + hp = skp; + hp += 32 + (24 * MLKEM_K * MLKEM_N>>3); /* fixme: should loads be 256-bits long? */ for i=0 to MLKEM_SYMBYTES/8 diff --git a/code/jasmin/mlkem_ref/Makefile b/code/jasmin/mlkem_ref/Makefile index 6a3c0454..86469323 100644 --- a/code/jasmin/mlkem_ref/Makefile +++ b/code/jasmin/mlkem_ref/Makefile @@ -4,7 +4,7 @@ CC ?= /usr/bin/gcc CFLAGS := -Wall -Wextra -g -O3 -fomit-frame-pointer -JFLAGS := -lea ${JADDFLAGS} +JFLAGS := ${JADDFLAGS} OS := $(shell uname -s) .SECONDARY: jpoly.s jpolyvec.s jfips203.s jindcpa.s jkem.s diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index 30cbaf5d..a2d16249 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -1114,13 +1114,16 @@ module M(SC:Syscall_t) = { zeta_0 <- zetasp.[(W64.to_uint zetasctr)]; zetasctr <- (zetasctr + (W64.of_int 1)); j <- start; - cmp <- (start + len); + cmp <- start; + cmp <- (cmp + len); while ((j \ult cmp)) { - offset <- (j + len); + offset <- j; + offset <- (offset + len); s <- rp.[(W64.to_uint offset)]; t <- rp.[(W64.to_uint j)]; - m <- (s + t); + m <- s; + m <- (m + t); m <@ __barrett_reduce (m); rp.[(W64.to_uint j)] <- m; t <- (t - s); @@ -1128,7 +1131,8 @@ module M(SC:Syscall_t) = { rp.[(W64.to_uint offset)] <- t; j <- (j + (W64.of_int 1)); } - start <- (j + len); + start <- j; + start <- (start + len); } len <- (len `<<` (W8.of_int 1)); } @@ -1169,10 +1173,12 @@ module M(SC:Syscall_t) = { zetasctr <- (zetasctr + (W64.of_int 1)); zeta_0 <- zetasp.[(W64.to_uint zetasctr)]; j <- start; - cmp <- (start + len); + cmp <- start; + cmp <- (cmp + len); while ((j \ult cmp)) { - offset <- (j + len); + offset <- j; + offset <- (offset + len); t <- rp.[(W64.to_uint offset)]; t <@ __fqmul (t, zeta_0); s <- rp.[(W64.to_uint j)]; @@ -1183,7 +1189,8 @@ module M(SC:Syscall_t) = { rp.[(W64.to_uint j)] <- t; j <- (j + (W64.of_int 1)); } - start <- (j + len); + start <- j; + start <- (start + len); } len <- (len `>>` (W8.of_int 1)); } @@ -2308,8 +2315,8 @@ module M(SC:Syscall_t) = { skp); buf <- Array64.init (fun i_0 => if 0 <= i_0 < 0 + 32 then aux.[i_0-0] else buf.[i_0]); - hp <- (skp + (W64.of_int 32)); - hp <- (hp + (W64.of_int (((24 * 3) * 256) `|>>` 3))); + hp <- skp; + hp <- (hp + (W64.of_int (32 + (((24 * 3) * 256) `|>>` 3)))); aux_0 <- (32 %/ 8); i <- 0; while (i < aux_0) { diff --git a/code/jasmin/mlkem_ref/kem.jinc b/code/jasmin/mlkem_ref/kem.jinc index 8fa79c3c..4795a352 100644 --- a/code/jasmin/mlkem_ref/kem.jinc +++ b/code/jasmin/mlkem_ref/kem.jinc @@ -107,8 +107,8 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) buf[0:MLKEM_MSGBYTES] = __indcpa_dec(buf[0:MLKEM_MSGBYTES], ctp, skp); - hp = skp + 32; - hp += 24 * MLKEM_K * MLKEM_N>>3; + hp = skp; + hp += 32 + (24 * MLKEM_K * MLKEM_N>>3); for i=0 to MLKEM_SYMBYTES/8 { diff --git a/code/jasmin/mlkem_ref/poly.jinc b/code/jasmin/mlkem_ref/poly.jinc index 99c3392e..0b69a270 100644 --- a/code/jasmin/mlkem_ref/poly.jinc +++ b/code/jasmin/mlkem_ref/poly.jinc @@ -505,13 +505,13 @@ fn _poly_invntt(reg ptr u16[MLKEM_N] rp) -> reg ptr u16[MLKEM_N] zetasctr += 1; j = start; - cmp = start + len; + cmp = start; cmp += len; while (j < cmp) { - offset = j + len; + offset = j; offset += len; s = rp[(int)offset]; t = rp[(int)j]; - m = s + t; + m = s; m += t; m = __barrett_reduce(m); rp[(int)j] = m; t -= s; @@ -519,7 +519,7 @@ fn _poly_invntt(reg ptr u16[MLKEM_N] rp) -> reg ptr u16[MLKEM_N] rp[(int)offset] = t; j += 1; } - start = j + len; + start = j; start += len; } len <<= 1; } @@ -563,10 +563,10 @@ fn _poly_ntt(reg ptr u16[MLKEM_N] rp) -> reg ptr u16[MLKEM_N] zetasctr += 1; zeta = zetasp[(int)zetasctr]; j = start; - cmp = start + len; + cmp = start; cmp += len; while (j < cmp) { - offset = j + len; + offset = j; offset += len; t = rp[(int)offset]; t = __fqmul(t, zeta); s = rp[(int)j]; @@ -577,7 +577,7 @@ fn _poly_ntt(reg ptr u16[MLKEM_N] rp) -> reg ptr u16[MLKEM_N] rp[(int)j] = t; j += 1; } - start = j + len; + start = j; start += len; } len >>= 1; }