diff --git a/check.go b/check.go index 2f5de53..d60c374 100644 --- a/check.go +++ b/check.go @@ -246,6 +246,12 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va return true, src, tv } } + // Free variables can be traversed using the value's referrers, or the + // value's parent's referrers. Each referrer is either an SSA value or + // instruction. + // + // These can be tricky because they can be used in a few different ways, + // preventing us from just checking the value's referrers in all cases. case *ssa.FreeVar: refs := value.Referrers() for _, ref := range *refs { @@ -318,19 +324,6 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va tainted, src, tv := checkSSAValue(path, sources, val, valueSet{}) if tainted { return true, src, tv - // if tv.Referrers() != nil { - // for _, ref := range *tv.Referrers() { - // // if value.Name() == "input" { - // fmt.Printf("\t\t\t\t tv %T: %[1]v\n", tv) - // for _, instr := range ref.Block().Instrs { - // fmt.Printf("\t\t\t\t ref ----------------> %T: %[1]v\n", instr) - // } - // // } - // } - // } - // if tv.Name() == value.Name() { - // return true, src, tv - // } } } } @@ -432,34 +425,38 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va if tainted { return true, src, tv } - case *ssa.Slice: + // Check the sliced value. tainted, src, tv := checkSSAValue(path, sources, value.X, visited) if tainted { return true, src, tv } case *ssa.MakeInterface: + // Check the value being made into an interface. tainted, src, tv := checkSSAValue(path, sources, value.X, visited) if tainted { return true, src, tv } case *ssa.Convert: + // Check the value being converted. tainted, src, tv := checkSSAValue(path, sources, value.X, visited) if tainted { return true, src, tv } case *ssa.Extract: + // Check the value being extracted. tainted, src, tv := checkSSAValue(path, sources, value.Tuple, visited) if tainted { return true, src, tv } case *ssa.Lookup: - // Check the string or map value + // Check the string or map value being looked up. tainted, src, tv := checkSSAValue(path, sources, value.X, visited) if tainted { return true, src, tv } - // Check the index value + + // Check the index value being looked up. refs := value.Index.Referrers() if refs != nil { for _, ref := range *refs { @@ -492,6 +489,8 @@ func checkSSAInstruction(path callgraph.Path, sources Sources, i ssa.Instruction switch instr := i.(type) { case *ssa.Store: + // Store instructions need to be checked for both the value being stored, + // and the address being stored to. tainted, src, tv := checkSSAValue(path, sources, instr.Val, visited) if tainted { return true, src, tv @@ -501,6 +500,7 @@ func checkSSAInstruction(path callgraph.Path, sources Sources, i ssa.Instruction return true, src, tv } case *ssa.Call: + // Check the operands of the call instruction. for _, instrValue := range instr.Operands(nil) { if instrValue == nil { continue diff --git a/sources_sinks.go b/sources_sinks.go index 0dbcd70..7f3f78c 100644 --- a/sources_sinks.go +++ b/sources_sinks.go @@ -4,8 +4,13 @@ import ( "golang.org/x/tools/go/ssa" ) +// valueSet is a set of ssa.Values that can be used to track +// the values that have been visited during a traversal. This +// is used to prevent infinite recursion, and to prevent +// visiting the same value multiple times. type valueSet map[ssa.Value]struct{} +// includes returns true if the value is in the set. func (v valueSet) includes(sv ssa.Value) bool { if v == nil { return false @@ -14,6 +19,7 @@ func (v valueSet) includes(sv ssa.Value) bool { return ok } +// add adds the value to the set. func (v valueSet) add(sv ssa.Value) { if v == nil { v = valueSet{} @@ -21,8 +27,12 @@ func (v valueSet) add(sv ssa.Value) { v[sv] = struct{}{} } +// stringSet is a set of unique strings that express +// the types of sources and sinks that are being +// tracked. type stringSet map[string]struct{} +// includes returns true if the string is in the set. func (t stringSet) includes(str string) (string, bool) { if t == nil { return "", false @@ -31,8 +41,12 @@ func (t stringSet) includes(str string) (string, bool) { return str, ok } +// Sources are the types that are considered "sources" of +// tainted data in the program. type Sources = stringSet +// NewSources returns a new Sources set with the given +// source types. func NewSources(sourceTypes ...string) Sources { srcs := Sources{} @@ -43,8 +57,12 @@ func NewSources(sourceTypes ...string) Sources { return srcs } +// Sinks are the types that are considered "sinks" that +// tainted data in the program may flow into. type Sinks = stringSet +// NewSinks returns a new Sinks set with the given +// sink types. func NewSinks(sinkTypes ...string) Sinks { snks := Sinks{}