Skip to content

Commit

Permalink
TUN-8817: Increase close session channel by one since there are two w…
Browse files Browse the repository at this point in the history
…riters

When closing a session, there are two possible signals that will occur,
one from the outside, indicating that the session is idle and needs to
be closed, and the internal error condition that will be unblocked
with a net.ErrClosed when the connection underneath is closed. Both of
these routines write to the session's closeChan.

Once the reader for the closeChan reads one value, it will immediately
return. This means that the channel is a one-shot and one of the two
writers will get stuck unless the size of the channel is increased to
accomodate for the second write to the channel.

With the channel size increased to two, the second writer (whichever
loses the race to write) will now be unblocked to end their go routine
and return.

Closes TUN-8817
  • Loading branch information
DevinCarr committed Dec 17, 2024
1 parent 1859d74 commit bc9c5d2
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 92 deletions.
5 changes: 4 additions & 1 deletion quic/v3/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ func NewSession(
log *zerolog.Logger,
) Session {
logger := log.With().Str(logFlowID, id.String()).Logger()
closeChan := make(chan error, 1)
// closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to
// write to the channel without blocking since there is only ever one value read from the closeChan by the
// waitForCloseCondition.
closeChan := make(chan error, 2)
session := &session{
id: id,
closeAfterIdle: closeAfterIdle,
Expand Down
167 changes: 76 additions & 91 deletions quic/v3/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package v3_test
import (
"context"
"errors"
"io"
"net"
"net/netip"
"slices"
"sync/atomic"
"testing"
"time"

"github.com/fortytw2/leaktest"
"github.com/rs/zerolog"

v3 "github.com/cloudflare/cloudflared/quic/v3"
Expand All @@ -32,45 +33,64 @@ func TestSessionNew(t *testing.T) {

func testSessionWrite(t *testing.T, payload []byte) {
log := zerolog.Nop()
origin := newTestOrigin(makePayload(1280))
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
// Start origin server read
serverRead := make(chan []byte, 1)
go func() {
read := make([]byte, 1500)
server.Read(read[:])
serverRead <- read
}()
// Create session and write to origin
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
n, err := session.Write(payload)
defer session.Close()
if err != nil {
t.Fatal(err)
}
if n != len(payload) {
t.Fatal("unable to write the whole payload")
}
if !slices.Equal(payload, origin.write[:len(payload)]) {

read := <-serverRead
if !slices.Equal(payload, read[:len(payload)]) {
t.Fatal("payload provided from origin and read value are not the same")
}
}

func TestSessionWrite_Max(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(1280)
testSessionWrite(t, payload)
}

func TestSessionWrite_Min(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0)
testSessionWrite(t, payload)
}

func TestSessionServe_OriginMax(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(1280)
testSessionServe_Origin(t, payload)
}

func TestSessionServe_OriginMin(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0)
testSessionServe_Origin(t, payload)
}

func testSessionServe_Origin(t *testing.T, payload []byte) {
log := zerolog.Nop()
origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
eyeball := newMockEyeball()
origin := newTestOrigin(payload)
session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
session := v3.NewSession(testRequestID, 3*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
defer session.Close()

ctx, cancel := context.WithCancelCause(context.Background())
Expand All @@ -80,13 +100,19 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
done <- session.Serve(ctx)
}()

// Write from the origin server
_, err := server.Write(payload)
if err != nil {
t.Fatal(err)
}

select {
case data := <-eyeball.recvData:
// check received data matches provided from origin
expectedData := makePayload(1500)
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:17+len(payload)], data) {
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
t.Fatal("expected datagram did not equal expected")
}
cancel(expectedContextCanceled)
Expand All @@ -95,7 +121,7 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
t.Fatal(err)
}

err := <-done
err = <-done
if !errors.Is(err, context.Canceled) {
t.Fatal(err)
}
Expand All @@ -105,18 +131,27 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
}

func TestSessionServe_OriginTooLarge(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
eyeball := newMockEyeball()
payload := makePayload(1281)
origin := newTestOrigin(payload)
session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
session := v3.NewSession(testRequestID, 2*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
defer session.Close()

done := make(chan error)
go func() {
done <- session.Serve(context.Background())
}()

// Attempt to write a payload too large from the origin
_, err := server.Write(payload)
if err != nil {
t.Fatal(err)
}

select {
case data := <-eyeball.recvData:
// we never expect a read to make it here because the origin provided a payload that is too large
Expand All @@ -130,6 +165,7 @@ func TestSessionServe_OriginTooLarge(t *testing.T) {
}

func TestSessionServe_Migrate(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
eyeball := newMockEyeball()
pipe1, pipe2 := net.Pipe()
Expand Down Expand Up @@ -186,6 +222,7 @@ func TestSessionServe_Migrate(t *testing.T) {
}

func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
eyeball := newMockEyeball()
pipe1, pipe2 := net.Pipe()
Expand Down Expand Up @@ -245,39 +282,48 @@ func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
}

func TestSessionClose_Multiple(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
origin := newTestOrigin(makePayload(128))
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
err := session.Close()
if err != nil {
t.Fatal(err)
}
if !origin.closed.Load() {
t.Fatal("origin wasn't closed")
b := [1500]byte{}
_, err = server.Read(b[:])
if !errors.Is(err, io.EOF) {
t.Fatalf("origin server connection should be closed: %s", err)
}
// Reset the closed status to make sure it isn't closed again
origin.closed.Store(false)
// subsequent closes shouldn't call close again or cause any errors
err = session.Close()
if err != nil {
t.Fatal(err)
}
if origin.closed.Load() {
t.Fatal("origin was incorrectly closed twice")
_, err = server.Read(b[:])
if !errors.Is(err, io.EOF) {
t.Fatalf("origin server connection should still be closed: %s", err)
}
}

func TestSessionServe_IdleTimeout(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle
origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
closeAfterIdle := 2 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
err := session.Serve(context.Background())
if !errors.Is(err, v3.SessionIdleErr{}) {
t.Fatal(err)
}
// session should be closed
if !origin.closed {
b := [1500]byte{}
_, err = server.Read(b[:])
if !errors.Is(err, io.EOF) {
t.Fatalf("session should be closed after Serve returns")
}
// closing a session again should not return an error
Expand All @@ -288,20 +334,24 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
}

func TestSessionServe_ParentContextCanceled(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
// Make idle time and idle timeout longer than closeAfterIdle
origin := newTestIdleOrigin(10 * time.Second)
origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
closeAfterIdle := 10 * time.Second

session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := session.Serve(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal(err)
}
// session should be closed
if !origin.closed {
b := [1500]byte{}
_, err = server.Read(b[:])
if !errors.Is(err, io.EOF) {
t.Fatalf("session should be closed after Serve returns")
}
// closing a session again should not return an error
Expand All @@ -312,6 +362,7 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) {
}

func TestSessionServe_ReadErrors(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
origin := newTestErrOrigin(net.ErrClosed, nil)
session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
Expand All @@ -321,72 +372,6 @@ func TestSessionServe_ReadErrors(t *testing.T) {
}
}

type testOrigin struct {
// bytes from Write
write []byte
// bytes provided to Read
read []byte
readOnce atomic.Bool
closed atomic.Bool
}

func newTestOrigin(payload []byte) testOrigin {
return testOrigin{
read: payload,
}
}

func (o *testOrigin) Read(p []byte) (n int, err error) {
if o.closed.Load() {
return -1, net.ErrClosed
}
if o.readOnce.Load() {
// We only want to provide one read so all other reads will be blocked
time.Sleep(10 * time.Second)
}
o.readOnce.Store(true)
return copy(p, o.read), nil
}

func (o *testOrigin) Write(p []byte) (n int, err error) {
if o.closed.Load() {
return -1, net.ErrClosed
}
o.write = make([]byte, len(p))
copy(o.write, p)
return len(p), nil
}

func (o *testOrigin) Close() error {
o.closed.Store(true)
return nil
}

type testIdleOrigin struct {
duration time.Duration
closed bool
}

func newTestIdleOrigin(d time.Duration) testIdleOrigin {
return testIdleOrigin{
duration: d,
}
}

func (o *testIdleOrigin) Read(p []byte) (n int, err error) {
time.Sleep(o.duration)
return -1, nil
}

func (o *testIdleOrigin) Write(p []byte) (n int, err error) {
return 0, nil
}

func (o *testIdleOrigin) Close() error {
o.closed = true
return nil
}

type testErrOrigin struct {
readErr error
writeErr error
Expand Down

0 comments on commit bc9c5d2

Please sign in to comment.