Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SSO login via multiple AssertionConsumerServiceURLs. #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions build_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@ import (

const issueInstantFormat = "2006-01-02T15:04:05Z"

func (sp *SAMLServiceProvider) buildAuthnRequest(includeSig bool) (*etree.Document, error) {
func (sp *SAMLServiceProvider) buildAuthnRequest(includeSig bool, assertionConsumerServiceURL ...string) (*etree.Document, error) {
authnRequest := &etree.Element{
Space: "samlp",
Tag: "AuthnRequest",
}

// When login via multiple ACS urls are supported the AuthnRequest will specify
// the ACS url explicitly. If none is specified, use the ACS url from the SP.
var acsUrl string
if len(assertionConsumerServiceURL) > 0 {
acsUrl = assertionConsumerServiceURL[0]
} else {
acsUrl = sp.AssertionConsumerServiceURL
}

authnRequest.CreateAttr("xmlns:samlp", "urn:oasis:names:tc:SAML:2.0:protocol")
authnRequest.CreateAttr("xmlns:saml", "urn:oasis:names:tc:SAML:2.0:assertion")

Expand All @@ -29,7 +38,7 @@ func (sp *SAMLServiceProvider) buildAuthnRequest(includeSig bool) (*etree.Docume
authnRequest.CreateAttr("ID", "_"+arId.String())
authnRequest.CreateAttr("Version", "2.0")
authnRequest.CreateAttr("ProtocolBinding", "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST")
authnRequest.CreateAttr("AssertionConsumerServiceURL", sp.AssertionConsumerServiceURL)
authnRequest.CreateAttr("AssertionConsumerServiceURL", acsUrl)
authnRequest.CreateAttr("IssueInstant", sp.Clock.Now().UTC().Format(issueInstantFormat))
authnRequest.CreateAttr("Destination", sp.IdentityProviderSSOURL)

Expand Down Expand Up @@ -72,12 +81,12 @@ func (sp *SAMLServiceProvider) buildAuthnRequest(includeSig bool) (*etree.Docume
return doc, nil
}

func (sp *SAMLServiceProvider) BuildAuthRequestDocument() (*etree.Document, error) {
return sp.buildAuthnRequest(true)
func (sp *SAMLServiceProvider) BuildAuthRequestDocument(acsUrl ...string) (*etree.Document, error) {
return sp.buildAuthnRequest(true, acsUrl...)
}

func (sp *SAMLServiceProvider) BuildAuthRequestDocumentNoSig() (*etree.Document, error) {
return sp.buildAuthnRequest(false)
func (sp *SAMLServiceProvider) BuildAuthRequestDocumentNoSig(acsUrl ...string) (*etree.Document, error) {
return sp.buildAuthnRequest(false, acsUrl...)
}

// SignAuthnRequest takes a document, builds a signature, creates another document
Expand Down
25 changes: 25 additions & 0 deletions build_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,28 @@ func TestRequestedAuthnContextIncluded(t *testing.T) {
require.Equal(t, el.Tag, "AuthnContextClassRef")
require.Equal(t, el.Text(), AuthnContextPasswordProtectedTransport)
}

func TestBuildAuthRequestDocumentWithCustomAcsUrl(t *testing.T) {
spURL := "https://sp.test"
sp := SAMLServiceProvider{
AssertionConsumerServiceURL: spURL,
MultiAssertionConsumerServiceURLs: []string{"https://sp.test1", "https://sp.test2", "https://sp.test3"},
AudienceURI: spURL,
IdentityProviderIssuer: spURL,
IdentityProviderSSOURL: "https://idp.test/saml/sso",
SignAuthnRequests: false,
}
// Case where ACS url is specified explicitly in the BuildAuthRequestDocument
doc, err := sp.BuildAuthRequestDocument("https://sp.test2")
require.NoError(t, err)
el := doc.FindElement("samlp:AuthnRequest")
require.Equal(t, el.SelectAttrValue("AssertionConsumerServiceURL", ""), "https://sp.test2")

// Case where no ACS url is specified in the BuildAuthRequestDocument.
// The AssertionConsumerServiceURL is supposed to be used in this case.
doc, err = sp.BuildAuthRequestDocument()
require.NoError(t, err)
el = doc.FindElement("samlp:AuthnRequest")
require.Equal(t, el.SelectAttrValue("AssertionConsumerServiceURL", ""), "https://sp.test")

}
19 changes: 16 additions & 3 deletions decode_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,22 @@ func (sp *SAMLServiceProvider) validationContext() *dsig.ValidationContext {
// validateResponseAttributes validates a SAML Response's tag and attributes. It does
// not inspect child elements of the Response at all.
func (sp *SAMLServiceProvider) validateResponseAttributes(response *types.Response) error {
if response.Destination != "" && response.Destination != sp.AssertionConsumerServiceURL {
destMatched := false
if len(sp.MultiAssertionConsumerServiceURLs) <= 1 {
if response.Destination == "" || response.Destination == sp.AssertionConsumerServiceURL {
destMatched = true
}
} else {
// Multiple ACS urls configured. Match the destination with any one of them.
for _, configuredAcsUrl := range sp.MultiAssertionConsumerServiceURLs {
if response.Destination == "" || response.Destination == configuredAcsUrl {
destMatched = true
break
}
}
}

if !destMatched {
return ErrInvalidValue{
Key: DestinationAttr,
Expected: sp.AssertionConsumerServiceURL,
Expand Down Expand Up @@ -464,5 +479,3 @@ func (sp *SAMLServiceProvider) ValidateEncodedLogoutResponseRedirect(encodedResp
return decodedResponse, nil
}
*/


148 changes: 148 additions & 0 deletions decode_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package saml2
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"github.com/russellhaering/gosaml2/types"
"io/ioutil"
"testing"
"time"
Expand Down Expand Up @@ -118,3 +120,149 @@ func TestCompressedResponse(t *testing.T) {
_, err = sp.RetrieveAssertionInfo(string(bs))
require.NoError(t, err, "Assertion info should be retrieved with no error")
}

func TestValidateResponseAttributesForMultiAcsUrls(t *testing.T) {
spURL := "myhost.test.com"
sp := SAMLServiceProvider{
AssertionConsumerServiceURL: spURL,
MultiAssertionConsumerServiceURLs: []string{"https://myhost-kube1-node1.test.com:443/sp/ACS.saml2", "https://myhost-kube1-node2.test.com:443/sp/ACS.saml2", "https://myhost-kube1-node3.test.com:443/sp/ACS.saml2"},
AudienceURI: spURL,
SignAuthnRequests: false,
}

bs, err := ioutil.ReadFile("./providertests/testdata/oktaenc_response_multi_acs.b64")
require.NoError(t, err, "couldn't read the response")

raw, err := base64.StdEncoding.DecodeString(string(bs))
require.NoError(t, err, "Couldn't decode encoded response.")

// Parse the raw response
_, el, err := parseResponse(raw)
if err != nil {
require.NoError(t, err, "Couldn't parse the response.")
}

decodedResponse := &types.Response{}
err = xmlUnmarshalElement(el, decodedResponse)
require.NoError(t, err, "Couldn't unmarshall the response.")

// Good case, when destination in the response matches one of the ACS urls configured.
err = sp.validateResponseAttributes(decodedResponse)
require.NoError(t, err, "Couldn't validate the saml response attributes.")

sp = SAMLServiceProvider{
AssertionConsumerServiceURL: spURL,
MultiAssertionConsumerServiceURLs: []string{"https://myhost-kube1-node0.test.com:443/sp/ACS.saml2", "https://myhost-kube1-node2.test.com:443/sp/ACS.saml2", "https://myhost-kube1-node3.test.com:443/sp/ACS.saml2"},
AudienceURI: spURL,
SignAuthnRequests: false,
}
// Response does not contain one of the ACS urls. Expect destination mismatch error.
err = sp.validateResponseAttributes(decodedResponse)
require.Error(t, err)
require.Contains(t, err.Error(), "Unrecognized Destination value")

}

func TestValidateResponseAttributes(t *testing.T) {
spURL := "https://myhost-kube1-node1.test.com:443/sp/ACS.saml2"
sp := SAMLServiceProvider{
AssertionConsumerServiceURL: spURL,
AudienceURI: spURL,
SignAuthnRequests: false,
}

bs, err := ioutil.ReadFile("./providertests/testdata/oktaenc_response_multi_acs.b64")
require.NoError(t, err, "couldn't read the response")

raw, err := base64.StdEncoding.DecodeString(string(bs))
require.NoError(t, err, "Couldn't decode encoded response.")

// Parse the raw response
_, el, err := parseResponse(raw)
if err != nil {
require.NoError(t, err, "Couldn't parse the response.")
}

decodedResponse := &types.Response{}
err = xmlUnmarshalElement(el, decodedResponse)
require.NoError(t, err, "Couldn't unmarshall the response.")

// Good case, when destination in the response matches the ACS urls configured.
err = sp.validateResponseAttributes(decodedResponse)
require.NoError(t, err, "Couldn't validate the saml response attributes.")

sp = SAMLServiceProvider{
AssertionConsumerServiceURL: "https://nomatch.test.com:443/sp/ACS.saml2",
AudienceURI: spURL,
SignAuthnRequests: false,
}
// Response does not contain the ACS urls. Expect destination mismatch error.
err = sp.validateResponseAttributes(decodedResponse)
require.Error(t, err)
require.Contains(t, err.Error(), "Unrecognized Destination value")

}

func TestValidateSubjectConfirmationDataRecipient(t *testing.T) {
spURL := "https://myhost-kube1-node1.test.com:443/sp/ACS.saml2"
sp := SAMLServiceProvider{
AssertionConsumerServiceURL: spURL,
AudienceURI: spURL,
SignAuthnRequests: false,
IdentityProviderIssuer: "http://www.okta.com/exk5lexwyipqCztUz5d7",
Clock: dsig.NewFakeClockAt(time.Date(2022, 9, 13, 20, 30, 00, 00, time.UTC)),
}

bs, err := ioutil.ReadFile("./providertests/testdata/oktaenc_response_multi_acs.b64")
require.NoError(t, err, "couldn't read the response")

raw, err := base64.StdEncoding.DecodeString(string(bs))
require.NoError(t, err, "Couldn't decode encoded response.")

// Parse the raw response
_, el, err := parseResponse(raw)
if err != nil {
require.NoError(t, err, "Couldn't parse the response.")
}

decodedResponse := &types.Response{}
err = xmlUnmarshalElement(el, decodedResponse)
require.NoError(t, err, "Couldn't unmarshall the response.")

// Good case, when recipient in the response matches the ACS urls configured.
err = sp.Validate(decodedResponse)
require.NoError(t, err, "Couldn't validate the saml response.")

}

func TestValidateSubjectConfirmationDataRecipientForMultiAcsUrls(t *testing.T) {
spURL := "myhost.test.com"
sp := SAMLServiceProvider{
AssertionConsumerServiceURL: spURL,
MultiAssertionConsumerServiceURLs: []string{"https://myhost-kube1-node1.test.com:443/sp/ACS.saml2", "https://myhost-kube1-node2.test.com:443/sp/ACS.saml2", "https://myhost-kube1-node3.test.com:443/sp/ACS.saml2"},
AudienceURI: spURL,
SignAuthnRequests: false,
Clock: dsig.NewFakeClockAt(time.Date(2022, 9, 13, 20, 30, 00, 00, time.UTC)),
}

bs, err := ioutil.ReadFile("./providertests/testdata/oktaenc_response_multi_acs.b64")
require.NoError(t, err, "couldn't read the response")

raw, err := base64.StdEncoding.DecodeString(string(bs))
require.NoError(t, err, "Couldn't decode encoded response.")

// Parse the raw response
_, el, err := parseResponse(raw)
if err != nil {
require.NoError(t, err, "Couldn't parse the response.")
}

decodedResponse := &types.Response{}
err = xmlUnmarshalElement(el, decodedResponse)
require.NoError(t, err, "Couldn't unmarshall the response.")

// Good case, when recipient in the response matches one of the ACS urls configured.
err = sp.Validate(decodedResponse)
require.NoError(t, err, "Couldn't validate the saml response.")

}
Loading