diff --git a/ecies.go b/ecies.go index a27c7a8..9739658 100644 --- a/ecies.go +++ b/ecies.go @@ -2,15 +2,19 @@ package eciesgo import ( "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" "fmt" "math/big" ) +type Config struct { + symmetricAlgorithm string + symmetricNonceLength int +} + +var DEFAULT_CONFIG = Config{symmetricAlgorithm: "aes-256-gcm", symmetricNonceLength: 16} + // Encrypt encrypts a passed message with a receiver public key, returns ciphertext or encryption error -func Encrypt(pubkey *PublicKey, msg []byte) ([]byte, error) { +func EncryptConf(pubkey *PublicKey, msg []byte, config Config) ([]byte, error) { var ct bytes.Buffer // Generate ephemeral key @@ -27,38 +31,23 @@ func Encrypt(pubkey *PublicKey, msg []byte) ([]byte, error) { return nil, err } - // AES encryption - block, err := aes.NewCipher(ss) + // Symmetrical encryption + ciphertext, err := EncryptSymm(ss, msg, config) if err != nil { - return nil, fmt.Errorf("cannot create new aes block: %w", err) - } - - nonce := make([]byte, 16) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("cannot read random bytes for nonce: %w", err) - } - - ct.Write(nonce) - - aesgcm, err := cipher.NewGCMWithNonceSize(block, 16) - if err != nil { - return nil, fmt.Errorf("cannot create aes gcm: %w", err) + return nil, err } - ciphertext := aesgcm.Seal(nil, nonce, msg, nil) - - tag := ciphertext[len(ciphertext)-aesgcm.NonceSize():] - ct.Write(tag) - ciphertext = ciphertext[:len(ciphertext)-len(tag)] ct.Write(ciphertext) - return ct.Bytes(), nil } +func Encrypt(pubkey *PublicKey, msg []byte) ([]byte, error) { + return EncryptConf(pubkey, msg, DEFAULT_CONFIG) +} + // Decrypt decrypts a passed message with a receiver private key, returns plaintext or decryption error -func Decrypt(privkey *PrivateKey, msg []byte) ([]byte, error) { - // Message cannot be less than length of public key (65) + nonce (16) + tag (16) - if len(msg) <= (1 + 32 + 32 + 16 + 16) { +func DecryptConf(privkey *PrivateKey, msg []byte, config Config) ([]byte, error) { + if len(msg) <= (1 + 32 + 32) { return nil, fmt.Errorf("invalid length of message") } @@ -69,36 +58,24 @@ func Decrypt(privkey *PrivateKey, msg []byte) ([]byte, error) { Y: new(big.Int).SetBytes(msg[33:65]), } - // Shift message - msg = msg[65:] - // Derive shared secret ss, err := ethPubkey.Decapsulate(privkey) if err != nil { return nil, err } - // AES decryption part - nonce := msg[:16] - tag := msg[16:32] - - // Create Golang-accepted ciphertext - ciphertext := bytes.Join([][]byte{msg[32:], tag}, nil) - - block, err := aes.NewCipher(ss) - if err != nil { - return nil, fmt.Errorf("cannot create new aes block: %w", err) - } - - gcm, err := cipher.NewGCMWithNonceSize(block, 16) - if err != nil { - return nil, fmt.Errorf("cannot create gcm cipher: %w", err) - } + // Shift message + msg = msg[65:] - plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + // Symmetrical decryption + plaintext, err := DecryptSymm(ss, msg, config) if err != nil { - return nil, fmt.Errorf("cannot decrypt ciphertext: %w", err) + return nil, err } return plaintext, nil } + +func Decrypt(privkey *PrivateKey, msg []byte) ([]byte, error) { + return DecryptConf(privkey, msg, DEFAULT_CONFIG) +} diff --git a/ecies_test.go b/ecies_test.go index c0842b9..950f0e6 100644 --- a/ecies_test.go +++ b/ecies_test.go @@ -56,15 +56,15 @@ func BenchmarkDecrypt(b *testing.B) { } } -func TestEncryptAndDecrypt(t *testing.T) { +func testEncryptAndDecryptParameters(conf Config, t *testing.T) { privkey := NewPrivateKeyFromBytes(testingReceiverPrivkey) - ciphertext, err := Encrypt(privkey.PublicKey, []byte(testingMessage)) + ciphertext, err := EncryptConf(privkey.PublicKey, []byte(testingMessage), conf) if !assert.NoError(t, err) { return } - plaintext, err := Decrypt(privkey, ciphertext) + plaintext, err := DecryptConf(privkey, ciphertext, conf) if !assert.NoError(t, err) { return } @@ -72,6 +72,13 @@ func TestEncryptAndDecrypt(t *testing.T) { assert.Equal(t, testingMessage, string(plaintext)) } +func TestEncryptAndDecrypt(t *testing.T) { + testEncryptAndDecryptParameters(DEFAULT_CONFIG, t) + testEncryptAndDecryptParameters(Config{symmetricAlgorithm: "aes-256-gcm", symmetricNonceLength: 12}, t) + testEncryptAndDecryptParameters(Config{symmetricAlgorithm: "aes-256-gcm", symmetricNonceLength: 16}, t) + testEncryptAndDecryptParameters(Config{symmetricAlgorithm: "xchacha20"}, t) +} + func TestPublicKeyDecompression(t *testing.T) { // Generate public key privkey, err := GenerateKey() diff --git a/go.mod b/go.mod index 10ca110..dab43d3 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 github.com/ethereum/go-ethereum v1.13.10 github.com/stretchr/testify v1.8.4 - golang.org/x/crypto v0.18.0 + golang.org/x/crypto v0.19.0 ) go 1.13 diff --git a/go.sum b/go.sum index 2a1ea8e..44f1a2c 100644 --- a/go.sum +++ b/go.sum @@ -726,8 +726,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -945,7 +945,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -959,7 +960,7 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= -golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/privatekey.go b/privatekey.go index a2db9a2..3d73f1e 100644 --- a/privatekey.go +++ b/privatekey.go @@ -127,9 +127,5 @@ func (k *PrivateKey) ECDH(pub *PublicKey) ([]byte, error) { // Equals compares two private keys with constant time (to resist timing attacks) func (k *PrivateKey) Equals(priv *PrivateKey) bool { - if subtle.ConstantTimeCompare(k.D.Bytes(), priv.D.Bytes()) == 1 { - return true - } - - return false + return subtle.ConstantTimeCompare(k.D.Bytes(), priv.D.Bytes()) == 1 } diff --git a/symm.go b/symm.go new file mode 100644 index 0000000..e1f1dd5 --- /dev/null +++ b/symm.go @@ -0,0 +1,90 @@ +package eciesgo + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + + "golang.org/x/crypto/chacha20poly1305" +) + +func generateSymmCipher(key []byte, conf Config) (cipher.AEAD, error) { + var err error + var aead cipher.AEAD + + switch conf.symmetricAlgorithm { + case "aes-256-gcm": + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("cannot create new AES block: %w", err) + } + + aead, err = cipher.NewGCMWithNonceSize(block, conf.symmetricNonceLength) + if err != nil { + return nil, fmt.Errorf("cannot create AES GCM: %w", err) + } + case "xchacha20": + aead, err = chacha20poly1305.NewX(key) + if err != nil { + return nil, fmt.Errorf("cannot create XChaCha20: %w", err) + } + default: + return nil, fmt.Errorf("unknown cipher: %s", conf.symmetricAlgorithm) + } + + return aead, nil +} + +func EncryptSymm(key []byte, msg []byte, conf Config) ([]byte, error) { + var ct bytes.Buffer + + aead, err := generateSymmCipher(key, conf) + if err != nil { + return nil, err + } + + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("cannot read random bytes for nonce: %w", err) + } + + ct.Write(nonce) + + ciphertext := aead.Seal(nil, nonce, msg, nil) + + tag := ciphertext[len(ciphertext)-aead.Overhead():] + ct.Write(tag) + ciphertext = ciphertext[:len(ciphertext)-len(tag)] + ct.Write(ciphertext) + + return ct.Bytes(), nil +} + +func DecryptSymm(key []byte, msg []byte, conf Config) ([]byte, error) { + aead, err := generateSymmCipher(key, conf) + if err != nil { + return nil, err + } + + // Message cannot be less than length of public key (65) + nonce + tag (16) + if len(msg) <= (aead.NonceSize() + aead.Overhead()) { + return nil, fmt.Errorf("invalid length of message") + } + + // Symmetrical decryption part + nonce := msg[:aead.NonceSize()] + tag := msg[aead.NonceSize() : aead.NonceSize()+aead.Overhead()] + msg = msg[aead.NonceSize()+aead.Overhead():] + + // Create Golang-accepted ciphertext + ciphertext := bytes.Join([][]byte{msg, tag}, nil) + + plaintext, err := aead.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("cannot decrypt ciphertext: %v", err) + } + + return plaintext, nil +}