From 102a7824fdd8b70b9bdc501740e9bff86d71cf32 Mon Sep 17 00:00:00 2001 From: Lance Ivy <lance@cainlevy.net> Date: Tue, 23 Mar 2021 12:05:39 -0700 Subject: [PATCH] migrate to fix case sensitive usernames --- app/data/mock/account_store.go | 14 ++++++---- app/data/postgres/migrations.go | 9 ++++++ app/data/sqlite3/migrations.go | 34 +++++++++++++++++++++++ app/data/testers/account_store_testers.go | 12 ++++++++ 4 files changed, 63 insertions(+), 6 deletions(-) diff --git a/app/data/mock/account_store.go b/app/data/mock/account_store.go index ce489fe2c0..9785d9907a 100644 --- a/app/data/mock/account_store.go +++ b/app/data/mock/account_store.go @@ -2,6 +2,7 @@ package mock import ( "fmt" + "strings" "time" "github.com/keratin/authn-server/app/models" @@ -42,7 +43,7 @@ func (s *accountStore) Find(id int) (*models.Account, error) { } func (s *accountStore) FindByUsername(u string) (*models.Account, error) { - id := s.idByUsername[u] + id := s.idByUsername[strings.ToLower(u)] if id == 0 { return nil, nil } @@ -60,7 +61,7 @@ func (s *accountStore) FindByOauthAccount(provider string, providerID string) (* } func (s *accountStore) Create(u string, p []byte) (*models.Account, error) { - if s.idByUsername[u] != 0 { + if s.idByUsername[strings.ToLower(u)] != 0 { return nil, Error{ErrNotUnique} } @@ -74,7 +75,7 @@ func (s *accountStore) Create(u string, p []byte) (*models.Account, error) { UpdatedAt: now, } s.accountsByID[acc.ID] = &acc - s.idByUsername[acc.Username] = acc.ID + s.idByUsername[strings.ToLower(acc.Username)] = acc.ID return dupAccount(acc), nil } @@ -114,7 +115,7 @@ func (s *accountStore) Archive(id int) (bool, error) { return false, nil } - delete(s.idByUsername, account.Username) + delete(s.idByUsername, strings.ToLower(account.Username)) now := time.Now() account.Username = "" account.Password = []byte("") @@ -176,18 +177,19 @@ func (s *accountStore) SetPassword(id int, p []byte) (bool, error) { } func (s *accountStore) UpdateUsername(id int, u string) (bool, error) { + uNormalized := strings.ToLower(u) account := s.accountsByID[id] if account == nil { return false, nil } - if s.idByUsername[u] != 0 && s.idByUsername[u] != id { + if s.idByUsername[uNormalized] != 0 && s.idByUsername[uNormalized] != id { return false, Error{ErrNotUnique} } account.Username = u account.UpdatedAt = time.Now() - s.idByUsername[u] = account.ID + s.idByUsername[uNormalized] = account.ID return true, nil } diff --git a/app/data/postgres/migrations.go b/app/data/postgres/migrations.go index e41986800d..e321764a78 100644 --- a/app/data/postgres/migrations.go +++ b/app/data/postgres/migrations.go @@ -11,6 +11,7 @@ func MigrateDB(db *sqlx.DB) error { migrateAccounts, createOauthAccounts, createAccountLastLoginAtField, + caseInsensitiveUsername, } for _, m := range migrations { if err := m(db); err != nil { @@ -59,3 +60,11 @@ func createAccountLastLoginAtField(db *sqlx.DB) error { `) return err } + +func caseInsensitiveUsername(db *sqlx.DB) error { + _, err := db.Exec(` + CREATE EXTENSION IF NOT EXISTS citext; + ALTER TABLE accounts ALTER COLUMN username TYPE CITEXT; + `) + return err +} diff --git a/app/data/sqlite3/migrations.go b/app/data/sqlite3/migrations.go index 4a9ebcdea0..0626804cc2 100644 --- a/app/data/sqlite3/migrations.go +++ b/app/data/sqlite3/migrations.go @@ -13,6 +13,7 @@ func MigrateDB(db *sqlx.DB) error { createBlobs, createOauthAccounts, createAccountLastLoginAtField, + caseInsensitiveUsername, } for _, m := range migrations { if err := m(db); err != nil { @@ -90,3 +91,36 @@ func createAccountLastLoginAtField(db *sqlx.DB) error { `) return err } + +// caseInsensitiveUsername will migrate the accounts table to use COLLATE NOCASE on username. +// this will fail if the current accounts table has existing usernames that are equal after +// the operation. +func caseInsensitiveUsername(db *sqlx.DB) error { + _, err := db.Exec(` + BEGIN TRANSACTION; + + ALTER TABLE accounts RENAME TO accounts_old; + + CREATE TABLE accounts ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL COLLATE NOCASE CONSTRAINT uniq UNIQUE, + password TEXT NOT NULL, + locked BOOLEAN NOT NULL, + require_new_password BOOLEAN NOT NULL, + password_changed_at DATETIME NOT NULL, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL, + deleted_at DATETIME, + last_login_at DATETIME + ); + + INSERT INTO accounts(id, username, password, locked, require_new_password, password_changed_at, created_at, updated_at, deleted_at, last_login_at) + SELECT id, username, password, locked, require_new_password, password_changed_at, created_at, updated_at, deleted_at, last_login_at + FROM accounts_old; + + DROP TABLE accounts_old; + + COMMIT; + `) + return err +} diff --git a/app/data/testers/account_store_testers.go b/app/data/testers/account_store_testers.go index 43500a6935..9c520b949f 100644 --- a/app/data/testers/account_store_testers.go +++ b/app/data/testers/account_store_testers.go @@ -53,6 +53,14 @@ func testCreate(t *testing.T, store data.AccountStore) { t.Errorf("expected uniqueness error, got %T %v", err, err) } + account, err = store.Create("AUTHN@KERATIN.TECH", []byte("password")) + if account != nil { + assert.NotEqual(t, nil, account) + } + if !data.IsUniquenessError(err) { + t.Errorf("expected uniqueness error, got %T %v", err, err) + } + // Assert that db connections are released to pool assert.Equal(t, 1, getOpenConnectionCount(store)) } @@ -69,6 +77,10 @@ func testFindByUsername(t *testing.T, store data.AccountStore) { assert.NoError(t, err) assert.NotNil(t, account) + account, err = store.FindByUsername("AUTHN@KERATIN.TECH") + assert.NoError(t, err) + assert.NotNil(t, account) + // Assert that db connections are released to pool assert.Equal(t, 1, getOpenConnectionCount(store)) }