Skip to content

Commit

Permalink
Add BearerToken with allowlist example (#12)
Browse files Browse the repository at this point in the history
Creates a new example based on bearer tokens, a common authentication
requirement, with a new helper method BearerToken to assist. This is
used to provide the example. A new method Echo is added to the example
ping service to showcase conditionally applying authentication based on
the Procedure name.

Signed-off-by: Edward McFarlane <[email protected]>
  • Loading branch information
emcfarlane authored Oct 2, 2024
1 parent 135799f commit 0d60722
Show file tree
Hide file tree
Showing 7 changed files with 371 additions and 55 deletions.
5 changes: 4 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,7 @@ issues:
exclude-rules:
# We need to init a global in-mem HTTP server for testable examples.
- path: examples_test.go
linters: [gocritic, gochecknoglobals, gosec, exhaustruct]
linters: [gocritic, gochecknoglobals, gosec]
# Allow more lenient rules in example code for brevity.
- path: examples_test.go
linters: [exhaustruct, nilnil, varnamelen]
12 changes: 12 additions & 0 deletions authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ func InferProcedure(url *url.URL) (string, bool) {
return procedure, true
}

// BearerToken returns the bearer token provided in the request's Authorization
// header, if any.
func BearerToken(request *http.Request) (string, bool) {
const prefix = "Bearer "
auth := request.Header.Get("Authorization")
// Case insensitive prefix match. See RFC 9110 Section 11.1.
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return "", false
}
return auth[len(prefix):], true
}

// Middleware is server-side HTTP middleware that authenticates RPC requests.
// In addition to rejecting unauthenticated requests, it can optionally attach
// arbitrary information about the authenticated identity to the context.
Expand Down
17 changes: 13 additions & 4 deletions authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ func assertInfo(ctx context.Context, tb testing.TB) {
}

func authenticate(_ context.Context, req *http.Request) (any, error) {
parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
if len(parts) < 2 || parts[0] != "Bearer" {
token, ok := authn.BearerToken(req)
if !ok {
err := authn.Errorf("expected Bearer authentication scheme")
err.Meta().Set("WWW-Authenticate", "Bearer")
return nil, err
}
if tok := parts[1]; tok != passphrase {
return nil, authn.Errorf("%q is not the magic passphrase", tok)
if token != passphrase {
return nil, authn.Errorf("%q is not the magic passphrase", token)
}
return hero, nil
}
Expand Down Expand Up @@ -256,3 +256,12 @@ func TestInferProtocol(t *testing.T) {
})
}
}

func TestBearerTokenCaseInsensitive(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/service/Method", nil)
req.Header.Set("Authorization", "bearer "+passphrase)
token, ok := authn.BearerToken(req)
assert.True(t, ok)
assert.Equal(t, passphrase, token)
}
130 changes: 123 additions & 7 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/http"
Expand Down Expand Up @@ -62,7 +63,7 @@ func Example_basicAuth() {

// Next, we build our Connect handler.
mux := http.NewServeMux()
service := &pingv1connect.UnimplementedPingServiceHandler{}
service := &pingHandler{}
mux.Handle(pingv1connect.NewPingServiceHandler(service))

// Finally, we wrap the handler with our middleware and start our server.
Expand All @@ -78,20 +79,108 @@ func Example_basicAuth() {
"Basic "+base64.StdEncoding.EncodeToString([]byte("Aladdin:open-sesame")),
)
_, err := client.Ping(context.Background(), req)

// We're using the UnimplementedPingServiceHandler stub, so authenticated
// clients should receive an error with CodeUnimplemented.
if connect.CodeOf(err) == connect.CodeUnimplemented {
fmt.Println("client received response")
} else {
if err != nil {
fmt.Printf("unexpected error: %v\n", err)
return
}
fmt.Println("client received response")

// Output:
// authenticated request from Aladdin
// client received response
}

func Example_bearerToken() {
// This example shows how to use this package with bearer token authentication.
// Any header-based authentication (including cookies and HTTP basic auth)
// works similarly.

// We'll use a simple allow list to demonstrate how to add authorization logic
// conditionally based on the request's procedure.
allowList := map[string]struct{}{
// Procedure constants are available in the generated code.
pingv1connect.PingServicePingProcedure: {},
}
// And a simple token-to-user map to demonstrate how to authenticate
// requests based on a bearer token.
tokenToUser := map[string]string{
"open-sesame": "Aladdin",
}

// First, we define our authentication logic and use it to build middleware.
authenticate := func(_ context.Context, req *http.Request) (any, error) {
// Infer the procedure from the request URL.
procedure, _ := authn.InferProcedure(req.URL)
// Extract the bearer token from the Authorization header.
token, ok := authn.BearerToken(req)
if !ok {
// We'll allow unauthenticated access to the ping procedure.
if _, ok := allowList[procedure]; ok {
fmt.Println("no authentication required for", procedure)
return nil, nil // no authentication required
}
fmt.Println("authentication required for", procedure)
err := authn.Errorf("invalid authorization")
err.Meta().Set("WWW-Authenticate", "Bearer")
return nil, err
}
user, ok := tokenToUser[token]
if !ok {
return nil, authn.Errorf("invalid token")
}
// The request is authenticated!
fmt.Println("authenticated request from", user, "for", procedure)
return user, nil
}
middleware := authn.NewMiddleware(authenticate)

// Next, we build our Connect handler.
mux := http.NewServeMux()
service := &pingHandler{}
mux.Handle(pingv1connect.NewPingServiceHandler(service))

// Finally, we wrap the handler with our middleware and start our server.
handler := middleware.Wrap(mux)
server := httptest.NewServer(handler)
defer server.Close()

// Create an unauthenticated call to the ping procedure.
client := pingv1connect.NewPingServiceClient(http.DefaultClient, server.URL)
if _, err := client.Ping(context.Background(), connect.NewRequest(
&pingv1.PingRequest{Text: "hello"},
)); err != nil {
fmt.Printf("unexpected error: %v\n", err)
return
}
fmt.Println("client received response")

// Create an unauthenticated call to the echo procedure.
if _, err := client.Echo(context.Background(), connect.NewRequest(
&pingv1.EchoRequest{Text: "hello"},
)); connect.CodeOf(err) != connect.CodeUnauthenticated {
fmt.Printf("unexpected error: %v\n", err)
return
}
fmt.Println("client unauthorized")

// Create an authenticated call to the echo procedure.
req := connect.NewRequest(&pingv1.EchoRequest{Text: "hello"})
req.Header().Set("Authorization", "Bearer open-sesame")
if _, err := client.Echo(context.Background(), req); err != nil {
fmt.Printf("unexpected error: %v\n", err)
return
}
fmt.Println("client received response")

// Output:
// no authentication required for /authn.ping.v1.PingService/Ping
// client received response
// authentication required for /authn.ping.v1.PingService/Echo
// client unauthorized
// authenticated request from Aladdin for /authn.ping.v1.PingService/Echo
// client received response
}

func Example_mutualTLS() {
// This example shows how to use this package with mutual TLS.
// First, we define our authentication logic and use it to build middleware.
Expand Down Expand Up @@ -273,3 +362,30 @@ func equal(left, right string) bool {
// Using subtle prevents some timing attacks.
return subtle.ConstantTimeCompare([]byte(left), []byte(right)) == 1
}

type pingHandler struct {
pingv1connect.UnimplementedPingServiceHandler
}

func (pingHandler) Ping(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
return connect.NewResponse(&pingv1.PingResponse{Text: req.Msg.Text}), nil
}

func (pingHandler) Echo(_ context.Context, req *connect.Request[pingv1.EchoRequest]) (*connect.Response[pingv1.EchoResponse], error) {
return connect.NewResponse(&pingv1.EchoResponse{Text: req.Msg.Text}), nil
}

func (pingHandler) PingStream(_ context.Context, stream *connect.BidiStream[pingv1.PingStreamRequest, pingv1.PingStreamResponse]) error {
for {
req, err := stream.Receive()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
if err := stream.Send(&pingv1.PingStreamResponse{Text: req.Text}); err != nil {
return err
}
}
}
Loading

0 comments on commit 0d60722

Please sign in to comment.