Skip to content

Commit

Permalink
EAR taint checker: traverse callers when connecting the sources and s…
Browse files Browse the repository at this point in the history
…inks (#319)
  • Loading branch information
guodongli-google authored Jul 15, 2021
1 parent 99e3a7f commit c0094bf
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 52 deletions.
4 changes: 3 additions & 1 deletion internal/pkg/earpointer/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ func analyze(ssainput *buildssa.SSA) *Partitions {
}
}
}
return vis.state.ToPartitions()
p := vis.state.ToPartitions()
p.cg = cg
return p
}

// Builds the calling context set for each function.
Expand Down
5 changes: 5 additions & 0 deletions internal/pkg/earpointer/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"log"
"sort"
"strings"

"golang.org/x/tools/go/callgraph"
)

// parentMap maps a reference to its representative (i.e. the parent
Expand Down Expand Up @@ -296,6 +298,9 @@ type Partitions struct {
// It is constructed separately using the "ConstructFieldParentMap()"
// at the final phase.
revFields map[Reference][]Reference

// The call graph used to unify callers and callees.
cg *callgraph.Graph
}

func (state *state) ToPartitions() *Partitions {
Expand Down
114 changes: 66 additions & 48 deletions internal/pkg/earpointer/taint.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package earpointer
import (
"go/types"

"golang.org/x/tools/go/callgraph"

"github.com/google/go-flow-levee/internal/pkg/config"
"github.com/google/go-flow-levee/internal/pkg/utils"

Expand All @@ -26,15 +28,17 @@ import (

// Bounded traversal of an EAR heap.
type heapTraversal struct {
heap *Partitions
callees map[*ssa.Function]bool // the functions containing the references of interest
visited ReferenceSet // the visited references during the traversal
heap *Partitions
// The reachable functions containing the references of interest.
reachableFns map[*ssa.Function]bool
// The visited references during the traversal.
visited ReferenceSet
isTaintField func(named *types.Named, index int) bool
}

func (ht *heapTraversal) isWithinCallees(ref Reference) bool {
if fn := ref.Value().Parent(); fn != nil {
return ht.callees[fn]
return ht.reachableFns[fn]
}
// Globals and Builtins have no parents.
return true
Expand All @@ -59,7 +63,8 @@ func (ht *heapTraversal) srcRefs(rep Reference, tp types.Type, result ReferenceS
ht.srcRefs(rep, tp.Underlying(), result)
return
}
result[rep] = true // the current struct object is tainted
// Mark the current struct object as tainted.
result[rep] = true
// Look for the taint fields.
for i := 0; i < tt.NumFields(); i++ {
f := tt.Field(i)
Expand Down Expand Up @@ -132,7 +137,7 @@ func (ht *heapTraversal) canReach(sink ssa.Instruction, sources []*source.Source
sinkedRefs := make(map[Reference]bool)
for _, op := range sink.Operands(nil) {
// Use a separate heapTraversal to search for the sink references.
sinkHT := &heapTraversal{heap: ht.heap, callees: ht.callees, visited: make(ReferenceSet)}
sinkHT := &heapTraversal{heap: ht.heap, reachableFns: ht.reachableFns, visited: make(ReferenceSet)}
v := *op
if isLocal(v) || isGlobal(v) {
ref := MakeLocalWithEmptyContext(v)
Expand All @@ -153,47 +158,45 @@ func (ht *heapTraversal) canReach(sink ssa.Instruction, sources []*source.Source
return nil
}

// For a function, transitively get the functions called within this function.
// Argument "depth" controls the depth of the call chain.
// For example, return {g1,g2,g3} for "func f(){ g1(); g2() }, func g1(){ g3() }".
func calleeFunctions(fn *ssa.Function, result map[*ssa.Function]bool, depth uint) {
if depth <= 0 {
// For a function, transitively get the functions reachable from this function
// according to the call graph. Both callers and callees are considered.
// Argument "depth" controls the depth of the call chain, and "result" is
// to store the set of reachable functions. For example,
// func f(){ g1(); g2() }
// func g1(){ g3() }
// for input "g1", result = {f,g1,g2,g3} if depth>1, and result = {f, g1} if depth=1.
func boundedReachableFunctions(fn *ssa.Function, cg *callgraph.Graph, depth uint, result map[*ssa.Function]bool) {
if depth <= 0 || result[fn] {
return
}
for _, b := range fn.Blocks {
for _, instr := range b.Instrs {
if call, ok := instr.(*ssa.Call); ok {
// TODO(#317): use more advanced call graph.
// skip empty, unlinked, or visited functions
if callee := call.Call.StaticCallee(); callee != nil && len(callee.Blocks) > 0 && !result[callee] {
result[callee] = true
calleeFunctions(callee, result, depth-1)
}
}
}
}
}

func boundedDepthCallees(fn *ssa.Function, depth uint) map[*ssa.Function]bool {
result := make(map[*ssa.Function]bool)
result[fn] = true
calleeFunctions(fn, result, depth)
return result
node := cg.Nodes[fn]
if node == nil {
return
}
// Visit the callees within "fn".
for _, out := range node.Out {
boundedReachableFunctions(out.Callee.Func, cg, depth-1, result)
}
// Visit the callers of "fn".
for _, in := range node.In {
boundedReachableFunctions(in.Caller.Func, cg, depth-1, result)
}
}

// Obtain the references which are aliases of a taint source, with field sensitivity.
// Argument "heap" is an immutable EAR heap containing alias information;
// "callees" is used to bound the searching of source references in the heap.
// "reachable" is used to bound the searching of source references in the heap.
func srcAliasRefs(src *source.Source, isTaintField func(named *types.Named, index int) bool,
heap *Partitions, callees map[*ssa.Function]bool) ReferenceSet {
heap *Partitions, reachable map[*ssa.Function]bool) ReferenceSet {

val, ok := src.Node.(ssa.Value)
if !ok {
return nil
}
rep := heap.Representative(MakeLocalWithEmptyContext(val))
refs := make(ReferenceSet)
ht := &heapTraversal{heap: heap, callees: callees, visited: make(ReferenceSet), isTaintField: isTaintField}
ht := &heapTraversal{heap: heap, reachableFns: reachable, visited: make(ReferenceSet), isTaintField: isTaintField}
ht.srcRefs(rep, val.Type(), refs)
return refs
}
Expand All @@ -206,33 +209,49 @@ type SourceSinkTrace struct {

// Look for <source, sink> pairs by examining the heap alias information.
func SourcesToSinks(funcSources source.ResultType, isTaintField func(named *types.Named, index int) bool,
heap *Partitions, conf *config.Config) []*SourceSinkTrace {
heap *Partitions, conf *config.Config) map[ssa.Instruction]*SourceSinkTrace {

var traces []*SourceSinkTrace
// A map from a callsite to its possible callees.
calleeMap := mapCallees(heap.cg)
traces := make(map[ssa.Instruction]*SourceSinkTrace)
for fn, sources := range funcSources {
// Transitively get the set of functions called within "fn".
// Transitively get the set of functions reachable from "fn".
// This set is used to narrow down the set of references needed to be
// considered during EAR heap traversal. It can also help reducing the
// false positives and boosting the performance.
callees := boundedDepthCallees(fn, conf.EARTaintCallSpan)
// For example,
// func f1(){ f2(); g4() }
// func f2(){ g1(); g2() }
// func g1(){ g3() }
// g1 can reach f2 through the caller, and then f1 similarly,
// and then g4 through the callee.
// g1's full reachable set is {f1,f2,g1,g2,g3,g4}.
reachable := make(map[*ssa.Function]bool)
boundedReachableFunctions(fn, heap.cg, conf.EARTaintCallSpan, reachable)
// Start from the set of taint sources.
srcRefs := make(map[*source.Source]ReferenceSet)
for _, s := range sources {
srcRefs[s] = srcAliasRefs(s, isTaintField, heap, callees)
srcRefs[s] = srcAliasRefs(s, isTaintField, heap, reachable)
}
// Traverse all the callee functions (not just the ones with sink sources)
ht := &heapTraversal{heap: heap, callees: callees, visited: make(ReferenceSet)}
for member := range callees {
// Traverse all the reachable functions (not just the ones with sink sources)
// in search for connected sinks.
ht := &heapTraversal{heap: heap, reachableFns: reachable, visited: make(ReferenceSet)}
for member := range reachable {
for _, b := range member.Blocks {
for _, instr := range b.Instrs {
switch v := instr.(type) {
case *ssa.Call:
callees := calleeMap[&v.Call]
sink := instr
// TODO(#317): use more advanced call graph.
callee := v.Call.StaticCallee()
if callee != nil && conf.IsSink(utils.DecomposeFunction(callee)) {
if src := ht.canReach(sink, sources, srcRefs); src != nil {
traces = append(traces, &SourceSinkTrace{Src: src, Sink: sink})
break
for _, callee := range callees {
if conf.IsSink(utils.DecomposeFunction(callee)) {
if src := ht.canReach(sink, sources, srcRefs); src != nil {
// If a previous source has been found, be in favor of the source within the same
// function. This can be extended to be in favor of the source closest to the sink.
if _, ok := traces[instr]; !ok || src.Node.Parent() == sink.Parent() {
traces[sink] = &SourceSinkTrace{Src: src, Sink: sink}
}
}
}
}
case *ssa.Panic:
Expand All @@ -241,8 +260,7 @@ func SourcesToSinks(funcSources source.ResultType, isTaintField func(named *type
}
sink := instr
if src := ht.canReach(sink, sources, srcRefs); src != nil {
traces = append(traces, &SourceSinkTrace{Src: src, Sink: sink})
break
traces[sink] = &SourceSinkTrace{Src: src, Sink: sink}
}

}
Expand Down
9 changes: 9 additions & 0 deletions internal/pkg/levee/levee_ear_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,12 @@ func TestLeveeEAR(t *testing.T) {
// analysistest.Run(t, dataDir, Analyzer, "./src/levee_analysistest/example/tests/structlit") // TODO: NP have been fixed?
analysistest.Run(t, dataDir, Analyzer, "./src/levee_analysistest/example/tests/typealias")
}

func TestLeveeEARInter(t *testing.T) {
dataDir := analysistest.TestData()
if err := Analyzer.Flags.Set("config", dataDir+"/test-ear-config.yaml"); err != nil {
t.Error(err)
}

analysistest.Run(t, dataDir, Analyzer, "./src/levee_analysistest/ear/tests/...")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package call

import (
"levee_analysistest/example/core"
)

func sinkf2(i interface{}) {
core.Sink(i) // want "a source has reached a sink"
}

func createSource2() interface{} {
return &core.Source{}
}

func identity(arg interface{}) interface{} {
return arg
}

// Test the case where:
// (1) the sink function doesn't embed the source function; and
// (2) the source function doesn't embed the sink function.
func TestSrcSinkInDifferentCallees() {
// The source is created in a callee.
s := createSource2()
i := identity(s)
// The sink is within another callee.
sinkf2(i)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package call

import (
"levee_analysistest/example/core"
)

func sinkf1(i interface{}) {
core.Sink(i) // want "a source has reached a sink"
}

func f1(s string) {
sinkf1(s)
}

func createData() string {
src := &core.Source{}
return src.Data
}

// Test the case where the source is introduced in callees.
func TestSrcInCallees() {
s := createData()
core.Sink(s) // want "a source has reached a sink"
f1(s)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ import (
"levee_analysistest/example/core"
)

type sourceBuilder struct {
type SourceBuilder struct {
sourcePtr *core.Source
sourceVal core.Source
}

func (b *sourceBuilder) buildP() {
func (b *SourceBuilder) buildP() {
core.Sinkf("Building cluster %v", b.sourcePtr) // want "a source has reached a sink"
core.Sinkf("Building cluster %v", b.sourceVal) // want "a source has reached a sink"
}

func (b sourceBuilder) buildV() {
func (b SourceBuilder) buildV() {
core.Sinkf("Building cluster %v", b.sourcePtr) // want "a source has reached a sink"
core.Sinkf("Building cluster %v", b.sourceVal) // want "a source has reached a sink"
}

0 comments on commit c0094bf

Please sign in to comment.