Skip to content

Commit c514b17

Browse files
committed
Add custom requeste tracker
1 parent ac4632d commit c514b17

File tree

2 files changed

+141
-6
lines changed

2 files changed

+141
-6
lines changed

pkg/sp/sp.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ func NewServiceProvider(cert, key string, metadata interface{}, root *url.URL, m
9494
// set SHA256 as the signature method
9595
mw.ServiceProvider.SignatureMethod = dsig.RSASHA256SignatureMethod
9696

97+
// use custom request tracker
98+
tracker := DefaultRequestTracker(opts, &mw.ServiceProvider)
99+
mw.RequestTracker = tracker
100+
97101
// set up custom session provider
98102
if err := setSessionProvider(root, mw); err != nil {
99103
return nil, fmt.Errorf("session provider error: %w", err)
@@ -325,12 +329,6 @@ func (s *ServiceProvider) doAuthFlow(w http.ResponseWriter, r *http.Request) {
325329
// transfer headers to response
326330
for header, v := range rr.Result().Header {
327331
for _, item := range v {
328-
if header == "Set-Cookie" {
329-
// add Domain to cookie if not set
330-
if !strings.Contains(item, "Domain=") {
331-
item = item + "; Domain=" + s.mw.Session.(samlsp.CookieSessionProvider).Domain
332-
}
333-
}
334332
w.Header().Add(header, item)
335333
}
336334
}

pkg/sp/tracker.go

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package sp
2+
3+
import (
4+
"encoding/base64"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"strings"
9+
"time"
10+
11+
"github.com/crewjam/saml"
12+
"github.com/crewjam/saml/samlsp"
13+
)
14+
15+
func DefaultRequestTracker(opts samlsp.Options, serviceProvider *saml.ServiceProvider) CookieRequestTracker {
16+
return CookieRequestTracker{
17+
ServiceProvider: serviceProvider,
18+
NamePrefix: "saml_",
19+
Codec: samlsp.DefaultTrackedRequestCodec(opts),
20+
MaxAge: saml.MaxIssueDelay,
21+
RelayStateFunc: opts.RelayStateFunc,
22+
SameSite: opts.CookieSameSite,
23+
CookieDomain: getDomain(&serviceProvider.AcsURL),
24+
}
25+
}
26+
27+
// CookieRequestTracker tracks requests by setting a uniquely named
28+
// cookie for each request.
29+
//
30+
// This implementation is idenitical to samlsp.CookieRequestTracker apart
31+
// from the addition of setting the CookieDomain for the tracker cookie.
32+
type CookieRequestTracker struct {
33+
ServiceProvider *saml.ServiceProvider
34+
NamePrefix string
35+
Codec samlsp.TrackedRequestCodec
36+
MaxAge time.Duration
37+
RelayStateFunc func(w http.ResponseWriter, r *http.Request) string
38+
SameSite http.SameSite
39+
CookieDomain string
40+
}
41+
42+
// TrackRequest starts tracking the SAML request with the given ID. It returns an
43+
// `index` that should be used as the RelayState in the SAMl request flow.
44+
func (t CookieRequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (string, error) {
45+
trackedRequest := samlsp.TrackedRequest{
46+
Index: base64.RawURLEncoding.EncodeToString(randomBytes(42)),
47+
SAMLRequestID: samlRequestID,
48+
URI: r.URL.String(),
49+
}
50+
51+
if t.RelayStateFunc != nil {
52+
relayState := t.RelayStateFunc(w, r)
53+
if relayState != "" {
54+
trackedRequest.Index = relayState
55+
}
56+
}
57+
58+
signedTrackedRequest, err := t.Codec.Encode(trackedRequest)
59+
if err != nil {
60+
return "", err
61+
}
62+
63+
http.SetCookie(w, &http.Cookie{
64+
Name: t.NamePrefix + trackedRequest.Index,
65+
Value: signedTrackedRequest,
66+
MaxAge: int(t.MaxAge.Seconds()),
67+
HttpOnly: true,
68+
SameSite: t.SameSite,
69+
Secure: t.ServiceProvider.AcsURL.Scheme == "https",
70+
Path: t.ServiceProvider.AcsURL.Path,
71+
Domain: t.CookieDomain,
72+
})
73+
74+
return trackedRequest.Index, nil
75+
}
76+
77+
// StopTrackingRequest stops tracking the SAML request given by index, which is a string
78+
// previously returned from TrackRequest
79+
func (t CookieRequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error {
80+
cookie, err := r.Cookie(t.NamePrefix + index)
81+
if err != nil {
82+
return err
83+
}
84+
cookie.Value = ""
85+
cookie.Domain = t.CookieDomain
86+
cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{}
87+
http.SetCookie(w, cookie)
88+
return nil
89+
}
90+
91+
// GetTrackedRequests returns all the pending tracked requests
92+
func (t CookieRequestTracker) GetTrackedRequests(r *http.Request) []samlsp.TrackedRequest {
93+
rv := []samlsp.TrackedRequest{}
94+
for _, cookie := range r.Cookies() {
95+
if !strings.HasPrefix(cookie.Name, t.NamePrefix) {
96+
continue
97+
}
98+
99+
trackedRequest, err := t.Codec.Decode(cookie.Value)
100+
if err != nil {
101+
continue
102+
}
103+
index := strings.TrimPrefix(cookie.Name, t.NamePrefix)
104+
if index != trackedRequest.Index {
105+
continue
106+
}
107+
108+
rv = append(rv, *trackedRequest)
109+
}
110+
return rv
111+
}
112+
113+
// GetTrackedRequest returns a pending tracked request.
114+
func (t CookieRequestTracker) GetTrackedRequest(r *http.Request, index string) (*samlsp.TrackedRequest, error) {
115+
cookie, err := r.Cookie(t.NamePrefix + index)
116+
if err != nil {
117+
return nil, err
118+
}
119+
120+
trackedRequest, err := t.Codec.Decode(cookie.Value)
121+
if err != nil {
122+
return nil, err
123+
}
124+
if trackedRequest.Index != index {
125+
return nil, fmt.Errorf("expected index %q, got %q", index, trackedRequest.Index)
126+
}
127+
return trackedRequest, nil
128+
}
129+
130+
func randomBytes(n int) []byte {
131+
rv := make([]byte, n)
132+
133+
if _, err := io.ReadFull(saml.RandReader, rv); err != nil {
134+
panic(err)
135+
}
136+
return rv
137+
}

0 commit comments

Comments
 (0)