Skip to content

Commit

Permalink
fix: Implement upsert (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
codelite7 authored Nov 22, 2022
1 parent b5b4dbc commit f87ed3f
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 139 deletions.
112 changes: 50 additions & 62 deletions pkg/cockroachdb_store/cockroachdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,36 @@ type CockroachdbStore struct {
config *gorm.Config
}

func (c *CockroachdbStore) DeleteTaskDefinitionsByMetadata(metadataQuery interface{}) error {
err := crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
// query for task definitions that aren't completed, whose next fire time is less than the limit
return tx.Where(metadataQuery).Delete(&models.TaskDefinition{}).Error
})
if err != nil {
logging.Log.WithError(err).Error("error deleting task definitions by metadata query")
}
return err
}

func (c *CockroachdbStore) GetTaskDefinitions(ids []*uuid.UUID) ([]pkg.TaskDefinition, error) {
definitions := []models.TaskDefinition{}
err := crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
// query for task definitions that aren't completed, whose next fire time is less than the limit
return tx.Find(definitions, ids).Error
})
if err != nil {
return nil, err
}
return models.ToTaskDefinitions(definitions)
}

func (c *CockroachdbStore) DeleteTaskDefinitions(ids []*uuid.UUID) error {
return crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
// query for task definitions that aren't completed, whose next fire time is less than the limit
return tx.Delete([]models.TaskDefinition{}, ids).Error
})
}

func (c *CockroachdbStore) GetTaskDefinitionsToSchedule(limit time.Time) ([]pkg.TaskDefinition, error) {
limit = limit.UTC()
taskDefinitionModels := []models.TaskDefinition{}
Expand All @@ -51,12 +81,13 @@ func (c *CockroachdbStore) GetTaskDefinitionsToSchedule(limit time.Time) ([]pkg.
func (c *CockroachdbStore) MarkTaskInstanceComplete(taskInstance pkg.TaskInstance) error {
completedAt := time.Now().UTC()
return crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
// if the parent task definition is not recurring, this marks it as completed in a single query
err := tx.Model(&models.TaskDefinition{}).Where("id = ? and recurring = false", taskInstance.TaskDefinition.Id).Update("completed_at", completedAt).Error
if err != nil {
logging.Log.WithError(err).Error("error marking task definition complete")
return err
}
err = tx.Model(&models.TaskInstance{}).Where("id = ?", taskInstance.Id).Update("completed_at", completedAt).Error
err = tx.Omit("TaskDefinition").Model(&models.TaskInstance{}).Where("id = ?", taskInstance.Id).Update("completed_at", completedAt).Error
if err != nil {
logging.Log.WithError(err).Error("error marking task instance complete")
}
Expand Down Expand Up @@ -129,10 +160,14 @@ func (c *CockroachdbStore) ListTaskInstances(offset, limit int) ([]pkg.TaskInsta
return models.ToTaskInstances(taskInstanceModels)
}

func (c *CockroachdbStore) ListTaskDefinitions(offset, limit int) ([]pkg.TaskDefinition, error) {
func (c *CockroachdbStore) ListTaskDefinitions(offset, limit int, metadataQuery interface{}) ([]pkg.TaskDefinition, error) {
taskDefinitionModels := []models.TaskDefinition{}
err := crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
return tx.Preload(clause.Associations).Order("created_at").Offset(offset).Limit(limit).Find(&taskDefinitionModels).Error
tx = tx.Preload(clause.Associations).Order("created_at").Offset(offset).Limit(limit)
if metadataQuery != nil {
tx = tx.Where(metadataQuery)
}
return tx.Find(&taskDefinitionModels).Error
})
if err != nil {
logging.Log.WithError(err).Error("error scheduling task with cockroachdb store")
Expand All @@ -152,21 +187,6 @@ func (c *CockroachdbStore) GetTaskDefinition(id *uuid.UUID) (pkg.TaskDefinition,
return taskDefinitionModel.ToTaskDefinition()
}

func (c *CockroachdbStore) UpdateTaskDefinition(taskDefinition pkg.TaskDefinition) error {
taskDefinitionModel, err := models.GetTaskDefinitionModelFromTaskDefinition(taskDefinition)
taskDefinitionModel.TaskInstances = nil
if err != nil {
return err
}
err = crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
return tx.Save(&taskDefinitionModel).Error
})
if err != nil {
logging.Log.WithError(err).Error("error scheduling task with cockroachdb store")
}
return err
}

func (c *CockroachdbStore) DeleteTaskDefinition(id *uuid.UUID) error {
err := crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
return tx.Delete(models.TaskDefinition{Id: id}).Error
Expand Down Expand Up @@ -198,40 +218,23 @@ func (c *CockroachdbStore) GetTaskInstancesToRun(limit time.Time) ([]pkg.TaskIns
return taskInstances, nil
}

func (c *CockroachdbStore) CreateTaskInstance(taskInstance pkg.TaskInstance) error {
func (c *CockroachdbStore) UpsertTaskInstance(taskInstance pkg.TaskInstance) error {
logging.Log.Info("upserting task instance")
taskInstanceModel, err := models.GetTaskInstanceModelFromTaskInstance(taskInstance)
if taskInstanceModel.Id == nil {
id := uuid.New()
taskInstanceModel.Id = &id
}
if err != nil {
return err
}
taskInstanceModel.TaskDefinition = nil
err = crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
return tx.Create(&taskInstanceModel).Error
return tx.Omit("TaskDefinition").Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(&taskInstanceModel).Error
})
if err != nil {
logging.Log.WithError(err).Error("error creating task instance")
}
return err
}

func (c *CockroachdbStore) UpdateTaskInstance(taskInstance pkg.TaskInstance) error {
taskInstanceModel, err := models.GetTaskInstanceModelFromTaskInstance(taskInstance)
if err != nil {
return err
}
taskInstanceModel.TaskDefinition = nil
err = crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
return tx.Save(&taskInstanceModel).Error
})
if err != nil {
logging.Log.WithError(err).Error("error scheduling task with cockroachdb store")
}
return err
}

func (c *CockroachdbStore) DeleteTaskInstance(id *uuid.UUID) error {
err := crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
return tx.Delete(models.TaskInstance{Id: id}).Error
Expand Down Expand Up @@ -277,35 +280,20 @@ func (c *CockroachdbStore) Initialize() (err error) {
return nil
}

func (c *CockroachdbStore) CreateTaskDefinition(taskDefinition pkg.TaskDefinition) error {
func (c *CockroachdbStore) UpsertTaskDefinition(taskDefinition pkg.TaskDefinition) error {
logging.Log.Info("upserting task definition")
taskDefinitionModel, err := models.GetTaskDefinitionModelFromTaskDefinition(taskDefinition)
taskDefinitionModel.TaskInstances = nil
if err != nil {
return err
}
err = crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
return tx.Create(&taskDefinitionModel).Error
return tx.Omit("TaskInstances").Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(&taskDefinitionModel).Error
})
if err != nil {
logging.Log.WithError(err).Error("error scheduling task with cockroachdb store")
logging.Log.WithError(err).Error("error upserting task with cockroachdb store")
}
return err
}

//func (c *CockroachdbStore) GetUpcomingTasks(limit time.Time) ([]pkg.TaskDefinition, error) {
// models := []models.TaskDefinition{}
// err := crdbgorm.ExecuteTx(context.Background(), c.db, nil, func(tx *gorm.DB) error {
// return tx.Preload(clause.Associations).Where("next_fire_time <= ? and (in_progress = false or age(last_fire_time) >= expire_after_interval)", limit.Format(time.RFC3339)).Order("next_fire_time").Find(&models).Error
// })
// if err != nil {
// logging.Log.WithError(err).Error("error scheduling task with cockroachdb store")
// }
// tasks := []pkg.TaskDefinition{}
// for _, model := range models {
// task, err := model.ToTaskDefinition()
// if err != nil {
// return nil, err
// }
// tasks = append(tasks, task)
// }
// return tasks, err
//}
3 changes: 2 additions & 1 deletion pkg/cockroachdb_store/migrations/000001_baseline.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ create table task_definitions
last_fire_time timestamptz,
next_fire_time timestamptz,
completed_at timestamptz,
recurring bool
recurring bool,
INVERTED INDEX metadata_idx (metadata)
);

create table task_instances
Expand Down
14 changes: 14 additions & 0 deletions pkg/cockroachdb_store/models/task_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package models

import (
"encoding/json"
"github.com/catalystsquad/app-utils-go/logging"
"github.com/catalystsquad/go-scheduler/pkg"
"github.com/dariubs/gorm-jsonb"
"github.com/google/uuid"
"gorm.io/gorm"
"time"
)

Expand All @@ -25,6 +27,18 @@ type TaskDefinition struct {
Recurring bool
}

var nilUuidString = uuid.Nil.String()

func (t *TaskDefinition) BeforeCreate(tx *gorm.DB) error {
// comparing to uuid.Nil directly doesn't work as expected and skips this condition when it shouldn't, hence the string comparison
if t.Id == nil || t.Id.String() == nilUuidString {
id := uuid.New()
t.Id = &id
logging.Log.Info("set new id on task definition during create")
}
return nil
}

func (t TaskDefinition) ToTaskDefinition() (pkg.TaskDefinition, error) {
var task pkg.TaskDefinition
taskModelJsonBytes, err := json.Marshal(t)
Expand Down
13 changes: 13 additions & 0 deletions pkg/cockroachdb_store/models/task_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package models

import (
"encoding/json"
"github.com/catalystsquad/app-utils-go/logging"
"github.com/catalystsquad/go-scheduler/pkg"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"time"
)

Expand All @@ -19,6 +22,16 @@ type TaskInstance struct {
TaskDefinition *TaskDefinition `json:"task_definition"`
}

func (t *TaskInstance) BeforeCreate(tx *gorm.DB) error {
// comparing to uuid.Nil directly doesn't work as expected and skips this condition when it shouldn't, hence the string comparison
if t.Id == nil || t.Id.String() == nilUuidString {
id := uuid.New()
t.Id = &id
logging.Log.WithFields(logrus.Fields{"task_definition_id": t.TaskDefinition.Id}).Info("set new id on task instance during create")
}
return nil
}

func (t TaskInstance) ToTaskInstance() (pkg.TaskInstance, error) {
var taskInstance pkg.TaskInstance
taskInstanceModelJsonBytes, err := json.Marshal(t)
Expand Down
33 changes: 20 additions & 13 deletions pkg/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Scheduler struct {
shutdown chan bool
}

func NewScheduler(scheduleWindow, runnerWindow, cleanupWindow time.Duration, handler func(task TaskInstance) error, store StoreInterface) (*Scheduler, error) {
func NewScheduler(scheduleWindow, runnerWindow, cleanupWindow time.Duration, handler func(taskInstance TaskInstance) error, store StoreInterface) (*Scheduler, error) {
scheduler := &Scheduler{
ScheduleWindow: &scheduleWindow,
RunnerWindow: &runnerWindow,
Expand All @@ -37,7 +37,7 @@ func NewScheduler(scheduleWindow, runnerWindow, cleanupWindow time.Duration, han
return scheduler, err
}

func (s *Scheduler) CreateTaskDefinition(task TaskDefinition) error {
func (s *Scheduler) UpsertTaskDefinition(task TaskDefinition) error {
err := validateTask(task)
if err != nil {
return err
Expand All @@ -51,16 +51,15 @@ func (s *Scheduler) CreateTaskDefinition(task TaskDefinition) error {
id := uuid.New()
task.Id = &id
}
return s.store.CreateTaskDefinition(task)
return s.store.UpsertTaskDefinition(task)
}

func (s *Scheduler) UpdateTaskDefinition(task TaskDefinition) error {
err := validateTask(task)
if err != nil {
return err
}
task.Recurring = task.GetTrigger().IsRecurring()
return s.store.UpdateTaskDefinition(task)
func (s *Scheduler) GetTaskDefinitions(ids []*uuid.UUID) ([]TaskDefinition, error) {
return s.store.GetTaskDefinitions(ids)
}

func (s *Scheduler) ListTaskDefinitions(skip, limit int, metadataQuery interface{}) ([]TaskDefinition, error) {
return s.store.ListTaskDefinitions(skip, limit, metadataQuery)
}

func (s *Scheduler) DeleteTaskDefinition(id *uuid.UUID) error {
Expand All @@ -70,6 +69,14 @@ func (s *Scheduler) DeleteTaskDefinition(id *uuid.UUID) error {
return s.store.DeleteTaskDefinition(id)
}

func (s *Scheduler) DeleteTaskDefinitions(ids []*uuid.UUID) error {
return s.store.DeleteTaskDefinitions(ids)
}

func (s *Scheduler) DeleteTaskDefinitionsByMetadataQuery(metadataQuery interface{}) error {
return s.store.DeleteTaskDefinitionsByMetadata(metadataQuery)
}

func (s *Scheduler) Run() {
s.run = true
// start task instance scheduler, task instance runner, and task instance cleanup, in background
Expand Down Expand Up @@ -116,7 +123,7 @@ func (s *Scheduler) createTaskInstance(taskDefinition TaskDefinition) error {
ExecuteAt: executeAt,
TaskDefinition: taskDefinition,
}
err := s.store.CreateTaskInstance(taskInstance)
err := s.store.UpsertTaskInstance(taskInstance)
if err != nil {
logging.Log.WithError(err).Error("error creating task instance")
return err
Expand All @@ -128,7 +135,7 @@ func (s *Scheduler) createTaskInstance(taskDefinition TaskDefinition) error {
} else {
taskDefinition.NextFireTime = nil
}
err = s.store.UpdateTaskDefinition(taskDefinition)
err = s.store.UpsertTaskDefinition(taskDefinition)
if err != nil {
logging.Log.WithError(err).WithFields(logrus.Fields{"id": taskDefinition.Id}).Error("error setting task definition next execution time")
}
Expand Down Expand Up @@ -184,7 +191,7 @@ func (s *Scheduler) markTaskInstanceInProgress(taskInstance TaskInstance) error
expiresAt := startedAt.Add(taskInstance.TaskDefinition.ExpireAfter)
taskInstance.StartedAt = &startedAt
taskInstance.ExpiresAt = &expiresAt
err := s.store.UpdateTaskInstance(taskInstance)
err := s.store.UpsertTaskInstance(taskInstance)
if err != nil {
logging.Log.WithError(err).WithFields(logrus.Fields{"task_instance_id": taskInstance.Id, "task_definition_id": taskInstance.TaskDefinition.Id}).Error("error setting task instance started_at")
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/store_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ import (

type StoreInterface interface {
Initialize() error
CreateTaskDefinition(taskDefinition TaskDefinition) error
ListTaskDefinitions(offset, limit int) ([]TaskDefinition, error)
UpsertTaskDefinition(definition TaskDefinition) error
ListTaskDefinitions(offset, limit int, metadataQuery interface{}) ([]TaskDefinition, error)
GetTaskDefinition(id *uuid.UUID) (TaskDefinition, error)
UpdateTaskDefinition(taskDefinition TaskDefinition) error
GetTaskDefinitions(ids []*uuid.UUID) ([]TaskDefinition, error)
DeleteTaskDefinition(id *uuid.UUID) error
CreateTaskInstance(taskInstance TaskInstance) error
DeleteTaskDefinitions(ids []*uuid.UUID) error
DeleteTaskDefinitionsByMetadata(metadataQuery interface{}) error
UpsertTaskInstance(taskInstance TaskInstance) error
GetTaskInstance(id *uuid.UUID) (TaskInstance, error)
ListTaskInstances(offset, limit int) ([]TaskInstance, error)
UpdateTaskInstance(taskInstance TaskInstance) error
DeleteTaskInstance(id *uuid.UUID) error
GetTaskDefinitionsToSchedule(limit time.Time) ([]TaskDefinition, error)
GetTaskInstancesToRun(limit time.Time) ([]TaskInstance, error)
Expand Down
22 changes: 22 additions & 0 deletions test/cockroachdb_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/catalystsquad/app-utils-go/logging"
"github.com/catalystsquad/go-scheduler/pkg"
"github.com/catalystsquad/go-scheduler/pkg/cockroachdb_store"
"github.com/google/uuid"
"github.com/orlangure/gnomock"
"github.com/orlangure/gnomock/preset/cockroachdb"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -84,3 +85,24 @@ func (s *CockroachdbStoreSuite) TestCockroachdbStoreLongRunningTaskNotExpired()
func (s *CockroachdbStoreSuite) TestCockroachdbStoreCronTriggerHappyPath() {
TestCronTriggerHappyPath(s.T(), cockroachdbStore)
}

func (s *CockroachdbStoreSuite) TestListWithMetadataQuery() {
id := uuid.New().String()
metadata := map[string]interface{}{"user_id": id}
metadataQuery := fmt.Sprintf(`metadata @> '{"user_id": "%s"}'`, id)
TestListWithMetadataQuery(s.T(), cockroachdbStore, metadata, metadataQuery)
}

func (s *CockroachdbStoreSuite) TestDeleteWithMetadataQuery() {
id := uuid.New().String()
metadata := map[string]interface{}{"user_id": id}
metadataQuery := fmt.Sprintf(`metadata @> '{"user_id": "%s"}'`, id)
TestDeleteWithMetadataQuery(s.T(), cockroachdbStore, metadata, metadataQuery)
}

func (s *CockroachdbStoreSuite) TestSingleTaskDefinitionCreatedForCronTasks() {
TestCronTriggerHappyPath(s.T(), cockroachdbStore)
definitions, err := cockroachdbStore.ListTaskDefinitions(0, 1000, nil)
require.NoError(s.T(), err)
require.Len(s.T(), definitions, 1)
}
Loading

0 comments on commit f87ed3f

Please sign in to comment.