diff --git a/builtin_functions.go b/builtin_functions.go index 556f2d08..5eb996d6 100644 --- a/builtin_functions.go +++ b/builtin_functions.go @@ -142,21 +142,27 @@ func funcLoadBase(ctx *Context, this *VMValue, params []*VMValue, isRaw bool) *V } // computed 回调 - if ctx.Config.HookFuncValueLoadOverwriteBeforeComputed != nil { - val = ctx.Config.HookFuncValueLoadOverwriteBeforeComputed(ctx, name, val) - } - - if !isRaw && val.TypeId == VMTypeComputedValue { - val = val.ComputedExecute(ctx, nil) - if ctx.Error != nil { - return nil + doCompute := func(val *VMValue) *VMValue { + if !isRaw { + if val.TypeId == VMTypeComputedValue { + val = val.ComputedExecute(ctx, nil) + if ctx.Error != nil { + return nil + } + } } + return val } if ctx.Config.HookFuncValueLoadOverwrite != nil { - val = ctx.Config.HookFuncValueLoadOverwrite(ctx, name, val, &BufferSpan{}) + val = ctx.Config.HookFuncValueLoadOverwrite(ctx, name, val, doCompute, &BufferSpan{}) + } else { + val = doCompute(val) } + if ctx.Error != nil { + return nil + } return val } diff --git a/rollvm.go b/rollvm.go index 7ae7d245..d06aef06 100644 --- a/rollvm.go +++ b/rollvm.go @@ -731,38 +731,46 @@ func (ctx *Context) evaluate() { return } - // computed 回调 - if ctx.Config.HookFuncValueLoadOverwriteBeforeComputed != nil { - val = ctx.Config.HookFuncValueLoadOverwriteBeforeComputed(ctx, name, val) - } - // 计算真实结果 isRaw := typeLoadNameRaw == code.T - if !isRaw && val.TypeId == VMTypeComputedValue { - detail := &details[len(details)-1] - val = val.ComputedExecute(ctx, detail) - if ctx.Error != nil { - return + doCompute := func(val *VMValue) *VMValue { + if !isRaw && val.TypeId == VMTypeComputedValue { + if withDetail { + detail := &details[len(details)-1] + val = val.ComputedExecute(ctx, detail) + } else { + val = val.ComputedExecute(ctx, &BufferSpan{}) + } + if ctx.Error != nil { + return nil + } } - } - // 追加计算结果到detail - if withDetail { - detail := &details[len(details)-1] - detail.Ret = val + // 追加计算结果到detail + if withDetail { + detail := &details[len(details)-1] + detail.Ret = val + } + return val } if ctx.Config.HookFuncValueLoadOverwrite != nil { if len(details) > 0 { oldRet := details[len(details)-1].Ret - val = ctx.Config.HookFuncValueLoadOverwrite(ctx, name, val, &details[len(details)-1]) + val = ctx.Config.HookFuncValueLoadOverwrite(ctx, name, val, doCompute, &details[len(details)-1]) if oldRet == details[len(details)-1].Ret { // 如果ret发生变化才修改,顺便修改detail中的结果为最终结果 details[len(details)-1].Ret = val } } else { - val = ctx.Config.HookFuncValueLoadOverwrite(ctx, name, val, &BufferSpan{}) + val = ctx.Config.HookFuncValueLoadOverwrite(ctx, name, val, doCompute, &BufferSpan{}) } + } else { + val = doCompute(val) + } + + if ctx.Error != nil { + return } stackPush(val) diff --git a/rollvm_callback_test.go b/rollvm_callback_test.go index b3fdcbf0..6fd27ead 100644 --- a/rollvm_callback_test.go +++ b/rollvm_callback_test.go @@ -27,7 +27,11 @@ func TestGlobalValueLoadOverwrite(t *testing.T) { func TestHookFuncValueLoadOverwrite(t *testing.T) { vm := NewVM() - vm.Config.HookFuncValueLoadOverwrite = func(ctx *Context, name string, curVal *VMValue, detail *BufferSpan) *VMValue { + vm.Config.HookFuncValueLoadOverwrite = func(ctx *Context, name string, curVal *VMValue, doCompute func(v *VMValue) *VMValue, detail *BufferSpan) *VMValue { + doCompute(curVal) + if ctx.Error != nil { + return nil + } return ni(123) } diff --git a/types.go b/types.go index c388721b..5a607b52 100644 --- a/types.go +++ b/types.go @@ -82,10 +82,8 @@ type RollConfig struct { HookFuncValueStore func(ctx *Context, name string, v *VMValue) (overwrite *VMValue, solved bool) // 如果overwrite不为nil,将结束值加载并使用overwrite值。如果为nil,将以newName为key进行加载 HookFuncValueLoad func(ctx *Context, name string) (newName string, overwrite *VMValue) - // 读取后回调(返回值将覆盖之前读到的值。如果之前未读取到值curVal将为nil,这个回调处于computed计算之前) - HookFuncValueLoadOverwriteBeforeComputed func(ctx *Context, name string, curVal *VMValue) *VMValue - // 读取后回调(返回值将覆盖之前读到的值。如果之前未读取到值curVal将为nil) - HookFuncValueLoadOverwrite func(ctx *Context, name string, curVal *VMValue, detail *BufferSpan) *VMValue + // 读取后回调(返回值将覆盖之前读到的值。如果之前未读取到值curVal将为nil),用户需要在里面调用doCompute保证结果正确 + HookFuncValueLoadOverwrite func(ctx *Context, name string, curVal *VMValue, doCompute func(curVal *VMValue) *VMValue, detail *BufferSpan) *VMValue // st回调,注意val和extra都经过clone,可以放心储存 CallbackSt func(_type string, name string, val *VMValue, extra *VMValue, op string, detail string) // st回调