-
Notifications
You must be signed in to change notification settings - Fork 3
/
common_test.go
94 lines (77 loc) · 2.31 KB
/
common_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package quotacontrol_test
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"sync/atomic"
"time"
"github.com/0xsequence/quotacontrol"
"github.com/0xsequence/quotacontrol/middleware"
"github.com/0xsequence/quotacontrol/proto"
"github.com/goware/logger"
"github.com/goware/cachestore/redis"
)
func newConfig() quotacontrol.Config {
return quotacontrol.Config{
Enabled: true,
UpdateFreq: quotacontrol.Duration{time.Minute},
Redis: redis.Config{
Enabled: true,
},
RateLimiter: middleware.RLConfig{
Enabled: true,
},
}
}
func newQuotaClient(cfg quotacontrol.Config, service proto.Service) *quotacontrol.Client {
logger := logger.NewLogger(logger.LogLevel_DEBUG).With(slog.String("client", "client"))
return quotacontrol.NewClient(logger, service, cfg, nil)
}
type hitCounter int64
func (c *hitCounter) GetValue() int64 {
return atomic.LoadInt64((*int64)(c))
}
func (c *hitCounter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64((*int64)(c), 1)
w.WriteHeader(http.StatusOK)
}
type spendingCounter int64
func (c *spendingCounter) GetValue() int64 {
return atomic.LoadInt64((*int64)(c))
}
func (c *spendingCounter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// up the counter only if quota control run
if middleware.HasSpending(r.Context()) {
atomic.AddInt64((*int64)(c), 1)
}
w.WriteHeader(http.StatusOK)
}
func executeRequest(ctx context.Context, handler http.Handler, path, accessKey, jwt string) (bool, http.Header, error) {
req, err := http.NewRequest("POST", path, nil)
if err != nil {
return false, nil, err
}
req.Header.Set("X-Real-IP", "127.0.0.1")
if accessKey != "" {
req.Header.Set(middleware.HeaderAccessKey, accessKey)
}
if jwt != "" {
req.Header.Set("Authorization", "Bearer "+jwt)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req.WithContext(ctx))
if status := rr.Result().StatusCode; status < http.StatusOK || status >= http.StatusBadRequest {
w := proto.WebRPCError{}
json.Unmarshal(rr.Body.Bytes(), &w)
return false, rr.Header(), w
}
return true, rr.Header(), nil
}
type addCost int64
func (i addCost) Middleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r.WithContext(middleware.AddCost(r.Context(), int64(i))))
})
}