From 8f89a79d5d232ed9568622ac7758072dc9b79f67 Mon Sep 17 00:00:00 2001 From: Travis DePrato Date: Thu, 30 Nov 2023 14:39:59 -0800 Subject: [PATCH] fix: Make av auth status output less confusing --- cmd/av/auth_status.go | 76 +++++++++++++++++++++++++++++++--------- cmd/av/main.go | 4 ++- internal/avgql/viewer.go | 2 +- internal/gh/error.go | 11 ++++++ internal/gh/viewer.go | 19 ++++++++++ 5 files changed, 93 insertions(+), 19 deletions(-) create mode 100644 internal/gh/error.go create mode 100644 internal/gh/viewer.go diff --git a/cmd/av/auth_status.go b/cmd/av/auth_status.go index 82bec386..e15c4c53 100644 --- a/cmd/av/auth_status.go +++ b/cmd/av/auth_status.go @@ -5,6 +5,10 @@ import ( "fmt" "os" + "emperror.dev/errors" + "github.com/aviator-co/av/internal/actions" + "github.com/aviator-co/av/internal/gh" + "github.com/aviator-co/av/internal/avgql" "github.com/aviator-co/av/internal/utils/colors" "github.com/spf13/cobra" @@ -16,26 +20,64 @@ var authStatusCmd = &cobra.Command{ SilenceUsage: true, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - client, err := avgql.NewClient() - if err != nil { - return err - } - - var query struct { - avgql.ViewerSubquery + exitCode := 0 + if err := checkAviatorAuthStatus(); err != nil { + _, _ = fmt.Fprintln(os.Stderr, colors.Failure(err.Error())) + exitCode = 1 } - - if err := client.Query(context.Background(), &query, nil); err != nil { - return err + if err := checkGitHubAuthStatus(); err != nil { + _, _ = fmt.Fprintln(os.Stderr, colors.Failure(err.Error())) + exitCode = 1 } - if err := query.CheckViewer(); err != nil { - return err + if exitCode != 0 { + return actions.ErrExitSilently{ExitCode: exitCode} } - - _, _ = fmt.Fprint(os.Stderr, - "Logged in as ", colors.UserInput(query.Viewer.FullName), - " (", colors.UserInput(query.Viewer.Email), ").\n", - ) return nil }, } + +func checkAviatorAuthStatus() error { + avClient, err := avgql.NewClient() + if err != nil { + return err + } + + var query struct{ avgql.ViewerSubquery } + if err := avClient.Query(context.Background(), &query, nil); err != nil { + return err + } + if err := query.CheckViewer(); err != nil { + return err + } + + _, _ = fmt.Fprint(os.Stderr, + "Logged in to Aviator as ", colors.UserInput(query.Viewer.FullName), + " (", colors.UserInput(query.Viewer.Email), ").\n", + ) + return nil +} + +func checkGitHubAuthStatus() error { + ghClient, err := getGitHubClient() + if err != nil { + return err + } + + viewer, err := ghClient.Viewer(context.Background()) + if err != nil { + // GitHub API returns 401 Unauthorized if the token is invalid or + // expired. + if gh.IsHTTPUnauthorized(err) { + return errors.New( + "You are not logged in to GitHub. Please verify that your API token is correct.", + ) + } + return errors.Wrap(err, "Failed to query GitHub") + } + + _, _ = fmt.Fprint(os.Stderr, + "Logged in to GitHub as ", colors.UserInput(viewer.Name), + " (", colors.UserInput(viewer.Login), ").\n", + ) + return nil +} diff --git a/cmd/av/main.go b/cmd/av/main.go index 2a0b81ee..4cf6dbc8 100644 --- a/cmd/av/main.go +++ b/cmd/av/main.go @@ -168,10 +168,12 @@ func discoverGitHubAPIToken() string { return "" } +var errNoGitHubToken = errors.New("No GitHub token is set (do you need to configure one?).") + func getGitHubClient() (*gh.Client, error) { token := discoverGitHubAPIToken() if token == "" { - return nil, errors.New("github token must be set") + return nil, errNoGitHubToken } var err error once.Do(func() { diff --git a/internal/avgql/viewer.go b/internal/avgql/viewer.go index 7702b443..fbbb0670 100644 --- a/internal/avgql/viewer.go +++ b/internal/avgql/viewer.go @@ -16,7 +16,7 @@ type ViewerSubquery struct { } var ErrNotAuthenticated = errors.New( - "You are not logged in. Please verify that your API token is correct.", + "You are not logged in to Aviator. Please verify that your API token is correct.", ) // CheckViewer checks whether or not the viewer is authenticated. diff --git a/internal/gh/error.go b/internal/gh/error.go new file mode 100644 index 00000000..3dd97faf --- /dev/null +++ b/internal/gh/error.go @@ -0,0 +1,11 @@ +package gh + +import "strings" + +// IsHTTPUnauthorized returns true if the given error is an HTTP 401 Unauthorized error. +func IsHTTPUnauthorized(err error) bool { + // This is a bit fragile because it relies on the error message from the + // GraphQL package. It doesn't export proper error types so we have to check + // the string. + return strings.Contains(err.Error(), "status code: 401") +} diff --git a/internal/gh/viewer.go b/internal/gh/viewer.go new file mode 100644 index 00000000..08369cad --- /dev/null +++ b/internal/gh/viewer.go @@ -0,0 +1,19 @@ +package gh + +import "context" + +type Viewer struct { + Name string `graphql:"name"` + Login string `graphql:"login"` +} + +func (c *Client) Viewer(ctx context.Context) (*Viewer, error) { + var query struct { + Viewer Viewer `graphql:"viewer"` + } + err := c.query(ctx, &query, nil) + if err != nil { + return nil, err + } + return &query.Viewer, nil +}