From 04bb3ef3923ba4b0931f0940e65f06b29cd53df8 Mon Sep 17 00:00:00 2001 From: MotorBottle <71703952+MotorBottle@users.noreply.github.com> Date: Tue, 6 Aug 2024 23:44:37 +0800 Subject: [PATCH] feat: add Max Tokens and Context Window Setting Options for Ollama Channel (#1694) * Update main.go with max_tokens param * Update model.go with max_tokens param * Update model.go * Update main.go * Update main.go * Adds num_ctx param for Ollama Channel * Added num_ctx param for ollama adapter * Added num_ctx param for ollama adapter * Improved data process logic --- relay/adaptor/ollama/main.go | 8 ++++++-- relay/adaptor/ollama/model.go | 2 ++ relay/model/general.go | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index 6a1d334d1a..43317ff66f 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { TopP: request.TopP, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, + NumPredict: request.MaxTokens, + NumCtx: request.NumCtx, }, Stream: request.Stream, } @@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC common.SetEventStreamHeaders(c) for scanner.Scan() { - data := strings.TrimPrefix(scanner.Text(), "}") - data = data + "}" + data := scanner.Text() + if strings.HasPrefix(data, "}") { + data = strings.TrimPrefix(data, "}") + "}" + } var ollamaResponse ChatResponse err := json.Unmarshal([]byte(data), &ollamaResponse) diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go index 29430e1c7c..7039984fcc 100644 --- a/relay/adaptor/ollama/model.go +++ b/relay/adaptor/ollama/model.go @@ -7,6 +7,8 @@ type Options struct { TopP float64 `json:"top_p,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } type Message struct { diff --git a/relay/model/general.go b/relay/model/general.go index 229a61c160..c34c1c2d5d 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -29,6 +29,7 @@ type GeneralOpenAIRequest struct { Dimensions int `json:"dimensions,omitempty"` Instruction string `json:"instruction,omitempty"` Size string `json:"size,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string {