-
Notifications
You must be signed in to change notification settings - Fork 0
/
shed.go
181 lines (150 loc) · 5.31 KB
/
shed.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// Package shed implements client and server middleware to propagate and
// respect client timeouts between services.
package shed
import (
"context"
"net/http"
"strconv"
"time"
)
const (
// Header contains the header key expected to be set by incoming requests
// in order to propagate timeouts across the network, it is expected to be
// a string parseable into an int64 which represents the timeout of the
// client in milliseconds.
Header = "X-Client-Timeout-Ms"
)
type roundTripper struct {
next http.RoundTripper
maxTimeout int64
until func(time.Time) time.Duration
}
// RoundTrip wraps the given round tripper, setting the `X-Client-Timeout-Ms`
// on any request made to the number of milliseconds left until the request's
// context deadline will be exceeded.
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
millisecondsLeft := rt.millisecondsLeft(r.Context())
if millisecondsLeft > 0 {
r.Header.Add(Header, strconv.FormatInt(millisecondsLeft, 10))
}
return rt.next.RoundTrip(r)
}
func (rt *roundTripper) millisecondsLeft(ctx context.Context) int64 {
deadlineMs := int64(0)
deadline, ok := ctx.Deadline()
if ok {
deadlineMs = int64(rt.until(deadline) / time.Millisecond)
}
if rt.maxTimeout > 0 && (!ok || rt.maxTimeout < deadlineMs) {
deadlineMs = rt.maxTimeout
}
return deadlineMs
}
// Client builds a new *http.Client from the given *http.Client, wrapping the
// given client's Transport using RoundTripper.
func Client(c *http.Client, opts ...RoundTripperOpt) *http.Client {
transport := c.Transport
if transport == nil {
transport = http.DefaultTransport
}
return &http.Client{
Transport: RoundTripper(transport, opts...),
CheckRedirect: c.CheckRedirect,
Jar: c.Jar,
Timeout: c.Timeout,
}
}
// RoundTripperOpt is a function which can modify the behaviour of the shed
// client transport middleware.
type RoundTripperOpt func(*roundTripper)
// WithUntilFunc is a function used to calculate the duration until deadline
// expiration. The default function is backed by `time.Until`.
func WithUntilFunc(until func(time.Time) time.Duration) RoundTripperOpt {
return func(rt *roundTripper) {
rt.until = until
}
}
// WithMaxTimeout will set a default X-Client-Timeout-Ms if it is lower than
// any context.Context deadline on the request.
//
// This is intended to be used in cases where some other timeouts are set on
// the client, e.g. ResponseHeaderTimeout.
func WithMaxTimeout(d time.Duration) RoundTripperOpt {
return func(rt *roundTripper) {
rt.maxTimeout = int64(d / time.Millisecond)
}
}
// RoundTripper builds a new http.RoundTripper which propagates context
// deadlines over the network via the `X-Client-Timeout-Ms` request header.
func RoundTripper(n http.RoundTripper, opts ...RoundTripperOpt) http.RoundTripper {
rt := &roundTripper{next: n, until: time.Until}
for _, opt := range opts {
opt(rt)
}
return rt
}
type propagateMiddleware struct {
next http.Handler
delta func(r *http.Request) time.Duration
}
// ServeHTTP will set the `X-Client-Timeout-Ms` value (adjusted via any
// provided Delta function) as the current requests context deadline.
func (h *propagateMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
value, err := strconv.ParseInt(r.Header.Get(Header), 10, 64)
if err == nil && value > 0 {
timeout := (time.Duration(value) * time.Millisecond) - h.delta(r)
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
r = r.WithContext(ctx)
}
h.next.ServeHTTP(w, r)
}
// PropagateMiddlewareOpt is a function which can modify the behaviour of the shed
// middleware.
type PropagateMiddlewareOpt func(*propagateMiddleware)
// WithDelta allows for adjusting the timeout set by the Middleware, in order
// to account for time spent in the network or on various server queues.
//
// The value returned by this function will by subtracted from the
// `X-Client-Timeout-Ms` value.
func WithDelta(f func(*http.Request) time.Duration) PropagateMiddlewareOpt {
return func(m *propagateMiddleware) {
m.delta = f
}
}
// PropagateMiddleware builds a new http.Handler middleware which sets a context timeout
// on incoming requests if the client has propagated its timeout via the
// `X-Client-Timeout-Ms` header.
func PropagateMiddleware(n http.Handler, opts ...PropagateMiddlewareOpt) http.Handler {
m := &propagateMiddleware{
next: n,
delta: func(_ *http.Request) time.Duration {
return time.Duration(0)
},
}
for _, opt := range opts {
opt(m)
}
return m
}
type defaultTimeoutMiddleware struct {
n http.Handler
f func(*http.Request) time.Duration
}
func (m *defaultTimeoutMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if timeout := m.f(r); timeout > time.Duration(0) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
r = r.WithContext(ctx)
}
m.n.ServeHTTP(w, r)
}
// DefaultTimeoutMiddleware wraps the given handler with a default context
// deadline propagated via the request context.
//
// The timeout function can be used to have dynamic request based upper bounds
// for requests. If this function returns a time.Duration that is not strictly
// greater than 0, no timeout will be set.
func DefaultTimeoutMiddleware(n http.Handler, timeout func(*http.Request) time.Duration) http.Handler {
return &defaultTimeoutMiddleware{n: n, f: timeout}
}