Skip to content

Commit

Permalink
Merge pull request #32843 from vespa-engine/arnej/add-query-postfile
Browse files Browse the repository at this point in the history
add the --postFile option to "vespa query"
  • Loading branch information
arnej27959 authored Nov 13, 2024
2 parents 3b06700 + e207edd commit ece9a08
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 31 deletions.
122 changes: 91 additions & 31 deletions client/go/internal/cli/cmd/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
package cmd

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"

Expand All @@ -22,14 +24,17 @@ import (
"github.com/vespa-engine/vespa/client/go/internal/vespa"
)

type queryOptions struct {
printCurl bool
queryTimeoutSecs int
waitSecs int
format string
postFile string
headers []string
}

func newQueryCmd(cli *CLI) *cobra.Command {
var (
printCurl bool
queryTimeoutSecs int
waitSecs int
format string
headers []string
)
opts := queryOptions{}
cmd := &cobra.Command{
Use: "query query-parameters",
Short: "Issue a query to Vespa",
Expand All @@ -43,32 +48,45 @@ can be set by the syntax [parameter-name]=[value].`,
// TODO: Support referencing a query json file
DisableAutoGenTag: true,
SilenceUsage: true,
Args: cobra.MinimumNArgs(1),
Args: cobra.MinimumNArgs(0),
RunE: func(cmd *cobra.Command, args []string) error {
waiter := cli.waiter(time.Duration(waitSecs)*time.Second, cmd)
return query(cli, args, queryTimeoutSecs, printCurl, format, headers, waiter)
if len(args) == 0 && opts.postFile == "" {
return fmt.Errorf("requires at least 1 arg")
}
waiter := cli.waiter(time.Duration(opts.waitSecs)*time.Second, cmd)
return query(cli, args, &opts, waiter)
},
}
cmd.Flags().BoolVarP(&printCurl, "verbose", "v", false, "Print the equivalent curl command for the query")
cmd.Flags().StringVarP(&format, "format", "", "human", "Output format. Must be 'human' (human-readable) or 'plain' (no formatting)")
cmd.Flags().StringSliceVarP(&headers, "header", "", nil, "Add a header to the HTTP request, on the format 'Header: Value'. This can be specified multiple times")
cmd.Flags().IntVarP(&queryTimeoutSecs, "timeout", "T", 10, "Timeout for the query in seconds")
cli.bindWaitFlag(cmd, 0, &waitSecs)
cmd.Flags().BoolVarP(&opts.printCurl, "verbose", "v", false, "Print the equivalent curl command for the query")
cmd.Flags().StringVarP(&opts.postFile, "file", "", "", "Read query parameters from the given JSON file and send a POST request, with overrides from arguments")
cmd.Flags().StringVarP(&opts.format, "format", "", "human", "Output format. Must be 'human' (human-readable) or 'plain' (no formatting)")
cmd.Flags().StringSliceVarP(&opts.headers, "header", "", nil, "Add a header to the HTTP request, on the format 'Header: Value'. This can be specified multiple times")
cmd.Flags().IntVarP(&opts.queryTimeoutSecs, "timeout", "T", 10, "Timeout for the query in seconds")
cli.bindWaitFlag(cmd, 0, &opts.waitSecs)
return cmd
}

func printCurl(stderr io.Writer, url string, service *vespa.Service) error {
cmd, err := curl.RawArgs(url)
func printCurl(stderr io.Writer, req *http.Request, postFile string, service *vespa.Service) error {
cmd, err := curl.RawArgs(req.URL.String())
if err != nil {
return err
}
cmd.Method = req.Method
if postFile != "" {
cmd.WithBodyFile(postFile)
}
for k, vl := range req.Header {
for _, v := range vl {
cmd.Header(k, v)
}
}
cmd.Certificate = service.TLSOptions.CertificateFile
cmd.PrivateKey = service.TLSOptions.PrivateKeyFile
_, err = io.WriteString(stderr, cmd.String()+"\n")
return err
}

func query(cli *CLI, arguments []string, timeoutSecs int, curl bool, format string, headers []string, waiter *Waiter) error {
func query(cli *CLI, arguments []string, opts *queryOptions, waiter *Waiter) error {
target, err := cli.target(targetOptions{})
if err != nil {
return err
Expand All @@ -77,12 +95,12 @@ func query(cli *CLI, arguments []string, timeoutSecs int, curl bool, format stri
if err != nil {
return err
}
switch format {
switch opts.format {
case "plain", "human":
default:
return fmt.Errorf("invalid format: %s", format)
return fmt.Errorf("invalid format: %s", opts.format)
}
url, _ := url.Parse(service.BaseURL + "/search/")
url, _ := url.Parse(strings.TrimSuffix(service.BaseURL, "/") + "/search/")
urlQuery := url.Query()
for i := range len(arguments) {
key, value := splitArg(arguments[i])
Expand All @@ -91,31 +109,44 @@ func query(cli *CLI, arguments []string, timeoutSecs int, curl bool, format stri
queryTimeout := urlQuery.Get("timeout")
if queryTimeout == "" {
// No timeout set by user, use the timeout option
queryTimeout = fmt.Sprintf("%ds", timeoutSecs)
queryTimeout = fmt.Sprintf("%ds", opts.queryTimeoutSecs)
urlQuery.Set("timeout", queryTimeout)
}
url.RawQuery = urlQuery.Encode()
deadline, err := time.ParseDuration(queryTimeout)
if err != nil {
return fmt.Errorf("invalid query timeout: %w", err)
}
if curl {
if err := printCurl(cli.Stderr, url.String(), service); err != nil {
return err
}
}
header, err := httputil.ParseHeader(headers)
header, err := httputil.ParseHeader(opts.headers)
if err != nil {
return err
}
response, err := service.Do(&http.Request{Header: header, URL: url}, deadline+time.Second) // Slightly longer than query timeout
hReq := &http.Request{Header: header, URL: url}
if opts.postFile != "" {
json, err := getJsonFrom(opts.postFile, urlQuery)
if err != nil {
return fmt.Errorf("bad JSON in postFile '%s': %w", opts.postFile, err)
}
header.Set("Content-Type", "application/json")
hReq.Method = "POST"
hReq.Body = io.NopCloser(bytes.NewBuffer(bytes.Clone(json)))
if err != nil {
return fmt.Errorf("bad postFile '%s': %w", opts.postFile, err)
}
}
url.RawQuery = urlQuery.Encode()
if opts.printCurl {
if err := printCurl(cli.Stderr, hReq, opts.postFile, service); err != nil {
return err
}
}
response, err := service.Do(hReq, deadline+time.Second) // Slightly longer than query timeout
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer response.Body.Close()

if response.StatusCode == 200 {
if err := printResponse(response.Body, response.Header.Get("Content-Type"), format, cli); err != nil {
if err := printResponse(response.Body, response.Header.Get("Content-Type"), opts.format, cli); err != nil {
return err
}
} else if response.StatusCode/100 == 4 {
Expand Down Expand Up @@ -207,3 +238,32 @@ func splitArg(argument string) (string, string) {
}
return parts[0], parts[1]
}

func getJsonFrom(fn string, query url.Values) ([]byte, error) {
parsed := make(map[string]any)
f, err := os.Open(fn)
if err != nil {
return nil, err
}
body, err := io.ReadAll(f)
if err != nil {
return nil, err
}
err = json.Unmarshal(body, &parsed)
if err != nil {
return nil, err
}
for k, vl := range query {
if len(vl) == 1 {
parsed[k] = vl[0]
} else {
parsed[k] = vl
}
query.Del(k)
}
b, err := json.Marshal(parsed)
if err != nil {
return nil, err
}
return b, nil
}
48 changes: 48 additions & 0 deletions client/go/internal/cli/cmd/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ package cmd

import (
"net/http"
"os"
"path/filepath"
"strconv"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vespa-engine/vespa/client/go/internal/mock"
)

Expand Down Expand Up @@ -134,6 +137,51 @@ data: {
assertStreamingQuery(t, bodyWithError, bodyWithError, "--format=plain")
}

func TestQueryPostFile(t *testing.T) {
mockResponse := `{"query":"result"}"`
client := &mock.HTTPClient{ReadBody: true}
client.NextResponseString(200, mockResponse)
cli, stdout, _ := newTestCLI(t)
cli.httpClient = client

tmpFileName := filepath.Join(t.TempDir(), "tq1.json")
jsonQuery := []byte(`{"yql": "some yql here"}`)
require.Nil(t, os.WriteFile(tmpFileName, jsonQuery, 0644))

assert.Nil(t, cli.Run("-t", "http://127.0.0.1:8080", "query", "--file", tmpFileName))
assert.Equal(t, mockResponse+"\n", stdout.String())
assert.Equal(t, `{"timeout":"10s","yql":"some yql here"}`, string(client.LastBody))
assert.Equal(t, []string{"application/json"}, client.LastRequest.Header.Values("Content-Type"))
assert.Equal(t, "POST", client.LastRequest.Method)
assert.Equal(t, "http://127.0.0.1:8080/search/", client.LastRequest.URL.String())
}

func TestQueryPostFileWithArgs(t *testing.T) {
mockResponse := `{"query":"result"}"`
client := &mock.HTTPClient{ReadBody: true}
client.NextResponseString(200, mockResponse)
cli, _, _ := newTestCLI(t)
cli.httpClient = client

tmpFileName := filepath.Join(t.TempDir(), "tq2.json")
jsonQuery := []byte(`{"yql": "some yql here"}`)
require.Nil(t, os.WriteFile(tmpFileName, jsonQuery, 0644))

assert.Nil(t, cli.Run(
"-t", "http://foo.bar:1234/",
"query",
"--file", tmpFileName,
"yql=foo bar",
"tracelevel=3",
"dispatch.docsumRetryLimit=42"))
assert.Equal(t,
`{"dispatch.docsumRetryLimit":"42","timeout":"10s","tracelevel":"3","yql":"foo bar"}`,
string(client.LastBody))
assert.Equal(t, []string{"application/json"}, client.LastRequest.Header.Values("Content-Type"))
assert.Equal(t, "POST", client.LastRequest.Method)
assert.Equal(t, "http://foo.bar:1234/search/", client.LastRequest.URL.String())
}

func assertStreamingQuery(t *testing.T, expectedOutput, body string, args ...string) {
t.Helper()
client := &mock.HTTPClient{}
Expand Down

0 comments on commit ece9a08

Please sign in to comment.