diff --git a/protobuf/source_reflection.go b/protobuf/source_reflection.go index 5e71092..9ae0e4a 100644 --- a/protobuf/source_reflection.go +++ b/protobuf/source_reflection.go @@ -3,8 +3,8 @@ package protobuf import ( "context" "fmt" - "strconv" "strings" + "sync" "time" "github.com/jhump/protoreflect/desc" @@ -19,6 +19,8 @@ import ( "google.golang.org/grpc/resolver/manual" ) +var resolverMutex sync.Mutex + // ReflectionArgs are args for constructing a DescriptorProvider that reaches out to a reflection server. type ReflectionArgs struct { Caller string @@ -32,8 +34,6 @@ type ReflectionArgs struct { // NewDescriptorProviderReflection returns a DescriptorProvider that reaches // out to a reflection server to access file descriptors. func NewDescriptorProviderReflection(args ReflectionArgs) (DescriptorProvider, error) { - r, deregisterScheme := GenerateAndRegisterManualResolver() - defer deregisterScheme() peers := make([]resolver.Address, len(args.Peers)) for i, p := range args.Peers { if strings.Contains(p, "://") { @@ -41,13 +41,17 @@ func NewDescriptorProviderReflection(args ReflectionArgs) (DescriptorProvider, e } peers[i] = resolver.Address{Addr: p, Type: resolver.Backend} } - r.InitialState(resolver.State{Addresses: peers}) + resolverMutex.Lock() + r := GetOrGenerateAndRegisterManualResolver(args.Service, peers) conn, err := grpc.DialContext(context.Background(), r.Scheme()+":///", // minimal target to dial registered host:port pairs grpc.WithTimeout(args.Timeout), grpc.WithBlock(), grpc.WithInsecure()) + + resolverMutex.Unlock() + if err != nil { return nil, fmt.Errorf("could not reach reflection server: %s", err) } @@ -127,9 +131,21 @@ func wrapReflectionError(err error) error { return fmt.Errorf("error in protobuf reflection: %v", err) } -func GenerateAndRegisterManualResolver() (*manual.Resolver, func()) { - scheme := strconv.FormatInt(time.Now().UnixNano(), 36) +func GetOrGenerateAndRegisterManualResolver(service string, peers []resolver.Address) *manual.Resolver { + scheme := "dest-" + service + newState := resolver.State{Addresses: peers} + + rb := resolver.Get(scheme) + if rb != nil { + if r, ok := rb.(*manual.Resolver); ok { + r.InitialState(newState) + return r + } + } + r := manual.NewBuilderWithScheme(scheme) resolver.Register(r) - return r, func() { resolver.UnregisterForTesting(scheme) } + r.InitialState(newState) + + return r } diff --git a/protobuf/source_reflection_test.go b/protobuf/source_reflection_test.go index 141be5d..59d6c33 100644 --- a/protobuf/source_reflection_test.go +++ b/protobuf/source_reflection_test.go @@ -196,6 +196,44 @@ func TestReflectionRoutingHeaders(t *testing.T) { } } +func TestResolverAlreadyExists(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + ln2, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + require.NoError(t, err) + + defer ln.Close() + defer ln2.Close() + + s := grpc.NewServer() + reflection.Register(s) + + go s.Serve(ln) + go s.Serve(ln2) + + // Ensure that all streams are closed by the end of the test. + defer s.GracefulStop() + + provider, err := NewDescriptorProviderReflection(ReflectionArgs{ + Timeout: time.Second, + Peers: []string{ln.Addr().String()}, + Service: "test", + }) + require.NoError(t, err, "failed to create reflection provider") + _, err = provider.FindService("grpc.reflection.v1alpha.ServerReflection") + assert.NoError(t, err, "unexpected error") + + provider, err = NewDescriptorProviderReflection(ReflectionArgs{ + Timeout: time.Second, + Peers: []string{ln2.Addr().String()}, + Service: "test", + }) + require.NoError(t, err, "failed to create reflection provider") + _, err = provider.FindService("grpc.reflection.v1alpha.ServerReflection") + assert.NoError(t, err, "unexpected error") + +} + func TestE2eErrors(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -212,6 +250,7 @@ func TestE2eErrors(t *testing.T) { source, err := NewDescriptorProviderReflection(ReflectionArgs{ Timeout: time.Second, Peers: []string{ln.Addr().String()}, + Service: "TestE2eErrors", }) require.NoError(t, err) defer source.Close()