Skip to content

Commit

Permalink
test: Distinct stream content type on runtime.Marshaler.
Browse files Browse the repository at this point in the history
  • Loading branch information
huin committed Nov 8, 2024
1 parent a400f8f commit d232cba
Showing 1 changed file with 46 additions and 15 deletions.
61 changes: 46 additions & 15 deletions runtime/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,36 +178,68 @@ func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c
func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) }
func (c *CustomMarshaler) ContentType(v interface{}) string { return "Custom-Content-Type" }

// MarshalerStreamContentType implements Marshaler, but with the addition of a custom StreamContentType.
type MarshalerStreamContentType struct {
runtime.Marshaler
CustomStreamContentType string
}

func (m MarshalerStreamContentType) StreamContentType(interface{}) string {
return m.CustomStreamContentType
}

func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
type msg struct {
pb proto.Message
err error
}
marshaler := &CustomMarshaler{&runtime.JSONPb{}}

tests := []struct {
name string
msgs []msg
statusCode int
name string
marshaler runtime.Marshaler
msgs []msg
statusCode int
wantContentType string
}{{
name: "encoding",
name: "encoding",
marshaler: marshaler,
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
{&pb.SimpleMessage{Id: "Two"}, nil},
},
statusCode: http.StatusOK,
statusCode: http.StatusOK,
wantContentType: "Custom-Content-Type",
}, {
name: "empty",
marshaler: marshaler,
statusCode: http.StatusOK,
}, {
name: "error",
msgs: []msg{{nil, status.Errorf(codes.OutOfRange, "400")}},
statusCode: http.StatusBadRequest,
name: "error",
marshaler: marshaler,
msgs: []msg{{nil, status.Errorf(codes.OutOfRange, "400")}},
statusCode: http.StatusBadRequest,
wantContentType: "Custom-Content-Type",
}, {
name: "stream_error",
name: "stream_error",
marshaler: marshaler,
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
{nil, status.Errorf(codes.OutOfRange, "400")},
},
statusCode: http.StatusOK,
statusCode: http.StatusOK,
wantContentType: "Custom-Content-Type",
}, {
name: "stream_content_type",
marshaler: MarshalerStreamContentType{
Marshaler: marshaler,
CustomStreamContentType: "Stream-Content-Type",
},
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
},
statusCode: http.StatusOK,
wantContentType: "Stream-Content-Type",
}}

newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
Expand All @@ -224,14 +256,13 @@ func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
}
}
ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
marshaler := &CustomMarshaler{&runtime.JSONPb{}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recv := newTestRecv(t, tt.msgs)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()

runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), tt.marshaler, resp, req, recv)

w := resp.Result()
if w.StatusCode != tt.statusCode {
Expand All @@ -245,16 +276,16 @@ func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
t.Errorf("Failed to read response body with %v", err)
}
w.Body.Close()
if len(body) > 0 && w.Header.Get("Content-Type") != "Custom-Content-Type" {
t.Errorf("Content-Type %s want Custom-Content-Type", w.Header.Get("Content-Type"))
if w.Header.Get("Content-Type") != tt.wantContentType {
t.Errorf("Content-Type %q want %q", w.Header.Get("Content-Type"), tt.wantContentType)
}

var want []byte
for _, msg := range tt.msgs {
if msg.err != nil {
t.Skip("checking error encodings")
}
b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
b, err := tt.marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
if err != nil {
t.Errorf("marshaler.Marshal() failed %v", err)
}
Expand Down

0 comments on commit d232cba

Please sign in to comment.