From 3d8d9532779dbde76c4330f01f35e131db0278be Mon Sep 17 00:00:00 2001 From: Mohamed Mokhtar Date: Fri, 1 Apr 2022 17:30:16 +0200 Subject: [PATCH] Added signature/verification for pub/priv --- cry/crypto.go | 143 ++++++++++++++++++++++++++++++- cry/pss.go | 231 ++++++++++++++++++++++++++++++++++++++++++++++++++ rsam.go | 26 +++--- rsam_test.go | 2 +- 4 files changed, 387 insertions(+), 15 deletions(-) create mode 100644 cry/pss.go diff --git a/cry/crypto.go b/cry/crypto.go index a63bd10..61a687e 100644 --- a/cry/crypto.go +++ b/cry/crypto.go @@ -88,6 +88,141 @@ func SignByPublic(rand io.Reader, pub *rsa.PublicKey, hash crypto.Hash, hashed [ return c.FillBytes(em), nil } +func SignPSSByPublic(rand io.Reader, pub *rsa.PublicKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) { + if opts != nil && opts.Hash != 0 { + hash = opts.Hash + } + + saltLength := opts.saltLength() + switch saltLength { + case PSSSaltLengthAuto: + saltLength = pub.Size() - 2 - hash.Size() + case PSSSaltLengthEqualsHash: + saltLength = hash.Size() + } + + salt := make([]byte, saltLength) + if rand != nil { + if _, err := io.ReadFull(rand, salt); err != nil { + return nil, err + } + } + return signPSSWithSaltByPublic(rand, pub, hash, digest, salt) +} + +// signPSSWithSaltByPublic calculates the signature of hashed using PSS with specified salt. +// Note that hashed must be the result of hashing the input message using the +// given hash function. salt is a random sequence of bytes whose length will be +// later used to verify the signature. +func signPSSWithSaltByPublic(rand io.Reader, pub *rsa.PublicKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { + emBits := pub.N.BitLen() - 1 + em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) + if err != nil { + return nil, err + } + m := new(big.Int).SetBytes(em) + c, err := decryptAndCheckByPublic(rand, pub, m) + if err != nil { + return nil, err + } + s := make([]byte, pub.Size()) + return c.FillBytes(s), nil +} + +// signPSSWithSaltByPublic calculates the signature of hashed using PSS with specified salt. +// Note that hashed must be the result of hashing the input message using the +// given hash function. salt is a random sequence of bytes whose length will be +// later used to verify the signature. +func signPSSWithSalt(rand io.Reader, priv *rsa.PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { + emBits := priv.N.BitLen() - 1 + em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) + if err != nil { + return nil, err + } + m := new(big.Int).SetBytes(em) + c, err := decryptAndCheck(rand, priv, m) + if err != nil { + return nil, err + } + s := make([]byte, priv.Size()) + return c.FillBytes(s), nil +} + +func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) { + // See RFC 8017, Section 9.1.1. + + hLen := hash.Size() + sLen := len(salt) + emLen := (emBits + 7) / 8 + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + + if len(mHash) != hLen { + return nil, errors.New("crypto/rsa: input must be hashed with given hash") + } + + // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. + + if emLen < hLen+sLen+2 { + return nil, errors.New("crypto/rsa: key size too small for PSS signature") + } + + em := make([]byte, emLen) + psLen := emLen - sLen - hLen - 2 + db := em[:psLen+1+sLen] + h := em[psLen+1+sLen : emLen-1] + + // 4. Generate a random octet string salt of length sLen; if sLen = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; + // + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length hLen. + + var prefix [8]byte + + hash.Write(prefix[:]) + hash.Write(mHash) + hash.Write(salt) + + h = hash.Sum(h[:0]) + hash.Reset() + + // 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + + db[psLen] = 0x01 + copy(db[psLen+1:], salt) + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + + mgf1XOR(db, hash, h) + + // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in + // maskedDB to zero. + + db[0] &= 0xff >> (8*emLen - emBits) + + // 12. Let EM = maskedDB || H || 0xbc. + em[emLen-1] = 0xbc + + // 13. Output EM. + return em, nil +} + func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte, err error) { // Special case: crypto.Hash(0) is used to indicate that the data is // signed directly. @@ -160,10 +295,10 @@ func EncryptOAEPM(hash hash.Hash, random io.Reader, priv *rsa.PrivateKey, msg [] db[len(db)-len(msg)-1] = 1 copy(db[len(db)-len(msg):], msg) if random != nil { - _, err := io.ReadFull(random, seed) - if err != nil { - return nil, err - } + _, err := io.ReadFull(random, seed) + if err != nil { + return nil, err + } } mgf1XOR(db, hash, seed) diff --git a/cry/pss.go b/cry/pss.go new file mode 100644 index 0000000..45f186a --- /dev/null +++ b/cry/pss.go @@ -0,0 +1,231 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cry + +// This file implements the RSASSA-PSS signature scheme according to RFC 8017. + +import ( + "bytes" + "crypto" + "crypto/rsa" + "errors" + "hash" + "io" + "math/big" +) + +// Per RFC 8017, Section 9.1 +// +// EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc +// +// where +// +// DB = PS || 0x01 || salt +// +// and PS can be empty so +// +// emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2 +// + +func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { + // See RFC 8017, Section 9.1.2. + + hLen := hash.Size() + if sLen == PSSSaltLengthEqualsHash { + sLen = hLen + } + emLen := (emBits + 7) / 8 + if emLen != len(em) { + return errors.New("rsa: internal error: inconsistent length") + } + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" + // and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + if hLen != len(mHash) { + return rsa.ErrVerification + } + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + if emLen < hLen+sLen+2 { + return rsa.ErrVerification + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if em[emLen-1] != 0xbc { + return rsa.ErrVerification + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and + // let H be the next hLen octets. + db := em[:emLen-hLen-1] + h := em[emLen-hLen-1 : emLen-1] + + // 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + var bitMask byte = 0xff >> (8*emLen - emBits) + if em[0] & ^bitMask != 0 { + return rsa.ErrVerification + } + + // 7. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 8. Let DB = maskedDB \xor dbMask. + mgf1XOR(db, hash, h) + + // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB + // to zero. + db[0] &= bitMask + + // If we don't know the salt length, look for the 0x01 delimiter. + if sLen == PSSSaltLengthAuto { + psLen := bytes.IndexByte(db, 0x01) + if psLen < 0 { + return rsa.ErrVerification + } + sLen = len(db) - psLen - 1 + } + + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero + // or if the octet at position emLen - hLen - sLen - 1 (the leftmost + // position is "position 1") does not have hexadecimal value 0x01, + // output "inconsistent" and stop. + psLen := emLen - hLen - sLen - 2 + for _, e := range db[:psLen] { + if e != 0x00 { + return rsa.ErrVerification + } + } + if db[psLen] != 0x01 { + return rsa.ErrVerification + } + + // 11. Let salt be the last sLen octets of DB. + salt := db[len(db)-sLen:] + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + var prefix [8]byte + hash.Write(prefix[:]) + hash.Write(mHash) + hash.Write(salt) + + h0 := hash.Sum(nil) + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if !bytes.Equal(h0, h) { // TODO: constant time? + return rsa.ErrVerification + } + return nil +} + +const ( + // PSSSaltLengthAuto causes the salt in a PSS signature to be as large + // as possible when signing, and to be auto-detected when verifying. + PSSSaltLengthAuto = 0 + // PSSSaltLengthEqualsHash causes the salt length to equal the length + // of the hash used in the signature. + PSSSaltLengthEqualsHash = -1 +) + +// PSSOptions contains options for creating and verifying PSS signatures. +type PSSOptions struct { + // SaltLength controls the length of the salt used in the PSS + // signature. It can either be a number of bytes, or one of the special + // PSSSaltLength constants. + SaltLength int + + // Hash is the hash function used to generate the message digest. If not + // zero, it overrides the hash function passed to SignPSS. It's required + // when using PrivateKey.Sign. + Hash crypto.Hash +} + +// HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts. +func (opts *PSSOptions) HashFunc() crypto.Hash { + return opts.Hash +} + +func (opts *PSSOptions) saltLength() int { + if opts == nil { + return PSSSaltLengthAuto + } + return opts.SaltLength +} + +// SignPSS calculates the signature of digest using PSS. +// +// digest must be the result of hashing the input message using the given hash +// function. The opts argument may be nil, in which case sensible defaults are +// used. If opts.Hash is set, it overrides hash. +func SignPSS(rand io.Reader, priv *rsa.PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) { + if opts != nil && opts.Hash != 0 { + hash = opts.Hash + } + + saltLength := opts.saltLength() + switch saltLength { + case PSSSaltLengthAuto: + saltLength = priv.Size() - 2 - hash.Size() + case PSSSaltLengthEqualsHash: + saltLength = hash.Size() + } + + salt := make([]byte, saltLength) + if _, err := io.ReadFull(rand, salt); err != nil { + return nil, err + } + return signPSSWithSalt(rand, priv, hash, digest, salt) +} + +// VerifyPSS verifies a PSS signature. +// +// A valid signature is indicated by returning a nil error. digest must be the +// result of hashing the input message using the given hash function. The opts +// argument may be nil, in which case sensible defaults are used. opts.Hash is +// ignored. +func VerifyPSS(pub *rsa.PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error { + if len(sig) != pub.Size() { + return rsa.ErrVerification + } + s := new(big.Int).SetBytes(sig) + m := encrypt(new(big.Int), pub, s) + emBits := pub.N.BitLen() - 1 + emLen := (emBits + 7) / 8 + if m.BitLen() > emLen*8 { + return rsa.ErrVerification + } + em := m.FillBytes(make([]byte, emLen)) + return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New()) +} + +// VerifyPSS verifies a PSS signature. +// +// A valid signature is indicated by returning a nil error. digest must be the +// result of hashing the input message using the given hash function. The opts +// argument may be nil, in which case sensible defaults are used. opts.Hash is +// ignored. +func VerifyPSSByPrivate(priv *rsa.PrivateKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error { + if len(sig) != priv.Size() { + return rsa.ErrVerification + } + s := new(big.Int).SetBytes(sig) + m := encryptByPrivateKey(new(big.Int), priv, s) + emBits := priv.N.BitLen() - 1 + emLen := (emBits + 7) / 8 + if m.BitLen() > emLen*8 { + return rsa.ErrVerification + } + em := m.FillBytes(make([]byte, emLen)) + return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New()) +} diff --git a/rsam.go b/rsam.go index 616397f..690271c 100644 --- a/rsam.go +++ b/rsam.go @@ -160,14 +160,14 @@ func EncryptWithPrivateKey(msg []byte, priv *rsa.PrivateKey, hash hash.Hash) ([] // Decrypts data with private key func DecryptWithPublicKey(ciphertext []byte, pub *rsa.PublicKey, hash hash.Hash) ([]byte, error) { - var plaintext,m []byte + var plaintext, m []byte chunkSize := pub.Size() var err error for i := 0; i < len(ciphertext); i += chunkSize { if i+chunkSize > len(ciphertext) { m, err = cry.DecryptOAEPM(hash, nil, pub, ciphertext[i:], nil) } else { - m, err = cry.DecryptOAEPM(hash, nil, pub, ciphertext[i : i+chunkSize], nil) + m, err = cry.DecryptOAEPM(hash, nil, pub, ciphertext[i:i+chunkSize], nil) } if err != nil { return nil, err @@ -178,10 +178,13 @@ func DecryptWithPublicKey(ciphertext []byte, pub *rsa.PublicKey, hash hash.Hash) } // Decrypts data with private key -func SignWithPrivateKey(msg []byte, priv *rsa.PrivateKey) ([]byte, error) { - hash := crypto.Hash(crypto.SHA512) +func SignWithPrivateKey(msg []byte, priv *rsa.PrivateKey, hashOpt ...crypto.Hash) ([]byte, error) { + hash := crypto.Hash(crypto.SHA256) + if len(hashOpt) > 0 { + hash = crypto.Hash(hashOpt[0]) + } hashed := sha256.Sum256(msg) - signature, err := rsa.SignPKCS1v15(rand.Reader, priv, hash, hashed[:]) + signature, err := rsa.SignPSS(rand.Reader, priv, hash, hashed[:], nil) if err != nil { return nil, err } @@ -189,24 +192,27 @@ func SignWithPrivateKey(msg []byte, priv *rsa.PrivateKey) ([]byte, error) { } // Decrypts data with private key -func VerifyWithPublicKey(msg []byte, signature []byte, pub *rsa.PublicKey) error { - hash := crypto.Hash(crypto.SHA512) +func VerifyWithPublicKey(msg []byte, signature []byte, pub *rsa.PublicKey, hashOpt ...crypto.Hash) error { + hash := crypto.Hash(crypto.SHA256) + if len(hashOpt) > 0 { + hash = crypto.Hash(hashOpt[0]) + } hashed := sha256.Sum256(msg) - return rsa.VerifyPKCS1v15(pub, hash, hashed[:], signature) + return rsa.VerifyPSS(pub, hash, hashed[:], signature, nil) } // Decrypts data with private key func VerifyWithPrivateKey(msg []byte, signature []byte, priv *rsa.PrivateKey) error { hash := crypto.Hash(crypto.SHA256) hashed := sha256.Sum256(msg) - return cry.VerifyByPrivate(priv, hash, hashed[:], signature) + return cry.VerifyPSSByPrivate(priv, hash, hashed[:], signature, nil) } // Encrypts data with public key func SignWithPublicKey(msg []byte, pub *rsa.PublicKey) ([]byte, error) { hash := crypto.Hash(crypto.SHA256) hashed := sha256.Sum256(msg) - return cry.SignByPublic(rand.Reader, pub, hash, hashed[:]) + return cry.SignPSSByPublic(nil, pub, hash, hashed[:], nil) } // Returns base64 encoded key from file diff --git a/rsam_test.go b/rsam_test.go index 3086255..93ade3e 100644 --- a/rsam_test.go +++ b/rsam_test.go @@ -126,4 +126,4 @@ MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD1/618vsRezcClpPzD4fbusyH6 3UiM/ZT7qzNtCIvyqW0xSzDH0t1/dmjQH4mOaiel2jQT9fQnSBKmGuG5GZ2ueDRg qrLZd3resJix9V3Q9tvSxevPsvOiSlV3Z8QXTMTMSQBA1DNpSDsi8inupoJSUUiZ +8k8d3eMJ2gBIgZ5qQIDAQAB ------END PUBLIC KEY-----` \ No newline at end of file +-----END PUBLIC KEY-----`