From 5f491b5a81a249ca5e6722391d7a2241f0df9b09 Mon Sep 17 00:00:00 2001 From: Branden J Brown Date: Fri, 19 Apr 2024 12:18:23 -0500 Subject: [PATCH] twitch: return ErrNeedRefresh on 401 response --- twitch/twitch.go | 6 +++- twitch/twitch_test.go | 76 +++++++++++++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 26 deletions(-) diff --git a/twitch/twitch.go b/twitch/twitch.go index 01e012f..9bd0e68 100644 --- a/twitch/twitch.go +++ b/twitch/twitch.go @@ -34,7 +34,11 @@ func reqjson[Resp any](ctx context.Context, client Client, method, url string, b return fmt.Errorf("couldn't read response: %w", err) } resp.Body.Close() - if resp.StatusCode != http.StatusOK { + switch resp.StatusCode { + case http.StatusOK: // do nothing + case http.StatusUnauthorized: + return fmt.Errorf("request failed: %s (%w)", b, ErrNeedRefresh) + default: return fmt.Errorf("request failed: %s (%s)", b, resp.Status) } r := struct { diff --git a/twitch/twitch_test.go b/twitch/twitch_test.go index ae9f076..c9e3c5f 100644 --- a/twitch/twitch_test.go +++ b/twitch/twitch_test.go @@ -28,29 +28,55 @@ func (r *reqspy) RoundTrip(req *http.Request) (*http.Response, error) { } func TestReqJSON(t *testing.T) { - spy := &reqspy{ - respond: &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{"data":1}`)), - }, - } - cl := Client{ - HTTP: &http.Client{ - Transport: spy, - }, - Token: &oauth2.Token{ - AccessToken: "bocchi", - }, - } - var u int - err := reqjson(context.Background(), cl, "GET", "https://bocchi.rocks/bocchi", nil, &u) - if err != nil { - t.Errorf("failed to request: %v", err) - } - if u != 1 { - t.Errorf("didn't get the result: want 1, got %d", u) - } - if spy.got.URL.String() != "https://bocchi.rocks/bocchi" { - t.Errorf("request went to the wrong place: want https://bocchi.rocks/bocchi, got %v", spy.got.URL) - } + t.Run("ok", func(t *testing.T) { + spy := &reqspy{ + respond: &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"data":1}`)), + }, + } + cl := Client{ + HTTP: &http.Client{ + Transport: spy, + }, + Token: &oauth2.Token{ + AccessToken: "bocchi", + }, + } + var u int + err := reqjson(context.Background(), cl, "GET", "https://bocchi.rocks/bocchi", nil, &u) + if err != nil { + t.Errorf("failed to request: %v", err) + } + if u != 1 { + t.Errorf("didn't get the result: want 1, got %d", u) + } + if got := spy.got.URL.String(); got != "https://bocchi.rocks/bocchi" { + t.Errorf(`request went to the wrong place: want "https://bocchi.rocks/bocchi", got %q`, got) + } + if got := spy.got.Header.Get("Authorization"); got != "Bearer bocchi" { + t.Errorf(`wrong authorization: want "Bearer bocchi", got %q`, got) + } + }) + t.Run("expired", func(t *testing.T) { + spy := &reqspy{ + respond: &http.Response{ + StatusCode: 401, + Body: io.NopCloser(strings.NewReader(`{"data":1}`)), + }, + } + cl := Client{ + HTTP: &http.Client{ + Transport: spy, + }, + Token: &oauth2.Token{ + AccessToken: "bocchi", + }, + } + var u int + err := reqjson(context.Background(), cl, "GET", "https://bocchi.rocks/bocchi", nil, &u) + if !errors.Is(err, ErrNeedRefresh) { + t.Errorf("unauthorized request didn't return ErrNeedRefresh error") + } + }) }