diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 391d7d14..d6b985d1 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -8,11 +8,10 @@ on: jobs: golangci: - name: Run golangci-lint runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + uses: golangci/golangci-lint-action@v3 with: version: v1.54.2 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 60f3f8f0..e6e91273 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,20 +7,13 @@ on: branches: [ 'main' ] jobs: tests: - name: Run tests runs-on: ubuntu-latest strategy: matrix: - go: [ '1.17.x', '1.18.x', '1.19.x'] + go: [ '1.19.x', '1.20.x', '1.21.x'] steps: - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - - name: Go version - run: go version - - name: Run Go tests - run: | - go test -v ./... + - run: go test -v ./... diff --git a/.golangci.yml b/.golangci.yml index f93ef23b..23f37cbf 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -8,7 +8,6 @@ linters: enable: - bodyclose # checks whether HTTP response body is closed successfully [fast: false, auto-fix: false] - - depguard # Go linter that checks if package imports are in a list of acceptable packages [fast: true, auto-fix: false] - errcheck # Inspects source code for security problems [fast: true, auto-fix: false] - gocritic # The most opinionated Go source code linter [fast: true, auto-fix: false] - gocyclo # Computes and checks the cyclomatic complexity of functions [fast: true, auto-fix: false] @@ -36,6 +35,7 @@ linters: - gochecknoinits # Checks that no init functions are present in Go code [fast: true, auto-fix: false] - goconst # Finds repeated strings that could be replaced by a constant [fast: true, auto-fix: false] - lll # Reports long lines [fast: true, auto-fix: false] + - depguard # Go linter that checks if package imports are in a list of acceptable packages [fast: true, auto-fix: false] linters-settings: goimports: local-prefixes: github.com/crewjam/saml diff --git a/go.mod b/go.mod index 745c5c2c..493d8b9f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/crewjam/saml -go 1.16 +go 1.19 require ( github.com/beevik/etree v1.1.0 @@ -10,10 +10,19 @@ require ( github.com/google/go-cmp v0.5.9 github.com/kr/pretty v0.3.1 github.com/mattermost/xml-roundtrip-validator v0.1.0 - github.com/pkg/errors v0.9.1 // indirect github.com/russellhaering/goxmldsig v1.3.0 github.com/stretchr/testify v1.8.1 github.com/zenazn/goji v1.0.1 golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed gotest.tools v2.2.0+incompatible ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jonboulle/clockwork v0.2.2 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 7ab71ea2..3f89dafa 100644 --- a/go.sum +++ b/go.sum @@ -50,13 +50,6 @@ github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8= github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed h1:YoWVYYAfvQ4ddHv3OKmIvX7NCAhFGTj62VP2l2kfBbA= golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/identity_provider.go b/identity_provider.go index 49109196..1c56e65c 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -9,7 +9,6 @@ import ( "encoding/xml" "fmt" "io" - "io/ioutil" "net/http" "net/url" "os" @@ -368,7 +367,7 @@ func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnReques if err != nil { return nil, fmt.Errorf("cannot decode request: %s", err) } - req.RequestBuffer, err = ioutil.ReadAll(newSaferFlateReader(bytes.NewReader(compressedRequest))) + req.RequestBuffer, err = io.ReadAll(newSaferFlateReader(bytes.NewReader(compressedRequest))) if err != nil { return nil, fmt.Errorf("cannot decompress request: %s", err) } diff --git a/identity_provider_go116_test.go b/identity_provider_go116_test.go deleted file mode 100644 index 6d4a0a53..00000000 --- a/identity_provider_go116_test.go +++ /dev/null @@ -1,57 +0,0 @@ -//go:build !go1.17 -// +build !go1.17 - -package saml - -import ( - "bytes" - "compress/flate" - "encoding/base64" - "io/ioutil" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "gotest.tools/assert" - is "gotest.tools/assert/cmp" -) - -func TestIDPHTTPCanHandleSSORequest(t *testing.T) { - test := NewIdentityProviderTest(t, applyKey) - w := httptest.NewRecorder() - - const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D` - - r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+ - "SAMLRequest="+validRequest, nil) - test.IDP.Handler().ServeHTTP(w, r) - assert.Check(t, is.Equal(http.StatusOK, w.Code)) - - // rejects requests that are invalid - w = httptest.NewRecorder() - r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+ - "SAMLRequest=PEF1dGhuUmVxdWVzdA%3D%3D", nil) - test.IDP.Handler().ServeHTTP(w, r) - assert.Check(t, is.Equal(http.StatusBadRequest, w.Code)) - - // rejects requests that contain malformed XML - { - a, _ := url.QueryUnescape(validRequest) - b, _ := base64.StdEncoding.DecodeString(a) - c, _ := ioutil.ReadAll(flate.NewReader(bytes.NewReader(b))) - d := bytes.Replace(c, []byte("]]"), 1) - f := bytes.Buffer{} - e, _ := flate.NewWriter(&f, flate.DefaultCompression) - _, err := e.Write(d) - assert.Check(t, err) - err = e.Close() - assert.Check(t, err) - g := base64.StdEncoding.EncodeToString(f.Bytes()) - invalidRequest := url.QueryEscape(g) - - w = httptest.NewRecorder() - r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+ - "SAMLRequest="+invalidRequest, nil) - test.IDP.Handler().ServeHTTP(w, r) - assert.Check(t, is.Equal(http.StatusBadRequest, w.Code)) - } -} diff --git a/identity_provider_test.go b/identity_provider_test.go index e4a0fd59..1d068268 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -10,6 +10,7 @@ import ( "encoding/pem" "encoding/xml" "fmt" + "io" "math/rand" "net/http" "net/http/httptest" @@ -1091,3 +1092,44 @@ func TestIDPRejectDecompressionBomb(t *testing.T) { _, err = NewIdpAuthnRequest(&test.IDP, r) assert.Error(t, err, "cannot decompress request: flate: uncompress limit exceeded (10485760 bytes)") } + +func TestIDPHTTPCanHandleSSORequest(t *testing.T) { + test := NewIdentityProviderTest(t, applyKey) + w := httptest.NewRecorder() + + const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D` + + r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+ + "SAMLRequest="+validRequest, nil) + test.IDP.Handler().ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusOK, w.Code)) + + // rejects requests that are invalid + w = httptest.NewRecorder() + r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+ + "SAMLRequest=PEF1dGhuUmVxdWVzdA%3D%3D", nil) + test.IDP.Handler().ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusBadRequest, w.Code)) + + // rejects requests that contain malformed XML + { + a, _ := url.QueryUnescape(validRequest) + b, _ := base64.StdEncoding.DecodeString(a) + c, _ := io.ReadAll(flate.NewReader(bytes.NewReader(b))) + d := bytes.Replace(c, []byte("]]"), 1) + f := bytes.Buffer{} + e, _ := flate.NewWriter(&f, flate.DefaultCompression) + _, err := e.Write(d) + assert.Check(t, err) + err = e.Close() + assert.Check(t, err) + g := base64.StdEncoding.EncodeToString(f.Bytes()) + invalidRequest := url.QueryEscape(g) + + w = httptest.NewRecorder() + r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+ + "SAMLRequest="+invalidRequest, nil) + test.IDP.Handler().ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusBadRequest, w.Code)) + } +} diff --git a/samlidp/util.go b/samlidp/util.go index d624be4f..2cb3c162 100644 --- a/samlidp/util.go +++ b/samlidp/util.go @@ -5,7 +5,6 @@ import ( "encoding/xml" "errors" "io" - "io/ioutil" xrv "github.com/mattermost/xml-roundtrip-validator" @@ -22,7 +21,7 @@ func randomBytes(n int) []byte { func getSPMetadata(r io.Reader) (spMetadata *saml.EntityDescriptor, err error) { var data []byte - if data, err = ioutil.ReadAll(r); err != nil { + if data, err = io.ReadAll(r); err != nil { return nil, err } diff --git a/samlidp/util_go116_test.go b/samlidp/util_go116_test.go deleted file mode 100644 index 16f54740..00000000 --- a/samlidp/util_go116_test.go +++ /dev/null @@ -1,26 +0,0 @@ -//go:build !go1.17 -// +build !go1.17 - -package samlidp - -import ( - "strings" - "testing" - - "gotest.tools/assert" - is "gotest.tools/assert/cmp" -) - -func TestGetSPMetadata(t *testing.T) { - good := "" + - "\n" + - "" - _, err := getSPMetadata(strings.NewReader(good)) - assert.Check(t, err) - - bad := "" + - "\n" + - "" - _, err = getSPMetadata(strings.NewReader(bad)) - assert.Check(t, is.Error(err, "validator: in token starting at 1:1: roundtrip error: expected {{ EntityDescriptor} [{{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{ :attr} foo} {{ validUntil} 2013-03-10T00:32:19.104Z} {{ cacheDuration} PT1H} {{ entityID} http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/}]}, observed {{ EntityDescriptor} [{{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{ attr} foo} {{ validUntil} 2013-03-10T00:32:19.104Z} {{ cacheDuration} PT1H} {{ entityID} http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/}]}")) -} diff --git a/samlidp/util_go117_test.go b/samlidp/util_test.go similarity index 68% rename from samlidp/util_go117_test.go rename to samlidp/util_test.go index 2c273bc1..23eccbf8 100644 --- a/samlidp/util_go117_test.go +++ b/samlidp/util_test.go @@ -8,12 +8,11 @@ import ( "testing" "gotest.tools/assert" - is "gotest.tools/assert/cmp" ) func TestGetSPMetadata(t *testing.T) { good := "" + - "\n" + + "\n" + "" _, err := getSPMetadata(strings.NewReader(good)) assert.Check(t, err) @@ -22,5 +21,5 @@ func TestGetSPMetadata(t *testing.T) { "]]>\n" + "" _, err = getSPMetadata(strings.NewReader(bad)) - assert.Check(t, is.Error(err, "XML syntax error on line 1: unescaped ]]> not in CDATA section")) + assert.Check(t, err != nil) } diff --git a/samlsp/fetch_metadata.go b/samlsp/fetch_metadata.go index 4d92503e..ede3c6b3 100644 --- a/samlsp/fetch_metadata.go +++ b/samlsp/fetch_metadata.go @@ -5,7 +5,7 @@ import ( "context" "encoding/xml" "errors" - "io/ioutil" + "io" "net/http" "net/url" @@ -72,7 +72,7 @@ func FetchMetadata(ctx context.Context, httpClient *http.Client, metadataURL url return nil, httperr.Response(*resp) } - data, err := ioutil.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) if err != nil { return nil, err } diff --git a/samlsp/fetch_metadata_go116_test.go b/samlsp/fetch_metadata_go116_test.go deleted file mode 100644 index 91c3aa69..00000000 --- a/samlsp/fetch_metadata_go116_test.go +++ /dev/null @@ -1,34 +0,0 @@ -//go:build !go1.17 -// +build !go1.17 - -package samlsp - -import ( - "bytes" - "context" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "gotest.tools/assert" - is "gotest.tools/assert/cmp" -) - -func TestFetchMetadataRejectsInvalid(t *testing.T) { - test := NewMiddlewareTest(t) - test.IDPMetadata = bytes.Replace(test.IDPMetadata, - []byte("]]")) - - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Check(t, is.Equal("/metadata", r.URL.String())) - _, err := w.Write(test.IDPMetadata) - assert.Check(t, err) - })) - - fmt.Println(testServer.URL + "/metadata") - u, _ := url.Parse(testServer.URL + "/metadata") - md, err := FetchMetadata(context.Background(), testServer.Client(), *u) - assert.Check(t, is.Error(err, "expected element in name space urn:oasis:names:tc:SAML:2.0:metadata but have no name space")) - assert.Check(t, is.Nil(md)) -} diff --git a/samlsp/fetch_metadata_test.go b/samlsp/fetch_metadata_test.go index bd90dd8a..c0295e13 100644 --- a/samlsp/fetch_metadata_test.go +++ b/samlsp/fetch_metadata_test.go @@ -1,6 +1,7 @@ package samlsp import ( + "bytes" "context" "fmt" "net/http" @@ -27,3 +28,21 @@ func TestFetchMetadata(t *testing.T) { assert.Check(t, err) assert.Check(t, is.Equal("https://idp.testshib.org/idp/shibboleth", md.EntityID)) } + +func TestFetchMetadataRejectsInvalid(t *testing.T) { + test := NewMiddlewareTest(t) + test.IDPMetadata = bytes.ReplaceAll(test.IDPMetadata, + []byte("]]")) + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Check(t, is.Equal("/metadata", r.URL.String())) + _, err := w.Write(test.IDPMetadata) + assert.Check(t, err) + })) + + fmt.Println(testServer.URL + "/metadata") + u, _ := url.Parse(testServer.URL + "/metadata") + md, err := FetchMetadata(context.Background(), testServer.Client(), *u) + assert.Check(t, err != nil) + assert.Check(t, is.Nil(md)) +} diff --git a/samlsp/middleware_test.go b/samlsp/middleware_test.go index 801aad08..fdb05b20 100644 --- a/samlsp/middleware_test.go +++ b/samlsp/middleware_test.go @@ -8,7 +8,7 @@ import ( "encoding/base64" "encoding/json" "encoding/xml" - "io/ioutil" + "io" "net" "net/http" "net/http/httptest" @@ -141,7 +141,7 @@ func TestMiddlewareFourOhFour(t *testing.T) { resp := httptest.NewRecorder() test.Middleware.ServeHTTP(resp, req) assert.Check(t, is.Equal(http.StatusNotFound, resp.Code)) - respBuf, _ := ioutil.ReadAll(resp.Body) + respBuf, _ := io.ReadAll(resp.Body) assert.Check(t, is.Equal("404 page not found\n", string(respBuf))) } @@ -516,7 +516,7 @@ func TestMiddlewareHandlesInvalidResponse(t *testing.T) { // the ACS handles DOES NOT reveal detailed error information in the // HTTP response. assert.Check(t, is.Equal(http.StatusForbidden, resp.Code)) - respBody, _ := ioutil.ReadAll(resp.Body) + respBody, _ := io.ReadAll(resp.Body) assert.Check(t, is.Equal("Forbidden\n", string(respBody))) assert.Check(t, is.Equal("", resp.Header().Get("Location"))) assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie"))) diff --git a/samlsp/samlsp_test.go b/samlsp/samlsp_test.go index 886098bb..75fa6894 100644 --- a/samlsp/samlsp_test.go +++ b/samlsp/samlsp_test.go @@ -6,7 +6,7 @@ import ( "crypto" "crypto/x509" "encoding/pem" - "io/ioutil" + "io" "net/http" "net/url" "testing" @@ -61,7 +61,7 @@ func TestCanParseTestshibMetadata(t *testing.T) { Header: http.Header{}, Request: req, StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader(responseBody)), + Body: io.NopCloser(bytes.NewReader(responseBody)), }, nil }), } diff --git a/service_provider.go b/service_provider.go index ad21321e..9dc88ca1 100644 --- a/service_provider.go +++ b/service_provider.go @@ -12,7 +12,7 @@ import ( "errors" "fmt" "html/template" - "io/ioutil" + "io" "net/http" "net/url" "regexp" @@ -642,7 +642,7 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: HTTP status %d (%s)", response.StatusCode, response.Status) return nil, retErr } - responseBody, err := ioutil.ReadAll(response.Body) + responseBody, err := io.ReadAll(response.Body) if err != nil { retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: %s", err) return nil, retErr @@ -1537,7 +1537,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(query url.Values) erro } retErr.Response = string(rawResponseBuf) - gr, err := ioutil.ReadAll(newSaferFlateReader(bytes.NewBuffer(rawResponseBuf))) + gr, err := io.ReadAll(newSaferFlateReader(bytes.NewBuffer(rawResponseBuf))) if err != nil { retErr.PrivateErr = err return retErr diff --git a/service_provider_go116_test.go b/service_provider_go116_test.go deleted file mode 100644 index 77395e01..00000000 --- a/service_provider_go116_test.go +++ /dev/null @@ -1,136 +0,0 @@ -//go:build !go1.17 -// +build !go1.17 - -package saml - -import ( - "encoding/base64" - "encoding/xml" - "net/http" - "net/url" - "strings" - "testing" - "time" - - dsig "github.com/russellhaering/goxmldsig" - "gotest.tools/assert" - is "gotest.tools/assert/cmp" - "gotest.tools/golden" -) - -func TestSPRejectsMalformedResponse(t *testing.T) { - test := NewServiceProviderTest(t) - // An actual response from google - TimeNow = func() time.Time { - rv, _ := time.Parse("Mon Jan 2 15:04:05 UTC 2006", "Tue Jan 5 16:55:39 UTC 2016") - return rv - } - Clock = dsig.NewFakeClockAt(TimeNow()) - SamlResponse := golden.Get(t, "TestSPRejectsMalformedResponse_response") - test.IDPMetadata = golden.Get(t, "TestSPRejectsMalformedResponse_IDPMetadata") - - s := ServiceProvider{ - Key: test.Key, - Certificate: test.Certificate, - MetadataURL: mustParseURL("https://29ee6d2e.ngrok.io/saml/metadata"), - AcsURL: mustParseURL("https://29ee6d2e.ngrok.io/saml/acs"), - IDPMetadata: &EntityDescriptor{}, - } - err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) - assert.Check(t, err) - - // this is a valid response - { - req := http.Request{PostForm: url.Values{}} - req.PostForm.Set("SAMLResponse", string(SamlResponse)) - assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) - assert.Check(t, err) - assert.Check(t, is.Equal("ross@octolabs.io", assertion.Subject.NameID.Value)) - } - - // this is a valid response but with a comment injected - { - x, _ := base64.StdEncoding.DecodeString(string(SamlResponse)) - y := strings.Replace(string(x), "World!"))) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "cannot unmarshal response: expected element type but have ")) - - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"wrongRequestID"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`InResponseTo` does not match any of the possible request IDs (expected [wrongRequestID])")) - - TimeNow = func() time.Time { - rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Nov 30 20:57:09 UTC 2016") - return rv - } - Clock = dsig.NewFakeClockAt(TimeNow()) - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "response IssueInstant expired at 2015-12-01 01:57:51.375 +0000 UTC")) - TimeNow = func() time.Time { - rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") - return rv - } - Clock = dsig.NewFakeClockAt(TimeNow()) - - s.IDPMetadata.EntityID = "http://snakeoil.com" - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "response Issuer does not match the IDP metadata (expected \"http://snakeoil.com\")")) - s.IDPMetadata.EntityID = "https://idp.testshib.org/idp/shibboleth" - - oldSpStatusSuccess := StatusSuccess - StatusSuccess = "not:the:success:value" - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "urn:oasis:names:tc:SAML:2.0:status:Success")) - StatusSuccess = oldSpStatusSuccess - - s.IDPMetadata.IDPSSODescriptors[0].KeyDescriptors[0].KeyInfo.X509Data.X509Certificates[0].Data = "invalid" - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "cannot validate signature on Response: cannot parse certificate: illegal base64 data at input byte 4")) - - s.IDPMetadata.IDPSSODescriptors[0].KeyDescriptors[0].KeyInfo.X509Data.X509Certificates[0].Data = "aW52YWxpZA==" - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "cannot validate signature on Response: asn1: structure error: tags don't match (16 vs {class:1 tag:9 length:110 isCompound:true}) {optional:false explicit:false application:false private:false defaultValue: tag: stringType:0 timeType:0 set:false omitEmpty:false} certificate @2")) -} diff --git a/service_provider_go117_test.go b/service_provider_go117_test.go deleted file mode 100644 index 3d4a5835..00000000 --- a/service_provider_go117_test.go +++ /dev/null @@ -1,136 +0,0 @@ -//go:build go1.17 -// +build go1.17 - -package saml - -import ( - "encoding/base64" - "encoding/xml" - "net/http" - "net/url" - "strings" - "testing" - "time" - - dsig "github.com/russellhaering/goxmldsig" - "gotest.tools/assert" - is "gotest.tools/assert/cmp" - "gotest.tools/golden" -) - -func TestSPRejectsMalformedResponse(t *testing.T) { - test := NewServiceProviderTest(t) - // An actual response from google - TimeNow = func() time.Time { - rv, _ := time.Parse("Mon Jan 2 15:04:05 UTC 2006", "Tue Jan 5 16:55:39 UTC 2016") - return rv - } - Clock = dsig.NewFakeClockAt(TimeNow()) - SamlResponse := golden.Get(t, "TestSPRejectsMalformedResponse_response") - test.IDPMetadata = golden.Get(t, "TestSPRejectsMalformedResponse_IDPMetadata") - - s := ServiceProvider{ - Key: test.Key, - Certificate: test.Certificate, - MetadataURL: mustParseURL("https://29ee6d2e.ngrok.io/saml/metadata"), - AcsURL: mustParseURL("https://29ee6d2e.ngrok.io/saml/acs"), - IDPMetadata: &EntityDescriptor{}, - } - err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) - assert.Check(t, err) - - // this is a valid response - { - req := http.Request{PostForm: url.Values{}} - req.PostForm.Set("SAMLResponse", string(SamlResponse)) - assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) - assert.Check(t, err) - assert.Check(t, is.Equal("ross@octolabs.io", assertion.Subject.NameID.Value)) - } - - // this is a valid response but with a comment injected - { - x, _ := base64.StdEncoding.DecodeString(string(SamlResponse)) - y := strings.Replace(string(x), "World!"))) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "cannot unmarshal response: expected element type but have ")) - - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"wrongRequestID"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`InResponseTo` does not match any of the possible request IDs (expected [wrongRequestID])")) - - TimeNow = func() time.Time { - rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Nov 30 20:57:09 UTC 2016") - return rv - } - Clock = dsig.NewFakeClockAt(TimeNow()) - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "response IssueInstant expired at 2015-12-01 01:57:51.375 +0000 UTC")) - TimeNow = func() time.Time { - rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") - return rv - } - Clock = dsig.NewFakeClockAt(TimeNow()) - - s.IDPMetadata.EntityID = "http://snakeoil.com" - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "response Issuer does not match the IDP metadata (expected \"http://snakeoil.com\")")) - s.IDPMetadata.EntityID = "https://idp.testshib.org/idp/shibboleth" - - oldSpStatusSuccess := StatusSuccess - StatusSuccess = "not:the:success:value" - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "urn:oasis:names:tc:SAML:2.0:status:Success")) - StatusSuccess = oldSpStatusSuccess - - s.IDPMetadata.IDPSSODescriptors[0].KeyDescriptors[0].KeyInfo.X509Data.X509Certificates[0].Data = "invalid" - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "cannot validate signature on Assertion: cannot parse certificate: illegal base64 data at input byte 4")) - - s.IDPMetadata.IDPSSODescriptors[0].KeyDescriptors[0].KeyInfo.X509Data.X509Certificates[0].Data = "aW52YWxpZA==" - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "cannot validate signature on Assertion: x509: malformed certificate")) -} diff --git a/service_provider_test.go b/service_provider_test.go index c1f2d80b..1d145871 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -37,7 +37,7 @@ type ServiceProviderTest struct { // x1 := "lJJBj9MwEIX%2FSuR7Y4%2FJRisriVS2Qqq0QNUAB27GmbYWiV08E6D%2FHqeA6AnKdfz85nvPbtYzn8Iev8xIXHyfxkCtmFMw0ZInE%2ByEZNiZfv362ehSmXOKHF0cRbEmwsQ%2BhqcYaJ4w9Zi%2Beofv98%2BtODGfyUgJD3UNVVWV4Zji59JHSXYatbSORLHJO32wi8efG344l5wP6OQ%2FlTEdl4HMWw9%2BRLlgaLnHwSd0LPv%2BrSi2m1b4YaWU0qpStXpUVjmFoEBDBTU8ggUHmIVEM24DsQ3cCq3gYQV6peCdAvMCjIaPotj9ivfSh8GHYytE8QETXQlzfNE1V5d0T1X2d0GieBXTZPnv8mWScxyuUoOBPV9E968iJ2Q7WLaN%2FAnWNW%2Byz3azi6N3l%2F980XGM354SWsZWcJpRdPcDc7KBfMZu5C1B18jbL9b9CAAA%2F%2F8%3D" // x2, _ := url.QueryUnescape(x1) // x3, _ := base64.StdEncoding.DecodeString(x2) -// x4, _ := ioutil.ReadAll(flate.NewReader(bytes.NewReader(x3))) +// x4, _ := io.ReadAll(flate.NewReader(bytes.NewReader(x3))) // fmt.Printf("%s\n", x4) type testRandomReader struct { @@ -1877,3 +1877,119 @@ func TestResponseWithDefaultNamespace(t *testing.T) { assert.NilError(t, err) } + +func TestSPRejectsMalformedResponse(t *testing.T) { + test := NewServiceProviderTest(t) + // An actual response from google + TimeNow = func() time.Time { + rv, _ := time.Parse("Mon Jan 2 15:04:05 UTC 2006", "Tue Jan 5 16:55:39 UTC 2016") + return rv + } + Clock = dsig.NewFakeClockAt(TimeNow()) + SamlResponse := golden.Get(t, "TestSPRejectsMalformedResponse_response") + test.IDPMetadata = golden.Get(t, "TestSPRejectsMalformedResponse_IDPMetadata") + + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://29ee6d2e.ngrok.io/saml/metadata"), + AcsURL: mustParseURL("https://29ee6d2e.ngrok.io/saml/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) + assert.Check(t, err) + + // this is a valid response + { + req := http.Request{PostForm: url.Values{}} + req.PostForm.Set("SAMLResponse", string(SamlResponse)) + assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) + assert.Check(t, err) + assert.Check(t, is.Equal("ross@octolabs.io", assertion.Subject.NameID.Value)) + } + + // this is a valid response but with a comment injected + { + x, _ := base64.StdEncoding.DecodeString(string(SamlResponse)) + y := strings.Replace(string(x), "World!"))) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, + "cannot unmarshal response: expected element type but have ")) + + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) + _, err = s.ParseResponse(&req, []string{"wrongRequestID"}) + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, + "`InResponseTo` does not match any of the possible request IDs (expected [wrongRequestID])")) + + TimeNow = func() time.Time { + rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Nov 30 20:57:09 UTC 2016") + return rv + } + Clock = dsig.NewFakeClockAt(TimeNow()) + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, + "response IssueInstant expired at 2015-12-01 01:57:51.375 +0000 UTC")) + TimeNow = func() time.Time { + rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") + return rv + } + Clock = dsig.NewFakeClockAt(TimeNow()) + + s.IDPMetadata.EntityID = "http://snakeoil.com" + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, + "response Issuer does not match the IDP metadata (expected \"http://snakeoil.com\")")) + s.IDPMetadata.EntityID = "https://idp.testshib.org/idp/shibboleth" + + oldSpStatusSuccess := StatusSuccess + StatusSuccess = "not:the:success:value" + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, + "urn:oasis:names:tc:SAML:2.0:status:Success")) + StatusSuccess = oldSpStatusSuccess + + s.IDPMetadata.IDPSSODescriptors[0].KeyDescriptors[0].KeyInfo.X509Data.X509Certificates[0].Data = "invalid" + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, + "cannot validate signature on Assertion: cannot parse certificate: illegal base64 data at input byte 4")) + + s.IDPMetadata.IDPSSODescriptors[0].KeyDescriptors[0].KeyInfo.X509Data.X509Certificates[0].Data = "aW52YWxpZA==" + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, + "cannot validate signature on Assertion: x509: malformed certificate")) +} diff --git a/testsaml/parse.go b/testsaml/parse.go index 63e3dbc5..aa585878 100644 --- a/testsaml/parse.go +++ b/testsaml/parse.go @@ -6,7 +6,7 @@ import ( "compress/flate" "encoding/base64" "fmt" - "io/ioutil" + "io" "net/url" ) @@ -16,7 +16,7 @@ func ParseRedirectRequest(u *url.URL) ([]byte, error) { if err != nil { return nil, fmt.Errorf("cannot decode request: %s", err) } - buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(compressedRequest))) + buf, err := io.ReadAll(flate.NewReader(bytes.NewReader(compressedRequest))) if err != nil { return nil, fmt.Errorf("cannot decompress request: %s", err) } @@ -29,7 +29,7 @@ func ParseRedirectResponse(u *url.URL) ([]byte, error) { if err != nil { return nil, fmt.Errorf("cannot decode response: %s", err) } - buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(compressedResponse))) + buf, err := io.ReadAll(flate.NewReader(bytes.NewReader(compressedResponse))) if err != nil { return nil, fmt.Errorf("cannot decompress response: %s", err) } diff --git a/xmlenc/fuzz_test.go b/xmlenc/fuzz_test.go index 2d83e49c..da3af637 100644 --- a/xmlenc/fuzz_test.go +++ b/xmlenc/fuzz_test.go @@ -4,14 +4,13 @@ package xmlenc import ( - "io/ioutil" "testing" "strings" ) func TestPastFuzzingFailures(t *testing.T) { - entries, err := ioutil.ReadDir("crashers") + entries, err := io.ReadDir("crashers") if err != nil { t.Errorf("%s", err) return @@ -24,7 +23,7 @@ func TestPastFuzzingFailures(t *testing.T) { continue } t.Logf("%s", entry.Name()) - data, err := ioutil.ReadFile("crashers/" + entry.Name()) + data, err := io.ReadFile("crashers/" + entry.Name()) if err != nil { t.Errorf("%s: %s", entry.Name(), err) return diff --git a/xmlenc/xmlenc_test.go b/xmlenc/xmlenc_test.go index cad9b90f..560e515c 100644 --- a/xmlenc/xmlenc_test.go +++ b/xmlenc/xmlenc_test.go @@ -1,8 +1,8 @@ package xmlenc import ( - "io/ioutil" "math/rand" + "os" "testing" "github.com/beevik/etree" @@ -13,7 +13,7 @@ import ( func TestDataAES128(t *testing.T) { t.Run("CBC", func(t *testing.T) { RandReader = rand.New(rand.NewSource(0)) //nolint:gosec // deterministic random numbers for tests - plaintext, err := ioutil.ReadFile("testdata/encrypt-data-aes128-cbc.data") + plaintext, err := os.ReadFile("testdata/encrypt-data-aes128-cbc.data") assert.Check(t, err) var ciphertext string