Skip to content

Commit

Permalink
Simplify hash implementation (#237)
Browse files Browse the repository at this point in the history
* simplify hash implementations

* fix openssl 3

* fix AZL 3
  • Loading branch information
qmuntal authored Dec 20, 2024
1 parent 7b07994 commit 6e4e578
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 651 deletions.
134 changes: 72 additions & 62 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,8 @@ func hashFuncHash(fn func() hash.Hash) (h hash.Hash, err error) {

// hashToMD converts a hash.Hash implementation from this package to a GO_EVP_MD_PTR.
func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR {
var ch crypto.Hash
switch h.(type) {
case *sha1Hash, *sha1Marshal:
ch = crypto.SHA1
case *sha224Hash, *sha224Marshal:
ch = crypto.SHA224
case *sha256Hash, *sha256Marshal:
ch = crypto.SHA256
case *sha384Hash, *sha384Marshal:
ch = crypto.SHA384
case *sha512Hash, *sha512Marshal:
ch = crypto.SHA512
case *sha3_224Hash:
ch = crypto.SHA3_224
case *sha3_256Hash:
ch = crypto.SHA3_256
case *sha3_384Hash:
ch = crypto.SHA3_384
case *sha3_512Hash:
ch = crypto.SHA3_512
}
if ch != 0 {
return cryptoHashToMD(ch)
if h, ok := h.(*evpHash); ok {
return h.alg.md
}
return nil
}
Expand All @@ -89,78 +68,109 @@ func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) {
return md, nil
}

// cryptoHashToMD converts a crypto.Hash to a EVP_MD.
func cryptoHashToMD(ch crypto.Hash) C.GO_EVP_MD_PTR {
type hashAlgorithm struct {
md C.GO_EVP_MD_PTR
ch crypto.Hash
size int
blockSize int
marshallable bool
magic string
marshalledSize int
}

// loadHash converts a crypto.Hash to a EVP_MD.
func loadHash(ch crypto.Hash) *hashAlgorithm {
if v, ok := cacheMD.Load(ch); ok {
return v.(C.GO_EVP_MD_PTR)
return v.(*hashAlgorithm)
}
var md C.GO_EVP_MD_PTR

var hash hashAlgorithm
switch ch {
case crypto.RIPEMD160:
md = C.go_openssl_EVP_ripemd160()
hash.md = C.go_openssl_EVP_ripemd160()
case crypto.MD4:
md = C.go_openssl_EVP_md4()
hash.md = C.go_openssl_EVP_md4()
case crypto.MD5:
md = C.go_openssl_EVP_md5()
hash.md = C.go_openssl_EVP_md5()
hash.magic = md5Magic
hash.marshalledSize = md5MarshaledSize
case crypto.MD5SHA1:
if vMajor == 1 && vMinor == 0 {
md = C.go_openssl_EVP_md5_sha1_backport()
hash.md = C.go_openssl_EVP_md5_sha1_backport()
} else {
md = C.go_openssl_EVP_md5_sha1()
hash.md = C.go_openssl_EVP_md5_sha1()
}
case crypto.SHA1:
md = C.go_openssl_EVP_sha1()
hash.md = C.go_openssl_EVP_sha1()
hash.magic = sha1Magic
hash.marshalledSize = sha1MarshaledSize
case crypto.SHA224:
md = C.go_openssl_EVP_sha224()
hash.md = C.go_openssl_EVP_sha224()
hash.magic = magic224
hash.marshalledSize = marshaledSize256
case crypto.SHA256:
md = C.go_openssl_EVP_sha256()
hash.md = C.go_openssl_EVP_sha256()
hash.magic = magic256
hash.marshalledSize = marshaledSize256
case crypto.SHA384:
md = C.go_openssl_EVP_sha384()
hash.md = C.go_openssl_EVP_sha384()
hash.magic = magic384
hash.marshalledSize = marshaledSize512
case crypto.SHA512:
md = C.go_openssl_EVP_sha512()
hash.md = C.go_openssl_EVP_sha512()
hash.magic = magic512
hash.marshalledSize = marshaledSize512
case crypto.SHA512_224:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha512_224()
hash.md = C.go_openssl_EVP_sha512_224()
hash.magic = magic512_224
hash.marshalledSize = marshaledSize512
}
case crypto.SHA512_256:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha512_256()
hash.md = C.go_openssl_EVP_sha512_256()
hash.magic = magic512_256
hash.marshalledSize = marshaledSize512
}
case crypto.SHA3_224:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha3_224()
hash.md = C.go_openssl_EVP_sha3_224()
}
case crypto.SHA3_256:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha3_256()
hash.md = C.go_openssl_EVP_sha3_256()
}
case crypto.SHA3_384:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha3_384()
hash.md = C.go_openssl_EVP_sha3_384()
}
case crypto.SHA3_512:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha3_512()
hash.md = C.go_openssl_EVP_sha3_512()
}
}
if md == nil {
cacheMD.Store(ch, nil)
if hash.md == nil {
cacheMD.Store(ch, (*hashAlgorithm)(nil))
return nil
}
hash.ch = ch
hash.size = int(C.go_openssl_EVP_MD_get_size(hash.md))
hash.blockSize = int(C.go_openssl_EVP_MD_get_block_size(hash.md))
if vMajor == 3 {
// On OpenSSL 3, directly operating on a EVP_MD object
// not created by EVP_MD_fetch has negative performance
// implications, as digest operations will have
// to fetch it on every call. Better to just fetch it once here.
md1 := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(md), nil)
md := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(hash.md), nil)
// Don't overwrite md in case it can't be fetched, as the md may still be used
// outside of EVP_MD_CTX, for example to sign and verify RSA signatures.
if md1 != nil {
md = md1
if md != nil {
hash.md = md
}
}
cacheMD.Store(ch, md)
return md
hash.marshallable = hash.magic != "" && isHashMarshallable(hash.md)
cacheMD.Store(ch, &hash)
return &hash
}

// generateEVPPKey generates a new EVP_PKEY with the given id and properties.
Expand Down Expand Up @@ -302,11 +312,11 @@ func setupEVP(withKey withKeyFunc, padding C.int,
}
}
case C.GO_RSA_PKCS1_PSS_PADDING:
md := cryptoHashToMD(ch)
if md == nil {
alg := loadHash(ch)
if alg == nil {
return nil, errors.New("crypto/rsa: unsupported hash function")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(md)) != 1 {
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(alg.md)) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
}
// setPadding must happen after setting EVP_PKEY_CTRL_MD.
Expand All @@ -322,11 +332,11 @@ func setupEVP(withKey withKeyFunc, padding C.int,
case C.GO_RSA_PKCS1_PADDING:
if ch != 0 {
// We support unhashed messages.
md := cryptoHashToMD(ch)
if md == nil {
alg := loadHash(ch)
if alg == nil {
return nil, errors.New("crypto/rsa: unsupported hash function")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(md)) != 1 {
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(alg.md)) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
}
if err := setPadding(); err != nil {
Expand Down Expand Up @@ -441,8 +451,8 @@ func evpVerify(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash,
}

func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) {
md := cryptoHashToMD(h)
if md == nil {
alg := loadHash(h)
if alg == nil {
return nil, errors.New("unsupported hash function: " + strconv.Itoa(int(h)))
}
var out []byte
Expand All @@ -453,7 +463,7 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error)
}
defer C.go_openssl_EVP_MD_CTX_free(ctx)
if withKey(func(key C.GO_EVP_PKEY_PTR) C.int {
return C.go_openssl_EVP_DigestSignInit(ctx, nil, md, nil, key)
return C.go_openssl_EVP_DigestSignInit(ctx, nil, alg.md, nil, key)
}) != 1 {
return nil, newOpenSSLError("EVP_DigestSignInit failed")
}
Expand All @@ -473,8 +483,8 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error)
}

func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error {
md := cryptoHashToMD(h)
if md == nil {
alg := loadHash(h)
if alg == nil {
return errors.New("unsupported hash function: " + strconv.Itoa(int(h)))
}
ctx := C.go_openssl_EVP_MD_CTX_new()
Expand All @@ -483,7 +493,7 @@ func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error {
}
defer C.go_openssl_EVP_MD_CTX_free(ctx)
if withKey(func(key C.GO_EVP_PKEY_PTR) C.int {
return C.go_openssl_EVP_DigestVerifyInit(ctx, nil, md, nil, key)
return C.go_openssl_EVP_DigestVerifyInit(ctx, nil, alg.md, nil, key)
}) != 1 {
return newOpenSSLError("EVP_DigestVerifyInit failed")
}
Expand Down
Loading

0 comments on commit 6e4e578

Please sign in to comment.