From b3a3e19b651e8cf0ee6822100346c4b0f8491bd8 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 1 Feb 2024 19:12:06 -0500 Subject: [PATCH] Remove authn.Request Replaces authn.Request with *http.Request for better interop. Signed-off-by: Edward McFarlane --- authn.go | 58 ++++++++---------------------------------------- authn_test.go | 4 ++-- examples_test.go | 6 ++--- 3 files changed, 14 insertions(+), 54 deletions(-) diff --git a/authn.go b/authn.go index 5adfab4..7e9976e 100644 --- a/authn.go +++ b/authn.go @@ -17,7 +17,6 @@ package authn import ( "context" - "crypto/tls" "fmt" "net/http" "strings" @@ -38,7 +37,7 @@ const infoKey key = iota // the information is automatically attached to the context using [SetInfo]. // // Implementations must be safe to call concurrently. -type AuthFunc func(ctx context.Context, req Request) (any, error) +type AuthFunc func(ctx context.Context, req *http.Request) (any, error) // SetInfo attaches authentication information to the context. It's often // useful in tests. @@ -71,58 +70,30 @@ func Errorf(template string, args ...any) *connect.Error { return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...)) } -// Request describes a single RPC invocation. -type Request struct { - request *http.Request -} - -// BasicAuth returns the username and password provided in the request's -// Authorization header, if any. -func (r Request) BasicAuth() (username string, password string, ok bool) { - return r.request.BasicAuth() -} - -// Cookies parses and returns the HTTP cookies sent with the request, if any. -func (r Request) Cookies() []*http.Cookie { - return r.request.Cookies() -} - -// Cookie returns the named cookie provided in the request or -// [http.ErrNoCookie] if not found. If multiple cookies match the given name, -// only one cookie will be returned. -func (r Request) Cookie(name string) (*http.Cookie, error) { - return r.request.Cookie(name) -} - // Procedure returns the RPC procedure name, in the form "/service/method". If // the request path does not contain a procedure name, the entire path is // returned. -func (r Request) Procedure() string { - path := strings.TrimSuffix(r.request.URL.Path, "/") +func Procedure(request *http.Request) string { + path := strings.TrimSuffix(request.URL.Path, "/") ultimate := strings.LastIndex(path, "/") if ultimate < 0 { - return r.request.URL.Path + return request.URL.Path } penultimate := strings.LastIndex(path[:ultimate], "/") if penultimate < 0 { - return r.request.URL.Path + return request.URL.Path } procedure := path[penultimate:] if len(procedure) < 4 { // two slashes + service + method - return r.request.URL.Path + return request.URL.Path } return procedure } -// ClientAddr returns the client address, in IP:port format. -func (r Request) ClientAddr() string { - return r.request.RemoteAddr -} - // Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect], // [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb]. -func (r Request) Protocol() string { - ct := r.request.Header.Get("Content-Type") +func Protocol(request *http.Request) string { + ct := request.Header.Get("Content-Type") switch { case strings.HasPrefix(ct, "application/grpc-web"): return connect.ProtocolGRPCWeb @@ -133,17 +104,6 @@ func (r Request) Protocol() string { } } -// Header returns the HTTP request headers. -func (r Request) Header() http.Header { - return r.request.Header -} - -// TLS returns the TLS connection state, if any. It may be nil if the connection -// is not using TLS. -func (r Request) TLS() *tls.ConnectionState { - return r.request.TLS -} - // Middleware is server-side HTTP middleware that authenticates RPC requests. // In addition to rejecting unauthenticated requests, it can optionally attach // arbitrary information about the authenticated identity to the context. @@ -175,7 +135,7 @@ func NewMiddleware(auth AuthFunc, opts ...connect.HandlerOption) *Middleware { func (m *Middleware) Wrap(handler http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { ctx := request.Context() - info, err := m.auth(ctx, Request{request: request}) + info, err := m.auth(ctx, request) if err != nil { _ = m.errW.Write(writer, request, err) return diff --git a/authn_test.go b/authn_test.go index 0c428b7..86ed4a9 100644 --- a/authn_test.go +++ b/authn_test.go @@ -93,8 +93,8 @@ func assertInfo(ctx context.Context, tb testing.TB) { } } -func authenticate(_ context.Context, req authn.Request) (any, error) { - parts := strings.SplitN(req.Header().Get("Authorization"), " ", 2) +func authenticate(_ context.Context, req *http.Request) (any, error) { + parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2) if len(parts) < 2 || parts[0] != "Bearer" { err := authn.Errorf("expected Bearer authentication scheme") err.Meta().Set("WWW-Authenticate", "Bearer") diff --git a/examples_test.go b/examples_test.go index 2c0365c..2044c1b 100644 --- a/examples_test.go +++ b/examples_test.go @@ -44,7 +44,7 @@ func Example_basicAuth() { // works similarly. // First, we define our authentication logic and use it to build middleware. - authenticate := func(_ context.Context, req authn.Request) (any, error) { + authenticate := func(_ context.Context, req *http.Request) (any, error) { username, password, ok := req.BasicAuth() if !ok { return nil, authn.Errorf("invalid authorization") @@ -95,8 +95,8 @@ func Example_basicAuth() { func Example_mutualTLS() { // This example shows how to use this package with mutual TLS. // First, we define our authentication logic and use it to build middleware. - authenticate := func(_ context.Context, req authn.Request) (any, error) { - tls := req.TLS() + authenticate := func(_ context.Context, req *http.Request) (any, error) { + tls := req.TLS if tls == nil { return nil, authn.Errorf("TLS required") }