Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add textin embedding for ai-cache #1493

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions plugins/wasm-go/extensions/ai-cache/embedding/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

const (
PROVIDER_TYPE_DASHSCOPE = "dashscope"
PROVIDER_TYPE_TEXTIN = "textin"
)

type providerInitializer interface {
Expand All @@ -19,6 +20,7 @@ type providerInitializer interface {
var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{},
}
)

Expand All @@ -38,6 +40,15 @@ type ProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
//@Title zh-CN TextIn x-ti-app-id
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinAppId string
//@Title zh-CN TextIn x-ti-secret-code
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinSecretCode string
//@Title zh-CN TextIn request matryoshka_dim
// @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding
textinMatryoshkaDim int
// @Title zh-CN 文本特征提取服务超时时间
// @Description zh-CN 文本特征提取服务超时时间
timeout uint32
Expand All @@ -52,6 +63,9 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.serviceHost = json.Get("serviceHost").String()
c.servicePort = json.Get("servicePort").Int()
c.apiKey = json.Get("apiKey").String()
c.textinAppId = json.Get("textinAppId").String()
c.textinSecretCode = json.Get("textinSecretCode").String()
c.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int())
c.timeout = uint32(json.Get("timeout").Int())
c.model = json.Get("model").String()
if c.timeout == 0 {
Expand All @@ -63,9 +77,6 @@ func (c *ProviderConfig) Validate() error {
if c.serviceName == "" {
return errors.New("embedding service name is required")
}
if c.apiKey == "" {
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
return errors.New("embedding service API key is required")
}
if c.typ == "" {
return errors.New("embedding service type is required")
}
Expand Down
161 changes: 161 additions & 0 deletions plugins/wasm-go/extensions/ai-cache/embedding/textin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package embedding

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)

const (
TEXTIN_DOMAIN = "api.textin.com"
TEXTIN_PORT = 443
TEXTIN_DEFAULT_MODEL_NAME = "acge-text-embedding"
TEXTIN_ENDPOINT = "/ai/service/v1/acge_embedding"
)

type textInProviderInitializer struct {
}

func (t *textInProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.textinAppId == "" {
return errors.New("embedding service TextIn App ID is required")
}
if config.textinSecretCode == "" {
return errors.New("embedding service TextIn Secret Code is required")
}
if config.textinMatryoshkaDim == 0 {
return errors.New("embedding service TextIn Matryoshka Dim is required")
}
return nil
}

func (t *textInProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
if c.servicePort == 0 {
c.servicePort = TEXTIN_PORT
}
if c.serviceHost == "" {
c.serviceHost = TEXTIN_DOMAIN
}
return &TIProvider{
config: c,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: c.serviceName,
Host: c.serviceHost,
Port: int64(c.servicePort),
}),
}, nil
}

func (t *TIProvider) GetProviderType() string {
return PROVIDER_TYPE_TEXTIN
}

type TextInResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Duration float64 `json:"duration"`
Result TextInResult `json:"result"`
}

type TextInResult struct {
Embeddings [][]float64 `json:"embedding"`
MatryoshkaDim int `json:"matryoshka_dim"`
}

type TextInEmbeddingRequest struct {
Input []string `json:"input"`
MatryoshkaDim int `json:"matryoshka_dim"`
}

type TIProvider struct {
config ProviderConfig
client wrapper.HttpClient
}

func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {

data := TextInEmbeddingRequest{
Input: texts,
MatryoshkaDim: t.config.textinMatryoshkaDim,
}

requestBody, err := json.Marshal(data)
if err != nil {
log.Errorf("failed to marshal request data: %v", err)
return "", nil, nil, err
}

if t.config.textinAppId == "" {
err := errors.New("textinAppId is empty")
log.Errorf("failed to construct headers: %v", err)
return "", nil, nil, err
}
if t.config.textinSecretCode == "" {
err := errors.New("textinSecretCode is empty")
log.Errorf("failed to construct headers: %v", err)
return "", nil, nil, err
}

headers := [][2]string{
{"x-ti-app-id", t.config.textinAppId},
{"x-ti-secret-code", t.config.textinSecretCode},
{"Content-Type", "application/json"},
}

return TEXTIN_ENDPOINT, headers, requestBody, err
}

func (t *TIProvider) parseTextEmbedding(responseBody []byte) (*TextInResponse, error) {
var resp TextInResponse
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}

func (t *TIProvider) GetEmbedding(
queryString string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(emb []float64, err error)) error {
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log)
if err != nil {
log.Errorf("failed to construct parameters: %v", err)
return err
}

var resp *TextInResponse
err = t.client.Post(embUrl, embHeaders, embRequestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {

if statusCode != http.StatusOK {
err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode))
callback(nil, err)
return
}

log.Debugf("get embedding response: %d, %s", statusCode, responseBody)

resp, err = t.parseTextEmbedding(responseBody)
if err != nil {
err = fmt.Errorf("failed to parse response: %v", err)
callback(nil, err)
return
}

if len(resp.Result.Embeddings) == 0 {
err = errors.New("no embedding found in response")
callback(nil, err)
return
}

callback(resp.Result.Embeddings[0], nil)

}, t.config.timeout)
return err
}
Loading