From a97ac98bfe61a37a140a98fa2b06990b63fabb91 Mon Sep 17 00:00:00 2001 From: Sword Date: Tue, 21 Nov 2023 05:19:24 +0000 Subject: [PATCH 1/2] Support custom idp entity id --- identity_provider.go | 21 ++++++++- identity_provider_test.go | 47 ++++++++++++++++--- ...esponse_response_with_custom_entity_id.xml | 27 +++++++++++ 3 files changed, 87 insertions(+), 8 deletions(-) create mode 100644 testdata/TestIDPMakeResponse_response_with_custom_entity_id.xml diff --git a/identity_provider.go b/identity_provider.go index abaaad68..135915e7 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -107,8 +107,11 @@ type IdentityProvider struct { AssertionMaker AssertionMaker SignatureMethod string ValidDuration *time.Duration + EntityIDConstructor EntityIDConstructor } +type EntityIDConstructor func() string + // Metadata returns the metadata structure for this identity provider. func (idp *IdentityProvider) Metadata() *EntityDescriptor { certStr := base64.StdEncoding.EncodeToString(idp.Certificate.Raw) @@ -121,7 +124,7 @@ func (idp *IdentityProvider) Metadata() *EntityDescriptor { } ed := &EntityDescriptor{ - EntityID: idp.MetadataURL.String(), + EntityID: idp.getEntityID(), ValidUntil: TimeNow().Add(validDuration), CacheDuration: validDuration, IDPSSODescriptors: []IDPSSODescriptor{ @@ -334,6 +337,20 @@ func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Re } } +// createDefaultEntityIDConstructor creates a function to return entityID from metadataURL. +func createDefaultEntityIDConstructor(metadataURL url.URL) func() string { + return func() string { + return metadataURL.String() + } +} + +func (idp *IdentityProvider) getEntityID() string { + if idp.EntityIDConstructor == nil { + return createDefaultEntityIDConstructor(idp.MetadataURL)() + } + return idp.EntityIDConstructor() +} + // IdpAuthnRequest is used by IdentityProvider to handle a single authentication request. type IdpAuthnRequest struct { IDP *IdentityProvider @@ -1019,7 +1036,7 @@ func (req *IdpAuthnRequest) MakeResponse() error { Version: "2.0", Issuer: &Issuer{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", - Value: req.IDP.MetadataURL.String(), + Value: req.IDP.getEntityID(), }, Status: Status{ StatusCode: StatusCode{ diff --git a/identity_provider_test.go b/identity_provider_test.go index 9d06a4bb..92729112 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -38,11 +38,12 @@ type IdentityProviderTest struct { SPCertificate *x509.Certificate SP ServiceProvider - Key crypto.PrivateKey - Signer crypto.Signer - Certificate *x509.Certificate - SessionProvider SessionProvider - IDP IdentityProvider + Key crypto.PrivateKey + Signer crypto.Signer + Certificate *x509.Certificate + SessionProvider SessionProvider + IDP IdentityProvider + ExpectedFilename string } func mustParseURL(s string) url.URL { @@ -98,6 +99,24 @@ var applySigner = idpTestOpts{ }, } +// applyEntityIDConstructor will set the entity ID constructor for the identity provider. +func applyEntityIDConstructor(c EntityIDConstructor) idpTestOpts { + return idpTestOpts{ + apply: func(_ *testing.T, test *IdentityProviderTest) { + test.IDP.EntityIDConstructor = c + }, + } +} + +// applyExpectedFilename will set the expected filename for the identity provider. +func applyExpectedFilename(filename string) idpTestOpts { + return idpTestOpts{ + apply: func(_ *testing.T, test *IdentityProviderTest) { + test.ExpectedFilename = filename + }, + } +} + func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProviderTest { test := IdentityProviderTest{} TimeNow = func() time.Time { @@ -139,6 +158,7 @@ func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProvide }, }, } + test.ExpectedFilename = "TestIDPMakeResponse_response.xml" // apply the test options for _, opt := range opts { @@ -772,7 +792,7 @@ func testMakeResponse(t *testing.T, test *IdentityProviderTest) { doc.Indent(2) responseStr, err := doc.WriteToString() assert.Check(t, err) - golden.Assert(t, responseStr, "TestIDPMakeResponse_response.xml") + golden.Assert(t, responseStr, test.ExpectedFilename) } func TestIDPWriteResponse(t *testing.T) { @@ -1130,3 +1150,18 @@ func TestIDPHTTPCanHandleSSORequest(t *testing.T) { assert.Check(t, is.Equal(http.StatusBadRequest, w.Code)) } } + +func TestIdentityProviderCustomEntityID(t *testing.T) { + customEntityID := "https://idp.example.com/entity-id" + test := NewIdentityProviderTest( + t, + applyKey, + applyEntityIDConstructor(func() string { + return customEntityID + }), + applyExpectedFilename("TestIDPMakeResponse_response_with_custom_entity_id.xml"), + ) + + assert.Equal(t, customEntityID, test.IDP.Metadata().EntityID) + testMakeResponse(t, test) +} diff --git a/testdata/TestIDPMakeResponse_response_with_custom_entity_id.xml b/testdata/TestIDPMakeResponse_response_with_custom_entity_id.xml new file mode 100644 index 00000000..853ca0a1 --- /dev/null +++ b/testdata/TestIDPMakeResponse_response_with_custom_entity_id.xml @@ -0,0 +1,27 @@ + + https://idp.example.com/entity-id + + + + + + + + + + + 5bBiRThV9gjcTNlKa+y00Gnzkh8= + + + A9fzgSO00HntRcx32qCEVHoTR8YiisGk6tkeAbhRKzXoIOw3UE4nhoBIYPTYj5G+mMjnB/eEw84kuUSZ9mLV+EIAMQuR6ctJyO6xdxy65l+iC0IBSk65wqCb6C4IRB5OaxN/QC0yTJ8Ps2+s1WRJSLLcmQU6Xatpe25vzk+hQ+4= + + + MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ== + + + + + + + + From 52cb104ccfe5bb5efa6b9882d1e5644c4553e04a Mon Sep 17 00:00:00 2001 From: Sword Date: Tue, 21 Nov 2023 05:49:16 +0000 Subject: [PATCH 2/2] Add test for samlidp --- identity_provider.go | 5 +-- samlidp/samlidp.go | 26 ++++++------ samlidp/samlidp_test.go | 41 ++++++++++++++++--- ...tadata_response_with_custom_entity_id.html | 25 +++++++++++ 4 files changed, 76 insertions(+), 21 deletions(-) create mode 100644 samlidp/testdata/http_metadata_response_with_custom_entity_id.html diff --git a/identity_provider.go b/identity_provider.go index 135915e7..abf180c3 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -110,6 +110,7 @@ type IdentityProvider struct { EntityIDConstructor EntityIDConstructor } +// EntityIDConstructor is a function that returns the entityID for customization. type EntityIDConstructor func() string // Metadata returns the metadata structure for this identity provider. @@ -339,9 +340,7 @@ func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Re // createDefaultEntityIDConstructor creates a function to return entityID from metadataURL. func createDefaultEntityIDConstructor(metadataURL url.URL) func() string { - return func() string { - return metadataURL.String() - } + return metadataURL.String } func (idp *IdentityProvider) getEntityID() string { diff --git a/samlidp/samlidp.go b/samlidp/samlidp.go index 13ca10b9..d6e6f8a7 100644 --- a/samlidp/samlidp.go +++ b/samlidp/samlidp.go @@ -18,12 +18,13 @@ import ( // Options represent the parameters to New() for creating a new IDP server type Options struct { - URL url.URL - Key crypto.PrivateKey - Signer crypto.Signer - Logger logger.Interface - Certificate *x509.Certificate - Store Store + URL url.URL + Key crypto.PrivateKey + Signer crypto.Signer + Logger logger.Interface + Certificate *x509.Certificate + Store Store + EntityIDConstructor saml.EntityIDConstructor } // Server represents an IDP server. The server provides the following URLs: @@ -59,12 +60,13 @@ func New(opts Options) (*Server, error) { s := &Server{ serviceProviders: map[string]*saml.EntityDescriptor{}, IDP: saml.IdentityProvider{ - Key: opts.Key, - Signer: opts.Signer, - Logger: logr, - Certificate: opts.Certificate, - MetadataURL: metadataURL, - SSOURL: ssoURL, + Key: opts.Key, + Signer: opts.Signer, + Logger: logr, + Certificate: opts.Certificate, + MetadataURL: metadataURL, + SSOURL: ssoURL, + EntityIDConstructor: opts.EntityIDConstructor, }, logger: logr, Store: opts.Store, diff --git a/samlidp/samlidp_test.go b/samlidp/samlidp_test.go index e5b2dafb..e1fd9dd9 100644 --- a/samlidp/samlidp_test.go +++ b/samlidp/samlidp_test.go @@ -66,6 +66,15 @@ func mustParseCertificate(pemStr []byte) *x509.Certificate { return cert } +func setupTestVariables() { + saml.TimeNow = func() time.Time { + rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") + return rv + } + jwt.TimeFunc = saml.TimeNow + saml.RandReader = &testRandomReader{} +} + type ServerTest struct { SPKey *rsa.PrivateKey SPCertificate *x509.Certificate @@ -79,12 +88,7 @@ type ServerTest struct { func NewServerTest(t *testing.T) *ServerTest { test := ServerTest{} - saml.TimeNow = func() time.Time { - rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") - return rv - } - jwt.TimeFunc = saml.TimeNow - saml.RandReader = &testRandomReader{} + setupTestVariables() test.SPKey = mustParsePrivateKey(golden.Get(t, "sp_key.pem")).(*rsa.PrivateKey) test.SPCertificate = mustParseCertificate(golden.Get(t, "sp_cert.pem")) @@ -143,3 +147,28 @@ func TestHTTPCanSSORequest(t *testing.T) { w.Body.String()) golden.Assert(t, w.Body.String(), "http_sso_response.html") } + +func TestHTTPMetadataResponseWithCustomEntityID(t *testing.T) { + setupTestVariables() + + server, err := New(Options{ + Certificate: mustParseCertificate(golden.Get(t, "idp_cert.pem")), + Key: mustParsePrivateKey(golden.Get(t, "idp_key.pem")).(*rsa.PrivateKey), + Logger: logger.DefaultLogger, + URL: url.URL{Scheme: "https", Host: "idp.example.com"}, + Store: &MemoryStore{}, + EntityIDConstructor: func() string { + return "https://idp.example.com/idp-id" + }, + }) + assert.Check(t, err) + + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "https://idp.example.com/metadata", nil) + server.ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusOK, w.Code)) + assert.Check(t, + strings.HasPrefix(w.Body.String(), " + + + + + MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ== + + + + + + + MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ== + + + + + + + + urn:oasis:names:tc:SAML:2.0:nameid-format:transient + + + + \ No newline at end of file