Skip to content

Commit

Permalink
Merge pull request #2 from picatz/improve
Browse files Browse the repository at this point in the history
Add a bit more inline documentation
  • Loading branch information
picatz authored Jan 2, 2023
2 parents d02bda4 + 75b0879 commit f43d4ad
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
32 changes: 16 additions & 16 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va
return true, src, tv
}
}
// Free variables can be traversed using the value's referrers, or the
// value's parent's referrers. Each referrer is either an SSA value or
// instruction.
//
// These can be tricky because they can be used in a few different ways,
// preventing us from just checking the value's referrers in all cases.
case *ssa.FreeVar:
refs := value.Referrers()
for _, ref := range *refs {
Expand Down Expand Up @@ -318,19 +324,6 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va
tainted, src, tv := checkSSAValue(path, sources, val, valueSet{})
if tainted {
return true, src, tv
// if tv.Referrers() != nil {
// for _, ref := range *tv.Referrers() {
// // if value.Name() == "input" {
// fmt.Printf("\t\t\t\t tv %T: %[1]v\n", tv)
// for _, instr := range ref.Block().Instrs {
// fmt.Printf("\t\t\t\t ref ----------------> %T: %[1]v\n", instr)
// }
// // }
// }
// }
// if tv.Name() == value.Name() {
// return true, src, tv
// }
}
}
}
Expand Down Expand Up @@ -432,34 +425,38 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va
if tainted {
return true, src, tv
}

case *ssa.Slice:
// Check the sliced value.
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
if tainted {
return true, src, tv
}
case *ssa.MakeInterface:
// Check the value being made into an interface.
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
if tainted {
return true, src, tv
}
case *ssa.Convert:
// Check the value being converted.
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
if tainted {
return true, src, tv
}
case *ssa.Extract:
// Check the value being extracted.
tainted, src, tv := checkSSAValue(path, sources, value.Tuple, visited)
if tainted {
return true, src, tv
}
case *ssa.Lookup:
// Check the string or map value
// Check the string or map value being looked up.
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
if tainted {
return true, src, tv
}
// Check the index value

// Check the index value being looked up.
refs := value.Index.Referrers()
if refs != nil {
for _, ref := range *refs {
Expand Down Expand Up @@ -492,6 +489,8 @@ func checkSSAInstruction(path callgraph.Path, sources Sources, i ssa.Instruction

switch instr := i.(type) {
case *ssa.Store:
// Store instructions need to be checked for both the value being stored,
// and the address being stored to.
tainted, src, tv := checkSSAValue(path, sources, instr.Val, visited)
if tainted {
return true, src, tv
Expand All @@ -501,6 +500,7 @@ func checkSSAInstruction(path callgraph.Path, sources Sources, i ssa.Instruction
return true, src, tv
}
case *ssa.Call:
// Check the operands of the call instruction.
for _, instrValue := range instr.Operands(nil) {
if instrValue == nil {
continue
Expand Down
18 changes: 18 additions & 0 deletions sources_sinks.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ import (
"golang.org/x/tools/go/ssa"
)

// valueSet is a set of ssa.Values that can be used to track
// the values that have been visited during a traversal. This
// is used to prevent infinite recursion, and to prevent
// visiting the same value multiple times.
type valueSet map[ssa.Value]struct{}

// includes returns true if the value is in the set.
func (v valueSet) includes(sv ssa.Value) bool {
if v == nil {
return false
Expand All @@ -14,15 +19,20 @@ func (v valueSet) includes(sv ssa.Value) bool {
return ok
}

// add adds the value to the set.
func (v valueSet) add(sv ssa.Value) {
if v == nil {
v = valueSet{}
}
v[sv] = struct{}{}
}

// stringSet is a set of unique strings that express
// the types of sources and sinks that are being
// tracked.
type stringSet map[string]struct{}

// includes returns true if the string is in the set.
func (t stringSet) includes(str string) (string, bool) {
if t == nil {
return "", false
Expand All @@ -31,8 +41,12 @@ func (t stringSet) includes(str string) (string, bool) {
return str, ok
}

// Sources are the types that are considered "sources" of
// tainted data in the program.
type Sources = stringSet

// NewSources returns a new Sources set with the given
// source types.
func NewSources(sourceTypes ...string) Sources {
srcs := Sources{}

Expand All @@ -43,8 +57,12 @@ func NewSources(sourceTypes ...string) Sources {
return srcs
}

// Sinks are the types that are considered "sinks" that
// tainted data in the program may flow into.
type Sinks = stringSet

// NewSinks returns a new Sinks set with the given
// sink types.
func NewSinks(sinkTypes ...string) Sinks {
snks := Sinks{}

Expand Down

0 comments on commit f43d4ad

Please sign in to comment.