Skip to content

Commit

Permalink
Refactor kubernetes bearer token authentication
Browse files Browse the repository at this point in the history
Instead of doing retries on 401 errors, use a mechanism from client-go
which simply reloads the token periodically in the background.

Also, don't stop logging errors after the first 401. These errors, if
present, need to be addressed by the cluster operator, so we should make
them more prominent.
  • Loading branch information
swiatekm committed Feb 14, 2025
1 parent 0f487df commit 5952aa4
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 242 deletions.
54 changes: 13 additions & 41 deletions metricbeat/helper/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ package helper
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"

"github.com/elastic/elastic-agent-libs/transport/httpcommon"
"github.com/elastic/elastic-agent-libs/useragent"

"k8s.io/client-go/transport"

"github.com/elastic/beats/v7/libbeat/version"
"github.com/elastic/beats/v7/metricbeat/helper/dialer"
"github.com/elastic/beats/v7/metricbeat/mb"
Expand Down Expand Up @@ -69,14 +71,6 @@ func NewHTTPFromConfig(config Config, hostData mb.HostData) (*HTTP, error) {
headers.Set(k, v)
}

if config.BearerTokenFile != "" {
header, err := getAuthHeaderFromToken(config.BearerTokenFile)
if err != nil {
return nil, err
}
headers.Set("Authorization", header)
}

// Ensure backward compatibility
builder := hostData.Transport
if builder == nil {
Expand All @@ -97,6 +91,15 @@ func NewHTTPFromConfig(config Config, hostData mb.HostData) (*HTTP, error) {
return nil, err
}

// Apply the token refreshing roundtripper. We can't do this in a transport option because we need to handle the
// error it can return at creation
if config.BearerTokenFile != "" {
client.Transport, err = transport.NewBearerAuthWithRefreshRoundTripper("", config.BearerTokenFile, client.Transport)
}
if err != nil {
return nil, err
}

return &HTTP{
hostData: hostData,
bearerFile: config.BearerTokenFile,
Expand All @@ -118,7 +121,7 @@ func (h *HTTP) FetchResponse() (*http.Response, error) {
reader = bytes.NewReader(h.body)
}

req, err := http.NewRequest(h.method, h.uri, reader)
req, err := http.NewRequestWithContext(context.Background(), h.method, h.uri, reader) // TODO: get context from caller
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
Expand Down Expand Up @@ -212,34 +215,3 @@ func (h *HTTP) FetchJSON() (map[string]interface{}, error) {

return data, nil
}

func (h *HTTP) RefreshAuthorizationHeader() (bool, error) {
if h.bearerFile != "" {
header, err := getAuthHeaderFromToken(h.bearerFile)
if err != nil {
return false, err
}
h.headers.Set("Authorization", header)
return true, nil
}
return false, nil
}

// getAuthHeaderFromToken reads a bearer authorization token from the given file
func getAuthHeaderFromToken(path string) (string, error) {
var token string

b, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("reading bearer token file: %w", err)
}

if len(b) != 0 {
if b[len(b)-1] == '\n' {
b = b[0 : len(b)-1]
}
token = fmt.Sprintf("Bearer %s", string(b))
}

return token, nil
}
88 changes: 21 additions & 67 deletions metricbeat/helper/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,51 +38,6 @@ import (
"github.com/elastic/beats/v7/metricbeat/mb/parse"
)

func TestGetAuthHeaderFromToken(t *testing.T) {
tests := []struct {
Name, Content, Expected string
}{
{
"Test a token is read",
"testtoken",
"Bearer testtoken",
},
{
"Test a token is trimmed",
"testtoken\n",
"Bearer testtoken",
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
content := []byte(test.Content)
tmpfile, err := os.CreateTemp("", "token")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())

if _, err := tmpfile.Write(content); err != nil {
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}

header, err := getAuthHeaderFromToken(tmpfile.Name())
assert.NoError(t, err)
assert.Equal(t, test.Expected, header)
})
}
}

func TestGetAuthHeaderFromTokenNoFile(t *testing.T) {
header, err := getAuthHeaderFromToken("nonexistingfile")
assert.Equal(t, "", header)
assert.Error(t, err)
}

func TestTimeout(t *testing.T) {
c := make(chan struct{})
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -301,35 +256,34 @@ func TestRefreshAuthorizationHeader(t *testing.T) {
bearerFileName := "token"
bearerFilePath := filepath.Join(path, bearerFileName)

getAuth := func(helper *HTTP) string {
for k, v := range helper.headers {
if k == "Authorization" {
return v[0]
}
}
return ""
}
var authToken string

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authToken = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

firstToken := "token-1"
err := os.WriteFile(bearerFilePath, []byte(firstToken), 0644)
assert.NoError(t, err)

helper := &HTTP{bearerFile: bearerFilePath, headers: make(http.Header)}
updated, err := helper.RefreshAuthorizationHeader()
assert.NoError(t, err)
assert.True(t, updated)
expected := fmt.Sprintf("Bearer %s", firstToken)
assert.Equal(t, expected, getAuth(helper))
cfg := defaultConfig()
cfg.BearerTokenFile = bearerFilePath
hostData := mb.HostData{
URI: ts.URL,
SanitizedURI: ts.URL,
}

secondToken := "token-2"
err = os.WriteFile(bearerFilePath, []byte(secondToken), 0644)
assert.NoError(t, err)
h, err := NewHTTPFromConfig(cfg, hostData)
require.NoError(t, err)

updated, err = helper.RefreshAuthorizationHeader()
assert.NoError(t, err)
assert.True(t, updated)
expected = fmt.Sprintf("Bearer %s", secondToken)
assert.Equal(t, expected, getAuth(helper))
res, err := h.FetchResponse()
require.NoError(t, err)
res.Body.Close()

assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, authToken, firstToken)
}

func checkTimeout(t *testing.T, h *HTTP) {
Expand Down
4 changes: 2 additions & 2 deletions metricbeat/include/list_init.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 10 additions & 37 deletions metricbeat/module/kubernetes/apiserver/metricset.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ package apiserver

import (
"fmt"
"net/http"
"strings"
"time"

"github.com/elastic/beats/v7/metricbeat/helper"
"github.com/elastic/beats/v7/metricbeat/helper/prometheus"
"github.com/elastic/beats/v7/metricbeat/mb"
k8smod "github.com/elastic/beats/v7/metricbeat/module/kubernetes"
Expand All @@ -34,7 +30,6 @@ import (
// Metricset for apiserver is a prometheus based metricset
type Metricset struct {
mb.BaseMetricSet
http *helper.HTTP
prometheusClient prometheus.Prometheus
prometheusMappings *prometheus.MetricsMapping
clusterMeta mapstr.M
Expand All @@ -54,13 +49,8 @@ func New(base mb.BaseMetricSet) (mb.MetricSet, error) {
return nil, fmt.Errorf("must be child of kubernetes module")
}

http, err := pc.GetHttp()
if err != nil {
return nil, fmt.Errorf("the http connection is not valid")
}
ms := &Metricset{
BaseMetricSet: base,
http: http,
prometheusClient: pc,
prometheusMappings: mapping,
clusterMeta: util.AddClusterECSMeta(base),
Expand All @@ -73,36 +63,19 @@ func New(base mb.BaseMetricSet) (mb.MetricSet, error) {
// Fetch gathers information from the apiserver and reports events with this information.
func (m *Metricset) Fetch(reporter mb.ReporterV2) error {
events, err := m.prometheusClient.GetProcessedMetrics(m.prometheusMappings)
errorString := fmt.Sprintf("%s", err)
errorUnauthorisedMsg := fmt.Sprintf("unexpected status code %d", http.StatusUnauthorized)
if err != nil && strings.Contains(errorString, errorUnauthorisedMsg) {
count := 2 // We retry twice to refresh the Authorisation token in case of http.StatusUnauthorize = 401 Error
for count > 0 {
if _, errAuth := m.http.RefreshAuthorizationHeader(); errAuth == nil {
events, err = m.prometheusClient.GetProcessedMetrics(m.prometheusMappings)
}
if err != nil {
time.Sleep(m.mod.Config().Period)
count--
} else {
break
}
}
}
// We need to check for err again in case error is not 401 or RefreshAuthorizationHeader has failed
if err != nil {
return fmt.Errorf("error getting metrics: %w", err)
} else {
for _, e := range events {
event := mb.TransformMapStrToEvent("kubernetes", e, nil)
if len(m.clusterMeta) != 0 {
event.RootFields.DeepUpdate(m.clusterMeta)
}
isOpen := reporter.Event(event)
if !isOpen {
return nil
}
}
for _, e := range events {
event := mb.TransformMapStrToEvent("kubernetes", e, nil)
if len(m.clusterMeta) != 0 {
event.RootFields.DeepUpdate(m.clusterMeta)
}
isOpen := reporter.Event(event)
if !isOpen {
return nil
}
return nil
}
return nil
}
49 changes: 11 additions & 38 deletions metricbeat/module/kubernetes/controllermanager/controllermanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ package controllermanager

import (
"fmt"
"net/http"
"strings"
"time"

"github.com/elastic/beats/v7/metricbeat/helper"
"github.com/elastic/beats/v7/metricbeat/helper/prometheus"
"github.com/elastic/beats/v7/metricbeat/mb"
k8smod "github.com/elastic/beats/v7/metricbeat/module/kubernetes"
Expand Down Expand Up @@ -79,7 +75,6 @@ func init() {
// MetricSet implements the mb.PushMetricSet interface, and therefore does not rely on polling.
type MetricSet struct {
mb.BaseMetricSet
http *helper.HTTP
prometheusClient prometheus.Prometheus
prometheusMappings *prometheus.MetricsMapping
clusterMeta mapstr.M
Expand All @@ -100,13 +95,8 @@ func New(base mb.BaseMetricSet) (mb.MetricSet, error) {
return nil, fmt.Errorf("must be child of kubernetes module")
}

http, err := pc.GetHttp()
if err != nil {
return nil, fmt.Errorf("the http connection is not valid")
}
ms := &MetricSet{
BaseMetricSet: base,
http: http,
prometheusClient: pc,
prometheusMappings: mapping,
clusterMeta: util.AddClusterECSMeta(base),
Expand All @@ -118,37 +108,20 @@ func New(base mb.BaseMetricSet) (mb.MetricSet, error) {
// Fetch gathers information from the apiserver and reports events with this information.
func (m *MetricSet) Fetch(reporter mb.ReporterV2) error {
events, err := m.prometheusClient.GetProcessedMetrics(m.prometheusMappings)
errorString := fmt.Sprintf("%s", err)
errorUnauthorisedMsg := fmt.Sprintf("unexpected status code %d", http.StatusUnauthorized)
if err != nil && strings.Contains(errorString, errorUnauthorisedMsg) {
count := 2 // We retry twice to refresh the Authorisation token in case of http.StatusUnauthorize = 401 Error
for count > 0 {
if _, errAuth := m.http.RefreshAuthorizationHeader(); errAuth == nil {
events, err = m.prometheusClient.GetProcessedMetrics(m.prometheusMappings)
}
if err != nil {
time.Sleep(m.mod.Config().Period)
count--
} else {
break
}
}
}
// We need to check for err again in case error is not 401 or RefreshAuthorizationHeader has failed
if err != nil {
return fmt.Errorf("error getting metrics: %w", err)
} else {
for _, e := range events {
event := mb.TransformMapStrToEvent("kubernetes", e, nil)
if len(m.clusterMeta) != 0 {
event.RootFields.DeepUpdate(m.clusterMeta)
}
isOpen := reporter.Event(event)
if !isOpen {
return nil
}
}

for _, e := range events {
event := mb.TransformMapStrToEvent("kubernetes", e, nil)
if len(m.clusterMeta) != 0 {
event.RootFields.DeepUpdate(m.clusterMeta)
}
isOpen := reporter.Event(event)
if !isOpen {
return nil
}

return nil
}
return nil
}
Loading

0 comments on commit 5952aa4

Please sign in to comment.