Skip to content

Commit

Permalink
fix: corrected inflate and deflate logic for xml, and corresponding t…
Browse files Browse the repository at this point in the history
…ests with some refactoring for SAML sessions (#92)

* refactor: remove separate functions for error reasons

* fix: corrected inflate and deflate logic, and corresponding tests

* fix: corrected inflate and deflate logic, and corresponding tests

* fix: corrected inflate and deflate logic, and corresponding tests

* fix: switch over status codes to errors

* fix: review changes

* fix: review changes
  • Loading branch information
stebenz authored Dec 11, 2024
1 parent 51c410d commit 95c785b
Show file tree
Hide file tree
Showing 15 changed files with 255 additions and 390 deletions.
8 changes: 6 additions & 2 deletions pkg/provider/attribute_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
queriedAttrs = append(queriedAttrs, queriedAttr)
}
}
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.timeFormat)
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.TimeFormat, p.Expiration)
return nil
},
func() {
Expand All @@ -139,7 +139,11 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
// create enveloped signature
checkerInstance.WithLogicStep(
func() error {
return createPostSignature(r.Context(), response, p)
cert, key, err := getResponseCert(r.Context(), p.storage)
if err != nil {
return err
}
return createPostSignature(response, key, cert, p.conf.SignatureAlgorithm)
},
func() {
http.Error(w, fmt.Errorf("failed to sign response: %w", err).Error(), http.StatusInternalServerError)
Expand Down
8 changes: 5 additions & 3 deletions pkg/provider/identityprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ type IdentityProvider struct {
metadataEndpoint *Endpoint
endpoints *Endpoints

timeFormat string
TimeFormat string
Expiration time.Duration
}

type Endpoints struct {
Expand All @@ -90,7 +91,8 @@ func NewIdentityProvider(metadata Endpoint, conf *IdentityProviderConfig, storag
postTemplate: conf.PostTemplate,
logoutTemplate: conf.LogoutTemplate,
endpoints: endpointConfigToEndpoints(conf.Endpoints),
timeFormat: DefaultTimeFormat,
TimeFormat: DefaultTimeFormat,
Expiration: DefaultExpiration,
}

if conf.PostTemplate == nil {
Expand Down Expand Up @@ -160,7 +162,7 @@ func (p *IdentityProvider) GetMetadata(ctx context.Context) (*md.IDPSSODescripto
return nil, nil, err
}

metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.timeFormat)
metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.TimeFormat)
return metadata, aaMetadata, nil
}

Expand Down
59 changes: 34 additions & 25 deletions pkg/provider/login.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package provider

import (
"context"
"fmt"
"net/http"

"github.com/zitadel/logging"

"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/saml/pkg/provider/xml/samlp"
)

func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Request) {
Expand All @@ -16,7 +20,6 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
Issuer: p.GetEntityID(r.Context()),
}

ctx := r.Context()
if err := r.ParseForm(); err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
Expand All @@ -34,52 +37,58 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
authRequest, err := p.storage.AuthRequestByID(r.Context(), requestID)
if err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to get request: %w", err).Error(), p.timeFormat))
response.sendBackResponse(r, w, p.errorResponse(response, StatusCodeRequestDenied, fmt.Errorf("failed to get request: %w", err).Error()))
return
}
response.RequestID = authRequest.GetAuthRequestID()
response.RelayState = authRequest.GetRelayState()
response.ProtocolBinding = authRequest.GetBindingType()
response.AcsUrl = authRequest.GetAccessConsumerServiceURL()

if !authRequest.Done() {
entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID())
if err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError)
return
}
response.Audience = entityID

entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID())
samlResponse, err := p.loginResponse(r.Context(), authRequest, response)
if err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError)
response.sendBackResponse(r, w, response.makeFailedResponse(err.Error(), "failed to create response", p.TimeFormat))
return
}
response.Audience = entityID

response.sendBackResponse(r, w, samlResponse)
return
}

func (p *IdentityProvider) loginResponse(ctx context.Context, authRequest models.AuthRequestInt, response *Response) (*samlp.ResponseType, error) {
if !authRequest.Done() {
logging.Error(StatusCodeAuthNFailed)
return nil, fmt.Errorf(StatusCodeAuthNFailed)
}

attrs := &Attributes{}
if err := p.storage.SetUserinfoWithUserID(ctx, authRequest.GetApplicationID(), attrs, authRequest.GetUserID(), []int{}); err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get userinfo: %w", err).Error(), http.StatusInternalServerError)
return
return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
}

samlResponse := response.makeSuccessfulResponse(attrs, p.timeFormat)
cert, key, err := getResponseCert(ctx, p.storage)
if err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
}

switch response.ProtocolBinding {
case PostBinding:
if err := createPostSignature(r.Context(), samlResponse, p); err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
return
}
case RedirectBinding:
if err := createRedirectSignature(r.Context(), samlResponse, p, response); err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
return
}
samlResponse := response.makeSuccessfulResponse(attrs, p.TimeFormat, p.Expiration)
if err := createSignature(response, samlResponse, key, cert, p.conf.SignatureAlgorithm); err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeResponder)
}
return samlResponse, nil
}

response.sendBackResponse(r, w, samlResponse)
return
func (p *IdentityProvider) errorResponse(response *Response, reason string, description string) *samlp.ResponseType {
return response.makeFailedResponse(reason, description, p.TimeFormat)
}
33 changes: 19 additions & 14 deletions pkg/provider/login_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package provider

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/golang/mock/gomock"
Expand All @@ -23,9 +23,11 @@ func TestSSO_loginHandleFunc(t *testing.T) {
Done bool
}
type res struct {
code int
err bool
state string
code int
err bool
state string
inflate bool
b64 bool
}
type sp struct {
appID string
Expand Down Expand Up @@ -235,7 +237,7 @@ func TestSSO_loginHandleFunc(t *testing.T) {
ID: "test",
AuthRequestID: "test",
Binding: RedirectBinding,
AcsURL: "url",
AcsURL: "https://sp.example.com",
RelayState: "relaystate",
UserID: "userid",
Done: false,
Expand All @@ -247,9 +249,11 @@ func TestSSO_loginHandleFunc(t *testing.T) {
},
},
res{
code: 500,
state: "",
err: false,
code: 302,
state: StatusCodeAuthNFailed,
err: false,
inflate: true,
b64: true,
}},
}

Expand Down Expand Up @@ -297,14 +301,15 @@ func TestSSO_loginHandleFunc(t *testing.T) {
defer func() {
_ = res.Body.Close()
}()
response, err := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.res.code {
t.Errorf("ssoHandleFunc() code got = %v, want %v", res.StatusCode, tt.res)
return
}

// currently only checked for redirect binding
if tt.res.state != "" {
if err := parseForState(string(response), tt.res.state); err != nil {
responseURL, err := url.Parse(res.Header.Get("Location"))
if err != nil {
t.Errorf("error while parsing url")
}

if err := parseForState(tt.res.inflate, tt.res.b64, responseURL.Query().Get("SAMLResponse"), tt.res.state); err != nil {
t.Errorf("ssoHandleFunc() response state not: %v", tt.res.state)
return
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/provider/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to parse form: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to decode request: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -69,10 +69,10 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
checkIfRequestTimeIsStillValid(
func() string { return logoutRequest.IssueInstant },
func() string { return logoutRequest.NotOnOrAfter },
p.timeFormat,
p.TimeFormat,
),
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate request: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -83,7 +83,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return err
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -106,7 +106,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque

response.sendBackLogoutResponse(
w,
response.makeSuccessfulLogoutResponse(p.timeFormat),
response.makeSuccessfulLogoutResponse(p.TimeFormat),
)
logging.Info(fmt.Sprintf("logout request for user %s", logoutRequest.NameID.Text))
}
Expand Down
39 changes: 6 additions & 33 deletions pkg/provider/logout_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,55 +55,28 @@ func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *sam
}
}

func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeSuccess,
"",
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makeUnsupportedlLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeRequestUnsupported,
message,
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makePartialLogoutResponse(
func (r *LogoutResponse) makeFailedLogoutResponse(
reason string,
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodePartialLogout,
reason,
message,
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makeDeniedLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeRequestDenied,
message,
StatusCodeSuccess,
"",
getIssuer(r.Issuer),
)
}
Expand Down
14 changes: 5 additions & 9 deletions pkg/provider/post.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package provider

import (
"context"
"crypto/rsa"
"encoding/base64"
"reflect"

Expand Down Expand Up @@ -63,16 +63,12 @@ func verifyPostSignature(
}

func createPostSignature(
ctx context.Context,
samlResponse *samlp.ResponseType,
idp *IdentityProvider,
key *rsa.PrivateKey,
cert []byte,
signatureAlgorithm string,
) error {
cert, key, err := getResponseCert(ctx, idp.storage)
if err != nil {
return err
}

signer, err := signature.GetSigner(cert, key, idp.conf.SignatureAlgorithm)
signer, err := signature.GetSigner(cert, key, signatureAlgorithm)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 95c785b

Please sign in to comment.