diff --git a/README.md b/README.md index 266ce081..048784c8 100644 --- a/README.md +++ b/README.md @@ -73,10 +73,7 @@ type PingServer struct { pingv1connect.UnimplementedPingServiceHandler // returns errors from all methods } -func (ps *PingServer) Ping( - ctx context.Context, - req *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (ps *PingServer) Ping(ctx context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { // connect.Request and connect.Response give you direct access to headers and // trailers. No context-based nonsense! log.Println(req.Header().Get("Some-Header")) diff --git a/client_ext_test.go b/client_ext_test.go index 40cbabed..5bf4050b 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -195,9 +195,7 @@ type notModifiedPingServer struct { etag string } -func (s *notModifiedPingServer) Ping( - _ context.Context, - req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { +func (s *notModifiedPingServer) Ping(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { if req.HTTPMethod() == http.MethodGet && req.Header().Get("If-None-Match") == s.etag { return nil, connect.NewNotModifiedError(http.Header{"Etag": []string{s.etag}}) } diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index 61aeedcd..3f2493c6 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -135,8 +135,12 @@ func generate(plugin *protogen.Plugin, file *protogen.File) { generatedFile.Import(file.GoImportPath) generatePreamble(generatedFile, file) generateServiceNameConstants(generatedFile, file.Services) + + paramNames := newParameterNames(generatedFile, file.GoImportPath, file.Services) + generateTypeAliases(generatedFile, paramNames) + for _, service := range file.Services { - generateService(generatedFile, service) + generateService(generatedFile, service, paramNames) } } @@ -219,16 +223,16 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge g.P() } -func generateService(g *protogen.GeneratedFile, service *protogen.Service) { +func generateService(g *protogen.GeneratedFile, service *protogen.Service, paramNames *parameterNames) { names := newNames(service) - generateClientInterface(g, service, names) - generateClientImplementation(g, service, names) - generateServerInterface(g, service, names) + generateClientInterface(g, service, names, paramNames) + generateClientImplementation(g, service, names, paramNames) + generateServerInterface(g, service, names, paramNames) generateServerConstructor(g, service, names) - generateUnimplementedServerImplementation(g, service, names) + generateUnimplementedServerImplementation(g, service, names, paramNames) } -func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Service, names names, paramNames *parameterNames) { wrapComments(g, names.Client, " is a client for the ", service.Desc.FullName(), " service.") if isDeprecatedService(service) { g.P("//") @@ -243,13 +247,13 @@ func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Servic method.Comments.Leading, isDeprecatedMethod(method), ) - g.P(clientSignature(g, method, false /* named */)) + g.P(clientSignature(g, method, paramNames.Get(method), false /* named */)) } g.P("}") g.P() } -func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names, paramNames *parameterNames) { clientOption := connectPackage.Ident("ClientOption") // Client constructor. @@ -304,11 +308,11 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S g.P("}") g.P() for _, method := range service.Methods { - generateClientMethod(g, method, names) + generateClientMethod(g, method, names, paramNames.Get(method)) } } -func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, names names) { +func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, names names, paramNames *methodParameterNames) { receiver := names.ClientImpl isStreamingClient := method.Desc.IsStreamingClient() isStreamingServer := method.Desc.IsStreamingServer() @@ -317,7 +321,7 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na g.P("//") deprecated(g) } - g.P("func (c *", receiver, ") ", clientSignature(g, method, true /* named */), " {") + g.P("func (c *", receiver, ") ", clientSignature(g, method, paramNames, true /* named */), " {") switch { case isStreamingClient && !isStreamingServer: @@ -333,37 +337,31 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na g.P() } -func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, named bool) string { +func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, paramNames *methodParameterNames, named bool) string { reqName := "req" ctxName := "ctx" if !named { reqName, ctxName = "", "" } + ctxType := g.QualifiedGoIdent(contextPackage.Ident("Context")) if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() { // bidi streaming - return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + - "*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStreamForClient")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + return method.GoName + "(" + ctxName + " " + ctxType + ") *" + paramNames.ClientOutput.Name() } if method.Desc.IsStreamingClient() { // client streaming - return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + - "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClient")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + return method.GoName + "(" + ctxName + " " + ctxType + ") *" + paramNames.ClientOutput.Name() } if method.Desc.IsStreamingServer() { - return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + - ", " + reqName + " *" + g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" + - g.QualifiedGoIdent(method.Input.GoIdent) + "]) " + - "(*" + g.QualifiedGoIdent(connectPackage.Ident("ServerStreamForClient")) + - "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + - ", error)" + return method.GoName + "(" + ctxName + " " + ctxType + + ", " + reqName + " *" + paramNames.ClientInput.Name() + ") " + + "(*" + paramNames.ClientOutput.Name() + ", error)" } // unary; symmetric so we can re-use server templating - return method.GoName + serverSignatureParams(g, method, named) + return method.GoName + serverSignatureParams(g, method, paramNames, named) } -func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Service, names names, paramNames *parameterNames) { wrapComments(g, names.Server, " is an implementation of the ", service.Desc.FullName(), " service.") if isDeprecatedService(service) { g.P("//") @@ -378,7 +376,7 @@ func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Servic isDeprecatedMethod(method), ) g.AnnotateSymbol(names.Server+"."+method.GoName, protogen.Annotation{Location: method.Location}) - g.P(serverSignature(g, method)) + g.P(serverSignature(g, method, paramNames.Get(method))) } g.P("}") g.P() @@ -439,12 +437,12 @@ func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Serv g.P() } -func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names, paramNames *parameterNames) { wrapComments(g, names.UnimplementedServer, " returns CodeUnimplemented from all methods.") g.P("type ", names.UnimplementedServer, " struct {}") g.P() for _, method := range service.Methods { - g.P("func (", names.UnimplementedServer, ") ", serverSignature(g, method), "{") + g.P("func (", names.UnimplementedServer, ") ", serverSignature(g, method, paramNames.Get(method)), "{") if method.Desc.IsStreamingServer() { g.P("return ", connectPackage.Ident("NewError"), "(", connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"), @@ -460,46 +458,47 @@ func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, servic g.P() } -func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string { - return method.GoName + serverSignatureParams(g, method, false /* named */) +func generateTypeAliases(g *protogen.GeneratedFile, paramNames *parameterNames) { + if len(paramNames.Aliases) == 0 { + return + } + g.P("type (") + for _, alias := range paramNames.Aliases { + g.P(alias[0], " = ", alias[1]) + } + g.P(")") + g.P() } -func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, named bool) string { +func serverSignature(g *protogen.GeneratedFile, method *protogen.Method, paramNames *methodParameterNames) string { + return method.GoName + serverSignatureParams(g, method, paramNames, false /* named */) +} + +func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, paramNames *methodParameterNames, named bool) string { ctxName := "ctx " reqName := "req " streamName := "stream " if !named { ctxName, reqName, streamName = "", "", "" } + ctxType := g.QualifiedGoIdent(contextPackage.Ident("Context")) if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() { // bidi streaming - return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + - streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStream")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + - ") error" + return "(" + ctxName + ctxType + ", " + streamName + "*" + paramNames.HandlerInput.Name() + ") error" } if method.Desc.IsStreamingClient() { // client streaming - return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + - streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + - ") (*" + g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "] ,error)" + return "(" + ctxName + ctxType + ", " + streamName + "*" + paramNames.HandlerInput.Name() + + ") (*" + paramNames.HandlerOutput.Name() + " ,error)" } if method.Desc.IsStreamingServer() { // server streaming - return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + - ", " + reqName + "*" + g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" + - g.QualifiedGoIdent(method.Input.GoIdent) + "], " + - streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ServerStream")) + - "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + - ") error" + return "(" + ctxName + ctxType + ", " + reqName + "*" + paramNames.HandlerInput.Name() + ", " + + streamName + "*" + paramNames.HandlerOutput.Name() + ") error" } // unary - return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + - ", " + reqName + "*" + g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" + - g.QualifiedGoIdent(method.Input.GoIdent) + "]) " + - "(*" + g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + - g.QualifiedGoIdent(method.Output.GoIdent) + "], error)" + return "(" + ctxName + ctxType + ", " + reqName + "*" + paramNames.HandlerInput.Name() + ") " + + "(*" + paramNames.HandlerOutput.Name() + ", error)" } func procedureConstName(m *protogen.Method) string { @@ -628,3 +627,158 @@ func newNames(service *protogen.Service) names { UnimplementedServer: fmt.Sprintf("Unimplemented%sHandler", base), } } + +type parameterNames struct { + Aliases [][2]string + Methods map[protoreflect.FullName]*methodParameterNames +} + +func newParameterNames(g *protogen.GeneratedFile, baseTypes protogen.GoImportPath, services []*protogen.Service) *parameterNames { + // First, make one pass to find alias-able request and response types. We're + // trying to shorten user-visible type names, so there's no point in + // producing aliases that are just as long as the spelled-out generic types. + // + // To safely produce short aliases, we're only aliasing messages that are: + // - used as a connect.Request or connect.Response, but not both. + // - from the same protobuf package and file as the service. + // Ideally we'd allow aliases for types from different files in the same + // package, but the plugin contract doesn't allow us to inspect services in + // files other than the ones we're generating code for. + // + // Notably, we're not generating aliases for Connect's stream types: useful + // aliases for them are just as wordy as the generic types, so the extra + // indirection isn't worth it. + const ( + asRequest = 0b01 + asResponse = 0b10 + ) + aliasable := make(map[protoreflect.FullName]uint8) + for _, service := range services { + pkg := service.Desc.ParentFile().Package() + path := service.Desc.ParentFile().Path() + for _, method := range service.Methods { + if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() { + continue + } + if method.Input.Desc.ParentFile().Package() == pkg && method.Input.Desc.ParentFile().Path() == path { + aliasable[method.Input.Desc.FullName()] |= asRequest + } + if method.Output.Desc.ParentFile().Package() == pkg && method.Input.Desc.ParentFile().Path() == path { + aliasable[method.Output.Desc.FullName()] |= asResponse + } + } + } + for fqn, usage := range aliasable { + if usage == asRequest&asResponse { + delete(aliasable, fqn) + } + } + // Now, make another pass to choose names. + params := ¶meterNames{Methods: make(map[protoreflect.FullName]*methodParameterNames)} + for _, service := range services { + for _, method := range service.Methods { + isStreamingClient := method.Desc.IsStreamingClient() + isStreamingServer := method.Desc.IsStreamingServer() + methodParams := &methodParameterNames{} + if isStreamingClient && isStreamingServer { + methodParams.ClientOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("BidiStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + methodParams.HandlerInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("BidiStream")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + } else if isStreamingClient { + methodParams.ClientOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + methodParams.HandlerInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + methodParams.HandlerOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Response")) + + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + if _, ok := aliasable[method.Output.Desc.FullName()]; ok { + methodParams.HandlerOutput.Alias = method.Output.GoIdent.GoName + } + } else if isStreamingServer { + methodParams.ClientInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Request")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + methodParams.ClientOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("ServerStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + methodParams.HandlerInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Request")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + methodParams.HandlerOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("ServerStream")) + + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + if _, ok := aliasable[method.Input.Desc.FullName()]; ok { + methodParams.ClientInput.Alias = method.Input.GoIdent.GoName + methodParams.HandlerInput.Alias = methodParams.ClientInput.Alias + } + } else { + methodParams.ClientInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" + + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + methodParams.ClientOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + methodParams.HandlerInput.Generic = methodParams.ClientInput.Generic + methodParams.HandlerOutput.Generic = methodParams.ClientOutput.Generic + if _, ok := aliasable[method.Input.Desc.FullName()]; ok { + methodParams.ClientInput.Alias = method.Input.GoIdent.GoName + methodParams.HandlerInput.Alias = methodParams.ClientInput.Alias + } + if _, ok := aliasable[method.Output.Desc.FullName()]; ok { + methodParams.ClientOutput.Alias = method.Output.GoIdent.GoName + methodParams.HandlerOutput.Alias = methodParams.ClientOutput.Alias + } + } + params.Methods[method.Desc.FullName()] = methodParams + } + } + // Finally, another pass to prepare the actual alias declarations. We need to + // deduplicate (in case the same message is used in multiple RPCs), and we'd + // like the aliases to appear in the same order as they're used in the RPC + // definitions. + for _, service := range services { + for _, method := range service.Methods { + methodParams := params.Get(method) + if _, ok := aliasable[method.Input.Desc.FullName()]; ok { + if methodParams.ClientInput.Alias != "" { + params.Aliases = append(params.Aliases, [2]string{methodParams.ClientInput.Alias, methodParams.ClientInput.Generic}) + } + if methodParams.HandlerInput.Alias != "" && methodParams.HandlerInput.Alias != methodParams.ClientInput.Alias { + params.Aliases = append(params.Aliases, [2]string{methodParams.HandlerInput.Alias, methodParams.HandlerInput.Generic}) + } + delete(aliasable, method.Input.Desc.FullName()) + } + if _, ok := aliasable[method.Output.Desc.FullName()]; ok { + if methodParams.ClientOutput.Alias != "" { + params.Aliases = append(params.Aliases, [2]string{methodParams.ClientOutput.Alias, methodParams.ClientOutput.Generic}) + } + if methodParams.HandlerOutput.Alias != "" && methodParams.HandlerOutput.Alias != methodParams.ClientOutput.Alias { + params.Aliases = append(params.Aliases, [2]string{methodParams.HandlerOutput.Alias, methodParams.HandlerOutput.Generic}) + } + delete(aliasable, method.Output.Desc.FullName()) + } + } + } + return params +} + +func (pn *parameterNames) Get(method *protogen.Method) *methodParameterNames { + return pn.Methods[method.Desc.FullName()] +} + +type methodParameterNames struct { + ClientInput aliasedTypeName + ClientOutput aliasedTypeName + HandlerInput aliasedTypeName + HandlerOutput aliasedTypeName +} + +type aliasedTypeName struct { + Generic string + Alias string +} + +func (n aliasedTypeName) Name() string { + if n.Alias != "" { + return n.Alias + } + return n.Generic +} diff --git a/connect_ext_test.go b/connect_ext_test.go index 4545071c..d75a2996 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -501,7 +501,7 @@ func TestHeaderBasic(t *testing.T) { ) pingServer := &pluggablePingServer{ - ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { assert.Equal(t, request.Header().Get(key), cval) response := connect.NewResponse(&pingv1.PingResponse{}) response.Header().Set(key, hval) @@ -529,7 +529,7 @@ func TestHeaderHost(t *testing.T) { ) pingServer := &pluggablePingServer{ - ping: func(_ context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(_ context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { assert.Equal(t, request.Header().Get(key), cval) response := connect.NewResponse(&pingv1.PingResponse{}) return response, nil @@ -583,7 +583,7 @@ func TestTimeoutParsing(t *testing.T) { t.Parallel() const timeout = 10 * time.Minute pingServer := &pluggablePingServer{ - ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { deadline, ok := ctx.Deadline() assert.True(t, ok) remaining := time.Until(deadline) @@ -1597,7 +1597,7 @@ func TestStreamForServer(t *testing.T) { t.Run("server-stream", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Equal(t, stream.Conn().Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, stream.Conn().Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.False(t, stream.Conn().Spec().IsClient) @@ -1614,7 +1614,7 @@ func TestStreamForServer(t *testing.T) { t.Run("server-stream-send", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1})) return nil }, @@ -1631,7 +1631,7 @@ func TestStreamForServer(t *testing.T) { t.Run("server-stream-send-nil", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error { stream.ResponseHeader().Set("foo", "bar") stream.ResponseTrailer().Set("bas", "blah") assert.Nil(t, stream.Send(nil)) @@ -1653,7 +1653,7 @@ func TestStreamForServer(t *testing.T) { t.Run("client-stream", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) { assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) assert.False(t, stream.Spec().IsClient) @@ -1675,7 +1675,7 @@ func TestStreamForServer(t *testing.T) { t.Run("client-stream-conn", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) { assert.NotNil(t, stream.Conn().Send("not-proto")) return connect.NewResponse(&pingv1.SumResponse{}), nil }, @@ -1690,7 +1690,7 @@ func TestStreamForServer(t *testing.T) { t.Run("client-stream-send-msg", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) { assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2})) return connect.NewResponse(&pingv1.SumResponse{}), nil }, @@ -1711,7 +1711,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { t.Helper() mux := http.NewServeMux() pluggableServer := &pluggablePingServer{ - ping: func(_ context.Context, _ *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(_ context.Context, _ *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { return nil, connect.NewError(connectCode, errors.New("error")) }, } @@ -1993,7 +1993,7 @@ func TestAllowCustomUserAgent(t *testing.T) { const customAgent = "custom" mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ - ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { agent := req.Header().Get("User-Agent") assert.Equal(t, agent, customAgent) return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil @@ -2063,10 +2063,10 @@ func TestHandlerReturnsNilResponse(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ - ping: func(ctx context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(ctx context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { return nil, nil //nolint: nilnil }, - sum: func(ctx context.Context, req *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + sum: func(ctx context.Context, req *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) { return nil, nil //nolint: nilnil }, }, connect.WithRecover(recoverPanic))) @@ -2353,29 +2353,26 @@ func (c failCodec) Unmarshal(data []byte, message any) error { type pluggablePingServer struct { pingv1connect.UnimplementedPingServiceHandler - ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) - sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) - countUp func(context.Context, *connect.Request[pingv1.CountUpRequest], *connect.ServerStream[pingv1.CountUpResponse]) error + ping func(context.Context, *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) + sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) + countUp func(context.Context, *pingv1connect.CountUpRequest, *connect.ServerStream[pingv1.CountUpResponse]) error cumSum func(context.Context, *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error } -func (p *pluggablePingServer) Ping( - ctx context.Context, - request *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (p *pluggablePingServer) Ping(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { return p.ping(ctx, request) } func (p *pluggablePingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], -) (*connect.Response[pingv1.SumResponse], error) { +) (*pingv1connect.SumResponse, error) { return p.sum(ctx, stream) } func (p *pluggablePingServer) CountUp( ctx context.Context, - req *connect.Request[pingv1.CountUpRequest], + req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { return p.countUp(ctx, req, stream) @@ -2431,7 +2428,7 @@ type pingServer struct { checkMetadata bool } -func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { +func (p pingServer) Ping(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2452,7 +2449,7 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi return response, nil } -func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { +func (p pingServer) Fail(ctx context.Context, request *pingv1connect.FailRequest) (*pingv1connect.FailResponse, error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2471,7 +2468,7 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa func (p pingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], -) (*connect.Response[pingv1.SumResponse], error) { +) (*pingv1connect.SumResponse, error) { if p.checkMetadata { if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { return nil, err @@ -2498,7 +2495,7 @@ func (p pingServer) Sum( func (p pingServer) CountUp( ctx context.Context, - request *connect.Request[pingv1.CountUpRequest], + request *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { if err := expectClientHeader(p.checkMetadata, request); err != nil { diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index 881a4b45..b1f05f84 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -34,10 +34,7 @@ type ExampleCachingPingServer struct { // Ping is idempotent and free of side effects (and the Protobuf schema // indicates this), so clients using the Connect protocol may call it with HTTP // GET requests. This implementation uses Etags to manage client-side caching. -func (*ExampleCachingPingServer) Ping( - _ context.Context, - req *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (*ExampleCachingPingServer) Ping(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { resp := connect.NewResponse(&pingv1.PingResponse{ Number: req.Msg.Number, }) diff --git a/handler_example_test.go b/handler_example_test.go index f5c2c0da..2d0666d6 100644 --- a/handler_example_test.go +++ b/handler_example_test.go @@ -30,10 +30,7 @@ type ExamplePingServer struct { } // Ping implements pingv1connect.PingServiceHandler. -func (*ExamplePingServer) Ping( - _ context.Context, - request *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (*ExamplePingServer) Ping(_ context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { return connect.NewResponse( &pingv1.PingResponse{ Number: request.Msg.Number, diff --git a/handler_ext_test.go b/handler_ext_test.go index ca71712e..ea3052b0 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -24,7 +24,6 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" - pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) @@ -213,6 +212,6 @@ type successPingServer struct { pingv1connect.UnimplementedPingServiceHandler } -func (successPingServer) Ping(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { - return &connect.Response[pingv1.PingResponse]{}, nil +func (successPingServer) Ping(context.Context, *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { + return &pingv1connect.PingResponse{}, nil } diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index e0ae47ae..483e1c4b 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -51,9 +51,14 @@ const ( CollideServiceImportProcedure = "/connect.collide.v1.CollideService/Import" ) +type ( + ImportRequest = connect.Request[v1.ImportRequest] + ImportResponse = connect.Response[v1.ImportResponse] +) + // CollideServiceClient is a client for the connect.collide.v1.CollideService service. type CollideServiceClient interface { - Import(context.Context, *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) + Import(context.Context, *ImportRequest) (*ImportResponse, error) } // NewCollideServiceClient constructs a client for the connect.collide.v1.CollideService service. By @@ -80,13 +85,13 @@ type collideServiceClient struct { } // Import calls connect.collide.v1.CollideService.Import. -func (c *collideServiceClient) Import(ctx context.Context, req *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) { +func (c *collideServiceClient) Import(ctx context.Context, req *ImportRequest) (*ImportResponse, error) { return c._import.CallUnary(ctx, req) } // CollideServiceHandler is an implementation of the connect.collide.v1.CollideService service. type CollideServiceHandler interface { - Import(context.Context, *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) + Import(context.Context, *ImportRequest) (*ImportResponse, error) } // NewCollideServiceHandler builds an HTTP handler from the service implementation. It returns the @@ -113,6 +118,6 @@ func NewCollideServiceHandler(svc CollideServiceHandler, opts ...connect.Handler // UnimplementedCollideServiceHandler returns CodeUnimplemented from all methods. type UnimplementedCollideServiceHandler struct{} -func (UnimplementedCollideServiceHandler) Import(context.Context, *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) { +func (UnimplementedCollideServiceHandler) Import(context.Context, *ImportRequest) (*ImportResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.collide.v1.CollideService.Import is not implemented")) } diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index 7a99236c..5fc78432 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -64,16 +64,25 @@ const ( PingServiceCumSumProcedure = "/connect.ping.v1.PingService/CumSum" ) +type ( + PingRequest = connect.Request[v1.PingRequest] + PingResponse = connect.Response[v1.PingResponse] + FailRequest = connect.Request[v1.FailRequest] + FailResponse = connect.Response[v1.FailResponse] + SumResponse = connect.Response[v1.SumResponse] + CountUpRequest = connect.Request[v1.CountUpRequest] +) + // PingServiceClient is a client for the connect.ping.v1.PingService service. type PingServiceClient interface { // Ping sends a ping to the server to determine if it's reachable. - Ping(context.Context, *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error) + Ping(context.Context, *PingRequest) (*PingResponse, error) // Fail always fails. - Fail(context.Context, *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error) + Fail(context.Context, *FailRequest) (*FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. Sum(context.Context) *connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse] // CountUp returns a stream of the numbers up to the given request. - CountUp(context.Context, *connect.Request[v1.CountUpRequest]) (*connect.ServerStreamForClient[v1.CountUpResponse], error) + CountUp(context.Context, *CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) // CumSum determines the cumulative sum of all the numbers sent on the stream. CumSum(context.Context) *connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse] } @@ -127,12 +136,12 @@ type pingServiceClient struct { } // Ping calls connect.ping.v1.PingService.Ping. -func (c *pingServiceClient) Ping(ctx context.Context, req *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error) { +func (c *pingServiceClient) Ping(ctx context.Context, req *PingRequest) (*PingResponse, error) { return c.ping.CallUnary(ctx, req) } // Fail calls connect.ping.v1.PingService.Fail. -func (c *pingServiceClient) Fail(ctx context.Context, req *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error) { +func (c *pingServiceClient) Fail(ctx context.Context, req *FailRequest) (*FailResponse, error) { return c.fail.CallUnary(ctx, req) } @@ -142,7 +151,7 @@ func (c *pingServiceClient) Sum(ctx context.Context) *connect.ClientStreamForCli } // CountUp calls connect.ping.v1.PingService.CountUp. -func (c *pingServiceClient) CountUp(ctx context.Context, req *connect.Request[v1.CountUpRequest]) (*connect.ServerStreamForClient[v1.CountUpResponse], error) { +func (c *pingServiceClient) CountUp(ctx context.Context, req *CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) { return c.countUp.CallServerStream(ctx, req) } @@ -154,13 +163,13 @@ func (c *pingServiceClient) CumSum(ctx context.Context) *connect.BidiStreamForCl // PingServiceHandler is an implementation of the connect.ping.v1.PingService service. type PingServiceHandler interface { // Ping sends a ping to the server to determine if it's reachable. - Ping(context.Context, *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error) + Ping(context.Context, *PingRequest) (*PingResponse, error) // Fail always fails. - Fail(context.Context, *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error) + Fail(context.Context, *FailRequest) (*FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. - Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error) + Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*SumResponse, error) // CountUp returns a stream of the numbers up to the given request. - CountUp(context.Context, *connect.Request[v1.CountUpRequest], *connect.ServerStream[v1.CountUpResponse]) error + CountUp(context.Context, *CountUpRequest, *connect.ServerStream[v1.CountUpResponse]) error // CumSum determines the cumulative sum of all the numbers sent on the stream. CumSum(context.Context, *connect.BidiStream[v1.CumSumRequest, v1.CumSumResponse]) error } @@ -218,19 +227,19 @@ func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption // UnimplementedPingServiceHandler returns CodeUnimplemented from all methods. type UnimplementedPingServiceHandler struct{} -func (UnimplementedPingServiceHandler) Ping(context.Context, *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error) { +func (UnimplementedPingServiceHandler) Ping(context.Context, *PingRequest) (*PingResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Ping is not implemented")) } -func (UnimplementedPingServiceHandler) Fail(context.Context, *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error) { +func (UnimplementedPingServiceHandler) Fail(context.Context, *FailRequest) (*FailResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Fail is not implemented")) } -func (UnimplementedPingServiceHandler) Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error) { +func (UnimplementedPingServiceHandler) Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*SumResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Sum is not implemented")) } -func (UnimplementedPingServiceHandler) CountUp(context.Context, *connect.Request[v1.CountUpRequest], *connect.ServerStream[v1.CountUpResponse]) error { +func (UnimplementedPingServiceHandler) CountUp(context.Context, *CountUpRequest, *connect.ServerStream[v1.CountUpResponse]) error { return connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.CountUp is not implemented")) } diff --git a/recover_ext_test.go b/recover_ext_test.go index 99df97c9..54b37319 100644 --- a/recover_ext_test.go +++ b/recover_ext_test.go @@ -33,16 +33,13 @@ type panicPingServer struct { panicWith any } -func (s *panicPingServer) Ping( - context.Context, - *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (s *panicPingServer) Ping(context.Context, *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { panic(s.panicWith) //nolint:forbidigo } func (s *panicPingServer) CountUp( _ context.Context, - _ *connect.Request[pingv1.CountUpRequest], + _ *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { if err := stream.Send(&pingv1.CountUpResponse{}); err != nil {