Skip to content

Commit

Permalink
keccak: BMI1 optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Dec 25, 2024
1 parent f4ac4bf commit a1dc90b
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 97 deletions.
1 change: 1 addition & 0 deletions benchmarks/bench_h_keccak.nim
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,6 @@ when isMainModule:
let iters = int(target_cycles div (s.int64 * worst_cycles_per_bytes))
benchKeccak256_constantine(msg, $s & "B", iters)
benchSHA3_256_openssl(msg, $s & "B", iters)
echo "----"

main()
210 changes: 123 additions & 87 deletions constantine/hashes/h_keccak.nim
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import
constantine/platforms/[abstractions, views],
./keccak/keccak_generic

when UseASM_X86_32:
import ./keccak/keccak_x86_bmi1

# Keccak, the hash function underlying SHA3
# --------------------------------------------------------------------------------
#
Expand Down Expand Up @@ -125,8 +128,110 @@ func init*(ctx: var KeccakContext) {.inline.} =
## Initialize or reinitialize a Keccak context
ctx.reset()

# debug
import constantine/serialization/codecs
template genAbsorb(isaFeatures: untyped) =
func `absorb _ isaFeatures`*(ctx: var KeccakContext, message: openArray[byte]) =
## Absorb a message in the Keccak sponge state
##
## Security note: the tail of your message might be stored
## in an internal buffer.
## if sensitive content is used, ensure that
## `ctx.finish(...)` and `ctx.clear()` are called as soon as possible.
## Additionally ensure that the message(s) passed were stored
## in memory considered secure for your threat model.

var pos = int ctx.absorb_offset # offset in Keccak state
var cur = 0 # offset in message
var bytesLeft = message.len

# We follow the "absorb-permute-squeeze" approach
# originally defined by the Keccak team.
# It is compatible with SHA-3 hash spec.
# See https://eprint.iacr.org/2022/1340.pdf
#
# There are no transition/permutation between squeezing -> absorbing
# And within this `absorb` function
# the state pos == ctx.rate()
# is always followed by a permute and setting `pos = 0`

if (pos mod ctx.rate()) != 0 and pos+bytesLeft >= ctx.rate():
# Previous partial update, fill the state and do one permutation
let free = ctx.rate() - pos
ctx.H.`xorInPartial _ isaFeatures`(pos, message.toOpenArray(0, free-1))
ctx.H.`permute _ isaFeatures`(NumRounds = 24)
pos = 0
cur = free
bytesLeft -= free

if bytesLeft >= ctx.rate():
# Process multiple blocks
let numBlocks = bytesLeft div ctx.rate()
ctx.H.`hashMessageBlocks _ isaFeatures`(message.asUnchecked() +% cur, numBlocks)
cur += numBlocks * ctx.rate()
bytesLeft -= numBlocks * ctx.rate()

if bytesLeft != 0:
# Store the tail in buffer
ctx.H.`xorInPartial _ isaFeatures`(pos, message.toOpenArray(cur, cur+bytesLeft-1))

# Epilogue
ctx.absorb_offset = int32(pos+bytesLeft)
# Signal that the next squeeze transition needs a permute
ctx.squeeze_offset = int32 ctx.rate()

genAbsorb(generic)
when UseASM_X86_32:
genAbsorb(x86_bmi1)

template genSqueeze(isaFeatures: untyped) =
func `squeeze _ isaFeatures`*(ctx: var KeccakContext, digest: var openArray[byte]) =
var pos = ctx.squeeze_offset # offset in Keccak state
var cur = 0 # offset in message
var bytesLeft = digest.len

if pos == ctx.rate():
# Transition from absorbing to squeezing
# This state can only come from `absorb` function
# as within `squeeze`, pos == ctx.rate() is always followed
# by a permute and pos = 0
ctx.H.pad(ctx.absorb_offset, ctx.delimiter, ctx.rate())
ctx.H.`permute _ isaFeatures`(NumRounds = 24)
pos = 0
ctx.absorb_offset = 0

if (pos mod ctx.rate()) != 0 and pos+bytesLeft >= ctx.rate():
# Previous partial squeeze, fill up to rate and do one permutation
let free = ctx.rate() - pos
ctx.H.`copyOutPartial _ isaFeatures`(hByteOffset = pos, digest.toOpenArray(0, free-1))
ctx.H.`permute _ isaFeatures`(NumRounds = 24)
pos = 0
ctx.absorb_offset = 0
cur = free
bytesLeft -= free

if bytesLeft >= ctx.rate():
# Process multiple blocks
let numBlocks = bytesLeft div ctx.rate()
ctx.H.`squeezeDigestBlocks _ isaFeatures`(digest.asUnchecked() +% cur, numBlocks)
ctx.absorb_offset = 0
cur += numBlocks * ctx.rate()
bytesLeft -= numBlocks * ctx.rate()

if bytesLeft != 0:
# Output the tail
ctx.H.`copyOutPartial _ isaFeatures`(hByteOffset = pos, digest.toOpenArray(cur, bytesLeft-1))

# Epilogue
ctx.squeeze_offset = int32 bytesLeft
# We don't signal absorb_offset to permute the state if called next
# as per
# - original keccak spec that uses "absorb-permute-squeeze" protocol
# - https://eprint.iacr.org/2022/1340.pdf
# - https://eprint.iacr.org/2023/522.pdf
# https://hackmd.io/@7dpNYqjKQGeYC7wMlPxHtQ/ByIbpfX9c#2-SAFE-definition

genSqueeze(generic)
when UseASM_X86_32:
genSqueeze(x86_bmi1)

func absorb*(ctx: var KeccakContext, message: openArray[byte]) =
## Absorb a message in the Keccak sponge state
Expand All @@ -137,91 +242,22 @@ func absorb*(ctx: var KeccakContext, message: openArray[byte]) =
## `ctx.finish(...)` and `ctx.clear()` are called as soon as possible.
## Additionally ensure that the message(s) passed were stored
## in memory considered secure for your threat model.

var pos = int ctx.absorb_offset # offset in Keccak state
var cur = 0 # offset in message
var bytesLeft = message.len

# We follow the "absorb-permute-squeeze" approach
# originally defined by the Keccak team.
# It is compatible with SHA-3 hash spec.
# See https://eprint.iacr.org/2022/1340.pdf
#
# There are no transition/permutation between squeezing -> absorbing
# And within this `absorb` function
# the state pos == ctx.rate()
# is always followed by a permute and setting `pos = 0`

if (pos mod ctx.rate()) != 0 and pos+bytesLeft >= ctx.rate():
# Previous partial update, fill the state and do one permutation
let free = ctx.rate() - pos
ctx.H.xorInPartial(pos, message.toOpenArray(0, free-1))
ctx.H.permute_generic(NumRounds = 24)
pos = 0
cur = free
bytesLeft -= free

if bytesLeft >= ctx.rate():
# Process multiple blocks
let numBlocks = bytesLeft div ctx.rate()
ctx.H.hashMessageBlocks_generic(message.asUnchecked() +% cur, numBlocks)
cur += numBlocks * ctx.rate()
bytesLeft -= numBlocks * ctx.rate()

if bytesLeft != 0:
# Store the tail in buffer
ctx.H.xorInPartial(pos, message.toOpenArray(cur, cur+bytesLeft-1))

# Epilogue
ctx.absorb_offset = int32(pos+bytesLeft)
# Signal that the next squeeze transition needs a permute
ctx.squeeze_offset = int32 ctx.rate()

func squeeze*(ctx: var KeccakContext, digest: var openArray[byte]) =
var pos = ctx.squeeze_offset # offset in Keccak state
var cur = 0 # offset in message
var bytesLeft = digest.len

if pos == ctx.rate():
# Transition from absorbing to squeezing
# This state can only come from `absorb` function
# as within `squeeze`, pos == ctx.rate() is always followed
# by a permute and pos = 0
ctx.H.pad(ctx.absorb_offset, ctx.delimiter, ctx.rate())
ctx.H.permute_generic(NumRounds = 24)
pos = 0
ctx.absorb_offset = 0

if (pos mod ctx.rate()) != 0 and pos+bytesLeft >= ctx.rate():
# Previous partial squeeze, fill up to rate and do one permutation
let free = ctx.rate() - pos
ctx.H.copyOutPartial(hByteOffset = pos, digest.toOpenArray(0, free-1))
ctx.H.permute_generic(NumRounds = 24)
pos = 0
ctx.absorb_offset = 0
cur = free
bytesLeft -= free

if bytesLeft >= ctx.rate():
# Process multiple blocks
let numBlocks = bytesLeft div ctx.rate()
ctx.H.squeezeDigestBlocks_generic(digest.asUnchecked() +% cur, numBlocks)
ctx.absorb_offset = 0
cur += numBlocks * ctx.rate()
bytesLeft -= numBlocks * ctx.rate()

if bytesLeft != 0:
# Output the tail
ctx.H.copyOutPartial(hByteOffset = pos, digest.toOpenArray(cur, bytesLeft-1))

# Epilogue
ctx.squeeze_offset = int32 bytesLeft
# We don't signal absorb_offset to permute the state if called next
# as per
# - original keccak spec that uses "absorb-permute-squeeze" protocol
# - https://eprint.iacr.org/2022/1340.pdf
# - https://eprint.iacr.org/2023/522.pdf
# https://hackmd.io/@7dpNYqjKQGeYC7wMlPxHtQ/ByIbpfX9c#2-SAFE-definition
when UseASM_X86_32:
if ({.noSideEffect.}: hasBmi1()):
ctx.absorb_x86_bmi1(message)
else:
ctx.absorb_generic(message)
else:
ctx.absorb_generic(message)

func squeeze*(ctx: var KeccakContext, message: var openArray[byte]) =
when UseASM_X86_32:
if ({.noSideEffect.}: hasBmi1()):
ctx.squeeze_x86_bmi1(message)
else:
ctx.squeeze_generic(message)
else:
ctx.squeeze_generic(message)

func update*(ctx: var KeccakContext, message: openArray[byte]) =
## Append a message to a Keccak context
Expand Down
65 changes: 55 additions & 10 deletions constantine/hashes/keccak/keccak_generic.nim
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,12 @@ func genRho(): array[5*5, int] =
func rotl(x: uint64, k: static int): uint64 {.inline.} =
return (x shl k) or (x shr (64 - k))

func permute_generic*(A: var KeccakState, NumRounds: static int) =
func permute_impl*(A: var KeccakState, NumRounds: static int) {.inline.} =
## Implementation of Keccak permutation
## Tagged inline so it's copied in:
## - keccak_generic.nim
## - keccak_x86_bmi1.nim
## and uses CPU features such as SIMD or andnot instructions
# We use algorithm 4 in https://keccak.team/files/Keccak-implementation-3.2.pdf
const Rho = genRho()

Expand Down Expand Up @@ -219,6 +224,9 @@ func permute_generic*(A: var KeccakState, NumRounds: static int) =
# ι step: break symmetries
A[0, 0] = A[0, 0] xor KRC[i+j]

func permute_generic*(A: var KeccakState, NumRounds: static int) =
permuteImpl(A, NumRounds)

template `^=`(accum: var SomeInteger, b: SomeInteger) =
accum = accum xor b

Expand All @@ -235,7 +243,7 @@ func xorInSingle(H: var KeccakState, hByteOffset: int, val: byte) {.inline.} =
let lane = uint64(val) shl slot # All bits but the one set in `val` are 0, and 0 is neutral element of xor
H.state[hByteOffset shr 3] ^= lane

func xorInBlock_generic(H: var KeccakState, msg: array[200 - 2*32, byte]) {.inline.} =
func xorInBlock(H: var KeccakState, msg: array[200 - 2*32, byte]) {.inline.} =
## Add new data into the Keccak state
# This can benefit from vectorized instructions
for i in 0 ..< msg.len div 8:
Expand Down Expand Up @@ -275,7 +283,7 @@ func copyOutPartialWord(
dst[i] = toByte(lane)
lane = lane shr sizeof(T)

func xorInPartial*(H: var KeccakState, hByteOffset: int, msg: openArray[byte]) =
func xorInPartial_impl*(H: var KeccakState, hByteOffset: int, msg: openArray[byte]) {.inline.} =
## Add multiple bytes to the state
## The hByteOffset+length MUST be less than the state length.
debug: doAssert hByteOffset + msg.len <= sizeof(H.state)
Expand Down Expand Up @@ -317,7 +325,12 @@ func xorInPartial*(H: var KeccakState, hByteOffset: int, msg: openArray[byte]) =
# Store the tail in buffer
H.xorInPartialWord(pos, msg.toOpenArray(cur, cur+bytesLeft-1))

func copyOutPartial*(
func xorInPartial_generic*(H: var KeccakState, hByteOffset: int, msg: openArray[byte]) =
## Add multiple bytes to the state
## The hByteOffset+length MUST be less than the state length.
xorInPartial_impl(H, hByteOffset, msg)

func copyOutPartial_impl*(
H: KeccakState,
hByteOffset: int,
dst: var openArray[byte]) {.inline.} =
Expand Down Expand Up @@ -364,15 +377,25 @@ func copyOutPartial*(
# Store the tail in buffer
H.copyOutPartialWord(pos, dst.toOpenArray(cur, cur+bytesLeft-1))

func copyOutPartial_generic*(
H: KeccakState,
hByteOffset: int,
dst: var openArray[byte]) =
## Read data from the Keccak state
## and write it into `dst`
## starting from the state byte offset `hByteOffset`
## hByteOffset + dst length MUST be less than the Keccak rate
copyOutPartial_impl(H, hByteOffset, dst)

func pad*(H: var KeccakState, hByteOffset: int, delim: static byte, rate: static int) {.inline.} =
debug: doAssert hByteOffset < rate
H.xorInSingle(hByteOffset, delim)
H.xorInSingle(hByteOffset = rate-1, 0x80)

func hashMessageBlocks_generic*(
func hashMessageBlocks_impl*(
H: var KeccakState,
message: ptr UncheckedArray[byte],
numBlocks: int) =
numBlocks: int) {.inline.} =
## Hash a message block by block
## Keccak block size is the rate: 64
## The state MUST be absorb ready
Expand All @@ -384,11 +407,22 @@ func hashMessageBlocks_generic*(
const numRounds = 24 # TODO: auto derive number of rounds
for _ in 0 ..< numBlocks:
let msg = cast[ptr array[rate, byte]](message)
H.xorInBlock_generic(msg[])
H.permute_generic(numRounds)
H.xorInBlock(msg[])
H.permute_impl(numRounds)
message +%= rate

func squeezeDigestBlocks_generic*(
func hashMessageBlocks_generic*(
H: var KeccakState,
message: ptr UncheckedArray[byte],
numBlocks: int) =
## Hash a message block by block
## Keccak block size is the rate: 64
## The state MUST be absorb ready
## i.e. previous operation cannot be a squeeze
## a permutation is needed in-between
hashMessageBlocks_impl(H, message, numBlocks)

func squeezeDigestBlocks_impl*(
H: var KeccakState,
digest: ptr UncheckedArray[byte],
numBlocks: int) =
Expand All @@ -404,4 +438,15 @@ func squeezeDigestBlocks_generic*(
let msg = cast[ptr array[rate, byte]](digest)
H.copyOutWords(msg[])
H.permute_generic(numRounds)
digest +%= rate
digest +%= rate

func squeezeDigestBlocks_generic*(
H: var KeccakState,
digest: ptr UncheckedArray[byte],
numBlocks: int) =
## Squeeze a digest block by block
## Keccak block digest is the rate: 64
## The state MUST be squeeze ready
## i.e. previous operation cannot be an absorb
## a permutation is needed in-between
squeezeDigestBlocks_impl(H, digest, numBlocks)
Loading

0 comments on commit a1dc90b

Please sign in to comment.