Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace authn.Request for *http.Request #9

Merged
merged 11 commits into from
Sep 30, 2024
35 changes: 35 additions & 0 deletions authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"fmt"
"net/http"
"strings"

"connectrpc.com/connect"
)
Expand Down Expand Up @@ -69,6 +70,40 @@ func Errorf(template string, args ...any) *connect.Error {
return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...))
}

// InferProtocol returns the inferred RPC protocol. It is one of
// [connect.ProtocolConnect], [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func InferProtocol(request *http.Request) string {
ct := request.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ct, "application/grpc-web"):
return connect.ProtocolGRPCWeb
case strings.HasPrefix(ct, "application/grpc"):
return connect.ProtocolGRPC
default:
return connect.ProtocolConnect
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no value in returning "unknown" (or empty string, etc) when the request doesn't look like any of these? Since this is middleware, it seems highly likely it could be used with a mux that has both connect and non-connect routes, so I think we do need better classification here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now return "", false.

}
}

// InferProcedure returns the inferred RPC procedure. It is of the form
// "/service/method". If the request path does not contain a procedure name, the
// entire path is returned.
func InferProcedure(request *http.Request) string {
path := strings.TrimSuffix(request.URL.Path, "/")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this? Do connect RPC servers actually accept an invalid trailing slash like this? Pretty sure gRPC servers are usually strict and do not allow this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed allowing trailing suffix.

ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return request.URL.Path
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return request.URL.Path
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
return request.URL.Path
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same remark as in connect-go PR. This would incorrectly allow "//foo".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a testcase.

return procedure
}

// Middleware is server-side HTTP middleware that authenticates RPC requests.
// In addition to rejecting unauthenticated requests, it can optionally attach
// arbitrary information about the authenticated identity to the context.
Expand Down
57 changes: 57 additions & 0 deletions authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"testing"

"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -105,3 +106,59 @@
}
return hero, nil
}

func TestInferProcedures(t *testing.T) {
t.Parallel()
testProcedures := [][2]string{
{"/empty.v1/GetEmpty", "/empty.v1/GetEmpty"},
{"/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/prefix/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/", "/"},
{"/invalid/", "/invalid/"},
}
for _, tt := range testProcedures {
req := httptest.NewRequest(http.MethodPost, tt[0], strings.NewReader("{}"))
assert.Equal(t, tt[1], authn.InferProcedure(req))
}
}

func TestInferProtocol(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentType string
method string
wantProtocol string
}{{
name: "connect",
contentType: "application/json",
wantProtocol: connect.ProtocolConnect,
}, {
name: "connectSubPath",
contentType: "application/connect+json",
wantProtocol: connect.ProtocolConnect,
}, {
name: "grpc",
contentType: "application/grpc+proto",
wantProtocol: connect.ProtocolGRPC,
}, {
name: "grpcWeb",
contentType: "application/grpc-web",
wantProtocol: connect.ProtocolGRPCWeb,
}, {
name: "grpcWeb",
contentType: "application/grpc-web+json",
wantProtocol: connect.ProtocolGRPCWeb,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/service/Method", strings.NewReader("{}"))
if tt.contentType != "" {

Check failure on line 158 in authn_test.go

View workflow job for this annotation

GitHub Actions / ci (1.22.x)

loop variable tt captured by func literal
req.Header.Set("Content-Type", tt.contentType)

Check failure on line 159 in authn_test.go

View workflow job for this annotation

GitHub Actions / ci (1.22.x)

loop variable tt captured by func literal
}
assert.Equal(t, tt.wantProtocol, authn.InferProtocol(req))

Check failure on line 161 in authn_test.go

View workflow job for this annotation

GitHub Actions / ci (1.22.x)

loop variable tt captured by func literal
})
}
}
Loading