diff --git a/internal/reliabletransport/common_test.go b/internal/reliabletransport/common_test.go new file mode 100644 index 00000000..ca6492e1 --- /dev/null +++ b/internal/reliabletransport/common_test.go @@ -0,0 +1,48 @@ +package reliabletransport + +import ( + "github.com/apex/log" + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +// +// Common utilities for tests in this package. +// + +// initManagers initializes a workers manager and a session manager. +func initManagers() (*workers.Manager, *session.Manager) { + w := workers.NewManager(log.Log) + s, err := session.NewManager(log.Log) + if err != nil { + panic(err) + } + return w, s +} + +// newRandomSessionID returns a random session ID to initialize mock sessions. +func newRandomSessionID() model.SessionID { + b, err := bytesx.GenRandomBytes(8) + if err != nil { + panic(err) + } + return model.SessionID(b) +} + +func ackSetFromInts(s []int) *ackSet { + acks := make([]model.PacketID, 0) + for _, i := range s { + acks = append(acks, model.PacketID(i)) + } + return newACKSet(acks...) +} + +func ackSetFromRange(start, total int) *ackSet { + acks := make([]model.PacketID, 0) + for i := 0; i < total; i++ { + acks = append(acks, model.PacketID(start+i)) + } + return newACKSet(acks...) +} diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go index 88664cf7..8fc9f5ed 100644 --- a/internal/reliabletransport/reliable_ack_test.go +++ b/internal/reliabletransport/reliable_ack_test.go @@ -1,6 +1,7 @@ package reliabletransport import ( + "slices" "testing" "time" @@ -16,6 +17,7 @@ func TestReliable_ACK(t *testing.T) { type args struct { inputSequence []string + start int wantacks int } @@ -38,9 +40,38 @@ func TestReliable_ACK(t *testing.T) { "[9] CONTROL_V1 +1ms", "[10] CONTROL_V1 +1ms", }, + start: 1, wantacks: 10, }, }, + { + name: "five ordered packets with offset", + args: args{ + inputSequence: []string{ + "[100] CONTROL_V1 +1ms", + "[101] CONTROL_V1 +1ms", + "[102] CONTROL_V1 +1ms", + "[103] CONTROL_V1 +1ms", + "[104] CONTROL_V1 +1ms", + }, + start: 100, + wantacks: 5, + }, + }, + { + name: "five reversed packets", + args: args{ + inputSequence: []string{ + "[5] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 5, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -48,7 +79,9 @@ func TestReliable_ACK(t *testing.T) { // just to properly initialize it, we don't care about these s.ControlToReliable = make(chan *model.Packet) - reliableToControl := make(chan *model.Packet) + // this one up to control/tls also needs to be buffered because otherwise + // we'll block on the receiver when delivering up. + reliableToControl := make(chan *model.Packet, 1024) s.ReliableToControl = &reliableToControl // the only two channels we're going to be testing on this test @@ -61,10 +94,6 @@ func TestReliable_ACK(t *testing.T) { workers, session := initManagers() - // this is our session (local to us) - localSessionID := session.LocalSessionID() - remoteSessionID := session.RemoteSessionID() - t0 := time.Now() // let the workers pump up the jam! @@ -72,19 +101,28 @@ func TestReliable_ACK(t *testing.T) { writer := vpntest.NewPacketWriter(dataIn) - // TODO -- need to create a session - writer.LocalSessionID = model.SessionID(remoteSessionID) - writer.RemoteSessionID = model.SessionID(localSessionID) + // initialize a mock session ID for our peer + peerSessionID := newRandomSessionID() + + writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) + writer.LocalSessionID = peerSessionID + session.SetRemoteSessionID(peerSessionID) go writer.WriteSequence(tt.args.inputSequence) reader := vpntest.NewPacketReader(dataOut) witness := vpntest.NewWitness(reader) - if ok := witness.VerifyACKs(tt.args.wantacks, t0); !ok { - //log.Debug(witness.Log()) - got := witness.NumberOfACKs() - t.Errorf("Reordering: got = %v, want %v", got, tt.args.wantacks) + if ok := witness.VerifyNumberOfACKs(tt.args.start, tt.args.wantacks, t0); !ok { + got := len(witness.Log().ACKs()) + t.Errorf("TestACK: got = %v, want %v", got, tt.args.wantacks) + } + gotAckSet := ackSetFromInts(witness.Log().ACKs()).sorted() + wantAckSet := ackSetFromRange(tt.args.start, tt.args.wantacks).sorted() + + if !slices.Equal(gotAckSet, wantAckSet) { + t.Errorf("TestACK: got = %v, want %v", gotAckSet, wantAckSet) + } }) } diff --git a/internal/reliabletransport/reliable_reorder_test.go b/internal/reliabletransport/reliable_reorder_test.go index c5d9b2ea..b4e94823 100644 --- a/internal/reliabletransport/reliable_reorder_test.go +++ b/internal/reliabletransport/reliable_reorder_test.go @@ -109,10 +109,6 @@ func TestReliable_Reordering_UP(t *testing.T) { workers, session := initManagers() - // this is our session (local to us) - localSessionID := session.LocalSessionID() - remoteSessionID := session.RemoteSessionID() - t0 := time.Now() // let the workers pump up the jam! @@ -120,9 +116,8 @@ func TestReliable_Reordering_UP(t *testing.T) { writer := vpntest.NewPacketWriter(dataIn) - // TODO -- need to create a session - writer.LocalSessionID = model.SessionID(remoteSessionID) - writer.RemoteSessionID = model.SessionID(localSessionID) + writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) + writer.LocalSessionID = newRandomSessionID() go writer.WriteSequence(tt.args.inputSequence) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 53578c43..c9013e65 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -237,7 +237,7 @@ type ackSet struct { m map[model.PacketID]bool } -// NewACKSet creates a new empty ACK set. +// newACKSet creates a new empty ACK set. func newACKSet(ids ...model.PacketID) *ackSet { m := make(map[model.PacketID]bool) for _, id := range ids { diff --git a/internal/reliabletransport/tests.go b/internal/reliabletransport/tests.go deleted file mode 100644 index 61472042..00000000 --- a/internal/reliabletransport/tests.go +++ /dev/null @@ -1,21 +0,0 @@ -package reliabletransport - -import ( - "github.com/apex/log" - "github.com/ooni/minivpn/internal/session" - "github.com/ooni/minivpn/internal/workers" -) - -// -// Common utilities for tests in this package. -// - -// initManagers initializes a workers manager and a session manager -func initManagers() (*workers.Manager, *session.Manager) { - w := workers.NewManager(log.Log) - s, err := session.NewManager(log.Log) - if err != nil { - panic(err) - } - return w, s -} diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index a8814416..499009d8 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -1,6 +1,7 @@ package vpntest import ( + "fmt" "slices" "time" @@ -41,6 +42,7 @@ func (pw *PacketWriter) WriteSequence(seq []string) { ID: model.PacketID(testPkt.ID), } pw.ch <- p + fmt.Println("<< wrote", p.ID) time.Sleep(testPkt.IAT) } } @@ -75,9 +77,9 @@ func (l PacketLog) IDSequence() []int { return ids } -// acks filters the log and returns an array of ids that have been acked +// ACKs filters the log and returns an array of ids that have been acked // either as ack packets or as part of the ack array of an outgoing packet. -func (l PacketLog) acks() []int { +func (l PacketLog) ACKs() []int { acks := []int{} for _, p := range l { for _, ack := range p.ACKs { @@ -120,7 +122,7 @@ func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { func (pr *PacketReader) WaitForNumberOfACKs(total int, start time.Time) { for { // have we read enough acks to call it a day? - if len(PacketLog(pr.log).acks()) >= total { + if len(PacketLog(pr.log).ACKs()) >= total { break } // no, so let's keep reading until the test runner kills us @@ -149,14 +151,14 @@ func (w *Witness) Log() PacketLog { } // VerifyACKs tells the underlying reader to wait for a given number of acks, -// and then checks that we have an ack sequence without holes. -func (w *Witness) VerifyACKs(total int, t time.Time) bool { +// returns true if we have the same number of acks. +func (w *Witness) VerifyNumberOfACKs(start, total int, t time.Time) bool { w.reader.WaitForNumberOfACKs(total, t) - // TODO: compare the range here, no holes - // TODO: probl. need start idx - return true + return len(w.Log().ACKs()) == total } +/* func (w *Witness) NumberOfACKs() int { - return len(w.reader.Log().acks()) + return len(w.Log().ACKs()) } +*/