From b107a349ba494523ccb648371f544e7513bf2d3f Mon Sep 17 00:00:00 2001 From: Kent 'picat' Gruber Date: Fri, 5 Jan 2024 12:10:44 -0500 Subject: [PATCH 1/5] Add `callgraphutil.WriteDOT` --- callgraphutil/dot.go | 41 ++++++++ callgraphutil/dot_test.go | 196 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 237 insertions(+) create mode 100644 callgraphutil/dot.go create mode 100644 callgraphutil/dot_test.go diff --git a/callgraphutil/dot.go b/callgraphutil/dot.go new file mode 100644 index 0000000..6c2413e --- /dev/null +++ b/callgraphutil/dot.go @@ -0,0 +1,41 @@ +package callgraphutil + +import ( + "bufio" + "fmt" + "io" + + "golang.org/x/tools/go/callgraph" +) + +// WriteDOT writes the given callgraph.Graph to the given io.Writer in the +// DOT format, which can be used to generate a visual representation of the +// call graph using Graphviz. +func WriteDOT(w io.Writer, g *callgraph.Graph) error { + b := bufio.NewWriter(w) + defer b.Flush() + + b.WriteString("digraph callgraph {\n") + b.WriteString("\tgraph [fontname=\"Helvetica\"];\n") + b.WriteString("\tnode [fontname=\"Helvetica\"];\n") + b.WriteString("\tedge [fontname=\"Helvetica\"];\n") + + edges := []*callgraph.Edge{} + + // Write nodes. + for _, n := range g.Nodes { + b.WriteString(fmt.Sprintf("\t%q [label=%q];\n", fmt.Sprintf("%d", n.ID), n.Func)) + + // Add edges + edges = append(edges, n.Out...) + } + + // Write edges. + for _, e := range edges { + b.WriteString(fmt.Sprintf("\t%q -> %q [label=%q];\n", fmt.Sprintf("%d", e.Caller.ID), fmt.Sprintf("%d", e.Callee.ID), e.Site)) + } + + b.WriteString("}\n") + + return nil +} diff --git a/callgraphutil/dot_test.go b/callgraphutil/dot_test.go new file mode 100644 index 0000000..8cea70c --- /dev/null +++ b/callgraphutil/dot_test.go @@ -0,0 +1,196 @@ +package callgraphutil_test + +import ( + "bytes" + "context" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "testing" + + "github.com/go-git/go-git/v5" + "github.com/picatz/taint/callgraphutil" + "golang.org/x/tools/go/callgraph" + "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/ssa/ssautil" +) + +func cloneGitHubRepository(ctx context.Context, ownerName, repoName string) (string, string, error) { + // Get the owner and repo part of the URL. + ownerAndRepo := ownerName + "/" + repoName + + // Get the directory path. + dir := filepath.Join(os.TempDir(), "taint", "github", ownerAndRepo) + + // Check if the directory exists. + _, err := os.Stat(dir) + if err == nil { + // If the directory exists, we'll assume it's a valid repository, + // and return the directory. Open the directory to + repo, err := git.PlainOpen(dir) + if err != nil { + return dir, "", fmt.Errorf("%w", err) + } + + // Get the repository's HEAD. + head, err := repo.Head() + if err != nil { + return dir, "", fmt.Errorf("%w", err) + } + + return dir, head.Hash().String(), nil + } + + // Clone the repository. + repo, err := git.PlainCloneContext(ctx, dir, false, &git.CloneOptions{ + URL: fmt.Sprintf("https://github.com/%s", ownerAndRepo), + Depth: 1, + Tags: git.NoTags, + SingleBranch: true, + }) + if err != nil { + return dir, "", fmt.Errorf("%w", err) + } + + // Get the repository's HEAD. + head, err := repo.Head() + if err != nil { + return dir, "", fmt.Errorf("%w", err) + } + + return dir, head.Hash().String(), nil +} + +func loadPackages(ctx context.Context, dir, pattern string) ([]*packages.Package, error) { + loadMode := + packages.NeedName | + packages.NeedDeps | + packages.NeedFiles | + packages.NeedModule | + packages.NeedTypes | + packages.NeedImports | + packages.NeedSyntax | + packages.NeedTypesInfo + // packages.NeedTypesSizes | + // packages.NeedCompiledGoFiles | + // packages.NeedExportFile | + // packages.NeedEmbedPatterns + + // parseMode := parser.ParseComments + parseMode := parser.SkipObjectResolution + + // patterns := []string{dir} + patterns := []string{pattern} + // patterns := []string{"all"} + + pkgs, err := packages.Load(&packages.Config{ + Mode: loadMode, + Context: ctx, + Env: os.Environ(), + Dir: dir, + Tests: false, + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parseMode) + }, + }, patterns...) + if err != nil { + return nil, err + } + + return pkgs, nil + +} + +func loadSSA(ctx context.Context, pkgs []*packages.Package) (mainFn *ssa.Function, srcFns []*ssa.Function, err error) { + ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug + + // Analyze the package. + ssaProg, ssaPkgs := ssautil.Packages(pkgs, ssaBuildMode) + + ssaProg.Build() + + for _, pkg := range ssaPkgs { + pkg.Build() + } + + mainPkgs := ssautil.MainPackages(ssaPkgs) + + mainFn = mainPkgs[0].Members["main"].(*ssa.Function) + + for _, pkg := range ssaPkgs { + for _, fn := range pkg.Members { + if fn.Object() == nil { + continue + } + + if fn.Object().Name() == "_" { + continue + } + + pkgFn := pkg.Func(fn.Object().Name()) + if pkgFn == nil { + continue + } + + var addAnons func(f *ssa.Function) + addAnons = func(f *ssa.Function) { + srcFns = append(srcFns, f) + for _, anon := range f.AnonFuncs { + addAnons(anon) + } + } + addAnons(pkgFn) + } + } + + if mainFn == nil { + err = fmt.Errorf("failed to find main function") + return + } + + return +} + +func loadCallGraph(ctx context.Context, mainFn *ssa.Function, srcFns []*ssa.Function) (*callgraph.Graph, error) { + cg, err := callgraphutil.NewGraph(mainFn, srcFns...) + if err != nil { + return nil, fmt.Errorf("failed to create new callgraph: %w", err) + } + + return cg, nil +} + +func TestWriteDOT(t *testing.T) { + repo, _, err := cloneGitHubRepository(context.Background(), "picatz", "taint") + if err != nil { + t.Fatal(err) + } + + pkgs, err := loadPackages(context.Background(), repo, "./...") + if err != nil { + t.Fatal(err) + } + + mainFn, srcFns, err := loadSSA(context.Background(), pkgs) + if err != nil { + t.Fatal(err) + } + + cg, err := loadCallGraph(context.Background(), mainFn, srcFns) + if err != nil { + t.Fatal(err) + } + + output := &bytes.Buffer{} + + err = callgraphutil.WriteDOT(output, cg) + if err != nil { + t.Fatal(err) + } + + fmt.Println(output.String()) +} From b763da4f2605fbc7829a3f2a2df045a54e86f543 Mon Sep 17 00:00:00 2001 From: Kent 'picat' Gruber Date: Fri, 5 Jan 2024 16:11:46 -0500 Subject: [PATCH 2/5] Normalize box-shaped nodes that don't overlap --- callgraphutil/dot.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callgraphutil/dot.go b/callgraphutil/dot.go index 6c2413e..b750a2b 100644 --- a/callgraphutil/dot.go +++ b/callgraphutil/dot.go @@ -16,8 +16,8 @@ func WriteDOT(w io.Writer, g *callgraph.Graph) error { defer b.Flush() b.WriteString("digraph callgraph {\n") - b.WriteString("\tgraph [fontname=\"Helvetica\"];\n") - b.WriteString("\tnode [fontname=\"Helvetica\"];\n") + b.WriteString("\tgraph [fontname=\"Helvetica\", overlap=false normalize=true];\n") + b.WriteString("\tnode [fontname=\"Helvetica\" shape=box];\n") b.WriteString("\tedge [fontname=\"Helvetica\"];\n") edges := []*callgraph.Edge{} From 8e88f9eca66fe116de24f921ffcd65cd90bb1efe Mon Sep 17 00:00:00 2001 From: Kent 'picat' Gruber Date: Fri, 5 Jan 2024 16:13:07 -0500 Subject: [PATCH 3/5] Write `root` node if it exists --- callgraphutil/dot.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/callgraphutil/dot.go b/callgraphutil/dot.go index b750a2b..53d613f 100644 --- a/callgraphutil/dot.go +++ b/callgraphutil/dot.go @@ -22,6 +22,11 @@ func WriteDOT(w io.Writer, g *callgraph.Graph) error { edges := []*callgraph.Edge{} + // Check if root node exists, if so, write it. + if g.Root != nil { + b.WriteString(fmt.Sprintf("\troot = %d;\n", g.Root.ID)) + } + // Write nodes. for _, n := range g.Nodes { b.WriteString(fmt.Sprintf("\t%q [label=%q];\n", fmt.Sprintf("%d", n.ID), n.Func)) From 37d7842cfe9a42a5cb8e76b32c0b150bf1786121 Mon Sep 17 00:00:00 2001 From: Kent 'picat' Gruber Date: Fri, 5 Jan 2024 16:13:30 -0500 Subject: [PATCH 4/5] Use integer for node IDs instead of quoted string --- callgraphutil/dot.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callgraphutil/dot.go b/callgraphutil/dot.go index 53d613f..7f9316e 100644 --- a/callgraphutil/dot.go +++ b/callgraphutil/dot.go @@ -29,7 +29,7 @@ func WriteDOT(w io.Writer, g *callgraph.Graph) error { // Write nodes. for _, n := range g.Nodes { - b.WriteString(fmt.Sprintf("\t%q [label=%q];\n", fmt.Sprintf("%d", n.ID), n.Func)) + b.WriteString(fmt.Sprintf("\t%d [label=%q];\n", n.ID, n.Func)) // Add edges edges = append(edges, n.Out...) @@ -37,7 +37,7 @@ func WriteDOT(w io.Writer, g *callgraph.Graph) error { // Write edges. for _, e := range edges { - b.WriteString(fmt.Sprintf("\t%q -> %q [label=%q];\n", fmt.Sprintf("%d", e.Caller.ID), fmt.Sprintf("%d", e.Callee.ID), e.Site)) + b.WriteString(fmt.Sprintf("\t%d -> %d;\n", e.Caller.ID, e.Callee.ID)) } b.WriteString("}\n") From 1d4aa7f3a8d299c59df39117cac2c8110602e56f Mon Sep 17 00:00:00 2001 From: Kent 'picat' Gruber Date: Fri, 5 Jan 2024 18:09:12 -0500 Subject: [PATCH 5/5] Use cluster subgraphs to scope package functions This doesn't work for all Graphviz layout engines, but it's nice when it does. --- callgraphutil/dot.go | 46 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/callgraphutil/dot.go b/callgraphutil/dot.go index 7f9316e..07b8609 100644 --- a/callgraphutil/dot.go +++ b/callgraphutil/dot.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "io" + "strings" "golang.org/x/tools/go/callgraph" ) @@ -22,19 +23,60 @@ func WriteDOT(w io.Writer, g *callgraph.Graph) error { edges := []*callgraph.Edge{} + nodesByPkg := map[string][]*callgraph.Node{} + + addPkgNode := func(n *callgraph.Node) { + // TODO: fix this so there's not so many "shared" functions? + // + // It is a bit of a hack, but it works for now. + var pkgPath string + if n.Func.Pkg != nil { + pkgPath = n.Func.Pkg.Pkg.Path() + } else { + pkgPath = "shared" + } + + // Check if the package already exists. + if _, ok := nodesByPkg[pkgPath]; !ok { + // If not, create it. + nodesByPkg[pkgPath] = []*callgraph.Node{} + } + nodesByPkg[pkgPath] = append(nodesByPkg[pkgPath], n) + } + // Check if root node exists, if so, write it. if g.Root != nil { b.WriteString(fmt.Sprintf("\troot = %d;\n", g.Root.ID)) } - // Write nodes. + // Process nodes and edges. for _, n := range g.Nodes { - b.WriteString(fmt.Sprintf("\t%d [label=%q];\n", n.ID, n.Func)) + // Add node to map of nodes by package. + addPkgNode(n) // Add edges edges = append(edges, n.Out...) } + // Write nodes by package. + for pkg, nodes := range nodesByPkg { + // Make the pkg name sugraph cluster friendly (remove dots, dashes, and slashes). + clusterName := strings.Replace(pkg, ".", "_", -1) + clusterName = strings.Replace(clusterName, "/", "_", -1) + clusterName = strings.Replace(clusterName, "-", "_", -1) + + // NOTE: even if we're using a subgraph cluster, it may not be + // respected by all Graphviz layout engines. For example, the + // "dot" engine will respect the cluster, but the "sfdp" engine + // will not. + b.WriteString(fmt.Sprintf("\tsubgraph cluster_%s {\n", clusterName)) + b.WriteString(fmt.Sprintf("\t\tlabel=%q;\n", pkg)) + for _, n := range nodes { + b.WriteString(fmt.Sprintf("\t\t%d [label=%q];\n", n.ID, n.Func)) + } + b.WriteString("\t}\n") + } + // Write edges. for _, e := range edges { b.WriteString(fmt.Sprintf("\t%d -> %d;\n", e.Caller.ID, e.Callee.ID))