diff --git a/client/client.go b/client/client.go index 52bee6c..5be35d1 100644 --- a/client/client.go +++ b/client/client.go @@ -60,12 +60,17 @@ func (c *Client) CallChatCompletionsChat(ctx context.Context, chatReq *request.C } defer respBody.Close() - chatResp := &response.ChatCompletionsResponse{} - err = json.NewDecoder(respBody).Decode(chatResp) + body, err := io.ReadAll(respBody) if err != nil { return nil, err } + if len(body) == 0 { + return nil, errors.New("err: service unavailable") + } + + chatResp := &response.ChatCompletionsResponse{} + err = json.Unmarshal(body, chatResp) return chatResp, err } @@ -116,12 +121,17 @@ func (c *Client) CallChatCompletionsReasoner(ctx context.Context, chatReq *reque } defer respBody.Close() - chatResp := &response.ChatCompletionsResponse{} - err = json.NewDecoder(respBody).Decode(chatResp) + body, err := io.ReadAll(respBody) if err != nil { return nil, err } + if len(body) == 0 { + return nil, errors.New("err: service unavailable") + } + + chatResp := &response.ChatCompletionsResponse{} + err = json.Unmarshal(body, chatResp) return chatResp, err } diff --git a/deepseek_test/deepseek_test.go b/deepseek_test/deepseek_test.go index 9d84f12..601667c 100644 --- a/deepseek_test/deepseek_test.go +++ b/deepseek_test/deepseek_test.go @@ -84,7 +84,7 @@ func TestStreamChat(t *testing.T) { if err == io.EOF { break } - panic(err) + t.Fatal(err) } assert.NotNil(t, resp) assert.NotEmpty(t, resp.Id) @@ -93,7 +93,7 @@ func TestStreamChat(t *testing.T) { } func TestCallReasoner(t *testing.T) { - // ts := NewFakeServer("testdata/01_resp_basic_chat.json") + // ts := NewFakeServer("testdata/03_resp_basic_reasoner.json") // defer ts.Close() client, err := deepseek.NewClient(GetApiKey()) @@ -139,7 +139,7 @@ func TestStreamReasoner(t *testing.T) { if err == io.EOF { break } - panic(err) + t.Fatal(err) } assert.NotNil(t, resp) assert.NotEmpty(t, resp.Id) diff --git a/deepseek_test/testdata/04_resp_stream_reasoner.json b/deepseek_test/testdata/04_resp_stream_reasoner.json index de285c3..b7cbff3 100644 --- a/deepseek_test/testdata/04_resp_stream_reasoner.json +++ b/deepseek_test/testdata/04_resp_stream_reasoner.json @@ -3516,4 +3516,4 @@ data: {"id":"520e6457-b0f2-4203-ae18-0d651a6b76da","object":"chat.completion.chu data: {"id":"520e6457-b0f2-4203-ae18-0d651a6b76da","object":"chat.completion.chunk","created":1738144688,"model":"deepseek-reasoner","system_fingerprint":"fp_7e73fd9a08","choices":[{"index":0,"delta":{"content":"","reasoning_content":null},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":18,"completion_tokens":1758,"total_tokens":1776,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":1336},"prompt_cache_hit_tokens":0,"prompt_cache_miss_tokens":18}} -data: [DONE] \ No newline at end of file +data: [DONE] diff --git a/response/stream_reader.go b/response/stream_reader.go index 120d25f..7ff73f8 100644 --- a/response/stream_reader.go +++ b/response/stream_reader.go @@ -3,9 +3,14 @@ package response import ( "bufio" "encoding/json" + "errors" "io" ) +const KEEP_ALIVE = `: keep-alive` + +const KEEP_ALIVE_LEN = len(KEEP_ALIVE) + type StreamReader interface { Read() (*ChatCompletionsResponse, error) } @@ -45,21 +50,40 @@ func (m *streamReader) process(stream io.ReadCloser) { if len(bytes) <= 1 { continue } - bytes = trimDataPrefix(bytes) - if len(bytes) > 1 && bytes[0] == '[' { - str := string(bytes) - if str == "[DONE]" { - m.respCh <- &streamResponse{nil, io.EOF} // io.EOF to indicate end - close(m.respCh) - return - } + chatResp, err := processResponse(bytes) + if err != nil { + m.respCh <- &streamResponse{nil, err} + close(m.respCh) + return } - chatResp := &ChatCompletionsResponse{} - err = json.Unmarshal(bytes, chatResp) m.respCh <- &streamResponse{chatResp, err} } } +func processResponse(bytes []byte) (*ChatCompletionsResponse, error) { + // handle keep-alive response + if len(bytes) == KEEP_ALIVE_LEN { + if string(bytes) == KEEP_ALIVE { + err := errors.New("err: service unavailable") + return nil, err + } + } + + // handle response end + bytes = trimDataPrefix(bytes) + if len(bytes) > 1 && bytes[0] == '[' { + str := string(bytes) + if str == "[DONE]" { + return nil, io.EOF // io.EOF to indicate end + } + } + + // parse response + chatResp := &ChatCompletionsResponse{} + err := json.Unmarshal(bytes, chatResp) + return chatResp, err +} + func trimDataPrefix(content []byte) []byte { trimIndex := 6 if len(content) > trimIndex { diff --git a/response/stream_reader_unit_test.go b/response/stream_reader_unit_test.go new file mode 100644 index 0000000..6b544c6 --- /dev/null +++ b/response/stream_reader_unit_test.go @@ -0,0 +1,49 @@ +package response + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProcessResponse(t *testing.T) { + t.Run("response keep-alive return error", func(t *testing.T) { + respBody := []byte(KEEP_ALIVE) + _, err := processResponse(respBody) + assert.Error(t, err) + }) + + t.Run("response done return error", func(t *testing.T) { + respBody := []byte(`data: [DONE]`) + _, err := processResponse(respBody) + assert.Error(t, err) + assert.Equal(t, err, io.EOF) + }) + + t.Run("response json return chat response", func(t *testing.T) { + respBody := []byte(`data: {"id":"aceb72f7-ffab-422a-b498-62c9b4034f84","object":"chat.completion.chunk","created":1738119601,"model":"deepseek-chat","system_fingerprint":"fp_3a5770e1b4","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}`) + chatResp, err := processResponse(respBody) + assert.NoError(t, err) + assert.NotNil(t, chatResp) + assert.Equal(t, "aceb72f7-ffab-422a-b498-62c9b4034f84", chatResp.Id) + }) +} + +func TestTrimDataPrefix(t *testing.T) { + t.Run("data prefix trimmed from json response", func(t *testing.T) { + dataPrefix := `data: ` + jsonResp := `{"id":"aceb72f7-ffab-422a-b498-62c9b4034f84","object":"chat.completion.chunk","created":1738119601,"model":"deepseek-chat","system_fingerprint":"fp_3a5770e1b4","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}` + respBody := []byte(dataPrefix + jsonResp) + gotBody := trimDataPrefix(respBody) + assert.Equal(t, jsonResp, string(gotBody)) + }) + + t.Run("data prefix not trimmed from done response", func(t *testing.T) { + dataPrefix := `data: ` + doneResp := `[DONE]` + respBody := []byte(dataPrefix + doneResp) + gotBody := trimDataPrefix(respBody) + assert.Equal(t, doneResp, string(gotBody)) + }) +}