diff --git a/.github/workflows/build-refactor.yml b/.github/workflows/build-refactor.yml new file mode 100644 index 00000000..a822bf03 --- /dev/null +++ b/.github/workflows/build-refactor.yml @@ -0,0 +1,57 @@ +name: build-refactor +# this action is covering internal/ tree with go1.21 + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + short-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: setup go + uses: actions/setup-go@v2 + with: + go-version: '1.21' + - name: Run short tests + run: go test --short -cover ./internal/... + + gosec: + runs-on: ubuntu-latest + env: + GO111MODULE: on + steps: + - name: Checkout Source + uses: actions/checkout@v2 + - name: Run Gosec security scanner + uses: securego/gosec@master + with: + args: '-no-fail ./...' + + coverage-threshold: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: setup go + uses: actions/setup-go@v2 + with: + go-version: '1.21' + - name: Ensure coverage threshold + run: make test-coverage-threshold-refactor + + integration: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: setup go + uses: actions/setup-go@v2 + with: + go-version: '1.21' + - name: run integration tests + run: go test -v ./tests/integration + diff --git a/.gitignore b/.gitignore index 6b0e6c86..0d839243 100644 --- a/.gitignore +++ b/.gitignore @@ -8,8 +8,6 @@ *.swo *.pem *.ovpn +/*.out data/* measurements/* -coverage.out -coverage-ping.out -cov-threshold.out diff --git a/Makefile b/Makefile index eaad788c..881d2423 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ TARGET ?= "1.1.1.1" COUNT ?= 5 TIMEOUT ?= 10 LOCAL_TARGET := $(shell ip -4 addr show docker0 | grep 'inet ' | awk '{print $$2}' | cut -f 1 -d /) -COVERAGE_THRESHOLD := 88 +COVERAGE_THRESHOLD := 80 FLAGS=-ldflags="-w -s -buildid=none -linkmode=external" -buildmode=pie -buildvcs=false build: @@ -33,10 +33,17 @@ test: test-coverage: go test -coverprofile=coverage.out ./vpn +test-coverage-refactor: + go test -coverprofile=coverage.out ./internal/... + test-coverage-threshold: go test --short -coverprofile=cov-threshold.out ./vpn ./scripts/go-coverage-check.sh cov-threshold.out ${COVERAGE_THRESHOLD} +test-coverage-threshold-refactor: + go test --short -coverprofile=cov-threshold-refactor.out ./internal/... + ./scripts/go-coverage-check.sh cov-threshold-refactor.out ${COVERAGE_THRESHOLD} + test-short: go test -race -short -v ./... diff --git a/internal/model/packet.go b/internal/model/packet.go index 210e0f60..f1256a4b 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -32,6 +32,33 @@ const ( P_DATA_V2 // 9 ) +// NewOpcodeFromString returns an opcode from a string representation, and an error if it cannot parse the opcode +// representation. The zero return value is invalid and always coupled with a non-nil error. +func NewOpcodeFromString(s string) (Opcode, error) { + switch s { + case "CONTROL_HARD_RESET_CLIENT_V1": + return P_CONTROL_HARD_RESET_CLIENT_V1, nil + case "CONTROL_HARD_RESET_SERVER_V1": + return P_CONTROL_HARD_RESET_SERVER_V1, nil + case "CONTROL_SOFT_RESET_V1": + return P_CONTROL_SOFT_RESET_V1, nil + case "CONTROL_V1": + return P_CONTROL_V1, nil + case "ACK_V1": + return P_ACK_V1, nil + case "DATA_V1": + return P_DATA_V1, nil + case "CONTROL_HARD_RESET_CLIENT_V2": + return P_CONTROL_HARD_RESET_CLIENT_V2, nil + case "CONTROL_HARD_RESET_SERVER_V2": + return P_CONTROL_HARD_RESET_SERVER_V2, nil + case "DATA_V2": + return P_DATA_V2, nil + default: + return 0, errors.New("unknown opcode") + } +} + // String returns the opcode string representation func (op Opcode) String() string { switch op { diff --git a/internal/reliabletransport/common_test.go b/internal/reliabletransport/common_test.go new file mode 100644 index 00000000..63581f48 --- /dev/null +++ b/internal/reliabletransport/common_test.go @@ -0,0 +1,47 @@ +package reliabletransport + +import ( + "github.com/apex/log" + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/runtimex" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +// +// Common utilities for tests in this package. +// + +// initManagers initializes a workers manager and a session manager. +func initManagers() (*workers.Manager, *session.Manager) { + w := workers.NewManager(log.Log) + s, err := session.NewManager(log.Log) + runtimex.PanicOnError(err, "cannot create session manager") + return w, s +} + +// newRandomSessionID returns a random session ID to initialize mock sessions. +func newRandomSessionID() model.SessionID { + b, err := bytesx.GenRandomBytes(8) + if err != nil { + panic(err) + } + return model.SessionID(b) +} + +func ackSetFromInts(s []int) *ackSet { + acks := make([]model.PacketID, 0) + for _, i := range s { + acks = append(acks, model.PacketID(i)) + } + return newACKSet(acks...) +} + +func ackSetFromRange(start, total int) *ackSet { + acks := make([]model.PacketID, 0) + for i := 0; i < total; i++ { + acks = append(acks, model.PacketID(start+i)) + } + return newACKSet(acks...) +} diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 2e6a7828..4065f54b 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -3,7 +3,6 @@ package reliabletransport import ( "bytes" "fmt" - "log" "sort" "github.com/ooni/minivpn/internal/model" @@ -42,7 +41,7 @@ func (ws *workersState) moveUpWorker() { // We should be able to deterministically test how this affects the state machine. // drop a packet that is not for our session - if !bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID()) { + if !bytes.Equal(packet.RemoteSessionID[:], ws.sessionManager.LocalSessionID()) { ws.logger.Warnf( "%s: packet with invalid RemoteSessionID: expected %x; got %x", workerName, @@ -64,6 +63,7 @@ func (ws *workersState) moveUpWorker() { if inserted := receiver.MaybeInsertIncoming(packet); !inserted { // this packet was not inserted in the queue: we drop it + ws.logger.Debugf("Dropping packet: %v", packet.ID) continue } diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go new file mode 100644 index 00000000..fe48048b --- /dev/null +++ b/internal/reliabletransport/reliable_ack_test.go @@ -0,0 +1,164 @@ +package reliabletransport + +import ( + "slices" + "testing" + "time" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/vpntest" +) + +// test that everything that is received from below is eventually ACKed to the sender. +func TestReliable_ACK(t *testing.T) { + + log.SetLevel(log.DebugLevel) + + type args struct { + inputSequence []string + start int + wantacks int + } + + tests := []struct { + name string + args args + }{ + { + name: "ten ordered packets in", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[5] CONTROL_V1 +1ms", + "[6] CONTROL_V1 +1ms", + "[7] CONTROL_V1 +1ms", + "[8] CONTROL_V1 +1ms", + "[9] CONTROL_V1 +1ms", + "[10] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 10, + }, + }, + { + name: "five ordered packets with offset", + args: args{ + inputSequence: []string{ + "[100] CONTROL_V1 +1ms", + "[101] CONTROL_V1 +1ms", + "[102] CONTROL_V1 +1ms", + "[103] CONTROL_V1 +1ms", + "[104] CONTROL_V1 +1ms", + }, + start: 100, + wantacks: 5, + }, + }, + { + name: "five reversed packets", + args: args{ + inputSequence: []string{ + "[5] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 5, + }, + }, + { + name: "ten unordered packets with duplicates", + args: args{ + inputSequence: []string{ + "[5] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[5] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 5, + }, + }, + /* + { + name: "a burst of packets", + args: args{ + inputSequence: []string{ + "[5] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 5, + }, + }, + */ + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Service{} + + // just to properly initialize it, we don't care about these + s.ControlToReliable = make(chan *model.Packet) + // this one up to control/tls also needs to be buffered because otherwise + // we'll block on the receiver when delivering up. + reliableToControl := make(chan *model.Packet, 1024) + s.ReliableToControl = &reliableToControl + + // the only two channels we're going to be testing on this test + // we want to buffer enough to be safe writing to them. + dataIn := make(chan *model.Packet, 1024) + dataOut := make(chan *model.Packet, 1024) + + s.MuxerToReliable = dataIn // up + s.DataOrControlToMuxer = &dataOut // down + + workers, session := initManagers() + + t0 := time.Now() + + // let the workers pump up the jam! + s.StartWorkers(log.Log, workers, session) + + writer := vpntest.NewPacketWriter(dataIn) + + // initialize a mock session ID for our peer + peerSessionID := newRandomSessionID() + + writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) + writer.LocalSessionID = peerSessionID + session.SetRemoteSessionID(peerSessionID) + + go writer.WriteSequence(tt.args.inputSequence) + + reader := vpntest.NewPacketReader(dataOut) + witness := vpntest.NewWitness(reader) + + if ok := witness.VerifyNumberOfACKs(tt.args.start, tt.args.wantacks, t0); !ok { + got := len(witness.Log().ACKs()) + t.Errorf("TestACK: got = %v, want %v", got, tt.args.wantacks) + } + gotAckSet := ackSetFromInts(witness.Log().ACKs()).sorted() + wantAckSet := ackSetFromRange(tt.args.start, tt.args.wantacks).sorted() + + if !slices.Equal(gotAckSet, wantAckSet) { + t.Errorf("TestACK: got = %v, want %v", gotAckSet, wantAckSet) + + } + }) + } +} diff --git a/internal/reliabletransport/reliable_reorder_test.go b/internal/reliabletransport/reliable_reorder_test.go new file mode 100644 index 00000000..b4e94823 --- /dev/null +++ b/internal/reliabletransport/reliable_reorder_test.go @@ -0,0 +1,131 @@ +package reliabletransport + +import ( + "testing" + "time" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/vpntest" +) + +// test that we're able to reorder (towards TLS) whatever is received (from the muxer). +func TestReliable_Reordering_UP(t *testing.T) { + + log.SetLevel(log.DebugLevel) + + type args struct { + inputSequence []string + outputSequence []int + } + + tests := []struct { + name string + args args + }{ + { + name: "test with a well-ordered input sequence", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, + }, + { + name: "test reordering for input sequence", + args: args{ + inputSequence: []string{ + "[2] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, + }, + { + name: "test reordering for input sequence, longer waits", + args: args{ + inputSequence: []string{ + "[2] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +50ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, + }, + { + name: "test reordering for input sequence, with duplicates", + args: args{ + inputSequence: []string{ + "[2] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, + }, + { + name: "reordering with acks interspersed", + args: args{ + inputSequence: []string{ + "[2] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +2ms", + "[0] ACK_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[0] ACK_V1 +1ms", + "[1] CONTROL_V1 +2ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Service{} + + // just to properly initialize it, we don't care about these + s.ControlToReliable = make(chan *model.Packet) + dataToMuxer := make(chan *model.Packet) + s.DataOrControlToMuxer = &dataToMuxer + + // the only two channels we're going to be testing on this test + // we want to buffer enough to be safe writing to them. + dataIn := make(chan *model.Packet, 1024) + dataOut := make(chan *model.Packet, 1024) + + s.MuxerToReliable = dataIn + s.ReliableToControl = &dataOut + + workers, session := initManagers() + + t0 := time.Now() + + // let the workers pump up the jam! + s.StartWorkers(log.Log, workers, session) + + writer := vpntest.NewPacketWriter(dataIn) + + writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) + writer.LocalSessionID = newRandomSessionID() + + go writer.WriteSequence(tt.args.inputSequence) + + reader := vpntest.NewPacketReader(dataOut) + if ok := reader.WaitForSequence(tt.args.outputSequence, t0); !ok { + got := reader.Log().IDSequence() + t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) + } + }) + } +} diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 74ae03ca..989e81f9 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -105,6 +105,7 @@ func (ws *workersState) blockOnTryingToSend(sender *reliableSender, ticker *time ACK, err := ws.sessionManager.NewACKForPacketIDs(sender.NextPacketIDsToACK()) if err != nil { ws.logger.Warnf("moveDownWorker: tryToSend: cannot create ack: %v", err.Error()) + return } ACK.Log(ws.logger, model.DirectionOutgoing) select { @@ -195,7 +196,7 @@ func (r *reliableSender) maybeEvictOrMarkWithHigherACK(acked model.PacketID) { // shouldRescheduleAfterACK checks whether we need to wakeup after receiving an ACK. // TODO: change this depending on the handshake state -------------------------- func (r *reliableSender) shouldWakeupAfterACK(t time.Time) (bool, time.Duration) { - if r.pendingACKsToSend.Len() == 0 { + if r.pendingACKsToSend.Len() <= 0 { return false, time.Minute } // for two or more ACKs pending, we want to send right now. @@ -235,7 +236,7 @@ type ackSet struct { m map[model.PacketID]bool } -// NewACKSet creates a new empty ACK set. +// newACKSet creates a new empty ACK set. func newACKSet(ids ...model.PacketID) *ackSet { m := make(map[model.PacketID]bool) for _, id := range ids { diff --git a/internal/reliabletransport/service.go b/internal/reliabletransport/service.go index 3197e686..b2a1bd75 100644 --- a/internal/reliabletransport/service.go +++ b/internal/reliabletransport/service.go @@ -39,7 +39,7 @@ func (s *Service) StartWorkers( logger: logger, // incomingSeen is a buffered channel to avoid losing packets if we're busy // processing in the sender goroutine. - incomingSeen: make(chan incomingPacketSeen, 20), + incomingSeen: make(chan incomingPacketSeen, 100), dataOrControlToMuxer: *s.DataOrControlToMuxer, controlToReliable: s.ControlToReliable, muxerToReliable: s.MuxerToReliable, diff --git a/internal/runtimex/runtimex.go b/internal/runtimex/runtimex.go index 5e135484..2403c825 100644 --- a/internal/runtimex/runtimex.go +++ b/internal/runtimex/runtimex.go @@ -1,6 +1,8 @@ // Package runtimex contains [runtime] extensions. package runtimex +import "fmt" + // PanicIfFalse calls panic with the given message if the given statement is false. func PanicIfFalse(stmt bool, message interface{}) { if !stmt { @@ -17,3 +19,11 @@ func PanicIfTrue(stmt bool, message interface{}) { // Assert calls panic with the given message if the given statement is false. var Assert = PanicIfFalse + +// PanicOnError calls panic() if err is not nil. The type passed to panic +// is an error type wrapping the original error. +func PanicOnError(err error, message string) { + if err != nil { + panic(fmt.Errorf("%s: %w", message, err)) + } +} diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go new file mode 100644 index 00000000..923bfc84 --- /dev/null +++ b/internal/vpntest/packetio.go @@ -0,0 +1,172 @@ +package vpntest + +import ( + "slices" + "time" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" +) + +// PacketWriter writes packets into a channel. +type PacketWriter struct { + // A channel where to write packets to. + ch chan<- *model.Packet + + // LocalSessionID is needed to produce incoming packets that pass sanity checks. + LocalSessionID model.SessionID + + // RemoteSessionID is needed to produce ACKs. + RemoteSessionID model.SessionID +} + +// NewPacketWriter creates a new PacketWriter. +func NewPacketWriter(ch chan<- *model.Packet) *PacketWriter { + return &PacketWriter{ch: ch} +} + +// WriteSequence writes the passed packet sequence (in their string representation) +// to the configured channel. It will wait the specified interval between one packet and the next. +func (pw *PacketWriter) WriteSequence(seq []string) { + for _, testStr := range seq { + testPkt, err := NewTestPacketFromString(testStr) + if err != nil { + panic("PacketWriter: error reading test sequence:" + err.Error()) + } + + p := &model.Packet{ + Opcode: testPkt.Opcode, + RemoteSessionID: pw.RemoteSessionID, + LocalSessionID: pw.LocalSessionID, + ID: model.PacketID(testPkt.ID), + } + pw.ch <- p + time.Sleep(testPkt.IAT) + } +} + +// LoggedPacket is a trace of a received packet. +type LoggedPacket struct { + ID int + Opcode model.Opcode + ACKs []model.PacketID + At time.Duration +} + +// newLoggedPacket returns a pointer to LoggedPacket from a real packet and a origin of time. +func newLoggedPacket(p *model.Packet, origin time.Time) *LoggedPacket { + return &LoggedPacket{ + ID: int(p.ID), + Opcode: p.Opcode, + ACKs: p.ACKs, + At: time.Since(origin), + } +} + +// PacketLog is a sequence of LoggedPacket. +type PacketLog []*LoggedPacket + +// IDSequence returns a sequence of int from the logged packets. +func (l PacketLog) IDSequence() []int { + ids := make([]int, 0) + for _, p := range l { + ids = append(ids, int(p.ID)) + } + return ids +} + +// ACKs filters the log and returns an array of unique ids that have been acked +// either as ack packets or as part of the ack array of an outgoing packet. +func (l PacketLog) ACKs() []int { + acks := []int{} + for _, p := range l { + for _, ack := range p.ACKs { + a := int(ack) + if !contains(acks, a) { + acks = append(acks, a) + } + } + } + return acks +} + +// PacketReader reads packets from a channel. +type PacketReader struct { + ch <-chan *model.Packet + log []*LoggedPacket +} + +// NewPacketReader creates a new PacketReader. +func NewPacketReader(ch <-chan *model.Packet) *PacketReader { + logged := make([]*LoggedPacket, 0) + return &PacketReader{ch: ch, log: logged} +} + +// WaitForSequence loops reading from the internal channel until the logged +// sequence matches the len of the expected sequence; it returns +// true if the obtained packet ID sequence matches the expected one. +func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { + for { + // have we read enough packets to call it a day? + if len(pr.log) >= len(seq) { + break + } + // no, so let's keep reading until the test runner kills us + pkt := <-pr.ch + pr.log = append(pr.log, newLoggedPacket(pkt, start)) + log.Debugf("got packet: %v", pkt.ID) + } + // TODO(ainghazal): move the comparison to witness, leave only wait here + return slices.Equal(seq, PacketLog(pr.log).IDSequence()) +} + +func (pr *PacketReader) WaitForNumberOfACKs(total int, start time.Time) { + for { + // have we read enough acks to call it a day? + if len(PacketLog(pr.log).ACKs()) >= total { + break + } + // no, so let's keep reading until the test runner kills us + pkt := <-pr.ch + pr.log = append(pr.log, newLoggedPacket(pkt, start)) + log.Debugf("got packet: %v", pkt.ID) + } +} + +// Log returns the log of the received packets. +func (pr *PacketReader) Log() PacketLog { + return PacketLog(pr.log) +} + +// A Witness checks for different conditions over a reader +type Witness struct { + reader *PacketReader +} + +// NewWitness constructs a Witness from a [PacketReader]. +func NewWitness(r *PacketReader) *Witness { + return &Witness{r} +} + +// Log returns the packet log from the internal reader this witness uses. +func (w *Witness) Log() PacketLog { + return w.reader.Log() +} + +// VerifyNumberOfACKs tells the underlying reader to wait for a given number of acks, +// returns true if we have the same number of acks. +func (w *Witness) VerifyNumberOfACKs(start, total int, t time.Time) bool { + w.reader.WaitForNumberOfACKs(total, t) + return len(w.Log().ACKs()) == total +} + +// contains check if the element is in the slice. this is expensive, but it's only +// for tests and the alternative is to make ackSet public. +func contains(slice []int, target int) bool { + for _, item := range slice { + if item == target { + return true + } + } + return false +} diff --git a/internal/vpntest/packetio_test.go b/internal/vpntest/packetio_test.go new file mode 100644 index 00000000..3a312c7f --- /dev/null +++ b/internal/vpntest/packetio_test.go @@ -0,0 +1,81 @@ +package vpntest + +import ( + "testing" + "time" + + "github.com/ooni/minivpn/internal/model" +) + +func TestPacketReaderWriter(t *testing.T) { + type args struct { + input []string + output []int + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "simple input, simple output", + args: args{ + input: []string{ + "[1] CONTROL_V1 +0ms", + "[2] CONTROL_V1 +0ms", + "[3] CONTROL_V1 +0ms", + }, + output: []int{1, 2, 3}, + }, + want: true, + }, + { + name: "reverse in, reverse out", + args: args{ + input: []string{ + "[3] CONTROL_V1 +0ms", + "[2] CONTROL_V1 +0ms", + "[1] CONTROL_V1 +0ms", + }, + output: []int{3, 2, 1}, + }, + want: true, + }, + { + name: "holes in, holes out", + args: args{ + input: []string{ + "[0] CONTROL_V1 +0ms", + "[10] CONTROL_V1 +0ms", + "[1] CONTROL_V1 +0ms", + "[20] CONTROL_V1 +0ms", + }, + output: []int{0, 10, 1, 20}, + }, + want: true, + }, + { + name: "mismatch returns false", + args: args{ + input: []string{ + "[0] CONTROL_V1 +0ms", + "[1] CONTROL_V1 +0ms", + }, + output: []int{1, 0}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := make(chan *model.Packet) + writer := NewPacketWriter(ch) + go writer.WriteSequence(tt.args.input) + reader := NewPacketReader(ch) + if ok := reader.WaitForSequence(tt.args.output, time.Now()); ok != tt.want { + got := reader.Log().IDSequence() + t.Errorf("PacketReader.WaitForSequence() = %v, want %v", got, tt.args.output) + } + }) + } +} diff --git a/internal/vpntest/vpntest.go b/internal/vpntest/vpntest.go new file mode 100644 index 00000000..bdf67fce --- /dev/null +++ b/internal/vpntest/vpntest.go @@ -0,0 +1,88 @@ +// Package vpntest provides utilities for minivpn testing. +package vpntest + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/ooni/minivpn/internal/model" +) + +// TestPacket is used to simulate incoming packets over the network. The goal is to be able to +// have a compact representation of a sequence of packets, their type, and extra properties like +// inter-arrival time. +type TestPacket struct { + // Opcode is the OpenVPN packet opcode. + Opcode model.Opcode + + // ID is the packet sequence + ID int + + // ACKs is the ack array in this packet + ACKs []int + + // IAT is the inter-arrival time until the next packet is received. + IAT time.Duration +} + +// the test packet string is in the form: +// "[ID] OPCODE +42ms" +func NewTestPacketFromString(s string) (*TestPacket, error) { + parts := strings.Split(s, " +") + + // Extracting id, opcode and ack parts + head := strings.Split(parts[0], " ") + if len(head) < 2 || len(head) > 3 { + return nil, fmt.Errorf("invalid format for ID-op-acks: %s", parts[0]) + } + + id, err := strconv.Atoi(strings.Trim(head[0], "[]")) + if err != nil { + return nil, fmt.Errorf("failed to parse id: %v", err) + } + + opcode, err := model.NewOpcodeFromString(head[1]) + if err != nil { + return nil, fmt.Errorf("failed to parse opcode: %v", err) + } + + acks := []int{} + + if len(head) == 3 { + acks, err = parseACKs(strings.Trim(head[2], "()")) + fmt.Println("acks:", acks) + if err != nil { + return nil, fmt.Errorf("failed to parse opcode: %v", err) + } + } + + // Parsing duration part + iatStr := parts[1] + iat, err := time.ParseDuration(iatStr) + if err != nil { + return nil, fmt.Errorf("failed to parse duration: %v", err) + } + + return &TestPacket{ID: id, Opcode: opcode, ACKs: acks, IAT: iat}, nil +} + +var errBadACK = errors.New("wrong ack string") + +func parseACKs(s string) ([]int, error) { + acks := []int{} + h := strings.Split(s, "ack:") + if len(h) != 2 { + return acks, errBadACK + } + values := strings.Split(h[1], ",") + for _, v := range values { + n, err := strconv.Atoi(v) + if err == nil { + acks = append(acks, n) + } + } + return acks, nil +} diff --git a/internal/vpntest/vpntest_test.go b/internal/vpntest/vpntest_test.go new file mode 100644 index 00000000..f606bd08 --- /dev/null +++ b/internal/vpntest/vpntest_test.go @@ -0,0 +1,68 @@ +// Package vpntest provides utilities for minivpn testing. +package vpntest + +import ( + "reflect" + "testing" + "time" + + "github.com/ooni/minivpn/internal/model" +) + +func TestNewTestPacketFromString(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want *TestPacket + wantErr bool + }{ + { + name: "parse a correct testpacket string", + args: args{"[1] CONTROL_V1 +42ms"}, + want: &TestPacket{ + ID: 1, + Opcode: model.P_CONTROL_V1, + ACKs: []int{}, + IAT: time.Millisecond * 42, + }, + wantErr: false, + }, + { + name: "parse a testpacket with acks", + args: args{"[1] CONTROL_V1 (ack:0,1) +42ms"}, + want: &TestPacket{ + ID: 1, + Opcode: model.P_CONTROL_V1, + ACKs: []int{0, 1}, + IAT: time.Millisecond * 42, + }, + wantErr: false, + }, + { + name: "empty acks part", + args: args{"[1] CONTROL_V1 (ack:) +42ms"}, + want: &TestPacket{ + ID: 1, + Opcode: model.P_CONTROL_V1, + ACKs: []int{}, + IAT: time.Millisecond * 42, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewTestPacketFromString(tt.args.s) + if (err != nil) != tt.wantErr { + t.Errorf("NewTestPacketFromString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTestPacketFromString() = %v, want %v", got, tt.want) + } + }) + } +}