Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Commit

Permalink
feature/internal/grpc: retry: vendor go-grpc-middleware testing/testp…
Browse files Browse the repository at this point in the history
…b package
  • Loading branch information
ggilmore committed Jul 31, 2024
1 parent b2e550c commit ded7187
Show file tree
Hide file tree
Showing 9 changed files with 1,758 additions and 0 deletions.
39 changes: 39 additions & 0 deletions internal/grpc/retry/testpb/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//dev:go_defs.bzl", "go_test")

go_library(
name = "testpb",
srcs = [
"interceptor_suite.go",
"pingservice.go",
"test.manual_validator.pb.go",
"test.pb.go",
"test_grpc.pb.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/grpc/retry/testpb",
visibility = ["//:__subpackages__"],
deps = [
"@com_github_stretchr_testify//require",
"@com_github_stretchr_testify//suite",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//credentials",
"@org_golang_google_grpc//credentials/insecure",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//runtime/protoimpl",
],
)

go_test(
name = "testpb_test",
srcs = ["pingservice_test.go"],
embed = [":testpb"],
deps = [
"@com_github_stretchr_testify//require",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//credentials/insecure",
"@org_golang_google_grpc//status",
],
)
233 changes: 233 additions & 0 deletions internal/grpc/retry/testpb/interceptor_suite.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.

package testpb

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"flag"
"math/big"
"net"
"sync"
"time"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)

var (
flagTls = flag.Bool("use_tls", true, "whether all gRPC middleware tests should use tls")

certPEM []byte
keyPEM []byte
)

// InterceptorTestSuite is a testify/Suite that starts a gRPC PingService server and a client.
type InterceptorTestSuite struct {
suite.Suite

TestService TestServiceServer
ServerOpts []grpc.ServerOption
ClientOpts []grpc.DialOption

serverAddr string
ServerListener net.Listener
Server *grpc.Server
clientConn *grpc.ClientConn
Client TestServiceClient

restartServerWithDelayedStart chan time.Duration
serverRunning chan bool

cancels []context.CancelFunc
}

func (s *InterceptorTestSuite) SetupSuite() {
s.restartServerWithDelayedStart = make(chan time.Duration)
s.serverRunning = make(chan bool)

s.serverAddr = "127.0.0.1:0"
var err error
certPEM, keyPEM, err = generateCertAndKey([]string{"localhost", "example.com"}) // CI:LOCALHOST_OK
require.NoError(s.T(), err, "unable to generate test certificate/key")

go func() {
for {
var err error
s.ServerListener, err = net.Listen("tcp", s.serverAddr)
s.serverAddr = s.ServerListener.Addr().String()
require.NoError(s.T(), err, "must be able to allocate a port for serverListener")
if *flagTls {
cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(s.T(), err, "unable to load test TLS certificate")
creds := credentials.NewServerTLSFromCert(&cert)
s.ServerOpts = append(s.ServerOpts, grpc.Creds(creds))
}
// This is the point where we hook up the interceptor.
s.Server = grpc.NewServer(s.ServerOpts...)
// Create a service if the instantiator hasn't provided one.
if s.TestService == nil {
s.TestService = &TestPingService{}
}
RegisterTestServiceServer(s.Server, s.TestService)

w := sync.WaitGroup{}
w.Add(1)
go func() {
_ = s.Server.Serve(s.ServerListener)
w.Done()
}()
if s.Client == nil {
s.Client = s.NewClient(s.ClientOpts...)
}

s.serverRunning <- true

d := <-s.restartServerWithDelayedStart
s.Server.Stop()
time.Sleep(d)
w.Wait()
}
}()

select {
case <-s.serverRunning:
case <-time.After(2 * time.Second):
s.T().Fatal("server failed to start before deadline")
}
}

func (s *InterceptorTestSuite) RestartServer(delayedStart time.Duration) <-chan bool {
s.restartServerWithDelayedStart <- delayedStart
time.Sleep(10 * time.Millisecond)
return s.serverRunning
}

func (s *InterceptorTestSuite) NewClient(dialOpts ...grpc.DialOption) TestServiceClient {
//lint:ignore SA1019 This is a vendored package, so we shouldn't be modifying it.
newDialOpts := append(dialOpts, grpc.WithBlock())
var err error
if *flagTls {
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(certPEM) {
s.T().Fatal("failed to append certificate")
}
creds := credentials.NewTLS(&tls.Config{ServerName: "localhost", RootCAs: cp}) // CI:LOCALHOST_OK
newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(creds))
} else {
newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
//lint:ignore SA1019 This is a vendored package, so we shouldn't be modifying it.
s.clientConn, err = grpc.DialContext(ctx, s.ServerAddr(), newDialOpts...)
require.NoError(s.T(), err, "must not error on client Dial")
return NewTestServiceClient(s.clientConn)
}

func (s *InterceptorTestSuite) ServerAddr() string {
return s.serverAddr
}

type ctxTestNumber struct{}

var (
ctxTestNumberKey = &ctxTestNumber{}
zero = 0
)

func ExtractCtxTestNumber(ctx context.Context) *int {
if v, ok := ctx.Value(ctxTestNumberKey).(*int); ok {
return v
}
return &zero
}

// UnaryServerInterceptor returns a new unary server interceptors that adds query information logging.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
// newCtx := newContext(ctx, log, opts)
newCtx := ctx
resp, err := handler(newCtx, req)
return resp, err
}
}

func (s *InterceptorTestSuite) SimpleCtx() context.Context {
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
ctx = context.WithValue(ctx, ctxTestNumberKey, 1)
s.cancels = append(s.cancels, cancel)
return ctx
}

func (s *InterceptorTestSuite) DeadlineCtx(deadline time.Time) context.Context {
ctx, cancel := context.WithDeadline(context.TODO(), deadline)
s.cancels = append(s.cancels, cancel)
return ctx
}

func (s *InterceptorTestSuite) TearDownSuite() {
time.Sleep(10 * time.Millisecond)
if s.ServerListener != nil {
s.Server.GracefulStop()
s.T().Logf("stopped grpc.Server at: %v", s.ServerAddr())
_ = s.ServerListener.Close()
}
if s.clientConn != nil {
_ = s.clientConn.Close()
}
for _, c := range s.cancels {
c()
}
}

// generateCertAndKey copied from https://github.com/johanbrandhorst/certify/blob/master/issuers/vault/vault_suite_test.go#L255
// with minor modifications.
func generateCertAndKey(san []string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
notBefore := time.Now()
notAfter := notBefore.Add(time.Hour)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "example.com",
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: san,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv)
if err != nil {
return nil, nil, err
}
certOut := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
keyOut := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})

return certOut, keyOut, nil
}
82 changes: 82 additions & 0 deletions internal/grpc/retry/testpb/pingservice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.

/*
Package `grpc_testing` provides helper functions for testing validators in this package.
*/

package testpb

import (
"context"
"io"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
// ListResponseCount is the expected number of responses to PingList
ListResponseCount = 100
)

var TestServiceFullName = _TestService_serviceDesc.ServiceName

// Interface implementation assert.
var _ TestServiceServer = &TestPingService{}

type TestPingService struct {
UnimplementedTestServiceServer
}

func (s *TestPingService) PingEmpty(_ context.Context, _ *PingEmptyRequest) (*PingEmptyResponse, error) {
return &PingEmptyResponse{}, nil
}

func (s *TestPingService) Ping(ctx context.Context, ping *PingRequest) (*PingResponse, error) {
// Modify the ctx value to verify the logger sees the value updated from the initial value
n := ExtractCtxTestNumber(ctx)
if n != nil {
*n = 42
}
// Send user trailers and headers.
return &PingResponse{Value: ping.Value, Counter: 0}, nil
}

func (s *TestPingService) PingError(_ context.Context, ping *PingErrorRequest) (*PingErrorResponse, error) {
code := codes.Code(ping.ErrorCodeReturned)
return nil, status.Error(code, "Userspace error")
}

func (s *TestPingService) PingList(ping *PingListRequest, stream TestService_PingListServer) error {
if ping.ErrorCodeReturned != 0 {
return status.Error(codes.Code(ping.ErrorCodeReturned), "foobar")
}

// Send user trailers and headers.
for i := 0; i < ListResponseCount; i++ {
if err := stream.Send(&PingListResponse{Value: ping.Value, Counter: int32(i)}); err != nil {
return err
}
}
return nil
}

func (s *TestPingService) PingStream(stream TestService_PingStreamServer) error {
count := 0
for {
ping, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return err
}
if err := stream.Send(&PingStreamResponse{Value: ping.Value, Counter: int32(count)}); err != nil {
return err
}

count += 1
}
return nil
}
Loading

0 comments on commit ded7187

Please sign in to comment.