From c8bd546a4f8bc8bcb329c7010d97628dd5cbb3ce Mon Sep 17 00:00:00 2001 From: Caleb Brown Date: Thu, 28 Apr 2022 08:02:59 +1000 Subject: [PATCH] Place code to be shared with signal collection into libraries. (#117) --- cmd/enumerate_github/main.go | 22 ++--- internal/logflag/level.go | 32 +++++++ internal/logflag/level_test.go | 78 ++++++++++++++++ internal/outfile/outfile.go | 86 ++++++++++++++++++ internal/outfile/outfile_test.go | 147 +++++++++++++++++++++++++++++++ 5 files changed, 349 insertions(+), 16 deletions(-) create mode 100644 internal/logflag/level.go create mode 100644 internal/logflag/level_test.go create mode 100644 internal/outfile/outfile.go create mode 100644 internal/outfile/outfile_test.go diff --git a/cmd/enumerate_github/main.go b/cmd/enumerate_github/main.go index ed25e362b..56ebc6dd2 100644 --- a/cmd/enumerate_github/main.go +++ b/cmd/enumerate_github/main.go @@ -12,6 +12,8 @@ import ( "time" "github.com/ossf/criticality_score/cmd/enumerate_github/githubsearch" + "github.com/ossf/criticality_score/internal/logflag" + "github.com/ossf/criticality_score/internal/outfile" "github.com/ossf/scorecard/v4/clients/githubrepo/roundtripper" sclog "github.com/ossf/scorecard/v4/log" "github.com/shurcooL/githubv4" @@ -29,8 +31,6 @@ var ( // epochDate is the earliest date for which GitHub has data. epochDate = time.Date(2008, 1, 1, 0, 0, 0, 0, time.UTC) - forceFlag = flag.Bool("force", false, "overwrites FILE if it already exists and -append is not set.") - appendFlag = flag.Bool("append", false, "appends to FILE if it already exists.") minStarsFlag = flag.Int("min-stars", 10, "only enumerates repositories with this or more of stars.") starOverlapFlag = flag.Int("star-overlap", 5, "the number of stars to overlap between queries.") requireMinStarsFlag = flag.Bool("require-min-stars", false, "abort if -min-stars can't be reached during enumeration.") @@ -38,7 +38,7 @@ var ( workersFlag = flag.Int("workers", 1, "the total number of concurrent workers to use.") startDateFlag = dateFlag(epochDate) endDateFlag = dateFlag(time.Now().UTC().Truncate(oneDay)) - logFlag = logLevelFlag(defaultLogLevel) + logFlag = logflag.Level(defaultLogLevel) ) // dateFlag implements the flag.Value interface to simplify the input and validation of @@ -87,6 +87,7 @@ func init() { flag.Var(&startDateFlag, "start", "the start `date` to enumerate back to. Must be at or after 2008-01-01.") flag.Var(&endDateFlag, "end", "the end `date` to enumerate from.") flag.Var(&logFlag, "log", "set the `level` of logging.") + outfile.DefineFlags(flag.CommandLine, "force", "append", "FILE") flag.Usage = func() { cmdName := path.Base(os.Args[0]) w := flag.CommandLine.Output() @@ -162,21 +163,10 @@ func main() { // Print a helpful message indicating the configuration we're using. logger.WithFields(log.Fields{ "filename": outFilename, - "force": *forceFlag, - "append": *appendFlag, }).Info("Preparing output file") - // Open the output file based on the flags - // TODO: support '-' to use os.Stdout. - var out *os.File - var err error - if *appendFlag { - out, err = os.OpenFile(outFilename, os.O_WRONLY|os.O_SYNC|os.O_CREATE|os.O_APPEND, 0666) - } else if *forceFlag { - out, err = os.OpenFile(outFilename, os.O_WRONLY|os.O_SYNC|os.O_CREATE|os.O_TRUNC, 0666) - } else { - out, err = os.OpenFile(outFilename, os.O_WRONLY|os.O_SYNC|os.O_CREATE|os.O_EXCL, 0666) - } + // Open the output file + out, err := outfile.Open(outFilename) if err != nil { // File failed to open logger.WithFields(log.Fields{ diff --git a/internal/logflag/level.go b/internal/logflag/level.go new file mode 100644 index 000000000..7014e69e9 --- /dev/null +++ b/internal/logflag/level.go @@ -0,0 +1,32 @@ +// Package logflag is a simple helper library that generalizes the logic for +// parsing command line flags for configuring the logging behavior. +package logflag + +import log "github.com/sirupsen/logrus" + +// Level implements the flag.Value interface to simplify the input and validation +// of the current logrus log level. +// +// var logLevel = logflag.Level(logrus.InfoLevel) +// flag.Var(&logLevel, "log", "set the `level` of logging.") +type Level log.Level + +// Set implements the flag.Value interface. +func (l *Level) Set(value string) error { + level, err := log.ParseLevel(string(value)) + if err != nil { + return err + } + *l = Level(level) + return nil +} + +// String implements the flag.Value interface. +func (l Level) String() string { + return log.Level(l).String() +} + +// Level returns either the default log level, or the value set on the command line. +func (l Level) Level() log.Level { + return log.Level(l) +} diff --git a/internal/logflag/level_test.go b/internal/logflag/level_test.go new file mode 100644 index 000000000..73f2605b6 --- /dev/null +++ b/internal/logflag/level_test.go @@ -0,0 +1,78 @@ +package logflag_test + +import ( + "flag" + "testing" + + "github.com/ossf/criticality_score/internal/logflag" + "github.com/sirupsen/logrus" +) + +func TestDefault(t *testing.T) { + level := logflag.Level(logrus.ErrorLevel) + if l := level.Level(); l != logrus.ErrorLevel { + t.Fatalf("Level() == %v, want %v", l, logrus.ErrorLevel) + } +} + +func TestSet(t *testing.T) { + level := logflag.Level(logrus.InfoLevel) + err := level.Set("error") + if err != nil { + t.Fatalf("Set() == %v, want nil", err) + } + if l := level.Level(); l != logrus.ErrorLevel { + t.Fatalf("Level() == %v, want %v", l, logrus.ErrorLevel) + } +} + +func TestSetError(t *testing.T) { + level := logflag.Level(logrus.InfoLevel) + err := level.Set("hello,world") + if err == nil { + t.Fatalf("Set() == nil, want an error") + } +} + +func TestString(t *testing.T) { + level := logflag.Level(logrus.DebugLevel) + if s := level.String(); s != logrus.DebugLevel.String() { + t.Fatalf("String() == %v, want %v", s, logrus.DebugLevel.String()) + } +} + +func TestFlagUnset(t *testing.T) { + fs := flag.NewFlagSet("", flag.ContinueOnError) + level := logflag.Level(logrus.InfoLevel) + fs.Var(&level, "level", "usage") + err := fs.Parse([]string{"arg"}) + if err != nil { + t.Fatalf("Parse() == %v, want nil", err) + } + if l := level.Level(); l != logrus.InfoLevel { + t.Fatalf("Level() == %v, want %v", l, logrus.InfoLevel) + } +} + +func TestFlagSet(t *testing.T) { + fs := flag.NewFlagSet("", flag.ContinueOnError) + level := logflag.Level(logrus.InfoLevel) + fs.Var(&level, "level", "usage") + err := fs.Parse([]string{"-level=fatal", "arg"}) + if err != nil { + t.Fatalf("Parse() == %v, want nil", err) + } + if l := level.Level(); l != logrus.FatalLevel { + t.Fatalf("Level() == %v, want %v", l, logrus.FatalLevel) + } +} + +func TestFlagSetError(t *testing.T) { + fs := flag.NewFlagSet("", flag.ContinueOnError) + level := logflag.Level(logrus.InfoLevel) + fs.Var(&level, "level", "usage") + err := fs.Parse([]string{"-level=foobar", "arg"}) + if err == nil { + t.Fatalf("Parse() == nil, want an error") + } +} diff --git a/internal/outfile/outfile.go b/internal/outfile/outfile.go new file mode 100644 index 000000000..a4be6bc35 --- /dev/null +++ b/internal/outfile/outfile.go @@ -0,0 +1,86 @@ +package outfile + +import ( + "flag" + "fmt" + "os" +) + +// fileOpener wraps a method for opening files. +// +// This allows tests to fake the behavior of os.OpenFile() to avoid hitting +// the filesystem. +type fileOpener interface { + Open(string, int, os.FileMode) (*os.File, error) +} + +// fileOpenerFunc allows a function to implement the openFileWrapper interface. +// +// This is convenient for wrapping os.OpenFile(). +type fileOpenerFunc func(string, int, os.FileMode) (*os.File, error) + +func (f fileOpenerFunc) Open(filename string, flags int, perm os.FileMode) (*os.File, error) { + return f(filename, flags, perm) +} + +type Opener struct { + force bool + append bool + fileOpener fileOpener + Perm os.FileMode + StdoutName string +} + +// CreateOpener creates an Opener and defines the sepecified flags forceFlag and appendFlag. +func CreateOpener(fs *flag.FlagSet, forceFlag string, appendFlag string, fileHelpName string) *Opener { + o := &Opener{ + Perm: 0666, + StdoutName: "-", + fileOpener: fileOpenerFunc(os.OpenFile), + } + fs.BoolVar(&(o.force), forceFlag, false, fmt.Sprintf("overwrites %s if it already exists and -%s is not set.", fileHelpName, appendFlag)) + fs.BoolVar(&(o.append), appendFlag, false, fmt.Sprintf("appends to %s if it already exists.", fileHelpName)) + return o +} + +func (o *Opener) openInternal(filename string, extraFlags int) (*os.File, error) { + return o.fileOpener.Open(filename, os.O_WRONLY|os.O_SYNC|os.O_CREATE|extraFlags, o.Perm) +} + +// Open opens and returns a file for output with the given filename. +// +// If filename is equal to o.StdoutName, os.Stdout will be used. +// If filename does not exist, it will be created with the mode set in o.Perm. +// If filename does exist, the behavior of this function will depend on the +// flags: +// - if appendFlag is set on the command line the existing file will be +// appended to. +// - if forceFlag is set on the command line the existing file will be +// truncated. +// - if neither forceFlag nor appendFlag are set an error will be +// returned. +func (o *Opener) Open(filename string) (f *os.File, err error) { + if o.StdoutName != "" && filename == o.StdoutName { + f = os.Stdout + } else if o.append { + f, err = o.openInternal(filename, os.O_APPEND) + } else if o.force { + f, err = o.openInternal(filename, os.O_TRUNC) + } else { + f, err = o.openInternal(filename, os.O_EXCL) + } + return +} + +var defaultOpener *Opener + +// DefineFlags is a wrapper around CreateOpener for updating a default instance +// of Opener. +func DefineFlags(fs *flag.FlagSet, forceFlag string, appendFlag string, fileHelpName string) { + defaultOpener = CreateOpener(fs, forceFlag, appendFlag, fileHelpName) +} + +// Open is a wrapper around Opener.Open for the default instance of Opener. +func Open(filename string) (*os.File, error) { + return defaultOpener.Open(filename) +} diff --git a/internal/outfile/outfile_test.go b/internal/outfile/outfile_test.go new file mode 100644 index 000000000..f91f4382d --- /dev/null +++ b/internal/outfile/outfile_test.go @@ -0,0 +1,147 @@ +package outfile + +import ( + "errors" + "flag" + "os" + "testing" +) + +type openCall struct { + filename string + flags int + perm os.FileMode +} + +type testOpener struct { + flag *flag.FlagSet + openErr error + lastOpen *openCall + opener *Opener +} + +func newTestOpener() *testOpener { + o := &testOpener{} + o.flag = flag.NewFlagSet("", flag.ContinueOnError) + o.opener = CreateOpener(o.flag, "force", "append", "FILE") + o.opener.Perm = 0567 + o.opener.StdoutName = "-stdout-" + o.opener.fileOpener = fileOpenerFunc(func(filename string, flags int, perm os.FileMode) (*os.File, error) { + o.lastOpen = &openCall{ + filename: filename, + flags: flags, + perm: perm, + } + if o.openErr != nil { + return nil, o.openErr + } else { + return &os.File{}, nil + } + }) + return o +} + +func TestForceFlagDefined(t *testing.T) { + o := newTestOpener() + f := o.flag.Lookup("force") + if f == nil { + t.Fatal("Lookup() == nil, wanted a flag.") + } +} + +func TestAppendFlagDefined(t *testing.T) { + o := newTestOpener() + f := o.flag.Lookup("append") + if f == nil { + t.Fatal("Lookup() == nil, wanted a flag.") + } +} + +func TestOpenStdout(t *testing.T) { + o := newTestOpener() + f, err := o.opener.Open("-stdout-") + if err != nil { + t.Fatalf("Open() == %v, want nil", err) + } + if f != os.Stdout { + n := "nil" + if f != nil { + n = f.Name() + } + t.Fatalf("Open() == %s, want %v", n, os.Stdout.Name()) + } +} + +func TestOpenFlagTest(t *testing.T) { + tests := []struct { + name string + args []string + openErr error + expectedFlag int + }{ + { + name: "no args", + args: []string{}, + expectedFlag: os.O_EXCL, + }, + { + name: "append only flag", + args: []string{"-append"}, + expectedFlag: os.O_APPEND, + }, + { + name: "force only flag", + args: []string{"-force"}, + expectedFlag: os.O_TRUNC, + }, + { + name: "both flags", + args: []string{"-force", "-append"}, + expectedFlag: os.O_APPEND, + }, + } + + // Test success responses + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + o := newTestOpener() + o.flag.Parse(test.args) + f, err := o.opener.Open("path/to/file") + if err != nil { + t.Fatalf("Open() == %v, want nil", err) + } + if f == nil { + t.Fatal("Open() == nil, want a file") + } + assertLastOpen(t, o, "path/to/file", test.expectedFlag, 0567) + }) + } + + // Test error responses + for _, test := range tests { + t.Run(test.name+" error", func(t *testing.T) { + o := newTestOpener() + o.flag.Parse(test.args) + o.openErr = errors.New("test error") + _, err := o.opener.Open("path/to/file") + if err == nil { + t.Fatalf("Open() is nil, want %v", o.openErr) + } + }) + } +} + +func assertLastOpen(t *testing.T, o *testOpener, filename string, requireFlags int, perm os.FileMode) { + if o.lastOpen == nil { + t.Fatalf("Open(...) not called, want call to Open(...)") + } + if o.lastOpen.filename != filename { + t.Fatalf("Open(%v, _, _) called, want Open(%v, _, _)", o.lastOpen.filename, filename) + } + if o.lastOpen.flags&requireFlags != requireFlags { + t.Fatalf("Open(_, %v, _) called, want Open(_, %v, _)", o.lastOpen.flags&requireFlags, requireFlags) + } + if o.lastOpen.perm != perm { + t.Fatalf("Open(_, _, %v) called, want Open(_, _, %v)", o.lastOpen.perm, perm) + } +}