|
| 1 | +package sp |
| 2 | + |
| 3 | +import ( |
| 4 | + "encoding/base64" |
| 5 | + "fmt" |
| 6 | + "io" |
| 7 | + "net/http" |
| 8 | + "strings" |
| 9 | + "time" |
| 10 | + |
| 11 | + "github.com/crewjam/saml" |
| 12 | + "github.com/crewjam/saml/samlsp" |
| 13 | +) |
| 14 | + |
| 15 | +func DefaultRequestTracker(opts samlsp.Options, serviceProvider *saml.ServiceProvider) CookieRequestTracker { |
| 16 | + return CookieRequestTracker{ |
| 17 | + ServiceProvider: serviceProvider, |
| 18 | + NamePrefix: "saml_", |
| 19 | + Codec: samlsp.DefaultTrackedRequestCodec(opts), |
| 20 | + MaxAge: saml.MaxIssueDelay, |
| 21 | + RelayStateFunc: opts.RelayStateFunc, |
| 22 | + SameSite: opts.CookieSameSite, |
| 23 | + CookieDomain: getDomain(&serviceProvider.AcsURL), |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +// CookieRequestTracker tracks requests by setting a uniquely named |
| 28 | +// cookie for each request. |
| 29 | +// |
| 30 | +// This implementation is idenitical to samlsp.CookieRequestTracker apart |
| 31 | +// from the addition of setting the CookieDomain for the tracker cookie. |
| 32 | +type CookieRequestTracker struct { |
| 33 | + ServiceProvider *saml.ServiceProvider |
| 34 | + NamePrefix string |
| 35 | + Codec samlsp.TrackedRequestCodec |
| 36 | + MaxAge time.Duration |
| 37 | + RelayStateFunc func(w http.ResponseWriter, r *http.Request) string |
| 38 | + SameSite http.SameSite |
| 39 | + CookieDomain string |
| 40 | +} |
| 41 | + |
| 42 | +// TrackRequest starts tracking the SAML request with the given ID. It returns an |
| 43 | +// `index` that should be used as the RelayState in the SAMl request flow. |
| 44 | +func (t CookieRequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (string, error) { |
| 45 | + trackedRequest := samlsp.TrackedRequest{ |
| 46 | + Index: base64.RawURLEncoding.EncodeToString(randomBytes(42)), |
| 47 | + SAMLRequestID: samlRequestID, |
| 48 | + URI: r.URL.String(), |
| 49 | + } |
| 50 | + |
| 51 | + if t.RelayStateFunc != nil { |
| 52 | + relayState := t.RelayStateFunc(w, r) |
| 53 | + if relayState != "" { |
| 54 | + trackedRequest.Index = relayState |
| 55 | + } |
| 56 | + } |
| 57 | + |
| 58 | + signedTrackedRequest, err := t.Codec.Encode(trackedRequest) |
| 59 | + if err != nil { |
| 60 | + return "", err |
| 61 | + } |
| 62 | + |
| 63 | + http.SetCookie(w, &http.Cookie{ |
| 64 | + Name: t.NamePrefix + trackedRequest.Index, |
| 65 | + Value: signedTrackedRequest, |
| 66 | + MaxAge: int(t.MaxAge.Seconds()), |
| 67 | + HttpOnly: true, |
| 68 | + SameSite: t.SameSite, |
| 69 | + Secure: t.ServiceProvider.AcsURL.Scheme == "https", |
| 70 | + Path: t.ServiceProvider.AcsURL.Path, |
| 71 | + Domain: t.CookieDomain, |
| 72 | + }) |
| 73 | + |
| 74 | + return trackedRequest.Index, nil |
| 75 | +} |
| 76 | + |
| 77 | +// StopTrackingRequest stops tracking the SAML request given by index, which is a string |
| 78 | +// previously returned from TrackRequest |
| 79 | +func (t CookieRequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error { |
| 80 | + cookie, err := r.Cookie(t.NamePrefix + index) |
| 81 | + if err != nil { |
| 82 | + return err |
| 83 | + } |
| 84 | + cookie.Value = "" |
| 85 | + cookie.Domain = t.CookieDomain |
| 86 | + cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{} |
| 87 | + http.SetCookie(w, cookie) |
| 88 | + return nil |
| 89 | +} |
| 90 | + |
| 91 | +// GetTrackedRequests returns all the pending tracked requests |
| 92 | +func (t CookieRequestTracker) GetTrackedRequests(r *http.Request) []samlsp.TrackedRequest { |
| 93 | + rv := []samlsp.TrackedRequest{} |
| 94 | + for _, cookie := range r.Cookies() { |
| 95 | + if !strings.HasPrefix(cookie.Name, t.NamePrefix) { |
| 96 | + continue |
| 97 | + } |
| 98 | + |
| 99 | + trackedRequest, err := t.Codec.Decode(cookie.Value) |
| 100 | + if err != nil { |
| 101 | + continue |
| 102 | + } |
| 103 | + index := strings.TrimPrefix(cookie.Name, t.NamePrefix) |
| 104 | + if index != trackedRequest.Index { |
| 105 | + continue |
| 106 | + } |
| 107 | + |
| 108 | + rv = append(rv, *trackedRequest) |
| 109 | + } |
| 110 | + return rv |
| 111 | +} |
| 112 | + |
| 113 | +// GetTrackedRequest returns a pending tracked request. |
| 114 | +func (t CookieRequestTracker) GetTrackedRequest(r *http.Request, index string) (*samlsp.TrackedRequest, error) { |
| 115 | + cookie, err := r.Cookie(t.NamePrefix + index) |
| 116 | + if err != nil { |
| 117 | + return nil, err |
| 118 | + } |
| 119 | + |
| 120 | + trackedRequest, err := t.Codec.Decode(cookie.Value) |
| 121 | + if err != nil { |
| 122 | + return nil, err |
| 123 | + } |
| 124 | + if trackedRequest.Index != index { |
| 125 | + return nil, fmt.Errorf("expected index %q, got %q", index, trackedRequest.Index) |
| 126 | + } |
| 127 | + return trackedRequest, nil |
| 128 | +} |
| 129 | + |
| 130 | +func randomBytes(n int) []byte { |
| 131 | + rv := make([]byte, n) |
| 132 | + |
| 133 | + if _, err := io.ReadFull(saml.RandReader, rv); err != nil { |
| 134 | + panic(err) |
| 135 | + } |
| 136 | + return rv |
| 137 | +} |
0 commit comments