From 64601784c5857b95e1380770da6aa639b5ad6947 Mon Sep 17 00:00:00 2001 From: Jonas Hungershausen Date: Wed, 17 Jan 2024 13:01:24 +0100 Subject: [PATCH] feat: order sessions by created_at --- courier/message.go | 8 ++--- ...8000000_sessions_created_at_index.down.sql | 1 + ...0_sessions_created_at_index.mysql.down.sql | 1 + ...628000000_sessions_created_at_index.up.sql | 1 + persistence/sql/persister_session.go | 11 ++++--- session/handler.go | 2 +- session/handler_test.go | 21 ++++++++++++-- session/session.go | 12 ++++++-- session/test/persistence.go | 29 ++++++++++++++----- x/pagination.go | 3 ++ 10 files changed, 67 insertions(+), 22 deletions(-) create mode 100644 persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.down.sql create mode 100644 persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.mysql.down.sql create mode 100644 persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.up.sql diff --git a/courier/message.go b/courier/message.go index ef39514aee93..2cc8ec5a1e32 100644 --- a/courier/message.go +++ b/courier/message.go @@ -13,6 +13,7 @@ import ( "github.com/ory/herodot" "github.com/ory/kratos/courier/template" + "github.com/ory/kratos/x" "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/sqlxx" "github.com/ory/x/stringsx" @@ -115,9 +116,6 @@ const ( messageTypeSMSText = "sms" ) -// The format we need to use in the Page tokens, as it's the only format that is understood by all DBs -const dbFormat = "2006-01-02 15:04:05.99999" - func ToMessageType(str string) (MessageType, error) { switch s := stringsx.SwitchExact(str); { case s.AddCase(messageTypeEmailText): @@ -211,14 +209,14 @@ type Message struct { func (m Message) PageToken() keysetpagination.PageToken { return keysetpagination.MapPageToken{ "id": m.ID.String(), - "created_at": m.CreatedAt.Format(dbFormat), + "created_at": m.CreatedAt.Format(x.MapPaginationDateFormat), } } func (m Message) DefaultPageToken() keysetpagination.PageToken { return keysetpagination.MapPageToken{ "id": uuid.Nil.String(), - "created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(dbFormat), + "created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(x.MapPaginationDateFormat), } } diff --git a/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.down.sql b/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.down.sql new file mode 100644 index 000000000000..5f937e301376 --- /dev/null +++ b/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.down.sql @@ -0,0 +1 @@ +DROP INDEX sessions_nid_created_at_id_idx; \ No newline at end of file diff --git a/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.mysql.down.sql b/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.mysql.down.sql new file mode 100644 index 000000000000..50d8926a5f31 --- /dev/null +++ b/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.mysql.down.sql @@ -0,0 +1 @@ +DROP INDEX sessions_nid_created_at_id_idx ON sessions; \ No newline at end of file diff --git a/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.up.sql b/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.up.sql new file mode 100644 index 000000000000..c2b2f9e080bf --- /dev/null +++ b/persistence/sql/migrations/sql/20240119094628000000_sessions_created_at_index.up.sql @@ -0,0 +1 @@ +CREATE INDEX sessions_nid_created_at_id_idx ON sessions (nid, created_at DESC, id ASC); \ No newline at end of file diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index f8f6b13e8a9e..7cbd968f50a2 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -25,10 +25,12 @@ import ( var _ session.Persister = new(Persister) -const SessionDeviceUserAgentMaxLength = 512 -const SessionDeviceLocationMaxLength = 512 -const paginationMaxItemsSize = 1000 -const paginationDefaultItemsSize = 250 +const ( + SessionDeviceUserAgentMaxLength = 512 + SessionDeviceLocationMaxLength = 512 + paginationMaxItemsSize = 1000 + paginationDefaultItemsSize = 250 +) func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables session.Expandables) (_ *session.Session, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetSession") @@ -73,6 +75,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt paginatorOpts = append(paginatorOpts, keysetpagination.WithDefaultSize(paginationDefaultItemsSize)) paginatorOpts = append(paginatorOpts, keysetpagination.WithMaxSize(paginationMaxItemsSize)) paginatorOpts = append(paginatorOpts, keysetpagination.WithDefaultToken(new(session.Session).DefaultPageToken())) + paginatorOpts = append(paginatorOpts, keysetpagination.WithColumn("created_at", "DESC")) paginator := keysetpagination.GetPaginator(paginatorOpts...) if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { diff --git a/session/handler.go b/session/handler.go index fd3b8ac1deea..4f23881c652c 100644 --- a/session/handler.go +++ b/session/handler.go @@ -383,7 +383,7 @@ func (h *Handler) adminListSessions(w http.ResponseWriter, r *http.Request, ps h } // Parse request pagination parameters - opts, err := keysetpagination.Parse(r.URL.Query(), keysetpagination.NewStringPageToken) + opts, err := keysetpagination.Parse(r.URL.Query(), keysetpagination.NewMapPageToken) if err != nil { h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError("could not parse parameter page_size")) return diff --git a/session/handler_test.go b/session/handler_test.go index b08fb4d1ecb5..dce2f7b05116 100644 --- a/session/handler_test.go +++ b/session/handler_test.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "sort" "strconv" "strings" @@ -18,6 +19,7 @@ import ( "time" "github.com/go-faker/faker/v4" + "github.com/peterhellberg/link" "github.com/tidwall/gjson" "github.com/ory/kratos/identity" @@ -26,6 +28,7 @@ import ( "github.com/pkg/errors" "github.com/ory/kratos/corpx" + "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/sqlcon" "github.com/julienschmidt/httprouter" @@ -557,13 +560,27 @@ func TestHandlerAdminSessionManagement(t *testing.T) { require.Equal(t, ts.URL+"/sessions/whoami", res.Header.Get("Location")) }) + assertPageToken := func(t *testing.T, id, linkHeader string) { + t.Helper() + + g := link.Parse(linkHeader) + require.Len(t, g, 1) + u, err := url.Parse(g["first"].URI) + require.NoError(t, err) + pt, err := keysetpagination.NewMapPageToken(u.Query().Get("page_token")) + require.NoError(t, err) + mpt := pt.(keysetpagination.MapPageToken) + assert.Equal(t, id, mpt["id"]) + } + t.Run("list sessions", func(t *testing.T) { req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/", nil) res, err := client.Do(req) require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, "1", res.Header.Get("X-Total-Count")) - assert.Equal(t, "; rel=\"first\"", res.Header.Get("Link")) + + assertPageToken(t, uuid.Nil.String(), res.Header.Get("Link")) var sessions []Session require.NoError(t, json.NewDecoder(res.Body).Decode(&sessions)) @@ -611,7 +628,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) { require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, "1", res.Header.Get("X-Total-Count")) - assert.Equal(t, "; rel=\"first\"", res.Header.Get("Link")) + assertPageToken(t, uuid.Nil.String(), res.Header.Get("Link")) body := ioutilx.MustReadAll(res.Body) assert.Equal(t, s.ID.String(), gjson.GetBytes(body, "0.id").String()) diff --git a/session/session.go b/session/session.go index d11a05e3bf05..84f64ceec0d6 100644 --- a/session/session.go +++ b/session/session.go @@ -153,11 +153,17 @@ type Session struct { } func (s Session) PageToken() keysetpagination.PageToken { - return keysetpagination.StringPageToken(s.ID.String()) + return keysetpagination.MapPageToken{ + "id": s.ID.String(), + "created_at": s.CreatedAt.Format(x.MapPaginationDateFormat), + } } -func (s Session) DefaultPageToken() keysetpagination.PageToken { - return keysetpagination.StringPageToken(uuid.Nil.String()) +func (m Session) DefaultPageToken() keysetpagination.PageToken { + return keysetpagination.MapPageToken{ + "id": uuid.Nil.String(), + "created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(x.MapPaginationDateFormat), + } } func (s Session) TableName(ctx context.Context) string { diff --git a/session/test/persistence.go b/session/test/persistence.go index fb6a7c469830..727cc744bdce 100644 --- a/session/test/persistence.go +++ b/session/test/persistence.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/gobuffalo/pop/v6" + "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/pointerx" @@ -30,7 +32,8 @@ import ( func TestPersister(ctx context.Context, conf *config.Config, p interface { persistence.Persister -}) func(t *testing.T) { +}, +) func(t *testing.T) { return func(t *testing.T) { _, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) @@ -149,6 +152,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { seedSessionIDs := make([]uuid.UUID, 5) seedSessionsList := make([]session.Session, 5) + start := time.Now() for j := range seedSessionsList { require.NoError(t, faker.FakeData(&seedSessionsList[j])) seedSessionsList[j].Identity = &identity1 @@ -165,9 +169,13 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { seedSessionsList[j].Devices = []session.Device{ device, } + pop.SetNowFunc(func() time.Time { + return start.Add(time.Duration(j) * time.Minute) + }) require.NoError(t, l.UpsertSession(ctx, &seedSessionsList[j])) seedSessionIDs[j] = seedSessionsList[j].ID } + pop.SetNowFunc(time.Now) identity2Session.Identity = &identity2 identity2Session.Active = true @@ -288,7 +296,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { require.Equal(t, len(tc.expected), len(actual)) require.Equal(t, int64(len(tc.expected)), total) assert.Equal(t, true, nextPage.IsLast()) - assert.Equal(t, uuid.Nil.String(), nextPage.Token().Encode()) + + mapPageToken := nextPage.Token().Parse("") + assert.Equal(t, uuid.Nil.String(), mapPageToken["id"]) + assert.Equal(t, 250, nextPage.Size()) for _, es := range tc.expected { found := false @@ -312,7 +323,8 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { require.Equal(t, 6, len(actual)) require.Equal(t, int64(6), total) assert.Equal(t, true, page.IsLast()) - assert.Equal(t, uuid.Nil.String(), page.Token().Encode()) + mapPageToken := page.Token().Parse("") + assert.Equal(t, uuid.Nil.String(), mapPageToken["id"]) assert.Equal(t, 250, page.Size()) }) @@ -325,21 +337,24 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { assert.Len(t, firstPageItems, 3) assert.Equal(t, false, page1.IsLast()) - assert.Equal(t, firstPageItems[len(firstPageItems)-1].ID.String(), page1.Token().Encode()) + mapPageToken := page1.Token().Parse("") + assert.Equal(t, firstPageItems[len(firstPageItems)-1].ID.String(), mapPageToken["id"]) assert.Equal(t, 3, page1.Size()) // Validate secondPageItems page secondPageItems, total, page2, err := l.ListSessions(ctx, nil, page1.ToOptions(), session.ExpandEverything) require.NoError(t, err) + require.Equal(t, int64(6), total) + assert.Len(t, secondPageItems, 3) acutalIDs := make([]uuid.UUID, 0) for _, s := range append(firstPageItems, secondPageItems...) { acutalIDs = append(acutalIDs, s.ID) } - assert.ElementsMatch(t, append(seedSessionIDs, identity2Session.ID), acutalIDs) + expect := append(seedSessionIDs, identity2Session.ID) + require.Len(t, acutalIDs, len(expect)) + assert.ElementsMatch(t, expect, acutalIDs) - require.Equal(t, int64(6), total) - assert.Len(t, secondPageItems, 3) assert.True(t, page2.IsLast()) assert.Equal(t, 3, page2.Size()) }) diff --git a/x/pagination.go b/x/pagination.go index 7e31cd8f7d2e..5d633f19dc58 100644 --- a/x/pagination.go +++ b/x/pagination.go @@ -13,6 +13,9 @@ import ( "github.com/ory/x/pagination/pagepagination" ) +// The format we need to use in the Page tokens, as it's the only format that is understood by all DBs +const MapPaginationDateFormat = "2006-01-02 15:04:05.99999" + // ParsePagination parses limit and page from *http.Request with given limits and defaults. func ParsePagination(r *http.Request) (page, itemsPerPage int) { return migrationpagination.NewDefaultPaginator().ParsePagination(r)