Skip to content

Commit

Permalink
Merge pull request #20 from picatz/callgraph-anon-funcs
Browse files Browse the repository at this point in the history
Ensure that anonymous functions are constructed with `taint` command
  • Loading branch information
picatz authored Jan 2, 2024
2 parents 934e80a + 036aed2 commit e38f7ac
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 13 deletions.
20 changes: 10 additions & 10 deletions callgraph/callgraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ func (e Edge) String() string {
return fmt.Sprintf("%s → %s", e.Caller, e.Callee)
}

// Description returns a human-readable description of the edge.
func (e Edge) Description() string {
var prefix string
switch e.Site.(type) {
Expand Down Expand Up @@ -447,14 +448,15 @@ func (p Path) String() string {

type Paths []Path

func PathSearch(start *Node, isEnd func(*Node) bool) Path {
func PathSearch(start *Node, isMatch func(*Node) bool) Path {
stack := make(Path, 0, 32)
seen := make(map[*Node]bool)
var search func(n *Node) Path
search = func(n *Node) Path {
if !seen[n] {
// debug("searching: %v\n", n)
seen[n] = true
if isEnd(n) {
if isMatch(n) {
return stack
}
for _, e := range n.Out {
Expand All @@ -470,33 +472,31 @@ func PathSearch(start *Node, isEnd func(*Node) bool) Path {
return search(start)
}

func PathsSearch(start *Node, isEnd func(*Node) bool) Paths {
func PathsSearch(start *Node, isMatch func(*Node) bool) Paths {
paths := Paths{}

stack := make(Path, 0, 32)
seen := make(map[*Node]bool)
var search func(n *Node)
search = func(n *Node) {
// debug("searching: %v\n", n)
if !seen[n] {
seen[n] = true
if isEnd(n) {
if isMatch(n) {
paths = append(paths, stack)

stack = make(Path, 0, 32)
seen = make(map[*Node]bool)
return
}
for _, e := range n.Out {
if e.Caller.Func.Name() != "main" {
stack = append(stack, e) // push
}
// debug("\tout: %v\n", e)
stack = append(stack, e) // push
search(e.Callee)
if len(stack) == 0 {
continue
}
if e.Caller.Func.Name() != "main" {
stack = stack[:len(stack)-1] // pop
}
stack = stack[:len(stack)-1] // pop
}
}
}
Expand Down
13 changes: 10 additions & 3 deletions cmd/taint/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,19 @@ var builtinCommandLoad = &command{
continue
}

pngFn := pkg.Func(fn.Object().Name())
if pngFn == nil {
pkgFn := pkg.Func(fn.Object().Name())
if pkgFn == nil {
continue
}

srcFns = append(srcFns, pngFn)
var addAnons func(f *ssa.Function)
addAnons = func(f *ssa.Function) {
srcFns = append(srcFns, f)
for _, anon := range f.AnonFuncs {
addAnons(anon)
}
}
addAnons(pkgFn)
}
}

Expand Down
124 changes: 124 additions & 0 deletions cmd/taint/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package main_test

import (
"context"
"go/ast"
"go/parser"
"go/token"
"os"
"testing"

"github.com/picatz/taint/callgraph"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)

func TestLoadAndSearch(t *testing.T) {
loadMode :=
packages.NeedName |
packages.NeedDeps |
packages.NeedFiles |
packages.NeedCompiledGoFiles |
packages.NeedModule |
packages.NeedTypes |
packages.NeedImports |
packages.NeedSyntax |
packages.NeedTypesInfo
// packages.NeedTypesSizes |
// packages.NeedExportFile |
// packages.NeedEmbedPatterns

// parseMode := parser.ParseComments
parseMode := parser.SkipObjectResolution

// patterns := []string{dir}
patterns := []string{"./..."}
// patterns := []string{"all"}

pkgs, err := packages.Load(&packages.Config{
Mode: loadMode,
Context: context.Background(),
Env: os.Environ(),
Dir: "./example",
Tests: false,
ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) {
return parser.ParseFile(fset, filename, src, parseMode)
},
}, patterns...)
if err != nil {
t.Fatal(err)
}

ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug

// Analyze the package.
ssaProg, ssaPkgs := ssautil.Packages(pkgs, ssaBuildMode)

ssaProg.Build()

for _, pkg := range ssaPkgs {
pkg.Build()
}

mainPkgs := ssautil.MainPackages(ssaPkgs)

mainFn := mainPkgs[0].Members["main"].(*ssa.Function)

var srcFns []*ssa.Function

for _, pkg := range ssaPkgs {
for _, fn := range pkg.Members {
if fn.Object() == nil {
continue
}

if fn.Object().Name() == "_" {
continue
}

pkgFn := pkg.Func(fn.Object().Name())
if pkgFn == nil {
continue
}

var addAnons func(f *ssa.Function)
addAnons = func(f *ssa.Function) {
srcFns = append(srcFns, f)
for _, anon := range f.AnonFuncs {
addAnons(anon)
}
}
addAnons(pkgFn)
}
}

if mainFn == nil {
t.Fatal("main function not found")
}

cg, err := callgraph.New(mainFn, srcFns...)
if err != nil {
t.Fatal(err)
}

t.Log(cg)

// path := callgraph.PathSearchCallTo(cg.Root, "(*database/sql.DB).Query")

// if path == nil {
// t.Fatal("no path found")
// }

// t.Log(path)

paths := callgraph.PathsSearchCallTo(cg.Root, "(*database/sql.DB).Query")

if len(paths) == 0 {
t.Fatal("no paths found")
}

for _, path := range paths {
t.Log(path)
}
}

0 comments on commit e38f7ac

Please sign in to comment.