Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cohere embedding for ai-cache #1572

Merged
merged 5 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions plugins/wasm-go/extensions/ai-cache/embedding/cohere.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package embedding

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

const (
COHERE_DOMAIN = "api.cohere.com"
COHERE_PORT = 443
COHERE_DEFAULT_MODEL_NAME = "embed-english-v2.0"
COHERE_ENDPOINT = "/v2/embed"
)

type cohereProviderInitializer struct {
}

var cohereConfig cohereProviderConfig

type cohereProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}

func (c *cohereProviderInitializer) InitConfig(json gjson.Result) {
cohereConfig.apiKey = json.Get("apiKey").String()
}
func (c *cohereProviderInitializer) ValidateConfig() error {
if cohereConfig.apiKey == "" {
return errors.New("[Cohere] apiKey is required")
}
return nil
}

func (t *cohereProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
if c.servicePort == 0 {
c.servicePort = COHERE_PORT
}
if c.serviceHost == "" {
c.serviceHost = COHERE_DOMAIN
}
return &CohereProvider{
config: c,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: c.serviceName,
Host: c.serviceHost,
Port: int64(c.servicePort),
}),
}, nil
}

type cohereResponse struct {
Embeddings cohereEmbeddings `json:"embeddings"`
}

type cohereEmbeddings struct {
FloatTypeEebedding [][]float64 `json:"float"`
}

type cohereEmbeddingRequest struct {
Texts []string `json:"texts"`
Model string `json:"model"`
InputType string `json:"input_type"`
EmbeddingTypes []string `json:"embedding_types"`
}

type CohereProvider struct {
config ProviderConfig
client wrapper.HttpClient
}

func (t *CohereProvider) GetProviderType() string {
return PROVIDER_TYPE_COHERE
}
func (t *CohereProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
model := t.config.model

if model == "" {
model = COHERE_DEFAULT_MODEL_NAME
}
data := cohereEmbeddingRequest{
Texts: texts,
Model: model,
InputType: "search_document",
EmbeddingTypes: []string{"float"},
}

requestBody, err := json.Marshal(data)
if err != nil {
log.Errorf("failed to marshal request data: %v", err)
return "", nil, nil, err
}

headers := [][2]string{
{"Authorization", fmt.Sprintf("BEARER %s", cohereConfig.apiKey)},
{"Content-Type", "application/json"},
}

return COHERE_ENDPOINT, headers, requestBody, nil
}

func (t *CohereProvider) parseTextEmbedding(responseBody []byte) (*cohereResponse, error) {
var resp cohereResponse
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}

func (t *CohereProvider) GetEmbedding(
queryString string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(emb []float64, err error)) error {
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log)
if err != nil {
log.Errorf("failed to construct parameters: %v", err)
return err
}

var resp *cohereResponse
err = t.client.Post(embUrl, embHeaders, embRequestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {

if statusCode != http.StatusOK {
err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode))
callback(nil, err)
return
}

log.Debugf("get embedding response: %d, %s", statusCode, responseBody)

resp, err = t.parseTextEmbedding(responseBody)
if err != nil {
err = fmt.Errorf("failed to parse response: %v", err)
callback(nil, err)
return
}

if len(resp.Embeddings.FloatTypeEebedding) == 0 {
err = errors.New("no embedding found in response")
callback(nil, err)
return
}

callback(resp.Embeddings.FloatTypeEebedding[0], nil)

}, t.config.timeout)
return err
}
20 changes: 16 additions & 4 deletions plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strconv"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

const (
Expand All @@ -17,11 +18,22 @@ const (
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
)

var dashScopeConfig dashScopeProviderConfig

type dashScopeProviderInitializer struct {
}
type dashScopeProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}

func (c *dashScopeProviderInitializer) InitConfig(json gjson.Result) {
dashScopeConfig.apiKey = json.Get("apiKey").String()
}

func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiKey == "" {
func (c *dashScopeProviderInitializer) ValidateConfig() error {
if dashScopeConfig.apiKey == "" {
return errors.New("[DashScope] apiKey is required")
}
return nil
Expand Down Expand Up @@ -114,14 +126,14 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin
return "", nil, nil, err
}

if d.config.apiKey == "" {
if dashScopeConfig.apiKey == "" {
err := errors.New("dashScopeKey is empty")
log.Errorf("failed to construct headers: %v", err)
return "", nil, nil, err
}

headers := [][2]string{
{"Authorization", "Bearer " + d.config.apiKey},
{"Authorization", "Bearer " + dashScopeConfig.apiKey},
{"Content-Type", "application/json"},
}

Expand Down
25 changes: 18 additions & 7 deletions plugins/wasm-go/extensions/ai-cache/embedding/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
"github.com/tidwall/gjson"
)

const (
Expand All @@ -18,11 +19,21 @@ const (
type openAIProviderInitializer struct {
}

func (t *openAIProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiKey == "" {
return errors.New("[OpenAI] embedding service ApiKey is required")
}
return nil
var openAIConfig openAIProviderConfig

type openAIProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}

func (c *openAIProviderInitializer) InitConfig(json gjson.Result) {
openAIConfig.apiKey = json.Get("apiKey").String()
}

func (c *openAIProviderInitializer) ValidateConfig() error {
if openAIConfig.apiKey == "" {
return errors.New("[openAI] apiKey is required")
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
}

func (t *openAIProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
Expand Down Expand Up @@ -97,7 +108,7 @@ func (t *OpenAIProvider) constructParameters(text string, log wrapper.Log) (stri
}

headers := [][2]string{
{"Authorization", fmt.Sprintf("Bearer %s", t.config.apiKey)},
{"Authorization", fmt.Sprintf("Bearer %s", openAIConfig.apiKey)},
{"Content-Type", "application/json"},
}

Expand Down
32 changes: 13 additions & 19 deletions plugins/wasm-go/extensions/ai-cache/embedding/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,21 @@ import (
const (
PROVIDER_TYPE_DASHSCOPE = "dashscope"
PROVIDER_TYPE_TEXTIN = "textin"
PROVIDER_TYPE_COHERE = "cohere"
PROVIDER_TYPE_OPENAI = "openai"
)

type providerInitializer interface {
ValidateConfig(ProviderConfig) error
InitConfig(json gjson.Result)
ValidateConfig() error
CreateProvider(ProviderConfig) (Provider, error)
}

var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{},
PROVIDER_TYPE_COHERE: &cohereProviderInitializer{},
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
}
)
Expand All @@ -39,35 +42,26 @@ type ProviderConfig struct {
// @Title zh-CN 文本特征提取服务端口
// @Description zh-CN 文本特征提取服务端口
servicePort int64
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
//@Title zh-CN TextIn x-ti-app-id
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinAppId string
//@Title zh-CN TextIn x-ti-secret-code
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinSecretCode string
//@Title zh-CN TextIn request matryoshka_dim
// @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding
textinMatryoshkaDim int
// @Title zh-CN 文本特征提取服务超时时间
// @Description zh-CN 文本特征提取服务超时时间
timeout uint32
// @Title zh-CN 文本特征提取服务使用的模型
// @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1"
model string

initializer providerInitializer
}

func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
i, has := providerInitializers[c.typ]
if has {
i.InitConfig(json)
c.initializer = i
}
c.serviceName = json.Get("serviceName").String()
c.serviceHost = json.Get("serviceHost").String()
c.servicePort = json.Get("servicePort").Int()
c.apiKey = json.Get("apiKey").String()
c.textinAppId = json.Get("textinAppId").String()
c.textinSecretCode = json.Get("textinSecretCode").String()
c.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int())
c.timeout = uint32(json.Get("timeout").Int())
c.model = json.Get("model").String()
if c.timeout == 0 {
Expand All @@ -82,11 +76,11 @@ func (c *ProviderConfig) Validate() error {
if c.typ == "" {
return errors.New("embedding service type is required")
}
initializer, has := providerInitializers[c.typ]
_, has := providerInitializers[c.typ]
if !has {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判定方法是不是需要改一下,直接判断 c.initializer 就行了吧?

return errors.New("unknown embedding service provider type: " + c.typ)
}
if err := initializer.ValidateConfig(*c); err != nil {
if err := c.initializer.ValidateConfig(); err != nil {
return err
}
return nil
Expand Down
Loading