From b3a3e19b651e8cf0ee6822100346c4b0f8491bd8 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 1 Feb 2024 19:12:06 -0500 Subject: [PATCH 1/4] Remove authn.Request Replaces authn.Request with *http.Request for better interop. Signed-off-by: Edward McFarlane --- authn.go | 58 ++++++++---------------------------------------- authn_test.go | 4 ++-- examples_test.go | 6 ++--- 3 files changed, 14 insertions(+), 54 deletions(-) diff --git a/authn.go b/authn.go index 5adfab4..7e9976e 100644 --- a/authn.go +++ b/authn.go @@ -17,7 +17,6 @@ package authn import ( "context" - "crypto/tls" "fmt" "net/http" "strings" @@ -38,7 +37,7 @@ const infoKey key = iota // the information is automatically attached to the context using [SetInfo]. // // Implementations must be safe to call concurrently. -type AuthFunc func(ctx context.Context, req Request) (any, error) +type AuthFunc func(ctx context.Context, req *http.Request) (any, error) // SetInfo attaches authentication information to the context. It's often // useful in tests. @@ -71,58 +70,30 @@ func Errorf(template string, args ...any) *connect.Error { return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...)) } -// Request describes a single RPC invocation. -type Request struct { - request *http.Request -} - -// BasicAuth returns the username and password provided in the request's -// Authorization header, if any. -func (r Request) BasicAuth() (username string, password string, ok bool) { - return r.request.BasicAuth() -} - -// Cookies parses and returns the HTTP cookies sent with the request, if any. -func (r Request) Cookies() []*http.Cookie { - return r.request.Cookies() -} - -// Cookie returns the named cookie provided in the request or -// [http.ErrNoCookie] if not found. If multiple cookies match the given name, -// only one cookie will be returned. -func (r Request) Cookie(name string) (*http.Cookie, error) { - return r.request.Cookie(name) -} - // Procedure returns the RPC procedure name, in the form "/service/method". If // the request path does not contain a procedure name, the entire path is // returned. -func (r Request) Procedure() string { - path := strings.TrimSuffix(r.request.URL.Path, "/") +func Procedure(request *http.Request) string { + path := strings.TrimSuffix(request.URL.Path, "/") ultimate := strings.LastIndex(path, "/") if ultimate < 0 { - return r.request.URL.Path + return request.URL.Path } penultimate := strings.LastIndex(path[:ultimate], "/") if penultimate < 0 { - return r.request.URL.Path + return request.URL.Path } procedure := path[penultimate:] if len(procedure) < 4 { // two slashes + service + method - return r.request.URL.Path + return request.URL.Path } return procedure } -// ClientAddr returns the client address, in IP:port format. -func (r Request) ClientAddr() string { - return r.request.RemoteAddr -} - // Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect], // [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb]. -func (r Request) Protocol() string { - ct := r.request.Header.Get("Content-Type") +func Protocol(request *http.Request) string { + ct := request.Header.Get("Content-Type") switch { case strings.HasPrefix(ct, "application/grpc-web"): return connect.ProtocolGRPCWeb @@ -133,17 +104,6 @@ func (r Request) Protocol() string { } } -// Header returns the HTTP request headers. -func (r Request) Header() http.Header { - return r.request.Header -} - -// TLS returns the TLS connection state, if any. It may be nil if the connection -// is not using TLS. -func (r Request) TLS() *tls.ConnectionState { - return r.request.TLS -} - // Middleware is server-side HTTP middleware that authenticates RPC requests. // In addition to rejecting unauthenticated requests, it can optionally attach // arbitrary information about the authenticated identity to the context. @@ -175,7 +135,7 @@ func NewMiddleware(auth AuthFunc, opts ...connect.HandlerOption) *Middleware { func (m *Middleware) Wrap(handler http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { ctx := request.Context() - info, err := m.auth(ctx, Request{request: request}) + info, err := m.auth(ctx, request) if err != nil { _ = m.errW.Write(writer, request, err) return diff --git a/authn_test.go b/authn_test.go index 0c428b7..86ed4a9 100644 --- a/authn_test.go +++ b/authn_test.go @@ -93,8 +93,8 @@ func assertInfo(ctx context.Context, tb testing.TB) { } } -func authenticate(_ context.Context, req authn.Request) (any, error) { - parts := strings.SplitN(req.Header().Get("Authorization"), " ", 2) +func authenticate(_ context.Context, req *http.Request) (any, error) { + parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2) if len(parts) < 2 || parts[0] != "Bearer" { err := authn.Errorf("expected Bearer authentication scheme") err.Meta().Set("WWW-Authenticate", "Bearer") diff --git a/examples_test.go b/examples_test.go index 2c0365c..2044c1b 100644 --- a/examples_test.go +++ b/examples_test.go @@ -44,7 +44,7 @@ func Example_basicAuth() { // works similarly. // First, we define our authentication logic and use it to build middleware. - authenticate := func(_ context.Context, req authn.Request) (any, error) { + authenticate := func(_ context.Context, req *http.Request) (any, error) { username, password, ok := req.BasicAuth() if !ok { return nil, authn.Errorf("invalid authorization") @@ -95,8 +95,8 @@ func Example_basicAuth() { func Example_mutualTLS() { // This example shows how to use this package with mutual TLS. // First, we define our authentication logic and use it to build middleware. - authenticate := func(_ context.Context, req authn.Request) (any, error) { - tls := req.TLS() + authenticate := func(_ context.Context, req *http.Request) (any, error) { + tls := req.TLS if tls == nil { return nil, authn.Errorf("TLS required") } From 70132910e0643396996bca30c21fb074d564252a Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 14 Jun 2024 11:46:20 -0400 Subject: [PATCH 2/4] Update README example Signed-off-by: Edward McFarlane --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bf45960..a757496 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ import ( "connectrpc.com/authn/internal/gen/authn/ping/v1/pingv1connect" ) -func authenticate(_ context.Context, req authn.Request) (any, error) { +func authenticate(_ context.Context, req *http.Request) (any, error) { username, password, ok := req.BasicAuth() if !ok { return nil, authn.Errorf("invalid authorization") From 478629fa19287cab2f4b0591060c0a9bc2c74a73 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 25 Jun 2024 14:20:16 -0400 Subject: [PATCH 3/4] Remove protocol type and procedure funcs Signed-off-by: Edward McFarlane --- authn.go | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/authn.go b/authn.go index 7e9976e..f3ae5c3 100644 --- a/authn.go +++ b/authn.go @@ -19,7 +19,6 @@ import ( "context" "fmt" "net/http" - "strings" "connectrpc.com/connect" ) @@ -70,40 +69,6 @@ func Errorf(template string, args ...any) *connect.Error { return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...)) } -// Procedure returns the RPC procedure name, in the form "/service/method". If -// the request path does not contain a procedure name, the entire path is -// returned. -func Procedure(request *http.Request) string { - path := strings.TrimSuffix(request.URL.Path, "/") - ultimate := strings.LastIndex(path, "/") - if ultimate < 0 { - return request.URL.Path - } - penultimate := strings.LastIndex(path[:ultimate], "/") - if penultimate < 0 { - return request.URL.Path - } - procedure := path[penultimate:] - if len(procedure) < 4 { // two slashes + service + method - return request.URL.Path - } - return procedure -} - -// Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect], -// [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb]. -func Protocol(request *http.Request) string { - ct := request.Header.Get("Content-Type") - switch { - case strings.HasPrefix(ct, "application/grpc-web"): - return connect.ProtocolGRPCWeb - case strings.HasPrefix(ct, "application/grpc"): - return connect.ProtocolGRPC - default: - return connect.ProtocolConnect - } -} - // Middleware is server-side HTTP middleware that authenticates RPC requests. // In addition to rejecting unauthenticated requests, it can optionally attach // arbitrary information about the authenticated identity to the context. From a8cdbbd4cad148e4a74f74f92b87f0c0268cb39d Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 3 Sep 2024 18:35:05 -0400 Subject: [PATCH 4/4] Add inferred methods Signed-off-by: Edward McFarlane --- authn.go | 35 +++++++++++++++++++++++++++++++ authn_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/authn.go b/authn.go index f3ae5c3..1ee1071 100644 --- a/authn.go +++ b/authn.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "net/http" + "strings" "connectrpc.com/connect" ) @@ -69,6 +70,40 @@ func Errorf(template string, args ...any) *connect.Error { return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...)) } +// InferProtocol returns the inferred RPC protocol. It is one of +// [connect.ProtocolConnect], [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb]. +func InferProtocol(request *http.Request) string { + ct := request.Header.Get("Content-Type") + switch { + case strings.HasPrefix(ct, "application/grpc-web"): + return connect.ProtocolGRPCWeb + case strings.HasPrefix(ct, "application/grpc"): + return connect.ProtocolGRPC + default: + return connect.ProtocolConnect + } +} + +// InferProcedure returns the inferred RPC procedure. It is of the form +// "/service/method". If the request path does not contain a procedure name, the +// entire path is returned. +func InferProcedure(request *http.Request) string { + path := strings.TrimSuffix(request.URL.Path, "/") + ultimate := strings.LastIndex(path, "/") + if ultimate < 0 { + return request.URL.Path + } + penultimate := strings.LastIndex(path[:ultimate], "/") + if penultimate < 0 { + return request.URL.Path + } + procedure := path[penultimate:] + if len(procedure) < 4 { // two slashes + service + method + return request.URL.Path + } + return procedure +} + // Middleware is server-side HTTP middleware that authenticates RPC requests. // In addition to rejecting unauthenticated requests, it can optionally attach // arbitrary information about the authenticated identity to the context. diff --git a/authn_test.go b/authn_test.go index 86ed4a9..1ba4842 100644 --- a/authn_test.go +++ b/authn_test.go @@ -23,6 +23,7 @@ import ( "testing" "connectrpc.com/authn" + "connectrpc.com/connect" "github.com/stretchr/testify/assert" ) @@ -105,3 +106,59 @@ func authenticate(_ context.Context, req *http.Request) (any, error) { } return hero, nil } + +func TestInferProcedures(t *testing.T) { + t.Parallel() + testProcedures := [][2]string{ + {"/empty.v1/GetEmpty", "/empty.v1/GetEmpty"}, + {"/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"}, + {"/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"}, + {"/prefix/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"}, + {"/", "/"}, + {"/invalid/", "/invalid/"}, + } + for _, tt := range testProcedures { + req := httptest.NewRequest(http.MethodPost, tt[0], strings.NewReader("{}")) + assert.Equal(t, tt[1], authn.InferProcedure(req)) + } +} + +func TestInferProtocol(t *testing.T) { + t.Parallel() + tests := []struct { + name string + contentType string + method string + wantProtocol string + }{{ + name: "connect", + contentType: "application/json", + wantProtocol: connect.ProtocolConnect, + }, { + name: "connectSubPath", + contentType: "application/connect+json", + wantProtocol: connect.ProtocolConnect, + }, { + name: "grpc", + contentType: "application/grpc+proto", + wantProtocol: connect.ProtocolGRPC, + }, { + name: "grpcWeb", + contentType: "application/grpc-web", + wantProtocol: connect.ProtocolGRPCWeb, + }, { + name: "grpcWeb", + contentType: "application/grpc-web+json", + wantProtocol: connect.ProtocolGRPCWeb, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/service/Method", strings.NewReader("{}")) + if tt.contentType != "" { + req.Header.Set("Content-Type", tt.contentType) + } + assert.Equal(t, tt.wantProtocol, authn.InferProtocol(req)) + }) + } +}