Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multipart/form-data transport to support file uploads #268

Merged
merged 21 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 147 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ import (
"fmt"
"io"
"math"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"os"
"strings"
"time"

"github.com/99designs/gqlgen/graphql"
"github.com/prometheus/client_golang/prometheus"
"github.com/vektah/gqlparser/v2/ast"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
Expand Down Expand Up @@ -131,11 +135,11 @@ func (c *GraphQLClient) Request(ctx context.Context, url string, request *Reques
return err
}

var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(request); err != nil {
buf, contentType, err := request.requestBody()
if err != nil {
return traceErr(fmt.Errorf("unable to encode request body: %w", err))
}

}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &buf)
if err != nil {
return traceErr(fmt.Errorf("unable to create request: %w", err))
Expand All @@ -145,7 +149,7 @@ func (c *GraphQLClient) Request(ctx context.Context, url string, request *Reques
httpReq.Header = request.Headers.Clone()
}

httpReq.Header.Set("Content-Type", "application/json; charset=utf-8")
httpReq.Header.Set("Content-Type", contentType)
httpReq.Header.Set("Accept", "application/json")

if c.UserAgent != "" {
Expand Down Expand Up @@ -246,6 +250,88 @@ func (r *Request) WithVariables(variables map[string]interface{}) *Request {
return r
}

// isMultipart returns true if the request contains a graphql.Upload object
// implying that the downstream request needs to be a multipart/form-data request
func (r *Request) isMultipart() bool {
stack := []map[string]any{r.Variables}
for len(stack) > 0 {
currentItem := stack[len(stack)-1]
stack = stack[:len(stack)-1]
for _, v := range currentItem {
switch v := v.(type) {
case graphql.Upload, *graphql.Upload, []graphql.Upload, []*graphql.Upload:
return true
case map[string]any:
stack = append(stack, v)
}
}
}
return false
}

func (r *Request) requestBody() (bytes.Buffer, string, error) {
var buf bytes.Buffer
var err error
contentType := "application/json; charset=utf-8"
if r.isMultipart() {
buf, contentType, err = multipartBody(r)
if err != nil {
return buf, "", fmt.Errorf("unable to encode multipart request body: %w", err)
}
return buf, contentType, nil
}
if err = json.NewEncoder(&buf).Encode(r); err != nil {
return buf, "", fmt.Errorf("unable to encode request body: %w", err)
}
return buf, contentType, nil
}

func multipartBody(r *Request) (bytes.Buffer, string, error) {
files, fileMap := prepareUploadsFromVariables(r.Variables)

var buf bytes.Buffer
mpw := multipart.NewWriter(&buf)
fw, err := mpw.CreateFormField("operations")
if err != nil {
return buf, "", err
}
if err = json.NewEncoder(fw).Encode(r); err != nil {
return buf, "", err
}
fw, err = mpw.CreateFormField("map")
if err != nil {
return buf, "", err
}
if err = json.NewEncoder(fw).Encode(fileMap); err != nil {
return buf, "", err
}
for fileIndex := range fileMap {
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{
"name": fileIndex,
"filename": files[fileIndex].Filename,
}))
if ct := files[fileIndex].ContentType; ct != "" {
h.Set("Content-Type", files[fileIndex].ContentType)
} else {
h.Set("Content-Type", "application/octet-stream")
}
innerFw, fileErr := mpw.CreatePart(h)
if fileErr != nil {
return buf, "", fileErr
}
_, ioErr := io.Copy(innerFw, files[fileIndex].File)
if ioErr != nil {
return buf, "", ioErr
}
}
err = mpw.Close()
if err != nil {
return buf, "", err
}
return buf, mpw.FormDataContentType(), nil
}

// Response is a GraphQL response
type Response struct {
Errors GraphqlErrors `json:"errors"`
Expand Down Expand Up @@ -275,3 +361,60 @@ func (e GraphqlErrors) Error() string {
func GenerateUserAgent(operation string) string {
return fmt.Sprintf("Bramble/%s (%s)", Version, operation)
}

func prepareUploadsFromVariables(variables map[string]any) (map[string]graphql.Upload, map[string][]string) {
type stackItem struct {
path string
data map[string]interface{}
}

stack := []stackItem{{path: "variables", data: variables}}

index := 0
fileMap := map[string][]string{}
files := map[string]graphql.Upload{}
for len(stack) > 0 {
currentItem := stack[len(stack)-1]
stack = stack[:len(stack)-1]

for key, value := range currentItem.data {
currentPath := currentItem.path + "." + key

switch v := value.(type) {
case graphql.Upload, *graphql.Upload:
currentItem.data[key] = nil
fileIndex := fmt.Sprintf("file%d", index)
fileMap[fileIndex] = []string{currentPath}
index += 1
switch v := v.(type) {
case graphql.Upload:
files[fileIndex] = v
case *graphql.Upload:
files[fileIndex] = *v
}
case []graphql.Upload:
currentItem.data[key] = make([]*struct{}, len(v))
for i, file := range v {
elemPath := fmt.Sprintf("%s.%d", currentPath, i)
fileIndex := fmt.Sprintf("file%d", index)
fileMap[fileIndex] = []string{elemPath}
index += 1
files[fileIndex] = file
}
case []*graphql.Upload:
currentItem.data[key] = make([]*struct{}, len(v))
for i, file := range v {
elemPath := fmt.Sprintf("%s.%d", currentPath, i)
fileIndex := fmt.Sprintf("file%d", index)
fileMap[fileIndex] = []string{elemPath}
index += 1
files[fileIndex] = *file
}
case map[string]any:
stack = append(stack, stackItem{data: v, path: currentPath})
default:
}
}
}
return files, fileMap
}
91 changes: 91 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"testing"

"github.com/99designs/gqlgen/graphql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -102,3 +103,93 @@ func TestGraphqlClient(t *testing.T) {
assert.Equal(t, "response exceeded maximum size of 1 bytes", err.Error())
})
}
func TestMultipartClient(t *testing.T) {
nestedMap := map[string]any{
"node1": map[string]any{
"node11": map[string]any{
"leaf111": graphql.Upload{},
"leaf112": "someThing",
"node113": map[string]any{"leaf1131": graphql.Upload{}},
},
"leaf12": 42,
"leaf13": graphql.Upload{},
},
"node2": map[string]any{
"leaf21": false,
"node21": map[string]any{
"leaf211": &graphql.Upload{},
},
},
"node3": graphql.Upload{},
"node4": []graphql.Upload{{}, {}},
"node5": []*graphql.Upload{{}, {}},
}

t.Run("parseMultipartVariables", func(t *testing.T) {
_, fileMap := prepareUploadsFromVariables(nestedMap)
fileMapKeys := []string{}
fileMapValues := []string{}
for k, v := range fileMap {
fileMapKeys = append(fileMapKeys, k)
fileMapValues = append(fileMapValues, v...)
}
assert.ElementsMatch(t, fileMapKeys, []string{"file0", "file1", "file2", "file3", "file4", "file5", "file6", "file7", "file8"})
assert.ElementsMatch(t, fileMapValues, []string{
"variables.node1.node11.node113.leaf1131",
"variables.node1.node11.leaf111",
"variables.node1.leaf13",
"variables.node2.node21.leaf211",
"variables.node3",
"variables.node4.0",
"variables.node4.1",
"variables.node5.0",
"variables.node5.1",
})
assert.Equal(
t,
map[string]any{
"node1": map[string]any{
"node11": map[string]any{
"leaf111": nil,
"leaf112": "someThing",
"node113": map[string]any{"leaf1131": nil},
},
"leaf12": 42,
"leaf13": nil,
},
"node2": map[string]any{
"leaf21": false,
"node21": map[string]any{
"leaf211": nil,
},
},
"node3": nil,
"node4": []*struct{}{nil, nil},
"node5": []*struct{}{nil, nil},
},
nestedMap,
)
})

t.Run("multipart request", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{ "data": {"root": "multipart response"} }`))
}))

c := NewClient()
req := &Request{Headers: make(http.Header)}
req.Headers.Set("Content-Type", "multipart/form-data")

var res struct {
Root string
}
err := c.Request(
context.Background(),
srv.URL,
req,
&res,
)
require.NoError(t, err)
assert.Equal(t, "multipart response", res.Root)
})
}
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type Config struct {
PollIntervalDuration time.Duration
MaxRequestsPerQuery int64 `json:"max-requests-per-query"`
MaxServiceResponseSize int64 `json:"max-service-response-size"`
MaxFileUploadSize int64 `json:"max-file-upload-size"`
Telemetry TelemetryConfig `json:"telemetry"`
Plugins []PluginConfig
// Config extensions that can be shared among plugins
Expand Down
12 changes: 11 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ services:
retries: 5
expose:
- 8080
gqlgen-multipart-file-upload-service:
build:
context: examples/gqlgen-multipart-file-upload-service
healthcheck: &healthcheck
test: wget -qO - http://localhost:8080/health
interval: 5s
timeout: 1s
retries: 5
expose:
- 8080
graph-gophers-service:
healthcheck: *healthcheck
build:
Expand All @@ -34,7 +44,7 @@ services:
configs: [gateway]
command: ["-config", "gateway"]
environment:
- BRAMBLE_SERVICE_LIST=http://gqlgen-service:8080/query http://graph-gophers-service:8080/query http://slow-service:8080/query http://nodejs-service:8080/query
- BRAMBLE_SERVICE_LIST=http://gqlgen-service:8080/query http://gqlgen-multipart-file-upload-service:8080/query http://graph-gophers-service:8080/query http://slow-service:8080/query http://nodejs-service:8080/query
ports:
- 8082:8082
- 8083:8083
Expand Down
13 changes: 13 additions & 0 deletions docs/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ Add `CORS` headers to queries.
}
```

## Headers

Allow headers to passthrough to downstream services.

```json
{
"name": "headers",
"config": {
"allowed-headers": ["X-Custom-Header"]
}
}
```

## JWT Auth

The JWT auth plugin validates that the request contains a valid JWT and
Expand Down
3 changes: 3 additions & 0 deletions examples/gqlgen-multipart-file-upload-service/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gqlgen-service
generated.go
models_gen.go
11 changes: 11 additions & 0 deletions examples/gqlgen-multipart-file-upload-service/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM golang:1.22-alpine3.19

ENV CGO_ENABLED=0

WORKDIR /go/src/app

COPY . .

RUN go generate .
RUN go get
CMD go run .
18 changes: 18 additions & 0 deletions examples/gqlgen-multipart-file-upload-service/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
ARTIFACT=gqlgen-service
DEF=gqlgen.yml schema.graphql
GEN=models_gen.go generated.go

build: $(ARTIFACT)

.PHONY: clean
clean:
rm -f $(ARTIFACT) $(GEN)

.PHONY: generate
generate: $(GEN)

$(GEN): $(DEF)
go generate

gqlgen-service: $(GEN) $(wildcard *.go)
go build
Loading
Loading