diff --git a/graphql/client.go b/graphql/client.go index 526395b8..852617c9 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -242,7 +242,10 @@ func (c *client) MakeRequest(ctx context.Context, req *Request, resp *Response) if err != nil { respBody = []byte(fmt.Sprintf("", err)) } - return fmt.Errorf("returned error %v: %s", httpResp.Status, respBody) + return &HTTPError{ + StatusCode: httpResp.StatusCode, + Body: string(respBody), + } } err = json.NewDecoder(httpResp.Body).Decode(resp) diff --git a/graphql/client_test.go b/graphql/client_test.go new file mode 100644 index 00000000..5eef22c3 --- /dev/null +++ b/graphql/client_test.go @@ -0,0 +1,97 @@ +package graphql + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMakeRequest_HTTPError(t *testing.T) { + testCases := []struct { + name string + serverResponseBody string + expectedErrorBody string + serverResponseCode int + expectedStatusCode int + }{ + { + name: "400 Bad Request", + serverResponseBody: "Bad Request", + expectedErrorBody: "Bad Request", + serverResponseCode: http.StatusBadRequest, + expectedStatusCode: http.StatusBadRequest, + }, + { + name: "429 Too Many Requests", + serverResponseBody: "Rate limit exceeded", + expectedErrorBody: "Rate limit exceeded", + serverResponseCode: http.StatusTooManyRequests, + expectedStatusCode: http.StatusTooManyRequests, + }, + { + name: "500 Internal Server Error", + serverResponseBody: "Internal Server Error", + expectedErrorBody: "Internal Server Error", + serverResponseCode: http.StatusInternalServerError, + expectedStatusCode: http.StatusInternalServerError, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.serverResponseCode) + _, err := w.Write([]byte(tc.serverResponseBody)) + if err != nil { + t.Fatalf("Failed to write response: %v", err) + } + })) + defer server.Close() + + client := NewClient(server.URL, server.Client()) + req := &Request{ + Query: "query { test }", + } + resp := &Response{} + + err := client.MakeRequest(context.Background(), req, resp) + + assert.Error(t, err) + var httpErr *HTTPError + assert.True(t, errors.As(err, &httpErr), "Error should be of type *HTTPError") + assert.Equal(t, tc.expectedStatusCode, httpErr.StatusCode) + assert.Equal(t, tc.expectedErrorBody, httpErr.Body) + }) + } +} + +func TestMakeRequest_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(map[string]interface{}{ + "data": map[string]string{ + "test": "success", + }, + }) + if err != nil { + t.Fatalf("Failed to encode response: %v", err) + } + })) + defer server.Close() + + client := NewClient(server.URL, server.Client()) + req := &Request{ + Query: "query { test }", + } + resp := &Response{} + + err := client.MakeRequest(context.Background(), req, resp) + + assert.NoError(t, err) + assert.Equal(t, map[string]interface{}{"test": "success"}, resp.Data) +} diff --git a/graphql/errors.go b/graphql/errors.go new file mode 100644 index 00000000..72ef3f82 --- /dev/null +++ b/graphql/errors.go @@ -0,0 +1,14 @@ +package graphql + +import "fmt" + +// HTTPError represents an HTTP error with status code and response body. +type HTTPError struct { + Body string + StatusCode int +} + +// Error implements the error interface for HTTPError. +func (e *HTTPError) Error() string { + return fmt.Sprintf("returned error %v: %s", e.StatusCode, e.Body) +} diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 301ade2b..30124724 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -7,6 +7,7 @@ package integration import ( "context" + "errors" "fmt" "net/http" "testing" @@ -14,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/vektah/gqlparser/v2/gqlerror" "github.com/Khan/genqlient/graphql" "github.com/Khan/genqlient/internal/integration/server" @@ -113,6 +115,15 @@ func TestServerError(t *testing.T) { // response -- and indeed in this case it should even have another field // (which didn't err) set. assert.Error(t, err) + t.Logf("Full error: %+v", err) + var gqlErrors gqlerror.List + if !assert.True(t, errors.As(err, &gqlErrors), "Error should be of type gqlerror.List") { + t.Logf("Actual error type: %T", err) + t.Logf("Error message: %v", err) + } else { + assert.Len(t, gqlErrors, 1, "Expected one GraphQL error") + assert.Equal(t, "oh no", gqlErrors[0].Message) + } assert.NotNil(t, resp) assert.Equal(t, "1", resp.Me.Id) } @@ -130,6 +141,8 @@ func TestNetworkError(t *testing.T) { // return resp.Me.Id, err // without a bunch of extra ceremony. assert.Error(t, err) + var gqlErrors gqlerror.List + assert.False(t, errors.As(err, &gqlErrors), "Error should not be of type gqlerror.List for network errors") assert.NotNil(t, resp) assert.Equal(t, new(failingQueryResponse), resp) }