Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace authn.Request for *http.Request #9

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 9 additions & 49 deletions authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package authn

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strings"
Expand All @@ -38,7 +37,7 @@ const infoKey key = iota
// the information is automatically attached to the context using [SetInfo].
//
// Implementations must be safe to call concurrently.
type AuthFunc func(ctx context.Context, req Request) (any, error)
type AuthFunc func(ctx context.Context, req *http.Request) (any, error)

// SetInfo attaches authentication information to the context. It's often
// useful in tests.
Expand Down Expand Up @@ -71,58 +70,30 @@ func Errorf(template string, args ...any) *connect.Error {
return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...))
}

// Request describes a single RPC invocation.
type Request struct {
request *http.Request
}

// BasicAuth returns the username and password provided in the request's
// Authorization header, if any.
func (r Request) BasicAuth() (username string, password string, ok bool) {
return r.request.BasicAuth()
}

// Cookies parses and returns the HTTP cookies sent with the request, if any.
func (r Request) Cookies() []*http.Cookie {
return r.request.Cookies()
}

// Cookie returns the named cookie provided in the request or
// [http.ErrNoCookie] if not found. If multiple cookies match the given name,
// only one cookie will be returned.
func (r Request) Cookie(name string) (*http.Cookie, error) {
return r.request.Cookie(name)
}

// Procedure returns the RPC procedure name, in the form "/service/method". If
// the request path does not contain a procedure name, the entire path is
// returned.
func (r Request) Procedure() string {
path := strings.TrimSuffix(r.request.URL.Path, "/")
func Procedure(request *http.Request) string {
path := strings.TrimSuffix(request.URL.Path, "/")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this? Do connect RPC servers actually accept an invalid trailing slash like this? Pretty sure gRPC servers are usually strict and do not allow this.

ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return r.request.URL.Path
return request.URL.Path
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return r.request.URL.Path
return request.URL.Path
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
return r.request.URL.Path
return request.URL.Path
}
return procedure
}

// ClientAddr returns the client address, in IP:port format.
func (r Request) ClientAddr() string {
return r.request.RemoteAddr
}

// Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect],
// [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func (r Request) Protocol() string {
ct := r.request.Header.Get("Content-Type")
func Protocol(request *http.Request) string {
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
ct := request.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ct, "application/grpc-web"):
return connect.ProtocolGRPCWeb
Expand All @@ -133,17 +104,6 @@ func (r Request) Protocol() string {
}
}

// Header returns the HTTP request headers.
func (r Request) Header() http.Header {
return r.request.Header
}

// TLS returns the TLS connection state, if any. It may be nil if the connection
// is not using TLS.
func (r Request) TLS() *tls.ConnectionState {
return r.request.TLS
}

// Middleware is server-side HTTP middleware that authenticates RPC requests.
// In addition to rejecting unauthenticated requests, it can optionally attach
// arbitrary information about the authenticated identity to the context.
Expand Down Expand Up @@ -175,7 +135,7 @@ func NewMiddleware(auth AuthFunc, opts ...connect.HandlerOption) *Middleware {
func (m *Middleware) Wrap(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
info, err := m.auth(ctx, Request{request: request})
info, err := m.auth(ctx, request)
if err != nil {
_ = m.errW.Write(writer, request, err)
return
Expand Down
4 changes: 2 additions & 2 deletions authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ func assertInfo(ctx context.Context, tb testing.TB) {
}
}

func authenticate(_ context.Context, req authn.Request) (any, error) {
parts := strings.SplitN(req.Header().Get("Authorization"), " ", 2)
func authenticate(_ context.Context, req *http.Request) (any, error) {
parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
if len(parts) < 2 || parts[0] != "Bearer" {
err := authn.Errorf("expected Bearer authentication scheme")
err.Meta().Set("WWW-Authenticate", "Bearer")
Expand Down
6 changes: 3 additions & 3 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Example_basicAuth() {
// works similarly.

// First, we define our authentication logic and use it to build middleware.
authenticate := func(_ context.Context, req authn.Request) (any, error) {
authenticate := func(_ context.Context, req *http.Request) (any, error) {
username, password, ok := req.BasicAuth()
if !ok {
return nil, authn.Errorf("invalid authorization")
Expand Down Expand Up @@ -95,8 +95,8 @@ func Example_basicAuth() {
func Example_mutualTLS() {
// This example shows how to use this package with mutual TLS.
// First, we define our authentication logic and use it to build middleware.
authenticate := func(_ context.Context, req authn.Request) (any, error) {
tls := req.TLS()
authenticate := func(_ context.Context, req *http.Request) (any, error) {
tls := req.TLS
if tls == nil {
return nil, authn.Errorf("TLS required")
}
Expand Down