Skip to content

Commit

Permalink
✨ plugin/event: add sync mode
Browse files Browse the repository at this point in the history
  • Loading branch information
rjeczalik committed Jul 1, 2024
1 parent 5c90647 commit 9046fcc
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 47 deletions.
130 changes: 83 additions & 47 deletions pkg/plugin/builtin/event/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package event // import "hookt.dev/cmd/pkg/plugin/builtin/event"
import (
"context"
"log/slog"
"maps"
"strconv"
"time"

Expand All @@ -14,20 +13,24 @@ import (
"hookt.dev/cmd/pkg/trace"
)

type step struct {
c chan proto.Message
done chan struct{}
}

type Plugin struct {
wire.Config

p *proto.P
c map[chan proto.Message]chan struct{}
mux chan proto.Message
stop chan (chan proto.Message)
p *proto.P
steps []step
mux chan proto.Message
stop chan int
}

func New(opts ...func(*Plugin)) *Plugin {
p := &Plugin{
c: make(map[chan proto.Message]chan struct{}),
mux: make(chan proto.Message),
stop: make(chan chan proto.Message),
stop: make(chan int),
}
for _, opt := range opts {
opt(p)
Expand All @@ -49,6 +52,12 @@ func (p *Plugin) Plugin(_ context.Context, q *proto.P) any {
}

func (p *Plugin) Init(ctx context.Context, job *proto.Job) error {
switch p.Config.Mode {
case "", "async", "sync":
// ok
default:
return errors.New("invalid mode %q", p.Config.Mode)
}
wire:
for _, source := range p.Config.Sources {
for _, plugin := range job.Plugins {
Expand All @@ -73,7 +82,7 @@ wire:
return errors.New("source %q not found in job plugins", source)
}

go p.process()
go p.schedule()

slog.Debug("event: init",
"config", p.Config,
Expand All @@ -82,62 +91,78 @@ wire:
return nil
}

func (p *Plugin) process() {
type indexer interface {
Index() int
}

func (p *Plugin) schedule() {
for {
select {
case c := <-p.stop:
done := p.c[c]
close(done)
delete(p.c, c)
case i := <-p.stop:
s := &p.steps[i]
close(s.done)
s.done = nil
case msg := <-p.mux:
ch := maps.Clone(p.c)
go func() {
for c, done := range ch {
select {
case <-done:
case c <- msg:
}
var steps []step
for _, step := range p.steps {
if step.done == nil {
continue
}
}()
steps = append(steps, step)
}
switch p.Config.Mode {
case "sync":
wg := Wait(msg)
go func() {
for _, step := range steps {
select {
case <-step.done:
continue
case step.c <- wg:
if wg.Wait() {
return
}
}
}
}()
case "", "async":
go func() {
for _, step := range steps {
select {
case <-step.done:
continue
case step.c <- msg:
}
}
}()
}
}
}
}

func (p *Plugin) Step(ctx context.Context) any {
c := make(chan proto.Message)
done := make(chan struct{})
p.c[c] = done
s := step{
c: make(chan proto.Message),
done: make(chan struct{}),
}
p.steps = append(p.steps, s)
it, _ := time.ParseDuration(p.Config.InactiveTimeout)
return &Step{
p: p,
c: c,
done: done,
it: nonempty(it, 1*time.Minute),
i: len(p.steps) - 1,
p: p,
it: nonempty(it, 1*time.Minute),
}
}

type Step struct {
wire.Step

p *Plugin
c chan proto.Message
done chan struct{}
it time.Duration
i int
p *Plugin
it time.Duration
}

func group(ctx context.Context, name string) context.Context {
return trace.With(ctx, "pattern-group", name)
}

func (s *Step) Run(ctx context.Context, c *check.S) error {
type indexer interface {
Index() int
}

slog.Debug("event: run",
"match", s.Match,
"pass", s.Pass,
Expand Down Expand Up @@ -170,15 +195,15 @@ func (s *Step) Run(ctx context.Context, c *check.S) error {
c.Fail()
tr.MatchTimeout(ctx)
return errors.New("step has timed out after %v", s.it)
case msg := <-s.c:
case msg := <-s.step().c:
if !inactive.Stop() {
<-inactive.C
}
inactive.Reset(s.it)

ctxt := ctx

if i, ok := msg.(indexer); ok {
if i, ok := msg.(Indexer); ok {
ctxt = trace.With(ctxt, "event-seq", strconv.Itoa(i.Index()))
}

Expand All @@ -190,6 +215,9 @@ func (s *Step) Run(ctx context.Context, c *check.S) error {
}

if !match {
if wg, ok := msg.(WaitMessage); ok {
wg.Done(false)
}
continue
}

Expand All @@ -202,11 +230,14 @@ func (s *Step) Run(ctx context.Context, c *check.S) error {
return errors.New("failure pattern matched")
}

ok, err := pass.Match(group(ctxt, "pass"), obj)
pass, err := pass.Match(group(ctxt, "pass"), obj)
if err != nil {
return errors.New("failed to match ok pattern: %w", err)
}
if ok {
if wg, ok := msg.(WaitMessage); ok {
wg.Done(pass)
}
if pass {
c.OK()
return nil
}
Expand All @@ -217,15 +248,20 @@ func (s *Step) Run(ctx context.Context, c *check.S) error {
}

func (s *Step) Stop() {
s.p.stop <- s.c
s.p.stop <- s.i
s.drain()
}

func (s *Step) step() step {
return s.p.steps[s.i]
}

func (s *Step) drain() {
for {
select {
case <-s.c:
case <-s.done:
case _ = <-s.step().c:
// drop the event
case <-s.step().done:
return
}
}
Expand Down
50 changes: 50 additions & 0 deletions pkg/plugin/builtin/event/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package event

import (
"hookt.dev/cmd/pkg/proto"
)

type Indexer interface {
Index() int
}

type WaitMessage interface {
proto.Message
Done(bool)
Wait() bool
}

type Message struct {
proto.Message

done chan bool
}

func (m *Message) Done(ok bool) {
m.done <- ok
}

func (m *Message) Wait() bool {
return <-m.done
}

var _ proto.Message = (*Message)(nil)

func Wait(msg proto.Message) WaitMessage {
m := &Message{
Message: msg,
done: make(chan bool),
}

if idx, ok := msg.(Indexer); ok {
return struct {
WaitMessage
Indexer
}{
m,
idx,
}
}

return m
}
1 change: 1 addition & 0 deletions pkg/plugin/builtin/event/wire/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

type Config struct {
Sources []string `json:"sources"`
Mode string `json:"mode,omitempty"`
Timeout string `json:"timeout,omitempty"`
InactiveTimeout string `json:"inactive_timeout,omitempty"`
}
Expand Down

0 comments on commit 9046fcc

Please sign in to comment.