diff --git a/internal/datachannel/controller.go b/internal/datachannel/controller.go new file mode 100644 index 00000000..40b5ec8a --- /dev/null +++ b/internal/datachannel/controller.go @@ -0,0 +1,244 @@ +package datachannel + +import ( + "bytes" + "crypto/hmac" + "fmt" + "strings" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/runtimex" + "github.com/ooni/minivpn/internal/session" +) + +// dataChannelHandler manages the data "channel". +type dataChannelHandler interface { + setupKeys(*session.DataChannelKey) error + writePacket([]byte) (*model.Packet, error) + readPacket(*model.Packet) ([]byte, error) + decodeEncryptedPayload([]byte, *dataChannelState) (*encryptedData, error) + encryptAndEncodePayload([]byte, *dataChannelState) ([]byte, error) +} + +// DataChannel represents the data "channel", that will encrypt and decrypt the tunnel payloads. +// data implements the dataHandler interface. +type DataChannel struct { + options *model.Options + sessionManager *session.Manager + state *dataChannelState + decodeFn func(model.Logger, []byte, *session.Manager, *dataChannelState) (*encryptedData, error) + encryptEncodeFn func(model.Logger, []byte, *session.Manager, *dataChannelState) ([]byte, error) + decryptFn func([]byte, *encryptedData) ([]byte, error) + log model.Logger +} + +var _ dataChannelHandler = &DataChannel{} // Ensure that we implement dataChannelHandler + +// NewDataChannelFromOptions returns a new data object, initialized with the +// options given. it also returns any error raised. +func NewDataChannelFromOptions(log model.Logger, + opt *model.Options, + sessionManager *session.Manager) (*DataChannel, error) { + runtimex.Assert(opt != nil, "openvpn datachannel: opts cannot be nil") + runtimex.Assert(opt != nil, "openvpn datachannel: opts cannot be nil") + runtimex.Assert(len(opt.Cipher) != 0, "need a configured cipher option") + runtimex.Assert(len(opt.Auth) != 0, "need a configured auth option") + + state := &dataChannelState{} + data := &DataChannel{ + options: opt, + sessionManager: sessionManager, + state: state, + } + + dataCipher, err := newDataCipherFromCipherSuite(opt.Cipher) + if err != nil { + return data, err + } + data.state.dataCipher = dataCipher + switch dataCipher.isAEAD() { + case true: + data.decodeFn = decodeEncryptedPayloadAEAD + data.encryptEncodeFn = encryptAndEncodePayloadAEAD + case false: + data.decodeFn = decodeEncryptedPayloadNonAEAD + data.encryptEncodeFn = encryptAndEncodePayloadNonAEAD + } + + hmacHash, ok := newHMACFactory(strings.ToLower(opt.Auth)) + if !ok { + return data, fmt.Errorf("%w: %s", errDataChannel, fmt.Sprintf("no such mac: %v", opt.Auth)) + } + data.state.hash = hmacHash + data.decryptFn = state.dataCipher.decrypt + + log.Info(fmt.Sprintf("Cipher: %s", opt.Cipher)) + log.Info(fmt.Sprintf("Auth: %s", opt.Auth)) + + return data, nil +} + +// DecodeEncryptedPayload calls the corresponding function for AEAD or Non-AEAD decryption. +func (d *DataChannel) decodeEncryptedPayload(b []byte, dcs *dataChannelState) (*encryptedData, error) { + return d.decodeFn(d.log, b, d.sessionManager, dcs) +} + +// setSetupKeys performs the key expansion from the local and remote +// keySources, initializing the data channel state. +func (d *DataChannel) setupKeys(dck *session.DataChannelKey) error { + runtimex.Assert(dck != nil, "data channel key cannot be nil") + if !dck.Ready() { + return fmt.Errorf("%w: %s", errDataChannelKey, "key not ready") + } + master := prf( + dck.Local().PreMaster[:], + []byte("OpenVPN master secret"), + dck.Local().R1[:], + dck.Remote().R1[:], + []byte{}, []byte{}, + 48) + + keys := prf( + master, + []byte("OpenVPN key expansion"), + dck.Local().R2[:], + dck.Remote().R2[:], + d.sessionManager.LocalSessionID(), + d.sessionManager.RemoteSessionID(), + 256) + + var keyLocal, hmacLocal, keyRemote, hmacRemote keySlot + copy(keyLocal[:], keys[0:64]) + copy(hmacLocal[:], keys[64:128]) + copy(keyRemote[:], keys[128:192]) + copy(hmacRemote[:], keys[192:256]) + + d.state.cipherKeyLocal = keyLocal + d.state.hmacKeyLocal = hmacLocal + d.state.cipherKeyRemote = keyRemote + d.state.hmacKeyRemote = hmacRemote + + log.Debugf("Cipher key local: %x", keyLocal) + log.Debugf("Cipher key remote: %x", keyRemote) + log.Debugf("Hmac key local: %x", hmacLocal) + log.Debugf("Hmac key remote: %x", hmacRemote) + + hashSize := d.state.hash().Size() + d.state.hmacLocal = hmac.New(d.state.hash, hmacLocal[:hashSize]) + d.state.hmacRemote = hmac.New(d.state.hash, hmacRemote[:hashSize]) + + log.Info("Key derivation OK") + return nil +} + +// +// write + encrypt +// + +func (d *DataChannel) writePacket(payload []byte) (*model.Packet, error) { + runtimex.Assert(d.state != nil, "data: nil state") + runtimex.Assert(d.state.dataCipher != nil, "data.state: nil dataCipher") + + var plain []byte + var err error + + switch d.state.dataCipher.isAEAD() { + case true: + plain, err = doCompress(payload, d.options.Compress) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + case false: // non-aead + localPacketID, _ := d.sessionManager.LocalDataPacketID() + plain = prependPacketID(localPacketID, payload) + + plain, err = doCompress(plain, d.options.Compress) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + } + + // encrypted adds padding, if needed, and it also includes the + // opcode/keyid and peer-id headers and, if used, any authenticated + // parts in the packet. + encrypted, err := d.encryptAndEncodePayload(plain, d.state) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + + // TODO(ainghazal): increment counter for used bytes? + // and trigger renegotiation if we're near the end of the key useful lifetime. + + packet := model.NewPacket(model.P_DATA_V2, d.sessionManager.CurrentKeyID(), encrypted) + peerid := &bytes.Buffer{} + bytesx.WriteUint24(peerid, uint32(d.sessionManager.TunnelInfo().PeerID)) + packet.PeerID = model.PeerID(peerid.Bytes()) + return packet, nil +} + +// encrypt calls the corresponding function for AEAD or Non-AEAD decryption. +// Due to the particularities of the iv generation on each of the modes, encryption and encoding are +// done together in the same function. +func (d *DataChannel) encryptAndEncodePayload(plaintext []byte, dcs *dataChannelState) ([]byte, error) { + runtimex.Assert(dcs != nil, "datachanelState is nil") + runtimex.Assert(dcs.dataCipher != nil, "dcs.dataCipher is nil") + + if len(plaintext) == 0 { + return nil, fmt.Errorf("%w: nothing to encrypt", ErrCannotEncrypt) + } + + padded, err := doPadding(plaintext, d.options.Compress, dcs.dataCipher.blockSize()) + if err != nil { + return nil, + fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + + encrypted, err := d.encryptEncodeFn(d.log, padded, d.sessionManager, d.state) + if err != nil { + return nil, + fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + return encrypted, nil + +} + +// +// read + decrypt +// + +func (d *DataChannel) readPacket(p *model.Packet) ([]byte, error) { + if len(p.Payload) == 0 { + return nil, fmt.Errorf("%w: %s", ErrCannotDecrypt, "empty payload") + } + runtimex.Assert(p.IsData(), "ReadPacket expects data packet") + + plaintext, err := d.decrypt(p.Payload) + if err != nil { + return nil, err + } + + // get plaintext payload from the decrypted plaintext + return maybeDecompress(plaintext, d.state, d.options) +} + +func (d *DataChannel) decrypt(encrypted []byte) ([]byte, error) { + if d.decryptFn == nil { + return []byte{}, errInitError + } + if len(d.state.hmacKeyRemote) == 0 { + d.log.Warn("decrypt: not ready yet") + return nil, ErrCannotDecrypt + } + encryptedData, err := d.decodeEncryptedPayload(encrypted, d.state) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotDecrypt, err) + } + + plainText, err := d.decryptFn(d.state.cipherKeyRemote[:], encryptedData) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotDecrypt, err) + } + return plainText, nil +} diff --git a/internal/datachannel/crypto.go b/internal/datachannel/crypto.go new file mode 100644 index 00000000..c35086f3 --- /dev/null +++ b/internal/datachannel/crypto.go @@ -0,0 +1,363 @@ +package datachannel + +// +// Code to perform encryption, decryption and key derivation. +// + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" + "log" + + "github.com/ooni/minivpn/internal/bytesx" +) //#nosec G501,G505 +// We know that sha1 and md5 are insecure, but we do not control the openvpn protocol. + +// TODO(ainghazal,bassosimone): see if it's feasible to use stdlib +// functionality rather than using the code below. + +type ( + // cipherMode describes a cipher mode (e.g., GCM). + cipherMode string + + // cipherName is a cipher name (e.g., AES). + cipherName string +) + +const ( + // cipherModeCBC is the CBC cipher mode. + cipherModeCBC = cipherMode("cbc") + + // cipherModeGCM is the GCM cipher mode. + cipherModeGCM = cipherMode("gcm") + + // cipherNameAES is an AES-based cipher. + cipherNameAES = cipherName("aes") +) + +// encrypteData holds the different parts needed to decrypt an encrypted data +// packet. +type encryptedData struct { + iv []byte + ciphertext []byte + aead []byte +} + +// plaintextData holds the different parts needed to encrypt a plaintext +// payload (after padding). +type plaintextData struct { + iv []byte + plaintext []byte + aead []byte +} + +// dataCipher encrypts and decrypts OpenVPN data. +type dataCipher interface { + // keySizeBytes returns the key size (in bytes). + keySizeBytes() int + + // isAEAD returns whether this cipher has AEAD properties. + isAEAD() bool + + // blockSize returns the expected block size. + blockSize() uint8 + + // encrypt encripts a plaintext. + // + // Arguments: + // + // - key is the key, whose size must be consistent with the cipher; + // + // - plaintextData is the data to be encrypted; + // + // Returns the ciphertext on success and an error on failure. + encrypt([]byte, *plaintextData) ([]byte, error) + + // decrypt is the opposite operation of encrypt. It takes in input the + // ciphertext and returns the plaintext of an error. + decrypt([]byte, *encryptedData) ([]byte, error) + + // mode returns the cipherMode + cipherMode() cipherMode +} + +// dataCipherAES implements dataCipher for AES. +type dataCipherAES struct { + // ksb is the key size in bytes + ksb int + + // mode is the cipher mode + mode cipherMode +} + +var _ dataCipher = &dataCipherAES{} // Ensure we implement dataCipher + +// keySizeBytes implements dataCipher.keySizeBytes +func (a *dataCipherAES) keySizeBytes() int { + return a.ksb +} + +// isAEAD implements dataCipher.isAEAD +func (a *dataCipherAES) isAEAD() bool { + return a.mode != cipherModeCBC +} + +// blockSize implements dataCipher.BlockSize +func (a *dataCipherAES) blockSize() uint8 { + switch a.mode { + case cipherModeCBC, cipherModeGCM: + return 16 + default: + return 0 + } +} + +// decrypt implements dataCipher.decrypt. +// Since key comes from a prf derivation, we only take as many bytes as we need to match +// our key size. +func (a *dataCipherAES) decrypt(key []byte, data *encryptedData) ([]byte, error) { + // TODO(ainghazal): split this function, it's too large + if len(key) < a.keySizeBytes() { + return nil, errInvalidKeySize + } + + // they key material might be longer + k := key[:a.keySizeBytes()] + block, err := aes.NewCipher(k) + if err != nil { + return nil, err + } + switch a.mode { + case cipherModeCBC: + if len(data.iv) != block.BlockSize() { + return nil, fmt.Errorf("%w: wrong size for iv: %v", ErrCannotDecrypt, len(data.iv)) + } + mode := cipher.NewCBCDecrypter(block, data.iv) + plaintext := make([]byte, len(data.ciphertext)) + mode.CryptBlocks(plaintext, data.ciphertext) + plaintext, err := bytesx.BytesUnpadPKCS7(plaintext, block.BlockSize()) + if err != nil { + return nil, err + } + padLen := len(data.ciphertext) - len(plaintext) + if padLen > block.BlockSize() || padLen > len(plaintext) { + // TODO(bassosimone, ainghazal): discuss the cases in which + // this set of conditions actually occurs. + // TODO(ainghazal): this assertion might actually be moved into a + // boundary assertion in the unpad fun. + return nil, errors.New("unpadding error") + } + return plaintext, nil + + case cipherModeGCM: + // standard nonce size is 12. more is surely ok, but let's stick to it. + // https://github.com/golang/go/blob/master/src/crypto/aes/aes_gcm.go#L37 + if len(data.iv) != 12 { + return nil, fmt.Errorf("%w: wrong size for iv: %v", ErrCannotDecrypt, len(data.iv)) + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + plaintext, err := aesGCM.Open(nil, data.iv, data.ciphertext, data.aead) + if err != nil { + log.Println("gdm decryption failed:", err.Error()) + /* + log.Println("dump begins----") + log.Println("len:", len(data.ciphertext)) + log.Println("iv:", data.iv) + log.Printf("%v\n", data.ciphertext) + log.Printf("%x\n", data.ciphertext) + log.Printf("aead: %x\n", data.aead) + log.Println("dump ends------") + */ + return nil, err + } + return plaintext, nil + + default: + return nil, errUnsupportedMode + } +} + +func (a *dataCipherAES) cipherMode() cipherMode { + return a.mode +} + +// encrypt implements dataCipher.encrypt +// Since key comes from a prf derivation, we only take as many bytes as we need to match +// our key size. +func (a *dataCipherAES) encrypt(key []byte, data *plaintextData) ([]byte, error) { + if len(key) < a.keySizeBytes() { + return nil, errInvalidKeySize + } + k := key[:a.keySizeBytes()] + block, err := aes.NewCipher(k) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + switch a.mode { + case cipherModeCBC: + if len(data.iv) != blockSize { + return []byte{}, fmt.Errorf("%w: wrong size for iv: %v", ErrCannotEncrypt, len(data.iv)) + } + if len(data.plaintext)%blockSize != 0 { + return []byte{}, fmt.Errorf("%w: wrong padding", ErrCannotEncrypt) + } + mode := cipher.NewCBCEncrypter(block, data.iv) + + ciphertext := make([]byte, len(data.plaintext)) + mode.CryptBlocks(ciphertext, data.plaintext) + return ciphertext, nil + + case cipherModeGCM: + if len(data.iv) != 12 { + return []byte{}, fmt.Errorf("%w: wrong size for iv: %v", ErrCannotEncrypt, len(data.iv)) + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + // In GCM mode, the IV consists of the 32-bit packet counter + // followed by data from the HMAC key. The HMAC key can be used + // as IV, since in GCM mode the HMAC key is not used for the + // HMAC. The packet counter may not roll over within a single + // TLS session. This results in a unique IV for each packet, as + // required by GCM. + ciphertext := aesGCM.Seal(nil, data.iv, data.plaintext, data.aead) + return ciphertext, nil + + default: + return nil, errUnsupportedMode + } +} + +// newDataCipherFromCipherSuite constructs a new dataCipher from the cipher suite string. +func newDataCipherFromCipherSuite(c string) (dataCipher, error) { + switch c { + case "AES-128-CBC": + return newDataCipher(cipherNameAES, 128, cipherModeCBC) + case "AES-192-CBC": + return newDataCipher(cipherNameAES, 192, cipherModeCBC) + case "AES-256-CBC": + return newDataCipher(cipherNameAES, 256, cipherModeCBC) + case "AES-128-GCM": + return newDataCipher(cipherNameAES, 128, cipherModeGCM) + case "AES-256-GCM": + return newDataCipher(cipherNameAES, 256, cipherModeGCM) + default: + return nil, errUnsupportedCipher + } +} + +// newDataCipher constructs a new dataCipher from the given name, bits, and mode. +func newDataCipher(name cipherName, bits int, mode cipherMode) (dataCipher, error) { + if bits%8 != 0 || bits > 512 || bits < 64 { + return nil, fmt.Errorf("%w: %d", errInvalidKeySize, bits) + } + switch name { + case cipherNameAES: + default: + return nil, fmt.Errorf("%w: %s", errUnsupportedCipher, name) + } + switch mode { + case cipherModeCBC, cipherModeGCM: + default: + return nil, fmt.Errorf("%w: %s", errUnsupportedMode, mode) + } + dc := &dataCipherAES{ + ksb: bits / 8, + mode: mode, + } + return dc, nil +} + +// newHMACFactory accepts a label coming from an OpenVPN auth label, and returns two +// values: a function that will return a Hash implementation, and a boolean +// indicating if the operation was successful. +func newHMACFactory(name string) (func() hash.Hash, bool) { + switch name { + case "sha1": + return sha1.New, true + case "sha256": + return sha256.New, true + case "sha512": + return sha512.New, true + default: + return nil, false + } +} + +// prf function is used to derive master and client keys +func prf(secret, label, clientSeed, serverSeed, clientSid, serverSid []byte, olen int) []byte { + seed := append(clientSeed, serverSeed...) + if len(clientSid) != 0 { + seed = append(seed, clientSid...) + } + if len(serverSid) != 0 { + seed = append(seed, serverSid...) + } + result := make([]byte, olen) + return prf10(result, secret, label, seed) +} + +// Code below is taken from crypto/tls/prf.go +// Copyright 2009 The Go Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. +func prf10(result, secret, label, seed []byte) []byte { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + for i, b := range result2 { + result[i] ^= b + } + return result +} + +// SPDX-License-Identifier: BSD-3-Clause +// Split a premaster secret in two as specified in RFC 4346, Section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return + +} + +// SPDX-License-Identifier: BSD-3-Clause +// pHash implements the P_hash function, as defined in RFC 4346, Section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} diff --git a/internal/datachannel/doc.go b/internal/datachannel/doc.go new file mode 100644 index 00000000..0b81c4a8 --- /dev/null +++ b/internal/datachannel/doc.go @@ -0,0 +1,4 @@ +// Package datachannel implements packet encryption and decryption over the OpenVPN Data Channel. +// Encryption Keys are derived after a successful TLS handshake, and they have a limited +// lifetime. +package datachannel diff --git a/internal/datachannel/errors.go b/internal/datachannel/errors.go new file mode 100644 index 00000000..cf35f4ac --- /dev/null +++ b/internal/datachannel/errors.go @@ -0,0 +1,29 @@ +package datachannel + +import "errors" + +var ( + errDataChannel = errors.New("datachannel error") + errDataChannelKey = errors.New("bad key") + errBadCompression = errors.New("bad compression") + errReplayAttack = errors.New("replay attack") + errBadHMAC = errors.New("bad hmac") + errInitError = errors.New("improperly initialized") + errExpiredKey = errors.New("key is expired") + + // errInvalidKeySize means that the key size is invalid. + errInvalidKeySize = errors.New("invalid key size") + + // errUnsupportedCipher indicates we don't support the desired cipher. + errUnsupportedCipher = errors.New("unsupported cipher") + + // errUnsupportedMode indicates that the mode is not uspported. + errUnsupportedMode = errors.New("unsupported mode") + + // errBadInput indicates invalid inputs to encrypt/decrypt functions. + errBadInput = errors.New("bad input") + + ErrSerialization = errors.New("cannot create packet") + ErrCannotEncrypt = errors.New("cannot encrypt") + ErrCannotDecrypt = errors.New("cannot decrypt") +) diff --git a/internal/datachannel/read.go b/internal/datachannel/read.go new file mode 100644 index 00000000..4d25be2d --- /dev/null +++ b/internal/datachannel/read.go @@ -0,0 +1,164 @@ +package datachannel + +import ( + "bytes" + "crypto/hmac" + "encoding/binary" + "errors" + "fmt" + + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/runtimex" + "github.com/ooni/minivpn/internal/session" +) + +func decodeEncryptedPayloadAEAD(log model.Logger, buf []byte, session *session.Manager, state *dataChannelState) (*encryptedData, error) { + // P_DATA_V2 GCM data channel crypto format + // 48000001 00000005 7e7046bd 444a7e28 cc6387b1 64a4d6c1 380275a... + // [ OP32 ] [seq # ] [ auth tag ] [ payload ... ] + // - means authenticated - * means encrypted * + // [ - opcode/peer-id - ] [ - packet ID - ] [ TAG ] [ * packet payload * ] + + // preconditions + + if len(buf) == 0 || len(buf) < 20 { + return nil, fmt.Errorf("too short: %d bytes", len(buf)) + } + if len(state.hmacKeyRemote) < 8 { + return nil, fmt.Errorf("bad remote hmac") + } + remoteHMAC := state.hmacKeyRemote[:8] + packet_id := buf[:4] + + headers := &bytes.Buffer{} + headers.WriteByte(opcodeAndKeyHeader(session)) + bytesx.WriteUint24(headers, uint32(session.TunnelInfo().PeerID)) + headers.Write(packet_id) + + // we need to swap because decryption expects payload|tag + // but we've got tag | payload instead + payload := &bytes.Buffer{} + payload.Write(buf[20:]) // ciphertext + payload.Write(buf[4:20]) // tag + + // iv := packetID | remoteHMAC + iv := &bytes.Buffer{} + iv.Write(packet_id) + iv.Write(remoteHMAC) + + encrypted := &encryptedData{ + iv: iv.Bytes(), + ciphertext: payload.Bytes(), + aead: headers.Bytes(), + } + return encrypted, nil +} + +var errCannotDecode = errors.New("cannot decode") + +func decodeEncryptedPayloadNonAEAD(log model.Logger, buf []byte, session *session.Manager, state *dataChannelState) (*encryptedData, error) { + runtimex.Assert(state != nil, "passed nil state") + runtimex.Assert(state.dataCipher != nil, "data cipher not initialized") + + hashSize := uint8(state.hmacRemote.Size()) + blockSize := state.dataCipher.blockSize() + + minLen := hashSize + blockSize + + if len(buf) < int(minLen) { + return &encryptedData{}, fmt.Errorf("%w: too short (%d bytes)", errCannotDecode, len(buf)) + } + + receivedHMAC := buf[:hashSize] + iv := buf[hashSize : hashSize+blockSize] + cipherText := buf[hashSize+blockSize:] + + state.hmacRemote.Reset() + state.hmacRemote.Write(iv) + state.hmacRemote.Write(cipherText) + computedHMAC := state.hmacRemote.Sum(nil) + + if !hmac.Equal(computedHMAC, receivedHMAC) { + log.Warnf("expected: %x, got: %x", computedHMAC, receivedHMAC) + return &encryptedData{}, fmt.Errorf("%w: %s", ErrCannotDecrypt, errBadHMAC) + } + + encrypted := &encryptedData{ + iv: iv, + ciphertext: cipherText, + aead: []byte{}, // no AEAD data in this mode, leaving it empty to satisfy common interface + } + return encrypted, nil +} + +// maybeDecompress de-serializes the data from the payload according to the framing +// given by different compression methods. only the different no-compression +// modes are supported at the moment, so no real decompression is done. It +// returns a byte array, and an error if the operation could not be completed +// successfully. +func maybeDecompress(b []byte, st *dataChannelState, opt *model.Options) ([]byte, error) { + if st == nil || st.dataCipher == nil { + return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad state") + } + if opt == nil { + return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad options") + } + + var compr byte // compression type + var payload []byte + + // TODO(ainghazal): have two different decompress implementations + // instead of this switch + switch st.dataCipher.isAEAD() { + case true: + switch opt.Compress { + case model.CompressionStub, model.CompressionLZONo: + // these are deprecated in openvpn 2.5.x + compr = b[0] + payload = b[1:] + default: + compr = 0x00 + payload = b[:] + } + default: // non-aead + remotePacketID := model.PacketID(binary.BigEndian.Uint32(b[:4])) + lastKnownRemote, err := st.RemotePacketID() + if err != nil { + return payload, err + } + if remotePacketID <= lastKnownRemote { + return []byte{}, errReplayAttack + } + st.SetRemotePacketID(remotePacketID) + + switch opt.Compress { + case model.CompressionStub, model.CompressionLZONo: + compr = b[4] + payload = b[5:] + default: + compr = 0x00 + payload = b[4:] + } + } + + switch compr { + case 0xfb: + // compression stub swap: + // we get the last byte and replace the compression byte + // these are deprecated in openvpn 2.5.x + end := payload[len(payload)-1] + b := payload[:len(payload)-1] + payload = append([]byte{end}, b...) + case 0x00, 0xfa: + // do nothing + // 0x00 is compress-no, + // 0xfa is the old no compression or comp-lzo no case. + // http://build.openvpn.net/doxygen/comp_8h_source.html + // see: https://community.openvpn.net/openvpn/ticket/952#comment:5 + default: + errMsg := fmt.Sprintf("cannot handle compression:%x", compr) + return []byte{}, fmt.Errorf("%w:%s", errBadCompression, errMsg) + } + return payload, nil +} diff --git a/internal/datachannel/service.go b/internal/datachannel/service.go new file mode 100644 index 00000000..7e263291 --- /dev/null +++ b/internal/datachannel/service.go @@ -0,0 +1,194 @@ +package datachannel + +// +// OpenVPN data channel +// + +import ( + "encoding/hex" + "fmt" + "sync" + + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +// Service is the datachannel service. Make sure you initialize +// the channels before invoking [Service.StartWorkers]. +type Service struct { + // MuxerToData moves packets up to us + MuxerToData chan *model.Packet + + // DataOrControlToMuxer is a shared channel to write packets to the muxer layer below + DataOrControlToMuxer *chan *model.Packet + + // TUNToData moves bytes down from the TUN layer above + TUNToData chan []byte + + // DataToTUN moves bytes up from us to the TUN layer above us + DataToTUN chan []byte + + // KeyReady is where the TLSState layer passes us any new keys + KeyReady chan *session.DataChannelKey +} + +// StartWorkers starts the data-channel workers. +// +// We start three workers: +// +// 1. moveUpWorker BLOCKS on dataPacketUp to read a packet coming from the muxer and +// eventually BLOCKS on tunUp to deliver it; +// +// 2. moveDownWorker BLOCKS on tunDown to read a packet and +// eventually BLOCKS on packetDown to deliver it; +// +// 3. keyWorker BLOCKS on keyUp to read an dataChannelKey and +// initializes the internal state with the resulting key; +func (s *Service) StartWorkers( + logger model.Logger, + workersManager *workers.Manager, + sessionManager *session.Manager, + options *model.Options, +) { + dc, err := NewDataChannelFromOptions(logger, options, sessionManager) + if err != nil { + logger.Warnf("cannot initialize channel %v", err) + return + } + ws := &workersState{ + logger: logger, + muxerToData: s.MuxerToData, + dataOrControlToMuxer: *s.DataOrControlToMuxer, + tunToData: s.TUNToData, + dataToTUN: s.DataToTUN, + keyReady: s.KeyReady, + dataChannel: dc, + workersManager: workersManager, + sessionManager: sessionManager, + } + + firstKeyReady := make(chan any) + + workersManager.StartWorker(ws.moveUpWorker) + workersManager.StartWorker(func() { ws.moveDownWorker(firstKeyReady) }) + workersManager.StartWorker(func() { ws.keyWorker(firstKeyReady) }) +} + +// workersState contains the data channel state. +type workersState struct { + logger model.Logger + workersManager *workers.Manager + sessionManager *session.Manager + keyReady <-chan *session.DataChannelKey + muxerToData <-chan *model.Packet + dataOrControlToMuxer chan<- *model.Packet + dataToTUN chan<- []byte + tunToData <-chan []byte + dataChannel *DataChannel +} + +// moveDownWorker moves packets down the stack. It will BLOCK on PacketDown +func (ws *workersState) moveDownWorker(firstKeyReady <-chan any) { + defer func() { + ws.workersManager.OnWorkerDone() + ws.workersManager.StartShutdown() + ws.logger.Debug("datachannel: moveDownWorker: done") + }() + select { + // wait for the first key to be ready + case <-firstKeyReady: + for { + select { + case data := <-ws.tunToData: + // TODO: writePacket should get the ACTIVE KEY (verify this) + packet, err := ws.dataChannel.writePacket(data) + if err != nil { + ws.logger.Warnf("error encrypting: %v", err) + continue + } + // ws.logger.Infof("encrypted %d bytes", len(packet.Payload)) + + select { + case ws.dataOrControlToMuxer <- packet: + default: + // drop the packet if the buffer is full + case <-ws.workersManager.ShouldShutdown(): + return + } + + case <-ws.workersManager.ShouldShutdown(): + return + } + } + case <-ws.workersManager.ShouldShutdown(): + return + } +} + +// moveUpWorker moves packets up the stack +func (ws *workersState) moveUpWorker() { + defer func() { + ws.workersManager.OnWorkerDone() + ws.workersManager.StartShutdown() + ws.logger.Debug("datachannel: moveUpWorker: done") + }() + for { + select { + // TODO: opportunistically try to kill lame duck + + case pkt := <-ws.muxerToData: + // TODO(ainghazal): factor out as handler function + decrypted, err := ws.dataChannel.readPacket(pkt) + if err != nil { + ws.logger.Warnf("error decrypting: %v", err) + continue + } + + if len(decrypted) == 16 { + // HACK - figure out what this fixed packet is. keepalive? + // "2a 18 7b f3 64 1e b4 cb 07 ed 2d 0a 98 1f c7 48" + fmt.Println(hex.Dump(decrypted)) + continue + } + + ws.dataToTUN <- decrypted + case <-ws.workersManager.ShouldShutdown(): + return + } + } +} + +// keyWorker receives notifications from key ready +func (ws *workersState) keyWorker(firstKeyReady chan<- any) { + defer func() { + ws.workersManager.OnWorkerDone() + ws.workersManager.StartShutdown() + ws.logger.Debug("datachannel: worker: done") + }() + + ws.logger.Debug("datachannel: worker: started") + once := &sync.Once{} + + for { + select { + case key := <-ws.keyReady: + // TODO(keyrotation): thread safety here - need to lock. + // When we actually get to key rotation, we need to add locks. + // Use RW lock, reader locks. + + err := ws.dataChannel.setupKeys(key) + if err != nil { + ws.logger.Warnf("error on key derivation: %v", err) + continue + } + ws.sessionManager.SetNegotiationState(session.S_GENERATED_KEYS) + once.Do(func() { + close(firstKeyReady) + }) + + case <-ws.workersManager.ShouldShutdown(): + return + } + } +} diff --git a/internal/datachannel/state.go b/internal/datachannel/state.go new file mode 100644 index 00000000..7c3fc33e --- /dev/null +++ b/internal/datachannel/state.go @@ -0,0 +1,58 @@ +package datachannel + +import ( + "hash" + "math" + "sync" + + "github.com/ooni/minivpn/internal/model" +) + +// keySlot holds the different local and remote keys. +type keySlot [64]byte + +// dataChannelState is the state of the data channel. +type dataChannelState struct { + dataCipher dataCipher + + // outgoing and incoming nomenclature is probably more adequate here. + hmacLocal hash.Hash + hmacRemote hash.Hash + cipherKeyLocal keySlot + cipherKeyRemote keySlot + hmacKeyLocal keySlot + hmacKeyRemote keySlot + + /* + // not used at the moment, paving the way for key rotation. + keyID int + */ + + // TODO(ainghazal): we need to keep a local packetID too. It should be separated from the control channel. + // TODO: move this to sessionManager perhaps? + remotePacketID model.PacketID + + hash func() hash.Hash + mu sync.Mutex +} + +// SetRemotePacketID stores the passed packetID internally. +func (dcs *dataChannelState) SetRemotePacketID(id model.PacketID) { + dcs.mu.Lock() + defer dcs.mu.Unlock() + dcs.remotePacketID = model.PacketID(id) +} + +// RemotePacketID returns the last known remote packetID. It returns an error +// if the stored packet id has reached the maximum capacity of the packetID +// type. +func (dcs *dataChannelState) RemotePacketID() (model.PacketID, error) { + dcs.mu.Lock() + defer dcs.mu.Unlock() + pid := dcs.remotePacketID + if pid == math.MaxUint32 { + // we reached the max packetID, increment will overflow + return 0, errExpiredKey + } + return pid, nil +} diff --git a/internal/datachannel/write.go b/internal/datachannel/write.go new file mode 100644 index 00000000..995389cf --- /dev/null +++ b/internal/datachannel/write.go @@ -0,0 +1,169 @@ +package datachannel + +// +// Functions for encoding & writing packets +// + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" +) + +// encryptAndEncodePayloadAEAD peforms encryption and encoding of the payload in AEAD modes (i.e., AES-GCM). +// TODO(ainghazal): for testing we can pass both the state object and the encryptFn +func encryptAndEncodePayloadAEAD(log model.Logger, padded []byte, session *session.Manager, state *dataChannelState) ([]byte, error) { + // TODO(ainghazal): call Session.NewPacket() instead? + nextPacketID, err := session.LocalDataPacketID() + if err != nil { + return []byte{}, fmt.Errorf("bad packet id") + } + + // in AEAD mode, we authenticate: + // - 1 byte: opcode/key + // - 3 bytes: peer-id (we're using P_DATA_V2) + // - 4 bytes: packet-id + aead := &bytes.Buffer{} + aead.WriteByte(opcodeAndKeyHeader(session)) + bytesx.WriteUint24(aead, uint32(session.TunnelInfo().PeerID)) + bytesx.WriteUint32(aead, uint32(nextPacketID)) + + // the iv is the packetID (again) concatenated with the 8 bytes of the + // key derived for local hmac (which we do not use for anything else in AEAD mode). + iv := &bytes.Buffer{} + bytesx.WriteUint32(iv, uint32(nextPacketID)) + iv.Write(state.hmacKeyLocal[:8]) + + data := &plaintextData{ + iv: iv.Bytes(), + plaintext: padded, + aead: aead.Bytes(), + } + + encryptFn := state.dataCipher.encrypt + encrypted, err := encryptFn(state.cipherKeyLocal[:], data) + if err != nil { + return []byte{}, err + } + + // some reordering, because openvpn uses tag | payload + boundary := len(encrypted) - 16 + tag := encrypted[boundary:] + ciphertext := encrypted[:boundary] + + // we now write to the output buffer + out := bytes.Buffer{} + out.Write(data.aead) // opcode|peer-id|packet_id + out.Write(tag) + out.Write(ciphertext) + return out.Bytes(), nil + +} + +// encryptAndEncodePayloadNonAEAD peforms encryption and encoding of the payload in Non-AEAD modes (i.e., AES-CBC). +func encryptAndEncodePayloadNonAEAD(log model.Logger, padded []byte, session *session.Manager, state *dataChannelState) ([]byte, error) { + // For iv generation, OpenVPN uses a nonce-based PRNG that is initially seeded with + // OpenSSL RAND_bytes function. I am assuming this is good enough for our current purposes. + blockSize := state.dataCipher.blockSize() + + iv, err := bytesx.GenRandomBytes(int(blockSize)) + if err != nil { + return nil, err + } + data := &plaintextData{ + iv: iv, + plaintext: padded, + aead: nil, + } + + encryptFn := state.dataCipher.encrypt + ciphertext, err := encryptFn(state.cipherKeyLocal[:], data) + if err != nil { + return nil, err + } + + state.hmacLocal.Reset() + state.hmacLocal.Write(iv) + state.hmacLocal.Write(ciphertext) + computedMAC := state.hmacLocal.Sum(nil) + + out := &bytes.Buffer{} + out.WriteByte(opcodeAndKeyHeader(session)) + bytesx.WriteUint24(out, uint32(session.TunnelInfo().PeerID)) + + out.Write(computedMAC) + out.Write(iv) + out.Write(ciphertext) + return out.Bytes(), nil +} + +// doCompress adds compression bytes if needed by the passed compression options. +// if the compression stub is on, it sends the first byte to the last position, +// and it adds the compression preamble, according to the spec. compression +// lzo-no also adds a preamble. It returns a byte array and an error if the +// operation could not be completed. +func doCompress(b []byte, compress model.Compression) ([]byte, error) { + switch compress { + case "stub": + // compression stub: send first byte to last + // and add 0xfb marker on the first byte. + b = append(b, b[0]) + b[0] = 0xfb + case "lzo-no": + // old "comp-lzo no" option + b = append([]byte{0xfa}, b...) + } + return b, nil +} + +var errPadding = errors.New("padding error") + +// doPadding does pkcs7 padding of the encryption payloads as +// needed. if we're using the compression stub the padding is applied without taking the +// trailing bit into account. it returns the resulting byte array, and an error +// if the operatio could not be completed. +func doPadding(b []byte, compress model.Compression, blockSize uint8) ([]byte, error) { + if len(b) == 0 { + return nil, fmt.Errorf("%w: %s", errPadding, "nothing to pad") + } + if compress == "stub" { + // if we're using the compression stub + // we need to account for a trailing byte + // that we have appended in the doCompress stage. + endByte := b[len(b)-1] + padded, err := bytesx.BytesPadPKCS7(b[:len(b)-1], int(blockSize)) + if err != nil { + return nil, err + } + padded[len(padded)-1] = endByte + return padded, nil + } + padded, err := bytesx.BytesPadPKCS7(b, int(blockSize)) + if err != nil { + return nil, err + } + return padded, nil +} + +// TODO(ainghazal): move to a different layer? +// prependPacketID returns the original buffer with the passed packetID +// concatenated at the beginning. +func prependPacketID(p model.PacketID, buf []byte) []byte { + newbuf := &bytes.Buffer{} + packetID := make([]byte, 4) + binary.BigEndian.PutUint32(packetID, uint32(p)) + newbuf.Write(packetID[:]) + newbuf.Write(buf) + return newbuf.Bytes() +} + +// opcodeAndKeyHeader returns the header byte encoding the opcode and keyID (3 upper +// and 5 lower bits, respectively) +func opcodeAndKeyHeader(session *session.Manager) byte { + return byte((byte(model.P_DATA_V2) << 3) | (byte(session.CurrentKeyID()) & 0x07)) +}