From c8f7184828ee4fd4ea4a9edf4e09f4287fcf889a Mon Sep 17 00:00:00 2001 From: Davis Goodin Date: Fri, 5 Jan 2024 14:09:47 -0800 Subject: [PATCH] Add NewGCMTLS13 (#51) --- cng/aes.go | 63 +++++++++++++++++++++++--- cng/aes_test.go | 115 +++++++++++++++++++++++++++++------------------- 2 files changed, 126 insertions(+), 52 deletions(-) diff --git a/cng/aes.go b/cng/aes.go index bf68ece..e399eba 100644 --- a/cng/aes.go +++ b/cng/aes.go @@ -106,7 +106,7 @@ func (c *aesCipher) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) { if tagSize != gcmTagSize { return cipher.NewGCMWithTagSize(&noGCM{c}, tagSize) } - return newGCM(c.key, false) + return newGCM(c.key, cipherGCMTLSNone) } // NewGCMTLS returns a GCM cipher specific to TLS @@ -116,7 +116,17 @@ func NewGCMTLS(c cipher.Block) (cipher.AEAD, error) { } func (c *aesCipher) NewGCMTLS() (cipher.AEAD, error) { - return newGCM(c.key, true) + return newGCM(c.key, cipherGCMTLS12) +} + +// NewGCMTLS13 returns a GCM cipher specific to TLS 1.3 and should not be used +// for non-TLS purposes. +func NewGCMTLS13(c cipher.Block) (cipher.AEAD, error) { + return c.(*aesCipher).NewGCMTLS13() +} + +func (c *aesCipher) NewGCMTLS13() (cipher.AEAD, error) { + return newGCM(c.key, cipherGCMTLS13) } type cbcCipher struct { @@ -197,17 +207,32 @@ const ( gcmTlsFixedNonceSize = 4 ) +type cipherGCMTLS uint8 + +const ( + cipherGCMTLSNone cipherGCMTLS = iota + cipherGCMTLS12 + cipherGCMTLS13 +) + type aesGCM struct { - kh bcrypt.KEY_HANDLE - tls bool + kh bcrypt.KEY_HANDLE + tls cipherGCMTLS + // minNextNonce is the minimum value that the next nonce can be, enforced by + // all TLS modes. minNextNonce uint64 + // mask is the nonce mask used in TLS 1.3 mode. + mask uint64 + // maskInitialized is true if mask has been initialized. This happens during + // the first Seal. The initialized mask may be 0. Used by TLS 1.3 mode. + maskInitialized bool } func (g *aesGCM) finalize() { bcrypt.DestroyKey(g.kh) } -func newGCM(key []byte, tls bool) (*aesGCM, error) { +func newGCM(key []byte, tls cipherGCMTLS) (*aesGCM, error) { kh, err := newCipherHandle(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM, key) if err != nil { return nil, err @@ -235,15 +260,39 @@ func (g *aesGCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte { if len(dst)+len(plaintext)+gcmTagSize < len(dst) { panic("cipher: message too large for buffer") } - if g.tls { + if g.tls != cipherGCMTLSNone { if len(additionalData) != gcmTlsAddSize { panic("cipher: incorrect additional data length given to GCM TLS") } + counter := bigUint64(nonce[gcmTlsFixedNonceSize:]) + if g.tls == cipherGCMTLS13 { + // In TLS 1.3, the counter in the nonce has a mask and requires + // further decoding. + if !g.maskInitialized { + // According to TLS 1.3 nonce construction details at + // https://tools.ietf.org/html/rfc8446#section-5.3: + // + // the first record transmitted under a particular traffic + // key MUST use sequence number 0. + // + // The padded sequence number is XORed with [a mask]. + // + // The resulting quantity (of length iv_length) is used as + // the per-record nonce. + // + // We need to convert from the given nonce to sequence numbers + // to keep track of minNextNonce and enforce the counter + // maximum. On the first call, we know counter^mask is 0^mask, + // so we can simply store it as the mask. + g.mask = counter + g.maskInitialized = true + } + counter ^= g.mask + } // BoringCrypto enforces strictly monotonically increasing explicit nonces // and to fail after 2^64 - 1 keys as per FIPS 140-2 IG A.5, // but BCrypt does not perform this check, so it is implemented here. const maxUint64 = 1<<64 - 1 - counter := bigUint64(nonce[gcmTlsFixedNonceSize:]) if counter == maxUint64 { panic("cipher: nonce counter must be less than 2^64 - 1") } diff --git a/cng/aes_test.go b/cng/aes_test.go index 4052ab9..f8f4fec 100644 --- a/cng/aes_test.go +++ b/cng/aes_test.go @@ -71,51 +71,76 @@ func TestSealAndOpen(t *testing.T) { } func TestSealAndOpenTLS(t *testing.T) { - ci, err := NewAESCipher(key) - if err != nil { - t.Fatal(err) - } - gcm, err := NewGCMTLS(ci) - if err != nil { - t.Fatal(err) - } - nonce := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - nonce1 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - nonce9 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9} - nonce10 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10} - nonceMax := [12]byte{0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255} - plainText := []byte{0x01, 0x02, 0x03} - additionalData := make([]byte, 13) - additionalData[11] = byte(len(plainText) >> 8) - additionalData[12] = byte(len(plainText)) - sealed := gcm.Seal(nil, nonce[:], plainText, additionalData) - assertPanic(t, func() { - gcm.Seal(nil, nonce[:], plainText, additionalData) - }) - sealed1 := gcm.Seal(nil, nonce1[:], plainText, additionalData) - gcm.Seal(nil, nonce10[:], plainText, additionalData) - assertPanic(t, func() { - gcm.Seal(nil, nonce9[:], plainText, additionalData) - }) - assertPanic(t, func() { - gcm.Seal(nil, nonceMax[:], plainText, additionalData) - }) - if bytes.Equal(sealed, sealed1) { - t.Errorf("different nonces should produce different outputs\ngot: %#v\nexp: %#v", sealed, sealed1) - } - decrypted, err := gcm.Open(nil, nonce[:], sealed, additionalData) - if err != nil { - t.Error(err) - } - decrypted1, err := gcm.Open(nil, nonce1[:], sealed1, additionalData) - if err != nil { - t.Error(err) - } - if !bytes.Equal(decrypted, plainText) { - t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText) - } - if !bytes.Equal(decrypted, decrypted1) { - t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, decrypted1) + tests := []struct { + name string + new func(c cipher.Block) (cipher.AEAD, error) + mask func(n *[12]byte) + }{ + {"1.2", NewGCMTLS, nil}, + {"1.3", NewGCMTLS13, nil}, + {"1.3_masked", NewGCMTLS13, func(n *[12]byte) { + // Arbitrary mask in the high bits. + n[9] ^= 0x42 + // Mask the very first bit. This makes sure that if Seal doesn't + // handle the mask, the counter appears to go backwards and panics + // when it shouldn't. + n[11] ^= 0x1 + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ci, err := NewAESCipher(key) + if err != nil { + t.Fatal(err) + } + gcm, err := tt.new(ci) + if err != nil { + t.Fatal(err) + } + nonce := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + nonce1 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + nonce9 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9} + nonce10 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10} + nonceMax := [12]byte{0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255} + if tt.mask != nil { + for _, m := range []*[12]byte{&nonce, &nonce1, &nonce9, &nonce10, &nonceMax} { + tt.mask(m) + } + } + plainText := []byte{0x01, 0x02, 0x03} + additionalData := make([]byte, 13) + additionalData[11] = byte(len(plainText) >> 8) + additionalData[12] = byte(len(plainText)) + sealed := gcm.Seal(nil, nonce[:], plainText, additionalData) + assertPanic(t, func() { + gcm.Seal(nil, nonce[:], plainText, additionalData) + }) + sealed1 := gcm.Seal(nil, nonce1[:], plainText, additionalData) + gcm.Seal(nil, nonce10[:], plainText, additionalData) + assertPanic(t, func() { + gcm.Seal(nil, nonce9[:], plainText, additionalData) + }) + assertPanic(t, func() { + gcm.Seal(nil, nonceMax[:], plainText, additionalData) + }) + if bytes.Equal(sealed, sealed1) { + t.Errorf("different nonces should produce different outputs\ngot: %#v\nexp: %#v", sealed, sealed1) + } + decrypted, err := gcm.Open(nil, nonce[:], sealed, additionalData) + if err != nil { + t.Error(err) + } + decrypted1, err := gcm.Open(nil, nonce1[:], sealed1, additionalData) + if err != nil { + t.Error(err) + } + if !bytes.Equal(decrypted, plainText) { + t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText) + } + if !bytes.Equal(decrypted, decrypted1) { + t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, decrypted1) + } + }) } }