Skip to content

Commit

Permalink
chore: more tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
lsjostro committed Feb 14, 2024
1 parent 137f644 commit b9cb4e8
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
47 changes: 39 additions & 8 deletions authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ import (
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
semconv "go.opentelemetry.io/otel/semconv/v1.23.1"
"go.opentelemetry.io/otel/trace"
"golang.org/x/net/http2"
rpcstatus "google.golang.org/genproto/googleapis/rpc/status"

Expand Down Expand Up @@ -87,6 +90,9 @@ func (s *Service) Name() string {
}

func (s *Service) Check(ctx context.Context, req *connect.Request[auth.CheckRequest]) (*connect.Response[auth.CheckResponse], error) {
ctx, span := tracer.Start(ctx, "Check")
defer span.End()

httpReq := req.Msg.GetAttributes().GetRequest().GetHttp()
reqHeaders := httpReq.GetHeaders()
var resp *auth.CheckResponse
Expand All @@ -108,14 +114,31 @@ func (s *Service) Check(ctx context.Context, req *connect.Request[auth.CheckRequ
}
if provider == nil {
slog.Error("no header matches any provider")
span.RecordError(errors.New("no header matches any auth provider"))
span.SetStatus(codes.Error, "no header matches any auth provider")
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_Unauthorized, nil, nil, "no header matches any auth provider")), nil
}

span.AddEvent("provider",
trace.WithAttributes(
attribute.String("issuer_url", provider.IssuerURL),
attribute.String("client_id", provider.ClientID),
attribute.String("callback_uri", provider.CallbackURI),
attribute.String("cookie_name_prefix", provider.CookieNamePrefix),
attribute.Bool("opa_enabled", provider.OPAEnabled),
attribute.Bool("allow_auth_header", provider.AllowAuthHeader),
attribute.Bool("secure_cookie", provider.SecureCookie),
attribute.StringSlice("scopes", provider.Scopes),
attribute.String("header_match_name", provider.HeaderMatch.Name),
),
)

if provider.OPAEnabled && s.authClient != nil {
slog.Debug("OPA is enabled, sending request to OPA for authorization")
opaResp, err := s.authClient.Check(ctx, req)
if err != nil {
slog.Error("OPA check failed", slog.String("err", err.Error()))
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}
if opaResp.Msg.GetStatus().GetCode() == int32(rpc.PERMISSION_DENIED) {
Expand All @@ -131,6 +154,8 @@ func (s *Service) Check(ctx context.Context, req *connect.Request[auth.CheckRequ
resp, err := s.authProcess(ctx, httpReq, provider)
if err != nil {
slog.Error("authProccess failed", slog.String("err", err.Error()))
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
resp = s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())
}

Expand All @@ -156,24 +181,29 @@ func (s *Service) authProcess(ctx context.Context, req *auth.AttributeContext_Ht
slog.Debug("session data not found in cookie, creating new")
headers, err := s.newSession(ctx, requestedURL, sessionCookieName, provider)
if err != nil {
span.RecordError(err)
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}
// set downstream headers and redirect to Idp
return s.authResponse(false, envoy_type.StatusCode_Found, headers, nil, "redirect to Idp"), nil
}

slog.Debug("session data found in cookie", slog.String("session_id", sessionId), slog.String("url", requestedURL))
span.SetAttributes(
attribute.String("session_id", sessionId),
attribute.String("requested_url", requestedURL),
)
if span.IsRecording() {
span.SetAttributes(
semconv.URLFull(requestedURL),
semconv.SourceAddress(req.GetHeaders()["x-forwarded-for"]),
semconv.SessionID(sessionId),
)
}

// If the request is for the callback URI, then we need to exchange the code for tokens
if strings.HasPrefix(requestedURL, provider.CallbackURI+"?") && sessionData.AccessToken == "" {
err := s.retriveTokens(ctx, provider, sessionData, requestedURL, sessionCookieName, sessionId)
if err != nil {
span.RecordError(err)
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}
// set downstream headers and redirect client to requested URL from session cookie
Expand All @@ -187,7 +217,8 @@ func (s *Service) authProcess(ctx context.Context, req *auth.AttributeContext_Ht
slog.Warn("couldn't validating tokens", slog.String("err", err.Error()))
headers, err := s.newSession(ctx, requestedURL, sessionCookieName, provider)
if err != nil {
span.RecordError(err)
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return s.authResponse(false, envoy_type.StatusCode_Found, headers, nil, "redirect to Idp"), nil
Expand Down
36 changes: 33 additions & 3 deletions oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)

// Create auth provicer interface
Expand Down Expand Up @@ -92,10 +94,10 @@ func (o *OIDCProvider) VerifyTokens(ctx context.Context, accessToken, idToken st
return false, err
}
}
span.SetAttributes(
span.AddEvent("log", trace.WithAttributes(
attribute.String("issuer", t.GetIssuer()),
attribute.String("expire", t.GetExpiration().String()),
attribute.Bool("has_expired", expired),
attribute.Bool("has_expired", expired)),
)
return expired, nil
}
Expand All @@ -105,7 +107,7 @@ func (o *OIDCProvider) VerifyTokens(ctx context.Context, accessToken, idToken st
func (o *OIDCProvider) RetriveTokens(ctx context.Context, code, codeVerifier string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
ctx, span := tracer.Start(ctx, "RetriveTokens")
defer span.End()
slog.Debug("retriving tokens", slog.String("authorization_code", code), slog.String("code_verifier", codeVerifier))

var opts []rp.CodeExchangeOpt

if o.isPKCE {
Expand All @@ -115,26 +117,54 @@ func (o *OIDCProvider) RetriveTokens(ctx context.Context, code, codeVerifier str

tokens, err := rp.CodeExchange[*oidc.IDTokenClaims](ctx, code, o.provider, opts...)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
slog.Error("retriving token", slog.String("err", err.Error()))
return nil, err
}

if !tokens.Valid() {
span.RecordError(errors.New("RetriveTokens: invalid token"))
span.SetStatus(codes.Error, "RetriveTokens: invalid token")
return nil, errors.New("RetriveTokens: invalid token")
}

span.AddEvent("log",
trace.WithAttributes(
attribute.String("issuer", tokens.IDTokenClaims.GetIssuer()),
attribute.Bool("is_pkce", o.isPKCE),
attribute.String("expire", tokens.IDTokenClaims.GetExpiration().String()),
attribute.String("subject", tokens.IDTokenClaims.GetSubject()),
),
)

return tokens, nil
}

// RefreshTokens refreshes the tokens and returns them
// clientAssertion is the client assertion jwt (tokens.AccessToken)
func (o *OIDCProvider) RefreshTokens(ctx context.Context, refreshToken, clientAssertion string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
ctx, span := tracer.Start(ctx, "RefreshTokens")
defer span.End()

tokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](ctx, o.provider, refreshToken, clientAssertion, oidc.ClientAssertionTypeJWTAssertion)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
if !tokens.Valid() {
span.RecordError(errors.New("RefreshTokens: invalid token"))
span.SetStatus(codes.Error, "RefreshTokens: invalid token")
return nil, errors.New("RefreshTokens: invalid token")
}
span.AddEvent("log",
trace.WithAttributes(
attribute.String("issuer", tokens.IDTokenClaims.GetIssuer()),
attribute.Bool("is_pkce", o.isPKCE),
attribute.String("expire", tokens.IDTokenClaims.GetExpiration().String()),
attribute.String("subject", tokens.IDTokenClaims.GetSubject()),
),
)
return tokens, nil
}
9 changes: 9 additions & 0 deletions session/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"errors"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/nacl/secretbox"
"google.golang.org/protobuf/proto"

Expand All @@ -23,12 +25,16 @@ func EncryptSession(ctx context.Context, key [32]byte, sessionData *pb.SessionDa

message, err := proto.Marshal(sessionData)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}

var nonce [24]byte
_, err = rand.Read(nonce[:])
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}

Expand All @@ -40,6 +46,7 @@ func EncryptSession(ctx context.Context, key [32]byte, sessionData *pb.SessionDa
func DecryptSession(ctx context.Context, key [32]byte, box []byte) (*pb.SessionData, error) {
_, span := tracer.Start(ctx, "DecryptSession")
defer span.End()

if len(box) < 24 {
return nil, errInvalid
}
Expand All @@ -53,6 +60,8 @@ func DecryptSession(ctx context.Context, key [32]byte, box []byte) (*pb.SessionD

sessionData := &pb.SessionData{}
if err := proto.Unmarshal(message, sessionData); err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return sessionData, nil
Expand Down

0 comments on commit b9cb4e8

Please sign in to comment.