Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(pipeline): move pipeline #7

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
module github.com/go-kratos/exp

go 1.18

require (
Casper-Mars marked this conversation as resolved.
Show resolved Hide resolved
github.com/go-kratos/kratos/v2 v2.2.0 // indirect
github.com/go-playground/form/v4 v4.2.0 // indirect
github.com/google/uuid v1.3.0 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350 // indirect
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
59 changes: 59 additions & 0 deletions internal/time/time.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package time
Casper-Mars marked this conversation as resolved.
Show resolved Hide resolved

import (
"context"
"database/sql/driver"
"strconv"
xtime "time"
)

// Time be used to MySql timestamp converting.
type Time int64

// Scan scan time.
func (jt *Time) Scan(src interface{}) (err error) {
switch sc := src.(type) {
case xtime.Time:
*jt = Time(sc.Unix())
case string:
var i int64
i, err = strconv.ParseInt(sc, 10, 64)
*jt = Time(i)
}
return
}

// Value get time value.
func (jt Time) Value() (driver.Value, error) {
return xtime.Unix(int64(jt), 0), nil
}

// Time get time.
func (jt Time) Time() xtime.Time {
return xtime.Unix(int64(jt), 0)
}

// Duration be used toml unmarshal string time, like 1s, 500ms.
type Duration xtime.Duration

// UnmarshalText unmarshal text to duration.
func (d *Duration) UnmarshalText(text []byte) error {
tmp, err := xtime.ParseDuration(string(text))
if err == nil {
*d = Duration(tmp)
}
return err
}

// Shrink will decrease the duration by comparing with context's timeout duration
// and return new timeout\context\CancelFunc.
func (d Duration) Shrink(c context.Context) (Duration, context.Context, context.CancelFunc) {
if deadline, ok := c.Deadline(); ok {
if ctimeout := xtime.Until(deadline); ctimeout < xtime.Duration(d) {
// deliver small timeout
return Duration(ctimeout), c, func() {}
}
}
ctx, cancel := context.WithTimeout(c, xtime.Duration(d))
return d, ctx, cancel
}
60 changes: 60 additions & 0 deletions internal/time/time_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package time

import (
"context"
"testing"
"time"
)

func TestShrink(t *testing.T) {
var d Duration
err := d.UnmarshalText([]byte("1s"))
if err != nil {
t.Fatalf("TestShrink: d.UnmarshalText failed!err:=%v", err)
}
c := context.Background()
to, ctx, cancel := d.Shrink(c)
defer cancel()
if time.Duration(to) != time.Second {
t.Fatalf("new timeout must be equal 1 second")
}
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > time.Second || time.Until(deadline) < time.Millisecond*500 {
t.Fatalf("ctx deadline must be less than 1s and greater than 500ms")
}
}

func TestShrinkWithTimeout(t *testing.T) {
var d Duration
err := d.UnmarshalText([]byte("1s"))
if err != nil {
t.Fatalf("TestShrink: d.UnmarshalText failed!err:=%v", err)
}
c, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()
to, ctx, cancel := d.Shrink(c)
defer cancel()
if time.Duration(to) != time.Second {
t.Fatalf("new timeout must be equal 1 second")
}
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > time.Second || time.Until(deadline) < time.Millisecond*500 {
t.Fatalf("ctx deadline must be less than 1s and greater than 500ms")
}
}

func TestShrinkWithDeadline(t *testing.T) {
var d Duration
err := d.UnmarshalText([]byte("1s"))
if err != nil {
t.Fatalf("TestShrink: d.UnmarshalText failed!err:=%v", err)
}
c, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
defer cancel()
to, ctx, cancel := d.Shrink(c)
defer cancel()
if time.Duration(to) >= time.Millisecond*500 {
t.Fatalf("new timeout must be less than 500 ms")
}
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > time.Millisecond*500 || time.Until(deadline) < time.Millisecond*200 {
t.Fatalf("ctx deadline must be less than 500ms and greater than 200ms")
}
}
3 changes: 3 additions & 0 deletions pipeline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pkg/sync/pipeline

提供内存批量聚合工具
219 changes: 219 additions & 0 deletions pipeline/pipeline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package pipeline

import (
"context"
"errors"
"github.com/go-kratos/kratos/v2/metadata"
"github.com/go-kratos/kratos/v2/metrics"
"strconv"
"sync"
"time"

xtime "github.com/go-kratos/exp/internal/time"
)

// ErrFull channel full error
var ErrFull = errors.New("channel full")

// mirrorKey
const mirrorKey = "mirror"

type message[T any] struct {
key string
value T
}

// Pipeline pipeline struct
type Pipeline[T any] struct {
Do func(c context.Context, index int, values map[string][]T)
Split func(key string) int
chans []chan *message[T]
mirrorChans []chan *message[T]
config *Config
wait sync.WaitGroup
name string
metricCount metrics.Counter
metricChanLen metrics.Gauge
}

// Config Pipeline config
type Config struct {
Casper-Mars marked this conversation as resolved.
Show resolved Hide resolved
// MaxSize merge size
MaxSize int
// Interval merge interval
Interval xtime.Duration
Casper-Mars marked this conversation as resolved.
Show resolved Hide resolved
// Buffer channel size
Buffer int
// Worker channel number
Worker int
// Name use for metrics
Name string
// MetricCount use for metrics
MetricCount metrics.Counter
// MetricChanLen use for metrics
MetricChanLen metrics.Gauge
}

func (c *Config) fix() {
if c.MaxSize <= 0 {
c.MaxSize = 1000
}
if c.Interval <= 0 {
c.Interval = xtime.Duration(time.Second)
}
if c.Buffer <= 0 {
c.Buffer = 1000
}
if c.Worker <= 0 {
c.Worker = 10
}
if c.Name == "" {
c.Name = "anonymous"
}
}

// NewPipeline new pipline
func NewPipeline[T any](config *Config) (res *Pipeline[T]) {
if config == nil {
config = &Config{}
}
config.fix()
res = &Pipeline[T]{
chans: make([]chan *message[T], config.Worker),
mirrorChans: make([]chan *message[T], config.Worker),
config: config,
name: config.Name,
metricCount: config.MetricCount,
metricChanLen: config.MetricChanLen,
}
for i := 0; i < config.Worker; i++ {
res.chans[i] = make(chan *message[T], config.Buffer)
res.mirrorChans[i] = make(chan *message[T], config.Buffer)
}
return
}

// Start start all mergeproc
func (p *Pipeline[T]) Start() {
if p.Do == nil {
panic("pipeline: do func is nil")
}
if p.Split == nil {
panic("pipeline: split func is nil")
}
var mirror bool
p.wait.Add(len(p.chans) + len(p.mirrorChans))
for i, ch := range p.chans {
go p.mergeProc(mirror, i, ch)
}
mirror = true
for i, ch := range p.mirrorChans {
go p.mergeProc(mirror, i, ch)
}
}

// SyncAdd sync add a value to channal, channel shard in split method
func (p *Pipeline[T]) SyncAdd(c context.Context, key string, value T) (err error) {
ch, msg := p.add(c, key, value)
select {
case ch <- msg:
case <-c.Done():
err = c.Err()
}
return
}

// Add async add a value to channal, channel shard in split method
func (p *Pipeline[T]) Add(c context.Context, key string, value T) (err error) {
ch, msg := p.add(c, key, value)
select {
case ch <- msg:
default:
err = ErrFull
}
return
}

func (p *Pipeline[T]) add(c context.Context, key string, value T) (ch chan *message[T], m *message[T]) {
shard := p.Split(key) % p.config.Worker
serverContext, b := metadata.FromServerContext(c)
if b && serverContext.Get(mirrorKey) != "" {
ch = p.mirrorChans[shard]
} else {
ch = p.chans[shard]
}
m = &message[T]{key: key, value: value}
return
}

// Close all goroutinue
func (p *Pipeline[T]) Close() (err error) {
for _, ch := range p.chans {
ch <- nil
}
for _, ch := range p.mirrorChans {
ch <- nil
}
p.wait.Wait()
return
}

func (p *Pipeline[T]) mergeProc(mirror bool, index int, ch <-chan *message[T]) {
defer p.wait.Done()
var (
m *message[T]
vals = make(map[string][]T, p.config.MaxSize)
closed bool
count int
inteval = p.config.Interval
timeout = false
)
if index > 0 {
inteval = xtime.Duration(int64(index) * (int64(p.config.Interval) / int64(p.config.Worker)))
}
timer := time.NewTimer(time.Duration(inteval))
defer timer.Stop()
for {
select {
case m = <-ch:
if m == nil {
closed = true
break
}
count++
vals[m.key] = append(vals[m.key], m.value)
if count >= p.config.MaxSize {
break
}
continue
case <-timer.C:
timeout = true
}
name := p.name
process := count
if len(vals) > 0 {
ctx := context.Background()
if mirror {
ctx = metadata.NewServerContext(ctx, metadata.Metadata{mirrorKey: "1"})
name = "mirror_" + name
}
p.Do(ctx, index, vals)
vals = make(map[string][]T, p.config.MaxSize)
count = 0
}
if p.metricChanLen != nil {
p.metricChanLen.With(name, strconv.Itoa(index)).Set(float64(len(ch)))
}
if p.metricCount != nil {
p.metricCount.With(name, strconv.Itoa(index)).Add(float64(process))
}
if closed {
return
}
if !timer.Stop() && !timeout {
<-timer.C
timeout = false
}
timer.Reset(time.Duration(p.config.Interval))
}
}
Loading