diff --git a/tool/instrument/inst_file.go b/tool/instrument/inst_file.go index cc793737..e7875857 100644 --- a/tool/instrument/inst_file.go +++ b/tool/instrument/inst_file.go @@ -2,12 +2,12 @@ package instrument import ( "fmt" - "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/shared" "log" "path/filepath" "strings" "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/resource" + "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/shared" "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/util" ) diff --git a/tool/instrument/inst_func.go b/tool/instrument/inst_func.go index bdb59d1f..34d2512c 100644 --- a/tool/instrument/inst_func.go +++ b/tool/instrument/inst_func.go @@ -2,20 +2,16 @@ package instrument import ( "fmt" - "go/token" "log" "path/filepath" "regexp" "sort" "strings" - "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/util" - + "github.com/alibaba/opentelemetry-go-auto-instrumentation/api" "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/resource" - "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/shared" - - "github.com/alibaba/opentelemetry-go-auto-instrumentation/api" + "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/util" "github.com/dave/dst" ) @@ -81,7 +77,7 @@ func findJumpPoint(jumpIf *dst.IfStmt) *dst.BlockStmt { // Insert trampoline jump within the else block elseBlock := jumpIf.Else.(*dst.BlockStmt) if len(elseBlock.List) > 1 { - // One trampoline jump already exists, recursively find the last one + // One trampoline jump already exists, recursively find last one ifStmt, ok := elseBlock.List[len(elseBlock.List)-1].(*dst.IfStmt) util.Assert(ok, "unexpected statement in trampoline-jump-if") return findJumpPoint(ifStmt) @@ -153,28 +149,24 @@ func (rp *RuleProcessor) insertTJump(t *api.InstFuncRule, funcDecl *dst.FuncDecl } return clone }()) - trampolineJump := &dst.IfStmt{ - Init: &dst.AssignStmt{ - Lhs: shared.Exprs( - shared.Ident(TrampolineCallContextName+varSuffix), - shared.Ident(TrampolineSkipName+varSuffix), - ), - Tok: token.DEFINE, - Rhs: shared.Exprs(onEnterCall), - }, - Cond: shared.Ident(TrampolineSkipName + varSuffix), - Body: shared.BlockStmts( - shared.ExprStmt(onExitCall), - shared.ReturnStmt(retVals), + tjumpInit := shared.AssignStmts( + shared.Exprs( + shared.Ident(TrampolineCallContextName+varSuffix), + shared.Ident(TrampolineSkipName+varSuffix), ), - Else: shared.Block( - shared.DeferStmt(onExitCall), - ), - } + shared.Exprs(onEnterCall), + ) + tjumpCond := shared.Ident(TrampolineSkipName + varSuffix) + tjumpBody := shared.BlockStmts( + shared.ExprStmt(onExitCall), + shared.ReturnStmt(retVals), + ) + tjumpElse := shared.Block(shared.DeferStmt(onExitCall)) + tjump := shared.IfStmt(tjumpInit, tjumpCond, tjumpBody, tjumpElse) // Add this trampoline-jump-if as optimization candidates rp.trampolineJumps = append(rp.trampolineJumps, &TJump{ target: funcDecl, - ifStmt: trampolineJump, + ifStmt: tjump, rule: t, }) @@ -189,22 +181,22 @@ func (rp *RuleProcessor) insertTJump(t *api.InstFuncRule, funcDecl *dst.FuncDecl // } /* NO_NEWWLINE_PLACEHOLDER */ // NEW_LINE { // then block - callExpr := trampolineJump.Body.List[0] + callExpr := tjump.Body.List[0] callExpr.Decorations().Start.Append(TrampolineNoNewlinePlaceholder) callExpr.Decorations().End.Append(TrampolineSemicolonPlaceholder) - retStmt := trampolineJump.Body.List[1] + retStmt := tjump.Body.List[1] retStmt.Decorations().End.Append(TrampolineNoNewlinePlaceholder) } { // else block - deferStmt := trampolineJump.Else.(*dst.BlockStmt).List[0] + deferStmt := tjump.Else.(*dst.BlockStmt).List[0] deferStmt.Decorations().Start.Append(TrampolineNoNewlinePlaceholder) deferStmt.Decorations().End.Append(TrampolineSemicolonPlaceholder) - trampolineJump.Else.Decorations().End.Append(TrampolineNoNewlinePlaceholder) - trampolineJump.Decs.If.Append(TrampolineJumpIfDesc) // Anchor label + tjump.Else.Decorations().End.Append(TrampolineNoNewlinePlaceholder) + tjump.Decs.If.Append(TrampolineJumpIfDesc) // Anchor label } - // Find if there is already a trampoline-jump-if, if so, insert new trampoline - // jump within the else block, otherwise prepend to block body + // Find if there is already a trampoline-jump-if, insert new tjump if so, + // otherwise prepend to block body found := false if len(funcDecl.Body.List) > 0 { firstStmt := funcDecl.Body.List[0] @@ -212,7 +204,7 @@ func (rp *RuleProcessor) insertTJump(t *api.InstFuncRule, funcDecl *dst.FuncDecl point := findJumpPoint(ifStmt) if point != nil { point.List = append(point.List, shared.EmptyStmt()) - point.List = append(point.List, trampolineJump) + point.List = append(point.List, tjump) found = true } } @@ -221,8 +213,8 @@ func (rp *RuleProcessor) insertTJump(t *api.InstFuncRule, funcDecl *dst.FuncDecl // Outmost trampoline-jump-if may follow by user code right after else // block, replacing the trailing newline mandatorily breaks the code, // we need to insert extra new line to make replacement possible - trampolineJump.Decorations().After = dst.EmptyLine - funcDecl.Body.List = append([]dst.Stmt{trampolineJump}, funcDecl.Body.List...) + tjump.Decorations().After = dst.EmptyLine + funcDecl.Body.List = append([]dst.Stmt{tjump}, funcDecl.Body.List...) } // Generate corresponding trampoline code diff --git a/tool/instrument/trampoline.go b/tool/instrument/trampoline.go index f80d9715..24ee677d 100644 --- a/tool/instrument/trampoline.go +++ b/tool/instrument/trampoline.go @@ -25,7 +25,6 @@ import ( // so-called "Trampoline Jump" snippet is inserted at start of raw func, it is // guaranteed to be generated within one line to avoid confusing debugging, as // its name suggests, it jumps to the trampoline function from raw function. - const ( TrampolineSetParamName = "SetParam" TrampolineGetParamName = "GetParam" @@ -405,10 +404,10 @@ func setValue(field string, idx int, typ dst.Expr) *dst.CaseClause { if shared.IsInterfaceType(typ) { assign = shared.AssignStmt(ie, val) } - caseClause := &dst.CaseClause{ - List: shared.Exprs(shared.IntLit(idx)), - Body: shared.Stmts(assign), - } + caseClause := shared.SwitchCase( + shared.Exprs(shared.IntLit(idx)), + shared.Stmts(assign), + ) return caseClause } @@ -424,10 +423,10 @@ func getValue(field string, idx int, typ dst.Expr) *dst.CaseClause { if shared.IsInterfaceType(typ) { ret = shared.ReturnStmt(shared.Exprs(ie)) } - caseClause := &dst.CaseClause{ - List: shared.Exprs(shared.IntLit(idx)), - Body: shared.Stmts(ret), - } + caseClause := shared.SwitchCase( + shared.Exprs(shared.IntLit(idx)), + shared.Stmts(ret), + ) return caseClause } diff --git a/tool/shared/ast.go b/tool/shared/ast.go index 9b77dde4..8a866b81 100644 --- a/tool/shared/ast.go +++ b/tool/shared/ast.go @@ -136,12 +136,21 @@ func ArrayType(elem dst.Expr) *dst.ArrayType { return &dst.ArrayType{Elt: elem} } +func IfStmt(init dst.Stmt, cond dst.Expr, body, elseBody *dst.BlockStmt) *dst.IfStmt { + return &dst.IfStmt{ + Init: dst.Clone(init).(dst.Stmt), + Cond: dst.Clone(cond).(dst.Expr), + Body: dst.Clone(body).(*dst.BlockStmt), + Else: dst.Clone(elseBody).(*dst.BlockStmt), + } +} + func EmptyStmt() *dst.EmptyStmt { return &dst.EmptyStmt{} } func ExprStmt(expr dst.Expr) *dst.ExprStmt { - return &dst.ExprStmt{X: expr} + return &dst.ExprStmt{X: dst.Clone(expr).(dst.Expr)} } func DeferStmt(call *dst.CallExpr) *dst.DeferStmt { @@ -160,6 +169,21 @@ func AssignStmt(lhs, rhs dst.Expr) *dst.AssignStmt { } } +func AssignStmts(lhs, rhs []dst.Expr) *dst.AssignStmt { + return &dst.AssignStmt{ + Lhs: lhs, + Tok: token.ASSIGN, + Rhs: rhs, + } +} + +func SwitchCase(list []dst.Expr, stmts []dst.Stmt) *dst.CaseClause { + return &dst.CaseClause{ + List: list, + Body: stmts, + } +} + func AddStructField(decl dst.Decl, name string, typ string) { gen, ok := decl.(*dst.GenDecl) if !ok {