Skip to content

Commit 0d1a0e4

Browse files
committed
Device token api endpoint (#1)
* Added /device/token handler with associated business logic and storage tests. * Use crypto rand for user code Signed-off-by: justin-slowik <[email protected]>
1 parent 6d343e0 commit 0d1a0e4

File tree

10 files changed

+161
-47
lines changed

10 files changed

+161
-47
lines changed

server/handlers.go

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8-
"io"
98
"net/http"
109
"net/url"
1110
"path"
@@ -1438,9 +1437,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
14381437
case http.MethodPost:
14391438
err := r.ParseForm()
14401439
if err != nil {
1441-
message := "Could not parse Device Request body"
1442-
s.logger.Errorf("%s : %v", message, err)
1443-
respondWithError(w, message, err)
1440+
s.logger.Errorf("Could not parse Device Request body: %v", err)
1441+
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
14441442
return
14451443
}
14461444

@@ -1454,7 +1452,11 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
14541452
deviceCode := storage.NewDeviceCode()
14551453

14561454
//make user code
1457-
userCode := storage.NewUserCode()
1455+
userCode, err := storage.NewUserCode()
1456+
if err != nil {
1457+
s.logger.Errorf("Error generating user code: %v", err)
1458+
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
1459+
}
14581460

14591461
//make a pkce verification code
14601462
pkceCode := storage.NewID()
@@ -1473,24 +1475,21 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
14731475
}
14741476

14751477
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
1476-
message := fmt.Sprintf("Failed to store device request %v", err)
1477-
s.logger.Errorf(message)
1478-
respondWithError(w, message, err)
1478+
s.logger.Errorf("Failed to store device request; %v", err)
1479+
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
14791480
return
14801481
}
14811482

14821483
//Store the device token
14831484
deviceToken := storage.DeviceToken{
14841485
DeviceCode: deviceCode,
1485-
Status: "pending",
1486-
Token: "",
1486+
Status: deviceTokenPending,
14871487
Expiry: expireTime,
14881488
}
14891489

14901490
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
1491-
message := fmt.Sprintf("Failed to store device token %v", err)
1492-
s.logger.Errorf(message)
1493-
respondWithError(w, message, err)
1491+
s.logger.Errorf("Failed to store device token %v", err)
1492+
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
14941493
return
14951494
}
14961495

@@ -1507,20 +1506,53 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
15071506
enc.Encode(code)
15081507

15091508
default:
1510-
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
1509+
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
1510+
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
15111511
}
15121512
}
15131513

1514-
func respondWithError(w io.Writer, errorMessage string, err error) {
1515-
resp := struct {
1516-
Error string `json:"error"`
1517-
ErrorMessage string `json:"message"`
1518-
}{
1519-
Error: err.Error(),
1520-
ErrorMessage: errorMessage,
1521-
}
1514+
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
1515+
w.Header().Set("Content-Type", "application/json")
1516+
switch r.Method {
1517+
case http.MethodPost:
1518+
err := r.ParseForm()
1519+
if err != nil {
1520+
message := "Could not parse Device Token Request body"
1521+
s.logger.Warnf("%s : %v", message, err)
1522+
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
1523+
return
1524+
}
1525+
1526+
deviceCode := r.Form.Get("device_code")
1527+
if deviceCode == "" {
1528+
message := "No device code received"
1529+
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
1530+
return
1531+
}
1532+
1533+
grantType := r.PostFormValue("grant_type")
1534+
if grantType != grantTypeDeviceCode {
1535+
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
1536+
return
1537+
}
1538+
1539+
//Grab the device token from the db
1540+
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
1541+
if err != nil || s.now().After(deviceToken.Expiry) {
1542+
if err != storage.ErrNotFound {
1543+
s.logger.Errorf("failed to get device code: %v", err)
1544+
}
1545+
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
1546+
return
1547+
}
15221548

1523-
enc := json.NewEncoder(w)
1524-
enc.SetIndent("", " ")
1525-
enc.Encode(resp)
1549+
switch deviceToken.Status {
1550+
case deviceTokenPending:
1551+
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
1552+
case deviceTokenComplete:
1553+
w.Write([]byte(deviceToken.Token))
1554+
}
1555+
default:
1556+
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
1557+
}
15261558
}

server/oauth2.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ const (
122122
grantTypeAuthorizationCode = "authorization_code"
123123
grantTypeRefreshToken = "refresh_token"
124124
grantTypePassword = "password"
125+
grantTypeDeviceCode = "device_code"
125126
)
126127

127128
const (
@@ -130,6 +131,11 @@ const (
130131
responseTypeIDToken = "id_token" // ID Token in url fragment
131132
)
132133

134+
const (
135+
deviceTokenPending = "authorization_pending"
136+
deviceTokenComplete = "complete"
137+
)
138+
133139
func parseScopes(scopes []string) connector.Scopes {
134140
var s connector.Scopes
135141
for _, scope := range scopes {

server/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
303303
handleFunc("/auth", s.handleAuthorization)
304304
handleFunc("/auth/{connector}", s.handleConnectorLogin)
305305
handleFunc("/device/code", s.handleDeviceCode)
306+
handleFunc("/device/token", s.handleDeviceToken)
306307
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
307308
// Strip the X-Remote-* headers to prevent security issues on
308309
// misconfigured authproxy connector setups.

storage/conformance/conformance.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,13 @@ func testGC(t *testing.T, s storage.Storage) {
837837
t.Errorf("expected storage.ErrNotFound, got %v", err)
838838
}
839839

840+
userCode, err := storage.NewUserCode()
841+
if err != nil {
842+
t.Errorf("Unexpected Error: %v", err)
843+
}
844+
840845
d := storage.DeviceRequest{
841-
UserCode: storage.NewUserCode(),
846+
UserCode: userCode,
842847
DeviceCode: storage.NewID(),
843848
ClientID: "client1",
844849
Scopes: []string{"openid", "email"},
@@ -896,22 +901,21 @@ func testGC(t *testing.T, s storage.Storage) {
896901
t.Errorf("expected no device token garbage collection results, got %#v", result)
897902
}
898903
}
899-
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
900-
// t.Errorf("expected to be able to get auth request after GC: %v", err)
901-
//}
904+
if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
905+
t.Errorf("expected to be able to get device token after GC: %v", err)
906+
}
902907
}
903908
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
904909
t.Errorf("garbage collection failed: %v", err)
905910
} else if r.DeviceTokens != 1 {
906911
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
907912
}
908913

909-
//TODO add this code back once Getters are written for device tokens
910-
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
911-
// t.Errorf("expected device request to be GC'd")
912-
//} else if err != storage.ErrNotFound {
913-
// t.Errorf("expected storage.ErrNotFound, got %v", err)
914-
//}
914+
if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
915+
t.Errorf("expected device token to be GC'd")
916+
} else if err != storage.ErrNotFound {
917+
t.Errorf("expected storage.ErrNotFound, got %v", err)
918+
}
915919
}
916920

917921
// testTimezones tests that backends either fully support timezones or
@@ -961,8 +965,12 @@ func testTimezones(t *testing.T, s storage.Storage) {
961965
}
962966

963967
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
968+
userCode, err := storage.NewUserCode()
969+
if err != nil {
970+
panic(err)
971+
}
964972
d1 := storage.DeviceRequest{
965-
UserCode: storage.NewUserCode(),
973+
UserCode: userCode,
966974
DeviceCode: storage.NewID(),
967975
ClientID: "client1",
968976
Scopes: []string{"openid", "email"},
@@ -975,7 +983,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
975983
}
976984

977985
// Attempt to create same DeviceRequest twice.
978-
err := s.CreateDeviceRequest(d1)
986+
err = s.CreateDeviceRequest(d1)
979987
mustBeErrAlreadyExists(t, "device request", err)
980988

981989
//No manual deletes for device requests, will be handled by garbage collection routines

storage/etcd/etcd.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,13 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
591591
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
592592
}
593593

594+
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
595+
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
596+
defer cancel()
597+
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
598+
return t, err
599+
}
600+
594601
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
595602
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
596603
if err != nil {

storage/kubernetes/storage.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,3 +641,11 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
641641
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
642642
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
643643
}
644+
645+
func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
646+
var token DeviceToken
647+
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
648+
return storage.DeviceToken{}, err
649+
}
650+
return toStorageDeviceToken(token), nil
651+
}

storage/kubernetes/types.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,12 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
739739
}
740740
return req
741741
}
742+
743+
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
744+
return storage.DeviceToken{
745+
DeviceCode: t.ObjectMeta.Name,
746+
Status: t.Status,
747+
Token: t.Token,
748+
Expiry: t.Expiry,
749+
}
750+
}

storage/memory/memory.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,14 @@ func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
503503
})
504504
return
505505
}
506+
507+
func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
508+
s.tx(func() {
509+
var ok bool
510+
if t, ok = s.deviceTokens[deviceCode]; !ok {
511+
err = storage.ErrNotFound
512+
return
513+
}
514+
})
515+
return
516+
}

storage/sql/crud.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,3 +922,25 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
922922
}
923923
return nil
924924
}
925+
926+
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
927+
return getDeviceToken(c, deviceCode)
928+
}
929+
930+
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
931+
err = q.QueryRow(`
932+
select
933+
status, token, expiry
934+
from device_token where device_code = $1;
935+
`, deviceCode).Scan(
936+
&a.Status, &a.Token, &a.Expiry,
937+
)
938+
if err != nil {
939+
if err == sql.ErrNoRows {
940+
return a, storage.ErrNotFound
941+
}
942+
return a, fmt.Errorf("select device token: %v", err)
943+
}
944+
a.DeviceCode = deviceCode
945+
return a, nil
946+
}

storage/storage.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"encoding/base32"
66
"errors"
77
"io"
8-
mrand "math/rand"
8+
"math/big"
99
"strings"
1010
"time"
1111

@@ -25,6 +25,9 @@ var (
2525
// TODO(ericchiang): refactor ID creation onto the storage.
2626
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
2727

28+
//Valid characters for user codes
29+
const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ"
30+
2831
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
2932
func NewDeviceCode() string {
3033
return newSecureID(32)
@@ -79,6 +82,7 @@ type Storage interface {
7982
GetPassword(email string) (Password, error)
8083
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
8184
GetConnector(id string) (Connector, error)
85+
GetDeviceToken(deviceCode string) (DeviceToken, error)
8286

8387
ListClients() ([]Client, error)
8488
ListRefreshTokens() ([]RefreshToken, error)
@@ -357,18 +361,24 @@ type Keys struct {
357361
NextRotation time.Time
358362
}
359363

360-
func NewUserCode() string {
361-
mrand.Seed(time.Now().UnixNano())
362-
return randomString(4) + "-" + randomString(4)
364+
// NewUserCode returns a randomized 8 character user code for the device flow.
365+
// No vowels are included to prevent accidental generation of words
366+
func NewUserCode() (string, error) {
367+
code, err := randomString(8)
368+
if err != nil {
369+
return "", err
370+
}
371+
return code[:4] + "-" + code[4:], nil
363372
}
364373

365-
func randomString(n int) string {
366-
var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
367-
b := make([]rune, n)
368-
for i := range b {
369-
b[i] = letter[mrand.Intn(len(letter))]
374+
func randomString(n int) (string, error) {
375+
v := big.NewInt(int64(len(validUserCharacters)))
376+
bytes := make([]byte, n)
377+
for i := 0; i < n; i++ {
378+
c, _ := rand.Int(rand.Reader, v)
379+
bytes[i] = validUserCharacters[c.Int64()]
370380
}
371-
return string(b)
381+
return string(bytes), nil
372382
}
373383

374384
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user

0 commit comments

Comments
 (0)