diff --git a/client.go b/client.go index e160ce8..1a37553 100644 --- a/client.go +++ b/client.go @@ -57,7 +57,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, body io.Rea req.Header.Set("Notion-Version", apiVersion) req.Header.Set("User-Agent", "go-notion/"+clientVersion) - if method == http.MethodPost || method == http.MethodPatch { + if body != nil { req.Header.Set("Content-Type", "application/json") } @@ -92,12 +92,14 @@ func (c *Client) FindDatabaseByID(ctx context.Context, id string) (db Database, // QueryDatabase returns database contents, with optional filters, sorts and pagination. // See: https://developers.notion.com/reference/post-database-query -func (c *Client) QueryDatabase(ctx context.Context, id string, query DatabaseQuery) (result DatabaseQueryResponse, err error) { +func (c *Client) QueryDatabase(ctx context.Context, id string, query *DatabaseQuery) (result DatabaseQueryResponse, err error) { body := &bytes.Buffer{} - err = json.NewEncoder(body).Encode(query) - if err != nil { - return DatabaseQueryResponse{}, fmt.Errorf("notion: failed to encode filter to JSON: %w", err) + if query != nil { + err = json.NewEncoder(body).Encode(query) + if err != nil { + return DatabaseQueryResponse{}, fmt.Errorf("notion: failed to encode filter to JSON: %w", err) + } } req, err := c.newRequest(ctx, http.MethodPost, fmt.Sprintf("/databases/%v/query", id), body) @@ -112,7 +114,7 @@ func (c *Client) QueryDatabase(ctx context.Context, id string, query DatabaseQue defer res.Body.Close() if res.StatusCode != http.StatusOK { - return DatabaseQueryResponse{}, fmt.Errorf("notion: failed to find database: %w", parseErrorResponse(res)) + return DatabaseQueryResponse{}, fmt.Errorf("notion: failed to query database: %w", parseErrorResponse(res)) } err = json.NewDecoder(res.Body).Decode(&result) diff --git a/client_test.go b/client_test.go index ccdef88..786f1f2 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package notion_test import ( "context" + "encoding/json" "errors" "io" "io/ioutil" @@ -388,3 +389,251 @@ func TestFindDatabaseByID(t *testing.T) { }) } } + +func TestQueryDatabase(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query *notion.DatabaseQuery + respBody func(r *http.Request) io.Reader + respStatusCode int + expPostBody map[string]interface{} + expResponse notion.DatabaseQueryResponse + expError error + }{ + { + name: "with query, successful response", + query: ¬ion.DatabaseQuery{ + Filter: notion.DatabaseQueryFilter{ + Property: "Name", + Text: ¬ion.TextDatabaseQueryFilter{ + Contains: "foobar", + }, + }, + Sorts: []notion.DatabaseQuerySort{ + { + Property: "Name", + Timestamp: notion.SortTimeStampCreatedTime, + Direction: notion.SortDirAsc, + }, + { + Property: "Date", + Timestamp: notion.SortTimeStampLastEditedTime, + Direction: notion.SortDirDesc, + }, + }, + }, + respBody: func(_ *http.Request) io.Reader { + return strings.NewReader( + `{ + "object": "list", + "results": [ + { + "object": "page", + "id": "7c6b1c95-de50-45ca-94e6-af1d9fd295ab", + "created_time": "2021-05-18T17:50:22.371Z", + "last_edited_time": "2021-05-18T17:50:22.371Z", + "parent": { + "type": "database_id", + "database_id": "39ddfc9d-33c9-404c-89cf-79f01c42dd0c" + }, + "archived": false, + "properties": { + "Date": { + "id": "Q]uT", + "type": "date", + "date": { + "start": "2021-05-18T12:49:00.000-05:00", + "end": null + } + }, + "Name": { + "id": "title", + "type": "title", + "title": [ + { + "type": "text", + "text": { + "content": "Foobar", + "link": null + }, + "annotations": { + "bold": false, + "italic": false, + "strikethrough": false, + "underline": false, + "code": false, + "color": "default" + }, + "plain_text": "Foobar", + "href": null + } + ] + } + } + } + ], + "next_cursor": "A^hd", + "has_more": true + }`, + ) + }, + respStatusCode: http.StatusOK, + expPostBody: map[string]interface{}{ + "filter": map[string]interface{}{ + "property": "Name", + "text": map[string]interface{}{ + "contains": "foobar", + }, + }, + "sorts": []interface{}{ + map[string]interface{}{ + "property": "Name", + "timestamp": "created_time", + "direction": "ascending", + }, + map[string]interface{}{ + "property": "Date", + "timestamp": "last_edited_time", + "direction": "descending", + }, + }, + }, + expResponse: notion.DatabaseQueryResponse{ + Results: []notion.Page{ + { + ID: "7c6b1c95-de50-45ca-94e6-af1d9fd295ab", + CreatedTime: mustParseTime(time.RFC3339Nano, "2021-05-18T17:50:22.371Z"), + LastEditedTime: mustParseTime(time.RFC3339Nano, "2021-05-18T17:50:22.371Z"), + Parent: notion.PageParent{ + Type: notion.ParentTypeDatabase, + DatabaseID: notion.StringPtr("39ddfc9d-33c9-404c-89cf-79f01c42dd0c"), + }, + Archived: false, + Properties: notion.DatabasePageProperties{ + "Date": notion.DatabasePageProperty{ + ID: "Q]uT", + Type: notion.DBPropTypeDate, + Date: ¬ion.Date{ + Start: mustParseTime(time.RFC3339Nano, "2021-05-18T12:49:00.000-05:00"), + }, + }, + "Name": notion.DatabasePageProperty{ + ID: "title", + Type: notion.DBPropTypeTitle, + Title: []notion.RichText{ + { + Type: notion.RichTextTypeText, + Text: ¬ion.Text{ + Content: "Foobar", + }, + PlainText: "Foobar", + Annotations: ¬ion.Annotations{ + Color: notion.ColorDefault, + }, + }, + }, + }, + }, + }, + }, + HasMore: true, + NextCursor: notion.StringPtr("A^hd"), + }, + expError: nil, + }, + { + name: "without query, successful response", + query: nil, + respBody: func(_ *http.Request) io.Reader { + return strings.NewReader( + `{ + "object": "list", + "results": [], + "next_cursor": null, + "has_more": false + }`, + ) + }, + respStatusCode: http.StatusOK, + expPostBody: nil, + expResponse: notion.DatabaseQueryResponse{ + Results: []notion.Page{}, + HasMore: false, + NextCursor: nil, + }, + expError: nil, + }, + { + name: "error response", + respBody: func(_ *http.Request) io.Reader { + return strings.NewReader( + `{ + "object": "error", + "status": 400, + "code": "validation_error", + "message": "foobar" + }`, + ) + }, + respStatusCode: http.StatusBadRequest, + expResponse: notion.DatabaseQueryResponse{}, + expError: errors.New("notion: failed to query database: foobar (code: validation_error, status: 400)"), + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + httpClient := &http.Client{ + Transport: &mockRoundtripper{fn: func(r *http.Request) (*http.Response, error) { + postBody := make(map[string]interface{}) + + err := json.NewDecoder(r.Body).Decode(&postBody) + if err != nil && err != io.EOF { + t.Fatal(err) + } + + if len(tt.expPostBody) == 0 && len(postBody) != 0 { + t.Errorf("unexpected post body: %+v", postBody) + } + + if len(tt.expPostBody) != 0 && len(postBody) == 0 { + t.Errorf("post body not equal (expected %+v, got: nil)", tt.expPostBody) + } + + if len(tt.expPostBody) != 0 && len(postBody) != 0 { + if diff := cmp.Diff(tt.expPostBody, postBody); diff != "" { + t.Errorf("post body not equal (-exp, +got):\n%v", diff) + } + } + + return &http.Response{ + StatusCode: tt.respStatusCode, + Status: http.StatusText(tt.respStatusCode), + Body: ioutil.NopCloser(tt.respBody(r)), + }, nil + }}, + } + client := notion.NewClient("secret-api-key", notion.WithHTTPClient(httpClient)) + resp, err := client.QueryDatabase(context.Background(), "00000000-0000-0000-0000-000000000000", tt.query) + + if tt.expError == nil && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.expError != nil && err == nil { + t.Fatalf("error not equal (expected: %v, got: nil)", tt.expError) + } + if tt.expError != nil && err != nil && tt.expError.Error() != err.Error() { + t.Fatalf("error not equal (expected: %v, got: %v)", tt.expError, err) + } + + if diff := cmp.Diff(tt.expResponse, resp); diff != "" { + t.Fatalf("response not equal (-exp, +got):\n%v", diff) + } + }) + } +}