Skip to content

Commit

Permalink
Remove dead tail code from (non-SHA3) AES-GCM AArch64 kernel (#1639)
Browse files Browse the repository at this point in the history
On AArch64 systems without support for EOR3, assembly kernels
`aes_gcm_enc_kernel` and `aes_gcm_dec_kernel` from `aesv8-gcm-armv8.pl`
are used for the bulk of AES-GCM processing. These kernels have
dedicated tail code for handling inputs whose size is not a multiple of
the block size (16 bytes).

However, the unique call-sites for `aes_gcm_enc_kernel` and
`aes_gcm_dec_kernel` in `gcm.c` only invoke them with data of size a
multiple of 16 bytes: See the masking here
[here](https://github.com/aws/aws-lc/blob/98735a2f6723ba984a18b2f79e05173a61e0f869/crypto/fipsmodule/modes/gcm.c#L154)
and
[here](https://github.com/aws/aws-lc/blob/98735a2f6723ba984a18b2f79e05173a61e0f869/crypto/fipsmodule/modes/gcm.c#L191).
This renders the tail code in `aesv8-gcm-armv8.pl` dead.

Simply removing the truncation to 16-byte aligned data in `gcm.c` --
that is, attempting to let `aes_gcm_{dec,enc}_kernel` process the entire
data -- leads to tests failing. It is not clear to me why that is, and
in particular the tail code could be faulty. OpenSSL seems to behave
similarly and call the AArch64 AES-GCM kernels for block-sized data
only.

This PR removes the dead tail code from the non-SHA3 AES-GCM kernels
`aes_gcm_enc_kernel` and `aes_gcm_dec_kernel`. In a first commit, the
code is annotated to explain the effect of the tail code in case of
block-aligned data. In the second commit, the tail code is removed.

It seems that a similar change can be made for the AES-GCM kernels
leveraging SHA3 instructions, but is not attempted here.
  • Loading branch information
hanno-becker authored Jul 8, 2024
1 parent 240ad03 commit 00fcba4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 204 deletions.
67 changes: 16 additions & 51 deletions crypto/fipsmodule/modes/asm/aesv8-gcm-armv8.pl
Original file line number Diff line number Diff line change
Expand Up @@ -777,22 +777,22 @@
fmov $ctr_t0d, $input_l0 // AES block 4k+4 - mov low
fmov $ctr_t0.d[1], $input_h0 // AES block 4k+4 - mov high
eor $res1b, $ctr_t0b, $ctr0b // AES block 4k+4 - result
b.gt .Lenc_blocks_more_than_3
b.gt .Lenc_blocks_4_remaining
cmp $main_end_input_ptr, #32
mov $ctr3b, $ctr2b
movi $acc_l.8b, #0
movi $acc_h.8b, #0
sub $rctr32w, $rctr32w, #1
mov $ctr2b, $ctr1b
movi $acc_m.8b, #0
b.gt .Lenc_blocks_more_than_2
b.gt .Lenc_blocks_3_remaining
mov $ctr3b, $ctr1b
sub $rctr32w, $rctr32w, #1
cmp $main_end_input_ptr, #16
b.gt .Lenc_blocks_more_than_1
b.gt .Lenc_blocks_2_remaining
sub $rctr32w, $rctr32w, #1
b .Lenc_blocks_less_than_1
.Lenc_blocks_more_than_3: // blocks left > 3
b .Lenc_blocks_1_remaining
.Lenc_blocks_4_remaining: // blocks left = 4
st1 { $res1b}, [$output_ptr], #16 // AES final-3 block - store result
ldp $input_l0, $input_h0, [$input_ptr], #16 // AES final-2 block - load input low & high
rev64 $res0b, $res1b // GHASH final-3 block
Expand All @@ -809,7 +809,7 @@
pmull2 $acc_h.1q, $res0.2d, $h4.2d // GHASH final-3 block - high
pmull $acc_m.1q, $rk4v.1d, $acc_m.1d // GHASH final-3 block - mid
eor $res1b, $res1b, $ctr1b // AES final-2 block - result
.Lenc_blocks_more_than_2: // blocks left > 2
.Lenc_blocks_3_remaining: // blocks left = 3
st1 { $res1b}, [$output_ptr], #16 // AES final-2 block - store result
ldp $input_l0, $input_h0, [$input_ptr], #16 // AES final-1 block - load input low & high
rev64 $res0b, $res1b // GHASH final-2 block
Expand All @@ -828,7 +828,7 @@
pmull $rk4v.1q, $rk4v.1d, $h34k.1d // GHASH final-2 block - mid
eor $acc_lb, $acc_lb, $rk3 // GHASH final-2 block - low
eor $acc_mb, $acc_mb, $rk4v.16b // GHASH final-2 block - mid
.Lenc_blocks_more_than_1: // blocks left > 1
.Lenc_blocks_2_remaining: // blocks left = 2
st1 { $res1b}, [$output_ptr], #16 // AES final-1 block - store result
rev64 $res0b, $res1b // GHASH final-1 block
ldp $input_l0, $input_h0, [$input_ptr], #16 // AES final block - load input low & high
Expand All @@ -848,24 +848,9 @@
eor $res1b, $res1b, $ctr3b // AES final block - result
eor $acc_mb, $acc_mb, $rk4v.16b // GHASH final-1 block - mid
eor $acc_lb, $acc_lb, $rk3 // GHASH final-1 block - low
.Lenc_blocks_less_than_1: // blocks left <= 1
and $bit_length, $bit_length, #127 // bit_length %= 128
mvn $rkN_l, xzr // rkN_l = 0xffffffffffffffff
sub $bit_length, $bit_length, #128 // bit_length -= 128
neg $bit_length, $bit_length // bit_length = 128 - #bits in input (in range [1,128])
ld1 { $rk0}, [$output_ptr] // load existing bytes where the possibly partial last block is to be stored
mvn $rkN_h, xzr // rkN_h = 0xffffffffffffffff
and $bit_length, $bit_length, #127 // bit_length %= 128
lsr $rkN_h, $rkN_h, $bit_length // rkN_h is mask for top 64b of last block
cmp $bit_length, #64
csel $input_l0, $rkN_l, $rkN_h, lt
csel $input_h0, $rkN_h, xzr, lt
fmov $ctr0d, $input_l0 // ctr0b is mask for last block
fmov $ctr0.d[1], $input_h0
and $res1b, $res1b, $ctr0b // possibly partial last block has zeroes in highest bits
.Lenc_blocks_1_remaining: // blocks_left = 1
rev64 $res0b, $res1b // GHASH final block
eor $res0b, $res0b, $t0.16b // feed in partial tag
bif $res1b, $rk0, $ctr0b // insert existing bytes in top end of result before storing
pmull2 $rk2q1, $res0.2d, $h1.2d // GHASH final block - high
mov $t0d, $res0.d[1] // GHASH final block - mid
rev $ctr32w, $rctr32w
Expand Down Expand Up @@ -1405,22 +1390,22 @@
cmp $main_end_input_ptr, #48
eor $output_l0, $output_l0, $rkN_l // AES block 4k+4 - round N low
eor $output_h0, $output_h0, $rkN_h // AES block 4k+4 - round N high
b.gt .Ldec_blocks_more_than_3
b.gt .Ldec_blocks_4_remaining
sub $rctr32w, $rctr32w, #1
mov $ctr3b, $ctr2b
movi $acc_m.8b, #0
movi $acc_l.8b, #0
cmp $main_end_input_ptr, #32
movi $acc_h.8b, #0
mov $ctr2b, $ctr1b
b.gt .Ldec_blocks_more_than_2
b.gt .Ldec_blocks_3_remaining
sub $rctr32w, $rctr32w, #1
mov $ctr3b, $ctr1b
cmp $main_end_input_ptr, #16
b.gt .Ldec_blocks_more_than_1
b.gt .Ldec_blocks_2_remaining
sub $rctr32w, $rctr32w, #1
b .Ldec_blocks_less_than_1
.Ldec_blocks_more_than_3: // blocks left > 3
b .Ldec_blocks_1_remaining
.Ldec_blocks_4_remaining: // blocks left = 4
rev64 $res0b, $res1b // GHASH final-3 block
ld1 { $res1b}, [$input_ptr], #16 // AES final-2 block - load ciphertext
stp $output_l0, $output_h0, [$output_ptr], #16 // AES final-3 block - store result
Expand All @@ -1437,7 +1422,7 @@
eor $output_l0, $output_l0, $rkN_l // AES final-2 block - round N low
pmull $acc_l.1q, $res0.1d, $h4.1d // GHASH final-3 block - low
eor $output_h0, $output_h0, $rkN_h // AES final-2 block - round N high
.Ldec_blocks_more_than_2: // blocks left > 2
.Ldec_blocks_3_remaining: // blocks left = 3
rev64 $res0b, $res1b // GHASH final-2 block
ld1 { $res1b}, [$input_ptr], #16 // AES final-1 block - load ciphertext
eor $res0b, $res0b, $t0.16b // feed in partial tag
Expand All @@ -1456,7 +1441,7 @@
eor $output_l0, $output_l0, $rkN_l // AES final-1 block - round N low
eor $acc_mb, $acc_mb, $rk4v.16b // GHASH final-2 block - mid
eor $output_h0, $output_h0, $rkN_h // AES final-1 block - round N high
.Ldec_blocks_more_than_1: // blocks left > 1
.Ldec_blocks_2_remaining: // blocks left = 2
stp $output_l0, $output_h0, [$output_ptr], #16 // AES final-1 block - store result
rev64 $res0b, $res1b // GHASH final-1 block
ld1 { $res1b}, [$input_ptr], #16 // AES final block - load ciphertext
Expand All @@ -1476,28 +1461,8 @@
eor $acc_hb, $acc_hb, $rk2 // GHASH final-1 block - high
eor $acc_mb, $acc_mb, $rk4v.16b // GHASH final-1 block - mid
eor $output_h0, $output_h0, $rkN_h // AES final block - round N high
.Ldec_blocks_less_than_1: // blocks left <= 1
and $bit_length, $bit_length, #127 // bit_length %= 128
mvn $rkN_h, xzr // rkN_h = 0xffffffffffffffff
sub $bit_length, $bit_length, #128 // bit_length -= 128
mvn $rkN_l, xzr // rkN_l = 0xffffffffffffffff
ldp $end_input_ptr, $main_end_input_ptr, [$output_ptr] // load existing bytes we need to not overwrite
neg $bit_length, $bit_length // bit_length = 128 - #bits in input (in range [1,128])
and $bit_length, $bit_length, #127 // bit_length %= 128
lsr $rkN_h, $rkN_h, $bit_length // rkN_h is mask for top 64b of last block
cmp $bit_length, #64
csel $ctr32x, $rkN_l, $rkN_h, lt
csel $ctr96_b64x, $rkN_h, xzr, lt
fmov $ctr0d, $ctr32x // ctr0b is mask for last block
and $output_l0, $output_l0, $ctr32x
mov $ctr0.d[1], $ctr96_b64x
bic $end_input_ptr, $end_input_ptr, $ctr32x // mask out low existing bytes
.Ldec_blocks_1_remaining: // blocks_left = 1
rev $ctr32w, $rctr32w
bic $main_end_input_ptr, $main_end_input_ptr, $ctr96_b64x // mask out high existing bytes
orr $output_l0, $output_l0, $end_input_ptr
and $output_h0, $output_h0, $ctr96_b64x
orr $output_h0, $output_h0, $main_end_input_ptr
and $res1b, $res1b, $ctr0b // possibly partial last block has zeroes in highest bits
rev64 $res0b, $res1b // GHASH final block
eor $res0b, $res0b, $t0.16b // feed in partial tag
pmull $rk3q1, $res0.1d, $h1.1d // GHASH final block - low
Expand Down
67 changes: 16 additions & 51 deletions generated-src/ios-aarch64/crypto/fipsmodule/aesv8-gcm-armv8.S
Original file line number Diff line number Diff line change
Expand Up @@ -657,22 +657,22 @@ Lenc_tail: // TAIL
fmov d4, x6 // AES block 4k+4 - mov low
fmov v4.d[1], x7 // AES block 4k+4 - mov high
eor v5.16b, v4.16b, v0.16b // AES block 4k+4 - result
b.gt Lenc_blocks_more_than_3
b.gt Lenc_blocks_4_remaining
cmp x5, #32
mov v3.16b, v2.16b
movi v11.8b, #0
movi v9.8b, #0
sub w12, w12, #1
mov v2.16b, v1.16b
movi v10.8b, #0
b.gt Lenc_blocks_more_than_2
b.gt Lenc_blocks_3_remaining
mov v3.16b, v1.16b
sub w12, w12, #1
cmp x5, #16
b.gt Lenc_blocks_more_than_1
b.gt Lenc_blocks_2_remaining
sub w12, w12, #1
b Lenc_blocks_less_than_1
Lenc_blocks_more_than_3: // blocks left > 3
b Lenc_blocks_1_remaining
Lenc_blocks_4_remaining: // blocks left = 4
st1 { v5.16b}, [x2], #16 // AES final-3 block - store result
ldp x6, x7, [x0], #16 // AES final-2 block - load input low & high
rev64 v4.16b, v5.16b // GHASH final-3 block
Expand All @@ -689,7 +689,7 @@ Lenc_blocks_more_than_3: // blocks left > 3
pmull2 v9.1q, v4.2d, v15.2d // GHASH final-3 block - high
pmull v10.1q, v22.1d, v10.1d // GHASH final-3 block - mid
eor v5.16b, v5.16b, v1.16b // AES final-2 block - result
Lenc_blocks_more_than_2: // blocks left > 2
Lenc_blocks_3_remaining: // blocks left = 3
st1 { v5.16b}, [x2], #16 // AES final-2 block - store result
ldp x6, x7, [x0], #16 // AES final-1 block - load input low & high
rev64 v4.16b, v5.16b // GHASH final-2 block
Expand All @@ -708,7 +708,7 @@ Lenc_blocks_more_than_2: // blocks left > 2
pmull v22.1q, v22.1d, v17.1d // GHASH final-2 block - mid
eor v11.16b, v11.16b, v21.16b // GHASH final-2 block - low
eor v10.16b, v10.16b, v22.16b // GHASH final-2 block - mid
Lenc_blocks_more_than_1: // blocks left > 1
Lenc_blocks_2_remaining: // blocks left = 2
st1 { v5.16b}, [x2], #16 // AES final-1 block - store result
rev64 v4.16b, v5.16b // GHASH final-1 block
ldp x6, x7, [x0], #16 // AES final block - load input low & high
Expand All @@ -728,24 +728,9 @@ Lenc_blocks_more_than_1: // blocks left > 1
eor v5.16b, v5.16b, v3.16b // AES final block - result
eor v10.16b, v10.16b, v22.16b // GHASH final-1 block - mid
eor v11.16b, v11.16b, v21.16b // GHASH final-1 block - low
Lenc_blocks_less_than_1: // blocks left <= 1
and x1, x1, #127 // bit_length %= 128
mvn x13, xzr // rkN_l = 0xffffffffffffffff
sub x1, x1, #128 // bit_length -= 128
neg x1, x1 // bit_length = 128 - #bits in input (in range [1,128])
ld1 { v18.16b}, [x2] // load existing bytes where the possibly partial last block is to be stored
mvn x14, xzr // rkN_h = 0xffffffffffffffff
and x1, x1, #127 // bit_length %= 128
lsr x14, x14, x1 // rkN_h is mask for top 64b of last block
cmp x1, #64
csel x6, x13, x14, lt
csel x7, x14, xzr, lt
fmov d0, x6 // ctr0b is mask for last block
fmov v0.d[1], x7
and v5.16b, v5.16b, v0.16b // possibly partial last block has zeroes in highest bits
Lenc_blocks_1_remaining: // blocks_left = 1
rev64 v4.16b, v5.16b // GHASH final block
eor v4.16b, v4.16b, v8.16b // feed in partial tag
bif v5.16b, v18.16b, v0.16b // insert existing bytes in top end of result before storing
pmull2 v20.1q, v4.2d, v12.2d // GHASH final block - high
mov d8, v4.d[1] // GHASH final block - mid
rev w9, w12
Expand Down Expand Up @@ -1426,22 +1411,22 @@ Ldec_tail: // TAIL
cmp x5, #48
eor x6, x6, x13 // AES block 4k+4 - round N low
eor x7, x7, x14 // AES block 4k+4 - round N high
b.gt Ldec_blocks_more_than_3
b.gt Ldec_blocks_4_remaining
sub w12, w12, #1
mov v3.16b, v2.16b
movi v10.8b, #0
movi v11.8b, #0
cmp x5, #32
movi v9.8b, #0
mov v2.16b, v1.16b
b.gt Ldec_blocks_more_than_2
b.gt Ldec_blocks_3_remaining
sub w12, w12, #1
mov v3.16b, v1.16b
cmp x5, #16
b.gt Ldec_blocks_more_than_1
b.gt Ldec_blocks_2_remaining
sub w12, w12, #1
b Ldec_blocks_less_than_1
Ldec_blocks_more_than_3: // blocks left > 3
b Ldec_blocks_1_remaining
Ldec_blocks_4_remaining: // blocks left = 4
rev64 v4.16b, v5.16b // GHASH final-3 block
ld1 { v5.16b}, [x0], #16 // AES final-2 block - load ciphertext
stp x6, x7, [x2], #16 // AES final-3 block - store result
Expand All @@ -1458,7 +1443,7 @@ Ldec_blocks_more_than_3: // blocks left > 3
eor x6, x6, x13 // AES final-2 block - round N low
pmull v11.1q, v4.1d, v15.1d // GHASH final-3 block - low
eor x7, x7, x14 // AES final-2 block - round N high
Ldec_blocks_more_than_2: // blocks left > 2
Ldec_blocks_3_remaining: // blocks left = 3
rev64 v4.16b, v5.16b // GHASH final-2 block
ld1 { v5.16b}, [x0], #16 // AES final-1 block - load ciphertext
eor v4.16b, v4.16b, v8.16b // feed in partial tag
Expand All @@ -1477,7 +1462,7 @@ Ldec_blocks_more_than_2: // blocks left > 2
eor x6, x6, x13 // AES final-1 block - round N low
eor v10.16b, v10.16b, v22.16b // GHASH final-2 block - mid
eor x7, x7, x14 // AES final-1 block - round N high
Ldec_blocks_more_than_1: // blocks left > 1
Ldec_blocks_2_remaining: // blocks left = 2
stp x6, x7, [x2], #16 // AES final-1 block - store result
rev64 v4.16b, v5.16b // GHASH final-1 block
ld1 { v5.16b}, [x0], #16 // AES final block - load ciphertext
Expand All @@ -1497,28 +1482,8 @@ Ldec_blocks_more_than_1: // blocks left > 1
eor v9.16b, v9.16b, v20.16b // GHASH final-1 block - high
eor v10.16b, v10.16b, v22.16b // GHASH final-1 block - mid
eor x7, x7, x14 // AES final block - round N high
Ldec_blocks_less_than_1: // blocks left <= 1
and x1, x1, #127 // bit_length %= 128
mvn x14, xzr // rkN_h = 0xffffffffffffffff
sub x1, x1, #128 // bit_length -= 128
mvn x13, xzr // rkN_l = 0xffffffffffffffff
ldp x4, x5, [x2] // load existing bytes we need to not overwrite
neg x1, x1 // bit_length = 128 - #bits in input (in range [1,128])
and x1, x1, #127 // bit_length %= 128
lsr x14, x14, x1 // rkN_h is mask for top 64b of last block
cmp x1, #64
csel x9, x13, x14, lt
csel x10, x14, xzr, lt
fmov d0, x9 // ctr0b is mask for last block
and x6, x6, x9
mov v0.d[1], x10
bic x4, x4, x9 // mask out low existing bytes
Ldec_blocks_1_remaining: // blocks_left = 1
rev w9, w12
bic x5, x5, x10 // mask out high existing bytes
orr x6, x6, x4
and x7, x7, x10
orr x7, x7, x5
and v5.16b, v5.16b, v0.16b // possibly partial last block has zeroes in highest bits
rev64 v4.16b, v5.16b // GHASH final block
eor v4.16b, v4.16b, v8.16b // feed in partial tag
pmull v21.1q, v4.1d, v12.1d // GHASH final block - low
Expand Down
Loading

0 comments on commit 00fcba4

Please sign in to comment.