Skip to content

Commit

Permalink
fix moonshot usage compatible problem (#1568)
Browse files Browse the repository at this point in the history
  • Loading branch information
rinfx authored Dec 5, 2024
1 parent 7ce6d7a commit 22790aa
Showing 1 changed file with 100 additions and 1 deletion.
101 changes: 100 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package provider
import (
"errors"
"fmt"
"net/http"
"strings"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"net/http"
"github.com/tidwall/sjson"
)

// moonshotProvider is the provider for Moonshot AI service.
Expand Down Expand Up @@ -149,3 +152,99 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
return errors.New("unsupported method: " + method)
}
}

func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
receivedBody := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
receivedBody = append(bufferedStreamingBody, chunk...)
}

eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1

defer func() {
if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) {
// Just in case the received chunk is not a complete event.
ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:])
} else {
ctx.SetContext(ctxKeyStreamingBody, nil)
}
}()

var responseBuilder strings.Builder
currentKey := ""
currentEvent := &streamEvent{}
i, length := 0, len(receivedBody)
for i = 0; i < length; i++ {
ch := receivedBody[i]
if ch != '\n' {
if lineStartIndex == -1 {
if eventStartIndex == -1 {
eventStartIndex = i
}
lineStartIndex = i
valueStartIndex = -1
}
if valueStartIndex == -1 {
if ch == ':' {
valueStartIndex = i + 1
currentKey = string(receivedBody[lineStartIndex:valueStartIndex])
}
} else if valueStartIndex == i && ch == ' ' {
// Skip leading spaces in data.
valueStartIndex = i + 1
}
continue
}

if lineStartIndex != -1 {
value := string(receivedBody[valueStartIndex:i])
currentEvent.setValue(currentKey, value)
} else {
// Extra new line. The current event is complete.
log.Debugf("processing event: %v", currentEvent)
m.convertStreamEvent(&responseBuilder, currentEvent, log)
// Reset event parsing state.
eventStartIndex = -1
currentEvent = &streamEvent{}
}

// Reset line parsing state.
lineStartIndex = -1
valueStartIndex = -1
currentKey = ""
}

modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
}

func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error {
if event.Data == streamEndDataValue {
m.appendStreamEvent(responseBuilder, event)
return nil
}

if gjson.Get(event.Data, "choices.0.usage").Exists() {
usageStr := gjson.Get(event.Data, "choices.0.usage").Raw
newData, err := sjson.Delete(event.Data, "choices.0.usage")
if err != nil {
log.Errorf("convert usage event error: %v", err)
return err
}
newData, err = sjson.SetRaw(newData, "usage", usageStr)
if err != nil {
log.Errorf("convert usage event error: %v", err)
return err
}
event.Data = newData
}
m.appendStreamEvent(responseBuilder, event)
return nil
}

func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
responseBuilder.WriteString(streamDataItemKey)
responseBuilder.WriteString(event.Data)
responseBuilder.WriteString("\n\n")
}

0 comments on commit 22790aa

Please sign in to comment.