Skip to content

Commit

Permalink
fix: add special case for MySQL
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Aug 10, 2023
1 parent bf7b3ef commit 2b3718d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 5 deletions.
6 changes: 4 additions & 2 deletions oauth2/fosite_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func setupRegistries(t *testing.T) {
}

func TestManagers(t *testing.T) {
ctx := context.TODO()
ctx := context.Background()
tests := []struct {
name string
enableSessionEncrypted bool
Expand All @@ -67,7 +67,9 @@ func TestManagers(t *testing.T) {
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tc := tc
t.Run("suite="+tc.name, func(t *testing.T) {
t.Parallel()
setupRegistries(t)

require.NoError(t, registries["memory"].ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foobar"})) // this is a workaround because the client is not being created for memory store by test helpers.
Expand Down
1 change: 0 additions & 1 deletion persistence/sql/migratest/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ func CompareWithFixture(t *testing.T, actual interface{}, prefix string, id stri
}

func TestMigrations(t *testing.T) {
//pop.Debug = true
connections := make(map[string]*pop.Connection, 1)

if testing.Short() {
Expand Down
35 changes: 35 additions & 0 deletions persistence/sql/persister_consent.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,11 @@ func (p *Persister) DeleteLoginSession(ctx context.Context, id string) (deletedS
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginSession")
defer otelx.End(span, &err)

if p.Connection(ctx).Dialect.Name() == "mysql" {
// MySQL does not support RETURNING.
return p.mySQLDeleteLoginSession(ctx, id)
}

var session flow.LoginSession

err = p.Connection(ctx).RawQuery(
Expand All @@ -477,6 +482,36 @@ func (p *Persister) DeleteLoginSession(ctx context.Context, id string) (deletedS
return &session, nil
}

func (p *Persister) mySQLDeleteLoginSession(ctx context.Context, id string) (*flow.LoginSession, error) {
var session flow.LoginSession

err := p.Connection(ctx).Transaction(func(tx *pop.Connection) error {
err := tx.RawQuery(`
SELECT * FROM hydra_oauth2_authentication_session
WHERE id = ? AND nid = ?`,
id,
p.NetworkID(ctx),
).First(&session)
if err != nil {
return err
}

return p.Connection(ctx).RawQuery(`
DELETE FROM hydra_oauth2_authentication_session
WHERE id = ? AND nid = ?`,
id,
p.NetworkID(ctx),
).Exec()
})

if err != nil {
return nil, sqlcon.HandleError(err)
}

return &session, nil

}

func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) (rs []flow.AcceptOAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequests")
defer span.End()
Expand Down
4 changes: 2 additions & 2 deletions persistence/sql/persister_nid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (s *PersisterTestSuite) TestConfirmLoginSession() {
require.NoError(t, r.Persister().CreateLoginSession(s.t1, ls))

// Expects the login session to be confirmed in the correct context.
require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember))
require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, ls, ls.ID, time.Now().UTC(), ls.Subject, !ls.Remember))
actual := &flow.LoginSession{}
require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID))
exp, _ := json.Marshal(ls)
Expand All @@ -199,7 +199,7 @@ func (s *PersisterTestSuite) TestConfirmLoginSession() {

// Can't find the login session in the wrong context.
require.ErrorIs(t,
r.Persister().ConfirmLoginSession(s.t2, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember),
r.Persister().ConfirmLoginSession(s.t2, ls, ls.ID, time.Now().UTC(), ls.Subject, !ls.Remember),
x.ErrNotFound,
)
})
Expand Down

0 comments on commit 2b3718d

Please sign in to comment.