Skip to content

Commit

Permalink
rocket server doesn't stop serving when listener is closed
Browse files Browse the repository at this point in the history
Summary:
fixes a bug where rocket server didn't stop serving when the listener was closed
See Issue report in https://fb.workplace.com/groups/codegophers/permalink/26891400743815186/

I forgot to return an error when the listener.Accept returned an error
This fixes the issue and adds two tests, one for header server and one for rocket server.

Differential Revision: D63021827

fbshipit-source-id: 1c143ee0ae54db426e34a97c1f2e5bd18b2c6339
  • Loading branch information
Walter Schulze authored and facebook-github-bot committed Sep 19, 2024
1 parent 4126978 commit 92c1da5
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 1 deletion.
91 changes: 91 additions & 0 deletions third-party/thrift/src/thrift/lib/go/thrift/header_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"context"
"net"
"testing"
"time"
)

type headerServerTestProcessor struct {
requests chan<- *MyTestStruct
}

func (t *headerServerTestProcessor) GetProcessorFunction(name string) ProcessorFunction {
if name == "test" {
return &headerServerTestProcessorFunction{&testProcessorFunction{}, t.requests}
}
return nil
}

type headerServerTestProcessorFunction struct {
ProcessorFunction
requests chan<- *MyTestStruct
}

func (p *headerServerTestProcessorFunction) RunContext(ctx context.Context, reqStruct Struct) (WritableStruct, ApplicationException) {
if p.requests != nil {
p.requests <- reqStruct.(*MyTestStruct)
}
return reqStruct, nil
}

// Test that header server stops serving if listener is closed.
func TestHeaderServerCloseListener(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errChan := make(chan error)
defer close(errChan)
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
server := NewSimpleServer(&headerServerTestProcessor{}, listener, TransportIDHeader)
go func() {
errChan <- server.ServeContext(ctx)
}()
addr := listener.Addr()
conn, err := net.Dial(addr.Network(), addr.String())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
proto, err := NewHeaderProtocol(conn)
if err != nil {
t.Fatalf("could not create client protocol: %s", err)
}
client := NewSerialChannel(proto)
req := &MyTestStruct{
St: "hello",
}
resp := &MyTestStruct{}
if err := client.Call(context.Background(), "test", req, resp); err != nil {
t.Fatalf("could not complete call: %v", err)
}
if resp.St != "hello" {
t.Fatalf("expected response to be a hello, got %s", resp.St)
}
listener.Close()
select {
case <-errChan:
break
case <-time.After(3 * time.Second):
t.Fatalf("listener did not close")
}
cancel()
}
45 changes: 45 additions & 0 deletions third-party/thrift/src/thrift/lib/go/thrift/rocket_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"
"net"
"testing"
"time"
)

type rocketServerTestProcessor struct {
Expand Down Expand Up @@ -126,3 +127,47 @@ func TestRocketServerOneWay(t *testing.T) {
cancel()
<-errChan
}

// Test that rocket server stops serving if listener is closed.
func TestRocketServerCloseListener(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errChan := make(chan error)
// defer close(errChan)
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
server := NewSimpleServer(&rocketServerTestProcessor{}, listener, TransportIDRocket)
go func() {
errChan <- server.ServeContext(ctx)
}()
addr := listener.Addr()
conn, err := net.Dial(addr.Network(), addr.String())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
proto, err := newRocketClient(conn, ProtocolIDCompact, 0, nil)
if err != nil {
t.Fatalf("could not create client protocol: %s", err)
}
client := NewSerialChannel(proto)
req := &MyTestStruct{
St: "hello",
}
resp := &MyTestStruct{}
if err := client.Call(context.Background(), "test", req, resp); err != nil {
t.Fatalf("could not complete call: %v", err)
}
if resp.St != "hello" {
t.Fatalf("expected response to be a hello, got %s", resp.St)
}
listener.Close()
select {
case <-errChan:
break
case <-time.After(3 * time.Second):
t.Fatalf("listener did not close")
}
cancel()
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (r *rocketServerTransport) acceptLoop(ctx context.Context) error {
case <-ctx.Done():
return nil
default:
err = fmt.Errorf("accept next conn failed: %w", err)
return fmt.Errorf("listener.Accept failed in rocketServerTransport.acceptLoop: %w", err)
}
}
if conn == nil {
Expand Down

0 comments on commit 92c1da5

Please sign in to comment.