From 486d5bebcfbfebbb63bf37a3ca0c0c136796bbd5 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 19 Jan 2024 13:14:50 +0100 Subject: [PATCH 1/7] data channel --- internal/datachannel/controller.go | 249 ++++++++++++++++++++ internal/datachannel/crypto.go | 362 +++++++++++++++++++++++++++++ internal/datachannel/errors.go | 29 +++ internal/datachannel/read.go | 164 +++++++++++++ internal/datachannel/service.go | 181 +++++++++++++++ internal/datachannel/state.go | 57 +++++ internal/datachannel/write.go | 169 ++++++++++++++ 7 files changed, 1211 insertions(+) create mode 100644 internal/datachannel/controller.go create mode 100644 internal/datachannel/crypto.go create mode 100644 internal/datachannel/errors.go create mode 100644 internal/datachannel/read.go create mode 100644 internal/datachannel/service.go create mode 100644 internal/datachannel/state.go create mode 100644 internal/datachannel/write.go diff --git a/internal/datachannel/controller.go b/internal/datachannel/controller.go new file mode 100644 index 00000000..a97714b8 --- /dev/null +++ b/internal/datachannel/controller.go @@ -0,0 +1,249 @@ +package datachannel + +import ( + "bytes" + "crypto/hmac" + "fmt" + "strings" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/runtimex" + "github.com/ooni/minivpn/internal/session" +) + +// dataChannelHandler manages the data "channel". +type dataChannelHandler interface { + setupKeys(*session.DataChannelKey) error + writePacket([]byte) (*model.Packet, error) + readPacket(*model.Packet) ([]byte, error) + decodeEncryptedPayload([]byte, *dataChannelState) (*encryptedData, error) + encryptAndEncodePayload([]byte, *dataChannelState) ([]byte, error) +} + +// DataChannel represents the data "channel", that will encrypt and decrypt the tunnel payloads. +// data implements the dataHandler interface. +type DataChannel struct { + options *model.Options + sessionManager *session.Manager + state *dataChannelState + decodeFn func(model.Logger, []byte, *session.Manager, *dataChannelState) (*encryptedData, error) + encryptEncodeFn func(model.Logger, []byte, *session.Manager, *dataChannelState) ([]byte, error) + decryptFn func([]byte, *encryptedData) ([]byte, error) + log model.Logger +} + +var _ dataChannelHandler = &DataChannel{} // Ensure that we implement dataChannelHandler + +// NewDataChannelFromOptions returns a new data object, initialized with the +// options given. it also returns any error raised. +func NewDataChannelFromOptions(log model.Logger, + opt *model.Options, + sessionManager *session.Manager) (*DataChannel, error) { + runtimex.Assert(opt != nil, "openvpn datachannel: opts cannot be nil") + runtimex.Assert(opt != nil, "openvpn datachannel: opts cannot be nil") + runtimex.Assert(len(opt.Cipher) != 0, "need a configured cipher option") + runtimex.Assert(len(opt.Auth) != 0, "need a configured auth option") + + state := &dataChannelState{} + data := &DataChannel{ + options: opt, + sessionManager: sessionManager, + state: state, + } + + dataCipher, err := newDataCipherFromCipherSuite(opt.Cipher) + if err != nil { + return data, err + } + data.state.dataCipher = dataCipher + switch dataCipher.isAEAD() { + case true: + data.decodeFn = decodeEncryptedPayloadAEAD + data.encryptEncodeFn = encryptAndEncodePayloadAEAD + case false: + data.decodeFn = decodeEncryptedPayloadNonAEAD + data.encryptEncodeFn = encryptAndEncodePayloadNonAEAD + } + + hmacHash, ok := newHMACFactory(strings.ToLower(opt.Auth)) + if !ok { + return data, fmt.Errorf("%w: %s", errDataChannel, fmt.Sprintf("no such mac: %v", opt.Auth)) + } + data.state.hash = hmacHash + data.decryptFn = state.dataCipher.decrypt + + log.Info(fmt.Sprintf("Cipher: %s", opt.Cipher)) + log.Info(fmt.Sprintf("Auth: %s", opt.Auth)) + + return data, nil +} + +// DecodeEncryptedPayload calls the corresponding function for AEAD or Non-AEAD decryption. +func (d *DataChannel) decodeEncryptedPayload(b []byte, dcs *dataChannelState) (*encryptedData, error) { + return d.decodeFn(d.log, b, d.sessionManager, dcs) +} + +// setSetupKeys performs the key expansion from the local and remote +// keySources, initializing the data channel state. +func (d *DataChannel) setupKeys(dck *session.DataChannelKey) error { + runtimex.Assert(dck != nil, "data channel key cannot be nil") + if !dck.Ready() { + return fmt.Errorf("%w: %s", errDataChannelKey, "key not ready") + } + master := prf( + dck.Local().PreMaster[:], + []byte("OpenVPN master secret"), + dck.Local().R1[:], + dck.Remote().R1[:], + []byte{}, []byte{}, + 48) + + keys := prf( + master, + []byte("OpenVPN key expansion"), + dck.Local().R2[:], + dck.Remote().R2[:], + d.sessionManager.LocalSessionID(), + d.sessionManager.RemoteSessionID(), + 256) + + var keyLocal, hmacLocal, keyRemote, hmacRemote keySlot + copy(keyLocal[:], keys[0:64]) + copy(hmacLocal[:], keys[64:128]) + copy(keyRemote[:], keys[128:192]) + copy(hmacRemote[:], keys[192:256]) + + d.state.cipherKeyLocal = keyLocal + d.state.hmacKeyLocal = hmacLocal + d.state.cipherKeyRemote = keyRemote + d.state.hmacKeyRemote = hmacRemote + + log.Debugf("Cipher key local: %x", keyLocal) + log.Debugf("Cipher key remote: %x", keyRemote) + log.Debugf("Hmac key local: %x", hmacLocal) + log.Debugf("Hmac key remote: %x", hmacRemote) + + hashSize := d.state.hash().Size() + d.state.hmacLocal = hmac.New(d.state.hash, hmacLocal[:hashSize]) + d.state.hmacRemote = hmac.New(d.state.hash, hmacRemote[:hashSize]) + + log.Info("Key derivation OK") + return nil +} + +// +// write + encrypt +// + +func (d *DataChannel) writePacket(payload []byte) (*model.Packet, error) { + runtimex.Assert(d.state != nil, "data: nil state") + runtimex.Assert(d.state.dataCipher != nil, "data.state: nil dataCipher") + + var plain []byte + var err error + + switch d.state.dataCipher.isAEAD() { + case true: + plain, err = doCompress(payload, d.options.Compress) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + case false: // non-aead + localPacketID, _ := d.sessionManager.LocalDataPacketID() + plain = prependPacketID(localPacketID, payload) + + plain, err = doCompress(plain, d.options.Compress) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + } + + // encrypted adds padding, if needed, and it also includes the + // opcode/keyid and peer-id headers and, if used, any authenticated + // parts in the packet. + encrypted, err := d.encryptAndEncodePayload(plain, d.state) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + + // TODO(ainghazal): increment counter for used bytes? + // and trigger renegotiation if we're near the end of the key useful lifetime. + + packet := model.NewPacket(model.P_DATA_V2, d.sessionManager.CurrentKeyID(), encrypted) + // packet, err := d.sessionManager.NewPacket(model.P_DATA_V2, encrypted) + //if err != nil { + // return nil, fmt.Errorf("%w: %s", ErrSerialization, err) + // } + peerid := &bytes.Buffer{} + bytesx.WriteUint24(peerid, uint32(d.sessionManager.TunnelInfo().PeerID)) + packet.PeerID = model.PeerID(peerid.Bytes()) + return packet, nil +} + +// encrypt calls the corresponding function for AEAD or Non-AEAD decryption. +// Due to the particularities of the iv generation on each of the modes, encryption and encoding are +// done together in the same function. +// TODO accept state for symmetry +func (d *DataChannel) encryptAndEncodePayload(plaintext []byte, dcs *dataChannelState) ([]byte, error) { + runtimex.Assert(dcs != nil, "datachanelState is nil") + runtimex.Assert(dcs.dataCipher != nil, "dcs.dataCipher is nil") + + if len(plaintext) == 0 { + return nil, fmt.Errorf("%w: nothing to encrypt", ErrCannotEncrypt) + } + + padded, err := doPadding(plaintext, d.options.Compress, dcs.dataCipher.blockSize()) + if err != nil { + return nil, + fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + + encrypted, err := d.encryptEncodeFn(d.log, padded, d.sessionManager, d.state) + if err != nil { + return nil, + fmt.Errorf("%w: %s", ErrCannotEncrypt, err) + } + return encrypted, nil + +} + +// +// read + decrypt +// + +func (d *DataChannel) readPacket(p *model.Packet) ([]byte, error) { + if len(p.Payload) == 0 { + return nil, fmt.Errorf("%w: %s", ErrCannotDecrypt, "empty payload") + } + runtimex.Assert(p.IsData(), "ReadPacket expects data packet") + + plaintext, err := d.decrypt(p.Payload) + if err != nil { + return nil, err + } + + // get plaintext payload from the decrypted plaintext + return maybeDecompress(plaintext, d.state, d.options) +} + +func (d *DataChannel) decrypt(encrypted []byte) ([]byte, error) { + if d.decryptFn == nil { + return []byte{}, errInitError + } + if len(d.state.hmacKeyRemote) == 0 { + d.log.Warn("decrypt: not ready yet") + return nil, ErrCannotDecrypt + } + encryptedData, err := d.decodeEncryptedPayload(encrypted, d.state) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotDecrypt, err) + } + + plainText, err := d.decryptFn(d.state.cipherKeyRemote[:], encryptedData) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotDecrypt, err) + } + return plainText, nil +} diff --git a/internal/datachannel/crypto.go b/internal/datachannel/crypto.go new file mode 100644 index 00000000..30f1a7e5 --- /dev/null +++ b/internal/datachannel/crypto.go @@ -0,0 +1,362 @@ +package datachannel + +// +// Code to perform encryption, decryption and key derivation. +// + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" + "log" + + "github.com/ooni/minivpn/internal/bytesx" +) //#nosec G501,G505 + +// TODO(ainghazal,bassosimone): see if it's feasible to use stdlib +// functionality rather than using the code below. + +type ( + // cipherMode describes a cipher mode (e.g., GCM). + cipherMode string + + // cipherName is a cipher name (e.g., AES). + cipherName string +) + +const ( + // cipherModeCBC is the CBC cipher mode. + cipherModeCBC = cipherMode("cbc") + + // cipherModeGCM is the GCM cipher mode. + cipherModeGCM = cipherMode("gcm") + + // cipherNameAES is an AES-based cipher. + cipherNameAES = cipherName("aes") +) + +// encrypteData holds the different parts needed to decrypt an encrypted data +// packet. +type encryptedData struct { + iv []byte + ciphertext []byte + aead []byte +} + +// plaintextData holds the different parts needed to encrypt a plaintext +// payload (after padding). +type plaintextData struct { + iv []byte + plaintext []byte + aead []byte +} + +// dataCipher encrypts and decrypts OpenVPN data. +type dataCipher interface { + // keySizeBytes returns the key size (in bytes). + keySizeBytes() int + + // isAEAD returns whether this cipher has AEAD properties. + isAEAD() bool + + // blockSize returns the expected block size. + blockSize() uint8 + + // encrypt encripts a plaintext. + // + // Arguments: + // + // - key is the key, whose size must be consistent with the cipher; + // + // - plaintextData is the data to be encrypted; + // + // Returns the ciphertext on success and an error on failure. + encrypt([]byte, *plaintextData) ([]byte, error) + + // decrypt is the opposite operation of encrypt. It takes in input the + // ciphertext and returns the plaintext of an error. + decrypt([]byte, *encryptedData) ([]byte, error) + + // mode returns the cipherMode + cipherMode() cipherMode +} + +// dataCipherAES implements dataCipher for AES. +type dataCipherAES struct { + // ksb is the key size in bytes + ksb int + + // mode is the cipher mode + mode cipherMode +} + +var _ dataCipher = &dataCipherAES{} // Ensure we implement dataCipher + +// keySizeBytes implements dataCipher.keySizeBytes +func (a *dataCipherAES) keySizeBytes() int { + return a.ksb +} + +// isAEAD implements dataCipher.isAEAD +func (a *dataCipherAES) isAEAD() bool { + return a.mode != cipherModeCBC +} + +// blockSize implements dataCipher.BlockSize +func (a *dataCipherAES) blockSize() uint8 { + switch a.mode { + case cipherModeCBC, cipherModeGCM: + return 16 + default: + return 0 + } +} + +// decrypt implements dataCipher.decrypt. +// Since key comes from a prf derivation, we only take as many bytes as we need to match +// our key size. +func (a *dataCipherAES) decrypt(key []byte, data *encryptedData) ([]byte, error) { + // TODO(ainghazal): split this function, it's too large + if len(key) < a.keySizeBytes() { + return nil, errInvalidKeySize + } + + // they key material might be longer + k := key[:a.keySizeBytes()] + block, err := aes.NewCipher(k) + if err != nil { + return nil, err + } + switch a.mode { + case cipherModeCBC: + if len(data.iv) != block.BlockSize() { + return nil, fmt.Errorf("%w: wrong size for iv: %v", ErrCannotDecrypt, len(data.iv)) + } + mode := cipher.NewCBCDecrypter(block, data.iv) + plaintext := make([]byte, len(data.ciphertext)) + mode.CryptBlocks(plaintext, data.ciphertext) + plaintext, err := bytesx.BytesUnpadPKCS7(plaintext, block.BlockSize()) + if err != nil { + return nil, err + } + padLen := len(data.ciphertext) - len(plaintext) + if padLen > block.BlockSize() || padLen > len(plaintext) { + // TODO(bassosimone, ainghazal): discuss the cases in which + // this set of conditions actually occurs. + // TODO(ainghazal): this assertion might actually be moved into a + // boundary assertion in the unpad fun. + return nil, errors.New("unpadding error") + } + return plaintext, nil + + case cipherModeGCM: + // standard nonce size is 12. more is surely ok, but let's stick to it. + // https://github.com/golang/go/blob/master/src/crypto/aes/aes_gcm.go#L37 + if len(data.iv) != 12 { + return nil, fmt.Errorf("%w: wrong size for iv: %v", ErrCannotDecrypt, len(data.iv)) + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + plaintext, err := aesGCM.Open(nil, data.iv, data.ciphertext, data.aead) + if err != nil { + log.Println("gdm decryption failed:", err.Error()) + /* + log.Println("dump begins----") + log.Println("len:", len(data.ciphertext)) + log.Println("iv:", data.iv) + log.Printf("%v\n", data.ciphertext) + log.Printf("%x\n", data.ciphertext) + log.Printf("aead: %x\n", data.aead) + log.Println("dump ends------") + */ + return nil, err + } + return plaintext, nil + + default: + return nil, errUnsupportedMode + } +} + +func (a *dataCipherAES) cipherMode() cipherMode { + return a.mode +} + +// encrypt implements dataCipher.encrypt +// Since key comes from a prf derivation, we only take as many bytes as we need to match +// our key size. +func (a *dataCipherAES) encrypt(key []byte, data *plaintextData) ([]byte, error) { + if len(key) < a.keySizeBytes() { + return nil, errInvalidKeySize + } + k := key[:a.keySizeBytes()] + block, err := aes.NewCipher(k) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + switch a.mode { + case cipherModeCBC: + if len(data.iv) != blockSize { + return []byte{}, fmt.Errorf("%w: wrong size for iv: %v", ErrCannotEncrypt, len(data.iv)) + } + if len(data.plaintext)%blockSize != 0 { + return []byte{}, fmt.Errorf("%w: wrong padding", ErrCannotEncrypt) + } + mode := cipher.NewCBCEncrypter(block, data.iv) + + ciphertext := make([]byte, len(data.plaintext)) + mode.CryptBlocks(ciphertext, data.plaintext) + return ciphertext, nil + + case cipherModeGCM: + if len(data.iv) != 12 { + return []byte{}, fmt.Errorf("%w: wrong size for iv: %v", ErrCannotEncrypt, len(data.iv)) + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + // In GCM mode, the IV consists of the 32-bit packet counter + // followed by data from the HMAC key. The HMAC key can be used + // as IV, since in GCM mode the HMAC key is not used for the + // HMAC. The packet counter may not roll over within a single + // TLS session. This results in a unique IV for each packet, as + // required by GCM. + ciphertext := aesGCM.Seal(nil, data.iv, data.plaintext, data.aead) + return ciphertext, nil + + default: + return nil, errUnsupportedMode + } +} + +// newDataCipherFromCipherSuite constructs a new dataCipher from the cipher suite string. +func newDataCipherFromCipherSuite(c string) (dataCipher, error) { + switch c { + case "AES-128-CBC": + return newDataCipher(cipherNameAES, 128, cipherModeCBC) + case "AES-192-CBC": + return newDataCipher(cipherNameAES, 192, cipherModeCBC) + case "AES-256-CBC": + return newDataCipher(cipherNameAES, 256, cipherModeCBC) + case "AES-128-GCM": + return newDataCipher(cipherNameAES, 128, cipherModeGCM) + case "AES-256-GCM": + return newDataCipher(cipherNameAES, 256, cipherModeGCM) + default: + return nil, errUnsupportedCipher + } +} + +// newDataCipher constructs a new dataCipher from the given name, bits, and mode. +func newDataCipher(name cipherName, bits int, mode cipherMode) (dataCipher, error) { + if bits%8 != 0 || bits > 512 || bits < 64 { + return nil, fmt.Errorf("%w: %d", errInvalidKeySize, bits) + } + switch name { + case cipherNameAES: + default: + return nil, fmt.Errorf("%w: %s", errUnsupportedCipher, name) + } + switch mode { + case cipherModeCBC, cipherModeGCM: + default: + return nil, fmt.Errorf("%w: %s", errUnsupportedMode, mode) + } + dc := &dataCipherAES{ + ksb: bits / 8, + mode: mode, + } + return dc, nil +} + +// newHMACFactory accepts a label coming from an OpenVPN auth label, and returns two +// values: a function that will return a Hash implementation, and a boolean +// indicating if the operation was successful. +func newHMACFactory(name string) (func() hash.Hash, bool) { + switch name { + case "sha1": + return sha1.New, true + case "sha256": + return sha256.New, true + case "sha512": + return sha512.New, true + default: + return nil, false + } +} + +// prf function is used to derive master and client keys +func prf(secret, label, clientSeed, serverSeed, clientSid, serverSid []byte, olen int) []byte { + seed := append(clientSeed, serverSeed...) + if len(clientSid) != 0 { + seed = append(seed, clientSid...) + } + if len(serverSid) != 0 { + seed = append(seed, serverSid...) + } + result := make([]byte, olen) + return prf10(result, secret, label, seed) +} + +// Code below is taken from crypto/tls/prf.go +// Copyright 2009 The Go Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. +func prf10(result, secret, label, seed []byte) []byte { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + for i, b := range result2 { + result[i] ^= b + } + return result +} + +// SPDX-License-Identifier: BSD-3-Clause +// Split a premaster secret in two as specified in RFC 4346, Section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return + +} + +// SPDX-License-Identifier: BSD-3-Clause +// pHash implements the P_hash function, as defined in RFC 4346, Section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} diff --git a/internal/datachannel/errors.go b/internal/datachannel/errors.go new file mode 100644 index 00000000..cf35f4ac --- /dev/null +++ b/internal/datachannel/errors.go @@ -0,0 +1,29 @@ +package datachannel + +import "errors" + +var ( + errDataChannel = errors.New("datachannel error") + errDataChannelKey = errors.New("bad key") + errBadCompression = errors.New("bad compression") + errReplayAttack = errors.New("replay attack") + errBadHMAC = errors.New("bad hmac") + errInitError = errors.New("improperly initialized") + errExpiredKey = errors.New("key is expired") + + // errInvalidKeySize means that the key size is invalid. + errInvalidKeySize = errors.New("invalid key size") + + // errUnsupportedCipher indicates we don't support the desired cipher. + errUnsupportedCipher = errors.New("unsupported cipher") + + // errUnsupportedMode indicates that the mode is not uspported. + errUnsupportedMode = errors.New("unsupported mode") + + // errBadInput indicates invalid inputs to encrypt/decrypt functions. + errBadInput = errors.New("bad input") + + ErrSerialization = errors.New("cannot create packet") + ErrCannotEncrypt = errors.New("cannot encrypt") + ErrCannotDecrypt = errors.New("cannot decrypt") +) diff --git a/internal/datachannel/read.go b/internal/datachannel/read.go new file mode 100644 index 00000000..4d25be2d --- /dev/null +++ b/internal/datachannel/read.go @@ -0,0 +1,164 @@ +package datachannel + +import ( + "bytes" + "crypto/hmac" + "encoding/binary" + "errors" + "fmt" + + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/runtimex" + "github.com/ooni/minivpn/internal/session" +) + +func decodeEncryptedPayloadAEAD(log model.Logger, buf []byte, session *session.Manager, state *dataChannelState) (*encryptedData, error) { + // P_DATA_V2 GCM data channel crypto format + // 48000001 00000005 7e7046bd 444a7e28 cc6387b1 64a4d6c1 380275a... + // [ OP32 ] [seq # ] [ auth tag ] [ payload ... ] + // - means authenticated - * means encrypted * + // [ - opcode/peer-id - ] [ - packet ID - ] [ TAG ] [ * packet payload * ] + + // preconditions + + if len(buf) == 0 || len(buf) < 20 { + return nil, fmt.Errorf("too short: %d bytes", len(buf)) + } + if len(state.hmacKeyRemote) < 8 { + return nil, fmt.Errorf("bad remote hmac") + } + remoteHMAC := state.hmacKeyRemote[:8] + packet_id := buf[:4] + + headers := &bytes.Buffer{} + headers.WriteByte(opcodeAndKeyHeader(session)) + bytesx.WriteUint24(headers, uint32(session.TunnelInfo().PeerID)) + headers.Write(packet_id) + + // we need to swap because decryption expects payload|tag + // but we've got tag | payload instead + payload := &bytes.Buffer{} + payload.Write(buf[20:]) // ciphertext + payload.Write(buf[4:20]) // tag + + // iv := packetID | remoteHMAC + iv := &bytes.Buffer{} + iv.Write(packet_id) + iv.Write(remoteHMAC) + + encrypted := &encryptedData{ + iv: iv.Bytes(), + ciphertext: payload.Bytes(), + aead: headers.Bytes(), + } + return encrypted, nil +} + +var errCannotDecode = errors.New("cannot decode") + +func decodeEncryptedPayloadNonAEAD(log model.Logger, buf []byte, session *session.Manager, state *dataChannelState) (*encryptedData, error) { + runtimex.Assert(state != nil, "passed nil state") + runtimex.Assert(state.dataCipher != nil, "data cipher not initialized") + + hashSize := uint8(state.hmacRemote.Size()) + blockSize := state.dataCipher.blockSize() + + minLen := hashSize + blockSize + + if len(buf) < int(minLen) { + return &encryptedData{}, fmt.Errorf("%w: too short (%d bytes)", errCannotDecode, len(buf)) + } + + receivedHMAC := buf[:hashSize] + iv := buf[hashSize : hashSize+blockSize] + cipherText := buf[hashSize+blockSize:] + + state.hmacRemote.Reset() + state.hmacRemote.Write(iv) + state.hmacRemote.Write(cipherText) + computedHMAC := state.hmacRemote.Sum(nil) + + if !hmac.Equal(computedHMAC, receivedHMAC) { + log.Warnf("expected: %x, got: %x", computedHMAC, receivedHMAC) + return &encryptedData{}, fmt.Errorf("%w: %s", ErrCannotDecrypt, errBadHMAC) + } + + encrypted := &encryptedData{ + iv: iv, + ciphertext: cipherText, + aead: []byte{}, // no AEAD data in this mode, leaving it empty to satisfy common interface + } + return encrypted, nil +} + +// maybeDecompress de-serializes the data from the payload according to the framing +// given by different compression methods. only the different no-compression +// modes are supported at the moment, so no real decompression is done. It +// returns a byte array, and an error if the operation could not be completed +// successfully. +func maybeDecompress(b []byte, st *dataChannelState, opt *model.Options) ([]byte, error) { + if st == nil || st.dataCipher == nil { + return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad state") + } + if opt == nil { + return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad options") + } + + var compr byte // compression type + var payload []byte + + // TODO(ainghazal): have two different decompress implementations + // instead of this switch + switch st.dataCipher.isAEAD() { + case true: + switch opt.Compress { + case model.CompressionStub, model.CompressionLZONo: + // these are deprecated in openvpn 2.5.x + compr = b[0] + payload = b[1:] + default: + compr = 0x00 + payload = b[:] + } + default: // non-aead + remotePacketID := model.PacketID(binary.BigEndian.Uint32(b[:4])) + lastKnownRemote, err := st.RemotePacketID() + if err != nil { + return payload, err + } + if remotePacketID <= lastKnownRemote { + return []byte{}, errReplayAttack + } + st.SetRemotePacketID(remotePacketID) + + switch opt.Compress { + case model.CompressionStub, model.CompressionLZONo: + compr = b[4] + payload = b[5:] + default: + compr = 0x00 + payload = b[4:] + } + } + + switch compr { + case 0xfb: + // compression stub swap: + // we get the last byte and replace the compression byte + // these are deprecated in openvpn 2.5.x + end := payload[len(payload)-1] + b := payload[:len(payload)-1] + payload = append([]byte{end}, b...) + case 0x00, 0xfa: + // do nothing + // 0x00 is compress-no, + // 0xfa is the old no compression or comp-lzo no case. + // http://build.openvpn.net/doxygen/comp_8h_source.html + // see: https://community.openvpn.net/openvpn/ticket/952#comment:5 + default: + errMsg := fmt.Sprintf("cannot handle compression:%x", compr) + return []byte{}, fmt.Errorf("%w:%s", errBadCompression, errMsg) + } + return payload, nil +} diff --git a/internal/datachannel/service.go b/internal/datachannel/service.go new file mode 100644 index 00000000..1e043079 --- /dev/null +++ b/internal/datachannel/service.go @@ -0,0 +1,181 @@ +package datachannel + +// +// OpenVPN data channel +// + +import ( + "encoding/hex" + "fmt" + + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +// Service is the datachannel service. Make sure you initialize +// the channels before invoking [Service.StartWorkers]. +type Service struct { + // MuxerToData moves packets up to us + MuxerToData chan *model.Packet + // DataOrControlToMuxer is a shared channel to write packets to the muxer layer below + DataOrControlToMuxer *chan *model.Packet + // TUNToData moves bytes down from the TUN layer above + TUNToData chan []byte + // DataToTUN moves bytes up from us to the TUN layer above us + DataToTUN chan []byte + // KeyReady is where the TLSState layer passes us any new keys + KeyReady chan *session.DataChannelKey +} + +// StartWorkers starts the data-channel workers. +// +// We start three workers: +// +// 1. moveUpWorker BLOCKS on dataPacketUp to read a packet coming from the muxer and +// eventually BLOCKS on tunUp to deliver it; +// +// 2. moveDownWorker BLOCKS on tunDown to read a packet and +// eventually BLOCKS on packetDown to deliver it; +// +// 3. keyWorker BLOCKS on keyUp to read an dataChannelKey and +// initializes the internal state with the resulting key; + +func (s *Service) StartWorkers( + logger model.Logger, + workersManager *workers.Manager, + sessionManager *session.Manager, + options *model.Options, +) { + dc, err := NewDataChannelFromOptions(logger, options, sessionManager) + if err != nil { + logger.Warnf("cannot initialize channel %v", err) + return + } + ws := &workersState{ + logger: logger, + muxerToData: s.MuxerToData, + dataOrControlToMuxer: *s.DataOrControlToMuxer, + tunToData: s.TUNToData, + dataToTUN: s.DataToTUN, + keyReady: s.KeyReady, + dataChannel: dc, + newKey: make(chan any), + workersManager: workersManager, + sessionManager: sessionManager, + } + workersManager.StartWorker(ws.moveUpWorker) + workersManager.StartWorker(ws.moveDownWorker) + workersManager.StartWorker(ws.keyWorker) +} + +// workersState contains the data channel state. +type workersState struct { + logger model.Logger + workersManager *workers.Manager + sessionManager *session.Manager + keyReady <-chan *session.DataChannelKey + muxerToData <-chan *model.Packet + dataOrControlToMuxer chan<- *model.Packet + dataToTUN chan<- []byte + tunToData <-chan []byte + dataChannel *DataChannel + newKey chan any +} + +// moveDownWorker moves packets down the stack. It will BLOCK on PacketDown +func (ws *workersState) moveDownWorker() { + defer func() { + ws.workersManager.OnWorkerDone() + ws.workersManager.StartShutdown() + ws.logger.Debug("datachannel: moveDownWorker: done") + }() + for { + select { + // wait for the key to be ready + case <-ws.newKey: + for { + select { + case data := <-ws.tunToData: + packet, err := ws.dataChannel.writePacket(data) + if err != nil { + ws.logger.Warnf("error encrypting: %v", err) + continue + } + // ws.logger.Infof("encrypted %d bytes", len(packet.Payload)) + + select { + case ws.dataOrControlToMuxer <- packet: + default: + // drop the packet if the buffer is full + case <-ws.workersManager.ShouldShutdown(): + return + } + + case <-ws.workersManager.ShouldShutdown(): + return + } + } + case <-ws.workersManager.ShouldShutdown(): + return + } + } +} + +// moveUpWorker moves packets up the stack +func (ws *workersState) moveUpWorker() { + defer func() { + ws.workersManager.OnWorkerDone() + ws.workersManager.StartShutdown() + ws.logger.Debug("datachannel: moveUpWorker: done") + }() + for { + select { + case pkt := <-ws.muxerToData: + // TODO(ainghazal): factor out as handler function + decrypted, err := ws.dataChannel.readPacket(pkt) + if err != nil { + ws.logger.Warnf("error decrypting: %v", err) + continue + } + + if len(decrypted) == 16 { + // HACK - figure out what this fixed packet is. keepalive? + // "2a 18 7b f3 64 1e b4 cb 07 ed 2d 0a 98 1f c7 48" + fmt.Println(hex.Dump(decrypted)) + continue + } + + // fmt.Printf("< decrypted %v bytes\n", len(decrypted)) + ws.dataToTUN <- decrypted + case <-ws.workersManager.ShouldShutdown(): + return + } + } +} + +// keyWorker receives notifications from key ready +func (ws *workersState) keyWorker() { + defer func() { + ws.workersManager.OnWorkerDone() + ws.workersManager.StartShutdown() + ws.logger.Debug("datachannel: worker: done") + }() + + ws.logger.Debug("datachannel: worker: started") + for { + select { + case key := <-ws.keyReady: + err := ws.dataChannel.setupKeys(key) + if err != nil { + ws.logger.Warnf("error on key derivation: %v", err) + continue + } + ws.sessionManager.SetNegotiationState(session.S_GENERATED_KEYS) + ws.newKey <- true + + case <-ws.workersManager.ShouldShutdown(): + return + } + } +} diff --git a/internal/datachannel/state.go b/internal/datachannel/state.go new file mode 100644 index 00000000..a44a22b0 --- /dev/null +++ b/internal/datachannel/state.go @@ -0,0 +1,57 @@ +package datachannel + +import ( + "hash" + "math" + "sync" + + "github.com/ooni/minivpn/internal/model" +) + +// keySlot holds the different local and remote keys. +type keySlot [64]byte + +// dataChannelState is the state of the data channel. +type dataChannelState struct { + dataCipher dataCipher + + // outgoing and incoming nomenclature is probably more adequate here. + hmacLocal hash.Hash + hmacRemote hash.Hash + cipherKeyLocal keySlot + cipherKeyRemote keySlot + hmacKeyLocal keySlot + hmacKeyRemote keySlot + /* + keyID int // not used at the moment, paving the way for key rotation. + peerID int + */ + + // TODO(ainghazal): we need to keep a local packetID too. It should be separated from the control channel. + // TODO: move this to sessionManager perhaps? + remotePacketID model.PacketID + + hash func() hash.Hash + mu sync.Mutex +} + +// SetRemotePacketID stores the passed packetID internally. +func (dcs *dataChannelState) SetRemotePacketID(id model.PacketID) { + dcs.mu.Lock() + defer dcs.mu.Unlock() + dcs.remotePacketID = model.PacketID(id) +} + +// RemotePacketID returns the last known remote packetID. It returns an error +// if the stored packet id has reached the maximum capacity of the packetID +// type. +func (dcs *dataChannelState) RemotePacketID() (model.PacketID, error) { + dcs.mu.Lock() + defer dcs.mu.Unlock() + pid := dcs.remotePacketID + if pid == math.MaxUint32 { + // we reached the max packetID, increment will overflow + return 0, errExpiredKey + } + return pid, nil +} diff --git a/internal/datachannel/write.go b/internal/datachannel/write.go new file mode 100644 index 00000000..995389cf --- /dev/null +++ b/internal/datachannel/write.go @@ -0,0 +1,169 @@ +package datachannel + +// +// Functions for encoding & writing packets +// + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" +) + +// encryptAndEncodePayloadAEAD peforms encryption and encoding of the payload in AEAD modes (i.e., AES-GCM). +// TODO(ainghazal): for testing we can pass both the state object and the encryptFn +func encryptAndEncodePayloadAEAD(log model.Logger, padded []byte, session *session.Manager, state *dataChannelState) ([]byte, error) { + // TODO(ainghazal): call Session.NewPacket() instead? + nextPacketID, err := session.LocalDataPacketID() + if err != nil { + return []byte{}, fmt.Errorf("bad packet id") + } + + // in AEAD mode, we authenticate: + // - 1 byte: opcode/key + // - 3 bytes: peer-id (we're using P_DATA_V2) + // - 4 bytes: packet-id + aead := &bytes.Buffer{} + aead.WriteByte(opcodeAndKeyHeader(session)) + bytesx.WriteUint24(aead, uint32(session.TunnelInfo().PeerID)) + bytesx.WriteUint32(aead, uint32(nextPacketID)) + + // the iv is the packetID (again) concatenated with the 8 bytes of the + // key derived for local hmac (which we do not use for anything else in AEAD mode). + iv := &bytes.Buffer{} + bytesx.WriteUint32(iv, uint32(nextPacketID)) + iv.Write(state.hmacKeyLocal[:8]) + + data := &plaintextData{ + iv: iv.Bytes(), + plaintext: padded, + aead: aead.Bytes(), + } + + encryptFn := state.dataCipher.encrypt + encrypted, err := encryptFn(state.cipherKeyLocal[:], data) + if err != nil { + return []byte{}, err + } + + // some reordering, because openvpn uses tag | payload + boundary := len(encrypted) - 16 + tag := encrypted[boundary:] + ciphertext := encrypted[:boundary] + + // we now write to the output buffer + out := bytes.Buffer{} + out.Write(data.aead) // opcode|peer-id|packet_id + out.Write(tag) + out.Write(ciphertext) + return out.Bytes(), nil + +} + +// encryptAndEncodePayloadNonAEAD peforms encryption and encoding of the payload in Non-AEAD modes (i.e., AES-CBC). +func encryptAndEncodePayloadNonAEAD(log model.Logger, padded []byte, session *session.Manager, state *dataChannelState) ([]byte, error) { + // For iv generation, OpenVPN uses a nonce-based PRNG that is initially seeded with + // OpenSSL RAND_bytes function. I am assuming this is good enough for our current purposes. + blockSize := state.dataCipher.blockSize() + + iv, err := bytesx.GenRandomBytes(int(blockSize)) + if err != nil { + return nil, err + } + data := &plaintextData{ + iv: iv, + plaintext: padded, + aead: nil, + } + + encryptFn := state.dataCipher.encrypt + ciphertext, err := encryptFn(state.cipherKeyLocal[:], data) + if err != nil { + return nil, err + } + + state.hmacLocal.Reset() + state.hmacLocal.Write(iv) + state.hmacLocal.Write(ciphertext) + computedMAC := state.hmacLocal.Sum(nil) + + out := &bytes.Buffer{} + out.WriteByte(opcodeAndKeyHeader(session)) + bytesx.WriteUint24(out, uint32(session.TunnelInfo().PeerID)) + + out.Write(computedMAC) + out.Write(iv) + out.Write(ciphertext) + return out.Bytes(), nil +} + +// doCompress adds compression bytes if needed by the passed compression options. +// if the compression stub is on, it sends the first byte to the last position, +// and it adds the compression preamble, according to the spec. compression +// lzo-no also adds a preamble. It returns a byte array and an error if the +// operation could not be completed. +func doCompress(b []byte, compress model.Compression) ([]byte, error) { + switch compress { + case "stub": + // compression stub: send first byte to last + // and add 0xfb marker on the first byte. + b = append(b, b[0]) + b[0] = 0xfb + case "lzo-no": + // old "comp-lzo no" option + b = append([]byte{0xfa}, b...) + } + return b, nil +} + +var errPadding = errors.New("padding error") + +// doPadding does pkcs7 padding of the encryption payloads as +// needed. if we're using the compression stub the padding is applied without taking the +// trailing bit into account. it returns the resulting byte array, and an error +// if the operatio could not be completed. +func doPadding(b []byte, compress model.Compression, blockSize uint8) ([]byte, error) { + if len(b) == 0 { + return nil, fmt.Errorf("%w: %s", errPadding, "nothing to pad") + } + if compress == "stub" { + // if we're using the compression stub + // we need to account for a trailing byte + // that we have appended in the doCompress stage. + endByte := b[len(b)-1] + padded, err := bytesx.BytesPadPKCS7(b[:len(b)-1], int(blockSize)) + if err != nil { + return nil, err + } + padded[len(padded)-1] = endByte + return padded, nil + } + padded, err := bytesx.BytesPadPKCS7(b, int(blockSize)) + if err != nil { + return nil, err + } + return padded, nil +} + +// TODO(ainghazal): move to a different layer? +// prependPacketID returns the original buffer with the passed packetID +// concatenated at the beginning. +func prependPacketID(p model.PacketID, buf []byte) []byte { + newbuf := &bytes.Buffer{} + packetID := make([]byte, 4) + binary.BigEndian.PutUint32(packetID, uint32(p)) + newbuf.Write(packetID[:]) + newbuf.Write(buf) + return newbuf.Bytes() +} + +// opcodeAndKeyHeader returns the header byte encoding the opcode and keyID (3 upper +// and 5 lower bits, respectively) +func opcodeAndKeyHeader(session *session.Manager) byte { + return byte((byte(model.P_DATA_V2) << 3) | (byte(session.CurrentKeyID()) & 0x07)) +} From 8addd7dba773915f581ead02b556db947202f66f Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 19 Jan 2024 13:28:50 +0100 Subject: [PATCH 2/7] add doc.go --- internal/datachannel/doc.go | 4 ++++ internal/datachannel/state.go | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 internal/datachannel/doc.go diff --git a/internal/datachannel/doc.go b/internal/datachannel/doc.go new file mode 100644 index 00000000..0b81c4a8 --- /dev/null +++ b/internal/datachannel/doc.go @@ -0,0 +1,4 @@ +// Package datachannel implements packet encryption and decryption over the OpenVPN Data Channel. +// Encryption Keys are derived after a successful TLS handshake, and they have a limited +// lifetime. +package datachannel diff --git a/internal/datachannel/state.go b/internal/datachannel/state.go index a44a22b0..7c3fc33e 100644 --- a/internal/datachannel/state.go +++ b/internal/datachannel/state.go @@ -22,9 +22,10 @@ type dataChannelState struct { cipherKeyRemote keySlot hmacKeyLocal keySlot hmacKeyRemote keySlot + /* - keyID int // not used at the moment, paving the way for key rotation. - peerID int + // not used at the moment, paving the way for key rotation. + keyID int */ // TODO(ainghazal): we need to keep a local packetID too. It should be separated from the control channel. From 45a5e1ae33853b345801fc91ebab6fd3d05abf80 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 19 Jan 2024 13:33:11 +0100 Subject: [PATCH 3/7] white space --- internal/datachannel/service.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/datachannel/service.go b/internal/datachannel/service.go index 1e043079..68644277 100644 --- a/internal/datachannel/service.go +++ b/internal/datachannel/service.go @@ -18,12 +18,16 @@ import ( type Service struct { // MuxerToData moves packets up to us MuxerToData chan *model.Packet + // DataOrControlToMuxer is a shared channel to write packets to the muxer layer below DataOrControlToMuxer *chan *model.Packet + // TUNToData moves bytes down from the TUN layer above TUNToData chan []byte + // DataToTUN moves bytes up from us to the TUN layer above us DataToTUN chan []byte + // KeyReady is where the TLSState layer passes us any new keys KeyReady chan *session.DataChannelKey } @@ -146,7 +150,6 @@ func (ws *workersState) moveUpWorker() { continue } - // fmt.Printf("< decrypted %v bytes\n", len(decrypted)) ws.dataToTUN <- decrypted case <-ws.workersManager.ShouldShutdown(): return From 0af71a5843946b840d32b41aa46c6c6d4635f4fd Mon Sep 17 00:00:00 2001 From: Ain Ghazal <99027643+ainghazal@users.noreply.github.com> Date: Fri, 19 Jan 2024 17:25:28 +0100 Subject: [PATCH 4/7] Update internal/datachannel/service.go Co-authored-by: Simone Basso --- internal/datachannel/service.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/datachannel/service.go b/internal/datachannel/service.go index 68644277..f6dca853 100644 --- a/internal/datachannel/service.go +++ b/internal/datachannel/service.go @@ -44,7 +44,6 @@ type Service struct { // // 3. keyWorker BLOCKS on keyUp to read an dataChannelKey and // initializes the internal state with the resulting key; - func (s *Service) StartWorkers( logger model.Logger, workersManager *workers.Manager, From 98c08efe3f9ecf5608e5297e9b6cc34443feceef Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 19 Jan 2024 17:31:06 +0100 Subject: [PATCH 5/7] remove unused code --- internal/datachannel/controller.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/internal/datachannel/controller.go b/internal/datachannel/controller.go index a97714b8..839301a9 100644 --- a/internal/datachannel/controller.go +++ b/internal/datachannel/controller.go @@ -172,10 +172,6 @@ func (d *DataChannel) writePacket(payload []byte) (*model.Packet, error) { // and trigger renegotiation if we're near the end of the key useful lifetime. packet := model.NewPacket(model.P_DATA_V2, d.sessionManager.CurrentKeyID(), encrypted) - // packet, err := d.sessionManager.NewPacket(model.P_DATA_V2, encrypted) - //if err != nil { - // return nil, fmt.Errorf("%w: %s", ErrSerialization, err) - // } peerid := &bytes.Buffer{} bytesx.WriteUint24(peerid, uint32(d.sessionManager.TunnelInfo().PeerID)) packet.PeerID = model.PeerID(peerid.Bytes()) From 07f620adf56f13d81d597da7b4bbe65fd1963715 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 19 Jan 2024 18:22:38 +0100 Subject: [PATCH 6/7] added todos --- internal/datachannel/service.go | 72 +++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/internal/datachannel/service.go b/internal/datachannel/service.go index f6dca853..28bd61c4 100644 --- a/internal/datachannel/service.go +++ b/internal/datachannel/service.go @@ -7,6 +7,7 @@ package datachannel import ( "encoding/hex" "fmt" + "sync" "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/session" @@ -63,13 +64,15 @@ func (s *Service) StartWorkers( dataToTUN: s.DataToTUN, keyReady: s.KeyReady, dataChannel: dc, - newKey: make(chan any), workersManager: workersManager, sessionManager: sessionManager, } + + firstKeyReady := make(chan any) + workersManager.StartWorker(ws.moveUpWorker) - workersManager.StartWorker(ws.moveDownWorker) - workersManager.StartWorker(ws.keyWorker) + workersManager.StartWorker(func() { ws.moveDownWorker(firstKeyReady) }) + workersManager.StartWorker(func() { ws.keyWorker(firstKeyReady) }) } // workersState contains the data channel state. @@ -83,45 +86,43 @@ type workersState struct { dataToTUN chan<- []byte tunToData <-chan []byte dataChannel *DataChannel - newKey chan any } // moveDownWorker moves packets down the stack. It will BLOCK on PacketDown -func (ws *workersState) moveDownWorker() { +func (ws *workersState) moveDownWorker(firstKeyReady <-chan any) { defer func() { ws.workersManager.OnWorkerDone() ws.workersManager.StartShutdown() ws.logger.Debug("datachannel: moveDownWorker: done") }() - for { - select { - // wait for the key to be ready - case <-ws.newKey: - for { - select { - case data := <-ws.tunToData: - packet, err := ws.dataChannel.writePacket(data) - if err != nil { - ws.logger.Warnf("error encrypting: %v", err) - continue - } - // ws.logger.Infof("encrypted %d bytes", len(packet.Payload)) - - select { - case ws.dataOrControlToMuxer <- packet: - default: - // drop the packet if the buffer is full - case <-ws.workersManager.ShouldShutdown(): - return - } + select { + // wait for the first key to be ready + case <-firstKeyReady: + for { + select { + case data := <-ws.tunToData: + // TODO: writePacket should get the ACTIVE KEY (verify this) + packet, err := ws.dataChannel.writePacket(data) + if err != nil { + ws.logger.Warnf("error encrypting: %v", err) + continue + } + // ws.logger.Infof("encrypted %d bytes", len(packet.Payload)) + select { + case ws.dataOrControlToMuxer <- packet: + default: + // drop the packet if the buffer is full case <-ws.workersManager.ShouldShutdown(): return } + + case <-ws.workersManager.ShouldShutdown(): + return } - case <-ws.workersManager.ShouldShutdown(): - return } + case <-ws.workersManager.ShouldShutdown(): + return } } @@ -134,6 +135,8 @@ func (ws *workersState) moveUpWorker() { }() for { select { + // TODO: opportunistically try to kill lame duck + case pkt := <-ws.muxerToData: // TODO(ainghazal): factor out as handler function decrypted, err := ws.dataChannel.readPacket(pkt) @@ -157,7 +160,7 @@ func (ws *workersState) moveUpWorker() { } // keyWorker receives notifications from key ready -func (ws *workersState) keyWorker() { +func (ws *workersState) keyWorker(firstKeyReady chan<- any) { defer func() { ws.workersManager.OnWorkerDone() ws.workersManager.StartShutdown() @@ -165,16 +168,25 @@ func (ws *workersState) keyWorker() { }() ws.logger.Debug("datachannel: worker: started") + once := &sync.Once{} + for { select { case key := <-ws.keyReady: + // TODO(keyrotation): thread safety here - need to lock. + // When we actually get to key rotation, we need to add locks. + // Use RW lock, reader locks. + err := ws.dataChannel.setupKeys(key) if err != nil { ws.logger.Warnf("error on key derivation: %v", err) continue } ws.sessionManager.SetNegotiationState(session.S_GENERATED_KEYS) - ws.newKey <- true + once.Do(func() { + close(firstKeyReady) + }) + //ws.newKey <- true case <-ws.workersManager.ShouldShutdown(): return From dad889bada9f25eda7fb6bf38da9f0d35b4bd8f4 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 19 Jan 2024 18:35:40 +0100 Subject: [PATCH 7/7] x --- internal/datachannel/controller.go | 1 - internal/datachannel/crypto.go | 1 + internal/datachannel/service.go | 1 - 3 files changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/datachannel/controller.go b/internal/datachannel/controller.go index 839301a9..40b5ec8a 100644 --- a/internal/datachannel/controller.go +++ b/internal/datachannel/controller.go @@ -181,7 +181,6 @@ func (d *DataChannel) writePacket(payload []byte) (*model.Packet, error) { // encrypt calls the corresponding function for AEAD or Non-AEAD decryption. // Due to the particularities of the iv generation on each of the modes, encryption and encoding are // done together in the same function. -// TODO accept state for symmetry func (d *DataChannel) encryptAndEncodePayload(plaintext []byte, dcs *dataChannelState) ([]byte, error) { runtimex.Assert(dcs != nil, "datachanelState is nil") runtimex.Assert(dcs.dataCipher != nil, "dcs.dataCipher is nil") diff --git a/internal/datachannel/crypto.go b/internal/datachannel/crypto.go index 30f1a7e5..c35086f3 100644 --- a/internal/datachannel/crypto.go +++ b/internal/datachannel/crypto.go @@ -19,6 +19,7 @@ import ( "github.com/ooni/minivpn/internal/bytesx" ) //#nosec G501,G505 +// We know that sha1 and md5 are insecure, but we do not control the openvpn protocol. // TODO(ainghazal,bassosimone): see if it's feasible to use stdlib // functionality rather than using the code below. diff --git a/internal/datachannel/service.go b/internal/datachannel/service.go index 28bd61c4..7e263291 100644 --- a/internal/datachannel/service.go +++ b/internal/datachannel/service.go @@ -186,7 +186,6 @@ func (ws *workersState) keyWorker(firstKeyReady chan<- any) { once.Do(func() { close(firstKeyReady) }) - //ws.newKey <- true case <-ws.workersManager.ShouldShutdown(): return