forked from connectrpc/authn-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
authn.go
189 lines (167 loc) · 6.05 KB
/
authn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
// Copyright 2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package authn provides authentication middleware for [connect].
package authn
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strings"
"connectrpc.com/connect"
)
type key int
const infoKey key = iota
// An AuthFunc authenticates an RPC. The function must return an error if the
// request cannot be authenticated. The error is typically produced with
// [Errorf], but any error will do.
//
// If requests are successfully authenticated, the authentication function may
// return some information about the authenticated caller (or nil). If non-nil,
// the information is automatically attached to the context using [SetInfo].
//
// Implementations must be safe to call concurrently.
type AuthFunc func(ctx context.Context, req Request) (any, error)
// SetInfo attaches authentication information to the context. It's often
// useful in tests.
//
// [AuthFunc] implementations do not need to call SetInfo explicitly. Any
// returned authentication information is automatically added to the context by
// [Middleware].
func SetInfo(ctx context.Context, info any) context.Context {
if info == nil {
return ctx
}
return context.WithValue(ctx, infoKey, info)
}
// GetInfo retrieves authentication information, if any, from the request
// context.
func GetInfo(ctx context.Context) any {
return ctx.Value(infoKey)
}
// WithoutInfo strips the authentication information, if any, from the provided
// context.
func WithoutInfo(ctx context.Context) context.Context {
return context.WithValue(ctx, infoKey, nil)
}
// Errorf is a convenience function that returns an error coded with
// [connect.CodeUnauthenticated].
func Errorf(template string, args ...any) *connect.Error {
return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...))
}
// Request describes a single RPC invocation.
type Request struct {
request *http.Request
}
// BasicAuth returns the username and password provided in the request's
// Authorization header, if any.
func (r Request) BasicAuth() (username string, password string, ok bool) {
return r.request.BasicAuth()
}
// Cookies parses and returns the HTTP cookies sent with the request, if any.
func (r Request) Cookies() []*http.Cookie {
return r.request.Cookies()
}
// Cookie returns the named cookie provided in the request or
// [http.ErrNoCookie] if not found. If multiple cookies match the given name,
// only one cookie will be returned.
func (r Request) Cookie(name string) (*http.Cookie, error) {
return r.request.Cookie(name)
}
// Procedure returns the RPC procedure name, in the form "/service/method". If
// the request path does not contain a procedure name, the entire path is
// returned.
func (r Request) Procedure() string {
path := strings.TrimSuffix(r.request.URL.Path, "/")
ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return r.request.URL.Path
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return r.request.URL.Path
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
return r.request.URL.Path
}
return procedure
}
// ClientAddr returns the client address, in IP:port format.
func (r Request) ClientAddr() string {
return r.request.RemoteAddr
}
// Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect],
// [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func (r Request) Protocol() string {
ct := r.request.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ct, "application/grpc-web"):
return connect.ProtocolGRPCWeb
case strings.HasPrefix(ct, "application/grpc"):
return connect.ProtocolGRPC
default:
return connect.ProtocolConnect
}
}
// Header returns the HTTP request headers.
func (r Request) Header() http.Header {
return r.request.Header
}
// TLS returns the TLS connection state, if any. It may be nil if the connection
// is not using TLS.
func (r Request) TLS() *tls.ConnectionState {
return r.request.TLS
}
// Middleware is server-side HTTP middleware that authenticates RPC requests.
// In addition to rejecting unauthenticated requests, it can optionally attach
// arbitrary information about the authenticated identity to the context.
//
// Middleware operates at a lower level than Connect interceptors, so the
// server doesn't decompress and unmarshal the request until the caller has
// been authenticated.
type Middleware struct {
auth AuthFunc
errW *connect.ErrorWriter
}
// NewMiddleware constructs HTTP middleware using the supplied authentication
// function. If authentication succeeds, the authentication information (if
// any) will be attached to the context. Subsequent HTTP middleware, all RPC
// interceptors, and application code may access it with [GetInfo].
//
// In order to properly marshal errors, applications must pass NewMiddleware
// the same handler options used when constructing Connect handlers.
func NewMiddleware(auth AuthFunc, opts ...connect.HandlerOption) *Middleware {
return &Middleware{
auth: auth,
errW: connect.NewErrorWriter(opts...),
}
}
// Wrap returns an HTTP handler that authenticates requests before forwarding
// them to handler.
func (m *Middleware) Wrap(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
info, err := m.auth(ctx, Request{request: request})
if err != nil {
_ = m.errW.Write(writer, request, err)
return
}
if info != nil {
ctx = SetInfo(ctx, info)
request = request.WithContext(ctx)
}
handler.ServeHTTP(writer, request)
})
}