From 102a7824fdd8b70b9bdc501740e9bff86d71cf32 Mon Sep 17 00:00:00 2001
From: Lance Ivy <lance@cainlevy.net>
Date: Tue, 23 Mar 2021 12:05:39 -0700
Subject: [PATCH] migrate to fix case sensitive usernames

---
 app/data/mock/account_store.go            | 14 ++++++----
 app/data/postgres/migrations.go           |  9 ++++++
 app/data/sqlite3/migrations.go            | 34 +++++++++++++++++++++++
 app/data/testers/account_store_testers.go | 12 ++++++++
 4 files changed, 63 insertions(+), 6 deletions(-)

diff --git a/app/data/mock/account_store.go b/app/data/mock/account_store.go
index ce489fe2c0..9785d9907a 100644
--- a/app/data/mock/account_store.go
+++ b/app/data/mock/account_store.go
@@ -2,6 +2,7 @@ package mock
 
 import (
 	"fmt"
+	"strings"
 	"time"
 
 	"github.com/keratin/authn-server/app/models"
@@ -42,7 +43,7 @@ func (s *accountStore) Find(id int) (*models.Account, error) {
 }
 
 func (s *accountStore) FindByUsername(u string) (*models.Account, error) {
-	id := s.idByUsername[u]
+	id := s.idByUsername[strings.ToLower(u)]
 	if id == 0 {
 		return nil, nil
 	}
@@ -60,7 +61,7 @@ func (s *accountStore) FindByOauthAccount(provider string, providerID string) (*
 }
 
 func (s *accountStore) Create(u string, p []byte) (*models.Account, error) {
-	if s.idByUsername[u] != 0 {
+	if s.idByUsername[strings.ToLower(u)] != 0 {
 		return nil, Error{ErrNotUnique}
 	}
 
@@ -74,7 +75,7 @@ func (s *accountStore) Create(u string, p []byte) (*models.Account, error) {
 		UpdatedAt:         now,
 	}
 	s.accountsByID[acc.ID] = &acc
-	s.idByUsername[acc.Username] = acc.ID
+	s.idByUsername[strings.ToLower(acc.Username)] = acc.ID
 	return dupAccount(acc), nil
 }
 
@@ -114,7 +115,7 @@ func (s *accountStore) Archive(id int) (bool, error) {
 		return false, nil
 	}
 
-	delete(s.idByUsername, account.Username)
+	delete(s.idByUsername, strings.ToLower(account.Username))
 	now := time.Now()
 	account.Username = ""
 	account.Password = []byte("")
@@ -176,18 +177,19 @@ func (s *accountStore) SetPassword(id int, p []byte) (bool, error) {
 }
 
 func (s *accountStore) UpdateUsername(id int, u string) (bool, error) {
+	uNormalized := strings.ToLower(u)
 	account := s.accountsByID[id]
 	if account == nil {
 		return false, nil
 	}
 
-	if s.idByUsername[u] != 0 && s.idByUsername[u] != id {
+	if s.idByUsername[uNormalized] != 0 && s.idByUsername[uNormalized] != id {
 		return false, Error{ErrNotUnique}
 	}
 
 	account.Username = u
 	account.UpdatedAt = time.Now()
-	s.idByUsername[u] = account.ID
+	s.idByUsername[uNormalized] = account.ID
 	return true, nil
 }
 
diff --git a/app/data/postgres/migrations.go b/app/data/postgres/migrations.go
index e41986800d..e321764a78 100644
--- a/app/data/postgres/migrations.go
+++ b/app/data/postgres/migrations.go
@@ -11,6 +11,7 @@ func MigrateDB(db *sqlx.DB) error {
 		migrateAccounts,
 		createOauthAccounts,
 		createAccountLastLoginAtField,
+		caseInsensitiveUsername,
 	}
 	for _, m := range migrations {
 		if err := m(db); err != nil {
@@ -59,3 +60,11 @@ func createAccountLastLoginAtField(db *sqlx.DB) error {
     `)
 	return err
 }
+
+func caseInsensitiveUsername(db *sqlx.DB) error {
+	_, err := db.Exec(`
+        CREATE EXTENSION IF NOT EXISTS citext;
+        ALTER TABLE accounts ALTER COLUMN username TYPE CITEXT;
+    `)
+	return err
+}
diff --git a/app/data/sqlite3/migrations.go b/app/data/sqlite3/migrations.go
index 4a9ebcdea0..0626804cc2 100644
--- a/app/data/sqlite3/migrations.go
+++ b/app/data/sqlite3/migrations.go
@@ -13,6 +13,7 @@ func MigrateDB(db *sqlx.DB) error {
 		createBlobs,
 		createOauthAccounts,
 		createAccountLastLoginAtField,
+		caseInsensitiveUsername,
 	}
 	for _, m := range migrations {
 		if err := m(db); err != nil {
@@ -90,3 +91,36 @@ func createAccountLastLoginAtField(db *sqlx.DB) error {
     `)
 	return err
 }
+
+// caseInsensitiveUsername will migrate the accounts table to use COLLATE NOCASE on username.
+// this will fail if the current accounts table has existing usernames that are equal after
+// the operation.
+func caseInsensitiveUsername(db *sqlx.DB) error {
+	_, err := db.Exec(`
+        BEGIN TRANSACTION;
+
+        ALTER TABLE accounts RENAME TO accounts_old;
+
+        CREATE TABLE accounts (
+            id INTEGER PRIMARY KEY,
+            username TEXT NOT NULL COLLATE NOCASE CONSTRAINT uniq UNIQUE,
+            password TEXT NOT NULL,
+            locked BOOLEAN NOT NULL,
+            require_new_password BOOLEAN NOT NULL,
+            password_changed_at DATETIME NOT NULL,
+            created_at DATETIME NOT NULL,
+            updated_at DATETIME NOT NULL,
+            deleted_at DATETIME,
+            last_login_at DATETIME
+        );
+
+        INSERT INTO accounts(id, username, password, locked, require_new_password, password_changed_at, created_at, updated_at, deleted_at, last_login_at)
+        SELECT id, username, password, locked, require_new_password, password_changed_at, created_at, updated_at, deleted_at, last_login_at
+        FROM accounts_old;
+
+        DROP TABLE accounts_old;
+
+        COMMIT;
+    `)
+	return err
+}
diff --git a/app/data/testers/account_store_testers.go b/app/data/testers/account_store_testers.go
index 43500a6935..9c520b949f 100644
--- a/app/data/testers/account_store_testers.go
+++ b/app/data/testers/account_store_testers.go
@@ -53,6 +53,14 @@ func testCreate(t *testing.T, store data.AccountStore) {
 		t.Errorf("expected uniqueness error, got %T %v", err, err)
 	}
 
+	account, err = store.Create("AUTHN@KERATIN.TECH", []byte("password"))
+	if account != nil {
+		assert.NotEqual(t, nil, account)
+	}
+	if !data.IsUniquenessError(err) {
+		t.Errorf("expected uniqueness error, got %T %v", err, err)
+	}
+
 	// Assert that db connections are released to pool
 	assert.Equal(t, 1, getOpenConnectionCount(store))
 }
@@ -69,6 +77,10 @@ func testFindByUsername(t *testing.T, store data.AccountStore) {
 	assert.NoError(t, err)
 	assert.NotNil(t, account)
 
+	account, err = store.FindByUsername("AUTHN@KERATIN.TECH")
+	assert.NoError(t, err)
+	assert.NotNil(t, account)
+
 	// Assert that db connections are released to pool
 	assert.Equal(t, 1, getOpenConnectionCount(store))
 }