diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 983dfbb25a..7fbd4954e2 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -101,55 +101,58 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun } func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) { - subMessages := strings.Split(sseMessage, "\n") - var message string - for _, msg := range subMessages { - if strings.HasPrefix(msg, "data:") { - message = msg - break + content := "" + for _, chunk := range strings.Split(sseMessage, "\n\n") { + log.Infof("chunk _ : %s", chunk) + subMessages := strings.Split(chunk, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break + } + } + if len(message) < 6 { + return content, fmt.Errorf("[processSSEMessage] invalid message: %s", message) } - } - if len(message) < 6 { - return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message) - } - // skip the prefix "data:" - bodyJson := message[5:] + // skip the prefix "data:" + bodyJson := message[5:] - if strings.TrimSpace(bodyJson) == "[DONE]" { - return "", nil - } + if strings.TrimSpace(bodyJson) == "[DONE]" { + return content, nil + } - // Extract values from JSON fields - responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom) - toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom) + // Extract values from JSON fields + responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom) + toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom) - if toolCalls.Exists() { - // TODO: Temporarily store the tool_calls value in the context for processing - ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String()) - } - - // Check if the ResponseBody field exists - if !responseBody.Exists() { - if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { - log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message) - return "", nil + if toolCalls.Exists() { + // TODO: Temporarily store the tool_calls value in the context for processing + ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String()) } - return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message) - } else { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - // If there is no content in the cache, initialize and set the content - if tempContentI == nil { - content := responseBody.String() - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) - return content, nil - } + // Check if the ResponseBody field exists + if !responseBody.Exists() { + if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { + log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message) + return content, nil + } + return content, fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message) + } else { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - // Update the content in the cache - appendMsg := responseBody.String() - content := tempContentI.(string) + appendMsg - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) - return content, nil + // If there is no content in the cache, initialize and set the content + if tempContentI == nil { + content = responseBody.String() + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + } else { + // Update the content in the cache + appendMsg := responseBody.String() + content = tempContentI.(string) + appendMsg + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + } + } } + return content, nil } diff --git a/plugins/wasm-go/extensions/ai-history/go.sum b/plugins/wasm-go/extensions/ai-history/go.sum index 6b1c2c3cd7..b4ab172fe2 100644 --- a/plugins/wasm-go/extensions/ai-history/go.sum +++ b/plugins/wasm-go/extensions/ai-history/go.sum @@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= diff --git a/plugins/wasm-go/extensions/ai-history/main.go b/plugins/wasm-go/extensions/ai-history/main.go index 512e13f1c6..3f728dd96d 100644 --- a/plugins/wasm-go/extensions/ai-history/main.go +++ b/plugins/wasm-go/extensions/ai-history/main.go @@ -194,6 +194,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte ctx.SetContext(StreamContextKey, struct{}{}) } identityKey := ctx.GetStringContext(IdentityKey, "") + question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String()) + if question == "" { + log.Debug("parse question from request body failed") + return types.ActionContinue + } + ctx.SetContext(QuestionContextKey, question) err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) { if err := response.Error(); err != nil { log.Errorf("redis get failed, err:%v", err) @@ -230,13 +236,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte _ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1) return } - question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String()) - if question == "" { - log.Debug("parse question from request body failed") - _ = proxywasm.ResumeHttpRequest() - return - } - ctx.SetContext(QuestionContextKey, question) fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2 currJson := bodyJson.Get("messages").String() var currMessage []ChatHistory @@ -317,38 +316,39 @@ func getIntQueryParameter(name string, path string, defaultValue int) int { } func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { - subMessages := strings.Split(sseMessage, "\n") - var message string - for _, msg := range subMessages { - if strings.HasPrefix(msg, "data:") { - message = msg - break + content := "" + for _, chunk := range strings.Split(sseMessage, "\n\n") { + subMessages := strings.Split(chunk, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break + } } - } - if len(message) < 6 { - log.Errorf("invalid message:%s", message) - return "" - } - // skip the prefix "data:" - bodyJson := message[5:] - if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() { - tempContentI := ctx.GetContext(AnswerContentContextKey) - if tempContentI == nil { - content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) - ctx.SetContext(AnswerContentContextKey, content) + if len(message) < 6 { + log.Errorf("invalid message:%s", message) return content } - append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) - content := tempContentI.(string) + append - ctx.SetContext(AnswerContentContextKey, content) - return content - } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { - // TODO: compatible with other providers - ctx.SetContext(ToolCallsContextKey, struct{}{}) - return "" + // skip the prefix "data:" + bodyJson := message[5:] + if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() { + tempContentI := ctx.GetContext(AnswerContentContextKey) + if tempContentI == nil { + content = TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) + ctx.SetContext(AnswerContentContextKey, content) + } else { + append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) + content = tempContentI.(string) + append + ctx.SetContext(AnswerContentContextKey, content) + } + } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { + // TODO: compatible with other providers + ctx.SetContext(ToolCallsContextKey, struct{}{}) + } + log.Debugf("unknown message:%s", bodyJson) } - log.Debugf("unknown message:%s", bodyJson) - return "" + return content } func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {