This repository has been archived by the owner on Sep 30, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature/internal/grpc: retry: vendor go-grpc-middleware testing/testp…
…b package
- Loading branch information
Showing
9 changed files
with
1,758 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
load("@io_bazel_rules_go//go:def.bzl", "go_library") | ||
load("//dev:go_defs.bzl", "go_test") | ||
|
||
go_library( | ||
name = "testpb", | ||
srcs = [ | ||
"interceptor_suite.go", | ||
"pingservice.go", | ||
"test.manual_validator.pb.go", | ||
"test.pb.go", | ||
"test_grpc.pb.go", | ||
], | ||
importpath = "github.com/sourcegraph/sourcegraph/internal/grpc/retry/testpb", | ||
visibility = ["//:__subpackages__"], | ||
deps = [ | ||
"@com_github_stretchr_testify//require", | ||
"@com_github_stretchr_testify//suite", | ||
"@org_golang_google_grpc//:go_default_library", | ||
"@org_golang_google_grpc//codes", | ||
"@org_golang_google_grpc//credentials", | ||
"@org_golang_google_grpc//credentials/insecure", | ||
"@org_golang_google_grpc//status", | ||
"@org_golang_google_protobuf//reflect/protoreflect", | ||
"@org_golang_google_protobuf//runtime/protoimpl", | ||
], | ||
) | ||
|
||
go_test( | ||
name = "testpb_test", | ||
srcs = ["pingservice_test.go"], | ||
embed = [":testpb"], | ||
deps = [ | ||
"@com_github_stretchr_testify//require", | ||
"@org_golang_google_grpc//:go_default_library", | ||
"@org_golang_google_grpc//codes", | ||
"@org_golang_google_grpc//credentials/insecure", | ||
"@org_golang_google_grpc//status", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
// Copyright (c) The go-grpc-middleware Authors. | ||
// Licensed under the Apache License 2.0. | ||
|
||
package testpb | ||
|
||
import ( | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"crypto/x509/pkix" | ||
"encoding/pem" | ||
"flag" | ||
"math/big" | ||
"net" | ||
"sync" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/stretchr/testify/suite" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/credentials" | ||
"google.golang.org/grpc/credentials/insecure" | ||
) | ||
|
||
var ( | ||
flagTls = flag.Bool("use_tls", true, "whether all gRPC middleware tests should use tls") | ||
|
||
certPEM []byte | ||
keyPEM []byte | ||
) | ||
|
||
// InterceptorTestSuite is a testify/Suite that starts a gRPC PingService server and a client. | ||
type InterceptorTestSuite struct { | ||
suite.Suite | ||
|
||
TestService TestServiceServer | ||
ServerOpts []grpc.ServerOption | ||
ClientOpts []grpc.DialOption | ||
|
||
serverAddr string | ||
ServerListener net.Listener | ||
Server *grpc.Server | ||
clientConn *grpc.ClientConn | ||
Client TestServiceClient | ||
|
||
restartServerWithDelayedStart chan time.Duration | ||
serverRunning chan bool | ||
|
||
cancels []context.CancelFunc | ||
} | ||
|
||
func (s *InterceptorTestSuite) SetupSuite() { | ||
s.restartServerWithDelayedStart = make(chan time.Duration) | ||
s.serverRunning = make(chan bool) | ||
|
||
s.serverAddr = "127.0.0.1:0" | ||
var err error | ||
certPEM, keyPEM, err = generateCertAndKey([]string{"localhost", "example.com"}) // CI:LOCALHOST_OK | ||
require.NoError(s.T(), err, "unable to generate test certificate/key") | ||
|
||
go func() { | ||
for { | ||
var err error | ||
s.ServerListener, err = net.Listen("tcp", s.serverAddr) | ||
s.serverAddr = s.ServerListener.Addr().String() | ||
require.NoError(s.T(), err, "must be able to allocate a port for serverListener") | ||
if *flagTls { | ||
cert, err := tls.X509KeyPair(certPEM, keyPEM) | ||
require.NoError(s.T(), err, "unable to load test TLS certificate") | ||
creds := credentials.NewServerTLSFromCert(&cert) | ||
s.ServerOpts = append(s.ServerOpts, grpc.Creds(creds)) | ||
} | ||
// This is the point where we hook up the interceptor. | ||
s.Server = grpc.NewServer(s.ServerOpts...) | ||
// Create a service if the instantiator hasn't provided one. | ||
if s.TestService == nil { | ||
s.TestService = &TestPingService{} | ||
} | ||
RegisterTestServiceServer(s.Server, s.TestService) | ||
|
||
w := sync.WaitGroup{} | ||
w.Add(1) | ||
go func() { | ||
_ = s.Server.Serve(s.ServerListener) | ||
w.Done() | ||
}() | ||
if s.Client == nil { | ||
s.Client = s.NewClient(s.ClientOpts...) | ||
} | ||
|
||
s.serverRunning <- true | ||
|
||
d := <-s.restartServerWithDelayedStart | ||
s.Server.Stop() | ||
time.Sleep(d) | ||
w.Wait() | ||
} | ||
}() | ||
|
||
select { | ||
case <-s.serverRunning: | ||
case <-time.After(2 * time.Second): | ||
s.T().Fatal("server failed to start before deadline") | ||
} | ||
} | ||
|
||
func (s *InterceptorTestSuite) RestartServer(delayedStart time.Duration) <-chan bool { | ||
s.restartServerWithDelayedStart <- delayedStart | ||
time.Sleep(10 * time.Millisecond) | ||
return s.serverRunning | ||
} | ||
|
||
func (s *InterceptorTestSuite) NewClient(dialOpts ...grpc.DialOption) TestServiceClient { | ||
//lint:ignore SA1019 This is a vendored package, so we shouldn't be modifying it. | ||
newDialOpts := append(dialOpts, grpc.WithBlock()) | ||
var err error | ||
if *flagTls { | ||
cp := x509.NewCertPool() | ||
if !cp.AppendCertsFromPEM(certPEM) { | ||
s.T().Fatal("failed to append certificate") | ||
} | ||
creds := credentials.NewTLS(&tls.Config{ServerName: "localhost", RootCAs: cp}) // CI:LOCALHOST_OK | ||
newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(creds)) | ||
} else { | ||
newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) | ||
} | ||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) | ||
defer cancel() | ||
//lint:ignore SA1019 This is a vendored package, so we shouldn't be modifying it. | ||
s.clientConn, err = grpc.DialContext(ctx, s.ServerAddr(), newDialOpts...) | ||
require.NoError(s.T(), err, "must not error on client Dial") | ||
return NewTestServiceClient(s.clientConn) | ||
} | ||
|
||
func (s *InterceptorTestSuite) ServerAddr() string { | ||
return s.serverAddr | ||
} | ||
|
||
type ctxTestNumber struct{} | ||
|
||
var ( | ||
ctxTestNumberKey = &ctxTestNumber{} | ||
zero = 0 | ||
) | ||
|
||
func ExtractCtxTestNumber(ctx context.Context) *int { | ||
if v, ok := ctx.Value(ctxTestNumberKey).(*int); ok { | ||
return v | ||
} | ||
return &zero | ||
} | ||
|
||
// UnaryServerInterceptor returns a new unary server interceptors that adds query information logging. | ||
func UnaryServerInterceptor() grpc.UnaryServerInterceptor { | ||
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { | ||
// newCtx := newContext(ctx, log, opts) | ||
newCtx := ctx | ||
resp, err := handler(newCtx, req) | ||
return resp, err | ||
} | ||
} | ||
|
||
func (s *InterceptorTestSuite) SimpleCtx() context.Context { | ||
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second) | ||
ctx = context.WithValue(ctx, ctxTestNumberKey, 1) | ||
s.cancels = append(s.cancels, cancel) | ||
return ctx | ||
} | ||
|
||
func (s *InterceptorTestSuite) DeadlineCtx(deadline time.Time) context.Context { | ||
ctx, cancel := context.WithDeadline(context.TODO(), deadline) | ||
s.cancels = append(s.cancels, cancel) | ||
return ctx | ||
} | ||
|
||
func (s *InterceptorTestSuite) TearDownSuite() { | ||
time.Sleep(10 * time.Millisecond) | ||
if s.ServerListener != nil { | ||
s.Server.GracefulStop() | ||
s.T().Logf("stopped grpc.Server at: %v", s.ServerAddr()) | ||
_ = s.ServerListener.Close() | ||
} | ||
if s.clientConn != nil { | ||
_ = s.clientConn.Close() | ||
} | ||
for _, c := range s.cancels { | ||
c() | ||
} | ||
} | ||
|
||
// generateCertAndKey copied from https://github.com/johanbrandhorst/certify/blob/master/issuers/vault/vault_suite_test.go#L255 | ||
// with minor modifications. | ||
func generateCertAndKey(san []string) ([]byte, []byte, error) { | ||
priv, err := rsa.GenerateKey(rand.Reader, 2048) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
notBefore := time.Now() | ||
notAfter := notBefore.Add(time.Hour) | ||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) | ||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
template := x509.Certificate{ | ||
SerialNumber: serialNumber, | ||
Subject: pkix.Name{ | ||
CommonName: "example.com", | ||
}, | ||
NotBefore: notBefore, | ||
NotAfter: notAfter, | ||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, | ||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, | ||
BasicConstraintsValid: true, | ||
DNSNames: san, | ||
} | ||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
certOut := pem.EncodeToMemory(&pem.Block{ | ||
Type: "CERTIFICATE", | ||
Bytes: derBytes, | ||
}) | ||
keyOut := pem.EncodeToMemory(&pem.Block{ | ||
Type: "RSA PRIVATE KEY", | ||
Bytes: x509.MarshalPKCS1PrivateKey(priv), | ||
}) | ||
|
||
return certOut, keyOut, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// Copyright (c) The go-grpc-middleware Authors. | ||
// Licensed under the Apache License 2.0. | ||
|
||
/* | ||
Package `grpc_testing` provides helper functions for testing validators in this package. | ||
*/ | ||
|
||
package testpb | ||
|
||
import ( | ||
"context" | ||
"io" | ||
|
||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
const ( | ||
// ListResponseCount is the expected number of responses to PingList | ||
ListResponseCount = 100 | ||
) | ||
|
||
var TestServiceFullName = _TestService_serviceDesc.ServiceName | ||
|
||
// Interface implementation assert. | ||
var _ TestServiceServer = &TestPingService{} | ||
|
||
type TestPingService struct { | ||
UnimplementedTestServiceServer | ||
} | ||
|
||
func (s *TestPingService) PingEmpty(_ context.Context, _ *PingEmptyRequest) (*PingEmptyResponse, error) { | ||
return &PingEmptyResponse{}, nil | ||
} | ||
|
||
func (s *TestPingService) Ping(ctx context.Context, ping *PingRequest) (*PingResponse, error) { | ||
// Modify the ctx value to verify the logger sees the value updated from the initial value | ||
n := ExtractCtxTestNumber(ctx) | ||
if n != nil { | ||
*n = 42 | ||
} | ||
// Send user trailers and headers. | ||
return &PingResponse{Value: ping.Value, Counter: 0}, nil | ||
} | ||
|
||
func (s *TestPingService) PingError(_ context.Context, ping *PingErrorRequest) (*PingErrorResponse, error) { | ||
code := codes.Code(ping.ErrorCodeReturned) | ||
return nil, status.Error(code, "Userspace error") | ||
} | ||
|
||
func (s *TestPingService) PingList(ping *PingListRequest, stream TestService_PingListServer) error { | ||
if ping.ErrorCodeReturned != 0 { | ||
return status.Error(codes.Code(ping.ErrorCodeReturned), "foobar") | ||
} | ||
|
||
// Send user trailers and headers. | ||
for i := 0; i < ListResponseCount; i++ { | ||
if err := stream.Send(&PingListResponse{Value: ping.Value, Counter: int32(i)}); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func (s *TestPingService) PingStream(stream TestService_PingStreamServer) error { | ||
count := 0 | ||
for { | ||
ping, err := stream.Recv() | ||
if err == io.EOF { | ||
break | ||
} | ||
if err != nil { | ||
return err | ||
} | ||
if err := stream.Send(&PingStreamResponse{Value: ping.Value, Counter: int32(count)}); err != nil { | ||
return err | ||
} | ||
|
||
count += 1 | ||
} | ||
return nil | ||
} |
Oops, something went wrong.