diff --git a/README.md b/README.md index 95a4150..76226c5 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,11 @@ func main() { task := boomerang.NewTask( "greeter", "some-unique-id", - time.Now().Add(5*time.Second), "Hello!", ) - if err := sch.Add(ctx, task); err != nil { + // Schedule task for execution every second starting from now + if err := sch.Add(ctx, task, time.Second, time.Now()); err != nil { panic(err) } diff --git a/schedule.go b/schedule.go index 6cc3898..86de676 100644 --- a/schedule.go +++ b/schedule.go @@ -22,8 +22,7 @@ var ( ) type Schedule interface { - Add(ctx context.Context, task *Task) error - Update(ctx context.Context, task *Task) error + Add(ctx context.Context, task *Task, interval time.Duration, firstExecution time.Time) error Remove(ctx context.Context, kind string, id string) error RunNow(ctx context.Context, kind string, id string) error On(ctx context.Context, kind string, handler func(ctx context.Context, task *Task)) error @@ -31,7 +30,7 @@ type Schedule interface { type TaskData struct { Interval time.Duration - Data any + Data []byte } type ScheduleImpl struct { @@ -44,16 +43,16 @@ func NewSchedule(redisClient *redis.Client) Schedule { } } -func (s *ScheduleImpl) Add(ctx context.Context, task *Task) error { +func (s *ScheduleImpl) Add(ctx context.Context, task *Task, interval time.Duration, firstExecution time.Time) error { taskData, err := json.Marshal(TaskData{ - Interval: task.Interval / time.Millisecond, + Interval: interval / time.Millisecond, Data: task.Data, }) if err != nil { return err } - nextTick := time.Now().Add(task.Interval).UnixMilli() + nextTick := firstExecution.UnixMilli() script := redis.NewScript(` local queueKey = KEYS[1] @@ -104,70 +103,6 @@ func (s *ScheduleImpl) Add(ctx context.Context, task *Task) error { } } -func (s *ScheduleImpl) Update(ctx context.Context, task *Task) error { - taskData, err := json.Marshal(TaskData{ - Interval: task.Interval / time.Millisecond, - Data: task.Data, - }) - if err != nil { - return err - } - - nextTick := time.Now().Add(task.Interval).UnixMilli() - - script := redis.NewScript(` - local queueKey = KEYS[1] - local taskDataKey = KEYS[2] - local id = ARGV[1] - local taskData = ARGV[2] - local score = ARGV[3] - - -- Check if the task exists - - local exists = redis.call("HEXISTS", taskDataKey, id) - if exists == 1 then - -- Remove the task from the sorted set and the task data from the hash set - redis.call("ZREM", queueKey, id) - redis.call("HDEL", taskDataKey, id) - else - -- Error: task does not exist - return -1 - end - - -- Update the task data in the hash set and the sorted set - - redis.call("HSET", taskDataKey, id, taskData) - redis.call("ZADD", queueKey, score, id) - - -- OK - return 0 - `) - - code, err := script.Run( - ctx, - s.redisClient, - []string{ - s.taskScheduleKey(task.Kind), - s.taskDataKey(task.Kind), - }, - task.ID, - taskData, - float64(nextTick), - ).Int() - if err != nil { - return err - } - - switch code { - case 0: - return nil - case -1: - return ErrTaskDoesNotExist - default: - return ErrUnexpectedReturnCode - } -} - func (s *ScheduleImpl) Remove(ctx context.Context, kind string, id string) error { script := redis.NewScript(` local queueKey = KEYS[1] @@ -385,21 +320,20 @@ func (s *ScheduleImpl) On(ctx context.Context, kind string, handler func(ctx con time.Sleep(time.Duration(delta) * time.Millisecond) } - taskDataRaw, ok := resSlice[3].(string) + data, ok := resSlice[3].(string) if !ok { return errors.New("unexpected type for taskDataRaw") } var taskData TaskData - if err := json.Unmarshal([]byte(taskDataRaw), &taskData); err != nil { + if err := json.Unmarshal([]byte(data), &taskData); err != nil { return err } handler(ctx, &Task{ - ID: id, - Kind: kind, - Interval: time.Duration(taskData.Interval) * time.Millisecond, - Data: taskData.Data, + ID: id, + Kind: kind, + Data: taskData.Data, }) } } diff --git a/schedule_test.go b/schedule_test.go index 7159f4a..d5d1282 100644 --- a/schedule_test.go +++ b/schedule_test.go @@ -12,10 +12,7 @@ import ( var testTask1 = NewTask( "test", "id", - 10*time.Millisecond, - map[string]any{ - "foo": "bar", - }, + []byte("test data"), ) func newSchedule(t *testing.T, ctx context.Context, db int) Schedule { @@ -38,30 +35,13 @@ func TestScheduleImpl_Add(t *testing.T) { schedule := newSchedule(t, ctx, 1) - err := schedule.Add(ctx, testTask1) + err := schedule.Add(ctx, testTask1, 10*time.Millisecond, time.Now()) assert.NoError(t, err) - err = schedule.Add(ctx, testTask1) + err = schedule.Add(ctx, testTask1, 10*time.Millisecond, time.Now()) assert.ErrorIs(t, err, ErrTaskAlreadyExists) } -func TestScheduleImpl_Update(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - schedule := newSchedule(t, ctx, 2) - - err := schedule.Update(ctx, testTask1) - assert.ErrorIs(t, err, ErrTaskDoesNotExist) - - err = schedule.Add(ctx, testTask1) - assert.NoError(t, err) - - err = schedule.Update(ctx, testTask1) - assert.NoError(t, err) -} - func TestScheduleImpl_Remove(t *testing.T) { t.Parallel() @@ -72,7 +52,7 @@ func TestScheduleImpl_Remove(t *testing.T) { err := schedule.Remove(ctx, testTask1.Kind, testTask1.ID) assert.ErrorIs(t, err, ErrTaskDoesNotExist) - err = schedule.Add(ctx, testTask1) + err = schedule.Add(ctx, testTask1, 10*time.Millisecond, time.Now()) assert.NoError(t, err) err = schedule.Remove(ctx, testTask1.Kind, testTask1.ID) @@ -89,7 +69,7 @@ func TestScheduleImpl_RunNow(t *testing.T) { err := schedule.RunNow(ctx, testTask1.Kind, testTask1.ID) assert.ErrorIs(t, err, ErrTaskDoesNotExist) - err = schedule.Add(ctx, testTask1) + err = schedule.Add(ctx, testTask1, 10*time.Millisecond, time.Now()) assert.NoError(t, err) err = schedule.RunNow(ctx, testTask1.Kind, testTask1.ID) @@ -107,7 +87,7 @@ func TestScheduleImpl_On(t *testing.T) { ctxA, cancelA := context.WithTimeout(ctx, 1*time.Second) - err := schedule.Add(ctx, testTask1) + err := schedule.Add(ctx, testTask1, 10*time.Millisecond, time.Now()) assert.NoError(t, err) err = schedule.On(ctxA, testTask1.Kind, func(ctx context.Context, task *Task) { @@ -125,4 +105,24 @@ func TestScheduleImpl_On(t *testing.T) { }) assert.ErrorIs(t, err, context.DeadlineExceeded) + + err = schedule.Remove(ctx, testTask1.Kind, testTask1.ID) + assert.NoError(t, err) + + // Test data unmarshalling. + + ctxC, cancelC := context.WithTimeout(ctx, 1*time.Second) + + err = schedule.Add(ctx, testTask1, 10*time.Millisecond, time.Now()) + assert.NoError(t, err) + + err = schedule.On(ctxC, testTask1.Kind, func(ctx context.Context, task *Task) { + assert.Equal(t, testTask1.Kind, task.Kind) + assert.Equal(t, testTask1.ID, task.ID) + assert.Equal(t, testTask1.Data, task.Data) + + cancelC() + }) + + assert.ErrorIs(t, err, context.Canceled) } diff --git a/task.go b/task.go index 8c40ac3..b9f5b8b 100644 --- a/task.go +++ b/task.go @@ -1,19 +1,15 @@ package boomerang -import "time" - type Task struct { - Kind string - ID string - Interval time.Duration - Data any + Kind string + ID string + Data []byte } -func NewTask(kind, id string, interval time.Duration, data any) *Task { +func NewTask(kind string, id string, data []byte) *Task { return &Task{ - Kind: kind, - ID: id, - Interval: interval, - Data: data, + Kind: kind, + ID: id, + Data: data, } }