From eb79a331b177839868d86aa27a5afe0bdb15828a Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 14 Jun 2024 10:19:49 -0400 Subject: [PATCH] Drop authn.Request for *http.Request --- authn.go | 58 ++++++++---------------------------------------- authn_test.go | 4 ++-- examples_test.go | 6 ++--- 3 files changed, 14 insertions(+), 54 deletions(-) diff --git a/authn.go b/authn.go index 0b25476..12d5745 100644 --- a/authn.go +++ b/authn.go @@ -17,7 +17,6 @@ package authn import ( "context" - "crypto/tls" "fmt" "net/http" "strings" @@ -38,7 +37,7 @@ const infoKey key = iota // the information is automatically attached to the context using [SetInfo]. // // Implementations must be safe to call concurrently. -type AuthFunc func(ctx context.Context, req Request) (any, error) +type AuthFunc func(ctx context.Context, req *http.Request) (any, error) // SetInfo attaches authentication information to the context. It's often // useful in tests. @@ -71,58 +70,30 @@ func Errorf(template string, args ...any) *connect.Error { return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...)) } -// Request describes a single RPC invocation. -type Request struct { - Request *http.Request -} - -// BasicAuth returns the username and password provided in the request's -// Authorization header, if any. -func (r Request) BasicAuth() (username string, password string, ok bool) { - return r.Request.BasicAuth() -} - -// Cookies parses and returns the HTTP cookies sent with the request, if any. -func (r Request) Cookies() []*http.Cookie { - return r.Request.Cookies() -} - -// Cookie returns the named cookie provided in the request or -// [http.ErrNoCookie] if not found. If multiple cookies match the given name, -// only one cookie will be returned. -func (r Request) Cookie(name string) (*http.Cookie, error) { - return r.Request.Cookie(name) -} - // Procedure returns the RPC procedure name, in the form "/service/method". If // the request path does not contain a procedure name, the entire path is // returned. -func (r Request) Procedure() string { - path := strings.TrimSuffix(r.Request.URL.Path, "/") +func Procedure(request *http.Request) string { + path := strings.TrimSuffix(request.URL.Path, "/") ultimate := strings.LastIndex(path, "/") if ultimate < 0 { - return r.Request.URL.Path + return request.URL.Path } penultimate := strings.LastIndex(path[:ultimate], "/") if penultimate < 0 { - return r.Request.URL.Path + return request.URL.Path } procedure := path[penultimate:] if len(procedure) < 4 { // two slashes + service + method - return r.Request.URL.Path + return request.URL.Path } return procedure } -// ClientAddr returns the client address, in IP:port format. -func (r Request) ClientAddr() string { - return r.Request.RemoteAddr -} - // Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect], // [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb]. -func (r Request) Protocol() string { - ct := r.Request.Header.Get("Content-Type") +func Protocol(request *http.Request) string { + ct := request.Header.Get("Content-Type") switch { case strings.HasPrefix(ct, "application/grpc-web"): return connect.ProtocolGRPCWeb @@ -133,17 +104,6 @@ func (r Request) Protocol() string { } } -// Header returns the HTTP request headers. -func (r Request) Header() http.Header { - return r.Request.Header -} - -// TLS returns the TLS connection state, if any. It may be nil if the connection -// is not using TLS. -func (r Request) TLS() *tls.ConnectionState { - return r.Request.TLS -} - // Middleware is server-side HTTP middleware that authenticates RPC requests. // In addition to rejecting unauthenticated requests, it can optionally attach // arbitrary information about the authenticated identity to the context. @@ -175,7 +135,7 @@ func NewMiddleware(auth AuthFunc, opts ...connect.HandlerOption) *Middleware { func (m *Middleware) Wrap(handler http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { ctx := request.Context() - info, err := m.auth(ctx, Request{Request: request}) + info, err := m.auth(ctx, request) if err != nil { _ = m.errW.Write(writer, request, err) return diff --git a/authn_test.go b/authn_test.go index 5a22589..ab12d00 100644 --- a/authn_test.go +++ b/authn_test.go @@ -93,8 +93,8 @@ func assertInfo(ctx context.Context, tb testing.TB) { } } -func authenticate(_ context.Context, req authn.Request) (any, error) { - parts := strings.SplitN(req.Header().Get("Authorization"), " ", 2) +func authenticate(_ context.Context, req *http.Request) (any, error) { + parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2) if len(parts) < 2 || parts[0] != "Bearer" { err := authn.Errorf("expected Bearer authentication scheme") err.Meta().Set("WWW-Authenticate", "Bearer") diff --git a/examples_test.go b/examples_test.go index a25e140..d068d04 100644 --- a/examples_test.go +++ b/examples_test.go @@ -44,7 +44,7 @@ func Example_basicAuth() { // works similarly. // First, we define our authentication logic and use it to build middleware. - authenticate := func(_ context.Context, req authn.Request) (any, error) { + authenticate := func(_ context.Context, req *http.Request) (any, error) { username, password, ok := req.BasicAuth() if !ok { return nil, authn.Errorf("invalid authorization") @@ -95,8 +95,8 @@ func Example_basicAuth() { func Example_mutualTLS() { // This example shows how to use this package with mutual TLS. // First, we define our authentication logic and use it to build middleware. - authenticate := func(_ context.Context, req authn.Request) (any, error) { - tls := req.TLS() + authenticate := func(_ context.Context, req *http.Request) (any, error) { + tls := req.TLS if tls == nil { return nil, authn.Errorf("TLS required") }