diff --git a/contextcheck.go b/contextcheck.go index 97832ad..123df74 100644 --- a/contextcheck.go +++ b/contextcheck.go @@ -68,6 +68,10 @@ var ( type resInfo struct { Valid bool Funcs []string + + // reuse for doc + ReqCtx bool + Skip bool } type ctxFact map[string]resInfo @@ -238,24 +242,45 @@ func (r *runner) checkIsEntry(f *ssa.Function) entryType { return EntryWithCtx } + reqctx, skip := r.docFlag(f) + // check is `func handler(w http.ResponseWriter, r *http.Request) {}` - if r.checkIsHttpHandler(f) { + // or use '// @contextcheck(req_has_ctx)' + if r.checkIsHttpHandler(f, reqctx) { return EntryWithHttpHandler } - if r.skipByNolint(f) { + if skip { return EntryNone } return EntryNormal } +func (r *runner) docFlag(f *ssa.Function) (reqctx, skip bool) { + key := "doc:" + f.RelString(nil) + res, ok := r.getValue(key, f) + if ok { + return res.ReqCtx, res.Skip + } + + for _, v := range r.getDocFromFunc(f) { + if len(nolintRe.FindString(v.Text)) > 0 && strings.Contains(v.Text, "contextcheck") { + res.Skip = true + } else if strings.HasPrefix(v.Text, "// @contextcheck(req_has_ctx)") { + res.ReqCtx = true + } + } + r.currentFact[key] = res + return res.ReqCtx, res.Skip +} + var nolintRe = regexp.MustCompile(`^//\s?nolint:`) -func (r *runner) skipByNolint(f *ssa.Function) bool { +func (r *runner) getDocFromFunc(f *ssa.Function) []*ast.Comment { file := analysisutil.File(r.pass, f.Pos()) if file == nil { - return false + return nil } // only support FuncDecl comment @@ -267,15 +292,9 @@ func (r *runner) skipByNolint(f *ssa.Function) bool { } } if fd == nil || fd.Doc == nil || len(fd.Doc.List) == 0 { - return false - } - - for _, v := range fd.Doc.List { - if len(nolintRe.FindString(v.Text)) > 0 && strings.Contains(v.Text, "contextcheck") { - return true - } + return nil } - return false + return fd.Doc.List } func (r *runner) checkIsCtx(f *ssa.Function) (in, out bool) { @@ -307,7 +326,16 @@ func (r *runner) checkIsCtx(f *ssa.Function) (in, out bool) { return } -func (r *runner) checkIsHttpHandler(f *ssa.Function) bool { +func (r *runner) checkIsHttpHandler(f *ssa.Function, reqctx bool) bool { + if reqctx { + tuple := f.Signature.Params() + for i := 0; i < tuple.Len(); i++ { + if r.isHttpReqType(tuple.At(i).Type()) { + return true + } + } + } + // must has no result if f.Signature.Results().Len() > 0 { return false diff --git a/testdata/src/a/a.go b/testdata/src/a/a.go index 874cd52..f8f5dbb 100644 --- a/testdata/src/a/a.go +++ b/testdata/src/a/a.go @@ -114,6 +114,11 @@ func f14(w http.ResponseWriter, r *http.Request, err error) { f8(r.Context(), w, r) } +// @contextcheck(req_has_ctx) +func f15(w http.ResponseWriter, r *http.Request, err error) { + f8(r.Context(), w, r) +} + func f11() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { f8(r.Context(), w, r) @@ -125,6 +130,7 @@ func f11() { f10(true, w, r) // want "Function `f10` should pass the context parameter" f14(w, r, nil) + f15(w, r, nil) }) }