Skip to content

Commit

Permalink
Generate aliases for connect.Request/Response
Browse files Browse the repository at this point in the history
Reduce Connect's generics-induced wordiness by generating type aliases
for `connect.Request` and `connect.Response`.

For an actual net reduction in wordiness, we can't generate long
identifiers. That reduces our ability to manage name collisions, so we
only generate aliases for messages that are declared in the same file
and used exclusively as requests or responses (but not both). Notably,
we don't attempt to generate aliases for the stream types - they end up
even wordier than the generic types, and they end up very confusingly
named.
  • Loading branch information
akshayjshah committed Aug 9, 2023
1 parent f281476 commit df304e9
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 119 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ type PingServer struct {
pingv1connect.UnimplementedPingServiceHandler // returns errors from all methods
}

func (ps *PingServer) Ping(
ctx context.Context,
req *connect.Request[pingv1.PingRequest],
) (*connect.Response[pingv1.PingResponse], error) {
func (ps *PingServer) Ping(ctx context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
// connect.Request and connect.Response give you direct access to headers and
// trailers. No context-based nonsense!
log.Println(req.Header().Get("Some-Header"))
Expand Down
4 changes: 1 addition & 3 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ type notModifiedPingServer struct {
etag string
}

func (s *notModifiedPingServer) Ping(
_ context.Context,
req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
func (s *notModifiedPingServer) Ping(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
if req.HTTPMethod() == http.MethodGet && req.Header().Get("If-None-Match") == s.etag {
return nil, connect.NewNotModifiedError(http.Header{"Etag": []string{s.etag}})
}
Expand Down
258 changes: 206 additions & 52 deletions cmd/protoc-gen-connect-go/main.go

Large diffs are not rendered by default.

49 changes: 23 additions & 26 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ func TestHeaderBasic(t *testing.T) {
)

pingServer := &pluggablePingServer{
ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
ping: func(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
assert.Equal(t, request.Header().Get(key), cval)
response := connect.NewResponse(&pingv1.PingResponse{})
response.Header().Set(key, hval)
Expand Down Expand Up @@ -529,7 +529,7 @@ func TestHeaderHost(t *testing.T) {
)

pingServer := &pluggablePingServer{
ping: func(_ context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
ping: func(_ context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
assert.Equal(t, request.Header().Get(key), cval)
response := connect.NewResponse(&pingv1.PingResponse{})
return response, nil
Expand Down Expand Up @@ -583,7 +583,7 @@ func TestTimeoutParsing(t *testing.T) {
t.Parallel()
const timeout = 10 * time.Minute
pingServer := &pluggablePingServer{
ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
ping: func(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
deadline, ok := ctx.Deadline()
assert.True(t, ok)
remaining := time.Until(deadline)
Expand Down Expand Up @@ -1597,7 +1597,7 @@ func TestStreamForServer(t *testing.T) {
t.Run("server-stream", func(t *testing.T) {
t.Parallel()
client, server := newPingServer(&pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error {
assert.Equal(t, stream.Conn().Spec().StreamType, connect.StreamTypeServer)
assert.Equal(t, stream.Conn().Spec().Procedure, pingv1connect.PingServiceCountUpProcedure)
assert.False(t, stream.Conn().Spec().IsClient)
Expand All @@ -1614,7 +1614,7 @@ func TestStreamForServer(t *testing.T) {
t.Run("server-stream-send", func(t *testing.T) {
t.Parallel()
client, server := newPingServer(&pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error {
assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1}))
return nil
},
Expand All @@ -1631,7 +1631,7 @@ func TestStreamForServer(t *testing.T) {
t.Run("server-stream-send-nil", func(t *testing.T) {
t.Parallel()
client, server := newPingServer(&pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error {
stream.ResponseHeader().Set("foo", "bar")
stream.ResponseTrailer().Set("bas", "blah")
assert.Nil(t, stream.Send(nil))
Expand All @@ -1653,7 +1653,7 @@ func TestStreamForServer(t *testing.T) {
t.Run("client-stream", func(t *testing.T) {
t.Parallel()
client, server := newPingServer(&pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) {
assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient)
assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure)
assert.False(t, stream.Spec().IsClient)
Expand All @@ -1675,7 +1675,7 @@ func TestStreamForServer(t *testing.T) {
t.Run("client-stream-conn", func(t *testing.T) {
t.Parallel()
client, server := newPingServer(&pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) {
assert.NotNil(t, stream.Conn().Send("not-proto"))
return connect.NewResponse(&pingv1.SumResponse{}), nil
},
Expand All @@ -1690,7 +1690,7 @@ func TestStreamForServer(t *testing.T) {
t.Run("client-stream-send-msg", func(t *testing.T) {
t.Parallel()
client, server := newPingServer(&pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) {
assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2}))
return connect.NewResponse(&pingv1.SumResponse{}), nil
},
Expand All @@ -1711,7 +1711,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) {
t.Helper()
mux := http.NewServeMux()
pluggableServer := &pluggablePingServer{
ping: func(_ context.Context, _ *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
ping: func(_ context.Context, _ *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
return nil, connect.NewError(connectCode, errors.New("error"))
},
}
Expand Down Expand Up @@ -1993,7 +1993,7 @@ func TestAllowCustomUserAgent(t *testing.T) {
const customAgent = "custom"
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{
ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
ping: func(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
agent := req.Header().Get("User-Agent")
assert.Equal(t, agent, customAgent)
return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil
Expand Down Expand Up @@ -2063,10 +2063,10 @@ func TestHandlerReturnsNilResponse(t *testing.T) {

mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{
ping: func(ctx context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
ping: func(ctx context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
return nil, nil //nolint: nilnil
},
sum: func(ctx context.Context, req *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
sum: func(ctx context.Context, req *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) {
return nil, nil //nolint: nilnil
},
}, connect.WithRecover(recoverPanic)))
Expand Down Expand Up @@ -2353,29 +2353,26 @@ func (c failCodec) Unmarshal(data []byte, message any) error {
type pluggablePingServer struct {
pingv1connect.UnimplementedPingServiceHandler

ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error)
sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error)
countUp func(context.Context, *connect.Request[pingv1.CountUpRequest], *connect.ServerStream[pingv1.CountUpResponse]) error
ping func(context.Context, *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error)
sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error)
countUp func(context.Context, *pingv1connect.CountUpRequest, *connect.ServerStream[pingv1.CountUpResponse]) error
cumSum func(context.Context, *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error
}

func (p *pluggablePingServer) Ping(
ctx context.Context,
request *connect.Request[pingv1.PingRequest],
) (*connect.Response[pingv1.PingResponse], error) {
func (p *pluggablePingServer) Ping(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
return p.ping(ctx, request)
}

func (p *pluggablePingServer) Sum(
ctx context.Context,
stream *connect.ClientStream[pingv1.SumRequest],
) (*connect.Response[pingv1.SumResponse], error) {
) (*pingv1connect.SumResponse, error) {
return p.sum(ctx, stream)
}

func (p *pluggablePingServer) CountUp(
ctx context.Context,
req *connect.Request[pingv1.CountUpRequest],
req *pingv1connect.CountUpRequest,
stream *connect.ServerStream[pingv1.CountUpResponse],
) error {
return p.countUp(ctx, req, stream)
Expand Down Expand Up @@ -2431,7 +2428,7 @@ type pingServer struct {
checkMetadata bool
}

func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
func (p pingServer) Ping(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
if err := expectClientHeader(p.checkMetadata, request); err != nil {
return nil, err
}
Expand All @@ -2452,7 +2449,7 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi
return response, nil
}

func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) {
func (p pingServer) Fail(ctx context.Context, request *pingv1connect.FailRequest) (*pingv1connect.FailResponse, error) {
if err := expectClientHeader(p.checkMetadata, request); err != nil {
return nil, err
}
Expand All @@ -2471,7 +2468,7 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa
func (p pingServer) Sum(
ctx context.Context,
stream *connect.ClientStream[pingv1.SumRequest],
) (*connect.Response[pingv1.SumResponse], error) {
) (*pingv1connect.SumResponse, error) {
if p.checkMetadata {
if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil {
return nil, err
Expand All @@ -2498,7 +2495,7 @@ func (p pingServer) Sum(

func (p pingServer) CountUp(
ctx context.Context,
request *connect.Request[pingv1.CountUpRequest],
request *pingv1connect.CountUpRequest,
stream *connect.ServerStream[pingv1.CountUpResponse],
) error {
if err := expectClientHeader(p.checkMetadata, request); err != nil {
Expand Down
5 changes: 1 addition & 4 deletions error_not_modified_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ type ExampleCachingPingServer struct {
// Ping is idempotent and free of side effects (and the Protobuf schema
// indicates this), so clients using the Connect protocol may call it with HTTP
// GET requests. This implementation uses Etags to manage client-side caching.
func (*ExampleCachingPingServer) Ping(
_ context.Context,
req *connect.Request[pingv1.PingRequest],
) (*connect.Response[pingv1.PingResponse], error) {
func (*ExampleCachingPingServer) Ping(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
resp := connect.NewResponse(&pingv1.PingResponse{
Number: req.Msg.Number,
})
Expand Down
5 changes: 1 addition & 4 deletions handler_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ type ExamplePingServer struct {
}

// Ping implements pingv1connect.PingServiceHandler.
func (*ExamplePingServer) Ping(
_ context.Context,
request *connect.Request[pingv1.PingRequest],
) (*connect.Response[pingv1.PingResponse], error) {
func (*ExamplePingServer) Ping(_ context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
return connect.NewResponse(
&pingv1.PingResponse{
Number: request.Msg.Number,
Expand Down
5 changes: 2 additions & 3 deletions handler_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (

connect "connectrpc.com/connect"
"connectrpc.com/connect/internal/assert"
pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1"
"connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect"
)

Expand Down Expand Up @@ -213,6 +212,6 @@ type successPingServer struct {
pingv1connect.UnimplementedPingServiceHandler
}

func (successPingServer) Ping(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
return &connect.Response[pingv1.PingResponse]{}, nil
func (successPingServer) Ping(context.Context, *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) {
return &pingv1connect.PingResponse{}, nil
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit df304e9

Please sign in to comment.