Skip to content

Commit

Permalink
Merge pull request #80 from NorskHelsenett/feature/add-metrics-for-us…
Browse files Browse the repository at this point in the history
…erauth

Feature/add metrics for userauth
  • Loading branch information
havardelnan authored Apr 24, 2024
2 parents 9bb1c36 + 530b5b6 commit f5765b0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 12 deletions.
15 changes: 15 additions & 0 deletions pkg/auth/authtools/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,23 @@ package authtools
import (
"fmt"
"strings"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)

var (
UserLookupHistogram *prometheus.HistogramVec
ServerConnectionHistogram *prometheus.HistogramVec
ServerReconnectCounter *prometheus.CounterVec
)

func init() {
ServerConnectionHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{Name: "auth_server_connection_duration_seconds", Help: "Duration of server connection in seconds"}, []string{"provider", "domain", "host", "port", "status"})
UserLookupHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{Name: "auth_user_lookup_duration_seconds", Help: "Duration of user lookup in seconds"}, []string{"provider", "domain", "host", "status"})
ServerReconnectCounter = promauto.NewCounterVec(prometheus.CounterOpts{Name: "auth_server_reconnects_total", Help: "Total number of server reconnects"}, []string{"provider", "domain"})
}

func SplitUserId(userId string) (string, string, error) {
parts := strings.Split(userId, "@")
if len(parts) != 2 {
Expand Down
15 changes: 13 additions & 2 deletions pkg/auth/userauth/activedirectory/ad.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package activedirectory

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
Expand All @@ -21,7 +22,8 @@ import (
var DefaultTimeout = 10 * time.Second

type AdConfig struct {
Domain string `json:"domain"`
Domain string `json:"domain"`
server string
BindUser string `json:"bindUser"`
BindPassword string `json:"bindPassword"`
BaseDN string `json:"basedn"`
Expand Down Expand Up @@ -90,6 +92,7 @@ func (l *AdClient) Connect() error {

for _, ldapserver := range l.config.Servers {
rlog.Infof("Trying server %s for domain %s.", ldapserver.Host, l.config.Domain)
connectionStart := time.Now()
if l.config.Certificate != nil {
caCert := l.config.Certificate
caCertPool := x509.NewCertPool()
Expand All @@ -109,14 +112,18 @@ func (l *AdClient) Connect() error {

if err != nil {
rlog.Error("an error occurred connecting to LDAP-host.", err, rlog.Any("Host", ldapserver.Host), rlog.Any("Port", ldapserver.Port))
authtools.ServerConnectionHistogram.WithLabelValues("ad", l.config.Domain, ldapserver.Host, strconv.Itoa(ldapserver.Port), "500").Observe(time.Since(connectionStart).Seconds())
continue
}

err = client.Bind(l.config.BindUser, l.config.BindPassword)
if err != nil {
rlog.Error("an error occurred authenticating to LDAP-host.", err, rlog.Any("Host", ldapserver.Host), rlog.Any("Port", ldapserver.Port), rlog.Any("BindUser", l.config.BindUser))
authtools.ServerConnectionHistogram.WithLabelValues("ad", l.config.Domain, ldapserver.Host, strconv.Itoa(ldapserver.Port), "401").Observe(time.Since(connectionStart).Seconds())
} else {
rlog.Infof("Connected to server server %s for domain %s.", ldapserver.Host, l.config.Domain)
l.config.server = ldapserver.Host
authtools.ServerConnectionHistogram.WithLabelValues("ad", l.config.Domain, ldapserver.Host, strconv.Itoa(ldapserver.Port), "200").Observe(time.Since(connectionStart).Seconds())
break
}
}
Expand Down Expand Up @@ -152,7 +159,7 @@ func (l *AdClient) search(basedn, filter string, attributes []string) (*ldap.Sea
return nil, fmt.Errorf("could not fetch search entries")
}

func (l *AdClient) GetUser(userId string) (*identitymodels.User, error) {
func (l *AdClient) GetUser(ctx context.Context, userId string) (*identitymodels.User, error) {

userpart, domainpart, err := authtools.SplitUserId(userId)
if err != nil {
Expand All @@ -163,17 +170,21 @@ func (l *AdClient) GetUser(userId string) (*identitymodels.User, error) {

if l.connection.IsClosing() {
rlog.Debug("Reconnecting to Active Directory")
authtools.ServerReconnectCounter.WithLabelValues("ad", l.config.Domain).Inc()
err := l.Connect()
if err != nil {
return nil, err
}
}

queryStart := time.Now()
result, err := l.search(l.config.BaseDN, filter, attributes)

if err != nil {
authtools.UserLookupHistogram.WithLabelValues("ad", l.config.Domain, l.config.server, "500").Observe(time.Since(queryStart).Seconds())
return nil, err
}
authtools.UserLookupHistogram.WithLabelValues("ad", l.config.Domain, l.config.server, "200").Observe(time.Since(queryStart).Seconds())
var userEntry *ldap.Entry
if result != nil && len(result.Entries) == 1 {
for _, entry := range result.Entries {
Expand Down
19 changes: 16 additions & 3 deletions pkg/auth/userauth/ldaps/openldap.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldaps

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
Expand All @@ -18,7 +19,8 @@ import (
var DefaultTimeout = 10 * time.Second

type LdapConfig struct {
Domain string `json:"domain"`
Domain string `json:"domain"`
server string
BindUser string `json:"bindUser"`
BindPassword string `json:"bindPassword"`
BaseDN string `json:"basedn"`
Expand Down Expand Up @@ -56,6 +58,7 @@ func (l *LdapsClient) Connect() error {
if err != nil {
return fmt.Errorf("failed to parse default ldaps port")
}
connectionStart := time.Now()
if ldapserver.Port == ldapsport {
caCert := l.config.Certificate
caCertPool := x509.NewCertPool()
Expand All @@ -75,15 +78,19 @@ func (l *LdapsClient) Connect() error {

if err != nil {
rlog.Error("an error occurred connecting to LDAP-host.", err, rlog.Any("Host", ldapserver.Host), rlog.Any("Port", ldapserver.Port))
authtools.ServerConnectionHistogram.WithLabelValues("openldap", l.config.Domain, ldapserver.Host, strconv.Itoa(ldapserver.Port), "500").Observe(time.Since(connectionStart).Seconds())
}

err = client.Bind(l.config.BindUser, l.config.BindPassword)
if err != nil {
rlog.Error("an error occurred authenticating to LDAP-host.", err, rlog.Any("Host", ldapserver.Host), rlog.Any("Port", ldapserver.Port), rlog.Any("BindUser", l.config.BindUser))
authtools.ServerConnectionHistogram.WithLabelValues("openldap", l.config.Domain, ldapserver.Host, strconv.Itoa(ldapserver.Port), "401").Observe(time.Since(connectionStart).Seconds())
} else {
rlog.Infof("Connected to server server %s for domain %s.", ldapserver.Host, l.config.Domain)
l.config.server = ldapserver.Host
authtools.ServerConnectionHistogram.WithLabelValues("openldap", l.config.Domain, ldapserver.Host, strconv.Itoa(ldapserver.Port), "200").Observe(time.Since(connectionStart).Seconds())
break
}

}

if client == nil {
Expand Down Expand Up @@ -117,7 +124,7 @@ func (l *LdapsClient) search(basedn, filter string, attributes []string) (*ldap.
return nil, fmt.Errorf("could not fetch search entries")
}

func (l *LdapsClient) GetUser(userId string) (*identitymodels.User, error) {
func (l *LdapsClient) GetUser(ctx context.Context, userId string) (*identitymodels.User, error) {

_, domainpart, err := authtools.SplitUserId(userId)
if err != nil {
Expand All @@ -131,17 +138,23 @@ func (l *LdapsClient) GetUser(userId string) (*identitymodels.User, error) {

if l.connection.IsClosing() {
rlog.Debug("Reconnecting to LDAP")
authtools.ServerReconnectCounter.WithLabelValues("openldap", l.config.Domain).Inc()
err := l.Connect()
if err != nil {
return nil, err
}
}

queryStart := time.Now()
result, err := l.search(l.config.BaseDN, filter, attributes)

if err != nil {
authtools.UserLookupHistogram.WithLabelValues("openldap", l.config.Domain, l.config.server, "500").Observe(time.Since(queryStart).Seconds())

return nil, err
}
authtools.UserLookupHistogram.WithLabelValues("openldap", l.config.Domain, l.config.server, "200").Observe(time.Since(queryStart).Seconds())

var userEntry *ldap.Entry
if result != nil && len(result.Entries) == 1 {
for _, entry := range result.Entries {
Expand Down
16 changes: 12 additions & 4 deletions pkg/auth/userauth/msgraph/msgraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package msgraph
import (
"context"
"fmt"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/NorskHelsenett/ror/pkg/auth/authtools"
Expand All @@ -18,6 +19,8 @@ import (
graphusers "github.com/microsoftgraph/msgraph-sdk-go/users"
)

var ApiEndpoint = "https://graph.microsoft.com/.default"

type MsGraphConfig struct {
Domain string `json:"domain"`
TenantID string `json:"tenantId"`
Expand Down Expand Up @@ -47,18 +50,21 @@ func NewMsGraphClient(config MsGraphConfig, cacheHelper kvcachehelper.CacheInter
} else {
client.GroupCache = memorycache.NewKvCache()
}

connectionStart := time.Now()
cred, err := azidentity.NewClientSecretCredential(client.config.TenantID, client.config.ClientID, client.config.ClientSecret, nil)
if err != nil {
return nil, err
}

conn, err := msgraphsdk.NewGraphServiceClientWithCredentials(
cred, []string{"https://graph.microsoft.com/.default"},
cred, []string{ApiEndpoint},
)

if err != nil {
authtools.ServerConnectionHistogram.WithLabelValues("msgraph", config.Domain, ApiEndpoint, "443", "500").Observe(time.Since(connectionStart).Seconds())
return nil, err
}
authtools.ServerConnectionHistogram.WithLabelValues("msgraph", config.Domain, ApiEndpoint, "443", "200").Observe(time.Since(connectionStart).Seconds())
rlog.Infof("Connected to msgraph api for domain %s.", config.Domain)
client.Client = conn
return client, nil
Expand All @@ -67,15 +73,15 @@ func NewMsGraphClient(config MsGraphConfig, cacheHelper kvcachehelper.CacheInter
// GetUsersWithGroups gets a user and the name of the groups the user is a member of
// TODO: Implement isExpired
// TODO: Implement isDisabled...
func (g *MsGraphClient) GetUser(userId string) (*identitymodels.User, error) {
func (g *MsGraphClient) GetUser(ctx context.Context, userId string) (*identitymodels.User, error) {
var ret *identitymodels.User
var groupnames []string = []string{}
var user models.Userable

groupsChan := make(chan []string)
userChan := make(chan models.Userable)
errorChan := make(chan error)

queryStart := time.Now()
go g.getUser(userId, userChan, errorChan)
go g.getGroups(userId, groupsChan, errorChan)

Expand All @@ -90,11 +96,13 @@ func (g *MsGraphClient) GetUser(userId string) (*identitymodels.User, error) {
case returneUser := <-userChan:
user = returneUser
case err := <-errorChan:
authtools.UserLookupHistogram.WithLabelValues("msgraph", g.config.Domain, ApiEndpoint, "500").Observe(time.Since(queryStart).Seconds())
return nil, err
}
}

addDomainpartToGroups(&groupnames, userId)
authtools.UserLookupHistogram.WithLabelValues("msgraph", g.config.Domain, ApiEndpoint, "200").Observe(time.Since(queryStart).Seconds())

ret = &identitymodels.User{
Email: *user.GetUserPrincipalName(),
Expand Down
7 changes: 4 additions & 3 deletions pkg/auth/userauth/userauth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package userauth

import (
"context"
"encoding/json"
"fmt"
"strconv"
Expand All @@ -26,15 +27,15 @@ type DomainResolverConfig struct {
}

type DomainResolverInterface interface {
GetUser(userId string) (*identitymodels.User, error)
GetUser(ctx context.Context, userId string) (*identitymodels.User, error)
CheckHealth() []newhealth.Check
}

type DomainResolvers struct {
resolvers map[string]DomainResolverInterface
}

func (d DomainResolvers) GetUser(userId string) (*identitymodels.User, error) {
func (d DomainResolvers) GetUser(ctx context.Context, userId string) (*identitymodels.User, error) {
_, domain, err := authtools.SplitUserId(userId)
if err != nil {
return nil, err
Expand All @@ -45,7 +46,7 @@ func (d DomainResolvers) GetUser(userId string) (*identitymodels.User, error) {
}

if domainResolver, ok := d.resolvers[domain]; ok {
return domainResolver.GetUser(userId)
return domainResolver.GetUser(ctx, userId)
}
return nil, fmt.Errorf("no domain resolver found for domain: %s", domain)
}
Expand Down

0 comments on commit f5765b0

Please sign in to comment.