Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Wasm Plugin Framework #586

Merged
merged 14 commits into from
Nov 26, 2024
4 changes: 0 additions & 4 deletions go/vt/vttablet/tabletserver/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ type ActionInterface interface {
SetParams(args ActionArgs) error

GetRule() *rules.Rule

GetSkipFlag() bool

SetSkipFlag(skip bool)
}

type ActionArgs interface {
Expand Down
146 changes: 3 additions & 143 deletions go/vt/vttablet/tabletserver/action_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package tabletserver
import (
"context"
"fmt"
"regexp"
"time"

"github.com/BurntSushi/toml"
Expand All @@ -20,8 +19,6 @@ type ContinueAction struct {

// Action is the action to take if the rule matches
Action rules.Action

skipFlag bool
}

func (p *ContinueAction) BeforeExecution(_ *QueryExecutor) *ActionExecutionResponse {
Expand Down Expand Up @@ -55,16 +52,6 @@ type FailAction struct {

// Action is the action to take if the rule matches
Action rules.Action

skipFlag bool
}

func (p *ContinueAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *ContinueAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

func (p *FailAction) BeforeExecution(_ *QueryExecutor) *ActionExecutionResponse {
Expand Down Expand Up @@ -93,21 +80,11 @@ func (p *FailAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *FailAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *FailAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type FailRetryAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

skipFlag bool
}

func (p *FailRetryAction) BeforeExecution(_ *QueryExecutor) *ActionExecutionResponse {
Expand Down Expand Up @@ -136,21 +113,11 @@ func (p *FailRetryAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *FailRetryAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *FailRetryAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type BufferAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

skipFlag bool
}

func (p *BufferAction) BeforeExecution(qre *QueryExecutor) *ActionExecutionResponse {
Expand Down Expand Up @@ -198,23 +165,13 @@ func (p *BufferAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *BufferAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *BufferAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type ConcurrencyControlAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

Args *ConcurrencyControlActionArgs

skipFlag bool
}

type ConcurrencyControlActionArgs struct {
Expand Down Expand Up @@ -292,23 +249,13 @@ func (p *ConcurrencyControlAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *ConcurrencyControlAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *ConcurrencyControlAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type WasmPluginAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

Args *WasmPluginActionArgs

skipFlag bool
}

type WasmPluginActionArgs struct {
Expand Down Expand Up @@ -337,20 +284,21 @@ func (args *WasmPluginActionArgs) Parse(stringParams string) (ActionArgs, error)
func (p *WasmPluginAction) BeforeExecution(qre *QueryExecutor) *ActionExecutionResponse {
controller := qre.tsv.qe.wasmPluginController

ok, module := controller.VM.GetWasmModule(p.Args.WasmBinaryName)
ok, module := controller.VM.GetWasmModule(p.GetRule().Name)
if !ok {
wasmBytes, err := controller.GetWasmBytesByBinaryName(qre.ctx, p.Args.WasmBinaryName)
if err != nil {
return &ActionExecutionResponse{Err: err}
}
module, err = controller.VM.InitWasmModule(p.Args.WasmBinaryName, wasmBytes)
module, err = controller.VM.InitWasmModule(p.GetRule().Name, wasmBytes)
if err != nil {
return &ActionExecutionResponse{Err: err}
}
}

instance, err := module.NewInstance(qre)
if err != nil {
//todo wasm: if instance is nil, we will not be able to get the it in AfterExecution. We need to handle this case
return &ActionExecutionResponse{Err: err}
}

Expand Down Expand Up @@ -399,91 +347,3 @@ func (p *WasmPluginAction) SetParams(args ActionArgs) error {
func (p *WasmPluginAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *WasmPluginAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *WasmPluginAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type SkipFilterAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

Args *SkipFilterActionArgs

skipFlag bool
}

type SkipFilterActionArgs struct {
AllowRegexString string `toml:"skip_filter_regex"`
AllowRegex *regexp.Regexp
}

func (args *SkipFilterActionArgs) Parse(stringParams string) (ActionArgs, error) {
s := &SkipFilterActionArgs{}
if stringParams == "" {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "stringParams is empty when parsing skip filter action args")
}

userInputTOML := ConvertUserInputToTOML(stringParams)
err := toml.Unmarshal([]byte(userInputTOML), s)
if err != nil {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error when parsing skip filter action args: %v", err)
}
s.AllowRegex, err = regexp.Compile(fmt.Sprintf("^%s$", s.AllowRegexString))
if err != nil {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error when compiling skip filter action args: %v", err)
}

return s, nil
}

func (p *SkipFilterAction) BeforeExecution(qre *QueryExecutor) *ActionExecutionResponse {
findSelf := false
for _, a := range qre.matchedActionList {
if a.GetRule().Name == p.GetRule().Name {
findSelf = true
continue
}
if findSelf {
if p.Args.AllowRegex.MatchString(a.GetRule().Name) {
a.SetSkipFlag(true)
}
}
}
return &ActionExecutionResponse{Err: nil}
}

func (p *SkipFilterAction) AfterExecution(qre *QueryExecutor, reply *sqltypes.Result, err error) *ActionExecutionResponse {
return &ActionExecutionResponse{Reply: reply, Err: err}
}

func (p *SkipFilterAction) ParseParams(argsStr string) (ActionArgs, error) {
return p.Args.Parse(argsStr)
}

func (p *SkipFilterAction) SetParams(args ActionArgs) error {
skipFilterArgs, ok := args.(*SkipFilterActionArgs)
if !ok {
return fmt.Errorf("args :%v is not a valid SkipFilterActionArgs)", args)
}
p.Args = skipFilterArgs
return nil
}

func (p *SkipFilterAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *SkipFilterAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *SkipFilterAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}
2 changes: 0 additions & 2 deletions go/vt/vttablet/tabletserver/action_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ func CreateActionInstance(action rules.Action, rule *rules.Rule) (ActionInterfac
actInst, err = &ConcurrencyControlAction{Rule: rule, Action: action}, nil
case rules.QRWasmPlugin:
actInst, err = &WasmPluginAction{Rule: rule, Action: action}, nil
case rules.QRSkipFilter:
actInst, err = &SkipFilterAction{Rule: rule, Action: action}, nil
default:
log.Errorf("unknown action: %v", action)
actInst, err = nil, fmt.Errorf("unknown action: %v", action)
Expand Down
26 changes: 26 additions & 0 deletions go/vt/vttablet/tabletserver/action_stats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package tabletserver

import (
"time"
"vitess.io/vitess/go/stats"
"vitess.io/vitess/go/vt/servenv"
)

type ActionStats struct {
FilterBeforeExecutionTiming *servenv.TimingsWrapper
FilterAfterExecutionTiming *servenv.TimingsWrapper
FilterErrorCounts *stats.CountersWithSingleLabel
FilterQPSRates *stats.Rates
FilterWasmMemorySize *stats.GaugesWithMultiLabels
}

func NewActionStats(exporter *servenv.Exporter) *ActionStats {
stats := &ActionStats{
FilterBeforeExecutionTiming: exporter.NewTimings("FilterBeforeExecution", "Filter before execution timings", "Name"),
FilterAfterExecutionTiming: exporter.NewTimings("FilterAfterExecution", "Filter before execution timings", "Name"),
FilterErrorCounts: exporter.NewCountersWithSingleLabel("FilterErrorCounts", "filter error counts", "Name"),
FilterWasmMemorySize: exporter.NewGaugesWithMultiLabels("FilterWasmMemorySize", "Wasm memory size", []string{"Name", "BeforeOrAfter"}),
}
stats.FilterQPSRates = exporter.NewRates("FilterQps", stats.FilterBeforeExecutionTiming, 15*60/5, 5*time.Second)
return stats
}
4 changes: 4 additions & 0 deletions go/vt/vttablet/tabletserver/query_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ type QueryEngine struct {

// stats
queryCounts, queryTimes, queryErrorCounts, queryRowsAffected, queryRowsReturned *stats.CountersWithMultiLabels
// actionStats for filters
actionStats *ActionStats

// Loggers
accessCheckerLogger *logutil.ThrottledLogger
Expand Down Expand Up @@ -281,6 +283,8 @@ func NewQueryEngine(env tabletenv.Env, se *schema.Engine) *QueryEngine {
qe.queryRowsReturned = env.Exporter().NewCountersWithMultiLabels("QueryRowsReturned", "query rows returned", []string{"Table", "Plan"})
qe.queryErrorCounts = env.Exporter().NewCountersWithMultiLabels("QueryErrorCounts", "query error counts", []string{"Table", "Plan"})

qe.actionStats = NewActionStats(env.Exporter())

env.Exporter().HandleFunc("/debug/ccl", qe.concurrencyController.ServeHTTP)
env.Exporter().HandleFunc("/debug/hotrows", qe.txSerializer.ServeHTTP)
env.Exporter().HandleFunc("/debug/tablet_plans", qe.handleHTTPQueryPlans)
Expand Down
22 changes: 16 additions & 6 deletions go/vt/vttablet/tabletserver/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,16 @@ func (qre *QueryExecutor) runActionListBeforeExecution() (*sqltypes.Result, erro
return nil, nil
}
for _, a := range qre.matchedActionList {
if !a.GetSkipFlag() {
resp := a.BeforeExecution(qre)
qre.calledActionList = append(qre.calledActionList, a)
if resp.Reply != nil || resp.Err != nil {
return resp.Reply, resp.Err
}
startTime := time.Now()
// execute the filter action
resp := a.BeforeExecution(qre)
qre.tsv.qe.actionStats.FilterBeforeExecutionTiming.Add(a.GetRule().Name, time.Since(startTime))
qre.calledActionList = append(qre.calledActionList, a)
if resp.Err != nil {
qre.tsv.qe.actionStats.FilterErrorCounts.Add(a.GetRule().Name, 1)
}
if resp.Reply != nil || resp.Err != nil {
return resp.Reply, resp.Err
}
}
return nil, nil
Expand All @@ -631,7 +635,13 @@ func (qre *QueryExecutor) runActionListAfterExecution(reply *sqltypes.Result, er

for i := len(qre.calledActionList) - 1; i >= 0; i-- {
a := qre.calledActionList[i]
startTime := time.Now()
// execute the filter action
resp := a.AfterExecution(qre, newReply, newErr)
qre.tsv.qe.actionStats.FilterAfterExecutionTiming.Add(a.GetRule().Name, time.Since(startTime))
if resp.Err != nil {
qre.tsv.qe.actionStats.FilterErrorCounts.Add(a.GetRule().Name, 1)
}
newReply, newErr = resp.Reply, resp.Err
}
return newReply, newErr
Expand Down
15 changes: 12 additions & 3 deletions go/vt/vttablet/tabletserver/query_executor_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ func TestQueryExecutor_runActionListBeforeExecution(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
qre := &QueryExecutor{ctx: ctx}
db := setUpQueryExecutorTest(t)
defer db.Close()
tsv := newTestTabletServer(ctx, noFlags, db)
qre := newTestQueryExecutor(ctx, tsv, "select 1", 0)
qre.matchedActionList = tt.actionList
_, err := qre.runActionListBeforeExecution()
tt.wantErr(t, err, "runActionListBeforeExecution()")
Expand Down Expand Up @@ -129,7 +132,10 @@ func TestQueryExecutor_runActionListAfterExecution(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
qre := &QueryExecutor{ctx: ctx}
db := setUpQueryExecutorTest(t)
defer db.Close()
tsv := newTestTabletServer(ctx, noFlags, db)
qre := newTestQueryExecutor(ctx, tsv, "select 1", 0)
qre.matchedActionList = tt.actionList
qr := &sqltypes.Result{}
var err error
Expand Down Expand Up @@ -162,7 +168,10 @@ func TestQueryExecutor_actions_can_be_skipped(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
qre := &QueryExecutor{ctx: ctx}
db := setUpQueryExecutorTest(t)
defer db.Close()
tsv := newTestTabletServer(ctx, noFlags, db)
qre := newTestQueryExecutor(ctx, tsv, "select 1", 0)
qre.matchedActionList = tt.actionList
qr, err := qre.runActionListBeforeExecution()
tt.wantErr(t, err, "runActionListBeforeExecution()")
Expand Down
Loading
Loading