diff --git a/domain/user/service/service.go b/domain/user/service/service.go index 98e8fbe8aa8..f6788976670 100644 --- a/domain/user/service/service.go +++ b/domain/user/service/service.go @@ -79,6 +79,11 @@ type State interface { // If no user is found for the supplied UUID an error is returned that // satisfies usererrors.NotFound. DisableUserAuthentication(context.Context, user.UUID) error + + // UpdateLastLogin will update the last login time for the user. + // If no user is found for the supplied UUID an error is returned that + // satisfies usererrors.NotFound. + UpdateLastLogin(context.Context, user.UUID) error } // Service provides the API for working with users. @@ -404,6 +409,22 @@ func (s *Service) DisableUserAuthentication(ctx context.Context, uuid user.UUID) return nil } +// UpdateLastLogin will update the last login time for the user. +// +// The following error types are possible from this function: +// - usererrors.UUIDNotValid: When the UUID supplied is not valid. +// - usererrors.NotFound: If no user by the given UUID exists. +func (s *Service) UpdateLastLogin(ctx context.Context, uuid user.UUID) error { + if err := uuid.Validate(); err != nil { + return errors.Annotatef(usererrors.UUIDNotValid, "%q", uuid) + } + + if err := s.st.UpdateLastLogin(ctx, uuid); err != nil { + return errors.Annotatef(err, "updating last login for user with uuid %q", uuid) + } + return nil +} + // generateActivationKey is responsible for generating a new activation key that // can be used for supplying to a user. func generateActivationKey() ([]byte, error) { diff --git a/domain/user/service/service_test.go b/domain/user/service/service_test.go index 310c6bddf61..6bed630db07 100644 --- a/domain/user/service/service_test.go +++ b/domain/user/service/service_test.go @@ -317,6 +317,20 @@ func (s *serviceSuite) setMockState(c *gc.C) map[user.UUID]stateUser { return nil }).AnyTimes() + // Implement the contract defined by UpdateLastLogin + s.state.EXPECT().UpdateLastLogin( + gomock.Any(), gomock.Any(), + ).DoAndReturn(func( + _ context.Context, + uuid user.UUID) error { + usr, exists := mockState[uuid] + if !exists || usr.removed { + return usererrors.NotFound + } + usr.lastLogin = time.Now() + mockState[uuid] = usr + return nil + }).AnyTimes() return mockState } @@ -1284,3 +1298,23 @@ func (s *serviceSuite) TestUsernameValidation(c *gc.C) { } } } + +// TestUpdateLastLogin tests the happy path for UpdateLastLogin. +func (s *serviceSuite) TestUpdateLastLogin(c *gc.C) { + defer s.setupMocks(c).Finish() + mockState := s.setMockState(c) + now := time.Now() + uuid, err := user.NewUUID() + c.Assert(err, jc.ErrorIsNil) + + mockState[uuid] = stateUser{ + name: "username", + lastLogin: now, + } + + err = s.service().UpdateLastLogin(context.Background(), uuid) + c.Assert(err, jc.ErrorIsNil) + + userState := mockState[uuid] + c.Assert(userState.lastLogin, gc.NotNil) +} diff --git a/domain/user/service/state_mock_test.go b/domain/user/service/state_mock_test.go index 82a99606ee3..32de47e9d9d 100644 --- a/domain/user/service/state_mock_test.go +++ b/domain/user/service/state_mock_test.go @@ -196,3 +196,17 @@ func (mr *MockStateMockRecorder) SetPasswordHash(arg0, arg1, arg2, arg3 any) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPasswordHash", reflect.TypeOf((*MockState)(nil).SetPasswordHash), arg0, arg1, arg2, arg3) } + +// UpdateLastLogin mocks base method. +func (m *MockState) UpdateLastLogin(arg0 context.Context, arg1 user.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateLastLogin", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateLastLogin indicates an expected call of UpdateLastLogin. +func (mr *MockStateMockRecorder) UpdateLastLogin(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLastLogin", reflect.TypeOf((*MockState)(nil).UpdateLastLogin), arg0, arg1) +} diff --git a/domain/user/state/state.go b/domain/user/state/state.go index d10e7e4c603..5165f921b88 100644 --- a/domain/user/state/state.go +++ b/domain/user/state/state.go @@ -155,7 +155,7 @@ WHERE user.uuid = $M.uuid selectGetUserStmt, err := sqlair.Prepare(getUserQuery, User{}, sqlair.M{}) if err != nil { - return errors.Annotate(err, "preparing select getUserWithAuthInfo query") + return errors.Annotate(err, "preparing select getUser query") } var result User @@ -197,7 +197,7 @@ WHERE user.name = $M.name AND removed = false selectGetUserByNameStmt, err := sqlair.Prepare(getUserByNameQuery, User{}, sqlair.M{}) if err != nil { - return errors.Annotate(err, "preparing select getUserWithAuthInfoByName query") + return errors.Annotate(err, "preparing select getUserByName query") } var result User @@ -427,6 +427,45 @@ func AddUserWithPassword( return errors.Trace(setPasswordHash(ctx, tx, uuid, passwordHash, salt)) } +// UpdateLastLogin updates the last login time for the user with the supplied +// uuid. If the user does not exist an error that satisfies +// usererrors.NotFound will be returned. +func (st *State) UpdateLastLogin(ctx context.Context, uuid user.UUID) error { + db, err := st.DB() + if err != nil { + return errors.Annotate(err, "getting DB access") + } + + updateLastLoginQuery := ` +UPDATE user_authentication +SET last_login = $M.last_login +WHERE user_uuid = $M.uuid +` + + updateLastLoginStmt, err := sqlair.Prepare(updateLastLoginQuery, sqlair.M{}) + if err != nil { + return errors.Annotate(err, "preparing update updateLastLogin query") + } + + return db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { + removed, err := st.isRemoved(ctx, tx, uuid) + if err != nil { + return errors.Annotatef(err, "getting user with uuid %q", uuid) + } + if removed { + return errors.Annotatef(usererrors.NotFound, "%q", uuid) + } + + query := tx.Query(ctx, updateLastLoginStmt, sqlair.M{"uuid": uuid.String(), "last_login": time.Now()}) + err = query.Run() + if err != nil { + return errors.Annotatef(err, "updating last login for user with uuid %q", uuid) + } + + return nil + }) +} + // addUser adds a new user to the database. If the user already exists an error // that satisfies usererrors.AlreadyExists will be returned. If the creator does // not exist an error that satisfies usererrors.UserCreatorUUIDNotFound will be diff --git a/domain/user/state/state_test.go b/domain/user/state/state_test.go index a05f8505f65..6b510ffdb36 100644 --- a/domain/user/state/state_test.go +++ b/domain/user/state/state_test.go @@ -200,6 +200,7 @@ func (s *stateSuite) TestGetUser(c *gc.C) { c.Assert(err, jc.ErrorIsNil) err = st.AddUserWithPasswordHash(context.Background(), adminUUID, adminUser, adminUUID, "passwordHash", salt) + c.Assert(err, jc.ErrorIsNil) // Get the user. u, err := st.GetUser(context.Background(), adminUUID) @@ -237,6 +238,7 @@ func (s *stateSuite) TestGetRemovedUser(c *gc.C) { c.Assert(err, jc.ErrorIsNil) err = st.AddUserWithPasswordHash(context.Background(), userToRemoveUUID, userToRemove, adminUUID, "passwordHash", salt) + c.Assert(err, jc.ErrorIsNil) // Remove userToRemove. err = st.RemoveUser(context.Background(), userToRemoveUUID) @@ -360,6 +362,7 @@ func (s *stateSuite) TestGetUserByNameMultipleUsers(c *gc.C) { c.Assert(err, jc.ErrorIsNil) err = st.AddUserWithPasswordHash(context.Background(), admin2UUID, admin2User, admin2UUID, "passwordHash", salt) + c.Assert(err, jc.ErrorIsNil) // Get the user. u, err := st.GetUserByName(context.Background(), "admin") @@ -862,3 +865,44 @@ WHERE user_uuid = ? c.Assert(disabled, gc.Equals, false) } + +// TestUpdateLastLogin asserts that we can update the last login time for a +// user. +func (s *stateSuite) TestUpdateLastLogin(c *gc.C) { + st := NewState(s.TxnRunnerFactory()) + + // Add admin user with activation key. + adminUUID, err := user.NewUUID() + c.Assert(err, jc.ErrorIsNil) + adminUser := user.User{ + Name: "admin", + DisplayName: "admin", + } + + salt, err := auth.NewSalt() + c.Assert(err, jc.ErrorIsNil) + + // Add user with password hash. + err = st.AddUserWithPasswordHash(context.Background(), adminUUID, adminUser, adminUUID, "passwordHash", salt) + c.Assert(err, jc.ErrorIsNil) + + // Update last login. + err = st.UpdateLastLogin(context.Background(), adminUUID) + c.Assert(err, jc.ErrorIsNil) + + // Check that the last login was updated correctly. + db := s.DB() + + row := db.QueryRow(` +SELECT last_login +FROM user_authentication +WHERE user_uuid = ? + `, adminUUID) + c.Assert(row.Err(), jc.ErrorIsNil) + + var lastLogin time.Time + err = row.Scan(&lastLogin) + c.Assert(err, jc.ErrorIsNil) + + c.Assert(lastLogin, gc.NotNil) +}