Skip to content

Commit

Permalink
refactor: Move access control logic to middleware
Browse files Browse the repository at this point in the history
* Fix passing Grafana team ID to middleware

* Pass query params to final handler using context

* Adapt tests and update fixtures

Signed-off-by: Mahendra Paipuri <[email protected]>
  • Loading branch information
mahendrapaipuri committed Apr 16, 2024
1 parent a78b96f commit a262f34
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 119 deletions.
3 changes: 2 additions & 1 deletion pkg/lb/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (lb *CEEMSLoadBalancer) Main() error {
).Default("").String()
configFile = lb.App.Flag(
"config.file",
"Config file containing backend server web URLs.",
"Configuration file path.",
).Default("").String()
maxProcs = lb.App.Flag(
"runtime.gomaxprocs", "The target number of CPUs Go will run on (GOMAXPROCS)",
Expand Down Expand Up @@ -136,6 +136,7 @@ func (lb *CEEMSLoadBalancer) Main() error {
AdminUsers: config.AdminUsers,
Manager: manager,
Grafana: grafanaClient,
GrafanaTeamID: config.Grafana.AdminTeamID,
}

// Create frontend instance
Expand Down
6 changes: 3 additions & 3 deletions pkg/lb/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ func TestCEEMSLBMainSuccess(t *testing.T) {
func TestCEEMSLBMainFail(t *testing.T) {
// Remove test related args and add a dummy arg
os.Args = []string{os.Args[0]}
a := CEEMSLoadBalancer{
appName: mockCEEMSLBAppName,
App: mockCEEMSLBApp,
a, err := NewCEEMSLoadBalancer()
if err != nil {
t.Fatal(err)
}

// Start Main
Expand Down
63 changes: 20 additions & 43 deletions pkg/lb/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
package frontend

import (
"bytes"
"context"
"database/sql"
"fmt"
"io"
"net/http"
"time"

Expand All @@ -22,6 +20,15 @@ import (
// RetryContextKey is the key used to set context value for retry
type RetryContextKey struct{}

// QueryParamsContextKey is the key used to set context value for query params
type QueryParamsContextKey struct{}

// QueryParams is the context value
type QueryParams struct {
uuids []string
queryPeriod time.Duration
}

// LoadBalancer is the interface to implement
type LoadBalancer interface {
Serve(http.ResponseWriter, *http.Request)
Expand All @@ -39,6 +46,7 @@ type Config struct {
AdminUsers []string
Manager serverpool.Manager
Grafana *grafana.Grafana
GrafanaTeamID string
}

// loadBalancer struct
Expand Down Expand Up @@ -73,10 +81,11 @@ func NewLoadBalancer(c *Config) (LoadBalancer, error) {
manager: c.Manager,
db: db,
amw: authenticationMiddleware{
logger: c.Logger,
adminUsers: c.AdminUsers,
grafana: c.Grafana,
db: db,
logger: c.Logger,
adminUsers: c.AdminUsers,
grafana: c.Grafana,
db: db,
grafanaTeamID: c.GrafanaTeamID,
},
}, nil
}
Expand Down Expand Up @@ -113,49 +122,17 @@ func (lb *loadBalancer) Shutdown(ctx context.Context) error {
// Serve serves the request using a backend TSDB server from the pool
func (lb *loadBalancer) Serve(w http.ResponseWriter, r *http.Request) {
var queryPeriod time.Duration
var body []byte
var err error

// Make a new request and add newReader to that request body
newReq := r.Clone(r.Context())

// If request has no body go to proxy directly
if r.Body == nil {
goto proxy
}

// If failed to read body, skip verification and go to request proxy
body, err = io.ReadAll(r.Body)
if err != nil {
level.Error(lb.logger).Log("msg", "Failed to read request body", "err", err)
goto proxy
}

// clone body to existing request and new request
r.Body = io.NopCloser(bytes.NewReader(body))
newReq.Body = io.NopCloser(bytes.NewReader(body))

// Get form values
if err := newReq.ParseForm(); err != nil {
level.Error(lb.logger).Log("msg", "Could not parse request body", "err", err)
goto proxy
}

// If not userUnits, forbid request
// if !lb.userUnits(newReq) {
// http.Error(w, "Bad request", http.StatusBadRequest)
// return
// }
// Retrieve query params from context
queryParams := r.Context().Value(QueryParamsContextKey{})

// Get query period and target server will depend on this
if startTime, err := parseTimeParam(newReq, "start", time.Now().UTC()); err != nil {
level.Error(lb.logger).Log("msg", "Could not parse start query param", "err", err)
// Check if queryParams is nil which could happen in edge cases
if queryParams == nil {
queryPeriod = time.Duration(0 * time.Second)
} else {
queryPeriod = time.Now().UTC().Sub(startTime)
queryPeriod = queryParams.(*QueryParams).queryPeriod
}

proxy:
// Choose target based on query Period
target := lb.manager.Target(queryPeriod)
if target != nil {
Expand Down
45 changes: 23 additions & 22 deletions pkg/lb/frontend/frontend_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package frontend

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/http/httputil"
Expand Down Expand Up @@ -70,48 +70,49 @@ func TestNewFrontend(t *testing.T) {

tests := []struct {
name string
req string
header bool
start int64
code int
response bool
}{
{
name: "pass with query",
req: "/test?query=foo{uuid=\"1479765|1481510\"}",
header: true,
name: "query with params in ctx",
start: time.Now().UTC().Unix(),
code: 200,
response: true,
},
{
name: "pass with start and query params",
req: fmt.Sprintf(
"/test?query=foo{uuid=\"123|345\"}&start=%d",
time.Now().UTC().Add(-time.Duration(29*24*time.Hour)).Unix(),
),
header: false,
name: "query with no params in ctx",
code: 200,
response: true,
},
{
name: "no target with start more than retention period",
req: fmt.Sprintf(
"/test?query=foo{uuid=\"123|345\"}&start=%d",
time.Now().UTC().Add(-time.Duration(31*24*time.Hour)).Unix(),
),
header: false,
name: "query with params in ctx and start more than retention period",
start: time.Now().UTC().Add(-time.Duration(31 * 24 * time.Hour)).Unix(),
code: 503,
response: false,
},
}

for _, test := range tests {
request := httptest.NewRequest("GET", test.req, nil)
if test.header {
request.Header.Set("X-Grafana-User", "usr1")
request := httptest.NewRequest("GET", "/test", nil)

// Add uuids and start to context
var newReq *http.Request
if test.start > 0 {
period := time.Duration((time.Now().UTC().Unix() - test.start)) * time.Second
newReq = request.WithContext(
context.WithValue(
request.Context(), QueryParamsContextKey{},
&QueryParams{queryPeriod: period},
),
)
} else {
newReq = request
}

responseRecorder := httptest.NewRecorder()

http.HandlerFunc(lb.Serve).ServeHTTP(responseRecorder, request)
http.HandlerFunc(lb.Serve).ServeHTTP(responseRecorder, newReq)

if responseRecorder.Code != test.code {
t.Errorf("%s: expected status %d, got %d", test.name, test.code, responseRecorder.Code)
Expand Down
75 changes: 66 additions & 9 deletions pkg/lb/frontend/helpers.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package frontend

import (
"bytes"
"context"
"fmt"
"io"
"math"
"net"
"net/http"
"net/url"
"reflect"
"slices"
"strconv"
"strings"
"time"

"github.com/go-kit/log"
Expand Down Expand Up @@ -67,15 +71,68 @@ func Monitor(ctx context.Context, manager serverpool.Manager, logger log.Logger)
}
}

// // Returns query period based on start time of query
// func parseQueryPeriod(r *http.Request) time.Duration {
// // Parse start query string in request
// start, err := parseTimeParam(r, "start", MinTime)
// if err != nil {
// return time.Duration(0 * time.Second)
// }
// return time.Now().UTC().Sub(start)
// }
// Set query params into request's context and return new request
func setQueryParams(r *http.Request, queryParams *QueryParams) *http.Request {
return r.WithContext(context.WithValue(r.Context(), QueryParamsContextKey{}, queryParams))
}

// Parse query in the request after cloning it and add query params to context
func parseQueryParams(r *http.Request, logger log.Logger) *http.Request {
var body []byte
var uuids []string
var queryPeriod time.Duration
var err error

// Make a new request and add newReader to that request body
clonedReq := r.Clone(r.Context())

// If request has no body go to proxy directly
if r.Body == nil {
return setQueryParams(r, &QueryParams{uuids, queryPeriod})
}

// If failed to read body, skip verification and go to request proxy
if body, err = io.ReadAll(r.Body); err != nil {
level.Error(logger).Log("msg", "Failed to read request body", "err", err)
return setQueryParams(r, &QueryParams{uuids, queryPeriod})
}

// clone body to existing request and new request
r.Body = io.NopCloser(bytes.NewReader(body))
clonedReq.Body = io.NopCloser(bytes.NewReader(body))

// Get form values
if err = clonedReq.ParseForm(); err != nil {
level.Error(logger).Log("msg", "Could not parse request body", "err", err)
return setQueryParams(r, &QueryParams{uuids, queryPeriod})
}

// Parse TSDB's query in request query params
if val := clonedReq.FormValue("query"); val != "" {
matches := regexpUUID.FindAllStringSubmatch(val, -1)
for _, match := range matches {
if len(match) > 1 {
for _, uuid := range strings.Split(match[1], "|") {
// Ignore empty strings
if strings.TrimSpace(uuid) != "" && !slices.Contains(uuids, uuid) {
uuids = append(uuids, uuid)
}
}
}
}
}

// Parse TSDB's start query in request query params
if startTime, err := parseTimeParam(clonedReq, "start", time.Now().UTC()); err != nil {
level.Error(logger).Log("msg", "Could not parse start query param", "err", err)
queryPeriod = time.Duration(0 * time.Second)
} else {
queryPeriod = time.Now().UTC().Sub(startTime)
}

// Set query params to request's context
return setQueryParams(r, &QueryParams{uuids, queryPeriod})
}

// Parse time parameter in request
func parseTimeParam(r *http.Request, paramName string, defaultValue time.Time) (time.Time, error) {
Expand Down
Loading

0 comments on commit a262f34

Please sign in to comment.